"""TTS generation endpoints.""" import asyncio import logging import uuid from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session logger = logging.getLogger(__name__) from .. import models from ..services import history, profiles, tts from ..database import Generation as DBGeneration, VoiceProfile as DBVoiceProfile, get_db from ..services.generation import run_generation from ..services.task_queue import cancel_generation as cancel_generation_job, enqueue_generation from ..utils.tasks import get_task_manager router = APIRouter() def _resolve_generation_engine(data: models.GenerationRequest, profile) -> str: return data.engine or getattr(profile, "default_engine", None) or getattr(profile, "preset_engine", None) or "qwen" @router.post("/generate", response_model=models.GenerationResponse) async def generate_speech( data: models.GenerationRequest, db: Session = Depends(get_db), ): """Generate speech from text using a voice profile.""" task_manager = get_task_manager() generation_id = str(uuid.uuid4()) profile = await profiles.get_profile(data.profile_id, db) if not profile: raise HTTPException(status_code=404, detail="Profile not found") from ..backends import engine_has_model_sizes engine = _resolve_generation_engine(data, profile) try: profiles.validate_profile_engine(profile, engine) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) model_size = (data.model_size or "1.7B") if engine_has_model_sizes(engine) else None generation = await history.create_generation( profile_id=data.profile_id, text=data.text, language=data.language, audio_path="", duration=0, seed=data.seed, db=db, instruct=data.instruct, generation_id=generation_id, status="generating", engine=engine, model_size=model_size if engine_has_model_sizes(engine) else None, ) task_manager.start_generation( task_id=generation_id, profile_id=data.profile_id, text=data.text, ) effects_chain_config = None if data.effects_chain is not None: effects_chain_config = [e.model_dump() for e in data.effects_chain] else: import json as _json profile_obj = db.query(DBVoiceProfile).filter_by(id=data.profile_id).first() if profile_obj and profile_obj.effects_chain: try: effects_chain_config = _json.loads(profile_obj.effects_chain) except Exception: pass enqueue_generation( generation_id, run_generation( generation_id=generation_id, profile_id=data.profile_id, text=data.text, language=data.language, engine=engine, model_size=model_size, seed=data.seed, normalize=data.normalize, effects_chain=effects_chain_config, instruct=data.instruct, mode="generate", max_chunk_chars=data.max_chunk_chars, crossfade_ms=data.crossfade_ms, ) ) return generation @router.post("/generate/{generation_id}/retry", response_model=models.GenerationResponse) async def retry_generation(generation_id: str, db: Session = Depends(get_db)): """Retry a failed generation using the same parameters.""" gen = db.query(DBGeneration).filter_by(id=generation_id).first() if not gen: raise HTTPException(status_code=404, detail="Generation not found") if (gen.status or "completed") != "failed": raise HTTPException(status_code=400, detail="Only failed generations can be retried") gen.status = "generating" gen.error = None gen.audio_path = "" gen.duration = 0 db.commit() db.refresh(gen) task_manager = get_task_manager() task_manager.start_generation( task_id=generation_id, profile_id=gen.profile_id, text=gen.text, ) enqueue_generation( generation_id, run_generation( generation_id=generation_id, profile_id=gen.profile_id, text=gen.text, language=gen.language, engine=gen.engine or "qwen", model_size=gen.model_size or "1.7B", seed=gen.seed, instruct=gen.instruct, mode="retry", ) ) return models.GenerationResponse.model_validate(gen) @router.post( "/generate/{generation_id}/regenerate", response_model=models.GenerationResponse, ) async def regenerate_generation(generation_id: str, db: Session = Depends(get_db)): """Re-run TTS with the same parameters and save the result as a new version.""" gen = db.query(DBGeneration).filter_by(id=generation_id).first() if not gen: raise HTTPException(status_code=404, detail="Generation not found") if (gen.status or "completed") != "completed": raise HTTPException(status_code=400, detail="Generation must be completed to regenerate") gen.status = "generating" gen.error = None db.commit() db.refresh(gen) task_manager = get_task_manager() task_manager.start_generation( task_id=generation_id, profile_id=gen.profile_id, text=gen.text, ) version_id = str(uuid.uuid4()) enqueue_generation( generation_id, run_generation( generation_id=generation_id, profile_id=gen.profile_id, text=gen.text, language=gen.language, engine=gen.engine or "qwen", model_size=gen.model_size or "1.7B", seed=gen.seed, instruct=gen.instruct, mode="regenerate", version_id=version_id, ) ) return models.GenerationResponse.model_validate(gen) @router.post("/generate/{generation_id}/cancel") async def cancel_generation(generation_id: str, db: Session = Depends(get_db)): """Cancel a queued or running generation.""" gen = db.query(DBGeneration).filter_by(id=generation_id).first() if not gen: raise HTTPException(status_code=404, detail="Generation not found") if (gen.status or "completed") not in ("loading_model", "generating"): raise HTTPException(status_code=400, detail="Only active generations can be cancelled") cancellation_state = cancel_generation_job(generation_id) if cancellation_state is None: raise HTTPException(status_code=409, detail="Generation is no longer cancellable") if cancellation_state == "queued": task_manager = get_task_manager() task_manager.complete_generation(generation_id) await history.update_generation_status( generation_id=generation_id, status="failed", db=db, error="Generation cancelled", ) return {"message": "Queued generation cancelled"} return {"message": "Generation cancellation requested"} @router.get("/generate/{generation_id}/status") async def get_generation_status(generation_id: str, db: Session = Depends(get_db)): """SSE endpoint that streams generation status updates.""" import json async def event_stream(): try: while True: db.expire_all() gen = db.query(DBGeneration).filter_by(id=generation_id).first() if not gen: yield f"data: {json.dumps({'status': 'not_found', 'id': generation_id})}\n\n" return payload = { "id": gen.id, "status": gen.status or "completed", "duration": gen.duration, "error": gen.error, } yield f"data: {json.dumps(payload)}\n\n" if (gen.status or "completed") in ("completed", "failed"): return await asyncio.sleep(1) except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError): logger.debug("SSE client disconnected for generation %s", generation_id) return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) @router.post("/generate/stream") async def stream_speech( data: models.GenerationRequest, db: Session = Depends(get_db), ): """Generate speech and stream the WAV audio directly without saving to disk.""" from ..backends import get_tts_backend_for_engine, ensure_model_cached_or_raise, load_engine_model, engine_needs_trim profile = await profiles.get_profile(data.profile_id, db) if not profile: raise HTTPException(status_code=404, detail="Profile not found") engine = _resolve_generation_engine(data, profile) try: profiles.validate_profile_engine(profile, engine) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) tts_model = get_tts_backend_for_engine(engine) model_size = data.model_size or "1.7B" await ensure_model_cached_or_raise(engine, model_size) await load_engine_model(engine, model_size) voice_prompt = await profiles.create_voice_prompt_for_profile( data.profile_id, db, engine=engine, ) from ..utils.chunked_tts import generate_chunked trim_fn = None if engine_needs_trim(engine): from ..utils.audio import trim_tts_output trim_fn = trim_tts_output audio, sample_rate = await generate_chunked( tts_model, data.text, voice_prompt, language=data.language, seed=data.seed, instruct=data.instruct, max_chunk_chars=data.max_chunk_chars, crossfade_ms=data.crossfade_ms, trim_fn=trim_fn, ) effects_chain_config = None if data.effects_chain is not None: effects_chain_config = [e.model_dump() for e in data.effects_chain] elif profile.effects_chain: import json as _json try: effects_chain_config = _json.loads(profile.effects_chain) except Exception: effects_chain_config = None if effects_chain_config: from ..utils.effects import apply_effects audio = apply_effects(audio, sample_rate, effects_chain_config) if data.normalize: from ..utils.audio import normalize_audio audio = normalize_audio(audio) wav_bytes = tts.audio_to_wav_bytes(audio, sample_rate) async def _wav_stream(): try: chunk_size = 64 * 1024 for i in range(0, len(wav_bytes), chunk_size): yield wav_bytes[i : i + chunk_size] except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError): logger.debug("Client disconnected during audio stream") return StreamingResponse( _wav_stream(), media_type="audio/wav", headers={"Content-Disposition": 'attachment; filename="speech.wav"'}, )