264 lines
8.1 KiB
Python
264 lines
8.1 KiB
Python
"""
|
|
Unified TTS generation orchestration.
|
|
|
|
Replaces the three near-identical closures (_run_generation, _run_retry,
|
|
_run_regenerate) that lived in main.py with a single ``run_generation()``
|
|
function parameterized by *mode*.
|
|
|
|
Mode differences:
|
|
- "generate" : full pipeline -- save clean version, optionally apply
|
|
effects and create a processed version.
|
|
- "retry" : re-runs a failed generation with the same seed.
|
|
No effects, no version creation.
|
|
- "regenerate" : re-runs with seed=None for variation. Creates a new
|
|
version with an auto-incremented "take-N" label.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import traceback
|
|
from typing import Literal, Optional
|
|
|
|
from .. import config
|
|
from . import history, profiles
|
|
from ..database import get_db
|
|
from ..utils.tasks import get_task_manager
|
|
|
|
|
|
async def run_generation(
|
|
*,
|
|
generation_id: str,
|
|
profile_id: str,
|
|
text: str,
|
|
language: str,
|
|
engine: str,
|
|
model_size: str,
|
|
seed: Optional[int],
|
|
normalize: bool = False,
|
|
effects_chain: Optional[list] = None,
|
|
instruct: Optional[str] = None,
|
|
mode: Literal["generate", "retry", "regenerate"],
|
|
max_chunk_chars: Optional[int] = None,
|
|
crossfade_ms: Optional[int] = None,
|
|
version_id: Optional[str] = None,
|
|
) -> None:
|
|
"""Execute TTS inference and persist the result.
|
|
|
|
This is the single entry point for all background generation work.
|
|
It is designed to be enqueued via ``services.task_queue.enqueue_generation``.
|
|
"""
|
|
from ..backends import load_engine_model, get_tts_backend_for_engine, engine_needs_trim
|
|
from ..utils.chunked_tts import generate_chunked
|
|
from ..utils.audio import normalize_audio, save_audio, trim_tts_output
|
|
|
|
task_manager = get_task_manager()
|
|
bg_db = next(get_db())
|
|
|
|
try:
|
|
tts_model = get_tts_backend_for_engine(engine)
|
|
|
|
if not tts_model.is_loaded():
|
|
await history.update_generation_status(generation_id, "loading_model", bg_db)
|
|
|
|
await load_engine_model(engine, model_size)
|
|
|
|
voice_prompt = await profiles.create_voice_prompt_for_profile(
|
|
profile_id,
|
|
bg_db,
|
|
use_cache=True,
|
|
engine=engine,
|
|
)
|
|
|
|
await history.update_generation_status(generation_id, "generating", bg_db)
|
|
trim_fn = trim_tts_output if engine_needs_trim(engine) else None
|
|
|
|
gen_kwargs: dict = dict(
|
|
language=language,
|
|
seed=seed if mode != "regenerate" else None,
|
|
instruct=instruct,
|
|
trim_fn=trim_fn,
|
|
)
|
|
if max_chunk_chars is not None:
|
|
gen_kwargs["max_chunk_chars"] = max_chunk_chars
|
|
if crossfade_ms is not None:
|
|
gen_kwargs["crossfade_ms"] = crossfade_ms
|
|
|
|
audio, sample_rate = await generate_chunked(tts_model, text, voice_prompt, **gen_kwargs)
|
|
|
|
# --- Normalize (generate and regenerate always; retry skips) -----
|
|
if normalize or mode == "regenerate":
|
|
audio = normalize_audio(audio)
|
|
|
|
duration = len(audio) / sample_rate
|
|
|
|
# --- Persist audio and update status -----------------------------
|
|
if mode == "generate":
|
|
final_path = _save_generate(
|
|
generation_id=generation_id,
|
|
audio=audio,
|
|
sample_rate=sample_rate,
|
|
effects_chain=effects_chain,
|
|
save_audio=save_audio,
|
|
db=bg_db,
|
|
)
|
|
elif mode == "retry":
|
|
final_path = _save_retry(
|
|
generation_id=generation_id,
|
|
audio=audio,
|
|
sample_rate=sample_rate,
|
|
save_audio=save_audio,
|
|
)
|
|
elif mode == "regenerate":
|
|
final_path = _save_regenerate(
|
|
generation_id=generation_id,
|
|
version_id=version_id,
|
|
audio=audio,
|
|
sample_rate=sample_rate,
|
|
save_audio=save_audio,
|
|
db=bg_db,
|
|
)
|
|
|
|
await history.update_generation_status(
|
|
generation_id=generation_id,
|
|
status="completed",
|
|
db=bg_db,
|
|
audio_path=final_path,
|
|
duration=duration,
|
|
)
|
|
|
|
except asyncio.CancelledError:
|
|
await history.update_generation_status(
|
|
generation_id=generation_id,
|
|
status="failed",
|
|
db=bg_db,
|
|
error="Generation cancelled",
|
|
)
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
await history.update_generation_status(
|
|
generation_id=generation_id,
|
|
status="failed",
|
|
db=bg_db,
|
|
error=str(e),
|
|
)
|
|
finally:
|
|
task_manager.complete_generation(generation_id)
|
|
bg_db.close()
|
|
|
|
|
|
def _save_generate(
|
|
*,
|
|
generation_id: str,
|
|
audio,
|
|
sample_rate: int,
|
|
effects_chain: Optional[list],
|
|
save_audio,
|
|
db,
|
|
) -> str:
|
|
"""Save clean version and optionally an effects-processed version.
|
|
|
|
Returns the final audio path (processed if effects were applied,
|
|
otherwise clean).
|
|
"""
|
|
from . import versions as versions_mod
|
|
|
|
clean_audio_path = config.get_generations_dir() / f"{generation_id}.wav"
|
|
save_audio(audio, str(clean_audio_path), sample_rate)
|
|
|
|
has_effects = effects_chain and any(e.get("enabled", True) for e in effects_chain)
|
|
|
|
versions_mod.create_version(
|
|
generation_id=generation_id,
|
|
label="original",
|
|
audio_path=config.to_storage_path(clean_audio_path),
|
|
db=db,
|
|
effects_chain=None,
|
|
is_default=not has_effects,
|
|
)
|
|
|
|
final_audio_path = str(clean_audio_path)
|
|
|
|
if has_effects:
|
|
from ..utils.effects import apply_effects, validate_effects_chain
|
|
|
|
assert effects_chain is not None
|
|
|
|
error_msg = validate_effects_chain(effects_chain)
|
|
if error_msg:
|
|
import logging
|
|
logging.getLogger(__name__).warning("invalid effects chain, skipping: %s", error_msg)
|
|
versions_mod.set_default_version(
|
|
versions_mod.list_versions(generation_id, db)[0].id, db
|
|
)
|
|
else:
|
|
processed_audio = apply_effects(audio, sample_rate, effects_chain)
|
|
processed_path = config.get_generations_dir() / f"{generation_id}_processed.wav"
|
|
save_audio(processed_audio, str(processed_path), sample_rate)
|
|
final_audio_path = str(processed_path)
|
|
versions_mod.create_version(
|
|
generation_id=generation_id,
|
|
label="version-2",
|
|
audio_path=config.to_storage_path(processed_path),
|
|
db=db,
|
|
effects_chain=effects_chain,
|
|
is_default=True,
|
|
)
|
|
|
|
return config.to_storage_path(final_audio_path)
|
|
|
|
|
|
def _save_retry(
|
|
*,
|
|
generation_id: str,
|
|
audio,
|
|
sample_rate: int,
|
|
save_audio,
|
|
) -> str:
|
|
"""Save retry output -- single file, no versions.
|
|
|
|
Returns the audio path.
|
|
"""
|
|
audio_path = config.get_generations_dir() / f"{generation_id}.wav"
|
|
save_audio(audio, str(audio_path), sample_rate)
|
|
return config.to_storage_path(audio_path)
|
|
|
|
|
|
def _save_regenerate(
|
|
*,
|
|
generation_id: str,
|
|
version_id: Optional[str],
|
|
audio,
|
|
sample_rate: int,
|
|
save_audio,
|
|
db,
|
|
) -> str:
|
|
"""Save regeneration output as a new version with auto-label.
|
|
|
|
Returns the audio path.
|
|
"""
|
|
from . import versions as versions_mod
|
|
|
|
import uuid as _uuid
|
|
|
|
suffix = _uuid.uuid4().hex[:8]
|
|
audio_path = config.get_generations_dir() / f"{generation_id}_{suffix}.wav"
|
|
save_audio(audio, str(audio_path), sample_rate)
|
|
|
|
# Count via DB query rather than list length to avoid TOCTOU race
|
|
from ..database import GenerationVersion as DBGenerationVersion
|
|
|
|
count = db.query(DBGenerationVersion).filter_by(generation_id=generation_id).count()
|
|
label = f"take-{count + 1}"
|
|
|
|
versions_mod.create_version(
|
|
generation_id=generation_id,
|
|
label=label,
|
|
audio_path=config.to_storage_path(audio_path),
|
|
db=db,
|
|
effects_chain=None,
|
|
is_default=True,
|
|
)
|
|
|
|
return config.to_storage_path(audio_path)
|