Initial commit
This commit is contained in:
263
backend/services/generation.py
Normal file
263
backend/services/generation.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Unified TTS generation orchestration.
|
||||
|
||||
Replaces the three near-identical closures (_run_generation, _run_retry,
|
||||
_run_regenerate) that lived in main.py with a single ``run_generation()``
|
||||
function parameterized by *mode*.
|
||||
|
||||
Mode differences:
|
||||
- "generate" : full pipeline -- save clean version, optionally apply
|
||||
effects and create a processed version.
|
||||
- "retry" : re-runs a failed generation with the same seed.
|
||||
No effects, no version creation.
|
||||
- "regenerate" : re-runs with seed=None for variation. Creates a new
|
||||
version with an auto-incremented "take-N" label.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Literal, Optional
|
||||
|
||||
from .. import config
|
||||
from . import history, profiles
|
||||
from ..database import get_db
|
||||
from ..utils.tasks import get_task_manager
|
||||
|
||||
|
||||
async def run_generation(
|
||||
*,
|
||||
generation_id: str,
|
||||
profile_id: str,
|
||||
text: str,
|
||||
language: str,
|
||||
engine: str,
|
||||
model_size: str,
|
||||
seed: Optional[int],
|
||||
normalize: bool = False,
|
||||
effects_chain: Optional[list] = None,
|
||||
instruct: Optional[str] = None,
|
||||
mode: Literal["generate", "retry", "regenerate"],
|
||||
max_chunk_chars: Optional[int] = None,
|
||||
crossfade_ms: Optional[int] = None,
|
||||
version_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Execute TTS inference and persist the result.
|
||||
|
||||
This is the single entry point for all background generation work.
|
||||
It is designed to be enqueued via ``services.task_queue.enqueue_generation``.
|
||||
"""
|
||||
from ..backends import load_engine_model, get_tts_backend_for_engine, engine_needs_trim
|
||||
from ..utils.chunked_tts import generate_chunked
|
||||
from ..utils.audio import normalize_audio, save_audio, trim_tts_output
|
||||
|
||||
task_manager = get_task_manager()
|
||||
bg_db = next(get_db())
|
||||
|
||||
try:
|
||||
tts_model = get_tts_backend_for_engine(engine)
|
||||
|
||||
if not tts_model.is_loaded():
|
||||
await history.update_generation_status(generation_id, "loading_model", bg_db)
|
||||
|
||||
await load_engine_model(engine, model_size)
|
||||
|
||||
voice_prompt = await profiles.create_voice_prompt_for_profile(
|
||||
profile_id,
|
||||
bg_db,
|
||||
use_cache=True,
|
||||
engine=engine,
|
||||
)
|
||||
|
||||
await history.update_generation_status(generation_id, "generating", bg_db)
|
||||
trim_fn = trim_tts_output if engine_needs_trim(engine) else None
|
||||
|
||||
gen_kwargs: dict = dict(
|
||||
language=language,
|
||||
seed=seed if mode != "regenerate" else None,
|
||||
instruct=instruct,
|
||||
trim_fn=trim_fn,
|
||||
)
|
||||
if max_chunk_chars is not None:
|
||||
gen_kwargs["max_chunk_chars"] = max_chunk_chars
|
||||
if crossfade_ms is not None:
|
||||
gen_kwargs["crossfade_ms"] = crossfade_ms
|
||||
|
||||
audio, sample_rate = await generate_chunked(tts_model, text, voice_prompt, **gen_kwargs)
|
||||
|
||||
# --- Normalize (generate and regenerate always; retry skips) -----
|
||||
if normalize or mode == "regenerate":
|
||||
audio = normalize_audio(audio)
|
||||
|
||||
duration = len(audio) / sample_rate
|
||||
|
||||
# --- Persist audio and update status -----------------------------
|
||||
if mode == "generate":
|
||||
final_path = _save_generate(
|
||||
generation_id=generation_id,
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
effects_chain=effects_chain,
|
||||
save_audio=save_audio,
|
||||
db=bg_db,
|
||||
)
|
||||
elif mode == "retry":
|
||||
final_path = _save_retry(
|
||||
generation_id=generation_id,
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
save_audio=save_audio,
|
||||
)
|
||||
elif mode == "regenerate":
|
||||
final_path = _save_regenerate(
|
||||
generation_id=generation_id,
|
||||
version_id=version_id,
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
save_audio=save_audio,
|
||||
db=bg_db,
|
||||
)
|
||||
|
||||
await history.update_generation_status(
|
||||
generation_id=generation_id,
|
||||
status="completed",
|
||||
db=bg_db,
|
||||
audio_path=final_path,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
await history.update_generation_status(
|
||||
generation_id=generation_id,
|
||||
status="failed",
|
||||
db=bg_db,
|
||||
error="Generation cancelled",
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
await history.update_generation_status(
|
||||
generation_id=generation_id,
|
||||
status="failed",
|
||||
db=bg_db,
|
||||
error=str(e),
|
||||
)
|
||||
finally:
|
||||
task_manager.complete_generation(generation_id)
|
||||
bg_db.close()
|
||||
|
||||
|
||||
def _save_generate(
|
||||
*,
|
||||
generation_id: str,
|
||||
audio,
|
||||
sample_rate: int,
|
||||
effects_chain: Optional[list],
|
||||
save_audio,
|
||||
db,
|
||||
) -> str:
|
||||
"""Save clean version and optionally an effects-processed version.
|
||||
|
||||
Returns the final audio path (processed if effects were applied,
|
||||
otherwise clean).
|
||||
"""
|
||||
from . import versions as versions_mod
|
||||
|
||||
clean_audio_path = config.get_generations_dir() / f"{generation_id}.wav"
|
||||
save_audio(audio, str(clean_audio_path), sample_rate)
|
||||
|
||||
has_effects = effects_chain and any(e.get("enabled", True) for e in effects_chain)
|
||||
|
||||
versions_mod.create_version(
|
||||
generation_id=generation_id,
|
||||
label="original",
|
||||
audio_path=config.to_storage_path(clean_audio_path),
|
||||
db=db,
|
||||
effects_chain=None,
|
||||
is_default=not has_effects,
|
||||
)
|
||||
|
||||
final_audio_path = str(clean_audio_path)
|
||||
|
||||
if has_effects:
|
||||
from ..utils.effects import apply_effects, validate_effects_chain
|
||||
|
||||
assert effects_chain is not None
|
||||
|
||||
error_msg = validate_effects_chain(effects_chain)
|
||||
if error_msg:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning("invalid effects chain, skipping: %s", error_msg)
|
||||
versions_mod.set_default_version(
|
||||
versions_mod.list_versions(generation_id, db)[0].id, db
|
||||
)
|
||||
else:
|
||||
processed_audio = apply_effects(audio, sample_rate, effects_chain)
|
||||
processed_path = config.get_generations_dir() / f"{generation_id}_processed.wav"
|
||||
save_audio(processed_audio, str(processed_path), sample_rate)
|
||||
final_audio_path = str(processed_path)
|
||||
versions_mod.create_version(
|
||||
generation_id=generation_id,
|
||||
label="version-2",
|
||||
audio_path=config.to_storage_path(processed_path),
|
||||
db=db,
|
||||
effects_chain=effects_chain,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
return config.to_storage_path(final_audio_path)
|
||||
|
||||
|
||||
def _save_retry(
|
||||
*,
|
||||
generation_id: str,
|
||||
audio,
|
||||
sample_rate: int,
|
||||
save_audio,
|
||||
) -> str:
|
||||
"""Save retry output -- single file, no versions.
|
||||
|
||||
Returns the audio path.
|
||||
"""
|
||||
audio_path = config.get_generations_dir() / f"{generation_id}.wav"
|
||||
save_audio(audio, str(audio_path), sample_rate)
|
||||
return config.to_storage_path(audio_path)
|
||||
|
||||
|
||||
def _save_regenerate(
|
||||
*,
|
||||
generation_id: str,
|
||||
version_id: Optional[str],
|
||||
audio,
|
||||
sample_rate: int,
|
||||
save_audio,
|
||||
db,
|
||||
) -> str:
|
||||
"""Save regeneration output as a new version with auto-label.
|
||||
|
||||
Returns the audio path.
|
||||
"""
|
||||
from . import versions as versions_mod
|
||||
|
||||
import uuid as _uuid
|
||||
|
||||
suffix = _uuid.uuid4().hex[:8]
|
||||
audio_path = config.get_generations_dir() / f"{generation_id}_{suffix}.wav"
|
||||
save_audio(audio, str(audio_path), sample_rate)
|
||||
|
||||
# Count via DB query rather than list length to avoid TOCTOU race
|
||||
from ..database import GenerationVersion as DBGenerationVersion
|
||||
|
||||
count = db.query(DBGenerationVersion).filter_by(generation_id=generation_id).count()
|
||||
label = f"take-{count + 1}"
|
||||
|
||||
versions_mod.create_version(
|
||||
generation_id=generation_id,
|
||||
label=label,
|
||||
audio_path=config.to_storage_path(audio_path),
|
||||
db=db,
|
||||
effects_chain=None,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
return config.to_storage_path(audio_path)
|
||||
Reference in New Issue
Block a user