Initial commit
This commit is contained in:
1
backend/services/__init__.py
Normal file
1
backend/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Services layer — generation orchestration and background task management.
|
||||
263
backend/services/channels.py
Normal file
263
backend/services/channels.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Audio channel management module.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import (
|
||||
AudioChannelCreate,
|
||||
AudioChannelUpdate,
|
||||
AudioChannelResponse,
|
||||
ChannelVoiceAssignment,
|
||||
ProfileChannelAssignment,
|
||||
)
|
||||
from ..database import (
|
||||
AudioChannel as DBAudioChannel,
|
||||
ChannelDeviceMapping as DBChannelDeviceMapping,
|
||||
ProfileChannelMapping as DBProfileChannelMapping,
|
||||
VoiceProfile as DBVoiceProfile,
|
||||
)
|
||||
|
||||
|
||||
async def list_channels(db: Session) -> List[AudioChannelResponse]:
|
||||
"""List all audio channels."""
|
||||
channels = db.query(DBAudioChannel).all()
|
||||
result = []
|
||||
|
||||
for channel in channels:
|
||||
# Get device IDs for this channel
|
||||
device_mappings = db.query(DBChannelDeviceMapping).filter_by(
|
||||
channel_id=channel.id
|
||||
).all()
|
||||
device_ids = [m.device_id for m in device_mappings]
|
||||
|
||||
result.append(AudioChannelResponse(
|
||||
id=channel.id,
|
||||
name=channel.name,
|
||||
is_default=channel.is_default,
|
||||
device_ids=device_ids,
|
||||
created_at=channel.created_at,
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def get_channel(channel_id: str, db: Session) -> Optional[AudioChannelResponse]:
|
||||
"""Get a channel by ID."""
|
||||
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
|
||||
if not channel:
|
||||
return None
|
||||
|
||||
# Get device IDs
|
||||
device_mappings = db.query(DBChannelDeviceMapping).filter_by(
|
||||
channel_id=channel.id
|
||||
).all()
|
||||
device_ids = [m.device_id for m in device_mappings]
|
||||
|
||||
return AudioChannelResponse(
|
||||
id=channel.id,
|
||||
name=channel.name,
|
||||
is_default=channel.is_default,
|
||||
device_ids=device_ids,
|
||||
created_at=channel.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def create_channel(
|
||||
data: AudioChannelCreate,
|
||||
db: Session,
|
||||
) -> AudioChannelResponse:
|
||||
"""Create a new audio channel."""
|
||||
# Check if name already exists
|
||||
existing = db.query(DBAudioChannel).filter_by(name=data.name).first()
|
||||
if existing:
|
||||
raise ValueError(f"Channel with name '{data.name}' already exists")
|
||||
|
||||
# Create channel
|
||||
channel = DBAudioChannel(
|
||||
id=str(uuid.uuid4()),
|
||||
name=data.name,
|
||||
is_default=False,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(channel)
|
||||
db.flush()
|
||||
|
||||
# Add device mappings
|
||||
for device_id in data.device_ids:
|
||||
mapping = DBChannelDeviceMapping(
|
||||
id=str(uuid.uuid4()),
|
||||
channel_id=channel.id,
|
||||
device_id=device_id,
|
||||
)
|
||||
db.add(mapping)
|
||||
|
||||
db.commit()
|
||||
db.refresh(channel)
|
||||
|
||||
return AudioChannelResponse(
|
||||
id=channel.id,
|
||||
name=channel.name,
|
||||
is_default=channel.is_default,
|
||||
device_ids=data.device_ids,
|
||||
created_at=channel.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def update_channel(
|
||||
channel_id: str,
|
||||
data: AudioChannelUpdate,
|
||||
db: Session,
|
||||
) -> Optional[AudioChannelResponse]:
|
||||
"""Update an audio channel."""
|
||||
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
|
||||
if not channel:
|
||||
return None
|
||||
|
||||
if channel.is_default:
|
||||
raise ValueError("Cannot modify the default channel")
|
||||
|
||||
# Update name if provided
|
||||
if data.name is not None:
|
||||
# Check if name already exists (excluding current channel)
|
||||
existing = db.query(DBAudioChannel).filter(
|
||||
DBAudioChannel.name == data.name,
|
||||
DBAudioChannel.id != channel_id
|
||||
).first()
|
||||
if existing:
|
||||
raise ValueError(f"Channel with name '{data.name}' already exists")
|
||||
channel.name = data.name
|
||||
|
||||
# Update device mappings if provided
|
||||
if data.device_ids is not None:
|
||||
# Delete existing mappings
|
||||
db.query(DBChannelDeviceMapping).filter_by(channel_id=channel_id).delete()
|
||||
|
||||
# Add new mappings
|
||||
for device_id in data.device_ids:
|
||||
mapping = DBChannelDeviceMapping(
|
||||
id=str(uuid.uuid4()),
|
||||
channel_id=channel.id,
|
||||
device_id=device_id,
|
||||
)
|
||||
db.add(mapping)
|
||||
|
||||
db.commit()
|
||||
db.refresh(channel)
|
||||
|
||||
# Get updated device IDs
|
||||
device_mappings = db.query(DBChannelDeviceMapping).filter_by(
|
||||
channel_id=channel.id
|
||||
).all()
|
||||
device_ids = [m.device_id for m in device_mappings]
|
||||
|
||||
return AudioChannelResponse(
|
||||
id=channel.id,
|
||||
name=channel.name,
|
||||
is_default=channel.is_default,
|
||||
device_ids=device_ids,
|
||||
created_at=channel.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def delete_channel(channel_id: str, db: Session) -> bool:
|
||||
"""Delete an audio channel."""
|
||||
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
|
||||
if not channel:
|
||||
return False
|
||||
|
||||
if channel.is_default:
|
||||
raise ValueError("Cannot delete the default channel")
|
||||
|
||||
# Delete device mappings
|
||||
db.query(DBChannelDeviceMapping).filter_by(channel_id=channel_id).delete()
|
||||
|
||||
# Delete profile-channel mappings
|
||||
db.query(DBProfileChannelMapping).filter_by(channel_id=channel_id).delete()
|
||||
|
||||
# Delete channel
|
||||
db.delete(channel)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def get_channel_voices(channel_id: str, db: Session) -> List[str]:
|
||||
"""Get list of profile IDs assigned to a channel."""
|
||||
mappings = db.query(DBProfileChannelMapping).filter_by(
|
||||
channel_id=channel_id
|
||||
).all()
|
||||
return [m.profile_id for m in mappings]
|
||||
|
||||
|
||||
async def set_channel_voices(
|
||||
channel_id: str,
|
||||
data: ChannelVoiceAssignment,
|
||||
db: Session,
|
||||
) -> None:
|
||||
"""Set which voices are assigned to a channel."""
|
||||
# Verify channel exists
|
||||
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
|
||||
if not channel:
|
||||
raise ValueError(f"Channel {channel_id} not found")
|
||||
|
||||
# Verify all profiles exist
|
||||
for profile_id in data.profile_ids:
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile {profile_id} not found")
|
||||
|
||||
# Delete existing mappings for this channel
|
||||
db.query(DBProfileChannelMapping).filter_by(channel_id=channel_id).delete()
|
||||
|
||||
# Add new mappings
|
||||
for profile_id in data.profile_ids:
|
||||
mapping = DBProfileChannelMapping(
|
||||
profile_id=profile_id,
|
||||
channel_id=channel_id,
|
||||
)
|
||||
db.add(mapping)
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
async def get_profile_channels(profile_id: str, db: Session) -> List[str]:
|
||||
"""Get list of channel IDs assigned to a profile."""
|
||||
mappings = db.query(DBProfileChannelMapping).filter_by(
|
||||
profile_id=profile_id
|
||||
).all()
|
||||
return [m.channel_id for m in mappings]
|
||||
|
||||
|
||||
async def set_profile_channels(
|
||||
profile_id: str,
|
||||
data: ProfileChannelAssignment,
|
||||
db: Session,
|
||||
) -> None:
|
||||
"""Set which channels a profile is assigned to."""
|
||||
# Verify profile exists
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile {profile_id} not found")
|
||||
|
||||
# Verify all channels exist
|
||||
for channel_id in data.channel_ids:
|
||||
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
|
||||
if not channel:
|
||||
raise ValueError(f"Channel {channel_id} not found")
|
||||
|
||||
# Delete existing mappings for this profile
|
||||
db.query(DBProfileChannelMapping).filter_by(profile_id=profile_id).delete()
|
||||
|
||||
# Add new mappings
|
||||
for channel_id in data.channel_ids:
|
||||
mapping = DBProfileChannelMapping(
|
||||
profile_id=profile_id,
|
||||
channel_id=channel_id,
|
||||
)
|
||||
db.add(mapping)
|
||||
|
||||
db.commit()
|
||||
422
backend/services/cuda.py
Normal file
422
backend/services/cuda.py
Normal file
@@ -0,0 +1,422 @@
|
||||
"""
|
||||
CUDA backend download, assembly, and verification.
|
||||
|
||||
Downloads two archives from GitHub Releases:
|
||||
1. Server core (voicebox-server-cuda.tar.gz) — the exe + non-NVIDIA deps,
|
||||
versioned with the app.
|
||||
2. CUDA libs (cuda-libs-{version}.tar.gz) — NVIDIA runtime libraries,
|
||||
versioned independently (only redownloaded on CUDA toolkit bump).
|
||||
|
||||
Both archives are extracted into {data_dir}/backends/cuda/ which forms the
|
||||
complete PyInstaller --onedir directory structure that torch expects.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from ..config import get_data_dir
|
||||
from ..utils.progress import get_progress_manager
|
||||
from .. import __version__
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GITHUB_RELEASES_URL = "https://github.com/jamiepine/voicebox/releases/download"
|
||||
|
||||
PROGRESS_KEY = "cuda-backend"
|
||||
|
||||
# The current expected CUDA libs version. Bump this when we change the
|
||||
# CUDA toolkit version or torch's CUDA dependency changes (e.g. cu126 -> cu128).
|
||||
CUDA_LIBS_VERSION = "cu128-v1"
|
||||
|
||||
# Prevents concurrent download_cuda_binary() calls from racing on the same
|
||||
# temp file. The auto-update background task and the manual HTTP endpoint
|
||||
# can both invoke download_cuda_binary(); without this lock the progress-
|
||||
# manager status check is a TOCTOU race.
|
||||
_download_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def get_backends_dir() -> Path:
|
||||
"""Directory where downloaded backend binaries are stored."""
|
||||
d = get_data_dir() / "backends"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
|
||||
|
||||
def get_cuda_dir() -> Path:
|
||||
"""Directory where the CUDA backend (onedir) is extracted."""
|
||||
d = get_backends_dir() / "cuda"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
|
||||
|
||||
def get_cuda_exe_name() -> str:
|
||||
"""Platform-specific CUDA executable filename."""
|
||||
if sys.platform == "win32":
|
||||
return "voicebox-server-cuda.exe"
|
||||
return "voicebox-server-cuda"
|
||||
|
||||
|
||||
def get_cuda_binary_path() -> Optional[Path]:
|
||||
"""Return path to the CUDA executable if it exists inside the onedir."""
|
||||
p = get_cuda_dir() / get_cuda_exe_name()
|
||||
if p.exists():
|
||||
return p
|
||||
return None
|
||||
|
||||
|
||||
def get_cuda_libs_manifest_path() -> Path:
|
||||
"""Path to the cuda-libs.json manifest inside the CUDA dir."""
|
||||
return get_cuda_dir() / "cuda-libs.json"
|
||||
|
||||
|
||||
def get_installed_cuda_libs_version() -> Optional[str]:
|
||||
"""Read the installed CUDA libs version from cuda-libs.json, or None."""
|
||||
manifest_path = get_cuda_libs_manifest_path()
|
||||
if not manifest_path.exists():
|
||||
return None
|
||||
try:
|
||||
data = json.loads(manifest_path.read_text())
|
||||
return data.get("version")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read cuda-libs.json: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def is_cuda_active() -> bool:
|
||||
"""Check if the current process is the CUDA binary.
|
||||
|
||||
The CUDA binary sets this env var on startup (see server.py).
|
||||
"""
|
||||
return os.environ.get("VOICEBOX_BACKEND_VARIANT") == "cuda"
|
||||
|
||||
|
||||
def get_cuda_status() -> dict:
|
||||
"""Get current CUDA backend status for the API."""
|
||||
progress_manager = get_progress_manager()
|
||||
cuda_path = get_cuda_binary_path()
|
||||
progress = progress_manager.get_progress(PROGRESS_KEY)
|
||||
cuda_libs_version = get_installed_cuda_libs_version()
|
||||
|
||||
return {
|
||||
"available": cuda_path is not None,
|
||||
"active": is_cuda_active(),
|
||||
"binary_path": str(cuda_path) if cuda_path else None,
|
||||
"cuda_libs_version": cuda_libs_version,
|
||||
"downloading": progress is not None and progress.get("status") == "downloading",
|
||||
"download_progress": progress,
|
||||
}
|
||||
|
||||
|
||||
def _needs_server_download(version: Optional[str] = None) -> bool:
|
||||
"""Check if the server core archive needs to be (re)downloaded."""
|
||||
cuda_path = get_cuda_binary_path()
|
||||
if not cuda_path:
|
||||
return True
|
||||
# Check if the binary version matches the expected app version
|
||||
installed = get_cuda_binary_version()
|
||||
expected = version or __version__
|
||||
if expected.startswith("v"):
|
||||
expected = expected[1:]
|
||||
return installed != expected
|
||||
|
||||
|
||||
def _needs_cuda_libs_download() -> bool:
|
||||
"""Check if the CUDA libs archive needs to be (re)downloaded."""
|
||||
installed = get_installed_cuda_libs_version()
|
||||
if installed is None:
|
||||
return True
|
||||
return installed != CUDA_LIBS_VERSION
|
||||
|
||||
|
||||
async def _download_and_extract_archive(
|
||||
client,
|
||||
url: str,
|
||||
sha256_url: Optional[str],
|
||||
dest_dir: Path,
|
||||
label: str,
|
||||
progress_offset: int,
|
||||
total_size: int,
|
||||
):
|
||||
"""Download a .tar.gz archive and extract it into dest_dir.
|
||||
|
||||
Args:
|
||||
client: httpx.AsyncClient
|
||||
url: URL of the .tar.gz archive
|
||||
sha256_url: URL of the .sha256 checksum file (optional)
|
||||
dest_dir: Directory to extract into
|
||||
label: Human-readable label for progress updates
|
||||
progress_offset: Byte offset for progress reporting (when downloading
|
||||
multiple archives sequentially)
|
||||
total_size: Total bytes across all downloads (for progress bar)
|
||||
"""
|
||||
progress = get_progress_manager()
|
||||
temp_path = dest_dir / f".download-{label.replace(' ', '-')}.tmp"
|
||||
|
||||
# Clean up leftover partial download
|
||||
if temp_path.exists():
|
||||
temp_path.unlink()
|
||||
|
||||
# Fetch expected checksum (fail-fast: never extract an unverified archive)
|
||||
expected_sha = None
|
||||
if sha256_url:
|
||||
try:
|
||||
sha_resp = await client.get(sha256_url)
|
||||
sha_resp.raise_for_status()
|
||||
expected_sha = sha_resp.text.strip().split()[0]
|
||||
logger.info(f"{label}: expected SHA-256: {expected_sha[:16]}...")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"{label}: failed to fetch checksum from {sha256_url}") from e
|
||||
|
||||
# Stream download, verify, and extract — always clean up temp file
|
||||
downloaded = 0
|
||||
try:
|
||||
async with client.stream("GET", url) as response:
|
||||
response.raise_for_status()
|
||||
with open(temp_path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes(chunk_size=1024 * 1024):
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
progress.update_progress(
|
||||
PROGRESS_KEY,
|
||||
current=progress_offset + downloaded,
|
||||
total=total_size,
|
||||
filename=f"Downloading {label}",
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
# Verify integrity
|
||||
if expected_sha:
|
||||
progress.update_progress(
|
||||
PROGRESS_KEY,
|
||||
current=progress_offset + downloaded,
|
||||
total=total_size,
|
||||
filename=f"Verifying {label}...",
|
||||
status="downloading",
|
||||
)
|
||||
sha256 = hashlib.sha256()
|
||||
with open(temp_path, "rb") as f:
|
||||
while True:
|
||||
data = f.read(1024 * 1024)
|
||||
if not data:
|
||||
break
|
||||
sha256.update(data)
|
||||
actual = sha256.hexdigest()
|
||||
if actual != expected_sha:
|
||||
raise ValueError(
|
||||
f"{label} integrity check failed: expected {expected_sha[:16]}..., got {actual[:16]}..."
|
||||
)
|
||||
logger.info(f"{label}: integrity verified")
|
||||
|
||||
# Extract (use data filter for path traversal protection on Python 3.12+)
|
||||
progress.update_progress(
|
||||
PROGRESS_KEY,
|
||||
current=progress_offset + downloaded,
|
||||
total=total_size,
|
||||
filename=f"Extracting {label}...",
|
||||
status="downloading",
|
||||
)
|
||||
with tarfile.open(temp_path, "r:gz") as tar:
|
||||
if sys.version_info >= (3, 12):
|
||||
tar.extractall(path=dest_dir, filter="data")
|
||||
else:
|
||||
tar.extractall(path=dest_dir)
|
||||
|
||||
logger.info(f"{label}: extracted to {dest_dir}")
|
||||
finally:
|
||||
if temp_path.exists():
|
||||
temp_path.unlink()
|
||||
return downloaded
|
||||
|
||||
|
||||
async def download_cuda_binary(version: Optional[str] = None):
|
||||
"""Download the CUDA backend (server core + CUDA libs if needed).
|
||||
|
||||
Downloads both archives from GitHub Releases, extracts them into
|
||||
{data_dir}/backends/cuda/, and writes the cuda-libs.json manifest.
|
||||
|
||||
Only downloads what's needed:
|
||||
- Server core: always redownloaded (versioned with app)
|
||||
- CUDA libs: only if missing or version mismatch
|
||||
|
||||
Args:
|
||||
version: Version tag (e.g. "v0.3.0"). Defaults to current app version.
|
||||
"""
|
||||
if _download_lock.locked():
|
||||
logger.info("CUDA download already in progress, skipping duplicate request")
|
||||
return
|
||||
async with _download_lock:
|
||||
await _download_cuda_binary_locked(version)
|
||||
|
||||
|
||||
async def _download_cuda_binary_locked(version: Optional[str] = None):
|
||||
"""Inner implementation of download_cuda_binary, called under _download_lock."""
|
||||
import httpx
|
||||
|
||||
if version is None:
|
||||
version = f"v{__version__}"
|
||||
|
||||
progress = get_progress_manager()
|
||||
cuda_dir = get_cuda_dir()
|
||||
|
||||
need_server = _needs_server_download(version)
|
||||
need_libs = _needs_cuda_libs_download()
|
||||
|
||||
if not need_server and not need_libs:
|
||||
logger.info("CUDA backend is up to date, nothing to download")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Starting CUDA backend download for {version} "
|
||||
f"(server={'yes' if need_server else 'cached'}, "
|
||||
f"libs={'yes' if need_libs else 'cached'})"
|
||||
)
|
||||
progress.update_progress(
|
||||
PROGRESS_KEY,
|
||||
current=0,
|
||||
total=0,
|
||||
filename="Preparing download...",
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
base_url = f"{GITHUB_RELEASES_URL}/{version}"
|
||||
server_archive = "voicebox-server-cuda.tar.gz"
|
||||
libs_archive = f"cuda-libs-{CUDA_LIBS_VERSION}.tar.gz"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(follow_redirects=True, timeout=30.0) as client:
|
||||
# Estimate total download size
|
||||
total_size = 0
|
||||
if need_server:
|
||||
try:
|
||||
head = await client.head(f"{base_url}/{server_archive}")
|
||||
total_size += int(head.headers.get("content-length", 0))
|
||||
except Exception:
|
||||
pass
|
||||
if need_libs:
|
||||
try:
|
||||
head = await client.head(f"{base_url}/{libs_archive}")
|
||||
total_size += int(head.headers.get("content-length", 0))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"Total download size: {total_size / 1024 / 1024:.1f} MB")
|
||||
|
||||
offset = 0
|
||||
|
||||
# Download server core
|
||||
if need_server:
|
||||
server_downloaded = await _download_and_extract_archive(
|
||||
client,
|
||||
url=f"{base_url}/{server_archive}",
|
||||
sha256_url=f"{base_url}/{server_archive}.sha256",
|
||||
dest_dir=cuda_dir,
|
||||
label="CUDA server",
|
||||
progress_offset=offset,
|
||||
total_size=total_size,
|
||||
)
|
||||
offset += server_downloaded
|
||||
|
||||
# Make executable on Unix
|
||||
exe_path = cuda_dir / get_cuda_exe_name()
|
||||
if sys.platform != "win32" and exe_path.exists():
|
||||
exe_path.chmod(0o755)
|
||||
|
||||
# Download CUDA libs
|
||||
if need_libs:
|
||||
await _download_and_extract_archive(
|
||||
client,
|
||||
url=f"{base_url}/{libs_archive}",
|
||||
sha256_url=f"{base_url}/{libs_archive}.sha256",
|
||||
dest_dir=cuda_dir,
|
||||
label="CUDA libraries",
|
||||
progress_offset=offset,
|
||||
total_size=total_size,
|
||||
)
|
||||
|
||||
# Write local cuda-libs.json manifest
|
||||
manifest = {"version": CUDA_LIBS_VERSION}
|
||||
get_cuda_libs_manifest_path().write_text(json.dumps(manifest, indent=2) + "\n")
|
||||
|
||||
logger.info(f"CUDA backend ready at {cuda_dir}")
|
||||
progress.mark_complete(PROGRESS_KEY)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CUDA backend download failed: {e}")
|
||||
progress.mark_error(PROGRESS_KEY, str(e))
|
||||
raise
|
||||
|
||||
|
||||
def get_cuda_binary_version() -> Optional[str]:
|
||||
"""Get the version of the installed CUDA binary, or None if not installed."""
|
||||
import subprocess
|
||||
|
||||
cuda_path = get_cuda_binary_path()
|
||||
if not cuda_path:
|
||||
return None
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[str(cuda_path), "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
cwd=str(cuda_path.parent), # Run from the onedir directory
|
||||
)
|
||||
# Output format: "voicebox-server 0.3.0"
|
||||
for line in result.stdout.strip().splitlines():
|
||||
if "voicebox-server" in line:
|
||||
return line.split()[-1]
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get CUDA binary version: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def check_and_update_cuda_binary():
|
||||
"""Check if the CUDA binary is outdated and auto-download if so.
|
||||
|
||||
Called on server startup. Checks both server version and CUDA libs
|
||||
version. Downloads only what's needed.
|
||||
"""
|
||||
cuda_path = get_cuda_binary_path()
|
||||
if not cuda_path:
|
||||
return # No CUDA binary installed, nothing to update
|
||||
|
||||
need_server = _needs_server_download()
|
||||
need_libs = _needs_cuda_libs_download()
|
||||
|
||||
if not need_server and not need_libs:
|
||||
logger.info(f"CUDA binary is up to date (server=v{__version__}, libs={get_installed_cuda_libs_version()})")
|
||||
return
|
||||
|
||||
reasons = []
|
||||
if need_server:
|
||||
cuda_version = get_cuda_binary_version()
|
||||
reasons.append(f"server v{cuda_version} != v{__version__}")
|
||||
if need_libs:
|
||||
installed_libs = get_installed_cuda_libs_version()
|
||||
reasons.append(f"libs {installed_libs} != {CUDA_LIBS_VERSION}")
|
||||
|
||||
logger.info(f"CUDA backend needs update ({', '.join(reasons)}). Auto-downloading...")
|
||||
|
||||
try:
|
||||
await download_cuda_binary()
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-update of CUDA binary failed: {e}")
|
||||
|
||||
|
||||
async def delete_cuda_binary() -> bool:
|
||||
"""Delete the downloaded CUDA backend directory. Returns True if deleted."""
|
||||
import shutil
|
||||
|
||||
cuda_dir = get_cuda_dir()
|
||||
if cuda_dir.exists() and any(cuda_dir.iterdir()):
|
||||
shutil.rmtree(cuda_dir)
|
||||
logger.info(f"Deleted CUDA backend directory: {cuda_dir}")
|
||||
return True
|
||||
return False
|
||||
120
backend/services/effects.py
Normal file
120
backend/services/effects.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Effect presets CRUD operations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ..utils.effects import validate_effects_chain
|
||||
|
||||
from ..database import EffectPreset as DBEffectPreset
|
||||
from ..models import EffectPresetResponse, EffectPresetCreate, EffectPresetUpdate, EffectConfig
|
||||
|
||||
|
||||
def _preset_response(p: DBEffectPreset) -> EffectPresetResponse:
|
||||
"""Convert a DB preset row to a Pydantic response."""
|
||||
effects_chain = [EffectConfig(**e) for e in json.loads(p.effects_chain)]
|
||||
return EffectPresetResponse(
|
||||
id=p.id,
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
effects_chain=effects_chain,
|
||||
is_builtin=p.is_builtin or False,
|
||||
created_at=p.created_at,
|
||||
)
|
||||
|
||||
|
||||
def list_presets(db: Session) -> List[EffectPresetResponse]:
|
||||
"""List all effect presets (built-in + user-created)."""
|
||||
presets = db.query(DBEffectPreset).order_by(DBEffectPreset.sort_order, DBEffectPreset.name).all()
|
||||
return [_preset_response(p) for p in presets]
|
||||
|
||||
|
||||
def get_preset(preset_id: str, db: Session) -> Optional[EffectPresetResponse]:
|
||||
"""Get a preset by ID."""
|
||||
p = db.query(DBEffectPreset).filter_by(id=preset_id).first()
|
||||
if not p:
|
||||
return None
|
||||
return _preset_response(p)
|
||||
|
||||
|
||||
def get_preset_by_name(name: str, db: Session) -> Optional[EffectPresetResponse]:
|
||||
"""Get a preset by name."""
|
||||
p = db.query(DBEffectPreset).filter_by(name=name).first()
|
||||
if not p:
|
||||
return None
|
||||
return _preset_response(p)
|
||||
|
||||
|
||||
def create_preset(data: EffectPresetCreate, db: Session) -> EffectPresetResponse:
|
||||
"""Create a new user effect preset."""
|
||||
|
||||
chain_dicts = [e.model_dump() for e in data.effects_chain]
|
||||
error = validate_effects_chain(chain_dicts)
|
||||
if error:
|
||||
raise ValueError(error)
|
||||
|
||||
# Check for duplicate name before insert
|
||||
existing = db.query(DBEffectPreset).filter_by(name=data.name).first()
|
||||
if existing:
|
||||
raise ValueError(f"A preset named '{data.name}' already exists")
|
||||
|
||||
preset = DBEffectPreset(
|
||||
id=str(uuid.uuid4()),
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
effects_chain=json.dumps(chain_dicts),
|
||||
is_builtin=False,
|
||||
)
|
||||
db.add(preset)
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError:
|
||||
db.rollback()
|
||||
raise ValueError(f"A preset named '{data.name}' already exists")
|
||||
db.refresh(preset)
|
||||
return _preset_response(preset)
|
||||
|
||||
|
||||
def update_preset(preset_id: str, data: EffectPresetUpdate, db: Session) -> Optional[EffectPresetResponse]:
|
||||
"""Update a user effect preset. Cannot modify built-in presets."""
|
||||
preset = db.query(DBEffectPreset).filter_by(id=preset_id).first()
|
||||
if not preset:
|
||||
return None
|
||||
if preset.is_builtin:
|
||||
raise ValueError("Cannot modify built-in presets")
|
||||
|
||||
if data.name is not None:
|
||||
preset.name = data.name
|
||||
if data.description is not None:
|
||||
preset.description = data.description
|
||||
if data.effects_chain is not None:
|
||||
|
||||
chain_dicts = [e.model_dump() for e in data.effects_chain]
|
||||
error = validate_effects_chain(chain_dicts)
|
||||
if error:
|
||||
raise ValueError(error)
|
||||
preset.effects_chain = json.dumps(chain_dicts)
|
||||
|
||||
db.commit()
|
||||
db.refresh(preset)
|
||||
return _preset_response(preset)
|
||||
|
||||
|
||||
def delete_preset(preset_id: str, db: Session) -> bool:
|
||||
"""Delete a user effect preset. Cannot delete built-in presets."""
|
||||
preset = db.query(DBEffectPreset).filter_by(id=preset_id).first()
|
||||
if not preset:
|
||||
return False
|
||||
if preset.is_builtin:
|
||||
raise ValueError("Cannot delete built-in presets")
|
||||
|
||||
db.delete(preset)
|
||||
db.commit()
|
||||
return True
|
||||
461
backend/services/export_import.py
Normal file
461
backend/services/export_import.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
Voice profile export/import module.
|
||||
|
||||
Handles exporting profiles to ZIP archives and importing them back.
|
||||
Also handles exporting individual generations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import zipfile
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import VoiceProfileResponse
|
||||
from ..database import VoiceProfile as DBVoiceProfile, ProfileSample as DBProfileSample, Generation as DBGeneration, GenerationVersion as DBGenerationVersion
|
||||
from .profiles import create_profile, add_profile_sample
|
||||
from ..models import VoiceProfileCreate
|
||||
from .. import config
|
||||
|
||||
|
||||
def _get_unique_profile_name(name: str, db: Session) -> str:
|
||||
"""
|
||||
Get a unique profile name by appending a number if needed.
|
||||
|
||||
Args:
|
||||
name: Original profile name
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Unique profile name
|
||||
"""
|
||||
base_name = name
|
||||
counter = 1
|
||||
|
||||
while True:
|
||||
existing = db.query(DBVoiceProfile).filter_by(name=name).first()
|
||||
if not existing:
|
||||
return name
|
||||
|
||||
name = f"{base_name} ({counter})"
|
||||
counter += 1
|
||||
|
||||
|
||||
def export_profile_to_zip(profile_id: str, db: Session) -> bytes:
|
||||
"""
|
||||
Export a voice profile to a ZIP archive.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID to export
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ZIP file contents as bytes
|
||||
|
||||
Raises:
|
||||
ValueError: If profile not found or has no samples
|
||||
"""
|
||||
# Get profile
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile {profile_id} not found")
|
||||
|
||||
# Get all samples
|
||||
samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all()
|
||||
if not samples:
|
||||
raise ValueError(f"Profile {profile_id} has no samples")
|
||||
|
||||
# Create ZIP in memory
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
||||
# Check if profile has avatar
|
||||
has_avatar = False
|
||||
if profile.avatar_path:
|
||||
avatar_path = config.resolve_storage_path(profile.avatar_path)
|
||||
if avatar_path is not None and avatar_path.exists():
|
||||
has_avatar = True
|
||||
# Add avatar to ZIP root with original extension
|
||||
avatar_ext = avatar_path.suffix
|
||||
zip_file.write(avatar_path, f"avatar{avatar_ext}")
|
||||
|
||||
# Create manifest.json
|
||||
manifest = {
|
||||
"version": "1.0",
|
||||
"profile": {
|
||||
"name": profile.name,
|
||||
"description": profile.description,
|
||||
"language": profile.language,
|
||||
},
|
||||
"has_avatar": has_avatar,
|
||||
}
|
||||
zip_file.writestr("manifest.json", json.dumps(manifest, indent=2))
|
||||
|
||||
# Create samples.json mapping
|
||||
samples_data = {}
|
||||
profile_dir = config.get_profiles_dir() / profile_id
|
||||
|
||||
for sample in samples:
|
||||
# Get filename from audio_path (should be {sample_id}.wav)
|
||||
audio_path = config.resolve_storage_path(sample.audio_path)
|
||||
if audio_path is None:
|
||||
raise ValueError(f"Audio file not found: {sample.audio_path}")
|
||||
filename = audio_path.name
|
||||
|
||||
# Read audio file
|
||||
if not audio_path.exists():
|
||||
raise ValueError(f"Audio file not found: {audio_path}")
|
||||
|
||||
# Add to samples directory in ZIP
|
||||
zip_path = f"samples/{filename}"
|
||||
zip_file.write(audio_path, zip_path)
|
||||
|
||||
# Map filename to reference text
|
||||
samples_data[filename] = sample.reference_text
|
||||
|
||||
zip_file.writestr("samples.json", json.dumps(samples_data, indent=2))
|
||||
|
||||
zip_buffer.seek(0)
|
||||
return zip_buffer.read()
|
||||
|
||||
|
||||
async def import_profile_from_zip(file_bytes: bytes, db: Session) -> VoiceProfileResponse:
|
||||
"""
|
||||
Import a voice profile from a ZIP archive.
|
||||
|
||||
Args:
|
||||
file_bytes: ZIP file contents
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created profile
|
||||
|
||||
Raises:
|
||||
ValueError: If ZIP is invalid or missing required files
|
||||
"""
|
||||
zip_buffer = io.BytesIO(file_bytes)
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_buffer, 'r') as zip_file:
|
||||
# Validate ZIP structure
|
||||
namelist = zip_file.namelist()
|
||||
|
||||
if "manifest.json" not in namelist:
|
||||
raise ValueError("ZIP archive missing manifest.json")
|
||||
|
||||
if "samples.json" not in namelist:
|
||||
raise ValueError("ZIP archive missing samples.json")
|
||||
|
||||
# Read manifest
|
||||
manifest_data = json.loads(zip_file.read("manifest.json"))
|
||||
|
||||
if "version" not in manifest_data:
|
||||
raise ValueError("Invalid manifest.json: missing version")
|
||||
|
||||
if "profile" not in manifest_data:
|
||||
raise ValueError("Invalid manifest.json: missing profile")
|
||||
|
||||
profile_data = manifest_data["profile"]
|
||||
|
||||
# Read samples mapping
|
||||
samples_data = json.loads(zip_file.read("samples.json"))
|
||||
|
||||
if not isinstance(samples_data, dict):
|
||||
raise ValueError("Invalid samples.json: must be a dictionary")
|
||||
|
||||
# Get unique profile name
|
||||
original_name = profile_data.get("name", "Imported Profile")
|
||||
unique_name = _get_unique_profile_name(original_name, db)
|
||||
|
||||
# Create profile
|
||||
profile_create = VoiceProfileCreate(
|
||||
name=unique_name,
|
||||
description=profile_data.get("description"),
|
||||
language=profile_data.get("language", "en"),
|
||||
)
|
||||
|
||||
profile = await create_profile(profile_create, db)
|
||||
|
||||
# Extract and add samples
|
||||
profile_dir = config.get_profiles_dir() / profile.id
|
||||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Handle avatar if present
|
||||
avatar_files = [f for f in namelist if f.startswith("avatar.")]
|
||||
if avatar_files:
|
||||
try:
|
||||
avatar_file = avatar_files[0]
|
||||
# Extract to temporary file
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=Path(avatar_file).suffix, delete=False) as tmp:
|
||||
tmp.write(zip_file.read(avatar_file))
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
from .profiles import upload_avatar
|
||||
await upload_avatar(profile.id, tmp_path, db)
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
# Avatar import is optional - continue even if it fails
|
||||
pass
|
||||
|
||||
for filename, reference_text in samples_data.items():
|
||||
# Validate filename
|
||||
if not filename.endswith('.wav'):
|
||||
raise ValueError(f"Invalid sample filename: {filename} (must be .wav)")
|
||||
|
||||
# Extract audio file to temp location
|
||||
zip_path = f"samples/{filename}"
|
||||
|
||||
if zip_path not in namelist:
|
||||
raise ValueError(f"Sample file not found in ZIP: {zip_path}")
|
||||
|
||||
# Extract to temporary file
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp.write(zip_file.read(zip_path))
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
# Add sample to profile
|
||||
await add_profile_sample(
|
||||
profile.id,
|
||||
tmp_path,
|
||||
reference_text,
|
||||
db,
|
||||
)
|
||||
finally:
|
||||
# Clean up temp file
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
return profile
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
raise ValueError("Invalid ZIP file")
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in archive: {e}")
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Error importing profile: {str(e)}")
|
||||
|
||||
|
||||
def export_generation_to_zip(generation_id: str, db: Session) -> bytes:
|
||||
"""
|
||||
Export a generation to a ZIP archive.
|
||||
|
||||
Args:
|
||||
generation_id: Generation ID to export
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ZIP file contents as bytes
|
||||
|
||||
Raises:
|
||||
ValueError: If generation not found
|
||||
"""
|
||||
# Get generation
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
raise ValueError(f"Generation {generation_id} not found")
|
||||
|
||||
# Get profile info
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile {generation.profile_id} not found")
|
||||
|
||||
# Get all versions for this generation
|
||||
versions = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=generation_id)
|
||||
.order_by(DBGenerationVersion.created_at)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create ZIP in memory
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
||||
# Build version manifest entries
|
||||
version_entries = []
|
||||
for v in versions:
|
||||
v_path = config.resolve_storage_path(v.audio_path)
|
||||
effects_chain = None
|
||||
if v.effects_chain:
|
||||
effects_chain = json.loads(v.effects_chain)
|
||||
version_entries.append({
|
||||
"id": v.id,
|
||||
"label": v.label,
|
||||
"is_default": v.is_default,
|
||||
"effects_chain": effects_chain,
|
||||
"filename": v_path.name,
|
||||
})
|
||||
|
||||
manifest = {
|
||||
"version": "1.0",
|
||||
"generation": {
|
||||
"id": generation.id,
|
||||
"text": generation.text,
|
||||
"language": generation.language,
|
||||
"duration": generation.duration,
|
||||
"seed": generation.seed,
|
||||
"instruct": generation.instruct,
|
||||
"created_at": generation.created_at.isoformat(),
|
||||
},
|
||||
"profile": {
|
||||
"id": profile.id,
|
||||
"name": profile.name,
|
||||
"description": profile.description,
|
||||
"language": profile.language,
|
||||
},
|
||||
"versions": version_entries,
|
||||
}
|
||||
zip_file.writestr("manifest.json", json.dumps(manifest, indent=2))
|
||||
|
||||
# Add all version audio files
|
||||
for v in versions:
|
||||
v_path = config.resolve_storage_path(v.audio_path)
|
||||
if v_path is not None and v_path.exists():
|
||||
zip_file.write(v_path, f"audio/{v_path.name}")
|
||||
|
||||
# Fallback: if no versions exist, include the generation's main audio
|
||||
if not versions:
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
zip_file.write(audio_path, f"audio/{audio_path.name}")
|
||||
|
||||
zip_buffer.seek(0)
|
||||
return zip_buffer.read()
|
||||
|
||||
|
||||
async def import_generation_from_zip(file_bytes: bytes, db: Session) -> dict:
|
||||
"""
|
||||
Import a generation from a ZIP archive.
|
||||
|
||||
Args:
|
||||
file_bytes: ZIP file contents
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary with generation ID and profile info
|
||||
|
||||
Raises:
|
||||
ValueError: If ZIP is invalid or missing required files
|
||||
"""
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from .. import config
|
||||
|
||||
zip_buffer = io.BytesIO(file_bytes)
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_buffer, 'r') as zip_file:
|
||||
# Validate ZIP structure
|
||||
namelist = zip_file.namelist()
|
||||
|
||||
if "manifest.json" not in namelist:
|
||||
raise ValueError("ZIP archive missing manifest.json")
|
||||
|
||||
# Read manifest
|
||||
manifest_data = json.loads(zip_file.read("manifest.json"))
|
||||
|
||||
if "version" not in manifest_data:
|
||||
raise ValueError("Invalid manifest.json: missing version")
|
||||
|
||||
if "generation" not in manifest_data:
|
||||
raise ValueError("Invalid manifest.json: missing generation data")
|
||||
|
||||
generation_data = manifest_data["generation"]
|
||||
profile_data = manifest_data.get("profile", {})
|
||||
|
||||
# Validate required fields
|
||||
required_fields = ["text", "language", "duration"]
|
||||
for field in required_fields:
|
||||
if field not in generation_data:
|
||||
raise ValueError(f"Invalid manifest.json: missing generation.{field}")
|
||||
|
||||
# Find audio file in archive
|
||||
audio_files = [f for f in namelist if f.startswith("audio/") and f.endswith(".wav")]
|
||||
if not audio_files:
|
||||
raise ValueError("No audio file found in ZIP archive")
|
||||
|
||||
audio_file_path = audio_files[0]
|
||||
|
||||
# Check if we should match an existing profile or create metadata
|
||||
profile_id = None
|
||||
profile_name = profile_data.get("name", "Unknown Profile")
|
||||
|
||||
# Try to find matching profile by name
|
||||
if profile_name and profile_name != "Unknown Profile":
|
||||
existing_profile = db.query(DBVoiceProfile).filter_by(name=profile_name).first()
|
||||
if existing_profile:
|
||||
profile_id = existing_profile.id
|
||||
|
||||
# If no matching profile, use a placeholder or the first available profile
|
||||
if not profile_id:
|
||||
# Get any profile, or None if no profiles exist
|
||||
any_profile = db.query(DBVoiceProfile).first()
|
||||
if any_profile:
|
||||
profile_id = any_profile.id
|
||||
profile_name = any_profile.name
|
||||
else:
|
||||
raise ValueError("No voice profiles found. Please create a profile before importing generations.")
|
||||
|
||||
# Extract audio file to temporary location
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp.write(zip_file.read(audio_file_path))
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
# Create generations directory
|
||||
generations_dir = config.get_generations_dir()
|
||||
generations_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate new ID for this generation
|
||||
new_generation_id = str(__import__('uuid').uuid4())
|
||||
|
||||
# Copy audio to generations directory
|
||||
audio_dest = generations_dir / f"{new_generation_id}.wav"
|
||||
shutil.copy(tmp_path, audio_dest)
|
||||
|
||||
# Create generation record
|
||||
db_generation = DBGeneration(
|
||||
id=new_generation_id,
|
||||
profile_id=profile_id,
|
||||
text=generation_data["text"],
|
||||
language=generation_data["language"],
|
||||
audio_path=config.to_storage_path(audio_dest),
|
||||
duration=generation_data["duration"],
|
||||
seed=generation_data.get("seed"),
|
||||
instruct=generation_data.get("instruct"),
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(db_generation)
|
||||
db.commit()
|
||||
db.refresh(db_generation)
|
||||
|
||||
return {
|
||||
"id": db_generation.id,
|
||||
"profile_id": profile_id,
|
||||
"profile_name": profile_name,
|
||||
"text": db_generation.text,
|
||||
"message": f"Generation imported successfully (assigned to profile: {profile_name})"
|
||||
}
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
raise ValueError("Invalid ZIP file")
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in archive: {e}")
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Error importing generation: {str(e)}")
|
||||
263
backend/services/generation.py
Normal file
263
backend/services/generation.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Unified TTS generation orchestration.
|
||||
|
||||
Replaces the three near-identical closures (_run_generation, _run_retry,
|
||||
_run_regenerate) that lived in main.py with a single ``run_generation()``
|
||||
function parameterized by *mode*.
|
||||
|
||||
Mode differences:
|
||||
- "generate" : full pipeline -- save clean version, optionally apply
|
||||
effects and create a processed version.
|
||||
- "retry" : re-runs a failed generation with the same seed.
|
||||
No effects, no version creation.
|
||||
- "regenerate" : re-runs with seed=None for variation. Creates a new
|
||||
version with an auto-incremented "take-N" label.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Literal, Optional
|
||||
|
||||
from .. import config
|
||||
from . import history, profiles
|
||||
from ..database import get_db
|
||||
from ..utils.tasks import get_task_manager
|
||||
|
||||
|
||||
async def run_generation(
|
||||
*,
|
||||
generation_id: str,
|
||||
profile_id: str,
|
||||
text: str,
|
||||
language: str,
|
||||
engine: str,
|
||||
model_size: str,
|
||||
seed: Optional[int],
|
||||
normalize: bool = False,
|
||||
effects_chain: Optional[list] = None,
|
||||
instruct: Optional[str] = None,
|
||||
mode: Literal["generate", "retry", "regenerate"],
|
||||
max_chunk_chars: Optional[int] = None,
|
||||
crossfade_ms: Optional[int] = None,
|
||||
version_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Execute TTS inference and persist the result.
|
||||
|
||||
This is the single entry point for all background generation work.
|
||||
It is designed to be enqueued via ``services.task_queue.enqueue_generation``.
|
||||
"""
|
||||
from ..backends import load_engine_model, get_tts_backend_for_engine, engine_needs_trim
|
||||
from ..utils.chunked_tts import generate_chunked
|
||||
from ..utils.audio import normalize_audio, save_audio, trim_tts_output
|
||||
|
||||
task_manager = get_task_manager()
|
||||
bg_db = next(get_db())
|
||||
|
||||
try:
|
||||
tts_model = get_tts_backend_for_engine(engine)
|
||||
|
||||
if not tts_model.is_loaded():
|
||||
await history.update_generation_status(generation_id, "loading_model", bg_db)
|
||||
|
||||
await load_engine_model(engine, model_size)
|
||||
|
||||
voice_prompt = await profiles.create_voice_prompt_for_profile(
|
||||
profile_id,
|
||||
bg_db,
|
||||
use_cache=True,
|
||||
engine=engine,
|
||||
)
|
||||
|
||||
await history.update_generation_status(generation_id, "generating", bg_db)
|
||||
trim_fn = trim_tts_output if engine_needs_trim(engine) else None
|
||||
|
||||
gen_kwargs: dict = dict(
|
||||
language=language,
|
||||
seed=seed if mode != "regenerate" else None,
|
||||
instruct=instruct,
|
||||
trim_fn=trim_fn,
|
||||
)
|
||||
if max_chunk_chars is not None:
|
||||
gen_kwargs["max_chunk_chars"] = max_chunk_chars
|
||||
if crossfade_ms is not None:
|
||||
gen_kwargs["crossfade_ms"] = crossfade_ms
|
||||
|
||||
audio, sample_rate = await generate_chunked(tts_model, text, voice_prompt, **gen_kwargs)
|
||||
|
||||
# --- Normalize (generate and regenerate always; retry skips) -----
|
||||
if normalize or mode == "regenerate":
|
||||
audio = normalize_audio(audio)
|
||||
|
||||
duration = len(audio) / sample_rate
|
||||
|
||||
# --- Persist audio and update status -----------------------------
|
||||
if mode == "generate":
|
||||
final_path = _save_generate(
|
||||
generation_id=generation_id,
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
effects_chain=effects_chain,
|
||||
save_audio=save_audio,
|
||||
db=bg_db,
|
||||
)
|
||||
elif mode == "retry":
|
||||
final_path = _save_retry(
|
||||
generation_id=generation_id,
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
save_audio=save_audio,
|
||||
)
|
||||
elif mode == "regenerate":
|
||||
final_path = _save_regenerate(
|
||||
generation_id=generation_id,
|
||||
version_id=version_id,
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
save_audio=save_audio,
|
||||
db=bg_db,
|
||||
)
|
||||
|
||||
await history.update_generation_status(
|
||||
generation_id=generation_id,
|
||||
status="completed",
|
||||
db=bg_db,
|
||||
audio_path=final_path,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
await history.update_generation_status(
|
||||
generation_id=generation_id,
|
||||
status="failed",
|
||||
db=bg_db,
|
||||
error="Generation cancelled",
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
await history.update_generation_status(
|
||||
generation_id=generation_id,
|
||||
status="failed",
|
||||
db=bg_db,
|
||||
error=str(e),
|
||||
)
|
||||
finally:
|
||||
task_manager.complete_generation(generation_id)
|
||||
bg_db.close()
|
||||
|
||||
|
||||
def _save_generate(
|
||||
*,
|
||||
generation_id: str,
|
||||
audio,
|
||||
sample_rate: int,
|
||||
effects_chain: Optional[list],
|
||||
save_audio,
|
||||
db,
|
||||
) -> str:
|
||||
"""Save clean version and optionally an effects-processed version.
|
||||
|
||||
Returns the final audio path (processed if effects were applied,
|
||||
otherwise clean).
|
||||
"""
|
||||
from . import versions as versions_mod
|
||||
|
||||
clean_audio_path = config.get_generations_dir() / f"{generation_id}.wav"
|
||||
save_audio(audio, str(clean_audio_path), sample_rate)
|
||||
|
||||
has_effects = effects_chain and any(e.get("enabled", True) for e in effects_chain)
|
||||
|
||||
versions_mod.create_version(
|
||||
generation_id=generation_id,
|
||||
label="original",
|
||||
audio_path=config.to_storage_path(clean_audio_path),
|
||||
db=db,
|
||||
effects_chain=None,
|
||||
is_default=not has_effects,
|
||||
)
|
||||
|
||||
final_audio_path = str(clean_audio_path)
|
||||
|
||||
if has_effects:
|
||||
from ..utils.effects import apply_effects, validate_effects_chain
|
||||
|
||||
assert effects_chain is not None
|
||||
|
||||
error_msg = validate_effects_chain(effects_chain)
|
||||
if error_msg:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning("invalid effects chain, skipping: %s", error_msg)
|
||||
versions_mod.set_default_version(
|
||||
versions_mod.list_versions(generation_id, db)[0].id, db
|
||||
)
|
||||
else:
|
||||
processed_audio = apply_effects(audio, sample_rate, effects_chain)
|
||||
processed_path = config.get_generations_dir() / f"{generation_id}_processed.wav"
|
||||
save_audio(processed_audio, str(processed_path), sample_rate)
|
||||
final_audio_path = str(processed_path)
|
||||
versions_mod.create_version(
|
||||
generation_id=generation_id,
|
||||
label="version-2",
|
||||
audio_path=config.to_storage_path(processed_path),
|
||||
db=db,
|
||||
effects_chain=effects_chain,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
return config.to_storage_path(final_audio_path)
|
||||
|
||||
|
||||
def _save_retry(
|
||||
*,
|
||||
generation_id: str,
|
||||
audio,
|
||||
sample_rate: int,
|
||||
save_audio,
|
||||
) -> str:
|
||||
"""Save retry output -- single file, no versions.
|
||||
|
||||
Returns the audio path.
|
||||
"""
|
||||
audio_path = config.get_generations_dir() / f"{generation_id}.wav"
|
||||
save_audio(audio, str(audio_path), sample_rate)
|
||||
return config.to_storage_path(audio_path)
|
||||
|
||||
|
||||
def _save_regenerate(
|
||||
*,
|
||||
generation_id: str,
|
||||
version_id: Optional[str],
|
||||
audio,
|
||||
sample_rate: int,
|
||||
save_audio,
|
||||
db,
|
||||
) -> str:
|
||||
"""Save regeneration output as a new version with auto-label.
|
||||
|
||||
Returns the audio path.
|
||||
"""
|
||||
from . import versions as versions_mod
|
||||
|
||||
import uuid as _uuid
|
||||
|
||||
suffix = _uuid.uuid4().hex[:8]
|
||||
audio_path = config.get_generations_dir() / f"{generation_id}_{suffix}.wav"
|
||||
save_audio(audio, str(audio_path), sample_rate)
|
||||
|
||||
# Count via DB query rather than list length to avoid TOCTOU race
|
||||
from ..database import GenerationVersion as DBGenerationVersion
|
||||
|
||||
count = db.query(DBGenerationVersion).filter_by(generation_id=generation_id).count()
|
||||
label = f"take-{count + 1}"
|
||||
|
||||
versions_mod.create_version(
|
||||
generation_id=generation_id,
|
||||
label=label,
|
||||
audio_path=config.to_storage_path(audio_path),
|
||||
db=db,
|
||||
effects_chain=None,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
return config.to_storage_path(audio_path)
|
||||
368
backend/services/history.py
Normal file
368
backend/services/history.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
Generation history management module.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_
|
||||
|
||||
from ..models import GenerationRequest, GenerationResponse, HistoryQuery, HistoryResponse, HistoryListResponse, GenerationVersionResponse, EffectConfig
|
||||
from ..database import Generation as DBGeneration, GenerationVersion as DBGenerationVersion, VoiceProfile as DBVoiceProfile
|
||||
from .. import config
|
||||
|
||||
|
||||
def _get_versions_for_generation(generation_id: str, db: Session) -> tuple:
|
||||
"""Get versions list and active version ID for a generation."""
|
||||
import json
|
||||
versions_rows = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=generation_id)
|
||||
.order_by(DBGenerationVersion.created_at)
|
||||
.all()
|
||||
)
|
||||
if not versions_rows:
|
||||
return None, None
|
||||
|
||||
versions = []
|
||||
active_version_id = None
|
||||
for v in versions_rows:
|
||||
effects_chain = None
|
||||
if v.effects_chain:
|
||||
try:
|
||||
raw = json.loads(v.effects_chain)
|
||||
effects_chain = [EffectConfig(**e) for e in raw]
|
||||
except Exception:
|
||||
pass
|
||||
versions.append(GenerationVersionResponse(
|
||||
id=v.id,
|
||||
generation_id=v.generation_id,
|
||||
label=v.label,
|
||||
audio_path=v.audio_path,
|
||||
effects_chain=effects_chain,
|
||||
is_default=v.is_default,
|
||||
created_at=v.created_at,
|
||||
))
|
||||
if v.is_default:
|
||||
active_version_id = v.id
|
||||
|
||||
return versions, active_version_id
|
||||
|
||||
|
||||
async def create_generation(
|
||||
profile_id: str,
|
||||
text: str,
|
||||
language: str,
|
||||
audio_path: str,
|
||||
duration: float,
|
||||
seed: Optional[int],
|
||||
db: Session,
|
||||
instruct: Optional[str] = None,
|
||||
generation_id: Optional[str] = None,
|
||||
status: str = "completed",
|
||||
engine: Optional[str] = "qwen",
|
||||
model_size: Optional[str] = None,
|
||||
) -> GenerationResponse:
|
||||
"""
|
||||
Create a new generation history entry.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID used for generation
|
||||
text: Generated text
|
||||
language: Language code
|
||||
audio_path: Path where audio was saved
|
||||
duration: Audio duration in seconds
|
||||
seed: Random seed used (if any)
|
||||
db: Database session
|
||||
instruct: Natural language instruction used (if any)
|
||||
generation_id: Pre-assigned ID (for async generation flow)
|
||||
status: Generation status (generating, completed, failed)
|
||||
engine: TTS engine used (qwen, luxtts, chatterbox, chatterbox_turbo)
|
||||
model_size: Model size variant (1.7B, 0.6B) — only relevant for qwen
|
||||
|
||||
Returns:
|
||||
Created generation entry
|
||||
"""
|
||||
db_generation = DBGeneration(
|
||||
id=generation_id or str(uuid.uuid4()),
|
||||
profile_id=profile_id,
|
||||
text=text,
|
||||
language=language,
|
||||
audio_path=audio_path,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
instruct=instruct,
|
||||
engine=engine,
|
||||
model_size=model_size,
|
||||
status=status,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(db_generation)
|
||||
db.commit()
|
||||
db.refresh(db_generation)
|
||||
|
||||
return GenerationResponse.model_validate(db_generation)
|
||||
|
||||
|
||||
async def update_generation_status(
|
||||
generation_id: str,
|
||||
status: str,
|
||||
db: Session,
|
||||
audio_path: Optional[str] = None,
|
||||
duration: Optional[float] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> Optional[GenerationResponse]:
|
||||
"""Update the status of a generation (used by async generation flow)."""
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
generation.status = status
|
||||
if audio_path is not None:
|
||||
generation.audio_path = audio_path
|
||||
if duration is not None:
|
||||
generation.duration = duration
|
||||
if error is not None:
|
||||
generation.error = error
|
||||
|
||||
db.commit()
|
||||
db.refresh(generation)
|
||||
return GenerationResponse.model_validate(generation)
|
||||
|
||||
|
||||
async def get_generation(
|
||||
generation_id: str,
|
||||
db: Session,
|
||||
) -> Optional[GenerationResponse]:
|
||||
"""
|
||||
Get a generation by ID.
|
||||
|
||||
Args:
|
||||
generation_id: Generation ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Generation or None if not found
|
||||
"""
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
return GenerationResponse.model_validate(generation)
|
||||
|
||||
|
||||
async def list_generations(
|
||||
query: HistoryQuery,
|
||||
db: Session,
|
||||
) -> HistoryListResponse:
|
||||
"""
|
||||
List generations with optional filters.
|
||||
|
||||
Args:
|
||||
query: Query parameters (filters, pagination)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
HistoryListResponse with items and total count
|
||||
"""
|
||||
# Build base query with join to get profile name
|
||||
q = db.query(
|
||||
DBGeneration,
|
||||
DBVoiceProfile.name.label('profile_name')
|
||||
).join(
|
||||
DBVoiceProfile,
|
||||
DBGeneration.profile_id == DBVoiceProfile.id
|
||||
)
|
||||
|
||||
# Apply profile filter
|
||||
if query.profile_id:
|
||||
q = q.filter(DBGeneration.profile_id == query.profile_id)
|
||||
|
||||
# Apply search filter (searches in text content)
|
||||
if query.search:
|
||||
search_pattern = f"%{query.search}%"
|
||||
q = q.filter(DBGeneration.text.like(search_pattern))
|
||||
|
||||
# Get total count before pagination
|
||||
total_count = q.count()
|
||||
|
||||
# Apply ordering (newest first)
|
||||
q = q.order_by(DBGeneration.created_at.desc())
|
||||
|
||||
# Apply pagination
|
||||
q = q.offset(query.offset).limit(query.limit)
|
||||
|
||||
# Execute query
|
||||
results = q.all()
|
||||
|
||||
# Convert to HistoryResponse with profile_name
|
||||
items = []
|
||||
for generation, profile_name in results:
|
||||
versions, active_version_id = _get_versions_for_generation(generation.id, db)
|
||||
items.append(HistoryResponse(
|
||||
id=generation.id,
|
||||
profile_id=generation.profile_id,
|
||||
profile_name=profile_name,
|
||||
text=generation.text,
|
||||
language=generation.language,
|
||||
audio_path=generation.audio_path,
|
||||
duration=generation.duration,
|
||||
seed=generation.seed,
|
||||
instruct=generation.instruct,
|
||||
engine=generation.engine or "qwen",
|
||||
model_size=generation.model_size,
|
||||
status=generation.status or "completed",
|
||||
error=generation.error,
|
||||
is_favorited=bool(generation.is_favorited),
|
||||
created_at=generation.created_at,
|
||||
versions=versions,
|
||||
active_version_id=active_version_id,
|
||||
))
|
||||
|
||||
return HistoryListResponse(
|
||||
items=items,
|
||||
total=total_count,
|
||||
)
|
||||
|
||||
|
||||
async def delete_generation(
|
||||
generation_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a generation.
|
||||
|
||||
Args:
|
||||
generation_id: Generation ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
return False
|
||||
|
||||
# Delete all version files and records
|
||||
from . import versions as versions_mod
|
||||
versions_mod.delete_versions_for_generation(generation_id, db)
|
||||
|
||||
# Delete main audio file (if not already removed by version cleanup)
|
||||
if generation.audio_path:
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
audio_path.unlink()
|
||||
|
||||
# Delete from database
|
||||
db.delete(generation)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def delete_failed_generations(db: Session) -> int:
|
||||
"""
|
||||
Delete every generation whose status is 'failed'.
|
||||
|
||||
Used by the "Clear failed" action in the UI so users can tidy up
|
||||
history after the model wasn't loaded, the app was closed mid-run,
|
||||
or a generation otherwise errored out (see issue #410).
|
||||
|
||||
Returns:
|
||||
Number of generations deleted.
|
||||
"""
|
||||
from . import versions as versions_mod
|
||||
|
||||
failed = db.query(DBGeneration).filter(DBGeneration.status == "failed").all()
|
||||
count = 0
|
||||
for generation in failed:
|
||||
# Clean up version files/rows first.
|
||||
versions_mod.delete_versions_for_generation(generation.id, db)
|
||||
|
||||
# Remove the main audio file if it somehow made it to disk.
|
||||
if generation.audio_path:
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
try:
|
||||
audio_path.unlink()
|
||||
except OSError:
|
||||
# Best-effort cleanup — don't abort the whole sweep
|
||||
# if a single file can't be removed.
|
||||
pass
|
||||
|
||||
db.delete(generation)
|
||||
count += 1
|
||||
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
|
||||
async def delete_generations_by_profile(
|
||||
profile_id: str,
|
||||
db: Session,
|
||||
) -> int:
|
||||
"""
|
||||
Delete all generations for a profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of generations deleted
|
||||
"""
|
||||
generations = db.query(DBGeneration).filter_by(profile_id=profile_id).all()
|
||||
|
||||
count = 0
|
||||
for generation in generations:
|
||||
# Delete associated version files and rows first
|
||||
from . import versions as versions_mod
|
||||
versions_mod.delete_versions_for_generation(generation.id, db)
|
||||
|
||||
# Delete audio file
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
audio_path.unlink()
|
||||
|
||||
# Delete from database
|
||||
db.delete(generation)
|
||||
count += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
return count
|
||||
|
||||
|
||||
async def get_generation_stats(db: Session) -> dict:
|
||||
"""
|
||||
Get generation statistics.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
from sqlalchemy import func
|
||||
|
||||
total = db.query(func.count(DBGeneration.id)).scalar()
|
||||
|
||||
total_duration = db.query(func.sum(DBGeneration.duration)).scalar() or 0
|
||||
|
||||
# Get generations by profile
|
||||
by_profile = db.query(
|
||||
DBGeneration.profile_id,
|
||||
func.count(DBGeneration.id).label('count')
|
||||
).group_by(DBGeneration.profile_id).all()
|
||||
|
||||
return {
|
||||
"total_generations": total,
|
||||
"total_duration_seconds": total_duration,
|
||||
"generations_by_profile": {
|
||||
profile_id: count for profile_id, count in by_profile
|
||||
},
|
||||
}
|
||||
686
backend/services/profiles.py
Normal file
686
backend/services/profiles.py
Normal file
@@ -0,0 +1,686 @@
|
||||
"""Voice profile management module."""
|
||||
|
||||
import json as _json
|
||||
import logging
|
||||
import shutil
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import config
|
||||
from ..database import Generation as DBGeneration, ProfileSample as DBProfileSample, VoiceProfile as DBVoiceProfile
|
||||
from ..models import (
|
||||
EffectConfig,
|
||||
ProfileSampleResponse,
|
||||
VoiceProfileCreate,
|
||||
VoiceProfileResponse,
|
||||
)
|
||||
from ..utils.audio import save_audio, validate_and_load_reference_audio
|
||||
from ..utils.cache import _get_cache_dir, clear_profile_cache
|
||||
from ..utils.images import process_avatar, validate_image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CLONING_ENGINES = {"qwen", "luxtts", "chatterbox", "chatterbox_turbo", "tada"}
|
||||
|
||||
|
||||
def _profile_to_response(
|
||||
profile: DBVoiceProfile,
|
||||
generation_count: int = 0,
|
||||
sample_count: int = 0,
|
||||
) -> VoiceProfileResponse:
|
||||
"""Convert a DB profile to a VoiceProfileResponse, deserializing effects_chain."""
|
||||
effects_chain = None
|
||||
if profile.effects_chain:
|
||||
try:
|
||||
raw = _json.loads(profile.effects_chain)
|
||||
effects_chain = [EffectConfig(**e) for e in raw]
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.warning(f"Failed to parse effects_chain for profile {profile.id}: {e}")
|
||||
return VoiceProfileResponse(
|
||||
id=profile.id,
|
||||
name=profile.name,
|
||||
description=profile.description,
|
||||
language=profile.language,
|
||||
avatar_path=profile.avatar_path,
|
||||
effects_chain=effects_chain,
|
||||
voice_type=getattr(profile, "voice_type", None) or "cloned",
|
||||
preset_engine=getattr(profile, "preset_engine", None),
|
||||
preset_voice_id=getattr(profile, "preset_voice_id", None),
|
||||
design_prompt=getattr(profile, "design_prompt", None),
|
||||
default_engine=getattr(profile, "default_engine", None),
|
||||
generation_count=generation_count,
|
||||
sample_count=sample_count,
|
||||
created_at=profile.created_at,
|
||||
updated_at=profile.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _get_preset_voice_ids(engine: str) -> set[str]:
|
||||
if engine == "kokoro":
|
||||
from ..backends.kokoro_backend import KOKORO_VOICES
|
||||
|
||||
return {voice_id for voice_id, _name, _gender, _lang in KOKORO_VOICES}
|
||||
|
||||
if engine == "qwen_custom_voice":
|
||||
from ..backends.qwen_custom_voice_backend import QWEN_CUSTOM_VOICES
|
||||
|
||||
return {voice_id for voice_id, _name, _gender, _lang, _desc in QWEN_CUSTOM_VOICES}
|
||||
|
||||
return set()
|
||||
|
||||
|
||||
def _validate_profile_fields(
|
||||
*,
|
||||
voice_type: str,
|
||||
preset_engine: str | None,
|
||||
preset_voice_id: str | None,
|
||||
design_prompt: str | None,
|
||||
default_engine: str | None,
|
||||
) -> str | None:
|
||||
if voice_type == "preset":
|
||||
if not preset_engine or not preset_voice_id:
|
||||
return "Preset profiles require both preset_engine and preset_voice_id"
|
||||
if default_engine and default_engine != preset_engine:
|
||||
return "Preset profiles must use their preset_engine as default_engine"
|
||||
|
||||
available_voice_ids = _get_preset_voice_ids(preset_engine)
|
||||
if available_voice_ids and preset_voice_id not in available_voice_ids:
|
||||
return f"Preset voice '{preset_voice_id}' is not valid for engine '{preset_engine}'"
|
||||
return None
|
||||
|
||||
if voice_type == "designed":
|
||||
if not design_prompt or not design_prompt.strip():
|
||||
return "Designed profiles require a design_prompt"
|
||||
if preset_engine or preset_voice_id:
|
||||
return "Designed profiles cannot set preset_engine or preset_voice_id"
|
||||
return None
|
||||
|
||||
if preset_engine or preset_voice_id:
|
||||
return "Cloned profiles cannot set preset_engine or preset_voice_id"
|
||||
if design_prompt:
|
||||
return "Cloned profiles cannot set design_prompt"
|
||||
if default_engine and default_engine not in CLONING_ENGINES:
|
||||
return f"Cloned profiles cannot use default engine '{default_engine}'"
|
||||
return None
|
||||
|
||||
|
||||
def validate_profile_engine(profile, engine: str) -> None:
|
||||
voice_type = getattr(profile, "voice_type", None) or "cloned"
|
||||
|
||||
if voice_type == "preset":
|
||||
preset_engine = getattr(profile, "preset_engine", None)
|
||||
preset_voice_id = getattr(profile, "preset_voice_id", None)
|
||||
if not preset_engine or not preset_voice_id:
|
||||
raise ValueError(f"Preset profile {profile.id} is missing preset engine metadata")
|
||||
if preset_engine != engine:
|
||||
raise ValueError(
|
||||
f"Preset profile {profile.id} only supports engine '{preset_engine}', not '{engine}'"
|
||||
)
|
||||
return
|
||||
|
||||
if voice_type == "designed":
|
||||
design_prompt = getattr(profile, "design_prompt", None)
|
||||
if not design_prompt or not design_prompt.strip():
|
||||
raise ValueError(f"Designed profile {profile.id} is missing design_prompt")
|
||||
return
|
||||
|
||||
if engine not in CLONING_ENGINES:
|
||||
raise ValueError(f"Engine '{engine}' does not support cloned voice profiles")
|
||||
|
||||
|
||||
async def create_profile(
|
||||
data: VoiceProfileCreate,
|
||||
db: Session,
|
||||
) -> VoiceProfileResponse:
|
||||
"""
|
||||
Create a new voice profile.
|
||||
|
||||
Args:
|
||||
data: Profile creation data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created profile
|
||||
|
||||
Raises:
|
||||
ValueError: If a profile with the same name already exists
|
||||
"""
|
||||
existing_profile = db.query(DBVoiceProfile).filter_by(name=data.name).first()
|
||||
if existing_profile:
|
||||
raise ValueError(f"A profile with the name '{data.name}' already exists. Please choose a different name.")
|
||||
|
||||
# Auto-set default_engine for preset profiles
|
||||
default_engine = data.default_engine
|
||||
voice_type = data.voice_type or "cloned"
|
||||
if voice_type == "preset" and data.preset_engine and not default_engine:
|
||||
default_engine = data.preset_engine
|
||||
|
||||
validation_error = _validate_profile_fields(
|
||||
voice_type=voice_type,
|
||||
preset_engine=data.preset_engine,
|
||||
preset_voice_id=data.preset_voice_id,
|
||||
design_prompt=data.design_prompt,
|
||||
default_engine=default_engine,
|
||||
)
|
||||
if validation_error:
|
||||
raise ValueError(validation_error)
|
||||
|
||||
db_profile = DBVoiceProfile(
|
||||
id=str(uuid.uuid4()),
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
language=data.language,
|
||||
voice_type=voice_type,
|
||||
preset_engine=data.preset_engine,
|
||||
preset_voice_id=data.preset_voice_id,
|
||||
design_prompt=data.design_prompt,
|
||||
default_engine=default_engine,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(db_profile)
|
||||
db.commit()
|
||||
db.refresh(db_profile)
|
||||
|
||||
profile_dir = config.get_profiles_dir() / db_profile.id
|
||||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return _profile_to_response(db_profile)
|
||||
|
||||
|
||||
async def add_profile_sample(
|
||||
profile_id: str,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
db: Session,
|
||||
) -> ProfileSampleResponse:
|
||||
"""
|
||||
Add a sample to a voice profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
audio_path: Path to temporary audio file
|
||||
reference_text: Transcript of audio
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created sample
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile {profile_id} not found")
|
||||
|
||||
# Validate and load audio in a single pass, off the event loop
|
||||
is_valid, error_msg, audio, sr = await asyncio.to_thread(
|
||||
validate_and_load_reference_audio, audio_path
|
||||
)
|
||||
if not is_valid:
|
||||
raise ValueError(f"Invalid reference audio: {error_msg}")
|
||||
|
||||
sample_id = str(uuid.uuid4())
|
||||
profile_dir = config.get_profiles_dir() / profile_id
|
||||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dest_path = profile_dir / f"{sample_id}.wav"
|
||||
await asyncio.to_thread(save_audio, audio, str(dest_path), sr)
|
||||
|
||||
db_sample = DBProfileSample(
|
||||
id=sample_id,
|
||||
profile_id=profile_id,
|
||||
audio_path=config.to_storage_path(dest_path),
|
||||
reference_text=reference_text,
|
||||
)
|
||||
|
||||
db.add(db_sample)
|
||||
|
||||
profile.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_sample)
|
||||
|
||||
# Invalidate combined audio cache for this profile
|
||||
# Since a new sample was added, any cached combined audio is now stale
|
||||
clear_profile_cache(profile_id)
|
||||
|
||||
return ProfileSampleResponse.model_validate(db_sample)
|
||||
|
||||
|
||||
async def get_profile(
|
||||
profile_id: str,
|
||||
db: Session,
|
||||
) -> VoiceProfileResponse | None:
|
||||
"""
|
||||
Get a voice profile by ID.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Profile or None if not found
|
||||
"""
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
return None
|
||||
|
||||
return _profile_to_response(profile)
|
||||
|
||||
|
||||
async def get_profile_samples(
|
||||
profile_id: str,
|
||||
db: Session,
|
||||
) -> list[ProfileSampleResponse]:
|
||||
"""
|
||||
Get all samples for a profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of samples
|
||||
"""
|
||||
samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all()
|
||||
return [ProfileSampleResponse.model_validate(s) for s in samples]
|
||||
|
||||
|
||||
async def list_profiles(db: Session) -> list[VoiceProfileResponse]:
|
||||
"""
|
||||
List all voice profiles with generation and sample counts.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of profiles
|
||||
"""
|
||||
profiles = db.query(DBVoiceProfile).order_by(DBVoiceProfile.created_at.desc()).all()
|
||||
|
||||
if not profiles:
|
||||
return []
|
||||
|
||||
# Batch-fetch generation counts
|
||||
gen_counts_rows = (
|
||||
db.query(DBGeneration.profile_id, func.count(DBGeneration.id)).group_by(DBGeneration.profile_id).all()
|
||||
)
|
||||
gen_counts = {row[0]: row[1] for row in gen_counts_rows}
|
||||
|
||||
# Batch-fetch sample counts
|
||||
sample_counts_rows = (
|
||||
db.query(DBProfileSample.profile_id, func.count(DBProfileSample.id)).group_by(DBProfileSample.profile_id).all()
|
||||
)
|
||||
sample_counts = {row[0]: row[1] for row in sample_counts_rows}
|
||||
|
||||
return [
|
||||
_profile_to_response(
|
||||
p,
|
||||
generation_count=gen_counts.get(p.id, 0),
|
||||
sample_count=sample_counts.get(p.id, 0),
|
||||
)
|
||||
for p in profiles
|
||||
]
|
||||
|
||||
|
||||
async def update_profile(
|
||||
profile_id: str,
|
||||
data: VoiceProfileCreate,
|
||||
db: Session,
|
||||
) -> VoiceProfileResponse | None:
|
||||
"""
|
||||
Update a voice profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
data: Updated profile data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated profile or None if not found
|
||||
|
||||
Raises:
|
||||
ValueError: If a profile with the same name already exists (different profile)
|
||||
"""
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
return None
|
||||
|
||||
if profile.name != data.name:
|
||||
existing_profile = db.query(DBVoiceProfile).filter_by(name=data.name).first()
|
||||
if existing_profile:
|
||||
raise ValueError(f"A profile with the name '{data.name}' already exists. Please choose a different name.")
|
||||
|
||||
voice_type = getattr(profile, "voice_type", None) or "cloned"
|
||||
preset_engine = getattr(profile, "preset_engine", None)
|
||||
preset_voice_id = getattr(profile, "preset_voice_id", None)
|
||||
design_prompt = getattr(profile, "design_prompt", None)
|
||||
default_engine = data.default_engine if data.default_engine is not None else getattr(profile, "default_engine", None)
|
||||
|
||||
validation_error = _validate_profile_fields(
|
||||
voice_type=voice_type,
|
||||
preset_engine=preset_engine,
|
||||
preset_voice_id=preset_voice_id,
|
||||
design_prompt=design_prompt,
|
||||
default_engine=default_engine,
|
||||
)
|
||||
if validation_error:
|
||||
raise ValueError(validation_error)
|
||||
|
||||
profile.name = data.name
|
||||
profile.description = data.description
|
||||
profile.language = data.language
|
||||
if data.default_engine is not None:
|
||||
profile.default_engine = data.default_engine or None # empty string → NULL
|
||||
profile.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(profile)
|
||||
|
||||
return _profile_to_response(profile)
|
||||
|
||||
|
||||
async def delete_profile(
|
||||
profile_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a voice profile and all associated data.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
return False
|
||||
|
||||
db.query(DBProfileSample).filter_by(profile_id=profile_id).delete()
|
||||
|
||||
db.delete(profile)
|
||||
db.commit()
|
||||
|
||||
profile_dir = config.get_profiles_dir() / profile_id
|
||||
if profile_dir.exists():
|
||||
shutil.rmtree(profile_dir)
|
||||
|
||||
# Clean up combined audio cache files for this profile
|
||||
clear_profile_cache(profile_id)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def delete_profile_sample(
|
||||
sample_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a profile sample.
|
||||
|
||||
Args:
|
||||
sample_id: Sample ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
sample = db.query(DBProfileSample).filter_by(id=sample_id).first()
|
||||
if not sample:
|
||||
return False
|
||||
|
||||
# Store profile_id before deleting
|
||||
profile_id = sample.profile_id
|
||||
|
||||
audio_path = config.resolve_storage_path(sample.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
audio_path.unlink()
|
||||
|
||||
db.delete(sample)
|
||||
db.commit()
|
||||
|
||||
# Invalidate combined audio cache for this profile
|
||||
# Since the sample set changed, any cached combined audio is now stale
|
||||
clear_profile_cache(profile_id)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def update_profile_sample(
|
||||
sample_id: str,
|
||||
reference_text: str,
|
||||
db: Session,
|
||||
) -> ProfileSampleResponse | None:
|
||||
"""
|
||||
Update a profile sample's reference text.
|
||||
|
||||
Args:
|
||||
sample_id: Sample ID
|
||||
reference_text: Updated reference text
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated sample or None if not found
|
||||
"""
|
||||
sample = db.query(DBProfileSample).filter_by(id=sample_id).first()
|
||||
if not sample:
|
||||
return None
|
||||
|
||||
# Store profile_id before updating
|
||||
profile_id = sample.profile_id
|
||||
|
||||
sample.reference_text = reference_text
|
||||
db.commit()
|
||||
db.refresh(sample)
|
||||
|
||||
# Invalidate combined audio cache for this profile
|
||||
# Since the reference text changed, cache keys and combined text are now stale
|
||||
clear_profile_cache(profile_id)
|
||||
|
||||
return ProfileSampleResponse.model_validate(sample)
|
||||
|
||||
|
||||
async def create_voice_prompt_for_profile(
|
||||
profile_id: str,
|
||||
db: Session,
|
||||
use_cache: bool = True,
|
||||
engine: str = "qwen",
|
||||
) -> dict:
|
||||
"""
|
||||
Create a voice prompt from a profile.
|
||||
|
||||
For cloned profiles: combines all audio samples into a voice prompt.
|
||||
For preset profiles: returns the engine-specific preset voice reference.
|
||||
For designed profiles: returns the text design prompt (future).
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
db: Database session
|
||||
use_cache: Whether to use cached prompts
|
||||
engine: TTS engine to create prompt for
|
||||
|
||||
Returns:
|
||||
Voice prompt dictionary
|
||||
"""
|
||||
from ..backends import get_tts_backend_for_engine
|
||||
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile not found: {profile_id}")
|
||||
|
||||
voice_type = getattr(profile, "voice_type", None) or "cloned"
|
||||
validate_profile_engine(profile, engine)
|
||||
|
||||
# ── Preset profiles: return engine-specific voice reference ──
|
||||
if voice_type == "preset":
|
||||
if not profile.preset_engine or not profile.preset_voice_id:
|
||||
raise ValueError(f"Preset profile {profile_id} is missing preset engine metadata")
|
||||
if profile.preset_engine != engine:
|
||||
raise ValueError(
|
||||
f"Preset profile {profile_id} only supports engine '{profile.preset_engine}', not '{engine}'"
|
||||
)
|
||||
return {
|
||||
"voice_type": "preset",
|
||||
"preset_engine": profile.preset_engine,
|
||||
"preset_voice_id": profile.preset_voice_id,
|
||||
}
|
||||
|
||||
# ── Designed profiles: return text description (future) ──
|
||||
if voice_type == "designed":
|
||||
if not profile.design_prompt or not profile.design_prompt.strip():
|
||||
raise ValueError(f"Designed profile {profile_id} is missing design_prompt")
|
||||
return {
|
||||
"voice_type": "designed",
|
||||
"design_prompt": profile.design_prompt,
|
||||
}
|
||||
|
||||
if engine not in CLONING_ENGINES:
|
||||
raise ValueError(f"Engine '{engine}' does not support cloned voice profiles")
|
||||
|
||||
# ── Cloned profiles: create from audio samples ──
|
||||
samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all()
|
||||
|
||||
if not samples:
|
||||
raise ValueError(f"No samples found for profile {profile_id}")
|
||||
|
||||
tts_model = get_tts_backend_for_engine(engine)
|
||||
|
||||
if len(samples) == 1:
|
||||
sample = samples[0]
|
||||
sample_audio_path = config.resolve_storage_path(sample.audio_path)
|
||||
if sample_audio_path is None:
|
||||
raise ValueError(f"Sample audio not found for profile {profile_id}")
|
||||
voice_prompt, _ = await tts_model.create_voice_prompt(
|
||||
str(sample_audio_path),
|
||||
sample.reference_text,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
return voice_prompt
|
||||
|
||||
audio_paths = []
|
||||
for sample in samples:
|
||||
sample_audio_path = config.resolve_storage_path(sample.audio_path)
|
||||
if sample_audio_path is None:
|
||||
raise ValueError(f"Sample audio not found for profile {profile_id}")
|
||||
audio_paths.append(str(sample_audio_path))
|
||||
reference_texts = [s.reference_text for s in samples]
|
||||
|
||||
combined_audio, combined_text = await tts_model.combine_voice_prompts(
|
||||
audio_paths,
|
||||
reference_texts,
|
||||
)
|
||||
|
||||
# Save combined audio to cache directory (persistent)
|
||||
# Create a hash of sample IDs to identify this specific combination
|
||||
import hashlib
|
||||
|
||||
sample_ids_str = "-".join(sorted([s.id for s in samples]))
|
||||
combination_hash = hashlib.md5(sample_ids_str.encode()).hexdigest()[:12]
|
||||
|
||||
cache_dir = _get_cache_dir()
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
combined_path = cache_dir / f"combined_{profile_id}_{combination_hash}.wav"
|
||||
|
||||
save_audio(combined_audio, str(combined_path), 24000)
|
||||
|
||||
voice_prompt, _ = await tts_model.create_voice_prompt(
|
||||
str(combined_path),
|
||||
combined_text,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
return voice_prompt
|
||||
|
||||
|
||||
async def upload_avatar(
|
||||
profile_id: str,
|
||||
image_path: str,
|
||||
db: Session,
|
||||
) -> VoiceProfileResponse:
|
||||
"""
|
||||
Upload and process avatar image for a profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
image_path: Path to uploaded image file
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated profile
|
||||
"""
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile {profile_id} not found")
|
||||
|
||||
is_valid, error_msg = validate_image(image_path)
|
||||
if not is_valid:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if profile.avatar_path:
|
||||
old_avatar = config.resolve_storage_path(profile.avatar_path)
|
||||
if old_avatar is not None and old_avatar.exists():
|
||||
old_avatar.unlink()
|
||||
|
||||
# Determine file extension from uploaded file
|
||||
from PIL import Image
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
# Normalize JPEG variants (MPO is multi-picture format from some cameras)
|
||||
img_format = img.format
|
||||
if img_format in ("MPO", "JPG"):
|
||||
img_format = "JPEG"
|
||||
|
||||
ext_map = {"PNG": ".png", "JPEG": ".jpg", "WEBP": ".webp"}
|
||||
ext = ext_map.get(img_format, ".png")
|
||||
|
||||
profile_dir = config.get_profiles_dir() / profile_id
|
||||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = profile_dir / f"avatar{ext}"
|
||||
|
||||
process_avatar(image_path, str(output_path))
|
||||
|
||||
profile.avatar_path = config.to_storage_path(output_path)
|
||||
profile.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(profile)
|
||||
|
||||
return _profile_to_response(profile)
|
||||
|
||||
|
||||
async def delete_avatar(
|
||||
profile_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete avatar image for a profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found or no avatar
|
||||
"""
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile or not profile.avatar_path:
|
||||
return False
|
||||
|
||||
avatar_path = config.resolve_storage_path(profile.avatar_path)
|
||||
if avatar_path is not None and avatar_path.exists():
|
||||
avatar_path.unlink()
|
||||
|
||||
profile.avatar_path = None
|
||||
profile.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
925
backend/services/stories.py
Normal file
925
backend/services/stories.py
Normal file
@@ -0,0 +1,925 @@
|
||||
"""
|
||||
Story management module.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from .. import config
|
||||
from ..models import (
|
||||
StoryCreate,
|
||||
StoryResponse,
|
||||
StoryDetailResponse,
|
||||
StoryItemDetail,
|
||||
StoryItemCreate,
|
||||
StoryItemBatchUpdate,
|
||||
StoryItemMove,
|
||||
StoryItemTrim,
|
||||
StoryItemSplit,
|
||||
StoryItemVersionUpdate,
|
||||
)
|
||||
from ..database import (
|
||||
Story as DBStory,
|
||||
StoryItem as DBStoryItem,
|
||||
Generation as DBGeneration,
|
||||
VoiceProfile as DBVoiceProfile,
|
||||
)
|
||||
from .history import _get_versions_for_generation
|
||||
from ..utils.audio import load_audio, save_audio
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _build_item_detail(
|
||||
item: DBStoryItem,
|
||||
generation: DBGeneration,
|
||||
profile_name: str,
|
||||
db: Session,
|
||||
) -> StoryItemDetail:
|
||||
"""Build a StoryItemDetail with version info from a story item and its generation."""
|
||||
versions, active_version_id = _get_versions_for_generation(generation.id, db)
|
||||
|
||||
# Resolve the audio path: if version_id is set, use that version's audio
|
||||
audio_path = generation.audio_path
|
||||
if item.version_id and versions:
|
||||
for v in versions:
|
||||
if v.id == item.version_id:
|
||||
audio_path = v.audio_path
|
||||
break
|
||||
|
||||
return StoryItemDetail(
|
||||
id=item.id,
|
||||
story_id=item.story_id,
|
||||
generation_id=item.generation_id,
|
||||
version_id=getattr(item, "version_id", None),
|
||||
start_time_ms=item.start_time_ms,
|
||||
track=item.track,
|
||||
trim_start_ms=getattr(item, "trim_start_ms", 0),
|
||||
trim_end_ms=getattr(item, "trim_end_ms", 0),
|
||||
created_at=item.created_at,
|
||||
profile_id=generation.profile_id,
|
||||
profile_name=profile_name,
|
||||
text=generation.text,
|
||||
language=generation.language,
|
||||
audio_path=audio_path,
|
||||
duration=generation.duration,
|
||||
seed=generation.seed,
|
||||
instruct=generation.instruct,
|
||||
generation_created_at=generation.created_at,
|
||||
versions=versions,
|
||||
active_version_id=active_version_id,
|
||||
)
|
||||
|
||||
|
||||
async def create_story(
|
||||
data: StoryCreate,
|
||||
db: Session,
|
||||
) -> StoryResponse:
|
||||
"""
|
||||
Create a new story.
|
||||
|
||||
Args:
|
||||
data: Story creation data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created story
|
||||
"""
|
||||
db_story = DBStory(
|
||||
id=str(uuid.uuid4()),
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(db_story)
|
||||
db.commit()
|
||||
db.refresh(db_story)
|
||||
|
||||
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == db_story.id).scalar()
|
||||
|
||||
response = StoryResponse.model_validate(db_story)
|
||||
response.item_count = item_count
|
||||
return response
|
||||
|
||||
|
||||
async def list_stories(
|
||||
db: Session,
|
||||
) -> List[StoryResponse]:
|
||||
"""
|
||||
List all stories.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of stories with item counts
|
||||
"""
|
||||
stories = db.query(DBStory).order_by(DBStory.updated_at.desc()).all()
|
||||
|
||||
result = []
|
||||
for story in stories:
|
||||
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == story.id).scalar()
|
||||
|
||||
response = StoryResponse.model_validate(story)
|
||||
response.item_count = item_count
|
||||
result.append(response)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def get_story(
|
||||
story_id: str,
|
||||
db: Session,
|
||||
) -> Optional[StoryDetailResponse]:
|
||||
"""
|
||||
Get a story with all its items.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Story with items or None if not found
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
items = (
|
||||
db.query(DBStoryItem, DBGeneration, DBVoiceProfile.name.label("profile_name"))
|
||||
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
|
||||
.join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id)
|
||||
.filter(DBStoryItem.story_id == story_id)
|
||||
.order_by(DBStoryItem.start_time_ms)
|
||||
.all()
|
||||
)
|
||||
|
||||
item_details = []
|
||||
for item, generation, profile_name in items:
|
||||
item_details.append(_build_item_detail(item, generation, profile_name, db))
|
||||
|
||||
response = StoryDetailResponse.model_validate(story)
|
||||
response.items = item_details
|
||||
return response
|
||||
|
||||
|
||||
async def update_story(
|
||||
story_id: str,
|
||||
data: StoryCreate,
|
||||
db: Session,
|
||||
) -> Optional[StoryResponse]:
|
||||
"""
|
||||
Update a story.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
data: Update data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated story or None if not found
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
story.name = data.name
|
||||
story.description = data.description
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(story)
|
||||
|
||||
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == story.id).scalar()
|
||||
|
||||
response = StoryResponse.model_validate(story)
|
||||
response.item_count = item_count
|
||||
return response
|
||||
|
||||
|
||||
async def delete_story(
|
||||
story_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a story and all its items.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return False
|
||||
|
||||
# Delete all items
|
||||
db.query(DBStoryItem).filter_by(story_id=story_id).delete()
|
||||
|
||||
# Delete story
|
||||
db.delete(story)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def add_item_to_story(
|
||||
story_id: str,
|
||||
data: StoryItemCreate,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Add a generation to a story.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
data: Item creation data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created item detail or None if story/generation not found
|
||||
"""
|
||||
# Verify story exists
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
# Verify generation exists
|
||||
generation = db.query(DBGeneration).filter_by(id=data.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Check if generation is already in story
|
||||
existing = db.query(DBStoryItem).filter_by(story_id=story_id, generation_id=data.generation_id).first()
|
||||
if existing:
|
||||
# Return existing item
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
return _build_item_detail(existing, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
# Get track from data or default to 0
|
||||
track = data.track if data.track is not None else 0
|
||||
|
||||
# Calculate start_time_ms if not provided
|
||||
if data.start_time_ms is not None:
|
||||
start_time_ms = data.start_time_ms
|
||||
else:
|
||||
existing_items = (
|
||||
db.query(DBStoryItem, DBGeneration)
|
||||
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
|
||||
.filter(
|
||||
DBStoryItem.story_id == story_id,
|
||||
DBStoryItem.track == track,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not existing_items:
|
||||
start_time_ms = 0
|
||||
else:
|
||||
max_end_time_ms = 0
|
||||
for item, gen in existing_items:
|
||||
item_end_ms = item.start_time_ms + int(gen.duration * 1000)
|
||||
max_end_time_ms = max(max_end_time_ms, item_end_ms)
|
||||
|
||||
# Add 200ms gap after the last item
|
||||
start_time_ms = max_end_time_ms + 200
|
||||
|
||||
# Create item
|
||||
item = DBStoryItem(
|
||||
id=str(uuid.uuid4()),
|
||||
story_id=story_id,
|
||||
generation_id=data.generation_id,
|
||||
start_time_ms=start_time_ms,
|
||||
track=track,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(item)
|
||||
|
||||
# Update story updated_at
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def move_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: StoryItemMove,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Move a story item (update position and/or track).
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID
|
||||
data: New position and track data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated item detail or None if not found
|
||||
"""
|
||||
# Get the item
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return None
|
||||
|
||||
# Get the generation
|
||||
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Update position and track
|
||||
item.start_time_ms = data.start_time_ms
|
||||
item.track = data.track
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def remove_item_from_story(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Remove a story item from a story.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID to remove
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found
|
||||
"""
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return False
|
||||
|
||||
# Delete item
|
||||
db.delete(item)
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def trim_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: StoryItemTrim,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Trim a story item (update trim_start_ms and trim_end_ms).
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID
|
||||
data: Trim data (trim_start_ms, trim_end_ms)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated item detail or None if not found
|
||||
"""
|
||||
# Get the item
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return None
|
||||
|
||||
# Get the generation
|
||||
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Validate trim values don't exceed duration
|
||||
max_duration_ms = int(generation.duration * 1000)
|
||||
if data.trim_start_ms + data.trim_end_ms >= max_duration_ms:
|
||||
return None # Invalid trim - would result in zero or negative duration
|
||||
|
||||
# Update trim values
|
||||
item.trim_start_ms = data.trim_start_ms
|
||||
item.trim_end_ms = data.trim_end_ms
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def split_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: StoryItemSplit,
|
||||
db: Session,
|
||||
) -> Optional[List[StoryItemDetail]]:
|
||||
"""
|
||||
Split a story item at a given time, creating two clips.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID to split
|
||||
data: Split data (split_time_ms - time within clip to split at)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of two updated item details (original and new) or None if not found/invalid
|
||||
"""
|
||||
# Get the item with a row lock to prevent concurrent splits on the
|
||||
# same clip (e.g. from rapid double-clicks racing each other).
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return None
|
||||
|
||||
# Get the generation
|
||||
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Calculate effective duration and validate split point
|
||||
current_trim_start = getattr(item, "trim_start_ms", 0)
|
||||
current_trim_end = getattr(item, "trim_end_ms", 0)
|
||||
original_duration_ms = int(generation.duration * 1000)
|
||||
effective_duration_ms = original_duration_ms - current_trim_start - current_trim_end
|
||||
|
||||
# Validate split_time_ms is within the effective duration
|
||||
if data.split_time_ms <= 0 or data.split_time_ms >= effective_duration_ms:
|
||||
return None # Invalid split point
|
||||
|
||||
# Calculate the absolute time in the original audio where we're splitting
|
||||
absolute_split_ms = current_trim_start + data.split_time_ms
|
||||
|
||||
# Update original clip: trim from the end
|
||||
item.trim_end_ms = original_duration_ms - absolute_split_ms
|
||||
|
||||
# Create new clip: starts after the split, trimmed from the start
|
||||
new_item = DBStoryItem(
|
||||
id=str(uuid.uuid4()),
|
||||
story_id=story_id,
|
||||
generation_id=item.generation_id, # Same generation, different trim
|
||||
version_id=getattr(item, "version_id", None), # Preserve pinned version
|
||||
start_time_ms=item.start_time_ms + data.split_time_ms,
|
||||
track=item.track,
|
||||
trim_start_ms=absolute_split_ms,
|
||||
trim_end_ms=current_trim_end,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(new_item)
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
db.refresh(new_item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
profile_name = profile.name if profile else "Unknown"
|
||||
|
||||
return [
|
||||
_build_item_detail(item, generation, profile_name, db),
|
||||
_build_item_detail(new_item, generation, profile_name, db),
|
||||
]
|
||||
|
||||
|
||||
async def duplicate_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Duplicate a story item, creating a copy with all properties.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID to duplicate
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
New item detail or None if not found
|
||||
"""
|
||||
# Get the original item
|
||||
original_item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not original_item:
|
||||
return None
|
||||
|
||||
# Get the generation
|
||||
generation = db.query(DBGeneration).filter_by(id=original_item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Calculate effective duration
|
||||
current_trim_start = getattr(original_item, "trim_start_ms", 0)
|
||||
current_trim_end = getattr(original_item, "trim_end_ms", 0)
|
||||
original_duration_ms = int(generation.duration * 1000)
|
||||
effective_duration_ms = original_duration_ms - current_trim_start - current_trim_end
|
||||
|
||||
# Create duplicate item - place it right after the original
|
||||
new_item = DBStoryItem(
|
||||
id=str(uuid.uuid4()),
|
||||
story_id=story_id,
|
||||
generation_id=original_item.generation_id, # Same generation as original
|
||||
version_id=getattr(original_item, "version_id", None), # Preserve pinned version
|
||||
start_time_ms=original_item.start_time_ms + effective_duration_ms + 200, # 200ms gap
|
||||
track=original_item.track,
|
||||
trim_start_ms=current_trim_start,
|
||||
trim_end_ms=current_trim_end,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(new_item)
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(new_item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(new_item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def update_story_item_times(
|
||||
story_id: str,
|
||||
data: StoryItemBatchUpdate,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Update story item timecodes.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
data: Batch update data with timecodes
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if updated, False if story not found or invalid
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return False
|
||||
|
||||
# Get all items for this story
|
||||
items = db.query(DBStoryItem).filter_by(story_id=story_id).all()
|
||||
item_map = {item.generation_id: item for item in items}
|
||||
|
||||
# Verify all generation IDs belong to this story and update timecodes
|
||||
for update in data.updates:
|
||||
if update.generation_id not in item_map:
|
||||
return False
|
||||
item_map[update.generation_id].start_time_ms = update.start_time_ms
|
||||
|
||||
# Update story updated_at
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def reorder_story_items(
|
||||
story_id: str,
|
||||
generation_ids: List[str],
|
||||
db: Session,
|
||||
gap_ms: int = 200,
|
||||
) -> Optional[List[StoryItemDetail]]:
|
||||
"""
|
||||
Reorder story items and recalculate timecodes.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
generation_ids: List of generation IDs in the desired order
|
||||
db: Database session
|
||||
gap_ms: Gap in milliseconds between items (default 200ms)
|
||||
|
||||
Returns:
|
||||
Updated list of story items with new timecodes, or None if invalid
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
# Get all items for this story with their generation data
|
||||
items_with_gen = (
|
||||
db.query(DBStoryItem, DBGeneration, DBVoiceProfile.name.label("profile_name"))
|
||||
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
|
||||
.join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id)
|
||||
.filter(DBStoryItem.story_id == story_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create maps for quick lookup
|
||||
item_map = {item.generation_id: (item, gen, profile_name) for item, gen, profile_name in items_with_gen}
|
||||
|
||||
# Verify all generation IDs belong to this story
|
||||
if set(generation_ids) != set(item_map.keys()):
|
||||
return None
|
||||
|
||||
# Recalculate timecodes based on new order
|
||||
current_time_ms = 0
|
||||
updated_items = []
|
||||
|
||||
for gen_id in generation_ids:
|
||||
item, generation, profile_name = item_map[gen_id]
|
||||
|
||||
# Update the item's start time
|
||||
item.start_time_ms = current_time_ms
|
||||
|
||||
# Calculate the duration in ms
|
||||
duration_ms = int(generation.duration * 1000)
|
||||
|
||||
# Move to next position (current end + gap)
|
||||
current_time_ms += duration_ms + gap_ms
|
||||
|
||||
# Build the response item
|
||||
updated_items.append(_build_item_detail(item, generation, profile_name, db))
|
||||
|
||||
# Update story updated_at
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
return updated_items
|
||||
|
||||
|
||||
async def set_story_item_version(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: StoryItemVersionUpdate,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Pin a story item to a specific generation version.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID
|
||||
data: Version update data (version_id or null for default)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated item detail or None if not found
|
||||
"""
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return None
|
||||
|
||||
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Validate version_id belongs to this generation if provided
|
||||
if data.version_id:
|
||||
from ..database import GenerationVersion as DBGenerationVersion
|
||||
|
||||
version = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(
|
||||
id=data.version_id,
|
||||
generation_id=item.generation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not version:
|
||||
return None
|
||||
|
||||
item.version_id = data.version_id
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def export_story_audio(
|
||||
story_id: str,
|
||||
db: Session,
|
||||
) -> Optional[bytes]:
|
||||
"""
|
||||
Export story as single mixed audio file with timecode-based mixing.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Audio file bytes or None if story not found
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
# Get all items ordered by start_time_ms
|
||||
items = (
|
||||
db.query(DBStoryItem, DBGeneration)
|
||||
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
|
||||
.filter(DBStoryItem.story_id == story_id)
|
||||
.order_by(DBStoryItem.start_time_ms)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not items:
|
||||
return None
|
||||
|
||||
# Load all audio files and calculate total duration
|
||||
audio_data = []
|
||||
sample_rate = 24000 # Default sample rate
|
||||
|
||||
for item, generation in items:
|
||||
# Resolve audio path: use pinned version if set, otherwise generation default
|
||||
resolved_audio_path = generation.audio_path
|
||||
if getattr(item, "version_id", None):
|
||||
from ..database import GenerationVersion as DBGenerationVersion
|
||||
|
||||
version = db.query(DBGenerationVersion).filter_by(id=item.version_id).first()
|
||||
if version:
|
||||
resolved_audio_path = version.audio_path
|
||||
|
||||
audio_path = config.resolve_storage_path(resolved_audio_path)
|
||||
if audio_path is None or not audio_path.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
audio, sr = load_audio(str(audio_path), sample_rate=sample_rate)
|
||||
sample_rate = sr # Use actual sample rate from first file
|
||||
|
||||
# Get trim values
|
||||
trim_start_ms = getattr(item, "trim_start_ms", 0)
|
||||
trim_end_ms = getattr(item, "trim_end_ms", 0)
|
||||
|
||||
# Calculate effective duration
|
||||
original_duration_ms = int(generation.duration * 1000)
|
||||
effective_duration_ms = original_duration_ms - trim_start_ms - trim_end_ms
|
||||
|
||||
# Slice audio based on trim values
|
||||
trim_start_sample = int((trim_start_ms / 1000.0) * sample_rate)
|
||||
trim_end_sample = int((trim_end_ms / 1000.0) * sample_rate)
|
||||
|
||||
# Extract the trimmed portion
|
||||
if trim_end_ms > 0:
|
||||
trimmed_audio = (
|
||||
audio[trim_start_sample:-trim_end_sample] if trim_end_sample > 0 else audio[trim_start_sample:]
|
||||
)
|
||||
else:
|
||||
trimmed_audio = audio[trim_start_sample:]
|
||||
|
||||
# Store audio with its timecode info
|
||||
start_time_ms = item.start_time_ms
|
||||
|
||||
audio_data.append(
|
||||
{
|
||||
"audio": trimmed_audio,
|
||||
"start_time_ms": start_time_ms,
|
||||
"duration_ms": effective_duration_ms,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Skip files that can't be loaded
|
||||
continue
|
||||
|
||||
if not audio_data:
|
||||
return None
|
||||
|
||||
# Calculate total duration: max(start_time_ms + duration_ms)
|
||||
max_end_time_ms = max((data["start_time_ms"] + data["duration_ms"] for data in audio_data), default=0)
|
||||
|
||||
# Convert to samples
|
||||
total_samples = int((max_end_time_ms / 1000.0) * sample_rate)
|
||||
|
||||
# Create output buffer initialized to zeros
|
||||
final_audio = np.zeros(total_samples, dtype=np.float32)
|
||||
|
||||
# Mix each audio segment at its timecode position
|
||||
for data in audio_data:
|
||||
audio = data["audio"]
|
||||
start_time_ms = data["start_time_ms"]
|
||||
|
||||
# Calculate start sample index
|
||||
start_sample = int((start_time_ms / 1000.0) * sample_rate)
|
||||
|
||||
# Ensure we don't exceed buffer bounds
|
||||
audio_length = len(audio)
|
||||
end_sample = min(start_sample + audio_length, total_samples)
|
||||
|
||||
if start_sample < total_samples:
|
||||
# Trim audio if it extends beyond buffer
|
||||
audio_to_mix = audio[: end_sample - start_sample]
|
||||
|
||||
# Mix: add audio to existing buffer (overlapping audio will sum)
|
||||
# Normalize to prevent clipping (simple approach: divide by max)
|
||||
final_audio[start_sample:end_sample] += audio_to_mix
|
||||
|
||||
# Normalize to prevent clipping
|
||||
max_val = np.abs(final_audio).max()
|
||||
if max_val > 1.0:
|
||||
final_audio = final_audio / max_val
|
||||
|
||||
# Save to temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
save_audio(final_audio, tmp_path, sample_rate)
|
||||
|
||||
# Read file bytes
|
||||
with open(tmp_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
return audio_bytes
|
||||
finally:
|
||||
# Clean up temp file
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
108
backend/services/task_queue.py
Normal file
108
backend/services/task_queue.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Serial generation queue — ensures only one TTS inference runs at a time
|
||||
to avoid GPU contention.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Coroutine, Literal
|
||||
|
||||
# Keep references to fire-and-forget background tasks to prevent GC
|
||||
_background_tasks: set = set()
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationJob:
|
||||
"""Queued generation work plus the generation ID it belongs to."""
|
||||
|
||||
generation_id: str
|
||||
coro: Coroutine
|
||||
|
||||
|
||||
# Generation queue — serializes TTS inference to avoid GPU contention
|
||||
_generation_queue: asyncio.Queue = None # type: ignore # initialized at startup
|
||||
_generation_worker_task: asyncio.Task | None = None
|
||||
_queued_generation_ids: set[str] = set()
|
||||
_running_generation_tasks: dict[str, asyncio.Task] = {}
|
||||
_cancelled_generation_ids: set[str] = set()
|
||||
|
||||
|
||||
def create_background_task(coro) -> asyncio.Task:
|
||||
"""Create a background task and prevent it from being garbage collected."""
|
||||
task = asyncio.create_task(coro)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
return task
|
||||
|
||||
|
||||
async def _generation_worker():
|
||||
"""Worker that processes generation tasks one at a time."""
|
||||
while True:
|
||||
job = await _generation_queue.get()
|
||||
try:
|
||||
if job.generation_id in _cancelled_generation_ids:
|
||||
_cancelled_generation_ids.discard(job.generation_id)
|
||||
job.coro.close()
|
||||
continue
|
||||
|
||||
task = asyncio.create_task(job.coro)
|
||||
_running_generation_tasks[job.generation_id] = task
|
||||
_queued_generation_ids.discard(job.generation_id)
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
if not task.cancelled():
|
||||
raise
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
_running_generation_tasks.pop(job.generation_id, None)
|
||||
_queued_generation_ids.discard(job.generation_id)
|
||||
_generation_queue.task_done()
|
||||
|
||||
|
||||
def enqueue_generation(generation_id: str, coro):
|
||||
"""Add a generation coroutine to the serial queue."""
|
||||
if _generation_queue is None:
|
||||
raise RuntimeError("Generation queue has not been initialized")
|
||||
|
||||
_queued_generation_ids.add(generation_id)
|
||||
_generation_queue.put_nowait(GenerationJob(generation_id=generation_id, coro=coro))
|
||||
|
||||
|
||||
def cancel_generation(generation_id: str) -> Literal["queued", "running"] | None:
|
||||
"""Cancel a queued or running generation if it is still active."""
|
||||
running_task = _running_generation_tasks.get(generation_id)
|
||||
if running_task is not None:
|
||||
running_task.cancel()
|
||||
return "running"
|
||||
|
||||
if generation_id in _queued_generation_ids:
|
||||
_queued_generation_ids.discard(generation_id)
|
||||
_cancelled_generation_ids.add(generation_id)
|
||||
return "queued"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def init_queue(force: bool = False):
|
||||
"""Initialize the generation queue and start the worker.
|
||||
|
||||
Must be called once during application startup (inside a running event loop).
|
||||
"""
|
||||
global _generation_queue, _generation_worker_task
|
||||
global _queued_generation_ids, _running_generation_tasks, _cancelled_generation_ids
|
||||
|
||||
if _generation_worker_task is not None and not _generation_worker_task.done():
|
||||
if not force:
|
||||
return
|
||||
_generation_worker_task.cancel()
|
||||
for task in list(_running_generation_tasks.values()):
|
||||
task.cancel()
|
||||
|
||||
_generation_queue = asyncio.Queue()
|
||||
_queued_generation_ids = set()
|
||||
_running_generation_tasks = {}
|
||||
_cancelled_generation_ids = set()
|
||||
_generation_worker_task = create_background_task(_generation_worker())
|
||||
22
backend/services/transcribe.py
Normal file
22
backend/services/transcribe.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
STT (Speech-to-Text) module - delegates to backend abstraction layer.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from ..backends import get_stt_backend, STTBackend
|
||||
|
||||
|
||||
def get_whisper_model() -> STTBackend:
|
||||
"""
|
||||
Get STT backend instance (MLX or PyTorch based on platform).
|
||||
|
||||
Returns:
|
||||
STT backend instance
|
||||
"""
|
||||
return get_stt_backend()
|
||||
|
||||
|
||||
def unload_whisper_model():
|
||||
"""Unload Whisper model to free memory."""
|
||||
backend = get_stt_backend()
|
||||
backend.unload_model()
|
||||
34
backend/services/tts.py
Normal file
34
backend/services/tts.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
TTS inference module - delegates to backend abstraction layer.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
import io
|
||||
import soundfile as sf
|
||||
|
||||
from ..backends import get_tts_backend, TTSBackend
|
||||
|
||||
|
||||
def get_tts_model() -> TTSBackend:
|
||||
"""
|
||||
Get TTS backend instance (MLX or PyTorch based on platform).
|
||||
|
||||
Returns:
|
||||
TTS backend instance
|
||||
"""
|
||||
return get_tts_backend()
|
||||
|
||||
|
||||
def unload_tts_model():
|
||||
"""Unload TTS model to free memory."""
|
||||
backend = get_tts_backend()
|
||||
backend.unload_model()
|
||||
|
||||
|
||||
def audio_to_wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes:
|
||||
"""Convert audio array to WAV bytes."""
|
||||
buffer = io.BytesIO()
|
||||
sf.write(buffer, audio, sample_rate, format="WAV")
|
||||
buffer.seek(0)
|
||||
return buffer.read()
|
||||
211
backend/services/versions.py
Normal file
211
backend/services/versions.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Generation versions management module.
|
||||
|
||||
Each generation can have multiple audio versions: a clean (unprocessed)
|
||||
version and any number of processed versions with different effects chains.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..database import (
|
||||
GenerationVersion as DBGenerationVersion,
|
||||
Generation as DBGeneration,
|
||||
)
|
||||
from ..models import GenerationVersionResponse, EffectConfig
|
||||
from .. import config
|
||||
|
||||
|
||||
def _version_response(v: DBGenerationVersion) -> GenerationVersionResponse:
|
||||
"""Convert a DB version row to a Pydantic response."""
|
||||
effects_chain = None
|
||||
if v.effects_chain:
|
||||
raw = json.loads(v.effects_chain)
|
||||
effects_chain = [EffectConfig(**e) for e in raw]
|
||||
return GenerationVersionResponse(
|
||||
id=v.id,
|
||||
generation_id=v.generation_id,
|
||||
label=v.label,
|
||||
audio_path=v.audio_path,
|
||||
effects_chain=effects_chain,
|
||||
source_version_id=v.source_version_id,
|
||||
is_default=v.is_default,
|
||||
created_at=v.created_at,
|
||||
)
|
||||
|
||||
|
||||
def list_versions(generation_id: str, db: Session) -> List[GenerationVersionResponse]:
|
||||
"""List all versions for a generation."""
|
||||
versions = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=generation_id)
|
||||
.order_by(DBGenerationVersion.created_at)
|
||||
.all()
|
||||
)
|
||||
return [_version_response(v) for v in versions]
|
||||
|
||||
|
||||
def get_version(version_id: str, db: Session) -> Optional[GenerationVersionResponse]:
|
||||
"""Get a specific version by ID."""
|
||||
v = db.query(DBGenerationVersion).filter_by(id=version_id).first()
|
||||
if not v:
|
||||
return None
|
||||
return _version_response(v)
|
||||
|
||||
|
||||
def get_default_version(generation_id: str, db: Session) -> Optional[GenerationVersionResponse]:
|
||||
"""Get the default version for a generation."""
|
||||
v = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=generation_id, is_default=True)
|
||||
.first()
|
||||
)
|
||||
if not v:
|
||||
# Fallback: return the first version
|
||||
v = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=generation_id)
|
||||
.order_by(DBGenerationVersion.created_at)
|
||||
.first()
|
||||
)
|
||||
if not v:
|
||||
return None
|
||||
return _version_response(v)
|
||||
|
||||
|
||||
def create_version(
|
||||
generation_id: str,
|
||||
label: str,
|
||||
audio_path: str,
|
||||
db: Session,
|
||||
effects_chain: Optional[List[dict]] = None,
|
||||
is_default: bool = False,
|
||||
source_version_id: Optional[str] = None,
|
||||
) -> GenerationVersionResponse:
|
||||
"""Create a new version for a generation.
|
||||
|
||||
If ``is_default`` is True, all other versions for this generation
|
||||
are un-defaulted first.
|
||||
"""
|
||||
if is_default:
|
||||
_clear_defaults(generation_id, db)
|
||||
|
||||
version = DBGenerationVersion(
|
||||
id=str(uuid.uuid4()),
|
||||
generation_id=generation_id,
|
||||
label=label,
|
||||
audio_path=audio_path,
|
||||
effects_chain=json.dumps(effects_chain) if effects_chain else None,
|
||||
source_version_id=source_version_id,
|
||||
is_default=is_default,
|
||||
)
|
||||
db.add(version)
|
||||
db.commit()
|
||||
db.refresh(version)
|
||||
|
||||
# If this version is the default, update the generation's audio_path
|
||||
if is_default:
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if gen:
|
||||
gen.audio_path = audio_path
|
||||
db.commit()
|
||||
|
||||
return _version_response(version)
|
||||
|
||||
|
||||
def set_default_version(version_id: str, db: Session) -> Optional[GenerationVersionResponse]:
|
||||
"""Set a version as the default for its generation."""
|
||||
version = db.query(DBGenerationVersion).filter_by(id=version_id).first()
|
||||
if not version:
|
||||
return None
|
||||
|
||||
_clear_defaults(version.generation_id, db)
|
||||
version.is_default = True
|
||||
db.commit()
|
||||
db.refresh(version)
|
||||
|
||||
# Update generation's audio_path to point to this version
|
||||
gen = db.query(DBGeneration).filter_by(id=version.generation_id).first()
|
||||
if gen:
|
||||
gen.audio_path = version.audio_path
|
||||
db.commit()
|
||||
|
||||
return _version_response(version)
|
||||
|
||||
|
||||
def delete_version(version_id: str, db: Session) -> bool:
|
||||
"""Delete a version. Cannot delete the last remaining version."""
|
||||
version = db.query(DBGenerationVersion).filter_by(id=version_id).first()
|
||||
if not version:
|
||||
return False
|
||||
|
||||
# Don't allow deleting the last version
|
||||
count = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=version.generation_id)
|
||||
.count()
|
||||
)
|
||||
if count <= 1:
|
||||
return False
|
||||
|
||||
was_default = version.is_default
|
||||
gen_id = version.generation_id
|
||||
|
||||
# Delete audio file
|
||||
audio_path = config.resolve_storage_path(version.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
audio_path.unlink()
|
||||
|
||||
db.delete(version)
|
||||
db.commit()
|
||||
|
||||
# If this was the default, promote the first remaining version
|
||||
if was_default:
|
||||
first = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=gen_id)
|
||||
.order_by(DBGenerationVersion.created_at)
|
||||
.first()
|
||||
)
|
||||
if first:
|
||||
first.is_default = True
|
||||
db.commit()
|
||||
gen = db.query(DBGeneration).filter_by(id=gen_id).first()
|
||||
if gen:
|
||||
gen.audio_path = first.audio_path
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def delete_versions_for_generation(generation_id: str, db: Session) -> int:
|
||||
"""Delete all versions for a generation (used when deleting a generation)."""
|
||||
versions = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=generation_id)
|
||||
.all()
|
||||
)
|
||||
count = 0
|
||||
for v in versions:
|
||||
audio_path = config.resolve_storage_path(v.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
audio_path.unlink()
|
||||
db.delete(v)
|
||||
count += 1
|
||||
if count > 0:
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
|
||||
def _clear_defaults(generation_id: str, db: Session) -> None:
|
||||
"""Clear the is_default flag on all versions for a generation."""
|
||||
db.query(DBGenerationVersion).filter_by(
|
||||
generation_id=generation_id, is_default=True
|
||||
).update({"is_default": False})
|
||||
db.flush()
|
||||
Reference in New Issue
Block a user