Initial commit

This commit is contained in:
2026-04-24 19:18:15 +08:00
commit fbcbe08696
555 changed files with 96692 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Services layer — generation orchestration and background task management.

View File

@@ -0,0 +1,263 @@
"""
Audio channel management module.
"""
from typing import List, Optional
from datetime import datetime
import uuid
from sqlalchemy.orm import Session
from ..models import (
AudioChannelCreate,
AudioChannelUpdate,
AudioChannelResponse,
ChannelVoiceAssignment,
ProfileChannelAssignment,
)
from ..database import (
AudioChannel as DBAudioChannel,
ChannelDeviceMapping as DBChannelDeviceMapping,
ProfileChannelMapping as DBProfileChannelMapping,
VoiceProfile as DBVoiceProfile,
)
async def list_channels(db: Session) -> List[AudioChannelResponse]:
"""List all audio channels."""
channels = db.query(DBAudioChannel).all()
result = []
for channel in channels:
# Get device IDs for this channel
device_mappings = db.query(DBChannelDeviceMapping).filter_by(
channel_id=channel.id
).all()
device_ids = [m.device_id for m in device_mappings]
result.append(AudioChannelResponse(
id=channel.id,
name=channel.name,
is_default=channel.is_default,
device_ids=device_ids,
created_at=channel.created_at,
))
return result
async def get_channel(channel_id: str, db: Session) -> Optional[AudioChannelResponse]:
"""Get a channel by ID."""
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
if not channel:
return None
# Get device IDs
device_mappings = db.query(DBChannelDeviceMapping).filter_by(
channel_id=channel.id
).all()
device_ids = [m.device_id for m in device_mappings]
return AudioChannelResponse(
id=channel.id,
name=channel.name,
is_default=channel.is_default,
device_ids=device_ids,
created_at=channel.created_at,
)
async def create_channel(
data: AudioChannelCreate,
db: Session,
) -> AudioChannelResponse:
"""Create a new audio channel."""
# Check if name already exists
existing = db.query(DBAudioChannel).filter_by(name=data.name).first()
if existing:
raise ValueError(f"Channel with name '{data.name}' already exists")
# Create channel
channel = DBAudioChannel(
id=str(uuid.uuid4()),
name=data.name,
is_default=False,
created_at=datetime.utcnow(),
)
db.add(channel)
db.flush()
# Add device mappings
for device_id in data.device_ids:
mapping = DBChannelDeviceMapping(
id=str(uuid.uuid4()),
channel_id=channel.id,
device_id=device_id,
)
db.add(mapping)
db.commit()
db.refresh(channel)
return AudioChannelResponse(
id=channel.id,
name=channel.name,
is_default=channel.is_default,
device_ids=data.device_ids,
created_at=channel.created_at,
)
async def update_channel(
channel_id: str,
data: AudioChannelUpdate,
db: Session,
) -> Optional[AudioChannelResponse]:
"""Update an audio channel."""
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
if not channel:
return None
if channel.is_default:
raise ValueError("Cannot modify the default channel")
# Update name if provided
if data.name is not None:
# Check if name already exists (excluding current channel)
existing = db.query(DBAudioChannel).filter(
DBAudioChannel.name == data.name,
DBAudioChannel.id != channel_id
).first()
if existing:
raise ValueError(f"Channel with name '{data.name}' already exists")
channel.name = data.name
# Update device mappings if provided
if data.device_ids is not None:
# Delete existing mappings
db.query(DBChannelDeviceMapping).filter_by(channel_id=channel_id).delete()
# Add new mappings
for device_id in data.device_ids:
mapping = DBChannelDeviceMapping(
id=str(uuid.uuid4()),
channel_id=channel.id,
device_id=device_id,
)
db.add(mapping)
db.commit()
db.refresh(channel)
# Get updated device IDs
device_mappings = db.query(DBChannelDeviceMapping).filter_by(
channel_id=channel.id
).all()
device_ids = [m.device_id for m in device_mappings]
return AudioChannelResponse(
id=channel.id,
name=channel.name,
is_default=channel.is_default,
device_ids=device_ids,
created_at=channel.created_at,
)
async def delete_channel(channel_id: str, db: Session) -> bool:
"""Delete an audio channel."""
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
if not channel:
return False
if channel.is_default:
raise ValueError("Cannot delete the default channel")
# Delete device mappings
db.query(DBChannelDeviceMapping).filter_by(channel_id=channel_id).delete()
# Delete profile-channel mappings
db.query(DBProfileChannelMapping).filter_by(channel_id=channel_id).delete()
# Delete channel
db.delete(channel)
db.commit()
return True
async def get_channel_voices(channel_id: str, db: Session) -> List[str]:
"""Get list of profile IDs assigned to a channel."""
mappings = db.query(DBProfileChannelMapping).filter_by(
channel_id=channel_id
).all()
return [m.profile_id for m in mappings]
async def set_channel_voices(
channel_id: str,
data: ChannelVoiceAssignment,
db: Session,
) -> None:
"""Set which voices are assigned to a channel."""
# Verify channel exists
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
if not channel:
raise ValueError(f"Channel {channel_id} not found")
# Verify all profiles exist
for profile_id in data.profile_ids:
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
raise ValueError(f"Profile {profile_id} not found")
# Delete existing mappings for this channel
db.query(DBProfileChannelMapping).filter_by(channel_id=channel_id).delete()
# Add new mappings
for profile_id in data.profile_ids:
mapping = DBProfileChannelMapping(
profile_id=profile_id,
channel_id=channel_id,
)
db.add(mapping)
db.commit()
async def get_profile_channels(profile_id: str, db: Session) -> List[str]:
"""Get list of channel IDs assigned to a profile."""
mappings = db.query(DBProfileChannelMapping).filter_by(
profile_id=profile_id
).all()
return [m.channel_id for m in mappings]
async def set_profile_channels(
profile_id: str,
data: ProfileChannelAssignment,
db: Session,
) -> None:
"""Set which channels a profile is assigned to."""
# Verify profile exists
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
raise ValueError(f"Profile {profile_id} not found")
# Verify all channels exist
for channel_id in data.channel_ids:
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
if not channel:
raise ValueError(f"Channel {channel_id} not found")
# Delete existing mappings for this profile
db.query(DBProfileChannelMapping).filter_by(profile_id=profile_id).delete()
# Add new mappings
for channel_id in data.channel_ids:
mapping = DBProfileChannelMapping(
profile_id=profile_id,
channel_id=channel_id,
)
db.add(mapping)
db.commit()

422
backend/services/cuda.py Normal file
View File

@@ -0,0 +1,422 @@
"""
CUDA backend download, assembly, and verification.
Downloads two archives from GitHub Releases:
1. Server core (voicebox-server-cuda.tar.gz) — the exe + non-NVIDIA deps,
versioned with the app.
2. CUDA libs (cuda-libs-{version}.tar.gz) — NVIDIA runtime libraries,
versioned independently (only redownloaded on CUDA toolkit bump).
Both archives are extracted into {data_dir}/backends/cuda/ which forms the
complete PyInstaller --onedir directory structure that torch expects.
"""
import asyncio
import hashlib
import json
import logging
import os
import sys
import tarfile
from pathlib import Path
from typing import Optional
from ..config import get_data_dir
from ..utils.progress import get_progress_manager
from .. import __version__
logger = logging.getLogger(__name__)
GITHUB_RELEASES_URL = "https://github.com/jamiepine/voicebox/releases/download"
PROGRESS_KEY = "cuda-backend"
# The current expected CUDA libs version. Bump this when we change the
# CUDA toolkit version or torch's CUDA dependency changes (e.g. cu126 -> cu128).
CUDA_LIBS_VERSION = "cu128-v1"
# Prevents concurrent download_cuda_binary() calls from racing on the same
# temp file. The auto-update background task and the manual HTTP endpoint
# can both invoke download_cuda_binary(); without this lock the progress-
# manager status check is a TOCTOU race.
_download_lock = asyncio.Lock()
def get_backends_dir() -> Path:
"""Directory where downloaded backend binaries are stored."""
d = get_data_dir() / "backends"
d.mkdir(parents=True, exist_ok=True)
return d
def get_cuda_dir() -> Path:
"""Directory where the CUDA backend (onedir) is extracted."""
d = get_backends_dir() / "cuda"
d.mkdir(parents=True, exist_ok=True)
return d
def get_cuda_exe_name() -> str:
"""Platform-specific CUDA executable filename."""
if sys.platform == "win32":
return "voicebox-server-cuda.exe"
return "voicebox-server-cuda"
def get_cuda_binary_path() -> Optional[Path]:
"""Return path to the CUDA executable if it exists inside the onedir."""
p = get_cuda_dir() / get_cuda_exe_name()
if p.exists():
return p
return None
def get_cuda_libs_manifest_path() -> Path:
"""Path to the cuda-libs.json manifest inside the CUDA dir."""
return get_cuda_dir() / "cuda-libs.json"
def get_installed_cuda_libs_version() -> Optional[str]:
"""Read the installed CUDA libs version from cuda-libs.json, or None."""
manifest_path = get_cuda_libs_manifest_path()
if not manifest_path.exists():
return None
try:
data = json.loads(manifest_path.read_text())
return data.get("version")
except Exception as e:
logger.warning(f"Could not read cuda-libs.json: {e}")
return None
def is_cuda_active() -> bool:
"""Check if the current process is the CUDA binary.
The CUDA binary sets this env var on startup (see server.py).
"""
return os.environ.get("VOICEBOX_BACKEND_VARIANT") == "cuda"
def get_cuda_status() -> dict:
"""Get current CUDA backend status for the API."""
progress_manager = get_progress_manager()
cuda_path = get_cuda_binary_path()
progress = progress_manager.get_progress(PROGRESS_KEY)
cuda_libs_version = get_installed_cuda_libs_version()
return {
"available": cuda_path is not None,
"active": is_cuda_active(),
"binary_path": str(cuda_path) if cuda_path else None,
"cuda_libs_version": cuda_libs_version,
"downloading": progress is not None and progress.get("status") == "downloading",
"download_progress": progress,
}
def _needs_server_download(version: Optional[str] = None) -> bool:
"""Check if the server core archive needs to be (re)downloaded."""
cuda_path = get_cuda_binary_path()
if not cuda_path:
return True
# Check if the binary version matches the expected app version
installed = get_cuda_binary_version()
expected = version or __version__
if expected.startswith("v"):
expected = expected[1:]
return installed != expected
def _needs_cuda_libs_download() -> bool:
"""Check if the CUDA libs archive needs to be (re)downloaded."""
installed = get_installed_cuda_libs_version()
if installed is None:
return True
return installed != CUDA_LIBS_VERSION
async def _download_and_extract_archive(
client,
url: str,
sha256_url: Optional[str],
dest_dir: Path,
label: str,
progress_offset: int,
total_size: int,
):
"""Download a .tar.gz archive and extract it into dest_dir.
Args:
client: httpx.AsyncClient
url: URL of the .tar.gz archive
sha256_url: URL of the .sha256 checksum file (optional)
dest_dir: Directory to extract into
label: Human-readable label for progress updates
progress_offset: Byte offset for progress reporting (when downloading
multiple archives sequentially)
total_size: Total bytes across all downloads (for progress bar)
"""
progress = get_progress_manager()
temp_path = dest_dir / f".download-{label.replace(' ', '-')}.tmp"
# Clean up leftover partial download
if temp_path.exists():
temp_path.unlink()
# Fetch expected checksum (fail-fast: never extract an unverified archive)
expected_sha = None
if sha256_url:
try:
sha_resp = await client.get(sha256_url)
sha_resp.raise_for_status()
expected_sha = sha_resp.text.strip().split()[0]
logger.info(f"{label}: expected SHA-256: {expected_sha[:16]}...")
except Exception as e:
raise RuntimeError(f"{label}: failed to fetch checksum from {sha256_url}") from e
# Stream download, verify, and extract — always clean up temp file
downloaded = 0
try:
async with client.stream("GET", url) as response:
response.raise_for_status()
with open(temp_path, "wb") as f:
async for chunk in response.aiter_bytes(chunk_size=1024 * 1024):
f.write(chunk)
downloaded += len(chunk)
progress.update_progress(
PROGRESS_KEY,
current=progress_offset + downloaded,
total=total_size,
filename=f"Downloading {label}",
status="downloading",
)
# Verify integrity
if expected_sha:
progress.update_progress(
PROGRESS_KEY,
current=progress_offset + downloaded,
total=total_size,
filename=f"Verifying {label}...",
status="downloading",
)
sha256 = hashlib.sha256()
with open(temp_path, "rb") as f:
while True:
data = f.read(1024 * 1024)
if not data:
break
sha256.update(data)
actual = sha256.hexdigest()
if actual != expected_sha:
raise ValueError(
f"{label} integrity check failed: expected {expected_sha[:16]}..., got {actual[:16]}..."
)
logger.info(f"{label}: integrity verified")
# Extract (use data filter for path traversal protection on Python 3.12+)
progress.update_progress(
PROGRESS_KEY,
current=progress_offset + downloaded,
total=total_size,
filename=f"Extracting {label}...",
status="downloading",
)
with tarfile.open(temp_path, "r:gz") as tar:
if sys.version_info >= (3, 12):
tar.extractall(path=dest_dir, filter="data")
else:
tar.extractall(path=dest_dir)
logger.info(f"{label}: extracted to {dest_dir}")
finally:
if temp_path.exists():
temp_path.unlink()
return downloaded
async def download_cuda_binary(version: Optional[str] = None):
"""Download the CUDA backend (server core + CUDA libs if needed).
Downloads both archives from GitHub Releases, extracts them into
{data_dir}/backends/cuda/, and writes the cuda-libs.json manifest.
Only downloads what's needed:
- Server core: always redownloaded (versioned with app)
- CUDA libs: only if missing or version mismatch
Args:
version: Version tag (e.g. "v0.3.0"). Defaults to current app version.
"""
if _download_lock.locked():
logger.info("CUDA download already in progress, skipping duplicate request")
return
async with _download_lock:
await _download_cuda_binary_locked(version)
async def _download_cuda_binary_locked(version: Optional[str] = None):
"""Inner implementation of download_cuda_binary, called under _download_lock."""
import httpx
if version is None:
version = f"v{__version__}"
progress = get_progress_manager()
cuda_dir = get_cuda_dir()
need_server = _needs_server_download(version)
need_libs = _needs_cuda_libs_download()
if not need_server and not need_libs:
logger.info("CUDA backend is up to date, nothing to download")
return
logger.info(
f"Starting CUDA backend download for {version} "
f"(server={'yes' if need_server else 'cached'}, "
f"libs={'yes' if need_libs else 'cached'})"
)
progress.update_progress(
PROGRESS_KEY,
current=0,
total=0,
filename="Preparing download...",
status="downloading",
)
base_url = f"{GITHUB_RELEASES_URL}/{version}"
server_archive = "voicebox-server-cuda.tar.gz"
libs_archive = f"cuda-libs-{CUDA_LIBS_VERSION}.tar.gz"
try:
async with httpx.AsyncClient(follow_redirects=True, timeout=30.0) as client:
# Estimate total download size
total_size = 0
if need_server:
try:
head = await client.head(f"{base_url}/{server_archive}")
total_size += int(head.headers.get("content-length", 0))
except Exception:
pass
if need_libs:
try:
head = await client.head(f"{base_url}/{libs_archive}")
total_size += int(head.headers.get("content-length", 0))
except Exception:
pass
logger.info(f"Total download size: {total_size / 1024 / 1024:.1f} MB")
offset = 0
# Download server core
if need_server:
server_downloaded = await _download_and_extract_archive(
client,
url=f"{base_url}/{server_archive}",
sha256_url=f"{base_url}/{server_archive}.sha256",
dest_dir=cuda_dir,
label="CUDA server",
progress_offset=offset,
total_size=total_size,
)
offset += server_downloaded
# Make executable on Unix
exe_path = cuda_dir / get_cuda_exe_name()
if sys.platform != "win32" and exe_path.exists():
exe_path.chmod(0o755)
# Download CUDA libs
if need_libs:
await _download_and_extract_archive(
client,
url=f"{base_url}/{libs_archive}",
sha256_url=f"{base_url}/{libs_archive}.sha256",
dest_dir=cuda_dir,
label="CUDA libraries",
progress_offset=offset,
total_size=total_size,
)
# Write local cuda-libs.json manifest
manifest = {"version": CUDA_LIBS_VERSION}
get_cuda_libs_manifest_path().write_text(json.dumps(manifest, indent=2) + "\n")
logger.info(f"CUDA backend ready at {cuda_dir}")
progress.mark_complete(PROGRESS_KEY)
except Exception as e:
logger.error(f"CUDA backend download failed: {e}")
progress.mark_error(PROGRESS_KEY, str(e))
raise
def get_cuda_binary_version() -> Optional[str]:
"""Get the version of the installed CUDA binary, or None if not installed."""
import subprocess
cuda_path = get_cuda_binary_path()
if not cuda_path:
return None
try:
result = subprocess.run(
[str(cuda_path), "--version"],
capture_output=True,
text=True,
timeout=30,
cwd=str(cuda_path.parent), # Run from the onedir directory
)
# Output format: "voicebox-server 0.3.0"
for line in result.stdout.strip().splitlines():
if "voicebox-server" in line:
return line.split()[-1]
except Exception as e:
logger.warning(f"Could not get CUDA binary version: {e}")
return None
async def check_and_update_cuda_binary():
"""Check if the CUDA binary is outdated and auto-download if so.
Called on server startup. Checks both server version and CUDA libs
version. Downloads only what's needed.
"""
cuda_path = get_cuda_binary_path()
if not cuda_path:
return # No CUDA binary installed, nothing to update
need_server = _needs_server_download()
need_libs = _needs_cuda_libs_download()
if not need_server and not need_libs:
logger.info(f"CUDA binary is up to date (server=v{__version__}, libs={get_installed_cuda_libs_version()})")
return
reasons = []
if need_server:
cuda_version = get_cuda_binary_version()
reasons.append(f"server v{cuda_version} != v{__version__}")
if need_libs:
installed_libs = get_installed_cuda_libs_version()
reasons.append(f"libs {installed_libs} != {CUDA_LIBS_VERSION}")
logger.info(f"CUDA backend needs update ({', '.join(reasons)}). Auto-downloading...")
try:
await download_cuda_binary()
except Exception as e:
logger.error(f"Auto-update of CUDA binary failed: {e}")
async def delete_cuda_binary() -> bool:
"""Delete the downloaded CUDA backend directory. Returns True if deleted."""
import shutil
cuda_dir = get_cuda_dir()
if cuda_dir.exists() and any(cuda_dir.iterdir()):
shutil.rmtree(cuda_dir)
logger.info(f"Deleted CUDA backend directory: {cuda_dir}")
return True
return False

120
backend/services/effects.py Normal file
View File

@@ -0,0 +1,120 @@
"""
Effect presets CRUD operations.
"""
from __future__ import annotations
import json
import uuid
from typing import List, Optional
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from ..utils.effects import validate_effects_chain
from ..database import EffectPreset as DBEffectPreset
from ..models import EffectPresetResponse, EffectPresetCreate, EffectPresetUpdate, EffectConfig
def _preset_response(p: DBEffectPreset) -> EffectPresetResponse:
"""Convert a DB preset row to a Pydantic response."""
effects_chain = [EffectConfig(**e) for e in json.loads(p.effects_chain)]
return EffectPresetResponse(
id=p.id,
name=p.name,
description=p.description,
effects_chain=effects_chain,
is_builtin=p.is_builtin or False,
created_at=p.created_at,
)
def list_presets(db: Session) -> List[EffectPresetResponse]:
"""List all effect presets (built-in + user-created)."""
presets = db.query(DBEffectPreset).order_by(DBEffectPreset.sort_order, DBEffectPreset.name).all()
return [_preset_response(p) for p in presets]
def get_preset(preset_id: str, db: Session) -> Optional[EffectPresetResponse]:
"""Get a preset by ID."""
p = db.query(DBEffectPreset).filter_by(id=preset_id).first()
if not p:
return None
return _preset_response(p)
def get_preset_by_name(name: str, db: Session) -> Optional[EffectPresetResponse]:
"""Get a preset by name."""
p = db.query(DBEffectPreset).filter_by(name=name).first()
if not p:
return None
return _preset_response(p)
def create_preset(data: EffectPresetCreate, db: Session) -> EffectPresetResponse:
"""Create a new user effect preset."""
chain_dicts = [e.model_dump() for e in data.effects_chain]
error = validate_effects_chain(chain_dicts)
if error:
raise ValueError(error)
# Check for duplicate name before insert
existing = db.query(DBEffectPreset).filter_by(name=data.name).first()
if existing:
raise ValueError(f"A preset named '{data.name}' already exists")
preset = DBEffectPreset(
id=str(uuid.uuid4()),
name=data.name,
description=data.description,
effects_chain=json.dumps(chain_dicts),
is_builtin=False,
)
db.add(preset)
try:
db.commit()
except IntegrityError:
db.rollback()
raise ValueError(f"A preset named '{data.name}' already exists")
db.refresh(preset)
return _preset_response(preset)
def update_preset(preset_id: str, data: EffectPresetUpdate, db: Session) -> Optional[EffectPresetResponse]:
"""Update a user effect preset. Cannot modify built-in presets."""
preset = db.query(DBEffectPreset).filter_by(id=preset_id).first()
if not preset:
return None
if preset.is_builtin:
raise ValueError("Cannot modify built-in presets")
if data.name is not None:
preset.name = data.name
if data.description is not None:
preset.description = data.description
if data.effects_chain is not None:
chain_dicts = [e.model_dump() for e in data.effects_chain]
error = validate_effects_chain(chain_dicts)
if error:
raise ValueError(error)
preset.effects_chain = json.dumps(chain_dicts)
db.commit()
db.refresh(preset)
return _preset_response(preset)
def delete_preset(preset_id: str, db: Session) -> bool:
"""Delete a user effect preset. Cannot delete built-in presets."""
preset = db.query(DBEffectPreset).filter_by(id=preset_id).first()
if not preset:
return False
if preset.is_builtin:
raise ValueError("Cannot delete built-in presets")
db.delete(preset)
db.commit()
return True

View File

@@ -0,0 +1,461 @@
"""
Voice profile export/import module.
Handles exporting profiles to ZIP archives and importing them back.
Also handles exporting individual generations.
"""
import json
import zipfile
import io
from pathlib import Path
from typing import Optional
from sqlalchemy.orm import Session
from ..models import VoiceProfileResponse
from ..database import VoiceProfile as DBVoiceProfile, ProfileSample as DBProfileSample, Generation as DBGeneration, GenerationVersion as DBGenerationVersion
from .profiles import create_profile, add_profile_sample
from ..models import VoiceProfileCreate
from .. import config
def _get_unique_profile_name(name: str, db: Session) -> str:
"""
Get a unique profile name by appending a number if needed.
Args:
name: Original profile name
db: Database session
Returns:
Unique profile name
"""
base_name = name
counter = 1
while True:
existing = db.query(DBVoiceProfile).filter_by(name=name).first()
if not existing:
return name
name = f"{base_name} ({counter})"
counter += 1
def export_profile_to_zip(profile_id: str, db: Session) -> bytes:
"""
Export a voice profile to a ZIP archive.
Args:
profile_id: Profile ID to export
db: Database session
Returns:
ZIP file contents as bytes
Raises:
ValueError: If profile not found or has no samples
"""
# Get profile
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
raise ValueError(f"Profile {profile_id} not found")
# Get all samples
samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all()
if not samples:
raise ValueError(f"Profile {profile_id} has no samples")
# Create ZIP in memory
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# Check if profile has avatar
has_avatar = False
if profile.avatar_path:
avatar_path = config.resolve_storage_path(profile.avatar_path)
if avatar_path is not None and avatar_path.exists():
has_avatar = True
# Add avatar to ZIP root with original extension
avatar_ext = avatar_path.suffix
zip_file.write(avatar_path, f"avatar{avatar_ext}")
# Create manifest.json
manifest = {
"version": "1.0",
"profile": {
"name": profile.name,
"description": profile.description,
"language": profile.language,
},
"has_avatar": has_avatar,
}
zip_file.writestr("manifest.json", json.dumps(manifest, indent=2))
# Create samples.json mapping
samples_data = {}
profile_dir = config.get_profiles_dir() / profile_id
for sample in samples:
# Get filename from audio_path (should be {sample_id}.wav)
audio_path = config.resolve_storage_path(sample.audio_path)
if audio_path is None:
raise ValueError(f"Audio file not found: {sample.audio_path}")
filename = audio_path.name
# Read audio file
if not audio_path.exists():
raise ValueError(f"Audio file not found: {audio_path}")
# Add to samples directory in ZIP
zip_path = f"samples/{filename}"
zip_file.write(audio_path, zip_path)
# Map filename to reference text
samples_data[filename] = sample.reference_text
zip_file.writestr("samples.json", json.dumps(samples_data, indent=2))
zip_buffer.seek(0)
return zip_buffer.read()
async def import_profile_from_zip(file_bytes: bytes, db: Session) -> VoiceProfileResponse:
"""
Import a voice profile from a ZIP archive.
Args:
file_bytes: ZIP file contents
db: Database session
Returns:
Created profile
Raises:
ValueError: If ZIP is invalid or missing required files
"""
zip_buffer = io.BytesIO(file_bytes)
try:
with zipfile.ZipFile(zip_buffer, 'r') as zip_file:
# Validate ZIP structure
namelist = zip_file.namelist()
if "manifest.json" not in namelist:
raise ValueError("ZIP archive missing manifest.json")
if "samples.json" not in namelist:
raise ValueError("ZIP archive missing samples.json")
# Read manifest
manifest_data = json.loads(zip_file.read("manifest.json"))
if "version" not in manifest_data:
raise ValueError("Invalid manifest.json: missing version")
if "profile" not in manifest_data:
raise ValueError("Invalid manifest.json: missing profile")
profile_data = manifest_data["profile"]
# Read samples mapping
samples_data = json.loads(zip_file.read("samples.json"))
if not isinstance(samples_data, dict):
raise ValueError("Invalid samples.json: must be a dictionary")
# Get unique profile name
original_name = profile_data.get("name", "Imported Profile")
unique_name = _get_unique_profile_name(original_name, db)
# Create profile
profile_create = VoiceProfileCreate(
name=unique_name,
description=profile_data.get("description"),
language=profile_data.get("language", "en"),
)
profile = await create_profile(profile_create, db)
# Extract and add samples
profile_dir = config.get_profiles_dir() / profile.id
profile_dir.mkdir(parents=True, exist_ok=True)
# Handle avatar if present
avatar_files = [f for f in namelist if f.startswith("avatar.")]
if avatar_files:
try:
avatar_file = avatar_files[0]
# Extract to temporary file
import tempfile
with tempfile.NamedTemporaryFile(suffix=Path(avatar_file).suffix, delete=False) as tmp:
tmp.write(zip_file.read(avatar_file))
tmp_path = tmp.name
try:
from .profiles import upload_avatar
await upload_avatar(profile.id, tmp_path, db)
finally:
Path(tmp_path).unlink(missing_ok=True)
except Exception as e:
# Avatar import is optional - continue even if it fails
pass
for filename, reference_text in samples_data.items():
# Validate filename
if not filename.endswith('.wav'):
raise ValueError(f"Invalid sample filename: {filename} (must be .wav)")
# Extract audio file to temp location
zip_path = f"samples/{filename}"
if zip_path not in namelist:
raise ValueError(f"Sample file not found in ZIP: {zip_path}")
# Extract to temporary file
import tempfile
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(zip_file.read(zip_path))
tmp_path = tmp.name
try:
# Add sample to profile
await add_profile_sample(
profile.id,
tmp_path,
reference_text,
db,
)
finally:
# Clean up temp file
Path(tmp_path).unlink(missing_ok=True)
return profile
except zipfile.BadZipFile:
raise ValueError("Invalid ZIP file")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in archive: {e}")
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Error importing profile: {str(e)}")
def export_generation_to_zip(generation_id: str, db: Session) -> bytes:
"""
Export a generation to a ZIP archive.
Args:
generation_id: Generation ID to export
db: Database session
Returns:
ZIP file contents as bytes
Raises:
ValueError: If generation not found
"""
# Get generation
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
if not generation:
raise ValueError(f"Generation {generation_id} not found")
# Get profile info
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
if not profile:
raise ValueError(f"Profile {generation.profile_id} not found")
# Get all versions for this generation
versions = (
db.query(DBGenerationVersion)
.filter_by(generation_id=generation_id)
.order_by(DBGenerationVersion.created_at)
.all()
)
# Create ZIP in memory
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# Build version manifest entries
version_entries = []
for v in versions:
v_path = config.resolve_storage_path(v.audio_path)
effects_chain = None
if v.effects_chain:
effects_chain = json.loads(v.effects_chain)
version_entries.append({
"id": v.id,
"label": v.label,
"is_default": v.is_default,
"effects_chain": effects_chain,
"filename": v_path.name,
})
manifest = {
"version": "1.0",
"generation": {
"id": generation.id,
"text": generation.text,
"language": generation.language,
"duration": generation.duration,
"seed": generation.seed,
"instruct": generation.instruct,
"created_at": generation.created_at.isoformat(),
},
"profile": {
"id": profile.id,
"name": profile.name,
"description": profile.description,
"language": profile.language,
},
"versions": version_entries,
}
zip_file.writestr("manifest.json", json.dumps(manifest, indent=2))
# Add all version audio files
for v in versions:
v_path = config.resolve_storage_path(v.audio_path)
if v_path is not None and v_path.exists():
zip_file.write(v_path, f"audio/{v_path.name}")
# Fallback: if no versions exist, include the generation's main audio
if not versions:
audio_path = config.resolve_storage_path(generation.audio_path)
if audio_path is not None and audio_path.exists():
zip_file.write(audio_path, f"audio/{audio_path.name}")
zip_buffer.seek(0)
return zip_buffer.read()
async def import_generation_from_zip(file_bytes: bytes, db: Session) -> dict:
"""
Import a generation from a ZIP archive.
Args:
file_bytes: ZIP file contents
db: Database session
Returns:
Dictionary with generation ID and profile info
Raises:
ValueError: If ZIP is invalid or missing required files
"""
from pathlib import Path
import tempfile
import shutil
from datetime import datetime
from .. import config
zip_buffer = io.BytesIO(file_bytes)
try:
with zipfile.ZipFile(zip_buffer, 'r') as zip_file:
# Validate ZIP structure
namelist = zip_file.namelist()
if "manifest.json" not in namelist:
raise ValueError("ZIP archive missing manifest.json")
# Read manifest
manifest_data = json.loads(zip_file.read("manifest.json"))
if "version" not in manifest_data:
raise ValueError("Invalid manifest.json: missing version")
if "generation" not in manifest_data:
raise ValueError("Invalid manifest.json: missing generation data")
generation_data = manifest_data["generation"]
profile_data = manifest_data.get("profile", {})
# Validate required fields
required_fields = ["text", "language", "duration"]
for field in required_fields:
if field not in generation_data:
raise ValueError(f"Invalid manifest.json: missing generation.{field}")
# Find audio file in archive
audio_files = [f for f in namelist if f.startswith("audio/") and f.endswith(".wav")]
if not audio_files:
raise ValueError("No audio file found in ZIP archive")
audio_file_path = audio_files[0]
# Check if we should match an existing profile or create metadata
profile_id = None
profile_name = profile_data.get("name", "Unknown Profile")
# Try to find matching profile by name
if profile_name and profile_name != "Unknown Profile":
existing_profile = db.query(DBVoiceProfile).filter_by(name=profile_name).first()
if existing_profile:
profile_id = existing_profile.id
# If no matching profile, use a placeholder or the first available profile
if not profile_id:
# Get any profile, or None if no profiles exist
any_profile = db.query(DBVoiceProfile).first()
if any_profile:
profile_id = any_profile.id
profile_name = any_profile.name
else:
raise ValueError("No voice profiles found. Please create a profile before importing generations.")
# Extract audio file to temporary location
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(zip_file.read(audio_file_path))
tmp_path = tmp.name
try:
# Create generations directory
generations_dir = config.get_generations_dir()
generations_dir.mkdir(parents=True, exist_ok=True)
# Generate new ID for this generation
new_generation_id = str(__import__('uuid').uuid4())
# Copy audio to generations directory
audio_dest = generations_dir / f"{new_generation_id}.wav"
shutil.copy(tmp_path, audio_dest)
# Create generation record
db_generation = DBGeneration(
id=new_generation_id,
profile_id=profile_id,
text=generation_data["text"],
language=generation_data["language"],
audio_path=config.to_storage_path(audio_dest),
duration=generation_data["duration"],
seed=generation_data.get("seed"),
instruct=generation_data.get("instruct"),
created_at=datetime.utcnow(),
)
db.add(db_generation)
db.commit()
db.refresh(db_generation)
return {
"id": db_generation.id,
"profile_id": profile_id,
"profile_name": profile_name,
"text": db_generation.text,
"message": f"Generation imported successfully (assigned to profile: {profile_name})"
}
finally:
# Clean up temp file
Path(tmp_path).unlink(missing_ok=True)
except zipfile.BadZipFile:
raise ValueError("Invalid ZIP file")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in archive: {e}")
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Error importing generation: {str(e)}")

View File

@@ -0,0 +1,263 @@
"""
Unified TTS generation orchestration.
Replaces the three near-identical closures (_run_generation, _run_retry,
_run_regenerate) that lived in main.py with a single ``run_generation()``
function parameterized by *mode*.
Mode differences:
- "generate" : full pipeline -- save clean version, optionally apply
effects and create a processed version.
- "retry" : re-runs a failed generation with the same seed.
No effects, no version creation.
- "regenerate" : re-runs with seed=None for variation. Creates a new
version with an auto-incremented "take-N" label.
"""
from __future__ import annotations
import asyncio
import traceback
from typing import Literal, Optional
from .. import config
from . import history, profiles
from ..database import get_db
from ..utils.tasks import get_task_manager
async def run_generation(
*,
generation_id: str,
profile_id: str,
text: str,
language: str,
engine: str,
model_size: str,
seed: Optional[int],
normalize: bool = False,
effects_chain: Optional[list] = None,
instruct: Optional[str] = None,
mode: Literal["generate", "retry", "regenerate"],
max_chunk_chars: Optional[int] = None,
crossfade_ms: Optional[int] = None,
version_id: Optional[str] = None,
) -> None:
"""Execute TTS inference and persist the result.
This is the single entry point for all background generation work.
It is designed to be enqueued via ``services.task_queue.enqueue_generation``.
"""
from ..backends import load_engine_model, get_tts_backend_for_engine, engine_needs_trim
from ..utils.chunked_tts import generate_chunked
from ..utils.audio import normalize_audio, save_audio, trim_tts_output
task_manager = get_task_manager()
bg_db = next(get_db())
try:
tts_model = get_tts_backend_for_engine(engine)
if not tts_model.is_loaded():
await history.update_generation_status(generation_id, "loading_model", bg_db)
await load_engine_model(engine, model_size)
voice_prompt = await profiles.create_voice_prompt_for_profile(
profile_id,
bg_db,
use_cache=True,
engine=engine,
)
await history.update_generation_status(generation_id, "generating", bg_db)
trim_fn = trim_tts_output if engine_needs_trim(engine) else None
gen_kwargs: dict = dict(
language=language,
seed=seed if mode != "regenerate" else None,
instruct=instruct,
trim_fn=trim_fn,
)
if max_chunk_chars is not None:
gen_kwargs["max_chunk_chars"] = max_chunk_chars
if crossfade_ms is not None:
gen_kwargs["crossfade_ms"] = crossfade_ms
audio, sample_rate = await generate_chunked(tts_model, text, voice_prompt, **gen_kwargs)
# --- Normalize (generate and regenerate always; retry skips) -----
if normalize or mode == "regenerate":
audio = normalize_audio(audio)
duration = len(audio) / sample_rate
# --- Persist audio and update status -----------------------------
if mode == "generate":
final_path = _save_generate(
generation_id=generation_id,
audio=audio,
sample_rate=sample_rate,
effects_chain=effects_chain,
save_audio=save_audio,
db=bg_db,
)
elif mode == "retry":
final_path = _save_retry(
generation_id=generation_id,
audio=audio,
sample_rate=sample_rate,
save_audio=save_audio,
)
elif mode == "regenerate":
final_path = _save_regenerate(
generation_id=generation_id,
version_id=version_id,
audio=audio,
sample_rate=sample_rate,
save_audio=save_audio,
db=bg_db,
)
await history.update_generation_status(
generation_id=generation_id,
status="completed",
db=bg_db,
audio_path=final_path,
duration=duration,
)
except asyncio.CancelledError:
await history.update_generation_status(
generation_id=generation_id,
status="failed",
db=bg_db,
error="Generation cancelled",
)
except Exception as e:
traceback.print_exc()
await history.update_generation_status(
generation_id=generation_id,
status="failed",
db=bg_db,
error=str(e),
)
finally:
task_manager.complete_generation(generation_id)
bg_db.close()
def _save_generate(
*,
generation_id: str,
audio,
sample_rate: int,
effects_chain: Optional[list],
save_audio,
db,
) -> str:
"""Save clean version and optionally an effects-processed version.
Returns the final audio path (processed if effects were applied,
otherwise clean).
"""
from . import versions as versions_mod
clean_audio_path = config.get_generations_dir() / f"{generation_id}.wav"
save_audio(audio, str(clean_audio_path), sample_rate)
has_effects = effects_chain and any(e.get("enabled", True) for e in effects_chain)
versions_mod.create_version(
generation_id=generation_id,
label="original",
audio_path=config.to_storage_path(clean_audio_path),
db=db,
effects_chain=None,
is_default=not has_effects,
)
final_audio_path = str(clean_audio_path)
if has_effects:
from ..utils.effects import apply_effects, validate_effects_chain
assert effects_chain is not None
error_msg = validate_effects_chain(effects_chain)
if error_msg:
import logging
logging.getLogger(__name__).warning("invalid effects chain, skipping: %s", error_msg)
versions_mod.set_default_version(
versions_mod.list_versions(generation_id, db)[0].id, db
)
else:
processed_audio = apply_effects(audio, sample_rate, effects_chain)
processed_path = config.get_generations_dir() / f"{generation_id}_processed.wav"
save_audio(processed_audio, str(processed_path), sample_rate)
final_audio_path = str(processed_path)
versions_mod.create_version(
generation_id=generation_id,
label="version-2",
audio_path=config.to_storage_path(processed_path),
db=db,
effects_chain=effects_chain,
is_default=True,
)
return config.to_storage_path(final_audio_path)
def _save_retry(
*,
generation_id: str,
audio,
sample_rate: int,
save_audio,
) -> str:
"""Save retry output -- single file, no versions.
Returns the audio path.
"""
audio_path = config.get_generations_dir() / f"{generation_id}.wav"
save_audio(audio, str(audio_path), sample_rate)
return config.to_storage_path(audio_path)
def _save_regenerate(
*,
generation_id: str,
version_id: Optional[str],
audio,
sample_rate: int,
save_audio,
db,
) -> str:
"""Save regeneration output as a new version with auto-label.
Returns the audio path.
"""
from . import versions as versions_mod
import uuid as _uuid
suffix = _uuid.uuid4().hex[:8]
audio_path = config.get_generations_dir() / f"{generation_id}_{suffix}.wav"
save_audio(audio, str(audio_path), sample_rate)
# Count via DB query rather than list length to avoid TOCTOU race
from ..database import GenerationVersion as DBGenerationVersion
count = db.query(DBGenerationVersion).filter_by(generation_id=generation_id).count()
label = f"take-{count + 1}"
versions_mod.create_version(
generation_id=generation_id,
label=label,
audio_path=config.to_storage_path(audio_path),
db=db,
effects_chain=None,
is_default=True,
)
return config.to_storage_path(audio_path)

368
backend/services/history.py Normal file
View File

@@ -0,0 +1,368 @@
"""
Generation history management module.
"""
from typing import List, Optional, Tuple
from datetime import datetime
import uuid
import shutil
from pathlib import Path
from sqlalchemy.orm import Session
from sqlalchemy import or_
from ..models import GenerationRequest, GenerationResponse, HistoryQuery, HistoryResponse, HistoryListResponse, GenerationVersionResponse, EffectConfig
from ..database import Generation as DBGeneration, GenerationVersion as DBGenerationVersion, VoiceProfile as DBVoiceProfile
from .. import config
def _get_versions_for_generation(generation_id: str, db: Session) -> tuple:
"""Get versions list and active version ID for a generation."""
import json
versions_rows = (
db.query(DBGenerationVersion)
.filter_by(generation_id=generation_id)
.order_by(DBGenerationVersion.created_at)
.all()
)
if not versions_rows:
return None, None
versions = []
active_version_id = None
for v in versions_rows:
effects_chain = None
if v.effects_chain:
try:
raw = json.loads(v.effects_chain)
effects_chain = [EffectConfig(**e) for e in raw]
except Exception:
pass
versions.append(GenerationVersionResponse(
id=v.id,
generation_id=v.generation_id,
label=v.label,
audio_path=v.audio_path,
effects_chain=effects_chain,
is_default=v.is_default,
created_at=v.created_at,
))
if v.is_default:
active_version_id = v.id
return versions, active_version_id
async def create_generation(
profile_id: str,
text: str,
language: str,
audio_path: str,
duration: float,
seed: Optional[int],
db: Session,
instruct: Optional[str] = None,
generation_id: Optional[str] = None,
status: str = "completed",
engine: Optional[str] = "qwen",
model_size: Optional[str] = None,
) -> GenerationResponse:
"""
Create a new generation history entry.
Args:
profile_id: Profile ID used for generation
text: Generated text
language: Language code
audio_path: Path where audio was saved
duration: Audio duration in seconds
seed: Random seed used (if any)
db: Database session
instruct: Natural language instruction used (if any)
generation_id: Pre-assigned ID (for async generation flow)
status: Generation status (generating, completed, failed)
engine: TTS engine used (qwen, luxtts, chatterbox, chatterbox_turbo)
model_size: Model size variant (1.7B, 0.6B) — only relevant for qwen
Returns:
Created generation entry
"""
db_generation = DBGeneration(
id=generation_id or str(uuid.uuid4()),
profile_id=profile_id,
text=text,
language=language,
audio_path=audio_path,
duration=duration,
seed=seed,
instruct=instruct,
engine=engine,
model_size=model_size,
status=status,
created_at=datetime.utcnow(),
)
db.add(db_generation)
db.commit()
db.refresh(db_generation)
return GenerationResponse.model_validate(db_generation)
async def update_generation_status(
generation_id: str,
status: str,
db: Session,
audio_path: Optional[str] = None,
duration: Optional[float] = None,
error: Optional[str] = None,
) -> Optional[GenerationResponse]:
"""Update the status of a generation (used by async generation flow)."""
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
if not generation:
return None
generation.status = status
if audio_path is not None:
generation.audio_path = audio_path
if duration is not None:
generation.duration = duration
if error is not None:
generation.error = error
db.commit()
db.refresh(generation)
return GenerationResponse.model_validate(generation)
async def get_generation(
generation_id: str,
db: Session,
) -> Optional[GenerationResponse]:
"""
Get a generation by ID.
Args:
generation_id: Generation ID
db: Database session
Returns:
Generation or None if not found
"""
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
if not generation:
return None
return GenerationResponse.model_validate(generation)
async def list_generations(
query: HistoryQuery,
db: Session,
) -> HistoryListResponse:
"""
List generations with optional filters.
Args:
query: Query parameters (filters, pagination)
db: Database session
Returns:
HistoryListResponse with items and total count
"""
# Build base query with join to get profile name
q = db.query(
DBGeneration,
DBVoiceProfile.name.label('profile_name')
).join(
DBVoiceProfile,
DBGeneration.profile_id == DBVoiceProfile.id
)
# Apply profile filter
if query.profile_id:
q = q.filter(DBGeneration.profile_id == query.profile_id)
# Apply search filter (searches in text content)
if query.search:
search_pattern = f"%{query.search}%"
q = q.filter(DBGeneration.text.like(search_pattern))
# Get total count before pagination
total_count = q.count()
# Apply ordering (newest first)
q = q.order_by(DBGeneration.created_at.desc())
# Apply pagination
q = q.offset(query.offset).limit(query.limit)
# Execute query
results = q.all()
# Convert to HistoryResponse with profile_name
items = []
for generation, profile_name in results:
versions, active_version_id = _get_versions_for_generation(generation.id, db)
items.append(HistoryResponse(
id=generation.id,
profile_id=generation.profile_id,
profile_name=profile_name,
text=generation.text,
language=generation.language,
audio_path=generation.audio_path,
duration=generation.duration,
seed=generation.seed,
instruct=generation.instruct,
engine=generation.engine or "qwen",
model_size=generation.model_size,
status=generation.status or "completed",
error=generation.error,
is_favorited=bool(generation.is_favorited),
created_at=generation.created_at,
versions=versions,
active_version_id=active_version_id,
))
return HistoryListResponse(
items=items,
total=total_count,
)
async def delete_generation(
generation_id: str,
db: Session,
) -> bool:
"""
Delete a generation.
Args:
generation_id: Generation ID
db: Database session
Returns:
True if deleted, False if not found
"""
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
if not generation:
return False
# Delete all version files and records
from . import versions as versions_mod
versions_mod.delete_versions_for_generation(generation_id, db)
# Delete main audio file (if not already removed by version cleanup)
if generation.audio_path:
audio_path = config.resolve_storage_path(generation.audio_path)
if audio_path is not None and audio_path.exists():
audio_path.unlink()
# Delete from database
db.delete(generation)
db.commit()
return True
async def delete_failed_generations(db: Session) -> int:
"""
Delete every generation whose status is 'failed'.
Used by the "Clear failed" action in the UI so users can tidy up
history after the model wasn't loaded, the app was closed mid-run,
or a generation otherwise errored out (see issue #410).
Returns:
Number of generations deleted.
"""
from . import versions as versions_mod
failed = db.query(DBGeneration).filter(DBGeneration.status == "failed").all()
count = 0
for generation in failed:
# Clean up version files/rows first.
versions_mod.delete_versions_for_generation(generation.id, db)
# Remove the main audio file if it somehow made it to disk.
if generation.audio_path:
audio_path = config.resolve_storage_path(generation.audio_path)
if audio_path is not None and audio_path.exists():
try:
audio_path.unlink()
except OSError:
# Best-effort cleanup — don't abort the whole sweep
# if a single file can't be removed.
pass
db.delete(generation)
count += 1
db.commit()
return count
async def delete_generations_by_profile(
profile_id: str,
db: Session,
) -> int:
"""
Delete all generations for a profile.
Args:
profile_id: Profile ID
db: Database session
Returns:
Number of generations deleted
"""
generations = db.query(DBGeneration).filter_by(profile_id=profile_id).all()
count = 0
for generation in generations:
# Delete associated version files and rows first
from . import versions as versions_mod
versions_mod.delete_versions_for_generation(generation.id, db)
# Delete audio file
audio_path = config.resolve_storage_path(generation.audio_path)
if audio_path is not None and audio_path.exists():
audio_path.unlink()
# Delete from database
db.delete(generation)
count += 1
db.commit()
return count
async def get_generation_stats(db: Session) -> dict:
"""
Get generation statistics.
Args:
db: Database session
Returns:
Statistics dictionary
"""
from sqlalchemy import func
total = db.query(func.count(DBGeneration.id)).scalar()
total_duration = db.query(func.sum(DBGeneration.duration)).scalar() or 0
# Get generations by profile
by_profile = db.query(
DBGeneration.profile_id,
func.count(DBGeneration.id).label('count')
).group_by(DBGeneration.profile_id).all()
return {
"total_generations": total,
"total_duration_seconds": total_duration,
"generations_by_profile": {
profile_id: count for profile_id, count in by_profile
},
}

View File

@@ -0,0 +1,686 @@
"""Voice profile management module."""
import json as _json
import logging
import shutil
import uuid
from datetime import datetime
from pathlib import Path
from sqlalchemy import func
from sqlalchemy.orm import Session
from .. import config
from ..database import Generation as DBGeneration, ProfileSample as DBProfileSample, VoiceProfile as DBVoiceProfile
from ..models import (
EffectConfig,
ProfileSampleResponse,
VoiceProfileCreate,
VoiceProfileResponse,
)
from ..utils.audio import save_audio, validate_and_load_reference_audio
from ..utils.cache import _get_cache_dir, clear_profile_cache
from ..utils.images import process_avatar, validate_image
logger = logging.getLogger(__name__)
CLONING_ENGINES = {"qwen", "luxtts", "chatterbox", "chatterbox_turbo", "tada"}
def _profile_to_response(
profile: DBVoiceProfile,
generation_count: int = 0,
sample_count: int = 0,
) -> VoiceProfileResponse:
"""Convert a DB profile to a VoiceProfileResponse, deserializing effects_chain."""
effects_chain = None
if profile.effects_chain:
try:
raw = _json.loads(profile.effects_chain)
effects_chain = [EffectConfig(**e) for e in raw]
except Exception as e:
import logging
logging.warning(f"Failed to parse effects_chain for profile {profile.id}: {e}")
return VoiceProfileResponse(
id=profile.id,
name=profile.name,
description=profile.description,
language=profile.language,
avatar_path=profile.avatar_path,
effects_chain=effects_chain,
voice_type=getattr(profile, "voice_type", None) or "cloned",
preset_engine=getattr(profile, "preset_engine", None),
preset_voice_id=getattr(profile, "preset_voice_id", None),
design_prompt=getattr(profile, "design_prompt", None),
default_engine=getattr(profile, "default_engine", None),
generation_count=generation_count,
sample_count=sample_count,
created_at=profile.created_at,
updated_at=profile.updated_at,
)
def _get_preset_voice_ids(engine: str) -> set[str]:
if engine == "kokoro":
from ..backends.kokoro_backend import KOKORO_VOICES
return {voice_id for voice_id, _name, _gender, _lang in KOKORO_VOICES}
if engine == "qwen_custom_voice":
from ..backends.qwen_custom_voice_backend import QWEN_CUSTOM_VOICES
return {voice_id for voice_id, _name, _gender, _lang, _desc in QWEN_CUSTOM_VOICES}
return set()
def _validate_profile_fields(
*,
voice_type: str,
preset_engine: str | None,
preset_voice_id: str | None,
design_prompt: str | None,
default_engine: str | None,
) -> str | None:
if voice_type == "preset":
if not preset_engine or not preset_voice_id:
return "Preset profiles require both preset_engine and preset_voice_id"
if default_engine and default_engine != preset_engine:
return "Preset profiles must use their preset_engine as default_engine"
available_voice_ids = _get_preset_voice_ids(preset_engine)
if available_voice_ids and preset_voice_id not in available_voice_ids:
return f"Preset voice '{preset_voice_id}' is not valid for engine '{preset_engine}'"
return None
if voice_type == "designed":
if not design_prompt or not design_prompt.strip():
return "Designed profiles require a design_prompt"
if preset_engine or preset_voice_id:
return "Designed profiles cannot set preset_engine or preset_voice_id"
return None
if preset_engine or preset_voice_id:
return "Cloned profiles cannot set preset_engine or preset_voice_id"
if design_prompt:
return "Cloned profiles cannot set design_prompt"
if default_engine and default_engine not in CLONING_ENGINES:
return f"Cloned profiles cannot use default engine '{default_engine}'"
return None
def validate_profile_engine(profile, engine: str) -> None:
voice_type = getattr(profile, "voice_type", None) or "cloned"
if voice_type == "preset":
preset_engine = getattr(profile, "preset_engine", None)
preset_voice_id = getattr(profile, "preset_voice_id", None)
if not preset_engine or not preset_voice_id:
raise ValueError(f"Preset profile {profile.id} is missing preset engine metadata")
if preset_engine != engine:
raise ValueError(
f"Preset profile {profile.id} only supports engine '{preset_engine}', not '{engine}'"
)
return
if voice_type == "designed":
design_prompt = getattr(profile, "design_prompt", None)
if not design_prompt or not design_prompt.strip():
raise ValueError(f"Designed profile {profile.id} is missing design_prompt")
return
if engine not in CLONING_ENGINES:
raise ValueError(f"Engine '{engine}' does not support cloned voice profiles")
async def create_profile(
data: VoiceProfileCreate,
db: Session,
) -> VoiceProfileResponse:
"""
Create a new voice profile.
Args:
data: Profile creation data
db: Database session
Returns:
Created profile
Raises:
ValueError: If a profile with the same name already exists
"""
existing_profile = db.query(DBVoiceProfile).filter_by(name=data.name).first()
if existing_profile:
raise ValueError(f"A profile with the name '{data.name}' already exists. Please choose a different name.")
# Auto-set default_engine for preset profiles
default_engine = data.default_engine
voice_type = data.voice_type or "cloned"
if voice_type == "preset" and data.preset_engine and not default_engine:
default_engine = data.preset_engine
validation_error = _validate_profile_fields(
voice_type=voice_type,
preset_engine=data.preset_engine,
preset_voice_id=data.preset_voice_id,
design_prompt=data.design_prompt,
default_engine=default_engine,
)
if validation_error:
raise ValueError(validation_error)
db_profile = DBVoiceProfile(
id=str(uuid.uuid4()),
name=data.name,
description=data.description,
language=data.language,
voice_type=voice_type,
preset_engine=data.preset_engine,
preset_voice_id=data.preset_voice_id,
design_prompt=data.design_prompt,
default_engine=default_engine,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
db.add(db_profile)
db.commit()
db.refresh(db_profile)
profile_dir = config.get_profiles_dir() / db_profile.id
profile_dir.mkdir(parents=True, exist_ok=True)
return _profile_to_response(db_profile)
async def add_profile_sample(
profile_id: str,
audio_path: str,
reference_text: str,
db: Session,
) -> ProfileSampleResponse:
"""
Add a sample to a voice profile.
Args:
profile_id: Profile ID
audio_path: Path to temporary audio file
reference_text: Transcript of audio
db: Database session
Returns:
Created sample
"""
import asyncio
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
raise ValueError(f"Profile {profile_id} not found")
# Validate and load audio in a single pass, off the event loop
is_valid, error_msg, audio, sr = await asyncio.to_thread(
validate_and_load_reference_audio, audio_path
)
if not is_valid:
raise ValueError(f"Invalid reference audio: {error_msg}")
sample_id = str(uuid.uuid4())
profile_dir = config.get_profiles_dir() / profile_id
profile_dir.mkdir(parents=True, exist_ok=True)
dest_path = profile_dir / f"{sample_id}.wav"
await asyncio.to_thread(save_audio, audio, str(dest_path), sr)
db_sample = DBProfileSample(
id=sample_id,
profile_id=profile_id,
audio_path=config.to_storage_path(dest_path),
reference_text=reference_text,
)
db.add(db_sample)
profile.updated_at = datetime.utcnow()
db.commit()
db.refresh(db_sample)
# Invalidate combined audio cache for this profile
# Since a new sample was added, any cached combined audio is now stale
clear_profile_cache(profile_id)
return ProfileSampleResponse.model_validate(db_sample)
async def get_profile(
profile_id: str,
db: Session,
) -> VoiceProfileResponse | None:
"""
Get a voice profile by ID.
Args:
profile_id: Profile ID
db: Database session
Returns:
Profile or None if not found
"""
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
return None
return _profile_to_response(profile)
async def get_profile_samples(
profile_id: str,
db: Session,
) -> list[ProfileSampleResponse]:
"""
Get all samples for a profile.
Args:
profile_id: Profile ID
db: Database session
Returns:
List of samples
"""
samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all()
return [ProfileSampleResponse.model_validate(s) for s in samples]
async def list_profiles(db: Session) -> list[VoiceProfileResponse]:
"""
List all voice profiles with generation and sample counts.
Args:
db: Database session
Returns:
List of profiles
"""
profiles = db.query(DBVoiceProfile).order_by(DBVoiceProfile.created_at.desc()).all()
if not profiles:
return []
# Batch-fetch generation counts
gen_counts_rows = (
db.query(DBGeneration.profile_id, func.count(DBGeneration.id)).group_by(DBGeneration.profile_id).all()
)
gen_counts = {row[0]: row[1] for row in gen_counts_rows}
# Batch-fetch sample counts
sample_counts_rows = (
db.query(DBProfileSample.profile_id, func.count(DBProfileSample.id)).group_by(DBProfileSample.profile_id).all()
)
sample_counts = {row[0]: row[1] for row in sample_counts_rows}
return [
_profile_to_response(
p,
generation_count=gen_counts.get(p.id, 0),
sample_count=sample_counts.get(p.id, 0),
)
for p in profiles
]
async def update_profile(
profile_id: str,
data: VoiceProfileCreate,
db: Session,
) -> VoiceProfileResponse | None:
"""
Update a voice profile.
Args:
profile_id: Profile ID
data: Updated profile data
db: Database session
Returns:
Updated profile or None if not found
Raises:
ValueError: If a profile with the same name already exists (different profile)
"""
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
return None
if profile.name != data.name:
existing_profile = db.query(DBVoiceProfile).filter_by(name=data.name).first()
if existing_profile:
raise ValueError(f"A profile with the name '{data.name}' already exists. Please choose a different name.")
voice_type = getattr(profile, "voice_type", None) or "cloned"
preset_engine = getattr(profile, "preset_engine", None)
preset_voice_id = getattr(profile, "preset_voice_id", None)
design_prompt = getattr(profile, "design_prompt", None)
default_engine = data.default_engine if data.default_engine is not None else getattr(profile, "default_engine", None)
validation_error = _validate_profile_fields(
voice_type=voice_type,
preset_engine=preset_engine,
preset_voice_id=preset_voice_id,
design_prompt=design_prompt,
default_engine=default_engine,
)
if validation_error:
raise ValueError(validation_error)
profile.name = data.name
profile.description = data.description
profile.language = data.language
if data.default_engine is not None:
profile.default_engine = data.default_engine or None # empty string → NULL
profile.updated_at = datetime.utcnow()
db.commit()
db.refresh(profile)
return _profile_to_response(profile)
async def delete_profile(
profile_id: str,
db: Session,
) -> bool:
"""
Delete a voice profile and all associated data.
Args:
profile_id: Profile ID
db: Database session
Returns:
True if deleted, False if not found
"""
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
return False
db.query(DBProfileSample).filter_by(profile_id=profile_id).delete()
db.delete(profile)
db.commit()
profile_dir = config.get_profiles_dir() / profile_id
if profile_dir.exists():
shutil.rmtree(profile_dir)
# Clean up combined audio cache files for this profile
clear_profile_cache(profile_id)
return True
async def delete_profile_sample(
sample_id: str,
db: Session,
) -> bool:
"""
Delete a profile sample.
Args:
sample_id: Sample ID
db: Database session
Returns:
True if deleted, False if not found
"""
sample = db.query(DBProfileSample).filter_by(id=sample_id).first()
if not sample:
return False
# Store profile_id before deleting
profile_id = sample.profile_id
audio_path = config.resolve_storage_path(sample.audio_path)
if audio_path is not None and audio_path.exists():
audio_path.unlink()
db.delete(sample)
db.commit()
# Invalidate combined audio cache for this profile
# Since the sample set changed, any cached combined audio is now stale
clear_profile_cache(profile_id)
return True
async def update_profile_sample(
sample_id: str,
reference_text: str,
db: Session,
) -> ProfileSampleResponse | None:
"""
Update a profile sample's reference text.
Args:
sample_id: Sample ID
reference_text: Updated reference text
db: Database session
Returns:
Updated sample or None if not found
"""
sample = db.query(DBProfileSample).filter_by(id=sample_id).first()
if not sample:
return None
# Store profile_id before updating
profile_id = sample.profile_id
sample.reference_text = reference_text
db.commit()
db.refresh(sample)
# Invalidate combined audio cache for this profile
# Since the reference text changed, cache keys and combined text are now stale
clear_profile_cache(profile_id)
return ProfileSampleResponse.model_validate(sample)
async def create_voice_prompt_for_profile(
profile_id: str,
db: Session,
use_cache: bool = True,
engine: str = "qwen",
) -> dict:
"""
Create a voice prompt from a profile.
For cloned profiles: combines all audio samples into a voice prompt.
For preset profiles: returns the engine-specific preset voice reference.
For designed profiles: returns the text design prompt (future).
Args:
profile_id: Profile ID
db: Database session
use_cache: Whether to use cached prompts
engine: TTS engine to create prompt for
Returns:
Voice prompt dictionary
"""
from ..backends import get_tts_backend_for_engine
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
raise ValueError(f"Profile not found: {profile_id}")
voice_type = getattr(profile, "voice_type", None) or "cloned"
validate_profile_engine(profile, engine)
# ── Preset profiles: return engine-specific voice reference ──
if voice_type == "preset":
if not profile.preset_engine or not profile.preset_voice_id:
raise ValueError(f"Preset profile {profile_id} is missing preset engine metadata")
if profile.preset_engine != engine:
raise ValueError(
f"Preset profile {profile_id} only supports engine '{profile.preset_engine}', not '{engine}'"
)
return {
"voice_type": "preset",
"preset_engine": profile.preset_engine,
"preset_voice_id": profile.preset_voice_id,
}
# ── Designed profiles: return text description (future) ──
if voice_type == "designed":
if not profile.design_prompt or not profile.design_prompt.strip():
raise ValueError(f"Designed profile {profile_id} is missing design_prompt")
return {
"voice_type": "designed",
"design_prompt": profile.design_prompt,
}
if engine not in CLONING_ENGINES:
raise ValueError(f"Engine '{engine}' does not support cloned voice profiles")
# ── Cloned profiles: create from audio samples ──
samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all()
if not samples:
raise ValueError(f"No samples found for profile {profile_id}")
tts_model = get_tts_backend_for_engine(engine)
if len(samples) == 1:
sample = samples[0]
sample_audio_path = config.resolve_storage_path(sample.audio_path)
if sample_audio_path is None:
raise ValueError(f"Sample audio not found for profile {profile_id}")
voice_prompt, _ = await tts_model.create_voice_prompt(
str(sample_audio_path),
sample.reference_text,
use_cache=use_cache,
)
return voice_prompt
audio_paths = []
for sample in samples:
sample_audio_path = config.resolve_storage_path(sample.audio_path)
if sample_audio_path is None:
raise ValueError(f"Sample audio not found for profile {profile_id}")
audio_paths.append(str(sample_audio_path))
reference_texts = [s.reference_text for s in samples]
combined_audio, combined_text = await tts_model.combine_voice_prompts(
audio_paths,
reference_texts,
)
# Save combined audio to cache directory (persistent)
# Create a hash of sample IDs to identify this specific combination
import hashlib
sample_ids_str = "-".join(sorted([s.id for s in samples]))
combination_hash = hashlib.md5(sample_ids_str.encode()).hexdigest()[:12]
cache_dir = _get_cache_dir()
cache_dir.mkdir(parents=True, exist_ok=True)
combined_path = cache_dir / f"combined_{profile_id}_{combination_hash}.wav"
save_audio(combined_audio, str(combined_path), 24000)
voice_prompt, _ = await tts_model.create_voice_prompt(
str(combined_path),
combined_text,
use_cache=use_cache,
)
return voice_prompt
async def upload_avatar(
profile_id: str,
image_path: str,
db: Session,
) -> VoiceProfileResponse:
"""
Upload and process avatar image for a profile.
Args:
profile_id: Profile ID
image_path: Path to uploaded image file
db: Database session
Returns:
Updated profile
"""
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
raise ValueError(f"Profile {profile_id} not found")
is_valid, error_msg = validate_image(image_path)
if not is_valid:
raise ValueError(error_msg)
if profile.avatar_path:
old_avatar = config.resolve_storage_path(profile.avatar_path)
if old_avatar is not None and old_avatar.exists():
old_avatar.unlink()
# Determine file extension from uploaded file
from PIL import Image
with Image.open(image_path) as img:
# Normalize JPEG variants (MPO is multi-picture format from some cameras)
img_format = img.format
if img_format in ("MPO", "JPG"):
img_format = "JPEG"
ext_map = {"PNG": ".png", "JPEG": ".jpg", "WEBP": ".webp"}
ext = ext_map.get(img_format, ".png")
profile_dir = config.get_profiles_dir() / profile_id
profile_dir.mkdir(parents=True, exist_ok=True)
output_path = profile_dir / f"avatar{ext}"
process_avatar(image_path, str(output_path))
profile.avatar_path = config.to_storage_path(output_path)
profile.updated_at = datetime.utcnow()
db.commit()
db.refresh(profile)
return _profile_to_response(profile)
async def delete_avatar(
profile_id: str,
db: Session,
) -> bool:
"""
Delete avatar image for a profile.
Args:
profile_id: Profile ID
db: Database session
Returns:
True if deleted, False if not found or no avatar
"""
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile or not profile.avatar_path:
return False
avatar_path = config.resolve_storage_path(profile.avatar_path)
if avatar_path is not None and avatar_path.exists():
avatar_path.unlink()
profile.avatar_path = None
profile.updated_at = datetime.utcnow()
db.commit()
return True

925
backend/services/stories.py Normal file
View File

@@ -0,0 +1,925 @@
"""
Story management module.
"""
from typing import List, Optional
from datetime import datetime
import uuid
import tempfile
from pathlib import Path
from sqlalchemy.orm import Session
from sqlalchemy import func
from .. import config
from ..models import (
StoryCreate,
StoryResponse,
StoryDetailResponse,
StoryItemDetail,
StoryItemCreate,
StoryItemBatchUpdate,
StoryItemMove,
StoryItemTrim,
StoryItemSplit,
StoryItemVersionUpdate,
)
from ..database import (
Story as DBStory,
StoryItem as DBStoryItem,
Generation as DBGeneration,
VoiceProfile as DBVoiceProfile,
)
from .history import _get_versions_for_generation
from ..utils.audio import load_audio, save_audio
import numpy as np
def _build_item_detail(
item: DBStoryItem,
generation: DBGeneration,
profile_name: str,
db: Session,
) -> StoryItemDetail:
"""Build a StoryItemDetail with version info from a story item and its generation."""
versions, active_version_id = _get_versions_for_generation(generation.id, db)
# Resolve the audio path: if version_id is set, use that version's audio
audio_path = generation.audio_path
if item.version_id and versions:
for v in versions:
if v.id == item.version_id:
audio_path = v.audio_path
break
return StoryItemDetail(
id=item.id,
story_id=item.story_id,
generation_id=item.generation_id,
version_id=getattr(item, "version_id", None),
start_time_ms=item.start_time_ms,
track=item.track,
trim_start_ms=getattr(item, "trim_start_ms", 0),
trim_end_ms=getattr(item, "trim_end_ms", 0),
created_at=item.created_at,
profile_id=generation.profile_id,
profile_name=profile_name,
text=generation.text,
language=generation.language,
audio_path=audio_path,
duration=generation.duration,
seed=generation.seed,
instruct=generation.instruct,
generation_created_at=generation.created_at,
versions=versions,
active_version_id=active_version_id,
)
async def create_story(
data: StoryCreate,
db: Session,
) -> StoryResponse:
"""
Create a new story.
Args:
data: Story creation data
db: Database session
Returns:
Created story
"""
db_story = DBStory(
id=str(uuid.uuid4()),
name=data.name,
description=data.description,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
db.add(db_story)
db.commit()
db.refresh(db_story)
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == db_story.id).scalar()
response = StoryResponse.model_validate(db_story)
response.item_count = item_count
return response
async def list_stories(
db: Session,
) -> List[StoryResponse]:
"""
List all stories.
Args:
db: Database session
Returns:
List of stories with item counts
"""
stories = db.query(DBStory).order_by(DBStory.updated_at.desc()).all()
result = []
for story in stories:
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == story.id).scalar()
response = StoryResponse.model_validate(story)
response.item_count = item_count
result.append(response)
return result
async def get_story(
story_id: str,
db: Session,
) -> Optional[StoryDetailResponse]:
"""
Get a story with all its items.
Args:
story_id: Story ID
db: Database session
Returns:
Story with items or None if not found
"""
story = db.query(DBStory).filter_by(id=story_id).first()
if not story:
return None
items = (
db.query(DBStoryItem, DBGeneration, DBVoiceProfile.name.label("profile_name"))
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
.join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id)
.filter(DBStoryItem.story_id == story_id)
.order_by(DBStoryItem.start_time_ms)
.all()
)
item_details = []
for item, generation, profile_name in items:
item_details.append(_build_item_detail(item, generation, profile_name, db))
response = StoryDetailResponse.model_validate(story)
response.items = item_details
return response
async def update_story(
story_id: str,
data: StoryCreate,
db: Session,
) -> Optional[StoryResponse]:
"""
Update a story.
Args:
story_id: Story ID
data: Update data
db: Database session
Returns:
Updated story or None if not found
"""
story = db.query(DBStory).filter_by(id=story_id).first()
if not story:
return None
story.name = data.name
story.description = data.description
story.updated_at = datetime.utcnow()
db.commit()
db.refresh(story)
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == story.id).scalar()
response = StoryResponse.model_validate(story)
response.item_count = item_count
return response
async def delete_story(
story_id: str,
db: Session,
) -> bool:
"""
Delete a story and all its items.
Args:
story_id: Story ID
db: Database session
Returns:
True if deleted, False if not found
"""
story = db.query(DBStory).filter_by(id=story_id).first()
if not story:
return False
# Delete all items
db.query(DBStoryItem).filter_by(story_id=story_id).delete()
# Delete story
db.delete(story)
db.commit()
return True
async def add_item_to_story(
story_id: str,
data: StoryItemCreate,
db: Session,
) -> Optional[StoryItemDetail]:
"""
Add a generation to a story.
Args:
story_id: Story ID
data: Item creation data
db: Database session
Returns:
Created item detail or None if story/generation not found
"""
# Verify story exists
story = db.query(DBStory).filter_by(id=story_id).first()
if not story:
return None
# Verify generation exists
generation = db.query(DBGeneration).filter_by(id=data.generation_id).first()
if not generation:
return None
# Check if generation is already in story
existing = db.query(DBStoryItem).filter_by(story_id=story_id, generation_id=data.generation_id).first()
if existing:
# Return existing item
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
return _build_item_detail(existing, generation, profile.name if profile else "Unknown", db)
# Get track from data or default to 0
track = data.track if data.track is not None else 0
# Calculate start_time_ms if not provided
if data.start_time_ms is not None:
start_time_ms = data.start_time_ms
else:
existing_items = (
db.query(DBStoryItem, DBGeneration)
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
.filter(
DBStoryItem.story_id == story_id,
DBStoryItem.track == track,
)
.all()
)
if not existing_items:
start_time_ms = 0
else:
max_end_time_ms = 0
for item, gen in existing_items:
item_end_ms = item.start_time_ms + int(gen.duration * 1000)
max_end_time_ms = max(max_end_time_ms, item_end_ms)
# Add 200ms gap after the last item
start_time_ms = max_end_time_ms + 200
# Create item
item = DBStoryItem(
id=str(uuid.uuid4()),
story_id=story_id,
generation_id=data.generation_id,
start_time_ms=start_time_ms,
track=track,
created_at=datetime.utcnow(),
)
db.add(item)
# Update story updated_at
story.updated_at = datetime.utcnow()
db.commit()
db.refresh(item)
# Get profile name
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
async def move_story_item(
story_id: str,
item_id: str,
data: StoryItemMove,
db: Session,
) -> Optional[StoryItemDetail]:
"""
Move a story item (update position and/or track).
Args:
story_id: Story ID
item_id: Story item ID
data: New position and track data
db: Database session
Returns:
Updated item detail or None if not found
"""
# Get the item
item = (
db.query(DBStoryItem)
.filter_by(
id=item_id,
story_id=story_id,
)
.first()
)
if not item:
return None
# Get the generation
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
if not generation:
return None
# Update position and track
item.start_time_ms = data.start_time_ms
item.track = data.track
# Update story updated_at
story = db.query(DBStory).filter_by(id=story_id).first()
if story:
story.updated_at = datetime.utcnow()
db.commit()
db.refresh(item)
# Get profile name
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
async def remove_item_from_story(
story_id: str,
item_id: str,
db: Session,
) -> bool:
"""
Remove a story item from a story.
Args:
story_id: Story ID
item_id: Story item ID to remove
db: Database session
Returns:
True if removed, False if not found
"""
item = (
db.query(DBStoryItem)
.filter_by(
id=item_id,
story_id=story_id,
)
.first()
)
if not item:
return False
# Delete item
db.delete(item)
# Update story updated_at
story = db.query(DBStory).filter_by(id=story_id).first()
if story:
story.updated_at = datetime.utcnow()
db.commit()
return True
async def trim_story_item(
story_id: str,
item_id: str,
data: StoryItemTrim,
db: Session,
) -> Optional[StoryItemDetail]:
"""
Trim a story item (update trim_start_ms and trim_end_ms).
Args:
story_id: Story ID
item_id: Story item ID
data: Trim data (trim_start_ms, trim_end_ms)
db: Database session
Returns:
Updated item detail or None if not found
"""
# Get the item
item = (
db.query(DBStoryItem)
.filter_by(
id=item_id,
story_id=story_id,
)
.first()
)
if not item:
return None
# Get the generation
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
if not generation:
return None
# Validate trim values don't exceed duration
max_duration_ms = int(generation.duration * 1000)
if data.trim_start_ms + data.trim_end_ms >= max_duration_ms:
return None # Invalid trim - would result in zero or negative duration
# Update trim values
item.trim_start_ms = data.trim_start_ms
item.trim_end_ms = data.trim_end_ms
# Update story updated_at
story = db.query(DBStory).filter_by(id=story_id).first()
if story:
story.updated_at = datetime.utcnow()
db.commit()
db.refresh(item)
# Get profile name
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
async def split_story_item(
story_id: str,
item_id: str,
data: StoryItemSplit,
db: Session,
) -> Optional[List[StoryItemDetail]]:
"""
Split a story item at a given time, creating two clips.
Args:
story_id: Story ID
item_id: Story item ID to split
data: Split data (split_time_ms - time within clip to split at)
db: Database session
Returns:
List of two updated item details (original and new) or None if not found/invalid
"""
# Get the item with a row lock to prevent concurrent splits on the
# same clip (e.g. from rapid double-clicks racing each other).
item = (
db.query(DBStoryItem)
.filter_by(
id=item_id,
story_id=story_id,
)
.with_for_update()
.first()
)
if not item:
return None
# Get the generation
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
if not generation:
return None
# Calculate effective duration and validate split point
current_trim_start = getattr(item, "trim_start_ms", 0)
current_trim_end = getattr(item, "trim_end_ms", 0)
original_duration_ms = int(generation.duration * 1000)
effective_duration_ms = original_duration_ms - current_trim_start - current_trim_end
# Validate split_time_ms is within the effective duration
if data.split_time_ms <= 0 or data.split_time_ms >= effective_duration_ms:
return None # Invalid split point
# Calculate the absolute time in the original audio where we're splitting
absolute_split_ms = current_trim_start + data.split_time_ms
# Update original clip: trim from the end
item.trim_end_ms = original_duration_ms - absolute_split_ms
# Create new clip: starts after the split, trimmed from the start
new_item = DBStoryItem(
id=str(uuid.uuid4()),
story_id=story_id,
generation_id=item.generation_id, # Same generation, different trim
version_id=getattr(item, "version_id", None), # Preserve pinned version
start_time_ms=item.start_time_ms + data.split_time_ms,
track=item.track,
trim_start_ms=absolute_split_ms,
trim_end_ms=current_trim_end,
created_at=datetime.utcnow(),
)
db.add(new_item)
# Update story updated_at
story = db.query(DBStory).filter_by(id=story_id).first()
if story:
story.updated_at = datetime.utcnow()
db.commit()
db.refresh(item)
db.refresh(new_item)
# Get profile name
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
profile_name = profile.name if profile else "Unknown"
return [
_build_item_detail(item, generation, profile_name, db),
_build_item_detail(new_item, generation, profile_name, db),
]
async def duplicate_story_item(
story_id: str,
item_id: str,
db: Session,
) -> Optional[StoryItemDetail]:
"""
Duplicate a story item, creating a copy with all properties.
Args:
story_id: Story ID
item_id: Story item ID to duplicate
db: Database session
Returns:
New item detail or None if not found
"""
# Get the original item
original_item = (
db.query(DBStoryItem)
.filter_by(
id=item_id,
story_id=story_id,
)
.first()
)
if not original_item:
return None
# Get the generation
generation = db.query(DBGeneration).filter_by(id=original_item.generation_id).first()
if not generation:
return None
# Calculate effective duration
current_trim_start = getattr(original_item, "trim_start_ms", 0)
current_trim_end = getattr(original_item, "trim_end_ms", 0)
original_duration_ms = int(generation.duration * 1000)
effective_duration_ms = original_duration_ms - current_trim_start - current_trim_end
# Create duplicate item - place it right after the original
new_item = DBStoryItem(
id=str(uuid.uuid4()),
story_id=story_id,
generation_id=original_item.generation_id, # Same generation as original
version_id=getattr(original_item, "version_id", None), # Preserve pinned version
start_time_ms=original_item.start_time_ms + effective_duration_ms + 200, # 200ms gap
track=original_item.track,
trim_start_ms=current_trim_start,
trim_end_ms=current_trim_end,
created_at=datetime.utcnow(),
)
db.add(new_item)
# Update story updated_at
story = db.query(DBStory).filter_by(id=story_id).first()
if story:
story.updated_at = datetime.utcnow()
db.commit()
db.refresh(new_item)
# Get profile name
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
return _build_item_detail(new_item, generation, profile.name if profile else "Unknown", db)
async def update_story_item_times(
story_id: str,
data: StoryItemBatchUpdate,
db: Session,
) -> bool:
"""
Update story item timecodes.
Args:
story_id: Story ID
data: Batch update data with timecodes
db: Database session
Returns:
True if updated, False if story not found or invalid
"""
story = db.query(DBStory).filter_by(id=story_id).first()
if not story:
return False
# Get all items for this story
items = db.query(DBStoryItem).filter_by(story_id=story_id).all()
item_map = {item.generation_id: item for item in items}
# Verify all generation IDs belong to this story and update timecodes
for update in data.updates:
if update.generation_id not in item_map:
return False
item_map[update.generation_id].start_time_ms = update.start_time_ms
# Update story updated_at
story.updated_at = datetime.utcnow()
db.commit()
return True
async def reorder_story_items(
story_id: str,
generation_ids: List[str],
db: Session,
gap_ms: int = 200,
) -> Optional[List[StoryItemDetail]]:
"""
Reorder story items and recalculate timecodes.
Args:
story_id: Story ID
generation_ids: List of generation IDs in the desired order
db: Database session
gap_ms: Gap in milliseconds between items (default 200ms)
Returns:
Updated list of story items with new timecodes, or None if invalid
"""
story = db.query(DBStory).filter_by(id=story_id).first()
if not story:
return None
# Get all items for this story with their generation data
items_with_gen = (
db.query(DBStoryItem, DBGeneration, DBVoiceProfile.name.label("profile_name"))
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
.join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id)
.filter(DBStoryItem.story_id == story_id)
.all()
)
# Create maps for quick lookup
item_map = {item.generation_id: (item, gen, profile_name) for item, gen, profile_name in items_with_gen}
# Verify all generation IDs belong to this story
if set(generation_ids) != set(item_map.keys()):
return None
# Recalculate timecodes based on new order
current_time_ms = 0
updated_items = []
for gen_id in generation_ids:
item, generation, profile_name = item_map[gen_id]
# Update the item's start time
item.start_time_ms = current_time_ms
# Calculate the duration in ms
duration_ms = int(generation.duration * 1000)
# Move to next position (current end + gap)
current_time_ms += duration_ms + gap_ms
# Build the response item
updated_items.append(_build_item_detail(item, generation, profile_name, db))
# Update story updated_at
story.updated_at = datetime.utcnow()
db.commit()
return updated_items
async def set_story_item_version(
story_id: str,
item_id: str,
data: StoryItemVersionUpdate,
db: Session,
) -> Optional[StoryItemDetail]:
"""
Pin a story item to a specific generation version.
Args:
story_id: Story ID
item_id: Story item ID
data: Version update data (version_id or null for default)
db: Database session
Returns:
Updated item detail or None if not found
"""
item = (
db.query(DBStoryItem)
.filter_by(
id=item_id,
story_id=story_id,
)
.first()
)
if not item:
return None
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
if not generation:
return None
# Validate version_id belongs to this generation if provided
if data.version_id:
from ..database import GenerationVersion as DBGenerationVersion
version = (
db.query(DBGenerationVersion)
.filter_by(
id=data.version_id,
generation_id=item.generation_id,
)
.first()
)
if not version:
return None
item.version_id = data.version_id
# Update story updated_at
story = db.query(DBStory).filter_by(id=story_id).first()
if story:
story.updated_at = datetime.utcnow()
db.commit()
db.refresh(item)
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
async def export_story_audio(
story_id: str,
db: Session,
) -> Optional[bytes]:
"""
Export story as single mixed audio file with timecode-based mixing.
Args:
story_id: Story ID
db: Database session
Returns:
Audio file bytes or None if story not found
"""
story = db.query(DBStory).filter_by(id=story_id).first()
if not story:
return None
# Get all items ordered by start_time_ms
items = (
db.query(DBStoryItem, DBGeneration)
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
.filter(DBStoryItem.story_id == story_id)
.order_by(DBStoryItem.start_time_ms)
.all()
)
if not items:
return None
# Load all audio files and calculate total duration
audio_data = []
sample_rate = 24000 # Default sample rate
for item, generation in items:
# Resolve audio path: use pinned version if set, otherwise generation default
resolved_audio_path = generation.audio_path
if getattr(item, "version_id", None):
from ..database import GenerationVersion as DBGenerationVersion
version = db.query(DBGenerationVersion).filter_by(id=item.version_id).first()
if version:
resolved_audio_path = version.audio_path
audio_path = config.resolve_storage_path(resolved_audio_path)
if audio_path is None or not audio_path.exists():
continue
try:
audio, sr = load_audio(str(audio_path), sample_rate=sample_rate)
sample_rate = sr # Use actual sample rate from first file
# Get trim values
trim_start_ms = getattr(item, "trim_start_ms", 0)
trim_end_ms = getattr(item, "trim_end_ms", 0)
# Calculate effective duration
original_duration_ms = int(generation.duration * 1000)
effective_duration_ms = original_duration_ms - trim_start_ms - trim_end_ms
# Slice audio based on trim values
trim_start_sample = int((trim_start_ms / 1000.0) * sample_rate)
trim_end_sample = int((trim_end_ms / 1000.0) * sample_rate)
# Extract the trimmed portion
if trim_end_ms > 0:
trimmed_audio = (
audio[trim_start_sample:-trim_end_sample] if trim_end_sample > 0 else audio[trim_start_sample:]
)
else:
trimmed_audio = audio[trim_start_sample:]
# Store audio with its timecode info
start_time_ms = item.start_time_ms
audio_data.append(
{
"audio": trimmed_audio,
"start_time_ms": start_time_ms,
"duration_ms": effective_duration_ms,
}
)
except Exception:
# Skip files that can't be loaded
continue
if not audio_data:
return None
# Calculate total duration: max(start_time_ms + duration_ms)
max_end_time_ms = max((data["start_time_ms"] + data["duration_ms"] for data in audio_data), default=0)
# Convert to samples
total_samples = int((max_end_time_ms / 1000.0) * sample_rate)
# Create output buffer initialized to zeros
final_audio = np.zeros(total_samples, dtype=np.float32)
# Mix each audio segment at its timecode position
for data in audio_data:
audio = data["audio"]
start_time_ms = data["start_time_ms"]
# Calculate start sample index
start_sample = int((start_time_ms / 1000.0) * sample_rate)
# Ensure we don't exceed buffer bounds
audio_length = len(audio)
end_sample = min(start_sample + audio_length, total_samples)
if start_sample < total_samples:
# Trim audio if it extends beyond buffer
audio_to_mix = audio[: end_sample - start_sample]
# Mix: add audio to existing buffer (overlapping audio will sum)
# Normalize to prevent clipping (simple approach: divide by max)
final_audio[start_sample:end_sample] += audio_to_mix
# Normalize to prevent clipping
max_val = np.abs(final_audio).max()
if max_val > 1.0:
final_audio = final_audio / max_val
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp_path = tmp.name
try:
save_audio(final_audio, tmp_path, sample_rate)
# Read file bytes
with open(tmp_path, "rb") as f:
audio_bytes = f.read()
return audio_bytes
finally:
# Clean up temp file
Path(tmp_path).unlink(missing_ok=True)

View File

@@ -0,0 +1,108 @@
"""
Serial generation queue — ensures only one TTS inference runs at a time
to avoid GPU contention.
"""
import asyncio
import traceback
from dataclasses import dataclass
from typing import Coroutine, Literal
# Keep references to fire-and-forget background tasks to prevent GC
_background_tasks: set = set()
@dataclass
class GenerationJob:
"""Queued generation work plus the generation ID it belongs to."""
generation_id: str
coro: Coroutine
# Generation queue — serializes TTS inference to avoid GPU contention
_generation_queue: asyncio.Queue = None # type: ignore # initialized at startup
_generation_worker_task: asyncio.Task | None = None
_queued_generation_ids: set[str] = set()
_running_generation_tasks: dict[str, asyncio.Task] = {}
_cancelled_generation_ids: set[str] = set()
def create_background_task(coro) -> asyncio.Task:
"""Create a background task and prevent it from being garbage collected."""
task = asyncio.create_task(coro)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
return task
async def _generation_worker():
"""Worker that processes generation tasks one at a time."""
while True:
job = await _generation_queue.get()
try:
if job.generation_id in _cancelled_generation_ids:
_cancelled_generation_ids.discard(job.generation_id)
job.coro.close()
continue
task = asyncio.create_task(job.coro)
_running_generation_tasks[job.generation_id] = task
_queued_generation_ids.discard(job.generation_id)
try:
await task
except asyncio.CancelledError:
if not task.cancelled():
raise
except Exception:
traceback.print_exc()
finally:
_running_generation_tasks.pop(job.generation_id, None)
_queued_generation_ids.discard(job.generation_id)
_generation_queue.task_done()
def enqueue_generation(generation_id: str, coro):
"""Add a generation coroutine to the serial queue."""
if _generation_queue is None:
raise RuntimeError("Generation queue has not been initialized")
_queued_generation_ids.add(generation_id)
_generation_queue.put_nowait(GenerationJob(generation_id=generation_id, coro=coro))
def cancel_generation(generation_id: str) -> Literal["queued", "running"] | None:
"""Cancel a queued or running generation if it is still active."""
running_task = _running_generation_tasks.get(generation_id)
if running_task is not None:
running_task.cancel()
return "running"
if generation_id in _queued_generation_ids:
_queued_generation_ids.discard(generation_id)
_cancelled_generation_ids.add(generation_id)
return "queued"
return None
def init_queue(force: bool = False):
"""Initialize the generation queue and start the worker.
Must be called once during application startup (inside a running event loop).
"""
global _generation_queue, _generation_worker_task
global _queued_generation_ids, _running_generation_tasks, _cancelled_generation_ids
if _generation_worker_task is not None and not _generation_worker_task.done():
if not force:
return
_generation_worker_task.cancel()
for task in list(_running_generation_tasks.values()):
task.cancel()
_generation_queue = asyncio.Queue()
_queued_generation_ids = set()
_running_generation_tasks = {}
_cancelled_generation_ids = set()
_generation_worker_task = create_background_task(_generation_worker())

View File

@@ -0,0 +1,22 @@
"""
STT (Speech-to-Text) module - delegates to backend abstraction layer.
"""
from typing import Optional
from ..backends import get_stt_backend, STTBackend
def get_whisper_model() -> STTBackend:
"""
Get STT backend instance (MLX or PyTorch based on platform).
Returns:
STT backend instance
"""
return get_stt_backend()
def unload_whisper_model():
"""Unload Whisper model to free memory."""
backend = get_stt_backend()
backend.unload_model()

34
backend/services/tts.py Normal file
View File

@@ -0,0 +1,34 @@
"""
TTS inference module - delegates to backend abstraction layer.
"""
from typing import Optional
import numpy as np
import io
import soundfile as sf
from ..backends import get_tts_backend, TTSBackend
def get_tts_model() -> TTSBackend:
"""
Get TTS backend instance (MLX or PyTorch based on platform).
Returns:
TTS backend instance
"""
return get_tts_backend()
def unload_tts_model():
"""Unload TTS model to free memory."""
backend = get_tts_backend()
backend.unload_model()
def audio_to_wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes:
"""Convert audio array to WAV bytes."""
buffer = io.BytesIO()
sf.write(buffer, audio, sample_rate, format="WAV")
buffer.seek(0)
return buffer.read()

View File

@@ -0,0 +1,211 @@
"""
Generation versions management module.
Each generation can have multiple audio versions: a clean (unprocessed)
version and any number of processed versions with different effects chains.
"""
from __future__ import annotations
import json
import uuid
from pathlib import Path
from typing import List, Optional
from sqlalchemy.orm import Session
from ..database import (
GenerationVersion as DBGenerationVersion,
Generation as DBGeneration,
)
from ..models import GenerationVersionResponse, EffectConfig
from .. import config
def _version_response(v: DBGenerationVersion) -> GenerationVersionResponse:
"""Convert a DB version row to a Pydantic response."""
effects_chain = None
if v.effects_chain:
raw = json.loads(v.effects_chain)
effects_chain = [EffectConfig(**e) for e in raw]
return GenerationVersionResponse(
id=v.id,
generation_id=v.generation_id,
label=v.label,
audio_path=v.audio_path,
effects_chain=effects_chain,
source_version_id=v.source_version_id,
is_default=v.is_default,
created_at=v.created_at,
)
def list_versions(generation_id: str, db: Session) -> List[GenerationVersionResponse]:
"""List all versions for a generation."""
versions = (
db.query(DBGenerationVersion)
.filter_by(generation_id=generation_id)
.order_by(DBGenerationVersion.created_at)
.all()
)
return [_version_response(v) for v in versions]
def get_version(version_id: str, db: Session) -> Optional[GenerationVersionResponse]:
"""Get a specific version by ID."""
v = db.query(DBGenerationVersion).filter_by(id=version_id).first()
if not v:
return None
return _version_response(v)
def get_default_version(generation_id: str, db: Session) -> Optional[GenerationVersionResponse]:
"""Get the default version for a generation."""
v = (
db.query(DBGenerationVersion)
.filter_by(generation_id=generation_id, is_default=True)
.first()
)
if not v:
# Fallback: return the first version
v = (
db.query(DBGenerationVersion)
.filter_by(generation_id=generation_id)
.order_by(DBGenerationVersion.created_at)
.first()
)
if not v:
return None
return _version_response(v)
def create_version(
generation_id: str,
label: str,
audio_path: str,
db: Session,
effects_chain: Optional[List[dict]] = None,
is_default: bool = False,
source_version_id: Optional[str] = None,
) -> GenerationVersionResponse:
"""Create a new version for a generation.
If ``is_default`` is True, all other versions for this generation
are un-defaulted first.
"""
if is_default:
_clear_defaults(generation_id, db)
version = DBGenerationVersion(
id=str(uuid.uuid4()),
generation_id=generation_id,
label=label,
audio_path=audio_path,
effects_chain=json.dumps(effects_chain) if effects_chain else None,
source_version_id=source_version_id,
is_default=is_default,
)
db.add(version)
db.commit()
db.refresh(version)
# If this version is the default, update the generation's audio_path
if is_default:
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if gen:
gen.audio_path = audio_path
db.commit()
return _version_response(version)
def set_default_version(version_id: str, db: Session) -> Optional[GenerationVersionResponse]:
"""Set a version as the default for its generation."""
version = db.query(DBGenerationVersion).filter_by(id=version_id).first()
if not version:
return None
_clear_defaults(version.generation_id, db)
version.is_default = True
db.commit()
db.refresh(version)
# Update generation's audio_path to point to this version
gen = db.query(DBGeneration).filter_by(id=version.generation_id).first()
if gen:
gen.audio_path = version.audio_path
db.commit()
return _version_response(version)
def delete_version(version_id: str, db: Session) -> bool:
"""Delete a version. Cannot delete the last remaining version."""
version = db.query(DBGenerationVersion).filter_by(id=version_id).first()
if not version:
return False
# Don't allow deleting the last version
count = (
db.query(DBGenerationVersion)
.filter_by(generation_id=version.generation_id)
.count()
)
if count <= 1:
return False
was_default = version.is_default
gen_id = version.generation_id
# Delete audio file
audio_path = config.resolve_storage_path(version.audio_path)
if audio_path is not None and audio_path.exists():
audio_path.unlink()
db.delete(version)
db.commit()
# If this was the default, promote the first remaining version
if was_default:
first = (
db.query(DBGenerationVersion)
.filter_by(generation_id=gen_id)
.order_by(DBGenerationVersion.created_at)
.first()
)
if first:
first.is_default = True
db.commit()
gen = db.query(DBGeneration).filter_by(id=gen_id).first()
if gen:
gen.audio_path = first.audio_path
db.commit()
return True
def delete_versions_for_generation(generation_id: str, db: Session) -> int:
"""Delete all versions for a generation (used when deleting a generation)."""
versions = (
db.query(DBGenerationVersion)
.filter_by(generation_id=generation_id)
.all()
)
count = 0
for v in versions:
audio_path = config.resolve_storage_path(v.audio_path)
if audio_path is not None and audio_path.exists():
audio_path.unlink()
db.delete(v)
count += 1
if count > 0:
db.commit()
return count
def _clear_defaults(generation_id: str, db: Session) -> None:
"""Clear the is_default flag on all versions for a generation."""
db.query(DBGenerationVersion).filter_by(
generation_id=generation_id, is_default=True
).update({"is_default": False})
db.flush()