Initial commit

This commit is contained in:
2026-04-24 19:18:15 +08:00
commit fbcbe08696
555 changed files with 96692 additions and 0 deletions

135
backend/README.md Normal file
View File

@@ -0,0 +1,135 @@
# Voicebox Backend
FastAPI server powering voice cloning, speech generation, and audio processing. Runs locally as a Tauri sidecar or standalone via `python -m backend.main`.
## Running
```bash
# Via justfile (recommended)
just dev:server
# Standalone
python -m backend.main --host 127.0.0.1 --port 17493
# With custom data directory
python -m backend.main --data-dir /path/to/data
```
The server auto-initializes the SQLite database on first startup. Models are downloaded from HuggingFace on first use.
## Architecture
```
backend/
app.py # FastAPI app factory, CORS, lifecycle events
main.py # Entry point (imports app, runs uvicorn)
config.py # Data directory paths and configuration
models.py # Pydantic request/response schemas
server.py # Tauri sidecar launcher, parent-pid watchdog
routes/ # Thin HTTP handlers — validation, delegation, response formatting
services/ # Business logic, CRUD, orchestration
backends/ # TTS/STT engine implementations (MLX, PyTorch, etc.)
database/ # ORM models, session management, migrations, seed data
utils/ # Shared utilities (audio, effects, caching, progress tracking)
```
### Request flow
```
HTTP request
-> routes/ (validate input, parse params)
-> services/ (business logic, database queries, orchestration)
-> backends/ (TTS/STT inference)
-> utils/ (audio processing, effects, caching)
```
Route handlers are intentionally thin. They validate input, delegate to a service function, and format the response. All business logic lives in `services/`.
### Key modules
**services/generation.py** -- Single `run_generation()` function that handles all three generation modes (generate, retry, regenerate). Manages model loading, voice prompt creation, chunked inference, normalization, effects, and version persistence.
**services/task_queue.py** -- Serial generation queue. Ensures only one GPU inference runs at a time. Background tasks are tracked to prevent garbage collection.
**backends/__init__.py** -- Protocol definitions (`TTSBackend`, `STTBackend`), model config registry, and factory functions. Adding a new engine means implementing the protocol and registering a config entry.
**backends/base.py** -- Shared utilities used across all engine implementations: HuggingFace cache checks, device detection, voice prompt combination, progress tracking.
**database/** -- SQLAlchemy ORM models with a re-exporting `__init__.py` for backward compatibility. Migrations run automatically on startup.
### Backend selection
The server detects the best inference backend at startup:
| Platform | Backend | Acceleration |
|----------|---------|-------------|
| macOS (Apple Silicon) | MLX | Metal / Neural Engine |
| Windows / Linux (NVIDIA) | PyTorch | CUDA |
| Linux (AMD) | PyTorch | ROCm |
| Intel Arc | PyTorch | IPEX / XPU |
| Windows (any GPU) | PyTorch | DirectML |
| Any | PyTorch | CPU fallback |
Detection is handled by `utils/platform_detect.py`. Both backends implement the same `TTSBackend` protocol, so the API layer is engine-agnostic.
## API
90 endpoints organized by domain. Full interactive documentation available at `http://localhost:17493/docs` when the server is running.
| Domain | Prefix | Description |
|--------|--------|-------------|
| Health | `/`, `/health` | Server status, GPU info, filesystem checks |
| Profiles | `/profiles` | Voice profile CRUD, samples, avatars, import/export |
| Channels | `/channels` | Audio channel management and voice assignment |
| Generation | `/generate` | TTS generation, retry, regenerate, status SSE |
| History | `/history` | Generation history, search, favorites, export |
| Transcription | `/transcribe` | Whisper-based audio-to-text |
| Stories | `/stories` | Multi-track timeline editor, audio export |
| Effects | `/effects` | Effect presets, preview, version management |
| Audio | `/audio`, `/samples` | Audio file serving |
| Models | `/models` | Load, unload, download, migrate, status |
| Tasks | `/tasks`, `/cache` | Active task tracking, cache management |
| CUDA | `/backend/cuda-*` | CUDA binary download and management |
### Quick examples
```bash
# Generate speech
curl -X POST http://localhost:17493/generate \
-H "Content-Type: application/json" \
-d '{"text": "Hello world", "profile_id": "...", "language": "en"}'
# List profiles
curl http://localhost:17493/profiles
# Stream generation status (SSE)
curl http://localhost:17493/generate/{id}/status
```
## Data directory
```
{data_dir}/
voicebox.db # SQLite database
profiles/{id}/ # Voice samples per profile
generations/ # Generated audio files
cache/ # Voice prompt cache (memory + disk)
backends/ # Downloaded CUDA binary (if applicable)
```
Default location is the OS-specific app data directory. Override with `--data-dir` or the `VOICEBOX_DATA_DIR` environment variable.
## Code quality
Linting and formatting are enforced by [ruff](https://docs.astral.sh/ruff/), configured in `pyproject.toml`. See `STYLE_GUIDE.md` for conventions.
```bash
just check-python # lint + format check
just fix-python # auto-fix lint issues + reformat
just test # run pytest
```
## Dependencies
Runtime dependencies are in `requirements.txt`. macOS-only MLX dependencies are in `requirements-mlx.txt`. Dev tools (ruff, pytest) are installed automatically by `just setup-python`.

404
backend/STYLE_GUIDE.md Normal file
View File

@@ -0,0 +1,404 @@
# Python Style Guide
Target: **Python 3.12+** | Formatter/Linter: **Ruff** | Config: `backend/pyproject.toml`
This guide codifies the conventions used across the backend, and prescribes the target style for code written during the refactor (Phases 3-6). Existing code should be migrated incrementally -- don't reformat entire files in unrelated PRs.
---
## Formatting
Enforced by `ruff format` (Black-compatible).
- **Line length**: 120 characters.
- **Indent**: 4 spaces. No tabs.
- **Trailing commas**: Required on multi-line function signatures, arguments, collections.
- **Quotes**: Double quotes (`"`) for strings. Single quotes are acceptable in f-string expressions and dict keys inside f-strings where avoiding escapes improves readability.
Run: `ruff format backend/`
---
## Imports
Enforced by ruff's `isort` rules (rule set `I`).
**Grouping** -- three blocks separated by a blank line:
```python
import asyncio # 1. stdlib
from pathlib import Path
import numpy as np # 2. third-party
from fastapi import APIRouter, HTTPException
from sqlalchemy.orm import Session
from backend.config import get_data_dir # 3. local (absolute)
from .database import get_db # or relative
```
**Rules:**
- Within the `backend` package, use **relative imports** for sibling/child modules: `from .database import get_db`, `from ..utils.audio import load_audio`.
- Absolute imports are fine for top-level references from entry points (`main.py`, `server.py`).
- Never use wildcard imports (`from module import *`).
- One import per line for `from X import Y` when there are 4+ names; below that, comma-separated is fine.
- **Lazy imports** are acceptable for heavy dependencies (torch, transformers, mlx) inside functions to reduce startup time. Add a comment: `# lazy: heavy import`.
---
## Type Annotations
Python 3.12 means we use **built-in generics and union syntax natively**. No `from __future__ import annotations`, no `typing.List`/`typing.Dict`.
```python
# Yes
def process(items: list[str], config: dict[str, int] | None = None) -> tuple[int, str]: ...
# No
from typing import List, Dict, Optional, Tuple
def process(items: List[str], config: Optional[Dict[str, int]] = None) -> Tuple[int, str]: ...
```
**What to annotate:**
- All public function signatures (parameters + return type).
- Private functions: parameters at minimum; return type encouraged.
- Module-level variables: only when the type isn't obvious from the assignment.
- Route handlers: parameters are annotated via FastAPI's dependency injection. Add explicit `-> SomeResponse` return types when the route doesn't use `response_model`.
**Imports from `typing` that are still needed** (no built-in equivalent):
`Literal`, `TypeAlias`, `Protocol`, `runtime_checkable`, `Callable`, `Any`, `ClassVar`, `TypeVar`, `overload`, `TYPE_CHECKING`.
Use `collections.abc` for abstract types: `Sequence`, `Mapping`, `Iterable`, `Iterator`, `Generator`.
---
## Naming
| Thing | Convention | Example |
|-------|-----------|---------|
| Module | `snake_case` | `task_queue.py` |
| Class | `PascalCase` | `ProgressManager` |
| Function / method | `snake_case` | `create_profile` |
| Variable | `snake_case` | `sample_rate` |
| Constant | `UPPER_SNAKE_CASE` | `DEFAULT_SAMPLE_RATE` |
| Private | `_leading_underscore` | `_generation_queue` |
| Type alias | `PascalCase` | `EffectChain = list[dict[str, Any]]` |
**Specific conventions:**
- Database ORM models imported with `DB` prefix alias: `from .database import VoiceProfile as DBVoiceProfile`.
- Pydantic models use descriptive suffixes: `VoiceProfileCreate`, `VoiceProfileResponse`, `GenerationRequest`.
- Backend classes use engine-name prefix: `MLXTTSBackend`, `PyTorchSTTBackend`.
---
## Docstrings
**Google style**. Required on all public functions, classes, and modules.
```python
def combine_voice_prompts(
profile_dir: Path,
*,
target_sr: int = 24000,
) -> tuple[np.ndarray, int]:
"""Load and concatenate all voice prompt files for a profile.
Reads .wav/.mp3/.flac files from the profile directory, resamples to
the target sample rate, normalizes, and concatenates into a single array.
Args:
profile_dir: Path to the voice profile directory containing audio files.
target_sr: Target sample rate for the output. Defaults to 24000.
Returns:
Tuple of (concatenated audio array, sample rate).
Raises:
FileNotFoundError: If profile_dir does not exist.
ValueError: If no valid audio files are found.
"""
```
**Short form** is fine for simple functions:
```python
def get_db_path() -> Path:
"""Get the path to the SQLite database file."""
```
**When to skip**: Private helpers under ~5 lines where the name and signature make intent obvious.
**Module docstrings**: A single sentence at the top of every file describing its purpose.
```python
"""Voice profile CRUD operations."""
```
---
## Comments
Comments explain **why**, not **what**. If the code needs a comment to explain what it does, the code should be rewritten to be clearer. The exceptions are non-obvious performance choices, external constraints, and concurrency/race-condition reasoning -- those always deserve a comment.
### No section dividers
Do not use ASCII dividers to create visual sections in files:
```python
# No -- any of these:
# ============================================
# GENERATION ENDPOINTS
# ============================================
# ---------------------------------------------------------------------------
# Device detection
# ---------------------------------------------------------------------------
# --- Load model --------------------------------------------------
```
If a file needs section dividers to be navigable, the file is too long. Split it into modules. Within a function, if you need labeled sections to follow the logic, extract those sections into named functions.
### Inline comments
Inline comments (end-of-line) are fine when they add information the code can't express:
```python
# Yes -- explains a non-obvious constraint or gives context:
audio, sr = load_audio(path, sr=24000) # Qwen expects 24kHz mono
_generation_queue: asyncio.Queue = None # type: ignore # initialized at startup
"tauri://localhost", # Tauri webview (macOS)
# No -- restates the code:
# Check if profile name already exists
existing = db.query(DBVoiceProfile).filter_by(name=data.name).first()
# Delete from database
db.delete(sample)
# Update fields
profile.name = data.name
```
Delete comments that narrate what the next line of code obviously does. If the function name, variable name, or method call already communicates intent, the comment is noise.
### Block comments
Use block comments for **why** explanations -- constraints, workarounds, non-obvious decisions:
```python
# PyInstaller + multiprocessing: child processes re-execute the frozen binary
# with internal arguments. freeze_support() handles this and exits early.
multiprocessing.freeze_support()
# Mark any stale "generating" records as failed -- these are leftovers
# from a previous process that was killed mid-generation.
db.query(Generation).filter_by(status="generating").update({"status": "failed"})
```
Keep block comments tight. Two to three lines is normal. If you need a paragraph, it probably belongs in the docstring or a design doc.
### Linter/type-checker suppression
Always add a reason after `noqa` and `type: ignore`:
```python
import intel_extension_for_pytorch # noqa: F401 -- side-effect import enables XPU
_queue: asyncio.Queue = None # type: ignore[assignment] # initialized at startup
```
Bare `# noqa` or `# type: ignore` with no explanation are not allowed.
### TODO / FIXME
Use sparingly. Every `TODO` must include a brief description of what needs doing. Don't use them as a substitute for tracking work properly:
```python
# TODO: replace with async SQLAlchemy once CRUD modules are migrated (Phase 5)
result = await asyncio.to_thread(profiles.get_profile, profile_id, db)
```
Never commit `HACK`, `XXX`, or `FIXME` -- fix the problem or file an issue.
### Commented-out code
Delete it. That's what git is for. If you need to document that something was intentionally removed, a short tombstone comment is acceptable:
```python
# Removed config.json-only check -- too lenient, doesn't confirm weights exist.
```
---
## Error Handling
The refactor is standardizing on a **two-layer pattern**:
### 1. Domain layer -- raise plain exceptions
CRUD modules and services raise `ValueError`, `FileNotFoundError`, or (post-refactor) custom exceptions defined in `backend/errors.py`:
```python
# backend/errors.py (to be created in Phase 4)
class NotFoundError(Exception):
"""Raised when a requested resource does not exist."""
class ConflictError(Exception):
"""Raised on uniqueness constraint violations."""
```
```python
# In a service or CRUD module:
raise NotFoundError(f"Profile {profile_id} not found")
```
### 2. Route layer -- translate to HTTPException
Route handlers catch domain exceptions and convert:
```python
@router.post("/profiles")
async def create_profile(data: VoiceProfileCreate, db: Session = Depends(get_db)):
try:
return await profiles.create_profile(data, db)
except ConflictError as e:
raise HTTPException(status_code=409, detail=str(e))
```
**Background tasks** catch `Exception` broadly, log with `logger.exception()`, and update the task status to `"failed"`.
**Never**: silently swallow exceptions, use bare `except:`, or catch `BaseException`.
---
## Async
### Rules for the refactor
1. **Don't declare `async def` unless the function awaits something.** Several service modules still declare `async def` without awaiting -- these should be migrated to sync functions with `asyncio.to_thread()` at the route layer, or to real async SQLAlchemy.
2. **CPU-bound work** (audio processing, numpy operations) goes through `asyncio.to_thread()`:
```python
audio, sr = await asyncio.to_thread(load_audio, source_path)
```
3. **GPU-bound TTS inference** is serialized through the generation queue (`services/task_queue.py`). Never call a backend's `generate()` directly from a route handler.
4. **Fire-and-forget tasks**: use `asyncio.create_task()` and track the task reference to prevent garbage collection:
```python
task = asyncio.create_task(some_coro())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
```
---
## Logging
Use the `logging` module. Not `print()`.
```python
import logging
logger = logging.getLogger(__name__)
logger.info("Loading model %s on %s", model_name, device)
logger.warning("Cache miss for %s, downloading", repo_id)
logger.exception("Generation %s failed") # logs traceback automatically
```
**Rules:**
- Use `%s`-style placeholders in log calls (not f-strings). This avoids formatting the string if the log level is filtered out.
- Use `logger.exception()` inside `except` blocks -- it captures the traceback.
- Logger name should be `__name__` (yields `backend.utils.audio`, etc.).
- Existing `print()` calls should be migrated to logging as files are touched during the refactor.
---
## Constants
- Define at **module level** in the file where they're primarily used.
- Use `UPPER_SNAKE_CASE`.
- Shared/cross-cutting constants (sample rates, file size limits, CORS origins) go in `backend/config.py` after Phase 6 consolidation.
- Magic numbers in function bodies should be extracted to named constants:
```python
# No
if len(audio) > 24000 * 60 * 10:
# Yes
MAX_AUDIO_DURATION_SAMPLES = SAMPLE_RATE * 60 * 10
if len(audio) > MAX_AUDIO_DURATION_SAMPLES:
```
---
## Function Signatures
- **Keyword-only arguments** (after `*`) for functions with 3+ parameters, especially when several share the same type:
```python
def is_model_cached(
hf_repo: str,
*,
weight_extensions: tuple[str, ...] = (".safetensors", ".bin"),
required_files: list[str] | None = None,
) -> bool:
```
- Parameters on **separate lines** when the signature exceeds ~100 characters or has 3+ params.
- **Trailing comma** after the last parameter in multi-line signatures.
- Default values inline with the parameter.
---
## String Formatting
- **f-strings** for runtime string construction.
- **`%s`-style** for `logging` calls (lazy evaluation).
- **`.format()`**: avoid; f-strings are preferred.
---
## Testing
Framework: **pytest** with `pytest-asyncio`.
- Test files: `test_<module>.py` in `backend/tests/`.
- Use `conftest.py` for shared fixtures (db sessions, test client, mock backends).
- Group related tests in classes: `class TestProfileCRUD:`.
- Use `@pytest.mark.asyncio` for async tests.
- Use `@pytest.mark.parametrize` to reduce repetition.
- Manual integration scripts stay in `tests/` but are clearly marked (filename prefix `manual_` or documented in `tests/README.md`).
---
## Project Layout
```
backend/
app.py # FastAPI app factory, CORS, lifecycle events
main.py # Entry point (imports app, runs uvicorn)
config.py # Data directory paths
models.py # Pydantic request/response schemas
server.py # Tauri sidecar launcher, parent-pid watchdog
routes/ # Thin HTTP handlers (validation, delegation, response formatting)
services/ # Business logic, CRUD, orchestration
backends/ # TTS/STT engine implementations
database/ # ORM models, session management, migrations, seeds
utils/ # Shared utilities (audio, effects, caching, progress)
tests/ # pytest suite
```
---
## Ruff Adoption
`pyproject.toml` configures ruff for linting and formatting. Run:
```bash
# Lint (check)
ruff check backend/
# Lint (auto-fix)
ruff check backend/ --fix
# Format
ruff format backend/
```
Introduce ruff fixes file-by-file as you touch them. Don't run `--fix` across the entire codebase in one shot -- that creates unreviewable diffs.

3
backend/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
# Backend package
__version__ = "0.4.5"

281
backend/app.py Normal file
View File

@@ -0,0 +1,281 @@
"""FastAPI application factory, middleware, and lifecycle events."""
import asyncio
import logging
import os
import sys
from pathlib import Path
class ColoredFormatter(logging.Formatter):
"""Custom formatter to add colors matching uvicorn's style."""
COLORS = {
"DEBUG": "\033[36m", # Cyan
"INFO": "\033[32m", # Green
"WARNING": "\033[33m", # Yellow
"ERROR": "\033[31m", # Red
"CRITICAL": "\033[35m", # Magenta
}
RESET = "\033[0m"
def format(self, record):
log_color = self.COLORS.get(record.levelname, self.RESET)
record.levelname = f"{log_color}{record.levelname}{self.RESET}"
return super().format(record)
# Configure logging to match uvicorn's format with colors
handler = logging.StreamHandler(sys.stderr)
handler.setFormatter(ColoredFormatter("%(levelname)s: %(message)s"))
logging.basicConfig(
level=logging.INFO,
handlers=[handler],
)
logger = logging.getLogger(__name__)
# AMD GPU environment variables must be set before torch import
if not os.environ.get("HSA_OVERRIDE_GFX_VERSION"):
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0"
if not os.environ.get("MIOPEN_LOG_LEVEL"):
os.environ["MIOPEN_LOG_LEVEL"] = "4"
import torch
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from urllib.parse import quote
from . import __version__, config, database
from .services import tts, transcribe
from .database import get_db
from .utils.platform_detect import get_backend_type
from .utils.progress import get_progress_manager
from .services.task_queue import create_background_task, init_queue
from .routes import register_routers
def safe_content_disposition(disposition_type: str, filename: str) -> str:
"""Build a Content-Disposition header safe for non-ASCII filenames.
Uses RFC 5987 ``filename*`` parameter so browsers can decode UTF-8
filenames while the ``filename`` fallback stays ASCII-only.
"""
ascii_name = "".join(c for c in filename if c.isascii() and (c.isalnum() or c in " -_.")).strip() or "download"
utf8_name = quote(filename, safe="")
return f"{disposition_type}; filename=\"{ascii_name}\"; filename*=UTF-8''{utf8_name}"
def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
application = FastAPI(
title="voicebox API",
description="Production-quality Qwen3-TTS voice cloning API",
version=__version__,
)
_configure_cors(application)
register_routers(application)
_register_lifecycle(application)
_mount_frontend(application)
return application
def _configure_cors(application: FastAPI) -> None:
"""Set up CORS middleware with local-first defaults."""
default_origins = [
"http://localhost:5173", # Vite dev server
"http://127.0.0.1:5173",
"http://localhost:17493",
"http://127.0.0.1:17493",
"tauri://localhost", # Tauri webview (macOS)
"https://tauri.localhost", # Tauri webview (Windows/Linux)
"http://tauri.localhost", # Tauri webview (Windows, some builds)
]
env_origins = os.environ.get("VOICEBOX_CORS_ORIGINS", "")
all_origins = default_origins + [o.strip() for o in env_origins.split(",") if o.strip()]
application.add_middleware(
CORSMiddleware,
allow_origins=all_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def _mount_frontend(application: FastAPI) -> None:
"""Serve the built web frontend when present (Docker / web deployment).
The Dockerfile copies the Vite build output to ``/app/frontend/``. When
that directory exists we mount static assets and add a catch-all route so
the React SPA handles client-side routing. In dev or API-only mode the
directory is absent and this function is a no-op.
"""
frontend_dir = Path(__file__).resolve().parent.parent / "frontend"
if not frontend_dir.is_dir():
return
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
# Mount hashed assets (JS, CSS, images) that Vite places under /assets
assets_dir = frontend_dir / "assets"
if assets_dir.is_dir():
application.mount(
"/assets",
StaticFiles(directory=str(assets_dir)),
name="frontend-assets",
)
# SPA catch-all: serve files if they exist, otherwise index.html for
# client-side routes like /voices, /stories, /models, etc.
@application.get("/{full_path:path}")
async def serve_spa(full_path: str):
file_path = (frontend_dir / full_path).resolve()
# Guard against path traversal — only serve files inside frontend_dir
if full_path and file_path.is_file() and file_path.is_relative_to(frontend_dir):
return FileResponse(file_path)
return FileResponse(frontend_dir / "index.html", media_type="text/html")
logger.info("Frontend: serving SPA from %s", frontend_dir)
def _get_gpu_status() -> str:
"""Return a human-readable string describing GPU availability."""
backend_type = get_backend_type()
if torch.cuda.is_available():
from .backends.base import check_cuda_compatibility
device_name = torch.cuda.get_device_name(0)
compatible, _warning = check_cuda_compatibility()
is_rocm = hasattr(torch.version, "hip") and torch.version.hip is not None
if is_rocm:
label = f"ROCm ({device_name})"
else:
label = f"CUDA ({device_name})"
if not compatible:
label += " [UNSUPPORTED - see logs]"
return label
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "MPS (Apple Silicon)"
elif backend_type == "mlx":
return "Metal (Apple Silicon via MLX)"
# Intel XPU (Arc / Data Center) via IPEX
try:
import intel_extension_for_pytorch # noqa: F401
if hasattr(torch, "xpu") and torch.xpu.is_available():
try:
xpu_name = torch.xpu.get_device_name(0)
except Exception:
xpu_name = "Intel GPU"
return f"XPU ({xpu_name})"
except ImportError:
pass
return "None (CPU only)"
def _register_lifecycle(application: FastAPI) -> None:
"""Attach startup and shutdown event handlers."""
@application.on_event("startup")
async def startup_event():
import platform
import sys
logger.info("Voicebox v%s starting up", __version__)
logger.info(
"Python %s on %s %s (%s)",
sys.version.split()[0],
platform.system(),
platform.release(),
platform.machine(),
)
database.init_db()
from .database.session import _db_path
logger.info("Database: %s", _db_path)
logger.info("Data directory: %s", config.get_data_dir())
init_queue()
# Mark stale "generating" records as failed -- leftovers from a killed process
from sqlalchemy import text as sa_text
db = next(get_db())
try:
result = db.execute(
sa_text(
"UPDATE generations SET status = 'failed', "
"error = 'Server was shut down during generation' "
"WHERE status IN ('generating', 'loading_model')"
)
)
if result.rowcount > 0:
logger.info("Marked %d stale generation(s) as failed", result.rowcount)
from .database import VoiceProfile as DBVoiceProfile, Generation as DBGeneration
profile_count = db.query(DBVoiceProfile).count()
generation_count = db.query(DBGeneration).count()
logger.info("Profiles: %d, Generations: %d", profile_count, generation_count)
db.commit()
except Exception as e:
db.rollback()
logger.warning("Could not clean up stale generations: %s", e)
finally:
db.close()
backend_type = get_backend_type()
logger.info("Backend: %s", backend_type.upper())
logger.info("GPU: %s", _get_gpu_status())
# Warn if GPU architecture is not supported by this PyTorch build
from .backends.base import check_cuda_compatibility
_compatible, _cuda_warning = check_cuda_compatibility()
if not _compatible:
logger.warning("GPU COMPATIBILITY: %s", _cuda_warning)
from .services.cuda import check_and_update_cuda_binary
create_background_task(check_and_update_cuda_binary())
try:
progress_manager = get_progress_manager()
progress_manager._set_main_loop(asyncio.get_running_loop())
except Exception as e:
logger.warning("Could not initialize progress manager event loop: %s", e)
try:
from huggingface_hub import constants as hf_constants
cache_dir = Path(hf_constants.HF_HUB_CACHE)
cache_dir.mkdir(parents=True, exist_ok=True)
logger.info("Model cache: %s", cache_dir)
except Exception as e:
logger.warning("Could not create HuggingFace cache directory: %s", e)
logger.info("Ready")
@application.on_event("shutdown")
async def shutdown_event():
logger.info("Voicebox server shutting down...")
try:
tts.unload_tts_model()
except Exception:
logger.exception("Failed to unload TTS model")
try:
transcribe.unload_whisper_model()
except Exception:
logger.exception("Failed to unload Whisper model")
app = create_app()

View File

@@ -0,0 +1,621 @@
"""
Backend abstraction layer for TTS and STT.
Provides a unified interface for MLX and PyTorch backends,
and a model config registry that eliminates per-engine dispatch maps.
"""
# Install HF compatibility patches before any backend imports transformers /
# huggingface_hub. The module runs ``patch_transformers_mistral_regex`` at
# import time, which wraps transformers' tokenizer load against the
# unconditional HuggingFace metadata call that otherwise raises on
# HF_HUB_OFFLINE=1 and on network failures.
from ..utils import hf_offline_patch # noqa: F401
import threading
from dataclasses import dataclass, field
from typing import Protocol, Optional, Tuple, List
from typing_extensions import runtime_checkable
import numpy as np
from ..utils.platform_detect import get_backend_type
LANGUAGE_CODE_TO_NAME = {
"zh": "chinese",
"en": "english",
"ja": "japanese",
"ko": "korean",
"de": "german",
"fr": "french",
"ru": "russian",
"pt": "portuguese",
"es": "spanish",
"it": "italian",
}
WHISPER_HF_REPOS = {
"base": "openai/whisper-base",
"small": "openai/whisper-small",
"medium": "openai/whisper-medium",
"large": "openai/whisper-large-v3",
"turbo": "openai/whisper-large-v3-turbo",
}
@dataclass
class ModelConfig:
"""Declarative config for a downloadable model variant."""
model_name: str # e.g. "luxtts", "chatterbox-tts"
display_name: str # e.g. "LuxTTS (Fast, CPU-friendly)"
engine: str # e.g. "luxtts", "chatterbox"
hf_repo_id: str # e.g. "YatharthS/LuxTTS"
model_size: str = "default"
size_mb: int = 0
needs_trim: bool = False
supports_instruct: bool = False
languages: list[str] = field(default_factory=lambda: ["en"])
@runtime_checkable
class TTSBackend(Protocol):
"""Protocol for TTS backend implementations."""
# Each backend class should define MODEL_CONFIGS as a class variable:
# MODEL_CONFIGS: list[ModelConfig]
async def load_model(self, model_size: str) -> None:
"""Load TTS model."""
...
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio.
Returns:
Tuple of (voice_prompt_dict, was_cached)
"""
...
async def combine_voice_prompts(
self,
audio_paths: List[str],
reference_texts: List[str],
) -> Tuple[np.ndarray, str]:
"""
Combine multiple voice prompts.
Returns:
Tuple of (combined_audio_array, combined_text)
"""
...
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio from text.
Returns:
Tuple of (audio_array, sample_rate)
"""
...
def unload_model(self) -> None:
"""Unload model to free memory."""
...
def is_loaded(self) -> bool:
"""Check if model is loaded."""
...
def _get_model_path(self, model_size: str) -> str:
"""
Get model path for a given size.
Returns:
Model path or HuggingFace Hub ID
"""
...
@runtime_checkable
class STTBackend(Protocol):
"""Protocol for STT (Speech-to-Text) backend implementations."""
async def load_model(self, model_size: str) -> None:
"""Load STT model."""
...
async def transcribe(
self,
audio_path: str,
language: Optional[str] = None,
model_size: Optional[str] = None,
) -> str:
"""
Transcribe audio to text.
Returns:
Transcribed text
"""
...
def unload_model(self) -> None:
"""Unload model to free memory."""
...
def is_loaded(self) -> bool:
"""Check if model is loaded."""
...
# Global backend instances
_tts_backend: Optional[TTSBackend] = None
_tts_backends: dict[str, TTSBackend] = {}
_tts_backends_lock = threading.Lock()
_stt_backend: Optional[STTBackend] = None
# Supported TTS engines — keyed by engine name, value is the backend class import path.
# The factory function uses this for the if/elif chain; the model configs live on the backend classes.
TTS_ENGINES = {
"qwen": "Qwen TTS",
"qwen_custom_voice": "Qwen CustomVoice",
"luxtts": "LuxTTS",
"chatterbox": "Chatterbox TTS",
"chatterbox_turbo": "Chatterbox Turbo",
"tada": "TADA",
"kokoro": "Kokoro",
}
def _get_qwen_model_configs() -> list[ModelConfig]:
"""Return Qwen model configs with backend-aware HF repo IDs."""
backend_type = get_backend_type()
if backend_type == "mlx":
repo_1_7b = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16"
repo_0_6b = "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-bf16"
else:
repo_1_7b = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
repo_0_6b = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
return [
ModelConfig(
model_name="qwen-tts-1.7B",
display_name="Qwen TTS 1.7B",
engine="qwen",
hf_repo_id=repo_1_7b,
model_size="1.7B",
size_mb=3500,
supports_instruct=False, # Base model drops instruct silently
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
),
ModelConfig(
model_name="qwen-tts-0.6B",
display_name="Qwen TTS 0.6B",
engine="qwen",
hf_repo_id=repo_0_6b,
model_size="0.6B",
size_mb=1200,
supports_instruct=False,
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
),
]
def _get_qwen_custom_voice_configs() -> list[ModelConfig]:
"""Return Qwen CustomVoice model configs."""
return [
ModelConfig(
model_name="qwen-custom-voice-1.7B",
display_name="Qwen CustomVoice 1.7B",
engine="qwen_custom_voice",
hf_repo_id="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
model_size="1.7B",
size_mb=3500,
supports_instruct=True,
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
),
ModelConfig(
model_name="qwen-custom-voice-0.6B",
display_name="Qwen CustomVoice 0.6B",
engine="qwen_custom_voice",
hf_repo_id="Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
model_size="0.6B",
size_mb=1200,
supports_instruct=True,
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
),
]
def _get_non_qwen_tts_configs() -> list[ModelConfig]:
"""Return model configs for non-Qwen TTS engines.
These are static — no backend-type branching needed.
"""
return [
ModelConfig(
model_name="luxtts",
display_name="LuxTTS (Fast, CPU-friendly)",
engine="luxtts",
hf_repo_id="YatharthS/LuxTTS",
size_mb=300,
languages=["en"],
),
ModelConfig(
model_name="chatterbox-tts",
display_name="Chatterbox TTS (Multilingual)",
engine="chatterbox",
hf_repo_id="ResembleAI/chatterbox",
size_mb=3200,
needs_trim=True,
languages=[
"zh",
"en",
"ja",
"ko",
"de",
"fr",
"ru",
"pt",
"es",
"it",
"he",
"ar",
"da",
"el",
"fi",
"hi",
"ms",
"nl",
"no",
"pl",
"sv",
"sw",
"tr",
],
),
ModelConfig(
model_name="chatterbox-turbo",
display_name="Chatterbox Turbo (English, Tags)",
engine="chatterbox_turbo",
hf_repo_id="ResembleAI/chatterbox-turbo",
size_mb=1500,
needs_trim=True,
languages=["en"],
),
ModelConfig(
model_name="tada-1b",
display_name="TADA 1B (English)",
engine="tada",
hf_repo_id="HumeAI/tada-1b",
model_size="1B",
size_mb=4000,
languages=["en"],
),
ModelConfig(
model_name="tada-3b-ml",
display_name="TADA 3B Multilingual",
engine="tada",
hf_repo_id="HumeAI/tada-3b-ml",
model_size="3B",
size_mb=8000,
languages=["en", "ar", "zh", "de", "es", "fr", "it", "ja", "pl", "pt"],
),
ModelConfig(
model_name="kokoro",
display_name="Kokoro 82M",
engine="kokoro",
hf_repo_id="hexgrad/Kokoro-82M",
size_mb=350,
languages=["en", "es", "fr", "hi", "it", "pt", "ja", "zh"],
),
]
def _get_whisper_configs() -> list[ModelConfig]:
"""Return Whisper STT model configs."""
return [
ModelConfig(
model_name="whisper-base",
display_name="Whisper Base",
engine="whisper",
hf_repo_id="openai/whisper-base",
model_size="base",
),
ModelConfig(
model_name="whisper-small",
display_name="Whisper Small",
engine="whisper",
hf_repo_id="openai/whisper-small",
model_size="small",
),
ModelConfig(
model_name="whisper-medium",
display_name="Whisper Medium",
engine="whisper",
hf_repo_id="openai/whisper-medium",
model_size="medium",
),
ModelConfig(
model_name="whisper-large",
display_name="Whisper Large",
engine="whisper",
hf_repo_id="openai/whisper-large-v3",
model_size="large",
),
ModelConfig(
model_name="whisper-turbo",
display_name="Whisper Turbo",
engine="whisper",
hf_repo_id="openai/whisper-large-v3-turbo",
model_size="turbo",
),
]
def get_all_model_configs() -> list[ModelConfig]:
"""Return the full list of model configs (TTS + STT)."""
return _get_qwen_model_configs() + _get_qwen_custom_voice_configs() + _get_non_qwen_tts_configs() + _get_whisper_configs()
def get_tts_model_configs() -> list[ModelConfig]:
"""Return only TTS model configs."""
return _get_qwen_model_configs() + _get_qwen_custom_voice_configs() + _get_non_qwen_tts_configs()
# Lookup helpers — these replace the if/elif chains in main.py
def get_model_config(model_name: str) -> Optional[ModelConfig]:
"""Look up a model config by model_name."""
for cfg in get_all_model_configs():
if cfg.model_name == model_name:
return cfg
return None
def engine_needs_trim(engine: str) -> bool:
"""Whether this engine's output should be run through trim_tts_output."""
for cfg in get_tts_model_configs():
if cfg.engine == engine:
return cfg.needs_trim
return False
def engine_has_model_sizes(engine: str) -> bool:
"""Whether this engine supports multiple model sizes (only Qwen currently)."""
configs = [c for c in get_tts_model_configs() if c.engine == engine]
return len(configs) > 1
async def load_engine_model(engine: str, model_size: str = "default") -> None:
"""Load a model for the given engine, handling engines with multiple model sizes."""
backend = get_tts_backend_for_engine(engine)
if engine in ("qwen", "qwen_custom_voice"):
await backend.load_model_async(model_size)
elif engine == "tada":
await backend.load_model(model_size)
else:
await backend.load_model()
async def ensure_model_cached_or_raise(engine: str, model_size: str = "default") -> None:
"""Check if a model is cached, raise HTTPException if not. Used by streaming endpoint."""
from fastapi import HTTPException
backend = get_tts_backend_for_engine(engine)
cfg = None
for c in get_tts_model_configs():
if c.engine == engine and c.model_size == model_size:
cfg = c
break
if engine in ("qwen", "qwen_custom_voice", "tada"):
if not backend._is_model_cached(model_size):
raise HTTPException(
status_code=400,
detail=f"Model {model_size} is not downloaded yet. Use /generate to trigger a download.",
)
else:
if not backend._is_model_cached():
display = cfg.display_name if cfg else engine
raise HTTPException(
status_code=400,
detail=f"{display} model is not downloaded yet. Use /generate to trigger a download.",
)
def unload_model_by_config(config: ModelConfig) -> bool:
"""Unload a model given its config. Returns True if it was loaded, False otherwise."""
from . import get_tts_backend_for_engine
from ..services import tts, transcribe
if config.engine == "whisper":
whisper_model = transcribe.get_whisper_model()
if whisper_model.is_loaded() and whisper_model.model_size == config.model_size:
transcribe.unload_whisper_model()
return True
return False
if config.engine == "qwen":
tts_model = tts.get_tts_model()
loaded_size = getattr(tts_model, "_current_model_size", None) or getattr(tts_model, "model_size", None)
if tts_model.is_loaded() and loaded_size == config.model_size:
tts.unload_tts_model()
return True
return False
if config.engine == "qwen_custom_voice":
backend = get_tts_backend_for_engine(config.engine)
loaded_size = getattr(backend, "_current_model_size", None) or getattr(backend, "model_size", None)
if backend.is_loaded() and loaded_size == config.model_size:
backend.unload_model()
return True
return False
# All other TTS engines
backend = get_tts_backend_for_engine(config.engine)
if backend.is_loaded():
backend.unload_model()
return True
return False
def check_model_loaded(config: ModelConfig) -> bool:
"""Check if a model is currently loaded."""
from . import get_tts_backend_for_engine
from ..services import tts, transcribe
try:
if config.engine == "whisper":
whisper_model = transcribe.get_whisper_model()
return whisper_model.is_loaded() and getattr(whisper_model, "model_size", None) == config.model_size
if config.engine == "qwen":
tts_model = tts.get_tts_model()
loaded_size = getattr(tts_model, "_current_model_size", None) or getattr(tts_model, "model_size", None)
return tts_model.is_loaded() and loaded_size == config.model_size
if config.engine == "qwen_custom_voice":
backend = get_tts_backend_for_engine(config.engine)
loaded_size = getattr(backend, "_current_model_size", None) or getattr(backend, "model_size", None)
return backend.is_loaded() and loaded_size == config.model_size
backend = get_tts_backend_for_engine(config.engine)
return backend.is_loaded()
except Exception:
return False
def get_model_load_func(config: ModelConfig):
"""Return a callable that loads/downloads the model."""
from . import get_tts_backend_for_engine
from ..services import tts, transcribe
if config.engine == "whisper":
return lambda: transcribe.get_whisper_model().load_model(config.model_size)
if config.engine == "qwen":
return lambda: tts.get_tts_model().load_model(config.model_size)
if config.engine == "qwen_custom_voice":
return lambda: get_tts_backend_for_engine(config.engine).load_model(config.model_size)
return lambda: get_tts_backend_for_engine(config.engine).load_model()
def get_tts_backend() -> TTSBackend:
"""
Get or create the default (Qwen) TTS backend instance based on platform.
Returns:
TTS backend instance (MLX or PyTorch)
"""
return get_tts_backend_for_engine("qwen")
def get_tts_backend_for_engine(engine: str) -> TTSBackend:
"""
Get or create a TTS backend for the given engine.
Args:
engine: Engine name (e.g. "qwen", "luxtts", "chatterbox", "chatterbox_turbo")
Returns:
TTS backend instance
"""
global _tts_backends
# Fast path: check without lock
if engine in _tts_backends:
return _tts_backends[engine]
# Slow path: create with lock to avoid duplicate instantiation
with _tts_backends_lock:
# Double-check after acquiring lock
if engine in _tts_backends:
return _tts_backends[engine]
if engine == "qwen":
backend_type = get_backend_type()
if backend_type == "mlx":
from .mlx_backend import MLXTTSBackend
backend = MLXTTSBackend()
else:
from .pytorch_backend import PyTorchTTSBackend
backend = PyTorchTTSBackend()
elif engine == "luxtts":
from .luxtts_backend import LuxTTSBackend
backend = LuxTTSBackend()
elif engine == "chatterbox":
from .chatterbox_backend import ChatterboxTTSBackend
backend = ChatterboxTTSBackend()
elif engine == "chatterbox_turbo":
from .chatterbox_turbo_backend import ChatterboxTurboTTSBackend
backend = ChatterboxTurboTTSBackend()
elif engine == "tada":
from .hume_backend import HumeTadaBackend
backend = HumeTadaBackend()
elif engine == "kokoro":
from .kokoro_backend import KokoroTTSBackend
backend = KokoroTTSBackend()
elif engine == "qwen_custom_voice":
from .qwen_custom_voice_backend import QwenCustomVoiceBackend
backend = QwenCustomVoiceBackend()
else:
raise ValueError(f"Unknown TTS engine: {engine}. Supported: {list(TTS_ENGINES.keys())}")
_tts_backends[engine] = backend
return backend
def get_stt_backend() -> STTBackend:
"""
Get or create STT backend instance based on platform.
Returns:
STT backend instance (MLX or PyTorch)
"""
global _stt_backend
if _stt_backend is None:
backend_type = get_backend_type()
if backend_type == "mlx":
from .mlx_backend import MLXSTTBackend
_stt_backend = MLXSTTBackend()
else:
from .pytorch_backend import PyTorchSTTBackend
_stt_backend = PyTorchSTTBackend()
return _stt_backend
def reset_backends():
"""Reset backend instances (useful for testing)."""
global _tts_backend, _tts_backends, _stt_backend
_tts_backend = None
_tts_backends.clear()
_stt_backend = None

327
backend/backends/base.py Normal file
View File

@@ -0,0 +1,327 @@
"""
Shared utilities for TTS/STT backend implementations.
Eliminates duplication of cache checking, device detection,
voice prompt combination, and model loading progress tracking.
"""
import logging
import platform
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, List, Optional, Tuple
import numpy as np
from ..utils.audio import normalize_audio, load_audio
from ..utils.progress import get_progress_manager
from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback
from ..utils.tasks import get_task_manager
logger = logging.getLogger(__name__)
def is_model_cached(
hf_repo: str,
*,
weight_extensions: tuple[str, ...] = (".safetensors", ".bin"),
required_files: Optional[list[str]] = None,
) -> bool:
"""
Check if a HuggingFace model is fully cached locally.
Args:
hf_repo: HuggingFace repo ID (e.g. "Qwen/Qwen3-TTS-12Hz-1.7B-Base")
weight_extensions: File extensions that count as model weights.
required_files: If set, check that these specific filenames exist
in snapshots instead of checking by extension.
Returns:
True if model is fully cached, False if missing or incomplete.
"""
try:
from huggingface_hub import constants as hf_constants
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + hf_repo.replace("/", "--"))
if not repo_cache.exists():
return False
# Incomplete blobs mean a download is still in progress
blobs_dir = repo_cache / "blobs"
if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
logger.debug(f"Found .incomplete files for {hf_repo}")
return False
snapshots_dir = repo_cache / "snapshots"
if not snapshots_dir.exists():
return False
if required_files:
# Check that every required filename exists somewhere in snapshots
for fname in required_files:
if not any(snapshots_dir.rglob(fname)):
return False
return True
# Check that at least one weight file exists
for ext in weight_extensions:
if any(snapshots_dir.rglob(f"*{ext}")):
return True
logger.debug(f"No model weights found for {hf_repo}")
return False
except Exception as e:
logger.warning(f"Error checking cache for {hf_repo}: {e}")
return False
def get_torch_device(
*,
allow_xpu: bool = False,
allow_directml: bool = False,
allow_mps: bool = False,
force_cpu_on_mac: bool = False,
) -> str:
"""
Detect the best available torch device.
Args:
allow_xpu: Check for Intel XPU (IPEX) support.
allow_directml: Check for DirectML (Windows) support.
allow_mps: Allow MPS (Apple Silicon). If False, MPS falls back to CPU.
force_cpu_on_mac: Force CPU on macOS regardless of GPU availability.
"""
if force_cpu_on_mac and platform.system() == "Darwin":
return "cpu"
import torch
if torch.cuda.is_available():
return "cuda"
if allow_xpu:
try:
import intel_extension_for_pytorch # noqa: F401
if hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
except ImportError:
pass
if allow_directml:
try:
import torch_directml
if torch_directml.device_count() > 0:
return torch_directml.device(0)
except ImportError:
pass
if allow_mps:
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
def check_cuda_compatibility() -> tuple[bool, str | None]:
"""Check if the installed PyTorch supports the current GPU's compute capability.
Returns:
(compatible, warning_message) — compatible is True if OK or no CUDA GPU,
warning_message is a human-readable string if there's a problem.
"""
import torch
if not torch.cuda.is_available():
return True, None
major, minor = torch.cuda.get_device_capability(0)
capability = f"{major}.{minor}"
device_name = torch.cuda.get_device_name(0)
sm_tag = f"sm_{major}{minor}"
# torch.cuda._get_arch_list() returns the SM architectures this build
# was compiled for (e.g. ["sm_50", "sm_60", ..., "sm_90"]).
try:
arch_list = torch.cuda._get_arch_list()
if arch_list:
# Check for both sm_XX and compute_XX (JIT-compiled) entries
compute_tag = f"compute_{major}{minor}"
if sm_tag not in arch_list and compute_tag not in arch_list:
return False, (
f"{device_name} (compute capability {capability} / {sm_tag}) "
f"is not supported by this PyTorch build. "
f"Supported architectures: {', '.join(arch_list)}. "
f"Install PyTorch nightly (cu128) for newer GPU support: "
f"pip install torch --index-url https://download.pytorch.org/whl/nightly/cu128"
)
except AttributeError:
pass
return True, None
def empty_device_cache(device: str) -> None:
"""
Free cached memory on the given device (CUDA or XPU).
Backends should call this after unloading models so VRAM is returned
to the OS.
"""
import torch
if device == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()
elif device == "xpu" and hasattr(torch, "xpu"):
torch.xpu.empty_cache()
def manual_seed(seed: int, device: str) -> None:
"""
Set the random seed on both CPU and the active accelerator.
Covers CUDA and Intel XPU so that generation is reproducible
regardless of which GPU backend is in use.
"""
import torch
torch.manual_seed(seed)
if device == "cuda" and torch.cuda.is_available():
torch.cuda.manual_seed(seed)
elif device == "xpu" and hasattr(torch, "xpu"):
torch.xpu.manual_seed(seed)
async def combine_voice_prompts(
audio_paths: List[str],
reference_texts: List[str],
*,
sample_rate: Optional[int] = None,
) -> Tuple[np.ndarray, str]:
"""
Combine multiple reference audio samples into one.
Loads each audio file, normalizes, concatenates, and joins texts.
Args:
audio_paths: Paths to reference audio files.
reference_texts: Corresponding transcripts.
sample_rate: If set, resample audio to this rate during loading.
"""
combined_audio = []
for path in audio_paths:
kwargs = {"sample_rate": sample_rate} if sample_rate else {}
audio, _sr = load_audio(path, **kwargs)
audio = normalize_audio(audio)
combined_audio.append(audio)
mixed = np.concatenate(combined_audio)
mixed = normalize_audio(mixed)
combined_text = " ".join(reference_texts)
return mixed, combined_text
@contextmanager
def model_load_progress(
model_name: str,
is_cached: bool,
filter_non_downloads: Optional[bool] = None,
):
"""
Context manager for model loading with HF download progress tracking.
Handles the tqdm patching, progress_manager/task_manager lifecycle,
and error reporting that every backend duplicates.
Args:
model_name: Progress tracking key (e.g. "qwen-tts-1.7B", "whisper-base").
is_cached: Whether the model is already downloaded.
filter_non_downloads: Whether to filter non-download tqdm bars.
Defaults to `is_cached`.
Yields:
The tracker context (already entered). The caller loads the model
inside the `with` block. The tqdm patch is torn down on exit.
Usage:
with model_load_progress("qwen-tts-1.7B", is_cached) as ctx:
self.model = SomeModel.from_pretrained(...)
"""
if filter_non_downloads is None:
filter_non_downloads = is_cached
progress_manager = get_progress_manager()
task_manager = get_task_manager()
progress_callback = create_hf_progress_callback(model_name, progress_manager)
tracker = HFProgressTracker(progress_callback, filter_non_downloads=filter_non_downloads)
tracker_context = tracker.patch_download()
tracker_context.__enter__()
if not is_cached:
task_manager.start_download(model_name)
progress_manager.update_progress(
model_name=model_name,
current=0,
total=0,
filename="Connecting to HuggingFace...",
status="downloading",
)
try:
yield tracker_context
except Exception as e:
# Report error to both managers
progress_manager.mark_error(model_name, str(e))
task_manager.error_download(model_name, str(e))
raise
else:
# Only mark complete if we were tracking a download
if not is_cached:
progress_manager.mark_complete(model_name)
task_manager.complete_download(model_name)
finally:
tracker_context.__exit__(None, None, None)
def patch_chatterbox_f32(model) -> None:
"""
Patch float64 -> float32 dtype mismatches in upstream chatterbox.
librosa.load returns float64 numpy arrays. Multiple upstream code paths
convert these to torch tensors via torch.from_numpy() without casting,
then matmul against float32 model weights. This patches the two known
entry points:
1. S3Tokenizer.log_mel_spectrogram — audio tensor hits _mel_filters (f32)
2. VoiceEncoder.forward — float64 mel spectrograms hit LSTM weights (f32)
"""
import types
# Patch S3Tokenizer
_tokzr = model.s3gen.tokenizer
_orig_log_mel = _tokzr.log_mel_spectrogram.__func__
def _f32_log_mel(self_tokzr, audio, padding=0):
import torch as _torch
if _torch.is_tensor(audio):
audio = audio.float()
return _orig_log_mel(self_tokzr, audio, padding)
_tokzr.log_mel_spectrogram = types.MethodType(_f32_log_mel, _tokzr)
# Patch VoiceEncoder
_ve = model.ve
_orig_ve_forward = _ve.forward.__func__
def _f32_ve_forward(self_ve, mels):
return _orig_ve_forward(self_ve, mels.float())
_ve.forward = types.MethodType(_f32_ve_forward, _ve)

View File

@@ -0,0 +1,226 @@
"""
Chatterbox TTS backend implementation.
Wraps ChatterboxMultilingualTTS from chatterbox-tts for zero-shot
voice cloning. Supports 23 languages including Hebrew. Forces CPU
on macOS due to known MPS tensor issues.
"""
import asyncio
import logging
import threading
from pathlib import Path
from typing import ClassVar, List, Optional, Tuple
import numpy as np
from . import TTSBackend
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
patch_chatterbox_f32,
)
logger = logging.getLogger(__name__)
CHATTERBOX_HF_REPO = "ResembleAI/chatterbox"
# Files that must be present for the multilingual model
_MTL_WEIGHT_FILES = [
"t3_mtl23ls_v2.safetensors",
"s3gen.pt",
"ve.pt",
]
class ChatterboxTTSBackend:
"""Chatterbox Multilingual TTS backend for voice cloning."""
# Class-level lock for torch.load monkey-patching
_load_lock: ClassVar[threading.Lock] = threading.Lock()
def __init__(self):
self.model = None
self.model_size = "default"
self._device = None
self._model_load_lock = asyncio.Lock()
def _get_device(self) -> str:
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
def is_loaded(self) -> bool:
return self.model is not None
def _get_model_path(self, model_size: str = "default") -> str:
return CHATTERBOX_HF_REPO
def _is_model_cached(self, model_size: str = "default") -> bool:
return is_model_cached(CHATTERBOX_HF_REPO, required_files=_MTL_WEIGHT_FILES)
async def load_model(self, model_size: str = "default") -> None:
"""Load the Chatterbox multilingual model."""
if self.model is not None:
return
async with self._model_load_lock:
if self.model is not None:
return
await asyncio.to_thread(self._load_model_sync)
def _load_model_sync(self):
"""Synchronous model loading."""
model_name = "chatterbox-tts"
is_cached = self._is_model_cached()
with model_load_progress(model_name, is_cached):
device = self._get_device()
self._device = device
logger.info(f"Loading Chatterbox Multilingual TTS on {device}...")
import torch
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
if device == "cpu":
_orig_torch_load = torch.load
def _patched_load(*args, **kwargs):
kwargs.setdefault("map_location", "cpu")
return _orig_torch_load(*args, **kwargs)
with ChatterboxTTSBackend._load_lock:
torch.load = _patched_load
try:
model = ChatterboxMultilingualTTS.from_pretrained(device=device)
finally:
torch.load = _orig_torch_load
else:
model = ChatterboxMultilingualTTS.from_pretrained(device=device)
# Fix sdpa attention for output_attentions support
t3_tfmr = model.t3.tfmr
if hasattr(t3_tfmr, "config") and hasattr(t3_tfmr.config, "_attn_implementation"):
t3_tfmr.config._attn_implementation = "eager"
for layer in getattr(t3_tfmr, "layers", []):
if hasattr(layer, "self_attn"):
layer.self_attn._attn_implementation = "eager"
patch_chatterbox_f32(model)
self.model = model
logger.info("Chatterbox Multilingual TTS loaded successfully")
def unload_model(self) -> None:
"""Unload model to free memory."""
if self.model is not None:
device = self._device
del self.model
self.model = None
self._device = None
empty_device_cache(device)
logger.info("Chatterbox unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio.
Chatterbox processes reference audio at generation time, so the
prompt just stores the file path. The actual audio is loaded by
model.generate() via audio_prompt_path.
"""
voice_prompt = {
"ref_audio": str(audio_path),
"ref_text": reference_text,
}
return voice_prompt, False
async def combine_voice_prompts(
self,
audio_paths: List[str],
reference_texts: List[str],
) -> Tuple[np.ndarray, str]:
return await _combine_voice_prompts(audio_paths, reference_texts)
# Per-language generation defaults. Lower temp + higher cfg = clearer speech.
_LANG_DEFAULTS: ClassVar[dict] = {
"he": {
"exaggeration": 0.4,
"cfg_weight": 0.7,
"temperature": 0.65,
"repetition_penalty": 2.5,
},
}
_GLOBAL_DEFAULTS: ClassVar[dict] = {
"exaggeration": 0.5,
"cfg_weight": 0.5,
"temperature": 0.8,
"repetition_penalty": 2.0,
}
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio using Chatterbox Multilingual TTS.
Args:
text: Text to synthesize
voice_prompt: Dict with ref_audio path
language: BCP-47 language code
seed: Random seed for reproducibility
instruct: Unused (protocol compatibility)
Returns:
Tuple of (audio_array, sample_rate)
"""
await self.load_model()
ref_audio = voice_prompt.get("ref_audio")
if ref_audio and not Path(ref_audio).exists():
logger.warning(f"Reference audio not found: {ref_audio}")
ref_audio = None
# Merge language-specific defaults with global defaults
lang_defaults = self._LANG_DEFAULTS.get(language, self._GLOBAL_DEFAULTS)
def _generate_sync():
import torch
if seed is not None:
manual_seed(seed, self._device)
logger.info(f"[Chatterbox] Generating: lang={language}")
wav = self.model.generate(
text,
language_id=language,
audio_prompt_path=ref_audio,
exaggeration=lang_defaults["exaggeration"],
cfg_weight=lang_defaults["cfg_weight"],
temperature=lang_defaults["temperature"],
repetition_penalty=lang_defaults["repetition_penalty"],
)
# Convert tensor -> numpy
if isinstance(wav, torch.Tensor):
audio = wav.squeeze().cpu().numpy().astype(np.float32)
else:
audio = np.asarray(wav, dtype=np.float32)
sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000)
return audio, sample_rate
return await asyncio.to_thread(_generate_sync)

View File

@@ -0,0 +1,206 @@
"""
Chatterbox Turbo TTS backend implementation.
Wraps ChatterboxTurboTTS from chatterbox-tts for fast, English-only
voice cloning with paralinguistic tag support ([laugh], [cough], etc.).
Forces CPU on macOS due to known MPS tensor issues.
"""
import asyncio
import logging
import threading
from pathlib import Path
from typing import ClassVar, List, Optional, Tuple
import numpy as np
from . import TTSBackend
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
patch_chatterbox_f32,
)
logger = logging.getLogger(__name__)
CHATTERBOX_TURBO_HF_REPO = "ResembleAI/chatterbox-turbo"
# Files that must be present for the turbo model
_TURBO_WEIGHT_FILES = [
"t3_turbo_v1.safetensors",
"s3gen_meanflow.safetensors",
"ve.safetensors",
]
class ChatterboxTurboTTSBackend:
"""Chatterbox Turbo TTS backend — fast, English-only, with paralinguistic tags."""
# Class-level lock for torch.load monkey-patching
_load_lock: ClassVar[threading.Lock] = threading.Lock()
def __init__(self):
self.model = None
self.model_size = "default"
self._device = None
self._model_load_lock = asyncio.Lock()
def _get_device(self) -> str:
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
def is_loaded(self) -> bool:
return self.model is not None
def _get_model_path(self, model_size: str = "default") -> str:
return CHATTERBOX_TURBO_HF_REPO
def _is_model_cached(self, model_size: str = "default") -> bool:
return is_model_cached(CHATTERBOX_TURBO_HF_REPO, required_files=_TURBO_WEIGHT_FILES)
async def load_model(self, model_size: str = "default") -> None:
"""Load the Chatterbox Turbo model."""
if self.model is not None:
return
async with self._model_load_lock:
if self.model is not None:
return
await asyncio.to_thread(self._load_model_sync)
def _load_model_sync(self):
"""Synchronous model loading."""
model_name = "chatterbox-turbo"
is_cached = self._is_model_cached()
with model_load_progress(model_name, is_cached):
device = self._get_device()
self._device = device
logger.info(f"Loading Chatterbox Turbo TTS on {device}...")
import torch
from huggingface_hub import snapshot_download
from chatterbox.tts_turbo import ChatterboxTurboTTS
local_path = snapshot_download(
repo_id=CHATTERBOX_TURBO_HF_REPO,
token=None,
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"],
)
if device == "cpu":
_orig_torch_load = torch.load
def _patched_load(*args, **kwargs):
kwargs.setdefault("map_location", "cpu")
return _orig_torch_load(*args, **kwargs)
with ChatterboxTurboTTSBackend._load_lock:
torch.load = _patched_load
try:
model = ChatterboxTurboTTS.from_local(local_path, device)
finally:
torch.load = _orig_torch_load
else:
model = ChatterboxTurboTTS.from_local(local_path, device)
patch_chatterbox_f32(model)
self.model = model
logger.info("Chatterbox Turbo TTS loaded successfully")
def unload_model(self) -> None:
"""Unload model to free memory."""
if self.model is not None:
device = self._device
del self.model
self.model = None
self._device = None
empty_device_cache(device)
logger.info("Chatterbox Turbo unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio.
Chatterbox Turbo processes reference audio at generation time, so the
prompt just stores the file path.
"""
voice_prompt = {
"ref_audio": str(audio_path),
"ref_text": reference_text,
}
return voice_prompt, False
async def combine_voice_prompts(
self,
audio_paths: List[str],
reference_texts: List[str],
) -> Tuple[np.ndarray, str]:
return await _combine_voice_prompts(audio_paths, reference_texts)
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio using Chatterbox Turbo TTS.
Supports paralinguistic tags in text: [laugh], [cough], [chuckle], etc.
Args:
text: Text to synthesize (may include paralinguistic tags)
voice_prompt: Dict with ref_audio path
language: Ignored (Turbo is English-only)
seed: Random seed for reproducibility
instruct: Unused (protocol compatibility)
Returns:
Tuple of (audio_array, sample_rate)
"""
await self.load_model()
ref_audio = voice_prompt.get("ref_audio")
if ref_audio and not Path(ref_audio).exists():
logger.warning(f"Reference audio not found: {ref_audio}")
ref_audio = None
def _generate_sync():
import torch
if seed is not None:
manual_seed(seed, self._device)
logger.info("[Chatterbox Turbo] Generating (English)")
wav = self.model.generate(
text,
audio_prompt_path=ref_audio,
temperature=0.8,
top_k=1000,
top_p=0.95,
repetition_penalty=1.2,
)
# Convert tensor -> numpy
if isinstance(wav, torch.Tensor):
audio = wav.squeeze().cpu().numpy().astype(np.float32)
else:
audio = np.asarray(wav, dtype=np.float32)
sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000)
return audio, sample_rate
return await asyncio.to_thread(_generate_sync)

View File

@@ -0,0 +1,346 @@
"""
HumeAI TADA TTS backend implementation.
Wraps HumeAI's TADA (Text-Acoustic Dual Alignment) model for
high-quality voice cloning. Two model variants:
- tada-1b: English-only, ~2B params (Llama 3.2 1B base)
- tada-3b-ml: Multilingual, ~4B params (Llama 3.2 3B base)
Both use a shared encoder/codec (HumeAI/tada-codec). The encoder
produces 1:1 aligned token embeddings from reference audio, and the
causal LM generates speech via flow-matching diffusion.
24kHz output, bf16 inference on CUDA, fp32 on CPU.
"""
import asyncio
import logging
import threading
from typing import ClassVar, List, Optional, Tuple
import numpy as np
from . import TTSBackend
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
)
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
logger = logging.getLogger(__name__)
# HuggingFace repos
TADA_CODEC_REPO = "HumeAI/tada-codec"
TADA_1B_REPO = "HumeAI/tada-1b"
TADA_3B_ML_REPO = "HumeAI/tada-3b-ml"
TADA_MODEL_REPOS = {
"1B": TADA_1B_REPO,
"3B": TADA_3B_ML_REPO,
}
# Key weight files for cache detection
_TADA_MODEL_WEIGHT_FILES = [
"model.safetensors",
]
_TADA_CODEC_WEIGHT_FILES = [
"encoder/model.safetensors",
]
class HumeTadaBackend:
"""HumeAI TADA TTS backend for high-quality voice cloning."""
_load_lock: ClassVar[threading.Lock] = threading.Lock()
def __init__(self):
self.model = None
self.encoder = None
self.model_size = "1B" # default to 1B
self._device = None
self._model_load_lock = asyncio.Lock()
def _get_device(self) -> str:
# Force CPU on macOS — MPS has issues with flow matching
# and large vocab lm_head (>65536 output channels)
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
def is_loaded(self) -> bool:
return self.model is not None
def _get_model_path(self, model_size: str = "1B") -> str:
return TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO)
def _is_model_cached(self, model_size: str = "1B") -> bool:
repo = TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO)
model_cached = is_model_cached(repo, required_files=_TADA_MODEL_WEIGHT_FILES)
codec_cached = is_model_cached(TADA_CODEC_REPO, required_files=_TADA_CODEC_WEIGHT_FILES)
return model_cached and codec_cached
async def load_model(self, model_size: str = "1B") -> None:
"""Load the TADA model and encoder."""
if self.model is not None and self.model_size == model_size:
return
async with self._model_load_lock:
if self.model is not None and self.model_size == model_size:
return
# Unload existing model if switching sizes
if self.model is not None:
self.unload_model()
self.model_size = model_size
await asyncio.to_thread(self._load_model_sync, model_size)
def _load_model_sync(self, model_size: str = "1B"):
"""Synchronous model loading with progress tracking."""
model_name = f"tada-{model_size.lower()}"
is_cached = self._is_model_cached(model_size)
repo = TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO)
with model_load_progress(model_name, is_cached):
# Install DAC shim before importing tada — tada's encoder/decoder
# import dac.nn.layers.Snake1d which requires the descript-audio-codec
# package. The real package pulls in onnx/tensorboard/matplotlib via
# descript-audiotools, so we use a lightweight shim instead.
from ..utils.dac_shim import install_dac_shim
install_dac_shim()
import torch
from huggingface_hub import snapshot_download
device = self._get_device()
self._device = device
logger.info(f"Loading HumeAI TADA {model_size} on {device}...")
# Download codec (encoder + decoder) if not cached
logger.info("Downloading TADA codec...")
snapshot_download(
repo_id=TADA_CODEC_REPO,
token=None,
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.bin"],
)
# Download model weights if not cached
logger.info(f"Downloading TADA {model_size} model...")
snapshot_download(
repo_id=repo,
token=None,
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.bin", "*.model"],
)
# TADA hardcodes "meta-llama/Llama-3.2-1B" as the tokenizer
# source in its Aligner and TadaForCausalLM.from_pretrained().
# That repo is gated (requires Meta license acceptance).
# Download the tokenizer from an ungated mirror and get its
# local cache path so we can point TADA at it directly.
logger.info("Downloading Llama tokenizer (ungated mirror)...")
tokenizer_path = snapshot_download(
repo_id="unsloth/Llama-3.2-1B",
token=None,
allow_patterns=["tokenizer*", "special_tokens*"],
)
# Determine dtype — use bf16 on CUDA/XPU for ~50% memory savings
if device == "cuda" and torch.cuda.is_bf16_supported():
model_dtype = torch.bfloat16
elif device == "xpu":
# Intel Arc (Alchemist+) supports bf16 natively
model_dtype = torch.bfloat16
else:
model_dtype = torch.float32
# Patch the Aligner config class to use the local tokenizer
# path instead of the gated "meta-llama/Llama-3.2-1B" default.
# This avoids monkey-patching AutoTokenizer.from_pretrained
# which corrupts the classmethod descriptor for other engines.
from tada.modules.aligner import AlignerConfig
AlignerConfig.tokenizer_name = tokenizer_path
# Load encoder (only needed for voice prompt encoding)
from tada.modules.encoder import Encoder
logger.info("Loading TADA encoder...")
self.encoder = Encoder.from_pretrained(TADA_CODEC_REPO, subfolder="encoder").to(device)
self.encoder.eval()
# Load the causal LM (includes decoder for wav generation).
# TadaForCausalLM.from_pretrained() calls
# getattr(config, "tokenizer_name", "meta-llama/Llama-3.2-1B")
# which hits the gated repo. Pre-load the config from HF,
# inject the local tokenizer path, then pass it in.
from tada.modules.tada import TadaForCausalLM, TadaConfig
logger.info(f"Loading TADA {model_size} model...")
config = TadaConfig.from_pretrained(repo)
config.tokenizer_name = tokenizer_path
self.model = TadaForCausalLM.from_pretrained(repo, config=config, torch_dtype=model_dtype).to(device)
self.model.eval()
logger.info(f"HumeAI TADA {model_size} loaded successfully on {device}")
def unload_model(self) -> None:
"""Unload model and encoder to free memory."""
if self.model is not None:
del self.model
self.model = None
if self.encoder is not None:
del self.encoder
self.encoder = None
device = self._device
self._device = None
if device:
empty_device_cache(device)
logger.info("HumeAI TADA unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio using TADA's encoder.
TADA's encoder performs forced alignment between audio and text tokens,
producing an EncoderOutput with 1:1 token-audio alignment. If no
reference_text is provided, the encoder uses built-in ASR (English only).
We serialize the EncoderOutput to a dict for caching.
"""
await self.load_model(self.model_size)
cache_key = ("tada_" + get_cache_key(audio_path, reference_text)) if use_cache else None
if cache_key:
cached = get_cached_voice_prompt(cache_key)
if cached is not None and isinstance(cached, dict):
return cached, True
def _encode_sync():
import torch
import soundfile as sf
device = self._device
# Load audio with soundfile (torchaudio 2.10+ requires torchcodec)
audio_np, sr = sf.read(str(audio_path), dtype="float32")
audio = torch.from_numpy(audio_np).float()
if audio.ndim == 1:
audio = audio.unsqueeze(0) # (samples,) -> (1, samples)
else:
audio = audio.T # (samples, channels) -> (channels, samples)
audio = audio.to(device)
# Encode with forced alignment
text_arg = [reference_text] if reference_text else None
prompt = self.encoder(audio, text=text_arg, sample_rate=sr)
# Serialize EncoderOutput to a dict of CPU tensors for caching
prompt_dict = {}
for field_name in prompt.__dataclass_fields__:
val = getattr(prompt, field_name)
if isinstance(val, torch.Tensor):
prompt_dict[field_name] = val.detach().cpu()
elif isinstance(val, list):
prompt_dict[field_name] = val
elif isinstance(val, (int, float)):
prompt_dict[field_name] = val
else:
prompt_dict[field_name] = val
return prompt_dict
encoded = await asyncio.to_thread(_encode_sync)
if cache_key:
cache_voice_prompt(cache_key, encoded)
return encoded, False
async def combine_voice_prompts(
self,
audio_paths: List[str],
reference_texts: List[str],
) -> Tuple[np.ndarray, str]:
return await _combine_voice_prompts(audio_paths, reference_texts, sample_rate=24000)
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio from text using HumeAI TADA.
Args:
text: Text to synthesize
voice_prompt: Serialized EncoderOutput dict from create_voice_prompt()
language: Language code (en, ar, de, es, fr, it, ja, pl, pt, zh)
seed: Random seed for reproducibility
instruct: Not supported by TADA (ignored)
Returns:
Tuple of (audio_array, sample_rate=24000)
"""
await self.load_model(self.model_size)
def _generate_sync():
import torch
from tada.modules.encoder import EncoderOutput
if seed is not None:
manual_seed(seed, self._device)
device = self._device
# Reconstruct EncoderOutput from the cached dict
restored = {}
for k, v in voice_prompt.items():
if isinstance(v, torch.Tensor):
# Move to device and match model dtype for float tensors
if v.is_floating_point():
model_dtype = next(self.model.parameters()).dtype
restored[k] = v.to(device=device, dtype=model_dtype)
else:
restored[k] = v.to(device=device)
else:
restored[k] = v
prompt = EncoderOutput(**restored)
# For non-English with the 3B-ML model, we could reload the
# encoder with the language-specific aligner. However, the
# generation itself is language-agnostic — only the encoder's
# aligner changes. Since we encode at create_voice_prompt time,
# the language is already baked in. For simplicity, we don't
# reload the encoder here.
logger.info(f"[TADA] Generating ({language}), text length: {len(text)}")
output = self.model.generate(
prompt=prompt,
text=text,
)
# output.audio is a list of tensors (one per batch item)
if output.audio and output.audio[0] is not None:
audio_tensor = output.audio[0]
audio = audio_tensor.detach().cpu().numpy().squeeze().astype(np.float32)
else:
logger.warning("[TADA] Generation produced no audio")
audio = np.zeros(24000, dtype=np.float32)
return audio, 24000
return await asyncio.to_thread(_generate_sync)

View File

@@ -0,0 +1,288 @@
"""
Kokoro TTS backend implementation.
Wraps the Kokoro-82M model for fast, lightweight text-to-speech.
82M parameters, CPU realtime, 24kHz output, Apache 2.0 license.
Kokoro uses pre-built voice style vectors (not traditional zero-shot cloning
from arbitrary audio). Voice prompts are stored as deferred references to
HF-hosted voice .pt files.
Languages supported (via misaki G2P):
- American English (a), British English (b)
- Spanish (e), French (f), Hindi (h), Italian (i), Portuguese (p)
- Japanese (j) — requires misaki[ja]
- Chinese (z) — requires misaki[zh]
"""
import asyncio
import logging
import os
from typing import Optional
import numpy as np
from . import TTSBackend
from .base import (
get_torch_device,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
)
logger = logging.getLogger(__name__)
# HuggingFace repo for model + voice detection
KOKORO_HF_REPO = "hexgrad/Kokoro-82M"
KOKORO_SAMPLE_RATE = 24000
# Default voice if none specified
KOKORO_DEFAULT_VOICE = "af_heart"
# All available Kokoro voices: (voice_id, display_name, gender, lang_code)
KOKORO_VOICES = [
# American English female
("af_alloy", "Alloy", "female", "en"),
("af_aoede", "Aoede", "female", "en"),
("af_bella", "Bella", "female", "en"),
("af_heart", "Heart", "female", "en"),
("af_jessica", "Jessica", "female", "en"),
("af_kore", "Kore", "female", "en"),
("af_nicole", "Nicole", "female", "en"),
("af_nova", "Nova", "female", "en"),
("af_river", "River", "female", "en"),
("af_sarah", "Sarah", "female", "en"),
("af_sky", "Sky", "female", "en"),
# American English male
("am_adam", "Adam", "male", "en"),
("am_echo", "Echo", "male", "en"),
("am_eric", "Eric", "male", "en"),
("am_fenrir", "Fenrir", "male", "en"),
("am_liam", "Liam", "male", "en"),
("am_michael", "Michael", "male", "en"),
("am_onyx", "Onyx", "male", "en"),
("am_puck", "Puck", "male", "en"),
("am_santa", "Santa", "male", "en"),
# British English female
("bf_alice", "Alice", "female", "en"),
("bf_emma", "Emma", "female", "en"),
("bf_isabella", "Isabella", "female", "en"),
("bf_lily", "Lily", "female", "en"),
# British English male
("bm_daniel", "Daniel", "male", "en"),
("bm_fable", "Fable", "male", "en"),
("bm_george", "George", "male", "en"),
("bm_lewis", "Lewis", "male", "en"),
# Spanish
("ef_dora", "Dora", "female", "es"),
("em_alex", "Alex", "male", "es"),
("em_santa", "Santa", "male", "es"),
# French
("ff_siwis", "Siwis", "female", "fr"),
# Hindi
("hf_alpha", "Alpha", "female", "hi"),
("hf_beta", "Beta", "female", "hi"),
("hm_omega", "Omega", "male", "hi"),
("hm_psi", "Psi", "male", "hi"),
# Italian
("if_sara", "Sara", "female", "it"),
("im_nicola", "Nicola", "male", "it"),
# Japanese
("jf_alpha", "Alpha", "female", "ja"),
("jf_gongitsune", "Gongitsune", "female", "ja"),
("jf_nezumi", "Nezumi", "female", "ja"),
("jf_tebukuro", "Tebukuro", "female", "ja"),
("jm_kumo", "Kumo", "male", "ja"),
# Portuguese
("pf_dora", "Dora", "female", "pt"),
("pm_alex", "Alex", "male", "pt"),
("pm_santa", "Santa", "male", "pt"),
# Chinese
("zf_xiaobei", "Xiaobei", "female", "zh"),
("zf_xiaoni", "Xiaoni", "female", "zh"),
("zf_xiaoxiao", "Xiaoxiao", "female", "zh"),
("zf_xiaoyi", "Xiaoyi", "female", "zh"),
]
# Map our ISO language codes to Kokoro lang_code characters
LANG_CODE_MAP = {
"en": "a", # American English
"es": "e",
"fr": "f",
"hi": "h",
"it": "i",
"pt": "p",
"ja": "j",
"zh": "z",
}
class KokoroTTSBackend:
"""Kokoro-82M TTS backend — tiny, fast, CPU-friendly."""
def __init__(self):
self._model = None
self._pipelines: dict = {} # lang_code -> KPipeline
self._device: Optional[str] = None
self.model_size = "default"
def _get_device(self) -> str:
"""Select device. Kokoro supports CUDA and CPU. MPS needs fallback env var."""
device = get_torch_device(allow_mps=False)
# Kokoro can use MPS but requires PYTORCH_ENABLE_MPS_FALLBACK=1
# For now, skip MPS to avoid user confusion — CPU is already realtime
return device
@property
def device(self) -> str:
if self._device is None:
self._device = self._get_device()
return self._device
def is_loaded(self) -> bool:
return self._model is not None
def _get_model_path(self, model_size: str) -> str:
return KOKORO_HF_REPO
def _is_model_cached(self, model_size: str = "default") -> bool:
"""Check if Kokoro model files are cached locally."""
from .base import is_model_cached
return is_model_cached(
KOKORO_HF_REPO,
required_files=["config.json", "kokoro-v1_0.pth"],
)
async def load_model(self, model_size: str = "default") -> None:
"""Load the Kokoro model."""
if self._model is not None:
return
await asyncio.to_thread(self._load_model_sync)
def _load_model_sync(self):
"""Synchronous model loading."""
model_name = "kokoro"
is_cached = self._is_model_cached()
with model_load_progress(model_name, is_cached):
from kokoro import KModel
device = self.device
logger.info(f"Loading Kokoro-82M on {device}...")
self._model = KModel(repo_id=KOKORO_HF_REPO).to(device).eval()
logger.info("Kokoro-82M loaded successfully")
def _get_pipeline(self, lang_code: str):
"""Get or create a KPipeline for the given language code."""
kokoro_lang = LANG_CODE_MAP.get(lang_code, "a")
if kokoro_lang not in self._pipelines:
from kokoro import KPipeline
# Create pipeline with our existing model (no redundant model loading)
self._pipelines[kokoro_lang] = KPipeline(
lang_code=kokoro_lang,
repo_id=KOKORO_HF_REPO,
model=self._model,
)
return self._pipelines[kokoro_lang]
def unload_model(self) -> None:
"""Unload model to free memory."""
if self._model is not None:
del self._model
self._model = None
self._pipelines.clear()
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Kokoro unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> tuple[dict, bool]:
"""
Create voice prompt for Kokoro.
Kokoro doesn't do traditional voice cloning from arbitrary audio.
When called for a cloned profile (fallback), uses the default voice.
For preset profiles, the voice_prompt dict is built by the profile
service and bypasses this method entirely.
"""
return {
"voice_type": "preset",
"preset_engine": "kokoro",
"preset_voice_id": KOKORO_DEFAULT_VOICE,
}, False
async def combine_voice_prompts(
self,
audio_paths: list[str],
reference_texts: list[str],
) -> tuple[np.ndarray, str]:
"""Combine voice prompts — uses base implementation for audio concatenation."""
return await _combine_voice_prompts(
audio_paths, reference_texts, sample_rate=KOKORO_SAMPLE_RATE
)
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> tuple[np.ndarray, int]:
"""
Generate audio from text using Kokoro.
Args:
text: Text to synthesize
voice_prompt: Dict with kokoro_voice key
language: Language code
seed: Random seed for reproducibility
instruct: Not supported by Kokoro (ignored)
Returns:
Tuple of (audio_array, sample_rate)
"""
await self.load_model()
voice_name = voice_prompt.get("preset_voice_id") or voice_prompt.get("kokoro_voice") or KOKORO_DEFAULT_VOICE
def _generate_sync():
import torch
if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
pipeline = self._get_pipeline(language)
# Generate all chunks and concatenate
audio_chunks = []
for result in pipeline(text, voice=voice_name, speed=1.0):
if result.audio is not None:
chunk = result.audio
if isinstance(chunk, torch.Tensor):
chunk = chunk.detach().cpu().numpy()
audio_chunks.append(chunk.squeeze())
if not audio_chunks:
# Return 1 second of silence as fallback
return np.zeros(KOKORO_SAMPLE_RATE, dtype=np.float32), KOKORO_SAMPLE_RATE
audio = np.concatenate(audio_chunks)
return audio.astype(np.float32), KOKORO_SAMPLE_RATE
return await asyncio.to_thread(_generate_sync)

View File

@@ -0,0 +1,184 @@
"""
LuxTTS backend implementation.
Wraps the LuxTTS (ZipVoice) model for zero-shot voice cloning.
~1GB VRAM, 48kHz output, 150x realtime on CPU.
"""
import asyncio
import logging
from typing import Optional, Tuple
import numpy as np
from . import TTSBackend
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
)
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
logger = logging.getLogger(__name__)
# HuggingFace repo for model weight detection
LUXTTS_HF_REPO = "YatharthS/LuxTTS"
class LuxTTSBackend:
"""LuxTTS backend for zero-shot voice cloning."""
def __init__(self):
self.model = None
self.model_size = "default" # LuxTTS has only one model size
self._device = None
def _get_device(self) -> str:
return get_torch_device(allow_mps=True, allow_xpu=True)
def is_loaded(self) -> bool:
return self.model is not None
@property
def device(self) -> str:
if self._device is None:
self._device = self._get_device()
return self._device
def _get_model_path(self, model_size: str) -> str:
return LUXTTS_HF_REPO
def _is_model_cached(self, model_size: str = "default") -> bool:
return is_model_cached(
LUXTTS_HF_REPO,
weight_extensions=(".pt", ".safetensors", ".onnx", ".bin"),
)
async def load_model(self, model_size: str = "default") -> None:
"""Load the LuxTTS model."""
if self.model is not None:
return
await asyncio.to_thread(self._load_model_sync)
def _load_model_sync(self):
model_name = "luxtts"
is_cached = self._is_model_cached()
with model_load_progress(model_name, is_cached):
from zipvoice.luxvoice import LuxTTS
device = self.device
logger.info(f"Loading LuxTTS on {device}...")
if device == "cpu":
import os
threads = os.cpu_count() or 4
self.model = LuxTTS(
model_path=LUXTTS_HF_REPO,
device="cpu",
threads=min(threads, 8),
)
else:
self.model = LuxTTS(model_path=LUXTTS_HF_REPO, device=device)
logger.info("LuxTTS loaded successfully")
def unload_model(self) -> None:
"""Unload model to free memory."""
if self.model is not None:
device = self.device
del self.model
self.model = None
self._device = None
empty_device_cache(device)
logger.info("LuxTTS unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio.
LuxTTS uses its own encode_prompt() which runs Whisper ASR internally
to transcribe the reference. The reference_text parameter is not used
by LuxTTS itself, but we include it in the cache key for consistency.
"""
await self.load_model()
# Compute cache key once for both lookup and storage
cache_key = ("luxtts_" + get_cache_key(audio_path, reference_text)) if use_cache else None
if cache_key:
cached = get_cached_voice_prompt(cache_key)
if cached is not None and isinstance(cached, dict):
return cached, True
def _encode_sync():
return self.model.encode_prompt(
prompt_audio=str(audio_path),
duration=5,
rms=0.01,
)
encoded = await asyncio.to_thread(_encode_sync)
if cache_key:
cache_voice_prompt(cache_key, encoded)
return encoded, False
async def combine_voice_prompts(self, audio_paths, reference_texts):
return await _combine_voice_prompts(audio_paths, reference_texts, sample_rate=24000)
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio from text using LuxTTS.
Args:
text: Text to synthesize
voice_prompt: Encoded prompt dict from encode_prompt()
language: Language code (LuxTTS is English-focused)
seed: Random seed for reproducibility
instruct: Not supported by LuxTTS (ignored)
Returns:
Tuple of (audio_array, sample_rate)
"""
await self.load_model()
def _generate_sync():
if seed is not None:
manual_seed(seed, self.device)
wav = self.model.generate_speech(
text=text,
encode_dict=voice_prompt,
num_steps=4,
guidance_scale=3.0,
t_shift=0.5,
speed=1.0,
return_smooth=False, # 48kHz output
)
# LuxTTS returns a tensor (may be on GPU/MPS), move to CPU first
audio = wav.detach().cpu().numpy().squeeze()
return audio, 48000
return await asyncio.to_thread(_generate_sync)

View File

@@ -0,0 +1,367 @@
"""
MLX backend implementation for TTS and STT using mlx-audio.
"""
from typing import Optional, List, Tuple
import asyncio
import logging
import numpy as np
from pathlib import Path
logger = logging.getLogger(__name__)
# PATCH: Import and apply offline patch BEFORE any huggingface_hub usage
# This prevents mlx_audio from making network requests when models are cached
from ..utils.hf_offline_patch import patch_huggingface_hub_offline, ensure_original_qwen_config_cached
patch_huggingface_hub_offline()
ensure_original_qwen_config_cached()
from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS
from .base import is_model_cached, combine_voice_prompts as _combine_voice_prompts, model_load_progress
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
class MLXTTSBackend:
"""MLX-based TTS backend using mlx-audio."""
def __init__(self, model_size: str = "1.7B"):
self.model = None
self.model_size = model_size
self._current_model_size = None
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self.model is not None
def _get_model_path(self, model_size: str) -> str:
"""
Get the MLX model path.
Args:
model_size: Model size (1.7B or 0.6B)
Returns:
HuggingFace Hub model ID for MLX
"""
mlx_model_map = {
"1.7B": "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16",
"0.6B": "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-bf16",
}
if model_size not in mlx_model_map:
raise ValueError(f"Unknown model size: {model_size}")
hf_model_id = mlx_model_map[model_size]
logger.info("Will download MLX model from HuggingFace Hub: %s", hf_model_id)
return hf_model_id
def _is_model_cached(self, model_size: str) -> bool:
return is_model_cached(
self._get_model_path(model_size),
weight_extensions=(".safetensors", ".bin", ".npz"),
)
async def load_model_async(self, model_size: Optional[str] = None):
"""
Lazy load the MLX TTS model.
Args:
model_size: Model size to load (1.7B or 0.6B)
"""
if model_size is None:
model_size = self.model_size
# If already loaded with correct size, return
if self.model is not None and self._current_model_size == model_size:
return
# Unload existing model if different size requested
if self.model is not None and self._current_model_size != model_size:
self.unload_model()
# Run blocking load in thread pool
await asyncio.to_thread(self._load_model_sync, model_size)
# Alias for compatibility
load_model = load_model_async
def _load_model_sync(self, model_size: str):
"""Synchronous model loading."""
model_path = self._get_model_path(model_size)
model_name = f"qwen-tts-{model_size}"
is_cached = self._is_model_cached(model_size)
with model_load_progress(model_name, is_cached):
from mlx_audio.tts import load
logger.info("Loading MLX TTS model %s...", model_size)
self.model = load(model_path)
self._current_model_size = model_size
self.model_size = model_size
logger.info("MLX TTS model %s loaded successfully", model_size)
def unload_model(self):
"""Unload the model to free memory."""
if self.model is not None:
del self.model
self.model = None
self._current_model_size = None
logger.info("MLX TTS model unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio.
MLX backend stores voice prompt as a dict with audio path and text.
The actual voice prompt processing happens during generation.
Args:
audio_path: Path to reference audio file
reference_text: Transcript of reference audio
use_cache: Whether to use cached prompt if available
Returns:
Tuple of (voice_prompt_dict, was_cached)
"""
await self.load_model_async(None)
# Check cache if enabled
if use_cache:
cache_key = get_cache_key(audio_path, reference_text)
cached_prompt = get_cached_voice_prompt(cache_key)
if cached_prompt is not None:
# Return cached prompt (should be dict format)
if isinstance(cached_prompt, dict):
# Validate that the cached audio file still exists
cached_audio_path = cached_prompt.get("ref_audio") or cached_prompt.get("ref_audio_path")
if cached_audio_path and Path(cached_audio_path).exists():
return cached_prompt, True
else:
# Cached file no longer exists, invalidate cache
logger.warning("Cached audio file not found: %s, regenerating prompt", cached_audio_path)
# MLX voice prompt format - store audio path and text
# The model will process this during generation
voice_prompt_items = {
"ref_audio": str(audio_path),
"ref_text": reference_text,
}
# Cache if enabled
if use_cache:
cache_key = get_cache_key(audio_path, reference_text)
cache_voice_prompt(cache_key, voice_prompt_items)
return voice_prompt_items, False
async def combine_voice_prompts(self, audio_paths, reference_texts):
return await _combine_voice_prompts(audio_paths, reference_texts)
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio from text using voice prompt.
Args:
text: Text to synthesize
voice_prompt: Voice prompt dictionary with ref_audio and ref_text
language: Language code (en or zh) - may not be fully supported by MLX
seed: Random seed for reproducibility
instruct: Natural language instruction (may not be supported by MLX)
Returns:
Tuple of (audio_array, sample_rate)
"""
await self.load_model_async(None)
logger.info("Generating audio for text: %s", text)
def _generate_sync():
"""Run synchronous generation in thread pool."""
# MLX generate() returns a generator yielding GenerationResult objects
audio_chunks = []
sample_rate = 24000
lang = LANGUAGE_CODE_TO_NAME.get(language, "auto")
# Set seed if provided (MLX uses numpy random)
if seed is not None:
import mlx.core as mx
np.random.seed(seed)
mx.random.seed(seed)
# Extract voice prompt info
ref_audio = voice_prompt.get("ref_audio") or voice_prompt.get("ref_audio_path")
ref_text = voice_prompt.get("ref_text", "")
# Validate that the audio file exists
if ref_audio and not Path(ref_audio).exists():
logger.warning("Audio file not found: %s", ref_audio)
logger.warning("This may be due to a cached voice prompt referencing a deleted temp file.")
logger.warning("Regenerating without voice prompt.")
ref_audio = None
# Inference runs with the process's default HF_HUB_OFFLINE
# state. Forcing offline here (previously used to avoid lazy
# mlx_audio lookups hanging when the network drops mid-inference,
# issue #462) regressed online users because libraries make
# legitimate metadata calls during generation.
try:
if ref_audio:
# Check if generate accepts ref_audio parameter
import inspect
sig = inspect.signature(self.model.generate)
if "ref_audio" in sig.parameters:
# Generate with voice cloning
for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text, lang_code=lang):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
else:
# Fallback: generate without voice cloning
for result in self.model.generate(text, lang_code=lang):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
else:
# No voice prompt, generate normally
for result in self.model.generate(text, lang_code=lang):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
except Exception as e:
# If voice cloning fails, try without it
logger.warning("Voice cloning failed, generating without voice prompt: %s", e)
for result in self.model.generate(text, lang_code=lang):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
# Concatenate all chunks
if audio_chunks:
audio = np.concatenate([np.asarray(chunk, dtype=np.float32) for chunk in audio_chunks])
else:
# Fallback: empty audio
audio = np.array([], dtype=np.float32)
return audio, sample_rate
# Run blocking inference in thread pool
audio, sample_rate = await asyncio.to_thread(_generate_sync)
return audio, sample_rate
class MLXSTTBackend:
"""MLX-based STT backend using mlx-audio Whisper."""
def __init__(self, model_size: str = "base"):
self.model = None
self.model_size = model_size
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self.model is not None
def _is_model_cached(self, model_size: str) -> bool:
hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
return is_model_cached(hf_repo, weight_extensions=(".safetensors", ".bin", ".npz"))
async def load_model_async(self, model_size: Optional[str] = None):
"""
Lazy load the MLX Whisper model.
Args:
model_size: Model size (tiny, base, small, medium, large)
"""
if model_size is None:
model_size = self.model_size
if self.model is not None and self.model_size == model_size:
return
# Run blocking load in thread pool
await asyncio.to_thread(self._load_model_sync, model_size)
# Alias for compatibility
load_model = load_model_async
def _load_model_sync(self, model_size: str):
"""Synchronous model loading."""
progress_model_name = f"whisper-{model_size}"
is_cached = self._is_model_cached(model_size)
with model_load_progress(progress_model_name, is_cached):
from mlx_audio.stt import load
model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
logger.info("Loading MLX Whisper model %s...", model_size)
self.model = load(model_name)
self.model_size = model_size
logger.info("MLX Whisper model %s loaded successfully", model_size)
def unload_model(self):
"""Unload the model to free memory."""
if self.model is not None:
del self.model
self.model = None
logger.info("MLX Whisper model unloaded")
async def transcribe(
self,
audio_path: str,
language: Optional[str] = None,
model_size: Optional[str] = None,
) -> str:
"""
Transcribe audio to text.
Args:
audio_path: Path to audio file
language: Optional language hint
model_size: Optional model size override
Returns:
Transcribed text
"""
await self.load_model_async(model_size)
def _transcribe_sync():
"""Run synchronous transcription in thread pool."""
# MLX Whisper transcription using generate method
# The generate method accepts audio path directly
decode_options = {}
if language:
decode_options["language"] = language
# Inference runs with the process's default HF_HUB_OFFLINE
# state — see the comment in MLXTTSBackend.generate for the
# regression this revert fixes (issue #462).
result = self.model.generate(str(audio_path), **decode_options)
# Extract text from result
if isinstance(result, str):
return result.strip()
elif isinstance(result, dict):
return result.get("text", "").strip()
elif hasattr(result, "text"):
return result.text.strip()
else:
return str(result).strip()
# Run blocking transcription in thread pool
return await asyncio.to_thread(_transcribe_sync)

View File

@@ -0,0 +1,378 @@
"""
PyTorch backend implementation for TTS and STT.
"""
from typing import Optional, List, Tuple
import asyncio
import logging
import torch
import numpy as np
logger = logging.getLogger(__name__)
from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
)
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
from ..utils.audio import load_audio
class PyTorchTTSBackend:
"""PyTorch-based TTS backend using Qwen3-TTS."""
def __init__(self, model_size: str = "1.7B"):
self.model = None
self.model_size = model_size
self.device = self._get_device()
self._current_model_size = None
def _get_device(self) -> str:
"""Get the best available device."""
return get_torch_device(allow_xpu=True, allow_directml=True)
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self.model is not None
def _get_model_path(self, model_size: str) -> str:
"""
Get the HuggingFace Hub model ID.
Args:
model_size: Model size (1.7B or 0.6B)
Returns:
HuggingFace Hub model ID
"""
hf_model_map = {
"1.7B": "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
"0.6B": "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
}
if model_size not in hf_model_map:
raise ValueError(f"Unknown model size: {model_size}")
return hf_model_map[model_size]
def _is_model_cached(self, model_size: str) -> bool:
return is_model_cached(self._get_model_path(model_size))
async def load_model_async(self, model_size: Optional[str] = None):
"""
Lazy load the TTS model with automatic downloading from HuggingFace Hub.
Args:
model_size: Model size to load (1.7B or 0.6B)
"""
if model_size is None:
model_size = self.model_size
# If already loaded with correct size, return
if self.model is not None and self._current_model_size == model_size:
return
# Unload existing model if different size requested
if self.model is not None and self._current_model_size != model_size:
self.unload_model()
# Run blocking load in thread pool
await asyncio.to_thread(self._load_model_sync, model_size)
# Alias for compatibility
load_model = load_model_async
def _load_model_sync(self, model_size: str):
"""Synchronous model loading."""
model_name = f"qwen-tts-{model_size}"
is_cached = self._is_model_cached(model_size)
with model_load_progress(model_name, is_cached):
from qwen_tts import Qwen3TTSModel
model_path = self._get_model_path(model_size)
logger.info("Loading TTS model %s on %s...", model_size, self.device)
# Route both HF Hub and Transformers through a single cache root.
# On Windows local setups, model assets can otherwise split between
# .hf-cache/hub and .hf-cache/transformers, causing speech_tokenizer
# and preprocessor_config.json to fail to resolve during load.
from huggingface_hub import constants as hf_constants
tts_cache_dir = hf_constants.HF_HUB_CACHE
if self.device == "cpu":
self.model = Qwen3TTSModel.from_pretrained(
model_path,
cache_dir=tts_cache_dir,
torch_dtype=torch.float32,
low_cpu_mem_usage=False,
)
else:
self.model = Qwen3TTSModel.from_pretrained(
model_path,
cache_dir=tts_cache_dir,
device_map=self.device,
torch_dtype=torch.bfloat16,
)
self._current_model_size = model_size
self.model_size = model_size
logger.info("TTS model %s loaded successfully", model_size)
def unload_model(self):
"""Unload the model to free memory."""
if self.model is not None:
del self.model
self.model = None
self._current_model_size = None
empty_device_cache(self.device)
logger.info("TTS model unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio.
Args:
audio_path: Path to reference audio file
reference_text: Transcript of reference audio
use_cache: Whether to use cached prompt if available
Returns:
Tuple of (voice_prompt_dict, was_cached)
"""
await self.load_model_async(None)
# Check cache if enabled
if use_cache:
cache_key = get_cache_key(audio_path, reference_text)
cached_prompt = get_cached_voice_prompt(cache_key)
if cached_prompt is not None:
# Cache stores as torch.Tensor but actual prompt is dict
# Convert if needed
if isinstance(cached_prompt, dict):
# For PyTorch backend, the dict should contain tensors, not file paths
# So we can safely return it
return cached_prompt, True
elif isinstance(cached_prompt, torch.Tensor):
# Legacy cache format - convert to dict
# This shouldn't happen in practice, but handle it
return {"prompt": cached_prompt}, True
def _create_prompt_sync():
"""Run synchronous voice prompt creation in thread pool."""
# Inference runs with the process's default HF_HUB_OFFLINE
# state. Forcing offline here (issue #462) regressed online
# users whose libraries issue legitimate metadata lookups
# during voice-prompt creation.
return self.model.create_voice_clone_prompt(
ref_audio=str(audio_path),
ref_text=reference_text,
x_vector_only_mode=False,
)
# Run blocking operation in thread pool
voice_prompt_items = await asyncio.to_thread(_create_prompt_sync)
# Cache if enabled
if use_cache:
cache_key = get_cache_key(audio_path, reference_text)
cache_voice_prompt(cache_key, voice_prompt_items)
return voice_prompt_items, False
async def combine_voice_prompts(
self,
audio_paths: List[str],
reference_texts: List[str],
) -> Tuple[np.ndarray, str]:
return await _combine_voice_prompts(audio_paths, reference_texts)
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio from text using voice prompt.
Args:
text: Text to synthesize
voice_prompt: Voice prompt dictionary from create_voice_prompt
language: Language code (en or zh)
seed: Random seed for reproducibility
instruct: Natural language instruction for speech delivery control
Returns:
Tuple of (audio_array, sample_rate)
"""
# Load model
await self.load_model_async(None)
def _generate_sync():
"""Run synchronous generation in thread pool."""
# Set seed if provided
if seed is not None:
manual_seed(seed, self.device)
# See _create_prompt_sync comment — inference runs with the
# process's default HF_HUB_OFFLINE state (issue #462).
wavs, sample_rate = self.model.generate_voice_clone(
text=text,
voice_clone_prompt=voice_prompt,
language=LANGUAGE_CODE_TO_NAME.get(language, "auto"),
instruct=instruct,
)
return wavs[0], sample_rate
# Run blocking inference in thread pool to avoid blocking event loop
audio, sample_rate = await asyncio.to_thread(_generate_sync)
return audio, sample_rate
class PyTorchSTTBackend:
"""PyTorch-based STT backend using Whisper."""
def __init__(self, model_size: str = "base"):
self.model = None
self.processor = None
self.model_size = model_size
self.device = self._get_device()
def _get_device(self) -> str:
"""Get the best available device."""
return get_torch_device(allow_xpu=True, allow_directml=True)
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self.model is not None
def _is_model_cached(self, model_size: str) -> bool:
hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
return is_model_cached(hf_repo)
async def load_model_async(self, model_size: Optional[str] = None):
"""
Lazy load the Whisper model.
Args:
model_size: Model size (tiny, base, small, medium, large)
"""
if model_size is None:
model_size = self.model_size
if self.model is not None and self.model_size == model_size:
return
await asyncio.to_thread(self._load_model_sync, model_size)
# Alias for compatibility
load_model = load_model_async
def _load_model_sync(self, model_size: str):
"""Synchronous model loading."""
progress_model_name = f"whisper-{model_size}"
is_cached = self._is_model_cached(model_size)
with model_load_progress(progress_model_name, is_cached):
from transformers import WhisperProcessor, WhisperForConditionalGeneration
model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
logger.info("Loading Whisper model %s on %s...", model_size, self.device)
self.processor = WhisperProcessor.from_pretrained(model_name)
self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
self.model.to(self.device)
self.model_size = model_size
logger.info("Whisper model %s loaded successfully", model_size)
def unload_model(self):
"""Unload the model to free memory."""
if self.model is not None:
del self.model
del self.processor
self.model = None
self.processor = None
empty_device_cache(self.device)
logger.info("Whisper model unloaded")
async def transcribe(
self,
audio_path: str,
language: Optional[str] = None,
model_size: Optional[str] = None,
) -> str:
"""
Transcribe audio to text.
Args:
audio_path: Path to audio file
language: Optional language hint
model_size: Optional model size override
Returns:
Transcribed text
"""
await self.load_model_async(model_size)
def _transcribe_sync():
"""Run synchronous transcription in thread pool."""
# Load audio
audio, _sr = load_audio(audio_path, sample_rate=16000)
# Inference runs with the process's default HF_HUB_OFFLINE
# state — forcing offline here (issue #462) broke online users
# whose `get_decoder_prompt_ids` / tokenizer calls issue
# legitimate metadata lookups.
# Process audio
inputs = self.processor(
audio,
sampling_rate=16000,
return_tensors="pt",
)
inputs = inputs.to(self.device)
# Generate transcription
# If language is provided, force it; otherwise let Whisper auto-detect
generate_kwargs = {}
if language:
forced_decoder_ids = self.processor.get_decoder_prompt_ids(
language=language,
task="transcribe",
)
generate_kwargs["forced_decoder_ids"] = forced_decoder_ids
with torch.no_grad():
predicted_ids = self.model.generate(
inputs["input_features"],
**generate_kwargs,
)
# Decode
transcription = self.processor.batch_decode(
predicted_ids,
skip_special_tokens=True,
)[0]
return transcription.strip()
# Run blocking transcription in thread pool
return await asyncio.to_thread(_transcribe_sync)

View File

@@ -0,0 +1,214 @@
"""
Qwen3-TTS CustomVoice backend implementation.
Wraps the Qwen3-TTS-12Hz CustomVoice model for preset-speaker TTS with
instruction-based style control. Uses the same qwen_tts library as the
Base model (pytorch_backend.py) but loads a different checkpoint and
calls generate_custom_voice() instead of generate_voice_clone().
Key differences from the Base engine:
- Uses preset speakers (9 built-in voices) instead of zero-shot cloning
- Supports instruct parameter for tone/emotion/prosody control
- Two model sizes: 1.7B and 0.6B
Languages supported: zh, en, ja, ko, de, fr, ru, pt, es, it
"""
import asyncio
import logging
from typing import Optional
import numpy as np
import torch
from . import TTSBackend, LANGUAGE_CODE_TO_NAME
from .base import (
is_model_cached,
get_torch_device,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
)
logger = logging.getLogger(__name__)
# ── Preset speakers ──────────────────────────────────────────────────
# (speaker_id, display_name, gender, native_language_code, description)
QWEN_CUSTOM_VOICES = [
("Vivian", "Vivian", "female", "zh", "Bright, slightly edgy young female voice"),
("Serena", "Serena", "female", "zh", "Warm, gentle young female voice"),
("Uncle_Fu", "Uncle Fu", "male", "zh", "Seasoned male voice with a low, mellow timbre"),
("Dylan", "Dylan", "male", "zh", "Youthful Beijing male voice with a clear, natural timbre"),
("Eric", "Eric", "male", "zh", "Lively Chengdu male voice with a slightly husky brightness"),
("Ryan", "Ryan", "male", "en", "Dynamic male voice with strong rhythmic drive"),
("Aiden", "Aiden", "male", "en", "Sunny American male voice with a clear midrange"),
("Ono_Anna", "Ono Anna", "female", "ja", "Playful Japanese female voice with a light, nimble timbre"),
("Sohee", "Sohee", "female", "ko", "Warm Korean female voice with rich emotion"),
]
QWEN_CV_DEFAULT_SPEAKER = "Ryan"
# HuggingFace repo IDs per model size
QWEN_CV_HF_REPOS = {
"1.7B": "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
"0.6B": "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
}
class QwenCustomVoiceBackend:
"""Qwen3-TTS CustomVoice backend — preset speakers with instruct control."""
def __init__(self, model_size: str = "1.7B"):
self.model = None
self.model_size = model_size
self.device = self._get_device()
self._current_model_size: Optional[str] = None
def _get_device(self) -> str:
return get_torch_device(allow_xpu=True, allow_directml=True)
def is_loaded(self) -> bool:
return self.model is not None
def _get_model_path(self, model_size: str) -> str:
if model_size not in QWEN_CV_HF_REPOS:
raise ValueError(f"Unknown model size: {model_size}")
return QWEN_CV_HF_REPOS[model_size]
def _is_model_cached(self, model_size: Optional[str] = None) -> bool:
size = model_size or self.model_size
return is_model_cached(self._get_model_path(size))
async def load_model_async(self, model_size: Optional[str] = None) -> None:
if model_size is None:
model_size = self.model_size
if self.model is not None and self._current_model_size == model_size:
return
if self.model is not None and self._current_model_size != model_size:
self.unload_model()
await asyncio.to_thread(self._load_model_sync, model_size)
# Alias for compatibility with the TTSBackend protocol
load_model = load_model_async
def _load_model_sync(self, model_size: str) -> None:
model_name = f"qwen-custom-voice-{model_size}"
is_cached = self._is_model_cached(model_size)
with model_load_progress(model_name, is_cached):
from qwen_tts import Qwen3TTSModel
model_path = self._get_model_path(model_size)
logger.info("Loading Qwen CustomVoice %s on %s...", model_size, self.device)
if self.device == "cpu":
self.model = Qwen3TTSModel.from_pretrained(
model_path,
torch_dtype=torch.float32,
low_cpu_mem_usage=False,
)
else:
self.model = Qwen3TTSModel.from_pretrained(
model_path,
device_map=self.device,
torch_dtype=torch.bfloat16,
)
self._current_model_size = model_size
self.model_size = model_size
logger.info("Qwen CustomVoice %s loaded successfully", model_size)
def unload_model(self) -> None:
if self.model is not None:
del self.model
self.model = None
self._current_model_size = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Qwen CustomVoice unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> tuple[dict, bool]:
"""
Create voice prompt for CustomVoice.
CustomVoice doesn't use reference audio — it uses preset speakers.
When called for a cloned profile (fallback), uses the default speaker.
For preset profiles, the voice_prompt dict is built by the profile
service and bypasses this method entirely.
"""
return {
"voice_type": "preset",
"preset_engine": "qwen_custom_voice",
"preset_voice_id": QWEN_CV_DEFAULT_SPEAKER,
}, False
async def combine_voice_prompts(
self,
audio_paths: list[str],
reference_texts: list[str],
) -> tuple[np.ndarray, str]:
return await _combine_voice_prompts(audio_paths, reference_texts)
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> tuple[np.ndarray, int]:
"""
Generate audio using Qwen CustomVoice.
Args:
text: Text to synthesize
voice_prompt: Dict with preset_voice_id (speaker name)
language: Language code (zh, en, ja, ko, etc.)
seed: Random seed for reproducibility
instruct: Natural language instruction for style control
(e.g. "Speak in an angry tone", "Very happy")
Returns:
Tuple of (audio_array, sample_rate)
"""
await self.load_model_async(None)
speaker = voice_prompt.get("preset_voice_id") or QWEN_CV_DEFAULT_SPEAKER
def _generate_sync():
if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
lang_name = LANGUAGE_CODE_TO_NAME.get(language, "auto")
kwargs = {
"text": text,
"language": lang_name.capitalize() if lang_name != "auto" else "Auto",
"speaker": speaker,
}
# Only pass instruct if non-empty
if instruct:
kwargs["instruct"] = instruct
# Inference runs with the process's default HF_HUB_OFFLINE
# state. Forcing offline here (issue #462) regressed online
# users whose libraries issue legitimate metadata lookups
# during generation.
wavs, sample_rate = self.model.generate_custom_voice(**kwargs)
return wavs[0], sample_rate
audio, sample_rate = await asyncio.to_thread(_generate_sync)
return audio, sample_rate

458
backend/build_binary.py Normal file
View File

@@ -0,0 +1,458 @@
"""
PyInstaller build script for creating standalone Python server binary.
Usage:
python build_binary.py # Build default (CPU) server binary
python build_binary.py --cuda # Build CUDA-enabled server binary
"""
import PyInstaller.__main__
import argparse
import logging
import os
import platform
import sys
from pathlib import Path
logger = logging.getLogger(__name__)
def is_apple_silicon():
"""Check if running on Apple Silicon."""
return platform.system() == "Darwin" and platform.machine() == "arm64"
def build_server(cuda=False):
"""Build Python server as standalone binary.
Args:
cuda: If True, build with CUDA support and name the binary
voicebox-server-cuda instead of voicebox-server.
"""
backend_dir = Path(__file__).parent
binary_name = "voicebox-server-cuda" if cuda else "voicebox-server"
# PyInstaller arguments
# CUDA builds use --onedir so we can split the output into two archives:
# 1. Server core (~200-400MB) — versioned with the app
# 2. CUDA libs (~2GB) — versioned independently (only redownloaded on
# CUDA toolkit / torch major version changes)
# CPU builds remain --onefile for simplicity.
pack_mode = "--onedir" if cuda else "--onefile"
args = [
"server.py", # Use server.py as entry point instead of main.py
pack_mode,
"--name",
binary_name,
]
# Hide console window on Windows only. On macOS/Linux the sidecar needs
# stdout/stderr for Tauri to capture logs.
if platform.system() == "Windows":
args.append("--noconsole")
# numpy 2.x / torch ABI mismatch fix: install memmove fallback for
# torch.from_numpy() before the app starts. Runtime hooks run after
# FrozenImporter is registered so frozen torch/numpy are importable.
# Paths are passed relative to backend_dir because os.chdir(backend_dir)
# runs before PyInstaller. Absolute paths would get baked into the
# generated .spec, breaking reproducible builds on other machines / CI.
args.extend(
[
"--runtime-hook",
"pyi_rth_numpy_compat.py",
# Stub torch.compiler.disable before transformers imports
# flex_attention, which otherwise triggers torch._dynamo →
# torch._numpy._ufuncs and crashes at module load under
# PyInstaller. See pyi_rth_torch_compiler_disable.py.
"--runtime-hook",
"pyi_rth_torch_compiler_disable.py",
# Per-module collection overrides (e.g. forcing scipy.stats._distn_infrastructure
# to bundle .py source alongside .pyc so the runtime hook can source-patch it).
"--additional-hooks-dir",
"pyi_hooks",
]
)
# Add local qwen_tts path if specified (for editable installs)
qwen_tts_path = os.getenv("QWEN_TTS_PATH")
if qwen_tts_path and Path(qwen_tts_path).exists():
args.extend(["--paths", str(qwen_tts_path)])
logger.info("Using local qwen_tts source from: %s", qwen_tts_path)
# Add common hidden imports
args.extend(
[
"--hidden-import",
"backend",
"--hidden-import",
"backend.main",
"--hidden-import",
"backend.config",
"--hidden-import",
"backend.database",
"--hidden-import",
"backend.models",
"--hidden-import",
"backend.services.profiles",
"--hidden-import",
"backend.services.history",
"--hidden-import",
"backend.services.tts",
"--hidden-import",
"backend.services.transcribe",
"--hidden-import",
"backend.utils.platform_detect",
"--hidden-import",
"backend.backends",
"--hidden-import",
"backend.backends.pytorch_backend",
"--hidden-import",
"backend.backends.qwen_custom_voice_backend",
"--hidden-import",
"backend.utils.audio",
"--hidden-import",
"backend.utils.cache",
"--hidden-import",
"backend.utils.progress",
"--hidden-import",
"backend.utils.hf_progress",
"--hidden-import",
"backend.services.cuda",
"--hidden-import",
"backend.services.effects",
"--hidden-import",
"backend.utils.effects",
"--hidden-import",
"backend.services.versions",
"--hidden-import",
"pedalboard",
"--hidden-import",
"chatterbox",
"--hidden-import",
"chatterbox.tts_turbo",
"--hidden-import",
"chatterbox.mtl_tts",
"--hidden-import",
"backend.backends.chatterbox_backend",
"--hidden-import",
"backend.backends.chatterbox_turbo_backend",
# chatterbox multilingual uses spacy_pkuseg for Chinese word
# segmentation, which ships pickled dict files (dicts/default.pkl)
# and native .so extensions that --hidden-import alone won't bundle.
"--collect-all",
"spacy_pkuseg",
"--hidden-import",
"backend.backends.luxtts_backend",
"--hidden-import",
"zipvoice",
"--hidden-import",
"zipvoice.luxvoice",
"--collect-all",
"zipvoice",
"--collect-all",
"linacodec",
"--hidden-import",
"torch",
"--hidden-import",
"transformers",
"--hidden-import",
"fastapi",
"--hidden-import",
"uvicorn",
"--hidden-import",
"sqlalchemy",
# librosa uses lazy_loader which generates .pyi stub files at
# install time and reads them at runtime to discover submodules.
# --hidden-import alone doesn't bundle the stubs, causing
# "Cannot load imports from non-existent stub" at runtime.
"--collect-all",
"lazy_loader",
"--collect-all",
"librosa",
"--hidden-import",
"soundfile",
"--hidden-import",
"qwen_tts",
"--hidden-import",
"qwen_tts.inference",
"--hidden-import",
"qwen_tts.inference.qwen3_tts_model",
"--hidden-import",
"qwen_tts.inference.qwen3_tts_tokenizer",
"--hidden-import",
"qwen_tts.core",
"--hidden-import",
"qwen_tts.cli",
"--copy-metadata",
"qwen-tts",
"--copy-metadata",
"requests",
"--copy-metadata",
"transformers",
"--copy-metadata",
"huggingface-hub",
"--copy-metadata",
"tokenizers",
"--copy-metadata",
"safetensors",
"--copy-metadata",
"tqdm",
"--hidden-import",
"requests",
# qwen_tts uses inspect.getsource() at runtime to locate
# modeling_qwen3_tts.py — needs physical .py source files bundled
"--collect-all",
"qwen_tts",
# Fix for pkg_resources and jaraco namespace packages
"--hidden-import",
"pkg_resources.extern",
"--collect-submodules",
"jaraco",
# inflect uses typeguard @typechecked which calls inspect.getsource()
# at import time — needs .py source files, not just .pyc bytecode
"--collect-all",
"inflect",
# perth ships pretrained watermark model files (hparams.yaml, .pth.tar)
# in perth/perth_net/pretrained/ — needed by chatterbox at runtime
"--collect-all",
"perth",
# piper_phonemize ships espeak-ng-data/ (phoneme tables, language dicts)
# needed by LuxTTS for text-to-phoneme conversion
"--collect-all",
"piper_phonemize",
# HumeAI TADA — speech-language model using Llama + flow matching
"--hidden-import",
"backend.backends.hume_backend",
"--hidden-import",
"tada",
"--hidden-import",
"tada.modules",
"--hidden-import",
"tada.modules.tada",
"--hidden-import",
"tada.modules.encoder",
"--hidden-import",
"tada.modules.decoder",
"--hidden-import",
"tada.modules.aligner",
"--hidden-import",
"tada.modules.acoustic_spkr_verf",
"--hidden-import",
"tada.nn",
"--hidden-import",
"tada.nn.vibevoice",
"--hidden-import",
"tada.utils",
"--hidden-import",
"tada.utils.gray_code",
"--hidden-import",
"tada.utils.text",
# DAC shim — provides dac.nn.layers.Snake1d without the real
# descript-audio-codec package (which pulls onnx/tensorboard via
# descript-audiotools). The shim is in backend/utils/dac_shim.py.
"--hidden-import",
"backend.utils.dac_shim",
"--hidden-import",
"torchaudio",
"--collect-submodules",
"tada",
# Kokoro 82M — lightweight TTS engine using misaki G2P
# collect-all is required because transformers introspects .py source
# files at runtime (e.g. _can_set_attn_implementation opens the class
# file); hidden-import alone only bundles bytecode.
"--hidden-import",
"backend.backends.kokoro_backend",
"--collect-all",
"kokoro",
# misaki ships G2P data files (dictionaries, phoneme tables)
# that must be bundled for espeak/en/ja/zh G2P to work
"--collect-all",
"misaki",
# language_tags ships JSON data files (index.json etc.) loaded at
# runtime via: misaki → phonemizer → segments → csvw → language_tags
"--collect-all",
"language_tags",
# espeakng_loader ships the entire espeak-ng-data directory (369 files)
# loaded at import time by misaki.espeak via get_data_path()
"--collect-all",
"espeakng_loader",
# spacy en_core_web_sm model — misaki.en tries to spacy.cli.download()
# at runtime if not found, which calls pip as a subprocess and crashes
# the frozen binary. Bundle the model so spacy.util.is_package() passes.
"--collect-all",
"en_core_web_sm",
"--copy-metadata",
"en_core_web_sm",
"--hidden-import",
"en_core_web_sm",
# unidic-lite ships the MeCab dictionary used by fugashi (pulled in
# by misaki[ja]). The dict lives in unidic_lite/dicdir/ and is
# discovered via the package's DICDIR constant, so the data files
# must be collected or Japanese Kokoro voices crash at runtime.
"--collect-all",
"unidic_lite",
"--hidden-import",
"loguru",
]
)
# Add CUDA-specific hidden imports
if cuda:
logger.info("Building with CUDA support")
args.extend(
[
"--hidden-import",
"torch.cuda",
"--hidden-import",
"torch.backends.cudnn",
]
)
else:
# Exclude NVIDIA CUDA packages from CPU-only builds to keep binary small.
# When building from a venv with CUDA torch installed, PyInstaller would
# bundle ~3GB of NVIDIA shared libraries. We exclude both the Python
# modules and the binary DLLs.
nvidia_packages = [
"nvidia",
"nvidia.cublas",
"nvidia.cuda_cupti",
"nvidia.cuda_nvrtc",
"nvidia.cuda_runtime",
"nvidia.cudnn",
"nvidia.cufft",
"nvidia.curand",
"nvidia.cusolver",
"nvidia.cusparse",
"nvidia.nccl",
"nvidia.nvjitlink",
"nvidia.nvtx",
]
for pkg in nvidia_packages:
args.extend(["--exclude-module", pkg])
# Add MLX-specific imports if building on Apple Silicon (never for CUDA builds)
if is_apple_silicon() and not cuda:
logger.info("Building for Apple Silicon - including MLX dependencies")
args.extend(
[
"--hidden-import",
"backend.backends.mlx_backend",
"--hidden-import",
"mlx",
"--hidden-import",
"mlx.core",
"--hidden-import",
"mlx.nn",
"--hidden-import",
"mlx_audio",
"--hidden-import",
"mlx_audio.tts",
"--hidden-import",
"mlx_audio.stt",
"--collect-submodules",
"mlx",
"--collect-submodules",
"mlx_audio",
# Use --collect-all so PyInstaller bundles both data files AND
# native shared libraries (.dylib, .metallib) for MLX.
# Previously only --collect-data was used, which caused MLX to
# raise OSError at runtime inside the bundled binary because
# the Metal shader libraries were missing.
"--collect-all",
"mlx",
"--collect-all",
"mlx_audio",
]
)
elif not cuda:
logger.info("Building for non-Apple Silicon platform - PyTorch only")
dist_dir = str(backend_dir / "dist")
build_dir = str(backend_dir / "build")
args.extend(
[
"--distpath",
dist_dir,
"--workpath",
build_dir,
"--noconfirm",
"--clean",
]
)
# Change to backend directory
os.chdir(backend_dir)
# For CPU builds on Windows, ensure we're using CPU-only torch.
# If CUDA torch is installed (local dev), swap to CPU torch before building,
# then restore CUDA torch after. This prevents PyInstaller from bundling
# ~3GB of CUDA DLLs into the CPU binary.
restore_cuda = False
if not cuda and platform.system() == "Windows":
import subprocess
result = subprocess.run(
[sys.executable, "-c", "import torch; print(torch.version.cuda or '')"], capture_output=True, text=True
)
has_cuda_torch = bool(result.stdout.strip())
if has_cuda_torch:
logger.info("CUDA torch detected — installing CPU torch for CPU build...")
subprocess.run(
[
sys.executable,
"-m",
"pip",
"install",
"torch",
"torchvision",
"torchaudio",
"--index-url",
"https://download.pytorch.org/whl/cpu",
"--force-reinstall",
"-q",
],
check=True,
)
restore_cuda = True
# Run PyInstaller
try:
PyInstaller.__main__.run(args)
finally:
# Restore CUDA torch if we swapped it out (even on build failure)
if restore_cuda:
logger.info("Restoring CUDA torch...")
import subprocess
subprocess.run(
[
sys.executable,
"-m",
"pip",
"install",
"torch",
"torchvision",
"torchaudio",
"--index-url",
"https://download.pytorch.org/whl/cu128",
"--force-reinstall",
"-q",
],
check=True,
)
logger.info("Binary built in %s", backend_dir / "dist" / binary_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Build voicebox-server binary")
parser.add_argument(
"--cuda",
action="store_true",
help="Build CUDA-enabled binary (voicebox-server-cuda)",
)
cli_args = parser.parse_args()
build_server(cuda=cli_args.cuda)

133
backend/config.py Normal file
View File

@@ -0,0 +1,133 @@
"""
Configuration module for voicebox backend.
Handles data directory configuration for production bundling.
"""
import logging
import os
from pathlib import Path
logger = logging.getLogger(__name__)
# Allow users to override the HuggingFace model download directory.
# Set VOICEBOX_MODELS_DIR to an absolute path before starting the server.
# This sets HF_HUB_CACHE so all huggingface_hub downloads go to that path.
_custom_models_dir = os.environ.get("VOICEBOX_MODELS_DIR")
if _custom_models_dir:
os.environ["HF_HUB_CACHE"] = _custom_models_dir
logger.info("Model download path set to: %s", _custom_models_dir)
# Default data directory (used in development)
_data_dir = Path("data").resolve()
def _path_relative_to_any_data_dir(path: Path) -> Path | None:
"""Extract the path within a data dir from an absolute or relative path."""
parts = path.parts
for idx, part in enumerate(parts):
if part != "data":
continue
tail = parts[idx + 1 :]
if tail:
return Path(*tail)
return Path()
return None
def set_data_dir(path: str | Path):
"""
Set the data directory path.
Args:
path: Path to the data directory
"""
global _data_dir
_data_dir = Path(path).resolve()
_data_dir.mkdir(parents=True, exist_ok=True)
logger.info("Data directory set to: %s", _data_dir)
def get_data_dir() -> Path:
"""
Get the data directory path.
Returns:
Path to the data directory
"""
return _data_dir
def to_storage_path(path: str | Path) -> str:
"""Convert a filesystem path to a DB-safe path relative to the data dir."""
resolved_path = Path(path).resolve()
relative_to_any_data_dir = _path_relative_to_any_data_dir(resolved_path)
if relative_to_any_data_dir is not None:
return str(relative_to_any_data_dir)
try:
return str(resolved_path.relative_to(_data_dir))
except ValueError:
return str(resolved_path)
def resolve_storage_path(path: str | Path | None) -> Path | None:
"""Resolve a DB-stored path against the configured data dir."""
if path is None:
return None
stored_path = Path(path)
if stored_path.is_absolute():
rebased_path = _path_relative_to_any_data_dir(stored_path)
if rebased_path is not None:
candidate = (_data_dir / rebased_path).resolve()
if candidate.exists() or not stored_path.exists():
return candidate
return stored_path
# 0.3.0 records sometimes stored relative paths with the data-dir name
# baked in (e.g. "data/profiles/..."). Joining those directly with
# _data_dir produces a spurious "<data_dir>/data/profiles/..." nest.
if stored_path.parts and stored_path.parts[0] == "data":
stored_path = (
Path(*stored_path.parts[1:]) if len(stored_path.parts) > 1 else Path()
)
return (_data_dir / stored_path).resolve()
def get_db_path() -> Path:
"""Get database file path."""
return _data_dir / "voicebox.db"
def get_profiles_dir() -> Path:
"""Get profiles directory path."""
path = _data_dir / "profiles"
path.mkdir(parents=True, exist_ok=True)
return path
def get_generations_dir() -> Path:
"""Get generations directory path."""
path = _data_dir / "generations"
path.mkdir(parents=True, exist_ok=True)
return path
def get_cache_dir() -> Path:
"""Get cache directory path."""
path = _data_dir / "cache"
path.mkdir(parents=True, exist_ok=True)
return path
def get_models_dir() -> Path:
"""Get models directory path."""
path = _data_dir / "models"
path.mkdir(parents=True, exist_ok=True)
return path

View File

@@ -0,0 +1,44 @@
"""Database package — ORM models, session management, and migrations.
Re-exports all public symbols so that ``from .database import get_db``
and ``from .database import Generation as DBGeneration`` continue to work
without changing any importers.
"""
from .models import (
Base,
AudioChannel,
ChannelDeviceMapping,
EffectPreset,
Generation,
GenerationVersion,
ProfileChannelMapping,
ProfileSample,
Project,
Story,
StoryItem,
VoiceProfile,
)
from .session import engine, SessionLocal, _db_path, init_db, get_db
__all__ = [
# Models
"Base",
"AudioChannel",
"ChannelDeviceMapping",
"EffectPreset",
"Generation",
"GenerationVersion",
"ProfileChannelMapping",
"ProfileSample",
"Project",
"Story",
"StoryItem",
"VoiceProfile",
# Session
"engine",
"SessionLocal",
"_db_path",
"init_db",
"get_db",
]

View File

@@ -0,0 +1,226 @@
"""Column-level migrations for the voicebox SQLite database.
Why not Alembic? voicebox is a single-user desktop app shipping as a
PyInstaller binary. Every user has exactly one SQLite file. Alembic's
strengths -- migration tracking across environments, rollback, team
coordination -- don't apply here and would add bundling complexity
(alembic.ini, env.py, versions/ directory all need to survive
PyInstaller). The column-existence checks below are idempotent, run in
<50 ms on startup, and have worked reliably across 12 schema changes.
If the project ever moves to a server-based deployment or Postgres, this
decision should be revisited.
Adding a new migration:
1. Append a new ``_migrate_*`` helper at the bottom of this file.
2. Call it from ``run_migrations()`` in the appropriate spot.
3. The helper should check column/table existence before acting
(idempotent) and print a short message when it does real work.
"""
import logging
from sqlalchemy import inspect, text
logger = logging.getLogger(__name__)
def run_migrations(engine) -> None:
"""Run all schema migrations. Safe to call on every startup."""
inspector = inspect(engine)
tables = set(inspector.get_table_names())
_migrate_story_items(engine, inspector, tables)
_migrate_profiles(engine, inspector, tables)
_migrate_generations(engine, inspector, tables)
_migrate_effect_presets(engine, inspector, tables)
_migrate_generation_versions(engine, inspector, tables)
_normalize_storage_paths(engine, tables)
# -- helpers ---------------------------------------------------------------
def _get_columns(inspector, table: str) -> set[str]:
return {col["name"] for col in inspector.get_columns(table)}
def _add_column(engine, table: str, column_sql: str, label: str) -> None:
"""Add a column if it doesn't already exist."""
with engine.connect() as conn:
conn.execute(text(f"ALTER TABLE {table} ADD COLUMN {column_sql}"))
conn.commit()
logger.info("Added %s column to %s", label, table)
# -- per-table migrations --------------------------------------------------
def _migrate_story_items(engine, inspector, tables: set[str]) -> None:
if "story_items" not in tables:
return
columns = _get_columns(inspector, "story_items")
# Replace position-based ordering with absolute timecodes
if "position" in columns:
logger.info("Migrating story_items: removing position column, using start_time_ms")
with engine.connect() as conn:
if "start_time_ms" not in columns:
conn.execute(text(
"ALTER TABLE story_items ADD COLUMN start_time_ms INTEGER DEFAULT 0"
))
result = conn.execute(text("""
SELECT si.id, si.story_id, si.position, g.duration
FROM story_items si
JOIN generations g ON si.generation_id = g.id
ORDER BY si.story_id, si.position
"""))
current_story_id = None
current_time_ms = 0
for item_id, story_id, _position, duration in result.fetchall():
if story_id != current_story_id:
current_story_id = story_id
current_time_ms = 0
conn.execute(
text("UPDATE story_items SET start_time_ms = :time WHERE id = :id"),
{"time": current_time_ms, "id": item_id},
)
current_time_ms += int((duration or 0) * 1000) + 200
conn.commit()
# Recreate table without the position column (SQLite lacks DROP COLUMN)
conn.execute(text("""
CREATE TABLE story_items_new (
id VARCHAR PRIMARY KEY,
story_id VARCHAR NOT NULL,
generation_id VARCHAR NOT NULL,
start_time_ms INTEGER NOT NULL DEFAULT 0,
track INTEGER NOT NULL DEFAULT 0,
trim_start_ms INTEGER NOT NULL DEFAULT 0,
trim_end_ms INTEGER NOT NULL DEFAULT 0,
version_id VARCHAR,
created_at DATETIME,
FOREIGN KEY (story_id) REFERENCES stories(id),
FOREIGN KEY (generation_id) REFERENCES generations(id)
)
"""))
conn.execute(text("""
INSERT INTO story_items_new (id, story_id, generation_id, start_time_ms, track, trim_start_ms, trim_end_ms, version_id, created_at)
SELECT id, story_id, generation_id, start_time_ms,
COALESCE(track, 0), COALESCE(trim_start_ms, 0), COALESCE(trim_end_ms, 0), version_id, created_at
FROM story_items
"""))
conn.execute(text("DROP TABLE story_items"))
conn.execute(text("ALTER TABLE story_items_new RENAME TO story_items"))
conn.commit()
# Re-read after table recreation
columns = _get_columns(inspector, "story_items")
if "track" not in columns:
_add_column(engine, "story_items", "track INTEGER NOT NULL DEFAULT 0", "track")
# Re-read so subsequent checks see new columns
columns = _get_columns(inspector, "story_items")
if "trim_start_ms" not in columns:
_add_column(engine, "story_items", "trim_start_ms INTEGER NOT NULL DEFAULT 0", "trim_start_ms")
if "trim_end_ms" not in columns:
_add_column(engine, "story_items", "trim_end_ms INTEGER NOT NULL DEFAULT 0", "trim_end_ms")
if "version_id" not in columns:
_add_column(engine, "story_items", "version_id VARCHAR", "version_id")
def _migrate_profiles(engine, inspector, tables: set[str]) -> None:
if "profiles" not in tables:
return
columns = _get_columns(inspector, "profiles")
if "avatar_path" not in columns:
_add_column(engine, "profiles", "avatar_path VARCHAR", "avatar_path")
if "effects_chain" not in columns:
_add_column(engine, "profiles", "effects_chain TEXT", "effects_chain")
# Voice type system — v0.3.x
if "voice_type" not in columns:
_add_column(engine, "profiles", "voice_type VARCHAR DEFAULT 'cloned'", "voice_type")
if "preset_engine" not in columns:
_add_column(engine, "profiles", "preset_engine VARCHAR", "preset_engine")
if "preset_voice_id" not in columns:
_add_column(engine, "profiles", "preset_voice_id VARCHAR", "preset_voice_id")
if "design_prompt" not in columns:
_add_column(engine, "profiles", "design_prompt TEXT", "design_prompt")
if "default_engine" not in columns:
_add_column(engine, "profiles", "default_engine VARCHAR", "default_engine")
def _migrate_generations(engine, inspector, tables: set[str]) -> None:
if "generations" not in tables:
return
columns = _get_columns(inspector, "generations")
if "status" not in columns:
_add_column(engine, "generations", "status VARCHAR DEFAULT 'completed'", "status")
if "error" not in columns:
_add_column(engine, "generations", "error TEXT", "error")
if "engine" not in columns:
_add_column(engine, "generations", "engine VARCHAR DEFAULT 'qwen'", "engine")
# Re-read after engine column (variable name shadows outer scope in old code)
columns = _get_columns(inspector, "generations")
if "model_size" not in columns:
_add_column(engine, "generations", "model_size VARCHAR", "model_size")
if "is_favorited" not in columns:
_add_column(engine, "generations", "is_favorited BOOLEAN DEFAULT 0", "is_favorited")
def _migrate_effect_presets(engine, inspector, tables: set[str]) -> None:
if "effect_presets" not in tables:
return
columns = _get_columns(inspector, "effect_presets")
if "sort_order" not in columns:
_add_column(engine, "effect_presets", "sort_order INTEGER DEFAULT 100", "sort_order")
def _migrate_generation_versions(engine, inspector, tables: set[str]) -> None:
if "generation_versions" not in tables:
return
columns = _get_columns(inspector, "generation_versions")
if "source_version_id" not in columns:
_add_column(engine, "generation_versions", "source_version_id VARCHAR", "source_version_id")
def _normalize_storage_paths(engine, tables: set[str]) -> None:
"""Normalize stored file paths to be relative to the configured data dir."""
from pathlib import Path
from ..config import get_data_dir, to_storage_path, resolve_storage_path
data_dir = get_data_dir()
path_columns = [
("generations", "audio_path"),
("generation_versions", "audio_path"),
("profile_samples", "audio_path"),
("profiles", "avatar_path"),
]
total_fixed = 0
with engine.connect() as conn:
for table, column in path_columns:
if table not in tables:
continue
rows = conn.execute(
text(f"SELECT id, {column} FROM {table} WHERE {column} IS NOT NULL")
).fetchall()
for row_id, path_val in rows:
if not path_val:
continue
p = Path(path_val)
resolved = resolve_storage_path(p)
if resolved is None:
continue
normalized = to_storage_path(resolved)
if normalized != path_val:
conn.execute(
text(f"UPDATE {table} SET {column} = :path WHERE id = :id"),
{"path": normalized, "id": row_id},
)
total_fixed += 1
if total_fixed > 0:
conn.commit()
logger.info("Normalized %d stored file paths", total_fixed)

169
backend/database/models.py Normal file
View File

@@ -0,0 +1,169 @@
"""ORM model definitions for the voicebox SQLite database."""
from datetime import datetime
import uuid
from sqlalchemy import Column, String, Integer, Float, DateTime, Text, ForeignKey, Boolean
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class VoiceProfile(Base):
"""Voice profile.
voice_type discriminates three flavours:
- "cloned" — traditional reference-audio profiles (all cloning engines)
- "preset" — engine-specific pre-built voice (e.g. Kokoro voices)
- "designed" — text-described voice (e.g. Qwen CustomVoice, future)
"""
__tablename__ = "profiles"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
name = Column(String, unique=True, nullable=False)
description = Column(Text)
language = Column(String, default="en")
avatar_path = Column(String, nullable=True)
effects_chain = Column(Text, nullable=True)
# Voice type system — added v0.3.x
voice_type = Column(String, default="cloned") # "cloned" | "preset" | "designed"
preset_engine = Column(String, nullable=True) # e.g. "kokoro" — only for preset
preset_voice_id = Column(String, nullable=True) # e.g. "am_adam" — only for preset
design_prompt = Column(Text, nullable=True) # text description — only for designed
default_engine = Column(String, nullable=True) # auto-selected engine, locked for preset
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class ProfileSample(Base):
"""Audio sample attached to a voice profile."""
__tablename__ = "profile_samples"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
profile_id = Column(String, ForeignKey("profiles.id"), nullable=False)
audio_path = Column(String, nullable=False)
reference_text = Column(Text, nullable=False)
class Generation(Base):
"""A single TTS generation."""
__tablename__ = "generations"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
profile_id = Column(String, ForeignKey("profiles.id"), nullable=False)
text = Column(Text, nullable=False)
language = Column(String, default="en")
audio_path = Column(String, nullable=True)
duration = Column(Float, nullable=True)
seed = Column(Integer)
instruct = Column(Text)
engine = Column(String, default="qwen")
model_size = Column(String, nullable=True)
status = Column(String, default="completed")
error = Column(Text, nullable=True)
is_favorited = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.utcnow)
class Story(Base):
"""A story that sequences multiple generations."""
__tablename__ = "stories"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
name = Column(String, nullable=False)
description = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class StoryItem(Base):
"""Links a generation to a story at a specific timecode."""
__tablename__ = "story_items"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
story_id = Column(String, ForeignKey("stories.id"), nullable=False)
generation_id = Column(String, ForeignKey("generations.id"), nullable=False)
version_id = Column(String, ForeignKey("generation_versions.id"), nullable=True)
start_time_ms = Column(Integer, nullable=False, default=0)
track = Column(Integer, nullable=False, default=0)
trim_start_ms = Column(Integer, nullable=False, default=0)
trim_end_ms = Column(Integer, nullable=False, default=0)
created_at = Column(DateTime, default=datetime.utcnow)
class Project(Base):
"""Audio studio project (JSON blob)."""
__tablename__ = "projects"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
name = Column(String, nullable=False)
data = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
class GenerationVersion(Base):
"""A version of a generation's audio (original, processed, alternate takes)."""
__tablename__ = "generation_versions"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
generation_id = Column(String, ForeignKey("generations.id"), nullable=False)
label = Column(String, nullable=False)
audio_path = Column(String, nullable=False)
effects_chain = Column(Text, nullable=True)
source_version_id = Column(String, ForeignKey("generation_versions.id"), nullable=True)
is_default = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.utcnow)
class EffectPreset(Base):
"""Saved effect chain preset."""
__tablename__ = "effect_presets"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
name = Column(String, unique=True, nullable=False)
description = Column(Text, nullable=True)
effects_chain = Column(Text, nullable=False)
is_builtin = Column(Boolean, default=False)
sort_order = Column(Integer, default=100)
created_at = Column(DateTime, default=datetime.utcnow)
class AudioChannel(Base):
"""Audio output channel (bus)."""
__tablename__ = "audio_channels"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
name = Column(String, nullable=False)
is_default = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.utcnow)
class ChannelDeviceMapping(Base):
"""Mapping between a channel and an OS audio device."""
__tablename__ = "channel_device_mappings"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
channel_id = Column(String, ForeignKey("audio_channels.id"), nullable=False)
device_id = Column(String, nullable=False)
class ProfileChannelMapping(Base):
"""Many-to-many mapping between voice profiles and audio channels."""
__tablename__ = "profile_channel_mappings"
profile_id = Column(String, ForeignKey("profiles.id"), primary_key=True)
channel_id = Column(String, ForeignKey("audio_channels.id"), primary_key=True)

73
backend/database/seed.py Normal file
View File

@@ -0,0 +1,73 @@
"""Post-migration data seeding and backfills."""
import json
import logging
import uuid
from .. import config
logger = logging.getLogger(__name__)
def backfill_generation_versions(SessionLocal, Generation, GenerationVersion) -> None:
"""Create 'clean' version entries for generations that predate the versions feature."""
db = SessionLocal()
try:
existing_version_gen_ids = {
row[0] for row in db.query(GenerationVersion.generation_id).all()
}
generations = db.query(Generation).filter(
Generation.status == "completed",
Generation.audio_path.isnot(None),
Generation.audio_path != "",
).all()
count = 0
for gen in generations:
if gen.id in existing_version_gen_ids:
continue
resolved_audio_path = config.resolve_storage_path(gen.audio_path)
if resolved_audio_path is None or not resolved_audio_path.exists():
continue
version = GenerationVersion(
id=str(uuid.uuid4()),
generation_id=gen.id,
label="clean",
audio_path=gen.audio_path,
effects_chain=None,
is_default=True,
)
db.add(version)
count += 1
if count > 0:
db.commit()
logger.info("Backfilled %d generation version entries", count)
finally:
db.close()
def seed_builtin_presets(SessionLocal, EffectPreset) -> None:
"""Ensure built-in effect presets exist in the database."""
from ..utils.effects import BUILTIN_PRESETS
db = SessionLocal()
try:
for idx, (_key, preset_data) in enumerate(BUILTIN_PRESETS.items()):
sort_order = preset_data.get("sort_order", idx)
existing = db.query(EffectPreset).filter_by(name=preset_data["name"]).first()
if not existing:
preset = EffectPreset(
id=str(uuid.uuid4()),
name=preset_data["name"],
description=preset_data.get("description"),
effects_chain=json.dumps(preset_data["effects_chain"]),
is_builtin=True,
sort_order=sort_order,
)
db.add(preset)
elif existing.sort_order != sort_order:
existing.sort_order = sort_order
db.commit()
finally:
db.close()

View File

@@ -0,0 +1,78 @@
"""Engine creation, initialization, and session management."""
import logging
import uuid
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from .. import config
from .models import (
Base,
AudioChannel,
EffectPreset,
Generation,
GenerationVersion,
ProfileChannelMapping,
VoiceProfile,
)
from .migrations import run_migrations
from .seed import backfill_generation_versions, seed_builtin_presets
logger = logging.getLogger(__name__)
# Initialized by init_db()
engine = None
SessionLocal = None
_db_path = None
def init_db() -> None:
"""Initialize the database engine, run migrations, create tables, and seed data."""
global engine, SessionLocal, _db_path
_db_path = config.get_db_path()
_db_path.parent.mkdir(parents=True, exist_ok=True)
engine = create_engine(
f"sqlite:///{_db_path}",
connect_args={"check_same_thread": False},
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
run_migrations(engine)
Base.metadata.create_all(bind=engine)
# Create default audio channel if it doesn't exist
db = SessionLocal()
try:
default_channel = db.query(AudioChannel).filter(AudioChannel.is_default == True).first()
if not default_channel:
default_channel = AudioChannel(
id=str(uuid.uuid4()),
name="Default",
is_default=True,
)
db.add(default_channel)
for profile in db.query(VoiceProfile).all():
db.add(ProfileChannelMapping(
profile_id=profile.id,
channel_id=default_channel.id,
))
db.commit()
finally:
db.close()
backfill_generation_versions(SessionLocal, Generation, GenerationVersion)
seed_builtin_presets(SessionLocal, EffectPreset)
def get_db():
"""Yield a database session (FastAPI dependency)."""
db = SessionLocal()
try:
yield db
finally:
db.close()

45
backend/main.py Normal file
View File

@@ -0,0 +1,45 @@
"""Entry point for the voicebox backend.
Imports the configured FastAPI app and provides a ``python -m backend.main``
entry point for development.
"""
import argparse
import uvicorn
from .app import app # noqa: F401 -- re-export for uvicorn "backend.main:app"
from . import config, database
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="voicebox backend server")
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Host to bind to (use 0.0.0.0 for remote access)",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port to bind to",
)
parser.add_argument(
"--data-dir",
type=str,
default=None,
help="Data directory for database, profiles, and generated audio",
)
args = parser.parse_args()
if args.data_dir:
config.set_data_dir(args.data_dir)
database.init_db()
uvicorn.run(
"backend.main:app",
host=args.host,
port=args.port,
reload=False,
)

521
backend/models.py Normal file
View File

@@ -0,0 +1,521 @@
"""
Pydantic models for request/response validation.
"""
from pydantic import BaseModel, Field
from typing import Optional, List
from datetime import datetime
class VoiceProfileCreate(BaseModel):
"""Request model for creating a voice profile."""
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
language: str = Field(
default="en", pattern="^(zh|en|ja|ko|de|fr|ru|pt|es|it|he|ar|da|el|fi|hi|ms|nl|no|pl|sv|sw|tr)$"
)
voice_type: Optional[str] = Field(default="cloned", pattern="^(cloned|preset|designed)$")
preset_engine: Optional[str] = Field(None, max_length=50)
preset_voice_id: Optional[str] = Field(None, max_length=100)
design_prompt: Optional[str] = Field(None, max_length=2000)
default_engine: Optional[str] = Field(None, max_length=50)
class VoiceProfileResponse(BaseModel):
"""Response model for voice profile."""
id: str
name: str
description: Optional[str]
language: str
avatar_path: Optional[str] = None
effects_chain: Optional[List["EffectConfig"]] = None
voice_type: str = "cloned"
preset_engine: Optional[str] = None
preset_voice_id: Optional[str] = None
design_prompt: Optional[str] = None
default_engine: Optional[str] = None
generation_count: int = 0
sample_count: int = 0
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class ProfileSampleCreate(BaseModel):
"""Request model for adding a sample to a profile."""
reference_text: str = Field(..., min_length=1, max_length=1000)
class ProfileSampleUpdate(BaseModel):
"""Request model for updating a profile sample."""
reference_text: str = Field(..., min_length=1, max_length=1000)
class ProfileSampleResponse(BaseModel):
"""Response model for profile sample."""
id: str
profile_id: str
audio_path: str
reference_text: str
class Config:
from_attributes = True
class GenerationRequest(BaseModel):
"""Request model for voice generation."""
profile_id: str
text: str = Field(..., min_length=1, max_length=50000)
language: str = Field(default="en", pattern="^(zh|en|ja|ko|de|fr|ru|pt|es|it|he|ar|da|el|fi|hi|ms|nl|no|pl|sv|sw|tr)$")
seed: Optional[int] = Field(None, ge=0)
model_size: Optional[str] = Field(default="1.7B", pattern="^(1\\.7B|0\\.6B|1B|3B)$")
instruct: Optional[str] = Field(None, max_length=500)
engine: Optional[str] = Field(default="qwen", pattern="^(qwen|qwen_custom_voice|luxtts|chatterbox|chatterbox_turbo|tada|kokoro)$")
max_chunk_chars: int = Field(
default=800, ge=100, le=5000, description="Max characters per chunk for long text splitting"
)
crossfade_ms: int = Field(
default=50, ge=0, le=500, description="Crossfade duration in ms between chunks (0 for hard cut)"
)
normalize: bool = Field(default=True, description="Normalize output audio volume")
effects_chain: Optional[List["EffectConfig"]] = Field(
None, description="Effects chain to apply after generation (overrides profile default)"
)
class GenerationResponse(BaseModel):
"""Response model for voice generation."""
id: str
profile_id: str
text: str
language: str
audio_path: Optional[str] = None
duration: Optional[float] = None
seed: Optional[int] = None
instruct: Optional[str] = None
engine: Optional[str] = "qwen"
model_size: Optional[str] = None
status: str = "completed"
error: Optional[str] = None
is_favorited: bool = False
created_at: datetime
versions: Optional[List["GenerationVersionResponse"]] = None
active_version_id: Optional[str] = None
class Config:
from_attributes = True
class HistoryQuery(BaseModel):
"""Query model for generation history."""
profile_id: Optional[str] = None
search: Optional[str] = None
limit: int = Field(default=50, ge=1, le=100)
offset: int = Field(default=0, ge=0)
class HistoryResponse(BaseModel):
"""Response model for history entry (includes profile name)."""
id: str
profile_id: str
profile_name: str
text: str
language: str
audio_path: Optional[str] = None
duration: Optional[float] = None
seed: Optional[int] = None
instruct: Optional[str] = None
engine: Optional[str] = "qwen"
model_size: Optional[str] = None
status: str = "completed"
error: Optional[str] = None
is_favorited: bool = False
created_at: datetime
versions: Optional[List["GenerationVersionResponse"]] = None
active_version_id: Optional[str] = None
class Config:
from_attributes = True
class HistoryListResponse(BaseModel):
"""Response model for history list."""
items: List[HistoryResponse]
total: int
class TranscriptionRequest(BaseModel):
"""Request model for audio transcription."""
language: Optional[str] = Field(None, pattern="^(en|zh|ja|ko|de|fr|ru|pt|es|it)$")
model: Optional[str] = Field(None, pattern="^(base|small|medium|large|turbo)$")
class TranscriptionResponse(BaseModel):
"""Response model for transcription."""
text: str
duration: float
class HealthResponse(BaseModel):
"""Response model for health check."""
status: str
model_loaded: bool
model_downloaded: Optional[bool] = None # Whether model is cached/downloaded
model_size: Optional[str] = None # Current model size if loaded
gpu_available: bool
gpu_type: Optional[str] = None # GPU type (CUDA, MPS, or None)
vram_used_mb: Optional[float] = None
backend_type: Optional[str] = None # Backend type (mlx or pytorch)
backend_variant: Optional[str] = None # Binary variant (cpu or cuda)
gpu_compatibility_warning: Optional[str] = None # Warning if GPU arch unsupported
class DirectoryCheck(BaseModel):
"""Health status for a single directory."""
path: str
exists: bool
writable: bool
error: Optional[str] = None
class FilesystemHealthResponse(BaseModel):
"""Response model for filesystem health check."""
healthy: bool
disk_free_mb: Optional[float] = None
disk_total_mb: Optional[float] = None
directories: List[DirectoryCheck]
class ModelStatus(BaseModel):
"""Response model for model status."""
model_name: str
display_name: str
hf_repo_id: Optional[str] = None # HuggingFace repository ID
downloaded: bool
downloading: bool = False # True if download is in progress
size_mb: Optional[float] = None
loaded: bool = False
class ModelStatusListResponse(BaseModel):
"""Response model for model status list."""
models: List[ModelStatus]
class ModelDownloadRequest(BaseModel):
"""Request model for triggering model download."""
model_name: str
class ModelMigrateRequest(BaseModel):
"""Request model for migrating models to a new directory."""
destination: str
class ActiveDownloadTask(BaseModel):
"""Response model for active download task."""
model_name: str
status: str
started_at: datetime
error: Optional[str] = None
progress: Optional[float] = None # 0-100 percentage
current: Optional[int] = None # bytes downloaded
total: Optional[int] = None # total bytes
filename: Optional[str] = None # current file being downloaded
class ActiveGenerationTask(BaseModel):
"""Response model for active generation task."""
task_id: str
profile_id: str
text_preview: str
started_at: datetime
class ActiveTasksResponse(BaseModel):
"""Response model for active tasks."""
downloads: List[ActiveDownloadTask]
generations: List[ActiveGenerationTask]
class AudioChannelCreate(BaseModel):
"""Request model for creating an audio channel."""
name: str = Field(..., min_length=1, max_length=100)
device_ids: List[str] = Field(default_factory=list)
class AudioChannelUpdate(BaseModel):
"""Request model for updating an audio channel."""
name: Optional[str] = Field(None, min_length=1, max_length=100)
device_ids: Optional[List[str]] = None
class AudioChannelResponse(BaseModel):
"""Response model for audio channel."""
id: str
name: str
is_default: bool
device_ids: List[str]
created_at: datetime
class Config:
from_attributes = True
class ChannelVoiceAssignment(BaseModel):
"""Request model for assigning voices to a channel."""
profile_ids: List[str]
class ProfileChannelAssignment(BaseModel):
"""Request model for assigning channels to a profile."""
channel_ids: List[str]
class StoryCreate(BaseModel):
"""Request model for creating a story."""
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
class StoryResponse(BaseModel):
"""Response model for story (list view)."""
id: str
name: str
description: Optional[str]
created_at: datetime
updated_at: datetime
item_count: int = 0
class Config:
from_attributes = True
class StoryItemDetail(BaseModel):
"""Detail model for story item with generation info."""
id: str
story_id: str
generation_id: str
version_id: Optional[str] = None
start_time_ms: int
track: int = 0
trim_start_ms: int = 0
trim_end_ms: int = 0
created_at: datetime
# Generation details
profile_id: str
profile_name: str
text: str
language: str
audio_path: str
duration: float
seed: Optional[int]
instruct: Optional[str]
generation_created_at: datetime
# Versions available for this generation
versions: Optional[List["GenerationVersionResponse"]] = None
active_version_id: Optional[str] = None
class Config:
from_attributes = True
class StoryDetailResponse(BaseModel):
"""Response model for story with items."""
id: str
name: str
description: Optional[str]
created_at: datetime
updated_at: datetime
items: List[StoryItemDetail] = []
class Config:
from_attributes = True
class StoryItemCreate(BaseModel):
"""Request model for adding a generation to a story."""
generation_id: str
start_time_ms: Optional[int] = None # If not provided, will be calculated automatically
track: Optional[int] = 0 # Track number (0 = main track)
class StoryItemUpdateTime(BaseModel):
"""Request model for updating a story item's timecode."""
generation_id: str
start_time_ms: int = Field(..., ge=0)
class StoryItemBatchUpdate(BaseModel):
"""Request model for batch updating story item timecodes."""
updates: List[StoryItemUpdateTime]
class StoryItemReorder(BaseModel):
"""Request model for reordering story items."""
generation_ids: List[str] = Field(..., min_length=1)
class StoryItemMove(BaseModel):
"""Request model for moving a story item (position and/or track)."""
start_time_ms: int = Field(..., ge=0)
track: int = 0
class StoryItemTrim(BaseModel):
"""Request model for trimming a story item."""
trim_start_ms: int = Field(..., ge=0)
trim_end_ms: int = Field(..., ge=0)
class StoryItemSplit(BaseModel):
"""Request model for splitting a story item."""
split_time_ms: int = Field(..., ge=0) # Time within the clip to split at (relative to clip start)
class StoryItemVersionUpdate(BaseModel):
"""Request model for setting a story item's pinned version."""
version_id: Optional[str] = None # null = use generation default
class EffectConfig(BaseModel):
"""A single effect in an effects chain."""
type: str
enabled: bool = True
params: dict = Field(default_factory=dict)
class EffectsChain(BaseModel):
"""An ordered list of effects to apply."""
effects: List[EffectConfig] = Field(default_factory=list)
class EffectPresetCreate(BaseModel):
"""Request model for creating an effect preset."""
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
effects_chain: List[EffectConfig]
class EffectPresetUpdate(BaseModel):
"""Request model for updating an effect preset."""
name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = None
effects_chain: Optional[List[EffectConfig]] = None
class EffectPresetResponse(BaseModel):
"""Response model for effect preset."""
id: str
name: str
description: Optional[str] = None
effects_chain: List[EffectConfig]
is_builtin: bool = False
created_at: datetime
class Config:
from_attributes = True
class GenerationVersionResponse(BaseModel):
"""Response model for a generation version."""
id: str
generation_id: str
label: str
audio_path: str
effects_chain: Optional[List[EffectConfig]] = None
source_version_id: Optional[str] = None
is_default: bool
created_at: datetime
class Config:
from_attributes = True
class ApplyEffectsRequest(BaseModel):
"""Request to apply effects to an existing generation."""
effects_chain: List[EffectConfig]
source_version_id: Optional[str] = Field(
None, description="Version to use as source audio (defaults to clean/original)"
)
label: Optional[str] = Field(None, max_length=100, description="Label for this version (auto-generated if omitted)")
set_as_default: bool = Field(default=True, description="Set this version as the default")
class ProfileEffectsUpdate(BaseModel):
"""Request to update the default effects chain on a profile."""
effects_chain: Optional[List[EffectConfig]] = Field(None, description="Effects chain (null to remove)")
class AvailableEffectParam(BaseModel):
"""Description of a single effect parameter."""
default: float
min: float
max: float
step: float
description: str
class AvailableEffect(BaseModel):
"""Description of an available effect type."""
type: str
label: str
description: str
params: dict # param_name -> AvailableEffectParam
class AvailableEffectsResponse(BaseModel):
"""Response listing all available effect types."""
effects: List[AvailableEffect]

View File

@@ -0,0 +1,12 @@
"""
Force scipy.stats._distn_infrastructure to be bundled with its .py source file
alongside the .pyc bytecode.
The runtime hook in backend/pyi_rth_torch_compiler_disable.py patches this
module's source at load time (the module has a `del obj` at line 369 that
raises NameError under PyInstaller's frozen importer). That patch reads the
source via loader.get_source(), which only works if the .py file was
actually collected into the bundle.
"""
module_collection_mode = "pyz+py"

View File

@@ -0,0 +1,11 @@
"""
Force transformers.masking_utils to be bundled with its .py source alongside
the .pyc bytecode so the runtime hook in
backend/pyi_rth_torch_compiler_disable.py can source-patch it.
The patch forces the torch<2.6 code path, bypassing `with TransformGetItemToIndex()`
which our torch._dynamo no-op stub can't implement for real — the real context
manager uses dynamo graph transforms to avoid `.item()` calls inside vmap.
"""
module_collection_mode = "pyz+py"

View File

@@ -0,0 +1,95 @@
"""
PyInstaller runtime hook: numpy 2.x / torch ABI mismatch fix.
Problem
-------
torch is compiled against numpy 1.x headers. numpy 2.x changed the version
number returned by PyArray_GetNDArrayCVersion() (0x01000009 → 0x02000000),
so torch's is_numpy_available() returns False and every torch.from_numpy()
call raises:
RuntimeError: Numpy is not available
This surfaces as:
ValueError: Unable to create tensor, you should probably activate
padding with 'padding=True'
during TTS generation (EncodecFeatureExtractor → BatchFeature.convert_to_tensors).
Fix
---
Runtime hooks execute after PyInstaller's FrozenImporter is registered, so
frozen torch/numpy are importable here. We start a background thread that
waits for torch to finish loading then wraps torch.from_numpy with a ctypes
memmove fallback that bypasses the C-level numpy ABI check entirely.
This approach works with any numpy version and is safer than binary-patching
libtorch_python.dylib (which risks PyArray_Descr struct layout mismatches).
"""
import sys
import threading
def _patch_torch_from_numpy():
import time
for _ in range(7200): # poll up to 360 s at 50 ms intervals
time.sleep(0.05)
torch = sys.modules.get("torch")
if torch is None or not hasattr(torch, "from_numpy"):
continue
if getattr(torch, "_vb_from_numpy_patched", False):
return
try:
import ctypes
import numpy as np
_orig = torch.from_numpy
# Explicit numpy → torch dtype map. Silent fallback to float32 on
# unknown dtypes would reinterpret the memcpy'd bytes as fp32 and
# silently corrupt data (e.g. fp16 tensors from some TTS engines),
# so we raise instead.
dtype_map = {
"float16": _t.float16,
"float32": _t.float32,
"float64": _t.float64,
"int8": _t.int8,
"int16": _t.int16,
"int32": _t.int32,
"int64": _t.int64,
"uint8": _t.uint8,
"bool": _t.bool,
"complex64": _t.complex64,
"complex128": _t.complex128,
}
def _safe_from_numpy(
arr, _orig=_orig, _c=ctypes, _np=np, _t=torch, _map=dtype_map
):
try:
return _orig(arr)
except RuntimeError:
a = _np.ascontiguousarray(arr)
key = str(a.dtype)
if key not in _map:
raise TypeError(
f"pyi_rth_numpy_compat: unsupported numpy dtype "
f"{key!r} in torch.from_numpy fallback; add an "
f"explicit mapping rather than silently copying "
f"bytes into the wrong dtype."
)
out = _t.empty(list(a.shape), dtype=_map[key])
_c.memmove(out.data_ptr(), a.ctypes.data, a.nbytes)
return out
torch.from_numpy = _safe_from_numpy
torch._vb_from_numpy_patched = True
except Exception:
pass
return
threading.Thread(target=_patch_torch_from_numpy, daemon=True).start()

View File

@@ -0,0 +1,540 @@
"""
PyInstaller runtime hook: stub torch._dynamo to a no-op module.
Problem
-------
transformers triggers torch._dynamo import at module-load time (not just
when torch.compile is called) via class-body decorators:
transformers/modeling_utils.py:1984
@torch._dynamo.allow_in_graph
class PreTrainedModel(...)
transformers/integrations/flex_attention.py:61
@torch.compiler.disable(recursive=False)
class WrappedFlexAttention...
The attribute access triggers torch.__getattr__ -> importlib.import_module
-> torch._dynamo -> torch._dynamo.utils imports torch._numpy ->
torch._numpy._ndarray imports torch._numpy._ufuncs, which crashes under
PyInstaller with:
File "torch/_numpy/_ufuncs.py", line 235, in <module>
vars()[name] = deco_binary_ufunc(ufunc)
NameError: name 'name' is not defined
(The module-level `for name in _binary: vars()[name] = ...` pattern works
in a regular venv but fails in the PyInstaller bundle. Root cause is in
PyInstaller's importer / bytecode pipeline and not easily fixed upstream.)
Surfaces as Kokoro failing to load when `from transformers import AlbertModel`
trips the decorator chain.
Fix
---
voicebox never uses torch.compile / torch._dynamo for inference, so we
replace torch._dynamo with a no-op stub module before transformers is
imported. Any attribute access on the stub returns a pass-through callable,
so `@torch._dynamo.allow_in_graph`, `torch._dynamo.is_compiling()`,
`torch._dynamo.mark_static_address(...)`, etc. all work.
This hook is pure sys.modules manipulation — we deliberately do NOT import
torch here. Runtime hooks run before the app starts and before
pyi_rth_numpy_compat has had a chance to patch torch.from_numpy (it runs
in a background thread, waiting for torch to appear in sys.modules).
Eager-importing torch at hook time would trip the numpy ABI issue and
kill the server process at startup.
torch.compiler.disable does not need a separate stub: its implementation
is effectively `import torch._dynamo; return torch._dynamo.disable(...)`,
and since our stub is in sys.modules, that call resolves to our no-op
_NoopDecorator pass-through.
"""
import os
import sys
import tempfile
import types
# Diagnostics — log hook activity to a file alongside the bundle so we can
# see what's happening when the server is run as a sidecar (no stdout for
# runtime hook prints). Safe no-op if the file can't be written.
_DIAG_PATH = os.path.join(tempfile.gettempdir(), "voicebox_rt_hook.log")
def _diag(msg: str) -> None:
try:
with open(_DIAG_PATH, "a", encoding="utf-8") as f:
f.write(msg + "\n")
except Exception:
pass
_HOOK_VERSION = "v6-masking-utils-finder"
_diag(f"=== runtime hook load @ pid={os.getpid()} version={_HOOK_VERSION} ===")
class _NoopDecorator:
"""Multi-role no-op: decorator, falsey predicate, and context manager.
Returned from calls like `torch._dynamo.disable()` (decorator),
`torch._dynamo.is_compiling()` (predicate used in `if not ...`), and
`with torch._dynamo._trace_wrapped_higher_order_op.TransformGetItemToIndex():`
(context manager used to scope an fx graph transformation).
By implementing __call__, __bool__, __enter__, __exit__, and __iter__ we
cover every use pattern we've seen transformers/torch use on a stubbed
object. Anything we haven't covered will raise a clearer error than a
silent wrong-result.
"""
__slots__ = ()
def __call__(self, fn=None, *args, **kwargs):
return fn
def __bool__(self) -> bool:
return False
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
return False # don't suppress exceptions
def __iter__(self):
return iter(())
_noop_decorator_singleton = _NoopDecorator()
def _noop_callable(*args, **kwargs):
# Direct-decorator use: @torch._dynamo.foo (no parens) — fn is positional
if len(args) == 1 and callable(args[0]) and not kwargs:
return args[0]
# Side-effect call with non-callable arg(s), e.g. mark_static_address(tensor)
return _noop_decorator_singleton
class _NoopDynamoModule(types.ModuleType):
"""Permissive stub: every attribute is a pass-through callable.
Covers attributes transformers hits at import time (allow_in_graph) and
runtime (is_compiling, mark_static_address, reset, disable, ...).
Dunder attributes (__file__, __spec__, __loader__, ...) raise
AttributeError so probes like inspect.getmodule() — which does
`hasattr(m, '__file__')` then `os.path.normpath(m.__file__)` — see the
module as having no source file and fall through to its normal
handling, instead of receiving a function and blowing up.
"""
def __getattr__(self, name: str):
if name.startswith("__") and name.endswith("__"):
raise AttributeError(name)
return _noop_callable
class _DynamoLoader:
"""Loader used by _DynamoMetaPathFinder to materialise stub submodules."""
def create_module(self, spec):
return _NoopDynamoModule(spec.name)
def exec_module(self, module):
# Mark every stub submodule as a package so deeper submodule imports
# (`from torch._dynamo.X.Y import Z`) keep working.
module.__path__ = []
class _DynamoMetaPathFinder:
"""Resolve any `torch._dynamo.X[.Y...]` import to a no-op stub module.
Without this, `from torch._dynamo._trace_wrapped_higher_order_op import X`
fails even with torch._dynamo pre-populated in sys.modules — Python's
import machinery checks the parent's __path__ and then looks up the
child, and we need to provide both.
"""
def find_spec(self, fullname, path=None, target=None):
if fullname == "torch._dynamo":
return None # handled by the pre-populated sys.modules entry
if not fullname.startswith("torch._dynamo."):
return None
from importlib.machinery import ModuleSpec
return ModuleSpec(fullname, _DynamoLoader(), is_package=True)
class _TransformersStubFinder:
"""Replace specific transformers submodules with no-op stubs.
Two modules are targeted:
1. transformers.utils.auto_docstring
The real @auto_docstring decorator loads
transformers.models.auto.modeling_auto just to build example docstrings,
which drags in GenerationMixin -> candidate_generator -> sklearn.metrics
-> scipy.stats._distn_infrastructure and trips (2) below. Docstrings
aren't functional for inference, so a pass-through decorator is safe.
2. transformers.generation.candidate_generator
Imported at module scope by transformers.generation.utils. It does
`from sklearn.metrics import roc_curve` at module load, which triggers:
File "scipy/stats/_distn_infrastructure.py", line 369, in <module>
NameError: name 'obj' is not defined
This is a PyInstaller-specific module-load bug (same class as the
torch._numpy._ufuncs crash) where a module-level `for obj in [s for s
in dir() if ...]` loop evaluates to empty in the bundle, leaving `obj`
unbound before `del obj`.
The exports (AssistedCandidateGenerator, EarlyExitCandidateGenerator,
etc.) are speculative-decoding helpers voicebox's TTS engines do not
use; a no-op stub module satisfies the imports.
"""
_STUBBED_MODULES = frozenset(
{
"transformers.utils.auto_docstring",
"transformers.generation.candidate_generator",
}
)
def find_spec(self, fullname, path=None, target=None):
if fullname not in self._STUBBED_MODULES:
return None
from importlib.machinery import ModuleSpec
return ModuleSpec(fullname, _NoopStubLoader(), is_package=False)
class _NoopStubLoader:
def create_module(self, spec):
return _NoopDynamoModule(spec.name)
def exec_module(self, module):
# _NoopDynamoModule.__getattr__ already answers every non-dunder
# attribute with a pass-through callable, which satisfies
# `from stubbed_module import X` for any X.
pass
def _patch_scipy_distn_source(source: str) -> str:
"""Replace the unsafe `del obj` with a no-op that survives when obj is unbound.
Returns the input unchanged if the target line isn't found (e.g. scipy
version has changed).
"""
target = "\ndel obj\n"
replacement = "\nglobals().pop('obj', None)\n"
if target in source:
return source.replace(target, replacement, 1)
return source
def _patch_masking_utils_source(source: str) -> str:
"""Force torch<2.6 code path in transformers.masking_utils.
The torch>=2.6 path uses `with TransformGetItemToIndex():` to allow
`.item()` calls inside vmap. That context manager is implemented via
torch._dynamo graph transforms, which our stub doesn't reproduce — it's
a no-op. The inner `_vmap_for_bhqkv` then crashes with:
RuntimeError: vmap: It looks like you're calling .item() on a Tensor.
Forcing the torch<2.6 flag off selects sdpa_mask_older_torch which uses
a different vmap pattern that does not hit .item() and does not need
TransformGetItemToIndex.
"""
target = 'is_torch_greater_or_equal("2.6", accept_dev=True)'
# Find the specific line that assigns _is_torch_greater_or_equal_than_2_6
if "_is_torch_greater_or_equal_than_2_6 = " + target in source:
return source.replace(
"_is_torch_greater_or_equal_than_2_6 = " + target,
"_is_torch_greater_or_equal_than_2_6 = False",
1,
)
return source
class _SourcePatchingFinder:
"""Generic delegate-and-wrap meta-path finder that patches a module's
source before exec'ing.
Subclasses declare `target` (module fullname) and `patch` (str->str).
Requires the target module's .py source to be bundled (use a PyInstaller
hook setting module_collection_mode = "pyz+py").
"""
target: str
patch_fn: callable = None
def find_spec(self, fullname, path=None, target=None):
if fullname != self.target:
return None
for finder in sys.meta_path:
if finder is self:
continue
find = getattr(finder, "find_spec", None)
if find is None:
continue
try:
real_spec = find(fullname, path, target)
except Exception:
continue
if real_spec is None or real_spec.loader is None:
continue
real_spec.loader = _SourcePatchLoader(real_spec.loader, self.patch_fn)
return real_spec
return None
class _SourcePatchLoader:
"""Delegate loader that reads source via get_source, applies a patch, and
compile/exec's the patched text into module.__dict__.
"""
def __init__(self, inner, patch_fn):
self._inner = inner
self._patch_fn = patch_fn
def __getattr__(self, name):
return getattr(self._inner, name)
def create_module(self, spec):
return self._inner.create_module(spec)
def exec_module(self, module):
source = None
try:
source = self._inner.get_source(module.__name__)
except Exception as e:
_diag(f"[source-patch] get_source({module.__name__}) failed: {e!r}")
if not source:
_diag(
f"[source-patch] no source for {module.__name__}; "
"falling back to inner exec_module (patch NOT applied)"
)
self._inner.exec_module(module)
return
patched = self._patch_fn(source)
_diag(
f"[source-patch] {module.__name__}: "
f"patched={patched is not source}, len={len(patched)}"
)
spec = module.__spec__
if spec is not None and spec.submodule_search_locations is not None:
module.__path__ = spec.submodule_search_locations
filename = getattr(self._inner, "path", module.__name__)
exec(compile(patched, filename, "exec"), module.__dict__)
_diag(f"[source-patch] {module.__name__} OK")
class _MaskingUtilsFinder(_SourcePatchingFinder):
target = "transformers.masking_utils"
patch_fn = staticmethod(_patch_masking_utils_source)
class _ScipyDistnPatchingFinder:
"""Delegate-and-wrap finder for scipy.stats._distn_infrastructure.
That module ends with:
for obj in [s for s in dir() if s.startswith('_doc_')]:
exec('del ' + obj)
del obj
In the PyInstaller bundle the list comprehension evaluates to empty
(module-level dir() under the frozen importer returns a different scope
than CPython's normal module-exec path — same class of bug as the
torch._numpy._ufuncs crash). The for loop body doesn't run, `obj` is
never bound, and the trailing `del obj` raises NameError at module load.
This kills every downstream module: librosa (needed by nearly every TTS
engine for mel filters) -> scipy.signal -> scipy.stats -> here.
Workaround: delegate to the real loader, but pre-bind `obj = None` in the
module namespace before its bytecode runs. If the for loop executes, each
iteration overwrites the sentinel via STORE_NAME (normal behaviour). If it
doesn't, `del obj` removes the sentinel and module load succeeds. The
`_doc_*` cleanup this line was meant to do is purely cosmetic — those vars
stay in the module namespace but nothing references them after this point.
"""
_TARGET = "scipy.stats._distn_infrastructure"
def find_spec(self, fullname, path=None, target=None):
if fullname != self._TARGET:
return None
_diag(f"[scipy-finder] match: {fullname}, path={path!r}")
# Delegate to the other finders to locate the real spec
for finder in sys.meta_path:
if finder is self:
continue
find = getattr(finder, "find_spec", None)
if find is None:
continue
try:
real_spec = find(fullname, path, target)
except Exception as e:
_diag(f"[scipy-finder] inner finder {type(finder).__name__} raised: {e}")
continue
if real_spec is None:
continue
if real_spec.loader is None:
_diag(f"[scipy-finder] {type(finder).__name__} returned spec with loader=None")
continue
_diag(
f"[scipy-finder] wrapped loader from "
f"{type(finder).__name__} -> {type(real_spec.loader).__name__}"
)
real_spec.loader = _ScipyDistnPrebindLoader(real_spec.loader)
return real_spec
_diag("[scipy-finder] NO inner finder returned a spec")
return None
class _ScipyDistnPrebindLoader:
"""Thin wrapper that pre-binds `obj = None` before delegating to the
real PyInstaller loader.
Every other attribute/method delegates to the inner loader — PyiFrozenLoader
is a rich FileLoader/ExecutionLoader with get_code/get_source/get_filename/
is_package/get_resource_reader/etc., any of which Python's import machinery
or 3rd-party code may call on spec.loader. Forwarding via __getattr__
avoids breaking any of those paths (and preserves @_check_name contracts
because the decorated methods run on the inner instance where self.name
matches spec.name).
"""
def __init__(self, inner):
self._inner = inner
def __getattr__(self, name):
# __getattr__ fires only for attrs not already on self, so delegate
# everything that isn't create_module/exec_module (or __getattr__/init).
return getattr(self._inner, name)
def create_module(self, spec):
return self._inner.create_module(spec)
def exec_module(self, module):
# Compile scipy's module source with the problematic line patched.
#
# The real module ends with:
# for obj in [s for s in dir() if s.startswith('_doc_')]:
# exec('del ' + obj)
# del obj
#
# Under PyInstaller's frozen importer, `del obj` raises NameError
# even when we pre-populate module.__dict__['obj'] — the pre-compiled
# .pyc bytecode interacts with the frame setup differently than a
# fresh compile() from source. Easiest robust fix: read the source
# and replace `del obj` with a safe variant before compiling.
#
# Requires the .py source to be bundled alongside the .pyc — see
# backend/pyi_hooks/hook-scipy.stats._distn_infrastructure.py.
source = None
try:
source = self._inner.get_source(module.__name__)
except Exception as e:
_diag(f"[scipy-loader] get_source failed: {e!r}")
if source:
patched = _patch_scipy_distn_source(source)
_diag(
f"[scipy-loader] source-patch path: patched={patched is not source}, "
f"len={len(patched)}"
)
spec = module.__spec__
if spec is not None and spec.submodule_search_locations is not None:
module.__path__ = spec.submodule_search_locations
filename = getattr(self._inner, "path", module.__name__)
bytecode = compile(patched, filename, "exec")
try:
exec(bytecode, module.__dict__)
except Exception as e:
_diag(f"[scipy-loader] patched exec raised {type(e).__name__}: {e!r}")
raise
_diag(f"[scipy-loader] exec_module {module.__name__} OK (source-patched)")
return
# No source available — fall back to the pre-bind approach. This is
# best-effort; if the frozen .pyc really does see a different `obj`
# slot, this will still crash, but we've done all we can without
# source.
_diag("[scipy-loader] no source available; falling back to pre-bind")
module.__dict__["obj"] = None
self._inner.exec_module(module)
def _install_dynamo_stub() -> None:
stub = _NoopDynamoModule("torch._dynamo")
# Mark as a package so `from torch._dynamo.X import Y` imports work
# (Python's import machinery checks parent.__path__ before looking up
# the child).
stub.__path__ = []
# torch._dynamo.config is accessed as a nested attribute namespace
# (e.g. `torch._dynamo.config.capture_scalar_outputs = True`), so use
# a permissive module so any attr read returns a no-op and sets succeed.
stub.config = _NoopDynamoModule("torch._dynamo.config")
stub.config.__path__ = []
sys.modules["torch._dynamo"] = stub
sys.modules["torch._dynamo.config"] = stub.config
# Finders:
# - torch._dynamo.* submodules -> no-op stubs
# - transformers.utils.auto_docstring and
# transformers.generation.candidate_generator -> no-op stubs (both
# paths reach sklearn -> scipy.stats which trips a separate crash)
# - scipy.stats._distn_infrastructure -> real load with `obj` pre-bound,
# so librosa -> scipy.signal -> scipy.stats loads cleanly
for _FinderCls in (
_DynamoMetaPathFinder,
_TransformersStubFinder,
_ScipyDistnPatchingFinder,
_MaskingUtilsFinder,
):
try:
sys.meta_path.insert(0, _FinderCls())
_diag(f"installed finder: {_FinderCls.__name__}")
except Exception as e:
_diag(f"FAILED to install {_FinderCls.__name__}: {e!r}")
_diag(
"final sys.meta_path head: "
+ ", ".join(type(f).__name__ for f in sys.meta_path[:6])
)
# If torch is already imported, also set the attribute on the package so
# `torch._dynamo` resolves to our stub without triggering torch.__getattr__
# (which would lazy-import the real module and crash).
torch_mod = sys.modules.get("torch")
if torch_mod is not None:
torch_mod._dynamo = stub
try:
_install_dynamo_stub()
except Exception as _e:
# Best effort. If this fails the original NameError will surface when
# transformers imports — no worse than not patching at all.
_diag(f"_install_dynamo_stub FAILED: {_e!r}")
# NOTE: we deliberately do NOT import torch or torch.compiler here.
# Runtime hooks run before the app starts and before pyi_rth_numpy_compat
# has had a chance to patch torch.from_numpy (it runs in a background
# thread, waiting for torch to appear in sys.modules). Importing torch
# eagerly at hook time would trip the numpy ABI issue and kill the
# server process at startup.
#
# torch.compiler.disable does not need an explicit stub: its
# implementation is effectively `import torch._dynamo; return
# torch._dynamo.disable(fn, recursive, reason=reason)`, and since our
# stub is installed in sys.modules, that call resolves to our no-op
# _NoopDecorator pass-through.

83
backend/pyproject.toml Normal file
View File

@@ -0,0 +1,83 @@
[project]
name = "voicebox-backend"
version = "0.2.3"
requires-python = ">=3.12"
# ---------------------------------------------------------------------------
# Ruff linter + formatter
# ---------------------------------------------------------------------------
[tool.ruff]
target-version = "py312"
line-length = 120
src = ["."]
# Files/dirs to skip entirely.
extend-exclude = [
"voicebox-server.spec",
"build_binary.py",
]
[tool.ruff.lint]
select = [
"F", # pyflakes
"E", # pycodestyle errors
"W", # pycodestyle warnings
"I", # isort
"N", # pep8-naming
"UP", # pyupgrade (modernize syntax for 3.12)
"B", # flake8-bugbear
"A", # flake8-builtins (shadowing built-in names)
"SIM", # flake8-simplify
"T20", # flake8-print (flag print() calls)
"RET", # flake8-return
"PIE", # misc lints
"PT", # flake8-pytest-style
"RUF", # ruff-specific rules
"ERA", # commented-out code detection
"FIX", # flag TODO/FIXME/HACK/XXX for review
]
ignore = [
# Allow print() in existing code -- remove items from this list as files
# are migrated to logging during the refactor.
"T201", # print() found
# These conflict with the formatter or are too noisy during migration:
"E501", # line too long (formatter handles this)
"RET504", # unnecessary assignment before return
"SIM108", # use ternary operator (sometimes less readable)
"B008", # function call in default argument (FastAPI Depends() pattern)
"UP007", # use X | Y for union (auto-fixed by UP, but noisy on big diffs)
]
# Per-file rule overrides.
[tool.ruff.lint.per-file-ignores]
# Tests can use assert, print, and magic values freely.
"tests/**" = ["S101", "T201", "PLR2004", "ERA001"]
# __init__.py re-exports are expected to have unused imports.
"**/__init__.py" = ["F401"]
# Entry points and scripts legitimately use print.
"server.py" = ["T201"]
"main.py" = ["T201"]
# AMD GPU env vars must be set before torch import.
"app.py" = ["E402"]
[tool.ruff.lint.isort]
known-first-party = ["backend"]
# Group "from backend.*" imports into the first-party section.
force-single-line = false
combine-as-imports = true
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
docstring-code-format = true
# ---------------------------------------------------------------------------
# pytest
# ---------------------------------------------------------------------------
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_mode = "auto"

View File

@@ -0,0 +1,22 @@
# MLX-specific dependencies (Apple Silicon only)
# These should only be installed on aarch64-apple-darwin platforms
mlx>=0.30.0
# miniaudio is a runtime dep of mlx-audio's STT path (mlx_audio.stt).
# mlx-audio itself is installed --no-deps (see comment below), so we
# must list miniaudio explicitly here or transcription fails on fresh
# M1 installs with `ModuleNotFoundError: miniaudio` (issue #505).
miniaudio>=1.59
# NOTE: mlx-audio is intentionally not listed here. From 0.3.1 onward it
# declares `transformers==5.0.0rc3` / `>=5.0.0`, which conflicts with the
# `transformers<=4.57.6` cap in requirements.txt and breaks CI's clean
# resolver. The mlx-audio API surface we use (mlx_audio.tts.load,
# mlx_audio.stt.load) works fine on transformers 4.57.x in practice.
#
# Install it via `pip install --no-deps mlx-audio==0.4.1` after this file
# (see .github/workflows/release.yml). Most other mlx-audio runtime deps
# (huggingface_hub, librosa, mlx-lm, numba, numpy, protobuf, pyloudnorm,
# sounddevice, tqdm) are already in requirements.txt or pulled in by
# other engines.

67
backend/requirements.txt Normal file
View File

@@ -0,0 +1,67 @@
# FastAPI and server
fastapi>=0.109.0
uvicorn[standard]>=0.27.0
pydantic>=2.5.0
# Database
sqlalchemy>=2.0.0
alembic>=1.13.0
# ML models
torch>=2.2.0
transformers>=4.36.0,<=4.57.6
accelerate>=0.26.0
huggingface_hub>=0.20.0
qwen-tts>=0.0.5
# LuxTTS (voice cloning engine)
# piper-phonemize needs custom index (no PyPI wheels)
--find-links https://k2-fsa.github.io/icefall/piper_phonemize.html
# linacodec is a git-only dep of Zipvoice (uv-only source, pip can't resolve it)
linacodec @ git+https://github.com/ysharma3501/LinaCodec.git
Zipvoice @ git+https://github.com/ysharma3501/LuxTTS.git
# Chatterbox TTS sub-dependencies (chatterbox-tts itself is installed
# --no-deps in the setup script because it pins numpy<1.26 / torch==2.6
# which are incompatible with Python 3.12+)
conformer>=0.3.2
diffusers>=0.29.0
omegaconf
pykakasi
resemble-perth>=1.0.1
s3tokenizer
spacy-pkuseg
pyloudnorm
# HumeAI TADA sub-dependencies (hume-tada itself is installed
# --no-deps in the setup script because it pins torch>=2.7,<2.8.
# descript-audio-codec is NOT installed — it pulls onnx/tensorboard
# via descript-audiotools. A lightweight shim in utils/dac_shim.py
# provides the only class TADA uses: Snake1d.)
torchaudio
# Kokoro TTS (lightweight 82M-param engine)
kokoro>=0.9.4
misaki[en,ja,zh]>=0.9.4
# spacy model for misaki English G2P — must be pre-installed or misaki
# tries spacy.cli.download() at runtime which crashes frozen builds
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl
# fugashi (pulled in by misaki[ja]) needs a MeCab dictionary on disk.
# unidic-lite ships one inside the wheel (~50MB); the full `unidic` package
# requires `python -m unidic download` (~526MB) which breaks frozen builds
# for the same reason en_core_web_sm does.
unidic-lite>=1.0.8
# Audio processing
librosa>=0.10.0
soundfile>=0.12.0
numpy>=1.24.0,<2.0
numba>=0.60.0,<0.61.0
pedalboard>=0.9.0
# HTTP client (for CUDA backend download)
httpx>=0.27.0
# Utilities
python-multipart>=0.0.6
Pillow>=10.0.0

View File

@@ -0,0 +1,32 @@
"""Route registration for the voicebox API."""
from fastapi import FastAPI
def register_routers(app: FastAPI) -> None:
"""Include all domain routers on the application."""
from .health import router as health_router
from .profiles import router as profiles_router
from .channels import router as channels_router
from .generations import router as generations_router
from .history import router as history_router
from .transcription import router as transcription_router
from .stories import router as stories_router
from .effects import router as effects_router
from .audio import router as audio_router
from .models import router as models_router
from .tasks import router as tasks_router
from .cuda import router as cuda_router
app.include_router(health_router)
app.include_router(profiles_router)
app.include_router(channels_router)
app.include_router(generations_router)
app.include_router(history_router)
app.include_router(transcription_router)
app.include_router(stories_router)
app.include_router(effects_router)
app.include_router(audio_router)
app.include_router(models_router)
app.include_router(tasks_router)
app.include_router(cuda_router)

69
backend/routes/audio.py Normal file
View File

@@ -0,0 +1,69 @@
"""Audio file serving endpoints."""
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from .. import config, models
from ..services import history
from ..database import get_db
router = APIRouter()
@router.get("/audio/version/{version_id}")
async def get_version_audio(version_id: str, db: Session = Depends(get_db)):
"""Serve audio for a specific version."""
from ..services import versions as versions_mod
version = versions_mod.get_version(version_id, db)
if not version:
raise HTTPException(status_code=404, detail="Version not found")
audio_path = config.resolve_storage_path(version.audio_path)
if audio_path is None or not audio_path.exists():
raise HTTPException(status_code=404, detail="Audio file not found")
return FileResponse(
audio_path,
media_type="audio/wav",
filename=f"generation_{version.generation_id}_{version.label}.wav",
)
@router.get("/audio/{generation_id}")
async def get_audio(generation_id: str, db: Session = Depends(get_db)):
"""Serve generated audio file (serves the default version)."""
generation = await history.get_generation(generation_id, db)
if not generation:
raise HTTPException(status_code=404, detail="Generation not found")
audio_path = config.resolve_storage_path(generation.audio_path)
if audio_path is None or not audio_path.exists():
raise HTTPException(status_code=404, detail="Audio file not found")
return FileResponse(
audio_path,
media_type="audio/wav",
filename=f"generation_{generation_id}.wav",
)
@router.get("/samples/{sample_id}")
async def get_sample_audio(sample_id: str, db: Session = Depends(get_db)):
"""Serve profile sample audio file."""
from ..database import ProfileSample as DBProfileSample
sample = db.query(DBProfileSample).filter_by(id=sample_id).first()
if not sample:
raise HTTPException(status_code=404, detail="Sample not found")
audio_path = config.resolve_storage_path(sample.audio_path)
if audio_path is None or not audio_path.exists():
raise HTTPException(status_code=404, detail="Audio file not found")
return FileResponse(
audio_path,
media_type="audio/wav",
filename=f"sample_{sample_id}.wav",
)

View File

@@ -0,0 +1,98 @@
"""Audio channel endpoints."""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from .. import models
from ..services import channels
from ..database import get_db
router = APIRouter()
@router.get("/channels", response_model=list[models.AudioChannelResponse])
async def list_channels(db: Session = Depends(get_db)):
"""List all audio channels."""
return await channels.list_channels(db)
@router.post("/channels", response_model=models.AudioChannelResponse)
async def create_channel(
data: models.AudioChannelCreate,
db: Session = Depends(get_db),
):
"""Create a new audio channel."""
try:
return await channels.create_channel(data, db)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/channels/{channel_id}", response_model=models.AudioChannelResponse)
async def get_channel(
channel_id: str,
db: Session = Depends(get_db),
):
"""Get an audio channel by ID."""
channel = await channels.get_channel(channel_id, db)
if not channel:
raise HTTPException(status_code=404, detail="Channel not found")
return channel
@router.put("/channels/{channel_id}", response_model=models.AudioChannelResponse)
async def update_channel(
channel_id: str,
data: models.AudioChannelUpdate,
db: Session = Depends(get_db),
):
"""Update an audio channel."""
try:
channel = await channels.update_channel(channel_id, data, db)
if not channel:
raise HTTPException(status_code=404, detail="Channel not found")
return channel
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/channels/{channel_id}")
async def delete_channel(
channel_id: str,
db: Session = Depends(get_db),
):
"""Delete an audio channel."""
try:
success = await channels.delete_channel(channel_id, db)
if not success:
raise HTTPException(status_code=404, detail="Channel not found")
return {"message": "Channel deleted successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/channels/{channel_id}/voices")
async def get_channel_voices(
channel_id: str,
db: Session = Depends(get_db),
):
"""Get list of profile IDs assigned to a channel."""
try:
profile_ids = await channels.get_channel_voices(channel_id, db)
return {"profile_ids": profile_ids}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.put("/channels/{channel_id}/voices")
async def set_channel_voices(
channel_id: str,
data: models.ChannelVoiceAssignment,
db: Session = Depends(get_db),
):
"""Set which voices are assigned to a channel."""
try:
await channels.set_channel_voices(channel_id, data, db)
return {"message": "Channel voices updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

82
backend/routes/cuda.py Normal file
View File

@@ -0,0 +1,82 @@
"""CUDA backend management endpoints."""
import logging
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from ..services.task_queue import create_background_task
from ..utils.progress import get_progress_manager
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/backend/cuda-status")
async def get_cuda_status():
"""Get CUDA backend download/availability status."""
from ..services import cuda
return cuda.get_cuda_status()
@router.post("/backend/download-cuda")
async def download_cuda_backend():
"""Download the CUDA backend binary."""
from ..services import cuda
if cuda.get_cuda_binary_path() is not None:
raise HTTPException(status_code=409, detail="CUDA backend already downloaded")
progress_manager = get_progress_manager()
existing = progress_manager.get_progress(cuda.PROGRESS_KEY)
if existing and existing.get("status") == "downloading":
raise HTTPException(status_code=409, detail="CUDA backend download already in progress")
async def _download():
try:
await cuda.download_cuda_binary()
except Exception as e:
logger.error("CUDA download failed: %s", e)
create_background_task(_download())
return {"message": "CUDA backend download started", "progress_key": "cuda-backend"}
@router.delete("/backend/cuda")
async def delete_cuda_backend():
"""Delete the downloaded CUDA backend binary."""
from ..services import cuda
if cuda.is_cuda_active():
raise HTTPException(
status_code=409,
detail="Cannot delete CUDA backend while it is active. Switch to CPU first.",
)
deleted = await cuda.delete_cuda_binary()
if not deleted:
raise HTTPException(status_code=404, detail="No CUDA backend found to delete")
return {"message": "CUDA backend deleted"}
@router.get("/backend/cuda-progress")
async def get_cuda_download_progress():
"""Get CUDA backend download progress via Server-Sent Events."""
progress_manager = get_progress_manager()
async def event_generator():
async for event in progress_manager.subscribe("cuda-backend"):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)

262
backend/routes/effects.py Normal file
View File

@@ -0,0 +1,262 @@
"""Effects presets and generation version endpoints."""
import asyncio
import io
import uuid
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from .. import config, models
from ..services import history
from ..database import Generation as DBGeneration, get_db
router = APIRouter()
@router.post("/effects/preview/{generation_id}")
async def preview_effects(
generation_id: str,
data: models.ApplyEffectsRequest,
db: Session = Depends(get_db),
):
"""Apply effects to a generation's clean audio and stream back without saving."""
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if (gen.status or "completed") != "completed":
raise HTTPException(status_code=400, detail="Generation is not completed")
from ..services import versions as versions_mod
from ..utils.effects import apply_effects, validate_effects_chain
from ..utils.audio import load_audio
chain_dicts = [e.model_dump() for e in data.effects_chain]
error = validate_effects_chain(chain_dicts)
if error:
raise HTTPException(status_code=400, detail=error)
all_versions = versions_mod.list_versions(generation_id, db)
clean_version = next((v for v in all_versions if v.effects_chain is None), None)
source_path = clean_version.audio_path if clean_version else gen.audio_path
resolved_source_path = config.resolve_storage_path(source_path)
if resolved_source_path is None or not resolved_source_path.exists():
raise HTTPException(status_code=404, detail="Source audio file not found")
audio, sample_rate = await asyncio.to_thread(load_audio, str(resolved_source_path))
processed = await asyncio.to_thread(apply_effects, audio, sample_rate, chain_dicts)
import soundfile as sf
buf = io.BytesIO()
await asyncio.to_thread(lambda: sf.write(buf, processed, sample_rate, format="WAV"))
buf.seek(0)
return StreamingResponse(
buf,
media_type="audio/wav",
headers={
"Content-Disposition": f'inline; filename="preview_{generation_id}.wav"',
"Cache-Control": "no-cache, no-store",
},
)
@router.get("/effects/available", response_model=models.AvailableEffectsResponse)
async def get_available_effects():
"""List all available effect types with parameter definitions."""
from ..utils.effects import get_available_effects as _get_effects
return models.AvailableEffectsResponse(effects=[models.AvailableEffect(**e) for e in _get_effects()])
@router.get("/effects/presets", response_model=list[models.EffectPresetResponse])
async def list_effect_presets(db: Session = Depends(get_db)):
"""List all effect presets (built-in + user-created)."""
from ..services import effects as effects_mod
return effects_mod.list_presets(db)
@router.get("/effects/presets/{preset_id}", response_model=models.EffectPresetResponse)
async def get_effect_preset(preset_id: str, db: Session = Depends(get_db)):
"""Get a specific effect preset."""
from ..services import effects as effects_mod
preset = effects_mod.get_preset(preset_id, db)
if not preset:
raise HTTPException(status_code=404, detail="Preset not found")
return preset
@router.post("/effects/presets", response_model=models.EffectPresetResponse)
async def create_effect_preset(
data: models.EffectPresetCreate,
db: Session = Depends(get_db),
):
"""Create a new effect preset."""
from ..services import effects as effects_mod
try:
return effects_mod.create_preset(data, db)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.put("/effects/presets/{preset_id}", response_model=models.EffectPresetResponse)
async def update_effect_preset(
preset_id: str,
data: models.EffectPresetUpdate,
db: Session = Depends(get_db),
):
"""Update an effect preset."""
from ..services import effects as effects_mod
try:
result = effects_mod.update_preset(preset_id, data, db)
if not result:
raise HTTPException(status_code=404, detail="Preset not found")
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/effects/presets/{preset_id}")
async def delete_effect_preset(preset_id: str, db: Session = Depends(get_db)):
"""Delete a user effect preset."""
from ..services import effects as effects_mod
try:
if not effects_mod.delete_preset(preset_id, db):
raise HTTPException(status_code=404, detail="Preset not found")
return {"status": "deleted"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get(
"/generations/{generation_id}/versions",
response_model=list[models.GenerationVersionResponse],
)
async def list_generation_versions(
generation_id: str,
db: Session = Depends(get_db),
):
"""List all versions for a generation."""
gen = await history.get_generation(generation_id, db)
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
from ..services import versions as versions_mod
return versions_mod.list_versions(generation_id, db)
@router.post(
"/generations/{generation_id}/versions/apply-effects",
response_model=models.GenerationVersionResponse,
)
async def apply_effects_to_generation(
generation_id: str,
data: models.ApplyEffectsRequest,
db: Session = Depends(get_db),
):
"""Apply an effects chain to an existing generation, creating a new version."""
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if (gen.status or "completed") != "completed":
raise HTTPException(status_code=400, detail="Generation is not completed")
from ..services import versions as versions_mod
from ..utils.effects import apply_effects, validate_effects_chain
from ..utils.audio import load_audio, save_audio
chain_dicts = [e.model_dump() for e in data.effects_chain]
error = validate_effects_chain(chain_dicts)
if error:
raise HTTPException(status_code=400, detail=error)
all_versions = versions_mod.list_versions(generation_id, db)
source_version_id = data.source_version_id
if source_version_id:
source_version = next((v for v in all_versions if v.id == source_version_id), None)
if not source_version:
raise HTTPException(status_code=404, detail="Source version not found")
source_path = source_version.audio_path
else:
clean_version = next((v for v in all_versions if v.effects_chain is None), None)
if not clean_version:
source_path = gen.audio_path
else:
source_path = clean_version.audio_path
source_version_id = clean_version.id
resolved_source_path = config.resolve_storage_path(source_path)
if resolved_source_path is None or not resolved_source_path.exists():
raise HTTPException(status_code=404, detail="Source audio file not found")
audio, sample_rate = await asyncio.to_thread(load_audio, str(resolved_source_path))
processed_audio = await asyncio.to_thread(apply_effects, audio, sample_rate, chain_dicts)
version_id = str(uuid.uuid4())
processed_path = config.get_generations_dir() / f"{generation_id}_{version_id[:8]}.wav"
await asyncio.to_thread(save_audio, processed_audio, str(processed_path), sample_rate)
label = data.label or f"version-{len(all_versions) + 1}"
version = versions_mod.create_version(
generation_id=generation_id,
label=label,
audio_path=config.to_storage_path(processed_path),
db=db,
effects_chain=chain_dicts,
is_default=data.set_as_default,
source_version_id=source_version_id,
)
return version
@router.put(
"/generations/{generation_id}/versions/{version_id}/set-default",
response_model=models.GenerationVersionResponse,
)
async def set_default_version(
generation_id: str,
version_id: str,
db: Session = Depends(get_db),
):
"""Set a specific version as the default for a generation."""
from ..services import versions as versions_mod
version = versions_mod.get_version(version_id, db)
if not version or version.generation_id != generation_id:
raise HTTPException(status_code=404, detail="Version not found")
result = versions_mod.set_default_version(version_id, db)
if not result:
raise HTTPException(status_code=404, detail="Version not found")
return result
@router.delete("/generations/{generation_id}/versions/{version_id}")
async def delete_generation_version(
generation_id: str,
version_id: str,
db: Session = Depends(get_db),
):
"""Delete a version. Cannot delete the last remaining version."""
from ..services import versions as versions_mod
version = versions_mod.get_version(version_id, db)
if not version or version.generation_id != generation_id:
raise HTTPException(status_code=404, detail="Version not found")
if not versions_mod.delete_version(version_id, db):
raise HTTPException(
status_code=400,
detail="Cannot delete the last remaining version",
)
return {"status": "deleted"}

View File

@@ -0,0 +1,345 @@
"""TTS generation endpoints."""
import asyncio
import logging
import uuid
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
from .. import models
from ..services import history, profiles, tts
from ..database import Generation as DBGeneration, VoiceProfile as DBVoiceProfile, get_db
from ..services.generation import run_generation
from ..services.task_queue import cancel_generation as cancel_generation_job, enqueue_generation
from ..utils.tasks import get_task_manager
router = APIRouter()
def _resolve_generation_engine(data: models.GenerationRequest, profile) -> str:
return data.engine or getattr(profile, "default_engine", None) or getattr(profile, "preset_engine", None) or "qwen"
@router.post("/generate", response_model=models.GenerationResponse)
async def generate_speech(
data: models.GenerationRequest,
db: Session = Depends(get_db),
):
"""Generate speech from text using a voice profile."""
task_manager = get_task_manager()
generation_id = str(uuid.uuid4())
profile = await profiles.get_profile(data.profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
from ..backends import engine_has_model_sizes
engine = _resolve_generation_engine(data, profile)
try:
profiles.validate_profile_engine(profile, engine)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
model_size = (data.model_size or "1.7B") if engine_has_model_sizes(engine) else None
generation = await history.create_generation(
profile_id=data.profile_id,
text=data.text,
language=data.language,
audio_path="",
duration=0,
seed=data.seed,
db=db,
instruct=data.instruct,
generation_id=generation_id,
status="generating",
engine=engine,
model_size=model_size if engine_has_model_sizes(engine) else None,
)
task_manager.start_generation(
task_id=generation_id,
profile_id=data.profile_id,
text=data.text,
)
effects_chain_config = None
if data.effects_chain is not None:
effects_chain_config = [e.model_dump() for e in data.effects_chain]
else:
import json as _json
profile_obj = db.query(DBVoiceProfile).filter_by(id=data.profile_id).first()
if profile_obj and profile_obj.effects_chain:
try:
effects_chain_config = _json.loads(profile_obj.effects_chain)
except Exception:
pass
enqueue_generation(
generation_id,
run_generation(
generation_id=generation_id,
profile_id=data.profile_id,
text=data.text,
language=data.language,
engine=engine,
model_size=model_size,
seed=data.seed,
normalize=data.normalize,
effects_chain=effects_chain_config,
instruct=data.instruct,
mode="generate",
max_chunk_chars=data.max_chunk_chars,
crossfade_ms=data.crossfade_ms,
)
)
return generation
@router.post("/generate/{generation_id}/retry", response_model=models.GenerationResponse)
async def retry_generation(generation_id: str, db: Session = Depends(get_db)):
"""Retry a failed generation using the same parameters."""
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if (gen.status or "completed") != "failed":
raise HTTPException(status_code=400, detail="Only failed generations can be retried")
gen.status = "generating"
gen.error = None
gen.audio_path = ""
gen.duration = 0
db.commit()
db.refresh(gen)
task_manager = get_task_manager()
task_manager.start_generation(
task_id=generation_id,
profile_id=gen.profile_id,
text=gen.text,
)
enqueue_generation(
generation_id,
run_generation(
generation_id=generation_id,
profile_id=gen.profile_id,
text=gen.text,
language=gen.language,
engine=gen.engine or "qwen",
model_size=gen.model_size or "1.7B",
seed=gen.seed,
instruct=gen.instruct,
mode="retry",
)
)
return models.GenerationResponse.model_validate(gen)
@router.post(
"/generate/{generation_id}/regenerate",
response_model=models.GenerationResponse,
)
async def regenerate_generation(generation_id: str, db: Session = Depends(get_db)):
"""Re-run TTS with the same parameters and save the result as a new version."""
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if (gen.status or "completed") != "completed":
raise HTTPException(status_code=400, detail="Generation must be completed to regenerate")
gen.status = "generating"
gen.error = None
db.commit()
db.refresh(gen)
task_manager = get_task_manager()
task_manager.start_generation(
task_id=generation_id,
profile_id=gen.profile_id,
text=gen.text,
)
version_id = str(uuid.uuid4())
enqueue_generation(
generation_id,
run_generation(
generation_id=generation_id,
profile_id=gen.profile_id,
text=gen.text,
language=gen.language,
engine=gen.engine or "qwen",
model_size=gen.model_size or "1.7B",
seed=gen.seed,
instruct=gen.instruct,
mode="regenerate",
version_id=version_id,
)
)
return models.GenerationResponse.model_validate(gen)
@router.post("/generate/{generation_id}/cancel")
async def cancel_generation(generation_id: str, db: Session = Depends(get_db)):
"""Cancel a queued or running generation."""
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if (gen.status or "completed") not in ("loading_model", "generating"):
raise HTTPException(status_code=400, detail="Only active generations can be cancelled")
cancellation_state = cancel_generation_job(generation_id)
if cancellation_state is None:
raise HTTPException(status_code=409, detail="Generation is no longer cancellable")
if cancellation_state == "queued":
task_manager = get_task_manager()
task_manager.complete_generation(generation_id)
await history.update_generation_status(
generation_id=generation_id,
status="failed",
db=db,
error="Generation cancelled",
)
return {"message": "Queued generation cancelled"}
return {"message": "Generation cancellation requested"}
@router.get("/generate/{generation_id}/status")
async def get_generation_status(generation_id: str, db: Session = Depends(get_db)):
"""SSE endpoint that streams generation status updates."""
import json
async def event_stream():
try:
while True:
db.expire_all()
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if not gen:
yield f"data: {json.dumps({'status': 'not_found', 'id': generation_id})}\n\n"
return
payload = {
"id": gen.id,
"status": gen.status or "completed",
"duration": gen.duration,
"error": gen.error,
}
yield f"data: {json.dumps(payload)}\n\n"
if (gen.status or "completed") in ("completed", "failed"):
return
await asyncio.sleep(1)
except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError):
logger.debug("SSE client disconnected for generation %s", generation_id)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.post("/generate/stream")
async def stream_speech(
data: models.GenerationRequest,
db: Session = Depends(get_db),
):
"""Generate speech and stream the WAV audio directly without saving to disk."""
from ..backends import get_tts_backend_for_engine, ensure_model_cached_or_raise, load_engine_model, engine_needs_trim
profile = await profiles.get_profile(data.profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
engine = _resolve_generation_engine(data, profile)
try:
profiles.validate_profile_engine(profile, engine)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
tts_model = get_tts_backend_for_engine(engine)
model_size = data.model_size or "1.7B"
await ensure_model_cached_or_raise(engine, model_size)
await load_engine_model(engine, model_size)
voice_prompt = await profiles.create_voice_prompt_for_profile(
data.profile_id,
db,
engine=engine,
)
from ..utils.chunked_tts import generate_chunked
trim_fn = None
if engine_needs_trim(engine):
from ..utils.audio import trim_tts_output
trim_fn = trim_tts_output
audio, sample_rate = await generate_chunked(
tts_model,
data.text,
voice_prompt,
language=data.language,
seed=data.seed,
instruct=data.instruct,
max_chunk_chars=data.max_chunk_chars,
crossfade_ms=data.crossfade_ms,
trim_fn=trim_fn,
)
effects_chain_config = None
if data.effects_chain is not None:
effects_chain_config = [e.model_dump() for e in data.effects_chain]
elif profile.effects_chain:
import json as _json
try:
effects_chain_config = _json.loads(profile.effects_chain)
except Exception:
effects_chain_config = None
if effects_chain_config:
from ..utils.effects import apply_effects
audio = apply_effects(audio, sample_rate, effects_chain_config)
if data.normalize:
from ..utils.audio import normalize_audio
audio = normalize_audio(audio)
wav_bytes = tts.audio_to_wav_bytes(audio, sample_rate)
async def _wav_stream():
try:
chunk_size = 64 * 1024
for i in range(0, len(wav_bytes), chunk_size):
yield wav_bytes[i : i + chunk_size]
except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError):
logger.debug("Client disconnected during audio stream")
return StreamingResponse(
_wav_stream(),
media_type="audio/wav",
headers={"Content-Disposition": 'attachment; filename="speech.wav"'},
)

248
backend/routes/health.py Normal file
View File

@@ -0,0 +1,248 @@
"""Health and infrastructure endpoints."""
import asyncio
import os
import signal
from pathlib import Path
import torch
from fastapi import APIRouter, Depends
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from .. import config, models
from ..services import tts
from ..database import get_db
from ..utils.platform_detect import get_backend_type
router = APIRouter()
# Frontend build directory — present in Docker, absent in dev/API-only mode
_frontend_dir = Path(__file__).resolve().parent.parent.parent / "frontend"
@router.get("/")
async def root():
"""Root endpoint — serves SPA index.html in Docker, JSON otherwise."""
from .. import __version__
index = _frontend_dir / "index.html"
if index.is_file():
return FileResponse(index, media_type="text/html")
return {"message": "voicebox API", "version": __version__}
@router.post("/shutdown")
async def shutdown():
"""Gracefully shutdown the server."""
async def shutdown_async():
await asyncio.sleep(0.1)
os.kill(os.getpid(), signal.SIGTERM)
asyncio.create_task(shutdown_async())
return {"message": "Shutting down..."}
@router.post("/watchdog/disable")
async def watchdog_disable():
"""Disable the parent process watchdog so the server keeps running."""
from backend.server import disable_watchdog
disable_watchdog()
return {"message": "Watchdog disabled"}
@router.get("/health", response_model=models.HealthResponse)
async def health():
"""Health check endpoint."""
from huggingface_hub import constants as hf_constants
from pathlib import Path
tts_model = tts.get_tts_model()
backend_type = get_backend_type()
has_cuda = torch.cuda.is_available()
has_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
has_xpu = False
xpu_name = None
try:
import intel_extension_for_pytorch as ipex # noqa: F401 -- side-effect import enables XPU
if hasattr(torch, "xpu") and torch.xpu.is_available():
has_xpu = True
try:
xpu_name = torch.xpu.get_device_name(0)
except Exception:
xpu_name = "Intel GPU"
except ImportError:
pass
has_directml = False
directml_name = None
try:
import torch_directml
if torch_directml.device_count() > 0:
has_directml = True
try:
directml_name = torch_directml.device_name(0)
except Exception:
directml_name = "DirectML GPU"
except ImportError:
pass
gpu_compat_warning = None
if has_cuda:
from ..backends.base import check_cuda_compatibility
_compatible, gpu_compat_warning = check_cuda_compatibility()
gpu_available = has_cuda or has_mps or has_xpu or has_directml or backend_type == "mlx"
gpu_type = None
if has_cuda:
gpu_type = f"CUDA ({torch.cuda.get_device_name(0)})"
elif has_mps:
gpu_type = "MPS (Apple Silicon)"
elif backend_type == "mlx":
gpu_type = "Metal (Apple Silicon via MLX)"
elif has_xpu:
gpu_type = f"XPU ({xpu_name})"
elif has_directml:
gpu_type = f"DirectML ({directml_name})"
vram_used = None
if has_cuda:
vram_used = torch.cuda.memory_allocated() / 1024 / 1024
elif has_xpu:
try:
vram_used = torch.xpu.memory_allocated() / 1024 / 1024
except Exception:
pass # memory_allocated() may not be available on all IPEX versions
model_loaded = False
model_size = None
try:
if tts_model.is_loaded():
model_loaded = True
model_size = getattr(tts_model, "_current_model_size", None)
if not model_size:
model_size = getattr(tts_model, "model_size", None)
except Exception:
model_loaded = False
model_size = None
model_downloaded = None
try:
from ..backends import get_model_config
default_config = get_model_config("qwen-tts-1.7B")
default_model_id = default_config.hf_repo_id if default_config else "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
try:
from huggingface_hub import scan_cache_dir
cache_info = scan_cache_dir()
for repo in cache_info.repos:
if repo.repo_id == default_model_id:
model_downloaded = True
break
except (ImportError, Exception):
cache_dir = hf_constants.HF_HUB_CACHE
repo_cache = Path(cache_dir) / ("models--" + default_model_id.replace("/", "--"))
if repo_cache.exists():
has_model_files = (
any(repo_cache.rglob("*.bin"))
or any(repo_cache.rglob("*.safetensors"))
or any(repo_cache.rglob("*.pt"))
or any(repo_cache.rglob("*.pth"))
or any(repo_cache.rglob("*.npz"))
)
model_downloaded = has_model_files
except Exception:
pass
return models.HealthResponse(
status="healthy",
model_loaded=model_loaded,
model_downloaded=model_downloaded,
model_size=model_size,
gpu_available=gpu_available,
gpu_type=gpu_type,
vram_used_mb=vram_used,
backend_type=backend_type,
backend_variant=os.environ.get(
"VOICEBOX_BACKEND_VARIANT",
"cuda" if torch.cuda.is_available() else ("xpu" if has_xpu else "cpu"),
),
gpu_compatibility_warning=gpu_compat_warning,
)
@router.get("/health/filesystem", response_model=models.FilesystemHealthResponse)
async def filesystem_health():
"""Check filesystem health: directory existence, write permissions, and disk space."""
import shutil
dirs_to_check = {
"generations": config.get_generations_dir(),
"profiles": config.get_profiles_dir(),
"data": config.get_data_dir(),
}
checks: list[models.DirectoryCheck] = []
all_ok = True
for _label, dir_path in dirs_to_check.items():
exists = dir_path.exists()
writable = False
error = None
if exists:
probe = dir_path / ".voicebox_probe"
try:
probe.write_text("ok")
probe.unlink()
writable = True
except PermissionError:
error = "Permission denied"
except OSError as e:
error = str(e)
finally:
try:
probe.unlink(missing_ok=True)
except Exception:
pass
else:
error = "Directory does not exist"
if not exists or not writable:
all_ok = False
checks.append(
models.DirectoryCheck(
path=str(dir_path.resolve()),
exists=exists,
writable=writable,
error=error,
)
)
disk_free_mb = None
disk_total_mb = None
try:
usage = shutil.disk_usage(str(config.get_data_dir()))
disk_free_mb = round(usage.free / (1024 * 1024), 1)
disk_total_mb = round(usage.total / (1024 * 1024), 1)
if disk_free_mb < 500:
all_ok = False
except OSError:
all_ok = False
return models.FilesystemHealthResponse(
healthy=all_ok,
disk_free_mb=disk_free_mb,
disk_total_mb=disk_total_mb,
directories=checks,
)

189
backend/routes/history.py Normal file
View File

@@ -0,0 +1,189 @@
"""Generation history endpoints."""
import io
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from fastapi.responses import FileResponse, StreamingResponse
from sqlalchemy.orm import Session
from .. import config, models
from ..services import export_import, history
from ..app import safe_content_disposition
from ..database import Generation as DBGeneration, VoiceProfile as DBVoiceProfile, get_db
router = APIRouter()
@router.get("/history", response_model=models.HistoryListResponse)
async def list_history(
profile_id: str | None = None,
search: str | None = None,
limit: int = 50,
offset: int = 0,
db: Session = Depends(get_db),
):
"""List generation history with optional filters."""
query = models.HistoryQuery(
profile_id=profile_id,
search=search,
limit=limit,
offset=offset,
)
return await history.list_generations(query, db)
@router.get("/history/stats")
async def get_stats(db: Session = Depends(get_db)):
"""Get generation statistics."""
return await history.get_generation_stats(db)
@router.post("/history/import")
async def import_generation(
file: UploadFile = File(...),
db: Session = Depends(get_db),
):
"""Import a generation from a ZIP archive."""
MAX_FILE_SIZE = 50 * 1024 * 1024
content = await file.read()
if len(content) > MAX_FILE_SIZE:
raise HTTPException(
status_code=400, detail=f"File too large. Maximum size is {MAX_FILE_SIZE / (1024 * 1024)}MB"
)
try:
result = await export_import.import_generation_from_zip(content, db)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/history/failed")
async def clear_failed_generations(db: Session = Depends(get_db)):
"""Delete every generation with status='failed'. Used by the UI's 'Clear failed' button (#410)."""
count = await history.delete_failed_generations(db)
return {"deleted": count}
@router.get("/history/{generation_id}", response_model=models.HistoryResponse)
async def get_generation(
generation_id: str,
db: Session = Depends(get_db),
):
"""Get a generation by ID."""
result = (
db.query(DBGeneration, DBVoiceProfile.name.label("profile_name"))
.join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id)
.filter(DBGeneration.id == generation_id)
.first()
)
if not result:
raise HTTPException(status_code=404, detail="Generation not found")
gen, profile_name = result
return models.HistoryResponse(
id=gen.id,
profile_id=gen.profile_id,
profile_name=profile_name,
text=gen.text,
language=gen.language,
audio_path=gen.audio_path,
duration=gen.duration,
seed=gen.seed,
instruct=gen.instruct,
engine=gen.engine or "qwen",
model_size=gen.model_size,
status=gen.status or "completed",
error=gen.error,
is_favorited=bool(gen.is_favorited),
created_at=gen.created_at,
)
@router.post("/history/{generation_id}/favorite")
async def toggle_favorite(
generation_id: str,
db: Session = Depends(get_db),
):
"""Toggle the favorite status of a generation."""
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
gen.is_favorited = not gen.is_favorited
db.commit()
return {"is_favorited": gen.is_favorited}
@router.delete("/history/{generation_id}")
async def delete_generation(
generation_id: str,
db: Session = Depends(get_db),
):
"""Delete a generation."""
success = await history.delete_generation(generation_id, db)
if not success:
raise HTTPException(status_code=404, detail="Generation not found")
return {"message": "Generation deleted successfully"}
@router.get("/history/{generation_id}/export")
async def export_generation(
generation_id: str,
db: Session = Depends(get_db),
):
"""Export a generation as a ZIP archive."""
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
if not generation:
raise HTTPException(status_code=404, detail="Generation not found")
try:
zip_bytes = export_import.export_generation_to_zip(generation_id, db)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
safe_text = "".join(c for c in generation.text[:30] if c.isalnum() or c in (" ", "-", "_")).strip()
if not safe_text:
safe_text = "generation"
filename = f"generation-{safe_text}.voicebox.zip"
return StreamingResponse(
io.BytesIO(zip_bytes),
media_type="application/zip",
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
)
@router.get("/history/{generation_id}/export-audio")
async def export_generation_audio(
generation_id: str,
db: Session = Depends(get_db),
):
"""Export only the audio file from a generation."""
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
if not generation:
raise HTTPException(status_code=404, detail="Generation not found")
if not generation.audio_path:
raise HTTPException(status_code=404, detail="Generation has no audio file")
audio_path = config.resolve_storage_path(generation.audio_path)
if audio_path is None or not audio_path.is_file():
raise HTTPException(status_code=404, detail="Audio file not found")
safe_text = "".join(c for c in generation.text[:30] if c.isalnum() or c in (" ", "-", "_")).strip()
if not safe_text:
safe_text = "generation"
filename = f"{safe_text}.wav"
return FileResponse(
audio_path,
media_type="audio/wav",
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
)

475
backend/routes/models.py Normal file
View File

@@ -0,0 +1,475 @@
"""Model management endpoints."""
import asyncio
import shutil
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from .. import models
from ..utils.platform_detect import get_backend_type
from ..services.task_queue import create_background_task
from ..utils.progress import get_progress_manager
from ..utils.tasks import get_task_manager
router = APIRouter()
def _get_dir_size(path: Path) -> int:
"""Get total size of a directory in bytes."""
total = 0
for f in path.rglob("*"):
if f.is_file():
total += f.stat().st_size
return total
def _copy_with_progress(src: Path, dst: Path, progress_manager, copied_so_far: int, total_bytes: int) -> int:
"""Copy a directory tree with byte-level progress tracking."""
dst.mkdir(parents=True, exist_ok=True)
for item in src.iterdir():
dest_item = dst / item.name
if item.is_dir():
copied_so_far = _copy_with_progress(item, dest_item, progress_manager, copied_so_far, total_bytes)
else:
size = item.stat().st_size
shutil.copy2(str(item), str(dest_item))
copied_so_far += size
progress_manager.update_progress(
"migration",
copied_so_far,
total_bytes,
filename=item.name,
status="downloading",
)
return copied_so_far
@router.post("/models/load")
async def load_model(model_size: str = "1.7B"):
"""Manually load TTS model."""
from ..services import tts
try:
tts_model = tts.get_tts_model()
await tts_model.load_model_async(model_size)
return {"message": f"Model {model_size} loaded successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/models/unload")
async def unload_model():
"""Unload the default Qwen TTS model to free memory."""
from ..services import tts
try:
tts.unload_tts_model()
return {"message": "Model unloaded successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/models/{model_name}/unload")
async def unload_model_by_name(model_name: str):
"""Unload a specific model from memory without deleting it from disk."""
from ..backends import get_model_config, unload_model_by_config
config = get_model_config(model_name)
if not config:
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
try:
was_loaded = unload_model_by_config(config)
if not was_loaded:
return {"message": f"Model {model_name} is not loaded"}
return {"message": f"Model {model_name} unloaded successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/models/progress/{model_name}")
async def get_model_progress(model_name: str):
"""Get model download progress via Server-Sent Events."""
progress_manager = get_progress_manager()
async def event_generator():
async for event in progress_manager.subscribe(model_name):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.get("/models/cache-dir")
async def get_models_cache_dir():
"""Get the path to the HuggingFace model cache directory."""
from huggingface_hub import constants as hf_constants
return {"path": str(Path(hf_constants.HF_HUB_CACHE))}
@router.post("/models/migrate")
async def migrate_models(request: models.ModelMigrateRequest):
"""Move all downloaded models to a new directory with byte-level progress via SSE."""
from huggingface_hub import constants as hf_constants
source = Path(hf_constants.HF_HUB_CACHE)
destination = Path(request.destination)
if not source.exists():
raise HTTPException(status_code=404, detail="Current model cache directory not found")
if source.resolve() == destination.resolve():
raise HTTPException(status_code=400, detail="Source and destination are the same directory")
if destination.resolve().is_relative_to(source.resolve()):
raise HTTPException(status_code=400, detail="Destination cannot be inside the current cache directory")
progress_manager = get_progress_manager()
model_dirs = [d for d in source.iterdir() if d.name.startswith("models--") and d.is_dir()]
if not model_dirs:
progress_manager.update_progress("migration", 1, 1, status="complete")
progress_manager.mark_complete("migration")
return {"moved": 0, "errors": [], "source": str(source), "destination": str(destination)}
destination.mkdir(parents=True, exist_ok=True)
same_fs = False
try:
same_fs = source.stat().st_dev == destination.stat().st_dev
except OSError:
pass
async def migrate_background():
moved = 0
errors = []
try:
if same_fs:
total = len(model_dirs)
for i, item in enumerate(model_dirs):
dest_item = destination / item.name
try:
if dest_item.exists():
shutil.rmtree(dest_item)
shutil.move(str(item), str(dest_item))
moved += 1
progress_manager.update_progress(
"migration",
i + 1,
total,
filename=item.name,
status="downloading",
)
except Exception as e:
errors.append(f"{item.name}: {str(e)}")
else:
total_bytes = sum(_get_dir_size(d) for d in model_dirs)
progress_manager.update_progress(
"migration", 0, total_bytes, filename="Calculating...", status="downloading"
)
copied = 0
for item in model_dirs:
dest_item = destination / item.name
try:
if dest_item.exists():
shutil.rmtree(dest_item)
copied = await asyncio.to_thread(
_copy_with_progress, item, dest_item, progress_manager, copied, total_bytes
)
await asyncio.to_thread(shutil.rmtree, str(item))
moved += 1
except Exception as e:
errors.append(f"{item.name}: {str(e)}")
progress_manager.update_progress("migration", 1, 1, status="complete")
progress_manager.mark_complete("migration")
except Exception as e:
progress_manager.update_progress("migration", 0, 0, status="error")
progress_manager.mark_error("migration", str(e))
create_background_task(migrate_background())
return {"source": str(source), "destination": str(destination)}
@router.get("/models/migrate/progress")
async def get_migration_progress():
"""Get model migration progress via Server-Sent Events."""
progress_manager = get_progress_manager()
async def event_generator():
async for event in progress_manager.subscribe("migration"):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.get("/models/status", response_model=models.ModelStatusListResponse)
async def get_model_status():
"""Get status of all available models."""
from huggingface_hub import constants as hf_constants
backend_type = get_backend_type()
task_manager = get_task_manager()
active_download_names = {task.model_name for task in task_manager.get_active_downloads()}
try:
from huggingface_hub import scan_cache_dir
use_scan_cache = True
except ImportError:
use_scan_cache = False
from ..backends import get_all_model_configs, check_model_loaded
registry_configs = get_all_model_configs()
model_configs = [
{
"model_name": cfg.model_name,
"display_name": cfg.display_name,
"hf_repo_id": cfg.hf_repo_id,
"model_size": cfg.model_size,
"check_loaded": lambda c=cfg: check_model_loaded(c),
}
for cfg in registry_configs
]
model_to_repo = {cfg["model_name"]: cfg["hf_repo_id"] for cfg in model_configs}
active_download_repos = {model_to_repo.get(name) for name in active_download_names if name in model_to_repo}
cache_info = None
if use_scan_cache:
try:
cache_info = scan_cache_dir()
except Exception:
pass
statuses = []
for config in model_configs:
try:
downloaded = False
size_mb = None
loaded = False
if cache_info:
repo_id = config["hf_repo_id"]
for repo in cache_info.repos:
if repo.repo_id == repo_id:
has_model_weights = False
for rev in repo.revisions:
for f in rev.files:
fname = f.file_name.lower()
if fname.endswith((".safetensors", ".bin", ".pt", ".pth", ".npz")):
has_model_weights = True
break
if has_model_weights:
break
has_incomplete = False
try:
cache_dir = hf_constants.HF_HUB_CACHE
blobs_dir = Path(cache_dir) / ("models--" + repo_id.replace("/", "--")) / "blobs"
if blobs_dir.exists():
has_incomplete = any(blobs_dir.glob("*.incomplete"))
except Exception:
pass
if has_model_weights and not has_incomplete:
downloaded = True
try:
total_size = sum(revision.size_on_disk for revision in repo.revisions)
size_mb = total_size / (1024 * 1024)
except Exception:
pass
break
if not downloaded:
try:
cache_dir = hf_constants.HF_HUB_CACHE
repo_cache = Path(cache_dir) / ("models--" + config["hf_repo_id"].replace("/", "--"))
if repo_cache.exists():
blobs_dir = repo_cache / "blobs"
has_incomplete = blobs_dir.exists() and any(blobs_dir.glob("*.incomplete"))
if not has_incomplete:
snapshots_dir = repo_cache / "snapshots"
has_model_files = False
if snapshots_dir.exists():
has_model_files = (
any(snapshots_dir.rglob("*.bin"))
or any(snapshots_dir.rglob("*.safetensors"))
or any(snapshots_dir.rglob("*.pt"))
or any(snapshots_dir.rglob("*.pth"))
or any(snapshots_dir.rglob("*.npz"))
)
if has_model_files:
downloaded = True
try:
total_size = sum(
f.stat().st_size
for f in repo_cache.rglob("*")
if f.is_file() and not f.name.endswith(".incomplete")
)
size_mb = total_size / (1024 * 1024)
except Exception:
pass
except Exception:
pass
try:
loaded = config["check_loaded"]()
except Exception:
loaded = False
is_downloading = config["hf_repo_id"] in active_download_repos
if is_downloading:
downloaded = False
size_mb = None
statuses.append(
models.ModelStatus(
model_name=config["model_name"],
display_name=config["display_name"],
hf_repo_id=config["hf_repo_id"],
downloaded=downloaded,
downloading=is_downloading,
size_mb=size_mb,
loaded=loaded,
)
)
except Exception:
try:
loaded = config["check_loaded"]()
except Exception:
loaded = False
is_downloading = config["hf_repo_id"] in active_download_repos
statuses.append(
models.ModelStatus(
model_name=config["model_name"],
display_name=config["display_name"],
hf_repo_id=config["hf_repo_id"],
downloaded=False,
downloading=is_downloading,
size_mb=None,
loaded=loaded,
)
)
return models.ModelStatusListResponse(models=statuses)
@router.post("/models/download")
async def trigger_model_download(request: models.ModelDownloadRequest):
"""Trigger download of a specific model."""
from ..backends import get_model_config, get_model_load_func
task_manager = get_task_manager()
progress_manager = get_progress_manager()
config = get_model_config(request.model_name)
if not config:
raise HTTPException(status_code=400, detail=f"Unknown model: {request.model_name}")
load_func = get_model_load_func(config)
async def download_in_background():
try:
result = load_func()
if asyncio.iscoroutine(result):
await result
task_manager.complete_download(request.model_name)
except Exception as e:
task_manager.error_download(request.model_name, str(e))
task_manager.start_download(request.model_name)
progress_manager.update_progress(
model_name=request.model_name,
current=0,
total=0,
filename="Connecting to HuggingFace...",
status="downloading",
)
create_background_task(download_in_background())
return {"message": f"Model {request.model_name} download started"}
@router.post("/models/download/cancel")
async def cancel_model_download(request: models.ModelDownloadRequest):
"""Cancel or dismiss an errored/stale download task."""
task_manager = get_task_manager()
progress_manager = get_progress_manager()
removed = task_manager.cancel_download(request.model_name)
progress_removed = False
with progress_manager._lock:
if request.model_name in progress_manager._progress:
del progress_manager._progress[request.model_name]
progress_removed = True
if removed or progress_removed:
return {"message": f"Download task for {request.model_name} cancelled"}
return {"message": f"No active task found for {request.model_name}"}
@router.delete("/models/{model_name}")
async def delete_model(model_name: str):
"""Delete a downloaded model from the HuggingFace cache."""
from huggingface_hub import constants as hf_constants
from ..backends import get_model_config, unload_model_by_config
config = get_model_config(model_name)
if not config:
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
hf_repo_id = config.hf_repo_id
try:
unload_model_by_config(config)
cache_dir = hf_constants.HF_HUB_CACHE
repo_cache_dir = Path(cache_dir) / ("models--" + hf_repo_id.replace("/", "--"))
if not repo_cache_dir.exists():
raise HTTPException(status_code=404, detail=f"Model {model_name} not found in cache")
try:
shutil.rmtree(repo_cache_dir)
except OSError as e:
raise HTTPException(status_code=500, detail=f"Failed to delete model cache directory: {str(e)}")
return {"message": f"Model {model_name} deleted successfully"}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")

363
backend/routes/profiles.py Normal file
View File

@@ -0,0 +1,363 @@
"""Voice profile endpoints."""
import io
import json as _json
import logging
import tempfile
from datetime import datetime
from pathlib import Path
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from fastapi.responses import FileResponse, StreamingResponse
from sqlalchemy.orm import Session
from .. import config, models
from ..app import safe_content_disposition
from ..database import VoiceProfile as DBVoiceProfile, get_db
from ..services import channels, export_import, profiles
from ..services.profiles import _profile_to_response
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/profiles", response_model=models.VoiceProfileResponse)
async def create_profile(
data: models.VoiceProfileCreate,
db: Session = Depends(get_db),
):
"""Create a new voice profile."""
try:
return await profiles.create_profile(data, db)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/profiles", response_model=list[models.VoiceProfileResponse])
async def list_profiles(db: Session = Depends(get_db)):
"""List all voice profiles."""
return await profiles.list_profiles(db)
@router.post("/profiles/import", response_model=models.VoiceProfileResponse)
async def import_profile(
file: UploadFile = File(...),
db: Session = Depends(get_db),
):
"""Import a voice profile from a ZIP archive."""
MAX_FILE_SIZE = 100 * 1024 * 1024
content = await file.read()
if len(content) > MAX_FILE_SIZE:
raise HTTPException(
status_code=400, detail=f"File too large. Maximum size is {MAX_FILE_SIZE / (1024 * 1024)}MB"
)
try:
profile = await export_import.import_profile_from_zip(content, db)
return profile
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ── Preset Voice Endpoints ───────────────────────────────────────────
# These MUST be declared before /profiles/{profile_id} to avoid the
# wildcard swallowing "presets" as a profile_id.
@router.get("/profiles/presets/{engine}")
async def list_preset_voices(engine: str):
"""List available preset voices for an engine."""
if engine == "kokoro":
from ..backends.kokoro_backend import KOKORO_VOICES
return {
"engine": engine,
"voices": [
{
"voice_id": vid,
"name": name,
"gender": gender,
"language": lang,
}
for vid, name, gender, lang in KOKORO_VOICES
],
}
if engine == "qwen_custom_voice":
from ..backends.qwen_custom_voice_backend import QWEN_CUSTOM_VOICES
return {
"engine": engine,
"voices": [
{
"voice_id": speaker_id,
"name": display_name,
"gender": gender,
"language": lang,
}
for speaker_id, display_name, gender, lang, _desc in QWEN_CUSTOM_VOICES
],
}
return {"engine": engine, "voices": []}
@router.get("/profiles/{profile_id}", response_model=models.VoiceProfileResponse)
async def get_profile(
profile_id: str,
db: Session = Depends(get_db),
):
"""Get a voice profile by ID."""
profile = await profiles.get_profile(profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
return profile
@router.put("/profiles/{profile_id}", response_model=models.VoiceProfileResponse)
async def update_profile(
profile_id: str,
data: models.VoiceProfileCreate,
db: Session = Depends(get_db),
):
"""Update a voice profile."""
try:
profile = await profiles.update_profile(profile_id, data, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
return profile
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/profiles/{profile_id}")
async def delete_profile(
profile_id: str,
db: Session = Depends(get_db),
):
"""Delete a voice profile."""
success = await profiles.delete_profile(profile_id, db)
if not success:
raise HTTPException(status_code=404, detail="Profile not found")
return {"message": "Profile deleted successfully"}
SAMPLE_MAX_FILE_SIZE = 50 * 1024 * 1024 # 50 MB
SAMPLE_UPLOAD_CHUNK_SIZE = 1024 * 1024 # 1 MB
@router.post("/profiles/{profile_id}/samples", response_model=models.ProfileSampleResponse)
async def add_profile_sample(
profile_id: str,
file: UploadFile = File(...),
reference_text: str = Form(...),
db: Session = Depends(get_db),
):
"""Add a sample to a voice profile."""
_allowed_audio_exts = {".wav", ".mp3", ".m4a", ".ogg", ".flac", ".aac", ".webm", ".opus"}
_uploaded_ext = Path(file.filename or "").suffix.lower()
file_suffix = _uploaded_ext if _uploaded_ext in _allowed_audio_exts else ".wav"
with tempfile.NamedTemporaryFile(suffix=file_suffix, delete=False) as tmp:
total_size = 0
while chunk := await file.read(SAMPLE_UPLOAD_CHUNK_SIZE):
total_size += len(chunk)
if total_size > SAMPLE_MAX_FILE_SIZE:
Path(tmp.name).unlink(missing_ok=True)
raise HTTPException(
status_code=413,
detail=f"File too large (max {SAMPLE_MAX_FILE_SIZE // (1024 * 1024)} MB)",
)
tmp.write(chunk)
tmp_path = tmp.name
try:
sample = await profiles.add_profile_sample(
profile_id,
tmp_path,
reference_text,
db,
)
return sample
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to process audio file: {str(e)}")
finally:
Path(tmp_path).unlink(missing_ok=True)
@router.get("/profiles/{profile_id}/samples", response_model=list[models.ProfileSampleResponse])
async def get_profile_samples(
profile_id: str,
db: Session = Depends(get_db),
):
"""Get all samples for a profile."""
return await profiles.get_profile_samples(profile_id, db)
@router.delete("/profiles/samples/{sample_id}")
async def delete_profile_sample(
sample_id: str,
db: Session = Depends(get_db),
):
"""Delete a profile sample."""
success = await profiles.delete_profile_sample(sample_id, db)
if not success:
raise HTTPException(status_code=404, detail="Sample not found")
return {"message": "Sample deleted successfully"}
@router.put("/profiles/samples/{sample_id}", response_model=models.ProfileSampleResponse)
async def update_profile_sample(
sample_id: str,
data: models.ProfileSampleUpdate,
db: Session = Depends(get_db),
):
"""Update a profile sample's reference text."""
sample = await profiles.update_profile_sample(sample_id, data.reference_text, db)
if not sample:
raise HTTPException(status_code=404, detail="Sample not found")
return sample
@router.post("/profiles/{profile_id}/avatar", response_model=models.VoiceProfileResponse)
async def upload_profile_avatar(
profile_id: str,
file: UploadFile = File(...),
db: Session = Depends(get_db),
):
"""Upload or update avatar image for a profile."""
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
profile = await profiles.upload_avatar(profile_id, tmp_path, db)
return profile
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
finally:
Path(tmp_path).unlink(missing_ok=True)
@router.get("/profiles/{profile_id}/avatar")
async def get_profile_avatar(
profile_id: str,
db: Session = Depends(get_db),
):
"""Get avatar image for a profile."""
profile = await profiles.get_profile(profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
if not profile.avatar_path:
raise HTTPException(status_code=404, detail="No avatar found for this profile")
avatar_path = config.resolve_storage_path(profile.avatar_path)
if avatar_path is None or not avatar_path.exists():
raise HTTPException(status_code=404, detail="Avatar file not found")
return FileResponse(avatar_path)
@router.delete("/profiles/{profile_id}/avatar")
async def delete_profile_avatar(
profile_id: str,
db: Session = Depends(get_db),
):
"""Delete avatar image for a profile."""
success = await profiles.delete_avatar(profile_id, db)
if not success:
raise HTTPException(status_code=404, detail="Profile not found or no avatar to delete")
return {"message": "Avatar deleted successfully"}
@router.get("/profiles/{profile_id}/export")
async def export_profile(
profile_id: str,
db: Session = Depends(get_db),
):
"""Export a voice profile as a ZIP archive."""
try:
profile = await profiles.get_profile(profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
zip_bytes = export_import.export_profile_to_zip(profile_id, db)
safe_name = "".join(c for c in profile.name if c.isalnum() or c in (" ", "-", "_")).strip()
if not safe_name:
safe_name = "profile"
filename = f"profile-{safe_name}.voicebox.zip"
return StreamingResponse(
io.BytesIO(zip_bytes),
media_type="application/zip",
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/profiles/{profile_id}/channels")
async def get_profile_channels(
profile_id: str,
db: Session = Depends(get_db),
):
"""Get list of channel IDs assigned to a profile."""
try:
channel_ids = await channels.get_profile_channels(profile_id, db)
return {"channel_ids": channel_ids}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.put("/profiles/{profile_id}/channels")
async def set_profile_channels(
profile_id: str,
data: models.ProfileChannelAssignment,
db: Session = Depends(get_db),
):
"""Set which channels a profile is assigned to."""
try:
await channels.set_profile_channels(profile_id, data, db)
return {"message": "Profile channels updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.put("/profiles/{profile_id}/effects", response_model=models.VoiceProfileResponse)
async def update_profile_effects(
profile_id: str,
data: models.ProfileEffectsUpdate,
db: Session = Depends(get_db),
):
"""Set or clear the default effects chain for a voice profile."""
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
if data.effects_chain is not None:
from ..utils.effects import validate_effects_chain
chain_dicts = [e.model_dump() for e in data.effects_chain]
error = validate_effects_chain(chain_dicts)
if error:
raise HTTPException(status_code=400, detail=error)
profile.effects_chain = _json.dumps(chain_dicts)
else:
profile.effects_chain = None
profile.updated_at = datetime.utcnow()
db.commit()
db.refresh(profile)
return _profile_to_response(profile)

223
backend/routes/stories.py Normal file
View File

@@ -0,0 +1,223 @@
"""Story endpoints."""
import io
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from .. import database, models
from ..services import stories
from ..app import safe_content_disposition
from ..database import get_db
router = APIRouter()
@router.get("/stories", response_model=list[models.StoryResponse])
async def list_stories(db: Session = Depends(get_db)):
"""List all stories."""
return await stories.list_stories(db)
@router.post("/stories", response_model=models.StoryResponse)
async def create_story(
data: models.StoryCreate,
db: Session = Depends(get_db),
):
"""Create a new story."""
try:
return await stories.create_story(data, db)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/stories/{story_id}", response_model=models.StoryDetailResponse)
async def get_story(
story_id: str,
db: Session = Depends(get_db),
):
"""Get a story with all its items."""
story = await stories.get_story(story_id, db)
if not story:
raise HTTPException(status_code=404, detail="Story not found")
return story
@router.put("/stories/{story_id}", response_model=models.StoryResponse)
async def update_story(
story_id: str,
data: models.StoryCreate,
db: Session = Depends(get_db),
):
"""Update a story."""
story = await stories.update_story(story_id, data, db)
if not story:
raise HTTPException(status_code=404, detail="Story not found")
return story
@router.delete("/stories/{story_id}")
async def delete_story(
story_id: str,
db: Session = Depends(get_db),
):
"""Delete a story."""
success = await stories.delete_story(story_id, db)
if not success:
raise HTTPException(status_code=404, detail="Story not found")
return {"message": "Story deleted successfully"}
@router.post("/stories/{story_id}/items", response_model=models.StoryItemDetail)
async def add_story_item(
story_id: str,
data: models.StoryItemCreate,
db: Session = Depends(get_db),
):
"""Add a generation to a story."""
item = await stories.add_item_to_story(story_id, data, db)
if not item:
raise HTTPException(status_code=404, detail="Story or generation not found")
return item
@router.delete("/stories/{story_id}/items/{item_id}")
async def remove_story_item(
story_id: str,
item_id: str,
db: Session = Depends(get_db),
):
"""Remove a story item from a story."""
success = await stories.remove_item_from_story(story_id, item_id, db)
if not success:
raise HTTPException(status_code=404, detail="Story item not found")
return {"message": "Item removed successfully"}
@router.put("/stories/{story_id}/items/times")
async def update_story_item_times(
story_id: str,
data: models.StoryItemBatchUpdate,
db: Session = Depends(get_db),
):
"""Update story item timecodes."""
success = await stories.update_story_item_times(story_id, data, db)
if not success:
raise HTTPException(status_code=400, detail="Invalid timecode update request")
return {"message": "Item timecodes updated successfully"}
@router.put("/stories/{story_id}/items/reorder", response_model=list[models.StoryItemDetail])
async def reorder_story_items(
story_id: str,
data: models.StoryItemReorder,
db: Session = Depends(get_db),
):
"""Reorder story items and recalculate timecodes."""
items = await stories.reorder_story_items(story_id, data.generation_ids, db)
if items is None:
raise HTTPException(
status_code=400, detail="Invalid reorder request - ensure all generation IDs belong to this story"
)
return items
@router.put("/stories/{story_id}/items/{item_id}/move", response_model=models.StoryItemDetail)
async def move_story_item(
story_id: str,
item_id: str,
data: models.StoryItemMove,
db: Session = Depends(get_db),
):
"""Move a story item (update position and/or track)."""
item = await stories.move_story_item(story_id, item_id, data, db)
if item is None:
raise HTTPException(status_code=404, detail="Story item not found")
return item
@router.put("/stories/{story_id}/items/{item_id}/trim", response_model=models.StoryItemDetail)
async def trim_story_item(
story_id: str,
item_id: str,
data: models.StoryItemTrim,
db: Session = Depends(get_db),
):
"""Trim a story item."""
item = await stories.trim_story_item(story_id, item_id, data, db)
if item is None:
raise HTTPException(status_code=404, detail="Story item not found or invalid trim values")
return item
@router.post("/stories/{story_id}/items/{item_id}/split", response_model=list[models.StoryItemDetail])
async def split_story_item(
story_id: str,
item_id: str,
data: models.StoryItemSplit,
db: Session = Depends(get_db),
):
"""Split a story item at a given time, creating two clips."""
items = await stories.split_story_item(story_id, item_id, data, db)
if items is None:
raise HTTPException(status_code=404, detail="Story item not found or invalid split point")
return items
@router.post("/stories/{story_id}/items/{item_id}/duplicate", response_model=models.StoryItemDetail)
async def duplicate_story_item(
story_id: str,
item_id: str,
db: Session = Depends(get_db),
):
"""Duplicate a story item."""
item = await stories.duplicate_story_item(story_id, item_id, db)
if item is None:
raise HTTPException(status_code=404, detail="Story item not found")
return item
@router.put("/stories/{story_id}/items/{item_id}/version", response_model=models.StoryItemDetail)
async def set_story_item_version(
story_id: str,
item_id: str,
data: models.StoryItemVersionUpdate,
db: Session = Depends(get_db),
):
"""Pin a story item to a specific generation version."""
item = await stories.set_story_item_version(story_id, item_id, data, db)
if item is None:
raise HTTPException(status_code=404, detail="Story item or version not found")
return item
@router.get("/stories/{story_id}/export-audio")
async def export_story_audio(
story_id: str,
db: Session = Depends(get_db),
):
"""Export story as single mixed audio file."""
try:
story = db.query(database.Story).filter_by(id=story_id).first()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
audio_bytes = await stories.export_story_audio(story_id, db)
if not audio_bytes:
raise HTTPException(status_code=400, detail="Story has no audio items")
safe_name = "".join(c for c in story.name if c.isalnum() or c in (" ", "-", "_")).strip()
if not safe_name:
safe_name = "story"
filename = f"{safe_name}.wav"
return StreamingResponse(
io.BytesIO(audio_bytes),
media_type="audio/wav",
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

125
backend/routes/tasks.py Normal file
View File

@@ -0,0 +1,125 @@
"""Task and cache management endpoints."""
from datetime import datetime
from fastapi import APIRouter
from .. import models
from ..utils.cache import clear_voice_prompt_cache
from ..utils.progress import get_progress_manager
from ..utils.tasks import get_task_manager
from fastapi import HTTPException
router = APIRouter()
@router.post("/tasks/clear")
async def clear_all_tasks():
"""Clear all download tasks and progress state."""
task_manager = get_task_manager()
progress_manager = get_progress_manager()
task_manager.clear_all()
with progress_manager._lock:
progress_manager._progress.clear()
progress_manager._last_notify_time.clear()
progress_manager._last_notify_progress.clear()
return {"message": "All task state cleared"}
@router.post("/cache/clear")
async def clear_cache():
"""Clear all voice prompt caches (memory and disk)."""
try:
deleted_count = clear_voice_prompt_cache()
return {
"message": "Voice prompt cache cleared successfully",
"files_deleted": deleted_count,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to clear cache: {str(e)}")
@router.get("/tasks/active", response_model=models.ActiveTasksResponse)
async def get_active_tasks():
"""Return all currently active downloads and generations."""
task_manager = get_task_manager()
progress_manager = get_progress_manager()
active_downloads = []
task_manager_downloads = task_manager.get_active_downloads()
progress_active = progress_manager.get_all_active()
download_map = {task.model_name: task for task in task_manager_downloads}
progress_map = {p["model_name"]: p for p in progress_active}
all_model_names = set(download_map.keys()) | set(progress_map.keys())
for model_name in all_model_names:
task = download_map.get(model_name)
progress = progress_map.get(model_name)
if task:
error = task.error
if not error:
with progress_manager._lock:
pm_data = progress_manager._progress.get(model_name)
if pm_data:
error = pm_data.get("error")
prog = progress or {}
if not prog:
with progress_manager._lock:
pm_data = progress_manager._progress.get(model_name)
if pm_data:
prog = pm_data
active_downloads.append(
models.ActiveDownloadTask(
model_name=model_name,
status=task.status,
started_at=task.started_at,
error=error,
progress=prog.get("progress"),
current=prog.get("current"),
total=prog.get("total"),
filename=prog.get("filename"),
)
)
elif progress:
timestamp_str = progress.get("timestamp")
if timestamp_str:
try:
started_at = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00"))
except (ValueError, AttributeError):
started_at = datetime.utcnow()
else:
started_at = datetime.utcnow()
active_downloads.append(
models.ActiveDownloadTask(
model_name=model_name,
status=progress.get("status", "downloading"),
started_at=started_at,
error=progress.get("error"),
progress=progress.get("progress"),
current=progress.get("current"),
total=progress.get("total"),
filename=progress.get("filename"),
)
)
active_generations = []
for gen_task in task_manager.get_active_generations():
active_generations.append(
models.ActiveGenerationTask(
task_id=gen_task.task_id,
profile_id=gen_task.profile_id,
text_preview=gen_task.text_preview,
started_at=gen_task.started_at,
)
)
return models.ActiveTasksResponse(
downloads=active_downloads,
generations=active_generations,
)

View File

@@ -0,0 +1,84 @@
"""Transcription endpoints."""
import asyncio
import tempfile
from pathlib import Path
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from .. import models
from ..services import transcribe
from ..services.task_queue import create_background_task
from ..utils.tasks import get_task_manager
router = APIRouter()
UPLOAD_CHUNK_SIZE = 1024 * 1024 # 1MB
@router.post("/transcribe", response_model=models.TranscriptionResponse)
async def transcribe_audio(
file: UploadFile = File(...),
language: str | None = Form(None),
model: str | None = Form(None),
):
"""Transcribe audio file to text."""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
tmp.write(chunk)
tmp_path = tmp.name
try:
from ..utils.audio import load_audio
from ..backends import WHISPER_HF_REPOS
audio, sr = await asyncio.to_thread(load_audio, tmp_path)
duration = len(audio) / sr
whisper_model = transcribe.get_whisper_model()
model_size = model if model else whisper_model.model_size
valid_sizes = list(WHISPER_HF_REPOS.keys())
if model_size not in valid_sizes:
raise HTTPException(
status_code=400,
detail=f"Invalid model size '{model_size}'. Must be one of: {', '.join(valid_sizes)}",
)
already_loaded = whisper_model.is_loaded() and whisper_model.model_size == model_size
if not already_loaded and not whisper_model._is_model_cached(model_size):
progress_model_name = f"whisper-{model_size}"
task_manager = get_task_manager()
async def download_whisper_background():
try:
await whisper_model.load_model_async(model_size)
task_manager.complete_download(progress_model_name)
except Exception as e:
task_manager.error_download(progress_model_name, str(e))
task_manager.start_download(progress_model_name)
create_background_task(download_whisper_background())
raise HTTPException(
status_code=202,
detail={
"message": f"Whisper model {model_size} is being downloaded. Please wait and try again.",
"model_name": progress_model_name,
"downloading": True,
},
)
text = await whisper_model.transcribe(tmp_path, language, model_size)
return models.TranscriptionResponse(
text=text,
duration=duration,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
Path(tmp_path).unlink(missing_ok=True)

303
backend/server.py Normal file
View File

@@ -0,0 +1,303 @@
"""
Entry point for PyInstaller-bundled voicebox server.
This module provides an entry point that works with PyInstaller by using
absolute imports instead of relative imports.
"""
import sys
import os
# On Windows with --noconsole (PyInstaller), sys.stdout/stderr are None.
# They can also be broken file objects in some edge cases.
# Redirect to devnull to prevent crashes from print()/tqdm/logging.
def _is_writable(stream):
"""Check if a stream is usable for writing."""
if stream is None:
return False
try:
stream.write("")
return True
except Exception:
return False
if not _is_writable(sys.stdout):
sys.stdout = open(os.devnull, 'w')
if not _is_writable(sys.stderr):
sys.stderr = open(os.devnull, 'w')
# PyInstaller + multiprocessing: child processes re-execute the frozen binary
# with internal arguments. freeze_support() handles this and exits early.
import multiprocessing
multiprocessing.freeze_support()
# In frozen builds, piper_phonemize's espeak-ng C library falls back to
# /usr/share/espeak-ng-data/ which doesn't exist. Point it at the bundled
# data directory instead.
if getattr(sys, 'frozen', False):
_meipass = getattr(sys, '_MEIPASS', os.path.dirname(sys.executable))
_espeak_data = os.path.join(_meipass, 'piper_phonemize', 'espeak-ng-data')
if os.path.isdir(_espeak_data):
os.environ.setdefault('ESPEAK_DATA_PATH', _espeak_data)
# Fast path: handle --version before any heavy imports so the Rust
# version check doesn't block for 30+ seconds loading torch etc.
if "--version" in sys.argv:
from backend import __version__
print(f"voicebox-server {__version__}")
sys.exit(0)
import logging
# Set up logging FIRST, before any imports that might fail
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stderr, # Log to stderr so it's captured by Tauri
)
logger = logging.getLogger(__name__)
# Log startup immediately to confirm binary execution
logger.info("=" * 60)
logger.info("voicebox-server starting up...")
logger.info(f"Python version: {sys.version}")
logger.info(f"Executable: {sys.executable}")
logger.info(f"Arguments: {sys.argv}")
logger.info("=" * 60)
try:
logger.info("Importing argparse...")
import argparse
logger.info("Importing uvicorn...")
import uvicorn
logger.info("Standard library imports successful")
# Import the FastAPI app from the backend package
logger.info("Importing backend.config...")
from backend import config
logger.info("Importing backend.database...")
from backend import database
logger.info("Importing backend.main (this may take a while due to torch/transformers)...")
from backend.main import app
logger.info("Backend imports successful")
except Exception as e:
logger.error(f"Failed to import required modules: {e}", exc_info=True)
sys.exit(1)
_watchdog_disabled = False
def disable_watchdog():
"""Disable the parent watchdog so the server keeps running after parent exits."""
global _watchdog_disabled
_watchdog_disabled = True
# Ignore SIGHUP so the server survives when the parent Tauri process exits.
# On Unix, child processes receive SIGHUP when the parent's session leader
# exits, which would kill the server even though we want it to persist.
if sys.platform != "win32":
import signal
signal.signal(signal.SIGHUP, signal.SIG_IGN)
def _start_parent_watchdog(parent_pid, data_dir=None):
"""Monitor parent process and exit if it dies.
This is the clean shutdown mechanism: instead of the Tauri app trying to
forcefully kill the server (which spawns console windows on Windows),
the server monitors its parent and shuts itself down gracefully.
The Tauri app writes a .keep-running sentinel file to data_dir before
exiting when "remain running after close" is enabled. This is a reliable
fallback for the HTTP /watchdog/disable request, which can race with
process exit on Windows.
"""
import os
import signal
import threading
import time
# Set up a file logger so we can debug in production
watchdog_logger = logging.getLogger("watchdog")
if data_dir:
try:
log_dir = os.path.join(data_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
fh = logging.FileHandler(os.path.join(log_dir, "watchdog.log"))
fh.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
watchdog_logger.addHandler(fh)
except Exception:
pass
watchdog_logger.setLevel(logging.INFO)
def _is_pid_alive(pid):
"""Check if a process with the given PID exists (cross-platform)."""
try:
if sys.platform == "win32":
import ctypes
kernel32 = ctypes.windll.kernel32
PROCESS_QUERY_LIMITED_INFORMATION = 0x1000
handle = kernel32.OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, False, pid)
if handle:
# Check if process has actually exited
STILL_ACTIVE = 259
exit_code = ctypes.c_ulong()
result = kernel32.GetExitCodeProcess(handle, ctypes.byref(exit_code))
kernel32.CloseHandle(handle)
if result and exit_code.value == STILL_ACTIVE:
return True
watchdog_logger.info(f"PID {pid}: exited with code {exit_code.value}")
return False
# OpenProcess failed — check if it's an access error (process exists
# but we can't open it) vs process not found
error = ctypes.GetLastError()
ACCESS_DENIED = 5
if error == ACCESS_DENIED:
return True # process exists, we just can't open it
watchdog_logger.info(f"PID {pid}: OpenProcess failed, error={error}")
return False
else:
os.kill(pid, 0)
return True
except (OSError, PermissionError):
return False
def _watch():
watchdog_logger.info(f"Parent watchdog started, monitoring PID {parent_pid}, server PID {os.getpid()}")
# Verify parent is alive before starting the loop
alive = _is_pid_alive(parent_pid)
watchdog_logger.info(f"Parent PID {parent_pid} initial check: alive={alive}")
if not alive:
watchdog_logger.warning(f"Parent PID {parent_pid} not found on first check — disabling watchdog")
return
# Clear any stale .keep-running sentinel from a previous session. The
# sentinel is only removed by the watchdog when it's consumed during a
# grace period; if the HTTP /watchdog/disable path wins the race on a
# "keep running" exit, the sentinel is left on disk. Wipe it here so a
# future session can't inherit that stale signal.
if data_dir:
stale = os.path.join(data_dir, ".keep-running")
if os.path.exists(stale):
try:
os.remove(stale)
watchdog_logger.info("Removed stale .keep-running sentinel from previous session")
except OSError as e:
watchdog_logger.warning(f"Failed to remove stale sentinel: {e}")
while True:
if _watchdog_disabled:
watchdog_logger.info("Watchdog disabled (keep server running), stopping monitor")
return
if not _is_pid_alive(parent_pid):
# Parent is gone. Before shutting down, give the app a moment
# to send /watchdog/disable — there is a race where the Tauri
# RunEvent::Exit handler sends the disable request while we are
# mid-iteration (already past the _watchdog_disabled check above).
watchdog_logger.info(f"Parent process {parent_pid} gone, waiting for possible disable request...")
time.sleep(1)
if _watchdog_disabled:
watchdog_logger.info("Watchdog was disabled during grace period, keeping server alive")
return
# Check for sentinel file written by Tauri before exit.
# This catches the case where the HTTP disable request
# didn't arrive before the parent process died (common
# on Windows where process teardown is fast).
sentinel = os.path.join(data_dir, ".keep-running") if data_dir else None
if sentinel and os.path.exists(sentinel):
watchdog_logger.info("Found .keep-running sentinel file, keeping server alive")
try:
os.remove(sentinel)
except OSError:
pass
return
watchdog_logger.info("Watchdog still enabled after grace period, shutting down server...")
if sys.platform == "win32":
# sys.exit triggers SystemExit, allowing uvicorn to run
# shutdown handlers. os.kill(SIGTERM) on Windows calls
# TerminateProcess which hard-kills without cleanup.
os._exit(0)
else:
os.kill(os.getpid(), signal.SIGTERM)
return
time.sleep(2)
t = threading.Thread(target=_watch, daemon=True)
t.start()
if __name__ == "__main__":
try:
parser = argparse.ArgumentParser(description="voicebox backend server")
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Host to bind to (use 0.0.0.0 for remote access)",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port to bind to",
)
parser.add_argument(
"--data-dir",
type=str,
default=None,
help="Data directory for database, profiles, and generated audio",
)
parser.add_argument(
"--parent-pid",
type=int,
default=None,
help="PID of parent process to monitor; server exits when parent dies",
)
parser.add_argument(
"--version",
action="store_true",
help="Print version and exit (handled above, kept for argparse help)",
)
args = parser.parse_args()
if args.parent_pid is not None and args.parent_pid <= 0:
parser.error("--parent-pid must be a positive integer")
# Detect backend variant from binary name
# voicebox-server-cuda → sets VOICEBOX_BACKEND_VARIANT=cuda
import os
binary_name = os.path.basename(sys.executable).lower()
if "cuda" in binary_name:
os.environ["VOICEBOX_BACKEND_VARIANT"] = "cuda"
logger.info("Backend variant: CUDA")
else:
os.environ["VOICEBOX_BACKEND_VARIANT"] = "cpu"
logger.info("Backend variant: CPU")
# Register parent watchdog to start after server is fully ready
if args.parent_pid is not None:
_parent_pid = args.parent_pid
_data_dir = args.data_dir
@app.on_event("startup")
async def _on_startup():
_start_parent_watchdog(_parent_pid, _data_dir)
logger.info(f"Parsed arguments: host={args.host}, port={args.port}, data_dir={args.data_dir}")
# Set data directory if provided
if args.data_dir:
logger.info(f"Setting data directory to: {args.data_dir}")
config.set_data_dir(args.data_dir)
# Initialize database after data directory is set
logger.info("Initializing database...")
database.init_db()
logger.info("Database initialized successfully")
logger.info(f"Starting uvicorn server on {args.host}:{args.port}...")
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info",
)
except Exception as e:
logger.error(f"Server startup failed: {e}", exc_info=True)
sys.exit(1)

View File

@@ -0,0 +1 @@
# Services layer — generation orchestration and background task management.

View File

@@ -0,0 +1,263 @@
"""
Audio channel management module.
"""
from typing import List, Optional
from datetime import datetime
import uuid
from sqlalchemy.orm import Session
from ..models import (
AudioChannelCreate,
AudioChannelUpdate,
AudioChannelResponse,
ChannelVoiceAssignment,
ProfileChannelAssignment,
)
from ..database import (
AudioChannel as DBAudioChannel,
ChannelDeviceMapping as DBChannelDeviceMapping,
ProfileChannelMapping as DBProfileChannelMapping,
VoiceProfile as DBVoiceProfile,
)
async def list_channels(db: Session) -> List[AudioChannelResponse]:
"""List all audio channels."""
channels = db.query(DBAudioChannel).all()
result = []
for channel in channels:
# Get device IDs for this channel
device_mappings = db.query(DBChannelDeviceMapping).filter_by(
channel_id=channel.id
).all()
device_ids = [m.device_id for m in device_mappings]
result.append(AudioChannelResponse(
id=channel.id,
name=channel.name,
is_default=channel.is_default,
device_ids=device_ids,
created_at=channel.created_at,
))
return result
async def get_channel(channel_id: str, db: Session) -> Optional[AudioChannelResponse]:
"""Get a channel by ID."""
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
if not channel:
return None
# Get device IDs
device_mappings = db.query(DBChannelDeviceMapping).filter_by(
channel_id=channel.id
).all()
device_ids = [m.device_id for m in device_mappings]
return AudioChannelResponse(
id=channel.id,
name=channel.name,
is_default=channel.is_default,
device_ids=device_ids,
created_at=channel.created_at,
)
async def create_channel(
data: AudioChannelCreate,
db: Session,
) -> AudioChannelResponse:
"""Create a new audio channel."""
# Check if name already exists
existing = db.query(DBAudioChannel).filter_by(name=data.name).first()
if existing:
raise ValueError(f"Channel with name '{data.name}' already exists")
# Create channel
channel = DBAudioChannel(
id=str(uuid.uuid4()),
name=data.name,
is_default=False,
created_at=datetime.utcnow(),
)
db.add(channel)
db.flush()
# Add device mappings
for device_id in data.device_ids:
mapping = DBChannelDeviceMapping(
id=str(uuid.uuid4()),
channel_id=channel.id,
device_id=device_id,
)
db.add(mapping)
db.commit()
db.refresh(channel)
return AudioChannelResponse(
id=channel.id,
name=channel.name,
is_default=channel.is_default,
device_ids=data.device_ids,
created_at=channel.created_at,
)
async def update_channel(
channel_id: str,
data: AudioChannelUpdate,
db: Session,
) -> Optional[AudioChannelResponse]:
"""Update an audio channel."""
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
if not channel:
return None
if channel.is_default:
raise ValueError("Cannot modify the default channel")
# Update name if provided
if data.name is not None:
# Check if name already exists (excluding current channel)
existing = db.query(DBAudioChannel).filter(
DBAudioChannel.name == data.name,
DBAudioChannel.id != channel_id
).first()
if existing:
raise ValueError(f"Channel with name '{data.name}' already exists")
channel.name = data.name
# Update device mappings if provided
if data.device_ids is not None:
# Delete existing mappings
db.query(DBChannelDeviceMapping).filter_by(channel_id=channel_id).delete()
# Add new mappings
for device_id in data.device_ids:
mapping = DBChannelDeviceMapping(
id=str(uuid.uuid4()),
channel_id=channel.id,
device_id=device_id,
)
db.add(mapping)
db.commit()
db.refresh(channel)
# Get updated device IDs
device_mappings = db.query(DBChannelDeviceMapping).filter_by(
channel_id=channel.id
).all()
device_ids = [m.device_id for m in device_mappings]
return AudioChannelResponse(
id=channel.id,
name=channel.name,
is_default=channel.is_default,
device_ids=device_ids,
created_at=channel.created_at,
)
async def delete_channel(channel_id: str, db: Session) -> bool:
"""Delete an audio channel."""
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
if not channel:
return False
if channel.is_default:
raise ValueError("Cannot delete the default channel")
# Delete device mappings
db.query(DBChannelDeviceMapping).filter_by(channel_id=channel_id).delete()
# Delete profile-channel mappings
db.query(DBProfileChannelMapping).filter_by(channel_id=channel_id).delete()
# Delete channel
db.delete(channel)
db.commit()
return True
async def get_channel_voices(channel_id: str, db: Session) -> List[str]:
"""Get list of profile IDs assigned to a channel."""
mappings = db.query(DBProfileChannelMapping).filter_by(
channel_id=channel_id
).all()
return [m.profile_id for m in mappings]
async def set_channel_voices(
channel_id: str,
data: ChannelVoiceAssignment,
db: Session,
) -> None:
"""Set which voices are assigned to a channel."""
# Verify channel exists
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
if not channel:
raise ValueError(f"Channel {channel_id} not found")
# Verify all profiles exist
for profile_id in data.profile_ids:
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
raise ValueError(f"Profile {profile_id} not found")
# Delete existing mappings for this channel
db.query(DBProfileChannelMapping).filter_by(channel_id=channel_id).delete()
# Add new mappings
for profile_id in data.profile_ids:
mapping = DBProfileChannelMapping(
profile_id=profile_id,
channel_id=channel_id,
)
db.add(mapping)
db.commit()
async def get_profile_channels(profile_id: str, db: Session) -> List[str]:
"""Get list of channel IDs assigned to a profile."""
mappings = db.query(DBProfileChannelMapping).filter_by(
profile_id=profile_id
).all()
return [m.channel_id for m in mappings]
async def set_profile_channels(
profile_id: str,
data: ProfileChannelAssignment,
db: Session,
) -> None:
"""Set which channels a profile is assigned to."""
# Verify profile exists
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
raise ValueError(f"Profile {profile_id} not found")
# Verify all channels exist
for channel_id in data.channel_ids:
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
if not channel:
raise ValueError(f"Channel {channel_id} not found")
# Delete existing mappings for this profile
db.query(DBProfileChannelMapping).filter_by(profile_id=profile_id).delete()
# Add new mappings
for channel_id in data.channel_ids:
mapping = DBProfileChannelMapping(
profile_id=profile_id,
channel_id=channel_id,
)
db.add(mapping)
db.commit()

422
backend/services/cuda.py Normal file
View File

@@ -0,0 +1,422 @@
"""
CUDA backend download, assembly, and verification.
Downloads two archives from GitHub Releases:
1. Server core (voicebox-server-cuda.tar.gz) — the exe + non-NVIDIA deps,
versioned with the app.
2. CUDA libs (cuda-libs-{version}.tar.gz) — NVIDIA runtime libraries,
versioned independently (only redownloaded on CUDA toolkit bump).
Both archives are extracted into {data_dir}/backends/cuda/ which forms the
complete PyInstaller --onedir directory structure that torch expects.
"""
import asyncio
import hashlib
import json
import logging
import os
import sys
import tarfile
from pathlib import Path
from typing import Optional
from ..config import get_data_dir
from ..utils.progress import get_progress_manager
from .. import __version__
logger = logging.getLogger(__name__)
GITHUB_RELEASES_URL = "https://github.com/jamiepine/voicebox/releases/download"
PROGRESS_KEY = "cuda-backend"
# The current expected CUDA libs version. Bump this when we change the
# CUDA toolkit version or torch's CUDA dependency changes (e.g. cu126 -> cu128).
CUDA_LIBS_VERSION = "cu128-v1"
# Prevents concurrent download_cuda_binary() calls from racing on the same
# temp file. The auto-update background task and the manual HTTP endpoint
# can both invoke download_cuda_binary(); without this lock the progress-
# manager status check is a TOCTOU race.
_download_lock = asyncio.Lock()
def get_backends_dir() -> Path:
"""Directory where downloaded backend binaries are stored."""
d = get_data_dir() / "backends"
d.mkdir(parents=True, exist_ok=True)
return d
def get_cuda_dir() -> Path:
"""Directory where the CUDA backend (onedir) is extracted."""
d = get_backends_dir() / "cuda"
d.mkdir(parents=True, exist_ok=True)
return d
def get_cuda_exe_name() -> str:
"""Platform-specific CUDA executable filename."""
if sys.platform == "win32":
return "voicebox-server-cuda.exe"
return "voicebox-server-cuda"
def get_cuda_binary_path() -> Optional[Path]:
"""Return path to the CUDA executable if it exists inside the onedir."""
p = get_cuda_dir() / get_cuda_exe_name()
if p.exists():
return p
return None
def get_cuda_libs_manifest_path() -> Path:
"""Path to the cuda-libs.json manifest inside the CUDA dir."""
return get_cuda_dir() / "cuda-libs.json"
def get_installed_cuda_libs_version() -> Optional[str]:
"""Read the installed CUDA libs version from cuda-libs.json, or None."""
manifest_path = get_cuda_libs_manifest_path()
if not manifest_path.exists():
return None
try:
data = json.loads(manifest_path.read_text())
return data.get("version")
except Exception as e:
logger.warning(f"Could not read cuda-libs.json: {e}")
return None
def is_cuda_active() -> bool:
"""Check if the current process is the CUDA binary.
The CUDA binary sets this env var on startup (see server.py).
"""
return os.environ.get("VOICEBOX_BACKEND_VARIANT") == "cuda"
def get_cuda_status() -> dict:
"""Get current CUDA backend status for the API."""
progress_manager = get_progress_manager()
cuda_path = get_cuda_binary_path()
progress = progress_manager.get_progress(PROGRESS_KEY)
cuda_libs_version = get_installed_cuda_libs_version()
return {
"available": cuda_path is not None,
"active": is_cuda_active(),
"binary_path": str(cuda_path) if cuda_path else None,
"cuda_libs_version": cuda_libs_version,
"downloading": progress is not None and progress.get("status") == "downloading",
"download_progress": progress,
}
def _needs_server_download(version: Optional[str] = None) -> bool:
"""Check if the server core archive needs to be (re)downloaded."""
cuda_path = get_cuda_binary_path()
if not cuda_path:
return True
# Check if the binary version matches the expected app version
installed = get_cuda_binary_version()
expected = version or __version__
if expected.startswith("v"):
expected = expected[1:]
return installed != expected
def _needs_cuda_libs_download() -> bool:
"""Check if the CUDA libs archive needs to be (re)downloaded."""
installed = get_installed_cuda_libs_version()
if installed is None:
return True
return installed != CUDA_LIBS_VERSION
async def _download_and_extract_archive(
client,
url: str,
sha256_url: Optional[str],
dest_dir: Path,
label: str,
progress_offset: int,
total_size: int,
):
"""Download a .tar.gz archive and extract it into dest_dir.
Args:
client: httpx.AsyncClient
url: URL of the .tar.gz archive
sha256_url: URL of the .sha256 checksum file (optional)
dest_dir: Directory to extract into
label: Human-readable label for progress updates
progress_offset: Byte offset for progress reporting (when downloading
multiple archives sequentially)
total_size: Total bytes across all downloads (for progress bar)
"""
progress = get_progress_manager()
temp_path = dest_dir / f".download-{label.replace(' ', '-')}.tmp"
# Clean up leftover partial download
if temp_path.exists():
temp_path.unlink()
# Fetch expected checksum (fail-fast: never extract an unverified archive)
expected_sha = None
if sha256_url:
try:
sha_resp = await client.get(sha256_url)
sha_resp.raise_for_status()
expected_sha = sha_resp.text.strip().split()[0]
logger.info(f"{label}: expected SHA-256: {expected_sha[:16]}...")
except Exception as e:
raise RuntimeError(f"{label}: failed to fetch checksum from {sha256_url}") from e
# Stream download, verify, and extract — always clean up temp file
downloaded = 0
try:
async with client.stream("GET", url) as response:
response.raise_for_status()
with open(temp_path, "wb") as f:
async for chunk in response.aiter_bytes(chunk_size=1024 * 1024):
f.write(chunk)
downloaded += len(chunk)
progress.update_progress(
PROGRESS_KEY,
current=progress_offset + downloaded,
total=total_size,
filename=f"Downloading {label}",
status="downloading",
)
# Verify integrity
if expected_sha:
progress.update_progress(
PROGRESS_KEY,
current=progress_offset + downloaded,
total=total_size,
filename=f"Verifying {label}...",
status="downloading",
)
sha256 = hashlib.sha256()
with open(temp_path, "rb") as f:
while True:
data = f.read(1024 * 1024)
if not data:
break
sha256.update(data)
actual = sha256.hexdigest()
if actual != expected_sha:
raise ValueError(
f"{label} integrity check failed: expected {expected_sha[:16]}..., got {actual[:16]}..."
)
logger.info(f"{label}: integrity verified")
# Extract (use data filter for path traversal protection on Python 3.12+)
progress.update_progress(
PROGRESS_KEY,
current=progress_offset + downloaded,
total=total_size,
filename=f"Extracting {label}...",
status="downloading",
)
with tarfile.open(temp_path, "r:gz") as tar:
if sys.version_info >= (3, 12):
tar.extractall(path=dest_dir, filter="data")
else:
tar.extractall(path=dest_dir)
logger.info(f"{label}: extracted to {dest_dir}")
finally:
if temp_path.exists():
temp_path.unlink()
return downloaded
async def download_cuda_binary(version: Optional[str] = None):
"""Download the CUDA backend (server core + CUDA libs if needed).
Downloads both archives from GitHub Releases, extracts them into
{data_dir}/backends/cuda/, and writes the cuda-libs.json manifest.
Only downloads what's needed:
- Server core: always redownloaded (versioned with app)
- CUDA libs: only if missing or version mismatch
Args:
version: Version tag (e.g. "v0.3.0"). Defaults to current app version.
"""
if _download_lock.locked():
logger.info("CUDA download already in progress, skipping duplicate request")
return
async with _download_lock:
await _download_cuda_binary_locked(version)
async def _download_cuda_binary_locked(version: Optional[str] = None):
"""Inner implementation of download_cuda_binary, called under _download_lock."""
import httpx
if version is None:
version = f"v{__version__}"
progress = get_progress_manager()
cuda_dir = get_cuda_dir()
need_server = _needs_server_download(version)
need_libs = _needs_cuda_libs_download()
if not need_server and not need_libs:
logger.info("CUDA backend is up to date, nothing to download")
return
logger.info(
f"Starting CUDA backend download for {version} "
f"(server={'yes' if need_server else 'cached'}, "
f"libs={'yes' if need_libs else 'cached'})"
)
progress.update_progress(
PROGRESS_KEY,
current=0,
total=0,
filename="Preparing download...",
status="downloading",
)
base_url = f"{GITHUB_RELEASES_URL}/{version}"
server_archive = "voicebox-server-cuda.tar.gz"
libs_archive = f"cuda-libs-{CUDA_LIBS_VERSION}.tar.gz"
try:
async with httpx.AsyncClient(follow_redirects=True, timeout=30.0) as client:
# Estimate total download size
total_size = 0
if need_server:
try:
head = await client.head(f"{base_url}/{server_archive}")
total_size += int(head.headers.get("content-length", 0))
except Exception:
pass
if need_libs:
try:
head = await client.head(f"{base_url}/{libs_archive}")
total_size += int(head.headers.get("content-length", 0))
except Exception:
pass
logger.info(f"Total download size: {total_size / 1024 / 1024:.1f} MB")
offset = 0
# Download server core
if need_server:
server_downloaded = await _download_and_extract_archive(
client,
url=f"{base_url}/{server_archive}",
sha256_url=f"{base_url}/{server_archive}.sha256",
dest_dir=cuda_dir,
label="CUDA server",
progress_offset=offset,
total_size=total_size,
)
offset += server_downloaded
# Make executable on Unix
exe_path = cuda_dir / get_cuda_exe_name()
if sys.platform != "win32" and exe_path.exists():
exe_path.chmod(0o755)
# Download CUDA libs
if need_libs:
await _download_and_extract_archive(
client,
url=f"{base_url}/{libs_archive}",
sha256_url=f"{base_url}/{libs_archive}.sha256",
dest_dir=cuda_dir,
label="CUDA libraries",
progress_offset=offset,
total_size=total_size,
)
# Write local cuda-libs.json manifest
manifest = {"version": CUDA_LIBS_VERSION}
get_cuda_libs_manifest_path().write_text(json.dumps(manifest, indent=2) + "\n")
logger.info(f"CUDA backend ready at {cuda_dir}")
progress.mark_complete(PROGRESS_KEY)
except Exception as e:
logger.error(f"CUDA backend download failed: {e}")
progress.mark_error(PROGRESS_KEY, str(e))
raise
def get_cuda_binary_version() -> Optional[str]:
"""Get the version of the installed CUDA binary, or None if not installed."""
import subprocess
cuda_path = get_cuda_binary_path()
if not cuda_path:
return None
try:
result = subprocess.run(
[str(cuda_path), "--version"],
capture_output=True,
text=True,
timeout=30,
cwd=str(cuda_path.parent), # Run from the onedir directory
)
# Output format: "voicebox-server 0.3.0"
for line in result.stdout.strip().splitlines():
if "voicebox-server" in line:
return line.split()[-1]
except Exception as e:
logger.warning(f"Could not get CUDA binary version: {e}")
return None
async def check_and_update_cuda_binary():
"""Check if the CUDA binary is outdated and auto-download if so.
Called on server startup. Checks both server version and CUDA libs
version. Downloads only what's needed.
"""
cuda_path = get_cuda_binary_path()
if not cuda_path:
return # No CUDA binary installed, nothing to update
need_server = _needs_server_download()
need_libs = _needs_cuda_libs_download()
if not need_server and not need_libs:
logger.info(f"CUDA binary is up to date (server=v{__version__}, libs={get_installed_cuda_libs_version()})")
return
reasons = []
if need_server:
cuda_version = get_cuda_binary_version()
reasons.append(f"server v{cuda_version} != v{__version__}")
if need_libs:
installed_libs = get_installed_cuda_libs_version()
reasons.append(f"libs {installed_libs} != {CUDA_LIBS_VERSION}")
logger.info(f"CUDA backend needs update ({', '.join(reasons)}). Auto-downloading...")
try:
await download_cuda_binary()
except Exception as e:
logger.error(f"Auto-update of CUDA binary failed: {e}")
async def delete_cuda_binary() -> bool:
"""Delete the downloaded CUDA backend directory. Returns True if deleted."""
import shutil
cuda_dir = get_cuda_dir()
if cuda_dir.exists() and any(cuda_dir.iterdir()):
shutil.rmtree(cuda_dir)
logger.info(f"Deleted CUDA backend directory: {cuda_dir}")
return True
return False

120
backend/services/effects.py Normal file
View File

@@ -0,0 +1,120 @@
"""
Effect presets CRUD operations.
"""
from __future__ import annotations
import json
import uuid
from typing import List, Optional
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from ..utils.effects import validate_effects_chain
from ..database import EffectPreset as DBEffectPreset
from ..models import EffectPresetResponse, EffectPresetCreate, EffectPresetUpdate, EffectConfig
def _preset_response(p: DBEffectPreset) -> EffectPresetResponse:
"""Convert a DB preset row to a Pydantic response."""
effects_chain = [EffectConfig(**e) for e in json.loads(p.effects_chain)]
return EffectPresetResponse(
id=p.id,
name=p.name,
description=p.description,
effects_chain=effects_chain,
is_builtin=p.is_builtin or False,
created_at=p.created_at,
)
def list_presets(db: Session) -> List[EffectPresetResponse]:
"""List all effect presets (built-in + user-created)."""
presets = db.query(DBEffectPreset).order_by(DBEffectPreset.sort_order, DBEffectPreset.name).all()
return [_preset_response(p) for p in presets]
def get_preset(preset_id: str, db: Session) -> Optional[EffectPresetResponse]:
"""Get a preset by ID."""
p = db.query(DBEffectPreset).filter_by(id=preset_id).first()
if not p:
return None
return _preset_response(p)
def get_preset_by_name(name: str, db: Session) -> Optional[EffectPresetResponse]:
"""Get a preset by name."""
p = db.query(DBEffectPreset).filter_by(name=name).first()
if not p:
return None
return _preset_response(p)
def create_preset(data: EffectPresetCreate, db: Session) -> EffectPresetResponse:
"""Create a new user effect preset."""
chain_dicts = [e.model_dump() for e in data.effects_chain]
error = validate_effects_chain(chain_dicts)
if error:
raise ValueError(error)
# Check for duplicate name before insert
existing = db.query(DBEffectPreset).filter_by(name=data.name).first()
if existing:
raise ValueError(f"A preset named '{data.name}' already exists")
preset = DBEffectPreset(
id=str(uuid.uuid4()),
name=data.name,
description=data.description,
effects_chain=json.dumps(chain_dicts),
is_builtin=False,
)
db.add(preset)
try:
db.commit()
except IntegrityError:
db.rollback()
raise ValueError(f"A preset named '{data.name}' already exists")
db.refresh(preset)
return _preset_response(preset)
def update_preset(preset_id: str, data: EffectPresetUpdate, db: Session) -> Optional[EffectPresetResponse]:
"""Update a user effect preset. Cannot modify built-in presets."""
preset = db.query(DBEffectPreset).filter_by(id=preset_id).first()
if not preset:
return None
if preset.is_builtin:
raise ValueError("Cannot modify built-in presets")
if data.name is not None:
preset.name = data.name
if data.description is not None:
preset.description = data.description
if data.effects_chain is not None:
chain_dicts = [e.model_dump() for e in data.effects_chain]
error = validate_effects_chain(chain_dicts)
if error:
raise ValueError(error)
preset.effects_chain = json.dumps(chain_dicts)
db.commit()
db.refresh(preset)
return _preset_response(preset)
def delete_preset(preset_id: str, db: Session) -> bool:
"""Delete a user effect preset. Cannot delete built-in presets."""
preset = db.query(DBEffectPreset).filter_by(id=preset_id).first()
if not preset:
return False
if preset.is_builtin:
raise ValueError("Cannot delete built-in presets")
db.delete(preset)
db.commit()
return True

View File

@@ -0,0 +1,461 @@
"""
Voice profile export/import module.
Handles exporting profiles to ZIP archives and importing them back.
Also handles exporting individual generations.
"""
import json
import zipfile
import io
from pathlib import Path
from typing import Optional
from sqlalchemy.orm import Session
from ..models import VoiceProfileResponse
from ..database import VoiceProfile as DBVoiceProfile, ProfileSample as DBProfileSample, Generation as DBGeneration, GenerationVersion as DBGenerationVersion
from .profiles import create_profile, add_profile_sample
from ..models import VoiceProfileCreate
from .. import config
def _get_unique_profile_name(name: str, db: Session) -> str:
"""
Get a unique profile name by appending a number if needed.
Args:
name: Original profile name
db: Database session
Returns:
Unique profile name
"""
base_name = name
counter = 1
while True:
existing = db.query(DBVoiceProfile).filter_by(name=name).first()
if not existing:
return name
name = f"{base_name} ({counter})"
counter += 1
def export_profile_to_zip(profile_id: str, db: Session) -> bytes:
"""
Export a voice profile to a ZIP archive.
Args:
profile_id: Profile ID to export
db: Database session
Returns:
ZIP file contents as bytes
Raises:
ValueError: If profile not found or has no samples
"""
# Get profile
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
raise ValueError(f"Profile {profile_id} not found")
# Get all samples
samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all()
if not samples:
raise ValueError(f"Profile {profile_id} has no samples")
# Create ZIP in memory
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# Check if profile has avatar
has_avatar = False
if profile.avatar_path:
avatar_path = config.resolve_storage_path(profile.avatar_path)
if avatar_path is not None and avatar_path.exists():
has_avatar = True
# Add avatar to ZIP root with original extension
avatar_ext = avatar_path.suffix
zip_file.write(avatar_path, f"avatar{avatar_ext}")
# Create manifest.json
manifest = {
"version": "1.0",
"profile": {
"name": profile.name,
"description": profile.description,
"language": profile.language,
},
"has_avatar": has_avatar,
}
zip_file.writestr("manifest.json", json.dumps(manifest, indent=2))
# Create samples.json mapping
samples_data = {}
profile_dir = config.get_profiles_dir() / profile_id
for sample in samples:
# Get filename from audio_path (should be {sample_id}.wav)
audio_path = config.resolve_storage_path(sample.audio_path)
if audio_path is None:
raise ValueError(f"Audio file not found: {sample.audio_path}")
filename = audio_path.name
# Read audio file
if not audio_path.exists():
raise ValueError(f"Audio file not found: {audio_path}")
# Add to samples directory in ZIP
zip_path = f"samples/{filename}"
zip_file.write(audio_path, zip_path)
# Map filename to reference text
samples_data[filename] = sample.reference_text
zip_file.writestr("samples.json", json.dumps(samples_data, indent=2))
zip_buffer.seek(0)
return zip_buffer.read()
async def import_profile_from_zip(file_bytes: bytes, db: Session) -> VoiceProfileResponse:
"""
Import a voice profile from a ZIP archive.
Args:
file_bytes: ZIP file contents
db: Database session
Returns:
Created profile
Raises:
ValueError: If ZIP is invalid or missing required files
"""
zip_buffer = io.BytesIO(file_bytes)
try:
with zipfile.ZipFile(zip_buffer, 'r') as zip_file:
# Validate ZIP structure
namelist = zip_file.namelist()
if "manifest.json" not in namelist:
raise ValueError("ZIP archive missing manifest.json")
if "samples.json" not in namelist:
raise ValueError("ZIP archive missing samples.json")
# Read manifest
manifest_data = json.loads(zip_file.read("manifest.json"))
if "version" not in manifest_data:
raise ValueError("Invalid manifest.json: missing version")
if "profile" not in manifest_data:
raise ValueError("Invalid manifest.json: missing profile")
profile_data = manifest_data["profile"]
# Read samples mapping
samples_data = json.loads(zip_file.read("samples.json"))
if not isinstance(samples_data, dict):
raise ValueError("Invalid samples.json: must be a dictionary")
# Get unique profile name
original_name = profile_data.get("name", "Imported Profile")
unique_name = _get_unique_profile_name(original_name, db)
# Create profile
profile_create = VoiceProfileCreate(
name=unique_name,
description=profile_data.get("description"),
language=profile_data.get("language", "en"),
)
profile = await create_profile(profile_create, db)
# Extract and add samples
profile_dir = config.get_profiles_dir() / profile.id
profile_dir.mkdir(parents=True, exist_ok=True)
# Handle avatar if present
avatar_files = [f for f in namelist if f.startswith("avatar.")]
if avatar_files:
try:
avatar_file = avatar_files[0]
# Extract to temporary file
import tempfile
with tempfile.NamedTemporaryFile(suffix=Path(avatar_file).suffix, delete=False) as tmp:
tmp.write(zip_file.read(avatar_file))
tmp_path = tmp.name
try:
from .profiles import upload_avatar
await upload_avatar(profile.id, tmp_path, db)
finally:
Path(tmp_path).unlink(missing_ok=True)
except Exception as e:
# Avatar import is optional - continue even if it fails
pass
for filename, reference_text in samples_data.items():
# Validate filename
if not filename.endswith('.wav'):
raise ValueError(f"Invalid sample filename: {filename} (must be .wav)")
# Extract audio file to temp location
zip_path = f"samples/{filename}"
if zip_path not in namelist:
raise ValueError(f"Sample file not found in ZIP: {zip_path}")
# Extract to temporary file
import tempfile
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(zip_file.read(zip_path))
tmp_path = tmp.name
try:
# Add sample to profile
await add_profile_sample(
profile.id,
tmp_path,
reference_text,
db,
)
finally:
# Clean up temp file
Path(tmp_path).unlink(missing_ok=True)
return profile
except zipfile.BadZipFile:
raise ValueError("Invalid ZIP file")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in archive: {e}")
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Error importing profile: {str(e)}")
def export_generation_to_zip(generation_id: str, db: Session) -> bytes:
"""
Export a generation to a ZIP archive.
Args:
generation_id: Generation ID to export
db: Database session
Returns:
ZIP file contents as bytes
Raises:
ValueError: If generation not found
"""
# Get generation
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
if not generation:
raise ValueError(f"Generation {generation_id} not found")
# Get profile info
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
if not profile:
raise ValueError(f"Profile {generation.profile_id} not found")
# Get all versions for this generation
versions = (
db.query(DBGenerationVersion)
.filter_by(generation_id=generation_id)
.order_by(DBGenerationVersion.created_at)
.all()
)
# Create ZIP in memory
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
# Build version manifest entries
version_entries = []
for v in versions:
v_path = config.resolve_storage_path(v.audio_path)
effects_chain = None
if v.effects_chain:
effects_chain = json.loads(v.effects_chain)
version_entries.append({
"id": v.id,
"label": v.label,
"is_default": v.is_default,
"effects_chain": effects_chain,
"filename": v_path.name,
})
manifest = {
"version": "1.0",
"generation": {
"id": generation.id,
"text": generation.text,
"language": generation.language,
"duration": generation.duration,
"seed": generation.seed,
"instruct": generation.instruct,
"created_at": generation.created_at.isoformat(),
},
"profile": {
"id": profile.id,
"name": profile.name,
"description": profile.description,
"language": profile.language,
},
"versions": version_entries,
}
zip_file.writestr("manifest.json", json.dumps(manifest, indent=2))
# Add all version audio files
for v in versions:
v_path = config.resolve_storage_path(v.audio_path)
if v_path is not None and v_path.exists():
zip_file.write(v_path, f"audio/{v_path.name}")
# Fallback: if no versions exist, include the generation's main audio
if not versions:
audio_path = config.resolve_storage_path(generation.audio_path)
if audio_path is not None and audio_path.exists():
zip_file.write(audio_path, f"audio/{audio_path.name}")
zip_buffer.seek(0)
return zip_buffer.read()
async def import_generation_from_zip(file_bytes: bytes, db: Session) -> dict:
"""
Import a generation from a ZIP archive.
Args:
file_bytes: ZIP file contents
db: Database session
Returns:
Dictionary with generation ID and profile info
Raises:
ValueError: If ZIP is invalid or missing required files
"""
from pathlib import Path
import tempfile
import shutil
from datetime import datetime
from .. import config
zip_buffer = io.BytesIO(file_bytes)
try:
with zipfile.ZipFile(zip_buffer, 'r') as zip_file:
# Validate ZIP structure
namelist = zip_file.namelist()
if "manifest.json" not in namelist:
raise ValueError("ZIP archive missing manifest.json")
# Read manifest
manifest_data = json.loads(zip_file.read("manifest.json"))
if "version" not in manifest_data:
raise ValueError("Invalid manifest.json: missing version")
if "generation" not in manifest_data:
raise ValueError("Invalid manifest.json: missing generation data")
generation_data = manifest_data["generation"]
profile_data = manifest_data.get("profile", {})
# Validate required fields
required_fields = ["text", "language", "duration"]
for field in required_fields:
if field not in generation_data:
raise ValueError(f"Invalid manifest.json: missing generation.{field}")
# Find audio file in archive
audio_files = [f for f in namelist if f.startswith("audio/") and f.endswith(".wav")]
if not audio_files:
raise ValueError("No audio file found in ZIP archive")
audio_file_path = audio_files[0]
# Check if we should match an existing profile or create metadata
profile_id = None
profile_name = profile_data.get("name", "Unknown Profile")
# Try to find matching profile by name
if profile_name and profile_name != "Unknown Profile":
existing_profile = db.query(DBVoiceProfile).filter_by(name=profile_name).first()
if existing_profile:
profile_id = existing_profile.id
# If no matching profile, use a placeholder or the first available profile
if not profile_id:
# Get any profile, or None if no profiles exist
any_profile = db.query(DBVoiceProfile).first()
if any_profile:
profile_id = any_profile.id
profile_name = any_profile.name
else:
raise ValueError("No voice profiles found. Please create a profile before importing generations.")
# Extract audio file to temporary location
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(zip_file.read(audio_file_path))
tmp_path = tmp.name
try:
# Create generations directory
generations_dir = config.get_generations_dir()
generations_dir.mkdir(parents=True, exist_ok=True)
# Generate new ID for this generation
new_generation_id = str(__import__('uuid').uuid4())
# Copy audio to generations directory
audio_dest = generations_dir / f"{new_generation_id}.wav"
shutil.copy(tmp_path, audio_dest)
# Create generation record
db_generation = DBGeneration(
id=new_generation_id,
profile_id=profile_id,
text=generation_data["text"],
language=generation_data["language"],
audio_path=config.to_storage_path(audio_dest),
duration=generation_data["duration"],
seed=generation_data.get("seed"),
instruct=generation_data.get("instruct"),
created_at=datetime.utcnow(),
)
db.add(db_generation)
db.commit()
db.refresh(db_generation)
return {
"id": db_generation.id,
"profile_id": profile_id,
"profile_name": profile_name,
"text": db_generation.text,
"message": f"Generation imported successfully (assigned to profile: {profile_name})"
}
finally:
# Clean up temp file
Path(tmp_path).unlink(missing_ok=True)
except zipfile.BadZipFile:
raise ValueError("Invalid ZIP file")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in archive: {e}")
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Error importing generation: {str(e)}")

View File

@@ -0,0 +1,263 @@
"""
Unified TTS generation orchestration.
Replaces the three near-identical closures (_run_generation, _run_retry,
_run_regenerate) that lived in main.py with a single ``run_generation()``
function parameterized by *mode*.
Mode differences:
- "generate" : full pipeline -- save clean version, optionally apply
effects and create a processed version.
- "retry" : re-runs a failed generation with the same seed.
No effects, no version creation.
- "regenerate" : re-runs with seed=None for variation. Creates a new
version with an auto-incremented "take-N" label.
"""
from __future__ import annotations
import asyncio
import traceback
from typing import Literal, Optional
from .. import config
from . import history, profiles
from ..database import get_db
from ..utils.tasks import get_task_manager
async def run_generation(
*,
generation_id: str,
profile_id: str,
text: str,
language: str,
engine: str,
model_size: str,
seed: Optional[int],
normalize: bool = False,
effects_chain: Optional[list] = None,
instruct: Optional[str] = None,
mode: Literal["generate", "retry", "regenerate"],
max_chunk_chars: Optional[int] = None,
crossfade_ms: Optional[int] = None,
version_id: Optional[str] = None,
) -> None:
"""Execute TTS inference and persist the result.
This is the single entry point for all background generation work.
It is designed to be enqueued via ``services.task_queue.enqueue_generation``.
"""
from ..backends import load_engine_model, get_tts_backend_for_engine, engine_needs_trim
from ..utils.chunked_tts import generate_chunked
from ..utils.audio import normalize_audio, save_audio, trim_tts_output
task_manager = get_task_manager()
bg_db = next(get_db())
try:
tts_model = get_tts_backend_for_engine(engine)
if not tts_model.is_loaded():
await history.update_generation_status(generation_id, "loading_model", bg_db)
await load_engine_model(engine, model_size)
voice_prompt = await profiles.create_voice_prompt_for_profile(
profile_id,
bg_db,
use_cache=True,
engine=engine,
)
await history.update_generation_status(generation_id, "generating", bg_db)
trim_fn = trim_tts_output if engine_needs_trim(engine) else None
gen_kwargs: dict = dict(
language=language,
seed=seed if mode != "regenerate" else None,
instruct=instruct,
trim_fn=trim_fn,
)
if max_chunk_chars is not None:
gen_kwargs["max_chunk_chars"] = max_chunk_chars
if crossfade_ms is not None:
gen_kwargs["crossfade_ms"] = crossfade_ms
audio, sample_rate = await generate_chunked(tts_model, text, voice_prompt, **gen_kwargs)
# --- Normalize (generate and regenerate always; retry skips) -----
if normalize or mode == "regenerate":
audio = normalize_audio(audio)
duration = len(audio) / sample_rate
# --- Persist audio and update status -----------------------------
if mode == "generate":
final_path = _save_generate(
generation_id=generation_id,
audio=audio,
sample_rate=sample_rate,
effects_chain=effects_chain,
save_audio=save_audio,
db=bg_db,
)
elif mode == "retry":
final_path = _save_retry(
generation_id=generation_id,
audio=audio,
sample_rate=sample_rate,
save_audio=save_audio,
)
elif mode == "regenerate":
final_path = _save_regenerate(
generation_id=generation_id,
version_id=version_id,
audio=audio,
sample_rate=sample_rate,
save_audio=save_audio,
db=bg_db,
)
await history.update_generation_status(
generation_id=generation_id,
status="completed",
db=bg_db,
audio_path=final_path,
duration=duration,
)
except asyncio.CancelledError:
await history.update_generation_status(
generation_id=generation_id,
status="failed",
db=bg_db,
error="Generation cancelled",
)
except Exception as e:
traceback.print_exc()
await history.update_generation_status(
generation_id=generation_id,
status="failed",
db=bg_db,
error=str(e),
)
finally:
task_manager.complete_generation(generation_id)
bg_db.close()
def _save_generate(
*,
generation_id: str,
audio,
sample_rate: int,
effects_chain: Optional[list],
save_audio,
db,
) -> str:
"""Save clean version and optionally an effects-processed version.
Returns the final audio path (processed if effects were applied,
otherwise clean).
"""
from . import versions as versions_mod
clean_audio_path = config.get_generations_dir() / f"{generation_id}.wav"
save_audio(audio, str(clean_audio_path), sample_rate)
has_effects = effects_chain and any(e.get("enabled", True) for e in effects_chain)
versions_mod.create_version(
generation_id=generation_id,
label="original",
audio_path=config.to_storage_path(clean_audio_path),
db=db,
effects_chain=None,
is_default=not has_effects,
)
final_audio_path = str(clean_audio_path)
if has_effects:
from ..utils.effects import apply_effects, validate_effects_chain
assert effects_chain is not None
error_msg = validate_effects_chain(effects_chain)
if error_msg:
import logging
logging.getLogger(__name__).warning("invalid effects chain, skipping: %s", error_msg)
versions_mod.set_default_version(
versions_mod.list_versions(generation_id, db)[0].id, db
)
else:
processed_audio = apply_effects(audio, sample_rate, effects_chain)
processed_path = config.get_generations_dir() / f"{generation_id}_processed.wav"
save_audio(processed_audio, str(processed_path), sample_rate)
final_audio_path = str(processed_path)
versions_mod.create_version(
generation_id=generation_id,
label="version-2",
audio_path=config.to_storage_path(processed_path),
db=db,
effects_chain=effects_chain,
is_default=True,
)
return config.to_storage_path(final_audio_path)
def _save_retry(
*,
generation_id: str,
audio,
sample_rate: int,
save_audio,
) -> str:
"""Save retry output -- single file, no versions.
Returns the audio path.
"""
audio_path = config.get_generations_dir() / f"{generation_id}.wav"
save_audio(audio, str(audio_path), sample_rate)
return config.to_storage_path(audio_path)
def _save_regenerate(
*,
generation_id: str,
version_id: Optional[str],
audio,
sample_rate: int,
save_audio,
db,
) -> str:
"""Save regeneration output as a new version with auto-label.
Returns the audio path.
"""
from . import versions as versions_mod
import uuid as _uuid
suffix = _uuid.uuid4().hex[:8]
audio_path = config.get_generations_dir() / f"{generation_id}_{suffix}.wav"
save_audio(audio, str(audio_path), sample_rate)
# Count via DB query rather than list length to avoid TOCTOU race
from ..database import GenerationVersion as DBGenerationVersion
count = db.query(DBGenerationVersion).filter_by(generation_id=generation_id).count()
label = f"take-{count + 1}"
versions_mod.create_version(
generation_id=generation_id,
label=label,
audio_path=config.to_storage_path(audio_path),
db=db,
effects_chain=None,
is_default=True,
)
return config.to_storage_path(audio_path)

368
backend/services/history.py Normal file
View File

@@ -0,0 +1,368 @@
"""
Generation history management module.
"""
from typing import List, Optional, Tuple
from datetime import datetime
import uuid
import shutil
from pathlib import Path
from sqlalchemy.orm import Session
from sqlalchemy import or_
from ..models import GenerationRequest, GenerationResponse, HistoryQuery, HistoryResponse, HistoryListResponse, GenerationVersionResponse, EffectConfig
from ..database import Generation as DBGeneration, GenerationVersion as DBGenerationVersion, VoiceProfile as DBVoiceProfile
from .. import config
def _get_versions_for_generation(generation_id: str, db: Session) -> tuple:
"""Get versions list and active version ID for a generation."""
import json
versions_rows = (
db.query(DBGenerationVersion)
.filter_by(generation_id=generation_id)
.order_by(DBGenerationVersion.created_at)
.all()
)
if not versions_rows:
return None, None
versions = []
active_version_id = None
for v in versions_rows:
effects_chain = None
if v.effects_chain:
try:
raw = json.loads(v.effects_chain)
effects_chain = [EffectConfig(**e) for e in raw]
except Exception:
pass
versions.append(GenerationVersionResponse(
id=v.id,
generation_id=v.generation_id,
label=v.label,
audio_path=v.audio_path,
effects_chain=effects_chain,
is_default=v.is_default,
created_at=v.created_at,
))
if v.is_default:
active_version_id = v.id
return versions, active_version_id
async def create_generation(
profile_id: str,
text: str,
language: str,
audio_path: str,
duration: float,
seed: Optional[int],
db: Session,
instruct: Optional[str] = None,
generation_id: Optional[str] = None,
status: str = "completed",
engine: Optional[str] = "qwen",
model_size: Optional[str] = None,
) -> GenerationResponse:
"""
Create a new generation history entry.
Args:
profile_id: Profile ID used for generation
text: Generated text
language: Language code
audio_path: Path where audio was saved
duration: Audio duration in seconds
seed: Random seed used (if any)
db: Database session
instruct: Natural language instruction used (if any)
generation_id: Pre-assigned ID (for async generation flow)
status: Generation status (generating, completed, failed)
engine: TTS engine used (qwen, luxtts, chatterbox, chatterbox_turbo)
model_size: Model size variant (1.7B, 0.6B) — only relevant for qwen
Returns:
Created generation entry
"""
db_generation = DBGeneration(
id=generation_id or str(uuid.uuid4()),
profile_id=profile_id,
text=text,
language=language,
audio_path=audio_path,
duration=duration,
seed=seed,
instruct=instruct,
engine=engine,
model_size=model_size,
status=status,
created_at=datetime.utcnow(),
)
db.add(db_generation)
db.commit()
db.refresh(db_generation)
return GenerationResponse.model_validate(db_generation)
async def update_generation_status(
generation_id: str,
status: str,
db: Session,
audio_path: Optional[str] = None,
duration: Optional[float] = None,
error: Optional[str] = None,
) -> Optional[GenerationResponse]:
"""Update the status of a generation (used by async generation flow)."""
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
if not generation:
return None
generation.status = status
if audio_path is not None:
generation.audio_path = audio_path
if duration is not None:
generation.duration = duration
if error is not None:
generation.error = error
db.commit()
db.refresh(generation)
return GenerationResponse.model_validate(generation)
async def get_generation(
generation_id: str,
db: Session,
) -> Optional[GenerationResponse]:
"""
Get a generation by ID.
Args:
generation_id: Generation ID
db: Database session
Returns:
Generation or None if not found
"""
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
if not generation:
return None
return GenerationResponse.model_validate(generation)
async def list_generations(
query: HistoryQuery,
db: Session,
) -> HistoryListResponse:
"""
List generations with optional filters.
Args:
query: Query parameters (filters, pagination)
db: Database session
Returns:
HistoryListResponse with items and total count
"""
# Build base query with join to get profile name
q = db.query(
DBGeneration,
DBVoiceProfile.name.label('profile_name')
).join(
DBVoiceProfile,
DBGeneration.profile_id == DBVoiceProfile.id
)
# Apply profile filter
if query.profile_id:
q = q.filter(DBGeneration.profile_id == query.profile_id)
# Apply search filter (searches in text content)
if query.search:
search_pattern = f"%{query.search}%"
q = q.filter(DBGeneration.text.like(search_pattern))
# Get total count before pagination
total_count = q.count()
# Apply ordering (newest first)
q = q.order_by(DBGeneration.created_at.desc())
# Apply pagination
q = q.offset(query.offset).limit(query.limit)
# Execute query
results = q.all()
# Convert to HistoryResponse with profile_name
items = []
for generation, profile_name in results:
versions, active_version_id = _get_versions_for_generation(generation.id, db)
items.append(HistoryResponse(
id=generation.id,
profile_id=generation.profile_id,
profile_name=profile_name,
text=generation.text,
language=generation.language,
audio_path=generation.audio_path,
duration=generation.duration,
seed=generation.seed,
instruct=generation.instruct,
engine=generation.engine or "qwen",
model_size=generation.model_size,
status=generation.status or "completed",
error=generation.error,
is_favorited=bool(generation.is_favorited),
created_at=generation.created_at,
versions=versions,
active_version_id=active_version_id,
))
return HistoryListResponse(
items=items,
total=total_count,
)
async def delete_generation(
generation_id: str,
db: Session,
) -> bool:
"""
Delete a generation.
Args:
generation_id: Generation ID
db: Database session
Returns:
True if deleted, False if not found
"""
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
if not generation:
return False
# Delete all version files and records
from . import versions as versions_mod
versions_mod.delete_versions_for_generation(generation_id, db)
# Delete main audio file (if not already removed by version cleanup)
if generation.audio_path:
audio_path = config.resolve_storage_path(generation.audio_path)
if audio_path is not None and audio_path.exists():
audio_path.unlink()
# Delete from database
db.delete(generation)
db.commit()
return True
async def delete_failed_generations(db: Session) -> int:
"""
Delete every generation whose status is 'failed'.
Used by the "Clear failed" action in the UI so users can tidy up
history after the model wasn't loaded, the app was closed mid-run,
or a generation otherwise errored out (see issue #410).
Returns:
Number of generations deleted.
"""
from . import versions as versions_mod
failed = db.query(DBGeneration).filter(DBGeneration.status == "failed").all()
count = 0
for generation in failed:
# Clean up version files/rows first.
versions_mod.delete_versions_for_generation(generation.id, db)
# Remove the main audio file if it somehow made it to disk.
if generation.audio_path:
audio_path = config.resolve_storage_path(generation.audio_path)
if audio_path is not None and audio_path.exists():
try:
audio_path.unlink()
except OSError:
# Best-effort cleanup — don't abort the whole sweep
# if a single file can't be removed.
pass
db.delete(generation)
count += 1
db.commit()
return count
async def delete_generations_by_profile(
profile_id: str,
db: Session,
) -> int:
"""
Delete all generations for a profile.
Args:
profile_id: Profile ID
db: Database session
Returns:
Number of generations deleted
"""
generations = db.query(DBGeneration).filter_by(profile_id=profile_id).all()
count = 0
for generation in generations:
# Delete associated version files and rows first
from . import versions as versions_mod
versions_mod.delete_versions_for_generation(generation.id, db)
# Delete audio file
audio_path = config.resolve_storage_path(generation.audio_path)
if audio_path is not None and audio_path.exists():
audio_path.unlink()
# Delete from database
db.delete(generation)
count += 1
db.commit()
return count
async def get_generation_stats(db: Session) -> dict:
"""
Get generation statistics.
Args:
db: Database session
Returns:
Statistics dictionary
"""
from sqlalchemy import func
total = db.query(func.count(DBGeneration.id)).scalar()
total_duration = db.query(func.sum(DBGeneration.duration)).scalar() or 0
# Get generations by profile
by_profile = db.query(
DBGeneration.profile_id,
func.count(DBGeneration.id).label('count')
).group_by(DBGeneration.profile_id).all()
return {
"total_generations": total,
"total_duration_seconds": total_duration,
"generations_by_profile": {
profile_id: count for profile_id, count in by_profile
},
}

View File

@@ -0,0 +1,686 @@
"""Voice profile management module."""
import json as _json
import logging
import shutil
import uuid
from datetime import datetime
from pathlib import Path
from sqlalchemy import func
from sqlalchemy.orm import Session
from .. import config
from ..database import Generation as DBGeneration, ProfileSample as DBProfileSample, VoiceProfile as DBVoiceProfile
from ..models import (
EffectConfig,
ProfileSampleResponse,
VoiceProfileCreate,
VoiceProfileResponse,
)
from ..utils.audio import save_audio, validate_and_load_reference_audio
from ..utils.cache import _get_cache_dir, clear_profile_cache
from ..utils.images import process_avatar, validate_image
logger = logging.getLogger(__name__)
CLONING_ENGINES = {"qwen", "luxtts", "chatterbox", "chatterbox_turbo", "tada"}
def _profile_to_response(
profile: DBVoiceProfile,
generation_count: int = 0,
sample_count: int = 0,
) -> VoiceProfileResponse:
"""Convert a DB profile to a VoiceProfileResponse, deserializing effects_chain."""
effects_chain = None
if profile.effects_chain:
try:
raw = _json.loads(profile.effects_chain)
effects_chain = [EffectConfig(**e) for e in raw]
except Exception as e:
import logging
logging.warning(f"Failed to parse effects_chain for profile {profile.id}: {e}")
return VoiceProfileResponse(
id=profile.id,
name=profile.name,
description=profile.description,
language=profile.language,
avatar_path=profile.avatar_path,
effects_chain=effects_chain,
voice_type=getattr(profile, "voice_type", None) or "cloned",
preset_engine=getattr(profile, "preset_engine", None),
preset_voice_id=getattr(profile, "preset_voice_id", None),
design_prompt=getattr(profile, "design_prompt", None),
default_engine=getattr(profile, "default_engine", None),
generation_count=generation_count,
sample_count=sample_count,
created_at=profile.created_at,
updated_at=profile.updated_at,
)
def _get_preset_voice_ids(engine: str) -> set[str]:
if engine == "kokoro":
from ..backends.kokoro_backend import KOKORO_VOICES
return {voice_id for voice_id, _name, _gender, _lang in KOKORO_VOICES}
if engine == "qwen_custom_voice":
from ..backends.qwen_custom_voice_backend import QWEN_CUSTOM_VOICES
return {voice_id for voice_id, _name, _gender, _lang, _desc in QWEN_CUSTOM_VOICES}
return set()
def _validate_profile_fields(
*,
voice_type: str,
preset_engine: str | None,
preset_voice_id: str | None,
design_prompt: str | None,
default_engine: str | None,
) -> str | None:
if voice_type == "preset":
if not preset_engine or not preset_voice_id:
return "Preset profiles require both preset_engine and preset_voice_id"
if default_engine and default_engine != preset_engine:
return "Preset profiles must use their preset_engine as default_engine"
available_voice_ids = _get_preset_voice_ids(preset_engine)
if available_voice_ids and preset_voice_id not in available_voice_ids:
return f"Preset voice '{preset_voice_id}' is not valid for engine '{preset_engine}'"
return None
if voice_type == "designed":
if not design_prompt or not design_prompt.strip():
return "Designed profiles require a design_prompt"
if preset_engine or preset_voice_id:
return "Designed profiles cannot set preset_engine or preset_voice_id"
return None
if preset_engine or preset_voice_id:
return "Cloned profiles cannot set preset_engine or preset_voice_id"
if design_prompt:
return "Cloned profiles cannot set design_prompt"
if default_engine and default_engine not in CLONING_ENGINES:
return f"Cloned profiles cannot use default engine '{default_engine}'"
return None
def validate_profile_engine(profile, engine: str) -> None:
voice_type = getattr(profile, "voice_type", None) or "cloned"
if voice_type == "preset":
preset_engine = getattr(profile, "preset_engine", None)
preset_voice_id = getattr(profile, "preset_voice_id", None)
if not preset_engine or not preset_voice_id:
raise ValueError(f"Preset profile {profile.id} is missing preset engine metadata")
if preset_engine != engine:
raise ValueError(
f"Preset profile {profile.id} only supports engine '{preset_engine}', not '{engine}'"
)
return
if voice_type == "designed":
design_prompt = getattr(profile, "design_prompt", None)
if not design_prompt or not design_prompt.strip():
raise ValueError(f"Designed profile {profile.id} is missing design_prompt")
return
if engine not in CLONING_ENGINES:
raise ValueError(f"Engine '{engine}' does not support cloned voice profiles")
async def create_profile(
data: VoiceProfileCreate,
db: Session,
) -> VoiceProfileResponse:
"""
Create a new voice profile.
Args:
data: Profile creation data
db: Database session
Returns:
Created profile
Raises:
ValueError: If a profile with the same name already exists
"""
existing_profile = db.query(DBVoiceProfile).filter_by(name=data.name).first()
if existing_profile:
raise ValueError(f"A profile with the name '{data.name}' already exists. Please choose a different name.")
# Auto-set default_engine for preset profiles
default_engine = data.default_engine
voice_type = data.voice_type or "cloned"
if voice_type == "preset" and data.preset_engine and not default_engine:
default_engine = data.preset_engine
validation_error = _validate_profile_fields(
voice_type=voice_type,
preset_engine=data.preset_engine,
preset_voice_id=data.preset_voice_id,
design_prompt=data.design_prompt,
default_engine=default_engine,
)
if validation_error:
raise ValueError(validation_error)
db_profile = DBVoiceProfile(
id=str(uuid.uuid4()),
name=data.name,
description=data.description,
language=data.language,
voice_type=voice_type,
preset_engine=data.preset_engine,
preset_voice_id=data.preset_voice_id,
design_prompt=data.design_prompt,
default_engine=default_engine,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
db.add(db_profile)
db.commit()
db.refresh(db_profile)
profile_dir = config.get_profiles_dir() / db_profile.id
profile_dir.mkdir(parents=True, exist_ok=True)
return _profile_to_response(db_profile)
async def add_profile_sample(
profile_id: str,
audio_path: str,
reference_text: str,
db: Session,
) -> ProfileSampleResponse:
"""
Add a sample to a voice profile.
Args:
profile_id: Profile ID
audio_path: Path to temporary audio file
reference_text: Transcript of audio
db: Database session
Returns:
Created sample
"""
import asyncio
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
raise ValueError(f"Profile {profile_id} not found")
# Validate and load audio in a single pass, off the event loop
is_valid, error_msg, audio, sr = await asyncio.to_thread(
validate_and_load_reference_audio, audio_path
)
if not is_valid:
raise ValueError(f"Invalid reference audio: {error_msg}")
sample_id = str(uuid.uuid4())
profile_dir = config.get_profiles_dir() / profile_id
profile_dir.mkdir(parents=True, exist_ok=True)
dest_path = profile_dir / f"{sample_id}.wav"
await asyncio.to_thread(save_audio, audio, str(dest_path), sr)
db_sample = DBProfileSample(
id=sample_id,
profile_id=profile_id,
audio_path=config.to_storage_path(dest_path),
reference_text=reference_text,
)
db.add(db_sample)
profile.updated_at = datetime.utcnow()
db.commit()
db.refresh(db_sample)
# Invalidate combined audio cache for this profile
# Since a new sample was added, any cached combined audio is now stale
clear_profile_cache(profile_id)
return ProfileSampleResponse.model_validate(db_sample)
async def get_profile(
profile_id: str,
db: Session,
) -> VoiceProfileResponse | None:
"""
Get a voice profile by ID.
Args:
profile_id: Profile ID
db: Database session
Returns:
Profile or None if not found
"""
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
return None
return _profile_to_response(profile)
async def get_profile_samples(
profile_id: str,
db: Session,
) -> list[ProfileSampleResponse]:
"""
Get all samples for a profile.
Args:
profile_id: Profile ID
db: Database session
Returns:
List of samples
"""
samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all()
return [ProfileSampleResponse.model_validate(s) for s in samples]
async def list_profiles(db: Session) -> list[VoiceProfileResponse]:
"""
List all voice profiles with generation and sample counts.
Args:
db: Database session
Returns:
List of profiles
"""
profiles = db.query(DBVoiceProfile).order_by(DBVoiceProfile.created_at.desc()).all()
if not profiles:
return []
# Batch-fetch generation counts
gen_counts_rows = (
db.query(DBGeneration.profile_id, func.count(DBGeneration.id)).group_by(DBGeneration.profile_id).all()
)
gen_counts = {row[0]: row[1] for row in gen_counts_rows}
# Batch-fetch sample counts
sample_counts_rows = (
db.query(DBProfileSample.profile_id, func.count(DBProfileSample.id)).group_by(DBProfileSample.profile_id).all()
)
sample_counts = {row[0]: row[1] for row in sample_counts_rows}
return [
_profile_to_response(
p,
generation_count=gen_counts.get(p.id, 0),
sample_count=sample_counts.get(p.id, 0),
)
for p in profiles
]
async def update_profile(
profile_id: str,
data: VoiceProfileCreate,
db: Session,
) -> VoiceProfileResponse | None:
"""
Update a voice profile.
Args:
profile_id: Profile ID
data: Updated profile data
db: Database session
Returns:
Updated profile or None if not found
Raises:
ValueError: If a profile with the same name already exists (different profile)
"""
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
return None
if profile.name != data.name:
existing_profile = db.query(DBVoiceProfile).filter_by(name=data.name).first()
if existing_profile:
raise ValueError(f"A profile with the name '{data.name}' already exists. Please choose a different name.")
voice_type = getattr(profile, "voice_type", None) or "cloned"
preset_engine = getattr(profile, "preset_engine", None)
preset_voice_id = getattr(profile, "preset_voice_id", None)
design_prompt = getattr(profile, "design_prompt", None)
default_engine = data.default_engine if data.default_engine is not None else getattr(profile, "default_engine", None)
validation_error = _validate_profile_fields(
voice_type=voice_type,
preset_engine=preset_engine,
preset_voice_id=preset_voice_id,
design_prompt=design_prompt,
default_engine=default_engine,
)
if validation_error:
raise ValueError(validation_error)
profile.name = data.name
profile.description = data.description
profile.language = data.language
if data.default_engine is not None:
profile.default_engine = data.default_engine or None # empty string → NULL
profile.updated_at = datetime.utcnow()
db.commit()
db.refresh(profile)
return _profile_to_response(profile)
async def delete_profile(
profile_id: str,
db: Session,
) -> bool:
"""
Delete a voice profile and all associated data.
Args:
profile_id: Profile ID
db: Database session
Returns:
True if deleted, False if not found
"""
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
return False
db.query(DBProfileSample).filter_by(profile_id=profile_id).delete()
db.delete(profile)
db.commit()
profile_dir = config.get_profiles_dir() / profile_id
if profile_dir.exists():
shutil.rmtree(profile_dir)
# Clean up combined audio cache files for this profile
clear_profile_cache(profile_id)
return True
async def delete_profile_sample(
sample_id: str,
db: Session,
) -> bool:
"""
Delete a profile sample.
Args:
sample_id: Sample ID
db: Database session
Returns:
True if deleted, False if not found
"""
sample = db.query(DBProfileSample).filter_by(id=sample_id).first()
if not sample:
return False
# Store profile_id before deleting
profile_id = sample.profile_id
audio_path = config.resolve_storage_path(sample.audio_path)
if audio_path is not None and audio_path.exists():
audio_path.unlink()
db.delete(sample)
db.commit()
# Invalidate combined audio cache for this profile
# Since the sample set changed, any cached combined audio is now stale
clear_profile_cache(profile_id)
return True
async def update_profile_sample(
sample_id: str,
reference_text: str,
db: Session,
) -> ProfileSampleResponse | None:
"""
Update a profile sample's reference text.
Args:
sample_id: Sample ID
reference_text: Updated reference text
db: Database session
Returns:
Updated sample or None if not found
"""
sample = db.query(DBProfileSample).filter_by(id=sample_id).first()
if not sample:
return None
# Store profile_id before updating
profile_id = sample.profile_id
sample.reference_text = reference_text
db.commit()
db.refresh(sample)
# Invalidate combined audio cache for this profile
# Since the reference text changed, cache keys and combined text are now stale
clear_profile_cache(profile_id)
return ProfileSampleResponse.model_validate(sample)
async def create_voice_prompt_for_profile(
profile_id: str,
db: Session,
use_cache: bool = True,
engine: str = "qwen",
) -> dict:
"""
Create a voice prompt from a profile.
For cloned profiles: combines all audio samples into a voice prompt.
For preset profiles: returns the engine-specific preset voice reference.
For designed profiles: returns the text design prompt (future).
Args:
profile_id: Profile ID
db: Database session
use_cache: Whether to use cached prompts
engine: TTS engine to create prompt for
Returns:
Voice prompt dictionary
"""
from ..backends import get_tts_backend_for_engine
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
raise ValueError(f"Profile not found: {profile_id}")
voice_type = getattr(profile, "voice_type", None) or "cloned"
validate_profile_engine(profile, engine)
# ── Preset profiles: return engine-specific voice reference ──
if voice_type == "preset":
if not profile.preset_engine or not profile.preset_voice_id:
raise ValueError(f"Preset profile {profile_id} is missing preset engine metadata")
if profile.preset_engine != engine:
raise ValueError(
f"Preset profile {profile_id} only supports engine '{profile.preset_engine}', not '{engine}'"
)
return {
"voice_type": "preset",
"preset_engine": profile.preset_engine,
"preset_voice_id": profile.preset_voice_id,
}
# ── Designed profiles: return text description (future) ──
if voice_type == "designed":
if not profile.design_prompt or not profile.design_prompt.strip():
raise ValueError(f"Designed profile {profile_id} is missing design_prompt")
return {
"voice_type": "designed",
"design_prompt": profile.design_prompt,
}
if engine not in CLONING_ENGINES:
raise ValueError(f"Engine '{engine}' does not support cloned voice profiles")
# ── Cloned profiles: create from audio samples ──
samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all()
if not samples:
raise ValueError(f"No samples found for profile {profile_id}")
tts_model = get_tts_backend_for_engine(engine)
if len(samples) == 1:
sample = samples[0]
sample_audio_path = config.resolve_storage_path(sample.audio_path)
if sample_audio_path is None:
raise ValueError(f"Sample audio not found for profile {profile_id}")
voice_prompt, _ = await tts_model.create_voice_prompt(
str(sample_audio_path),
sample.reference_text,
use_cache=use_cache,
)
return voice_prompt
audio_paths = []
for sample in samples:
sample_audio_path = config.resolve_storage_path(sample.audio_path)
if sample_audio_path is None:
raise ValueError(f"Sample audio not found for profile {profile_id}")
audio_paths.append(str(sample_audio_path))
reference_texts = [s.reference_text for s in samples]
combined_audio, combined_text = await tts_model.combine_voice_prompts(
audio_paths,
reference_texts,
)
# Save combined audio to cache directory (persistent)
# Create a hash of sample IDs to identify this specific combination
import hashlib
sample_ids_str = "-".join(sorted([s.id for s in samples]))
combination_hash = hashlib.md5(sample_ids_str.encode()).hexdigest()[:12]
cache_dir = _get_cache_dir()
cache_dir.mkdir(parents=True, exist_ok=True)
combined_path = cache_dir / f"combined_{profile_id}_{combination_hash}.wav"
save_audio(combined_audio, str(combined_path), 24000)
voice_prompt, _ = await tts_model.create_voice_prompt(
str(combined_path),
combined_text,
use_cache=use_cache,
)
return voice_prompt
async def upload_avatar(
profile_id: str,
image_path: str,
db: Session,
) -> VoiceProfileResponse:
"""
Upload and process avatar image for a profile.
Args:
profile_id: Profile ID
image_path: Path to uploaded image file
db: Database session
Returns:
Updated profile
"""
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
raise ValueError(f"Profile {profile_id} not found")
is_valid, error_msg = validate_image(image_path)
if not is_valid:
raise ValueError(error_msg)
if profile.avatar_path:
old_avatar = config.resolve_storage_path(profile.avatar_path)
if old_avatar is not None and old_avatar.exists():
old_avatar.unlink()
# Determine file extension from uploaded file
from PIL import Image
with Image.open(image_path) as img:
# Normalize JPEG variants (MPO is multi-picture format from some cameras)
img_format = img.format
if img_format in ("MPO", "JPG"):
img_format = "JPEG"
ext_map = {"PNG": ".png", "JPEG": ".jpg", "WEBP": ".webp"}
ext = ext_map.get(img_format, ".png")
profile_dir = config.get_profiles_dir() / profile_id
profile_dir.mkdir(parents=True, exist_ok=True)
output_path = profile_dir / f"avatar{ext}"
process_avatar(image_path, str(output_path))
profile.avatar_path = config.to_storage_path(output_path)
profile.updated_at = datetime.utcnow()
db.commit()
db.refresh(profile)
return _profile_to_response(profile)
async def delete_avatar(
profile_id: str,
db: Session,
) -> bool:
"""
Delete avatar image for a profile.
Args:
profile_id: Profile ID
db: Database session
Returns:
True if deleted, False if not found or no avatar
"""
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile or not profile.avatar_path:
return False
avatar_path = config.resolve_storage_path(profile.avatar_path)
if avatar_path is not None and avatar_path.exists():
avatar_path.unlink()
profile.avatar_path = None
profile.updated_at = datetime.utcnow()
db.commit()
return True

925
backend/services/stories.py Normal file
View File

@@ -0,0 +1,925 @@
"""
Story management module.
"""
from typing import List, Optional
from datetime import datetime
import uuid
import tempfile
from pathlib import Path
from sqlalchemy.orm import Session
from sqlalchemy import func
from .. import config
from ..models import (
StoryCreate,
StoryResponse,
StoryDetailResponse,
StoryItemDetail,
StoryItemCreate,
StoryItemBatchUpdate,
StoryItemMove,
StoryItemTrim,
StoryItemSplit,
StoryItemVersionUpdate,
)
from ..database import (
Story as DBStory,
StoryItem as DBStoryItem,
Generation as DBGeneration,
VoiceProfile as DBVoiceProfile,
)
from .history import _get_versions_for_generation
from ..utils.audio import load_audio, save_audio
import numpy as np
def _build_item_detail(
item: DBStoryItem,
generation: DBGeneration,
profile_name: str,
db: Session,
) -> StoryItemDetail:
"""Build a StoryItemDetail with version info from a story item and its generation."""
versions, active_version_id = _get_versions_for_generation(generation.id, db)
# Resolve the audio path: if version_id is set, use that version's audio
audio_path = generation.audio_path
if item.version_id and versions:
for v in versions:
if v.id == item.version_id:
audio_path = v.audio_path
break
return StoryItemDetail(
id=item.id,
story_id=item.story_id,
generation_id=item.generation_id,
version_id=getattr(item, "version_id", None),
start_time_ms=item.start_time_ms,
track=item.track,
trim_start_ms=getattr(item, "trim_start_ms", 0),
trim_end_ms=getattr(item, "trim_end_ms", 0),
created_at=item.created_at,
profile_id=generation.profile_id,
profile_name=profile_name,
text=generation.text,
language=generation.language,
audio_path=audio_path,
duration=generation.duration,
seed=generation.seed,
instruct=generation.instruct,
generation_created_at=generation.created_at,
versions=versions,
active_version_id=active_version_id,
)
async def create_story(
data: StoryCreate,
db: Session,
) -> StoryResponse:
"""
Create a new story.
Args:
data: Story creation data
db: Database session
Returns:
Created story
"""
db_story = DBStory(
id=str(uuid.uuid4()),
name=data.name,
description=data.description,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
db.add(db_story)
db.commit()
db.refresh(db_story)
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == db_story.id).scalar()
response = StoryResponse.model_validate(db_story)
response.item_count = item_count
return response
async def list_stories(
db: Session,
) -> List[StoryResponse]:
"""
List all stories.
Args:
db: Database session
Returns:
List of stories with item counts
"""
stories = db.query(DBStory).order_by(DBStory.updated_at.desc()).all()
result = []
for story in stories:
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == story.id).scalar()
response = StoryResponse.model_validate(story)
response.item_count = item_count
result.append(response)
return result
async def get_story(
story_id: str,
db: Session,
) -> Optional[StoryDetailResponse]:
"""
Get a story with all its items.
Args:
story_id: Story ID
db: Database session
Returns:
Story with items or None if not found
"""
story = db.query(DBStory).filter_by(id=story_id).first()
if not story:
return None
items = (
db.query(DBStoryItem, DBGeneration, DBVoiceProfile.name.label("profile_name"))
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
.join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id)
.filter(DBStoryItem.story_id == story_id)
.order_by(DBStoryItem.start_time_ms)
.all()
)
item_details = []
for item, generation, profile_name in items:
item_details.append(_build_item_detail(item, generation, profile_name, db))
response = StoryDetailResponse.model_validate(story)
response.items = item_details
return response
async def update_story(
story_id: str,
data: StoryCreate,
db: Session,
) -> Optional[StoryResponse]:
"""
Update a story.
Args:
story_id: Story ID
data: Update data
db: Database session
Returns:
Updated story or None if not found
"""
story = db.query(DBStory).filter_by(id=story_id).first()
if not story:
return None
story.name = data.name
story.description = data.description
story.updated_at = datetime.utcnow()
db.commit()
db.refresh(story)
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == story.id).scalar()
response = StoryResponse.model_validate(story)
response.item_count = item_count
return response
async def delete_story(
story_id: str,
db: Session,
) -> bool:
"""
Delete a story and all its items.
Args:
story_id: Story ID
db: Database session
Returns:
True if deleted, False if not found
"""
story = db.query(DBStory).filter_by(id=story_id).first()
if not story:
return False
# Delete all items
db.query(DBStoryItem).filter_by(story_id=story_id).delete()
# Delete story
db.delete(story)
db.commit()
return True
async def add_item_to_story(
story_id: str,
data: StoryItemCreate,
db: Session,
) -> Optional[StoryItemDetail]:
"""
Add a generation to a story.
Args:
story_id: Story ID
data: Item creation data
db: Database session
Returns:
Created item detail or None if story/generation not found
"""
# Verify story exists
story = db.query(DBStory).filter_by(id=story_id).first()
if not story:
return None
# Verify generation exists
generation = db.query(DBGeneration).filter_by(id=data.generation_id).first()
if not generation:
return None
# Check if generation is already in story
existing = db.query(DBStoryItem).filter_by(story_id=story_id, generation_id=data.generation_id).first()
if existing:
# Return existing item
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
return _build_item_detail(existing, generation, profile.name if profile else "Unknown", db)
# Get track from data or default to 0
track = data.track if data.track is not None else 0
# Calculate start_time_ms if not provided
if data.start_time_ms is not None:
start_time_ms = data.start_time_ms
else:
existing_items = (
db.query(DBStoryItem, DBGeneration)
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
.filter(
DBStoryItem.story_id == story_id,
DBStoryItem.track == track,
)
.all()
)
if not existing_items:
start_time_ms = 0
else:
max_end_time_ms = 0
for item, gen in existing_items:
item_end_ms = item.start_time_ms + int(gen.duration * 1000)
max_end_time_ms = max(max_end_time_ms, item_end_ms)
# Add 200ms gap after the last item
start_time_ms = max_end_time_ms + 200
# Create item
item = DBStoryItem(
id=str(uuid.uuid4()),
story_id=story_id,
generation_id=data.generation_id,
start_time_ms=start_time_ms,
track=track,
created_at=datetime.utcnow(),
)
db.add(item)
# Update story updated_at
story.updated_at = datetime.utcnow()
db.commit()
db.refresh(item)
# Get profile name
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
async def move_story_item(
story_id: str,
item_id: str,
data: StoryItemMove,
db: Session,
) -> Optional[StoryItemDetail]:
"""
Move a story item (update position and/or track).
Args:
story_id: Story ID
item_id: Story item ID
data: New position and track data
db: Database session
Returns:
Updated item detail or None if not found
"""
# Get the item
item = (
db.query(DBStoryItem)
.filter_by(
id=item_id,
story_id=story_id,
)
.first()
)
if not item:
return None
# Get the generation
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
if not generation:
return None
# Update position and track
item.start_time_ms = data.start_time_ms
item.track = data.track
# Update story updated_at
story = db.query(DBStory).filter_by(id=story_id).first()
if story:
story.updated_at = datetime.utcnow()
db.commit()
db.refresh(item)
# Get profile name
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
async def remove_item_from_story(
story_id: str,
item_id: str,
db: Session,
) -> bool:
"""
Remove a story item from a story.
Args:
story_id: Story ID
item_id: Story item ID to remove
db: Database session
Returns:
True if removed, False if not found
"""
item = (
db.query(DBStoryItem)
.filter_by(
id=item_id,
story_id=story_id,
)
.first()
)
if not item:
return False
# Delete item
db.delete(item)
# Update story updated_at
story = db.query(DBStory).filter_by(id=story_id).first()
if story:
story.updated_at = datetime.utcnow()
db.commit()
return True
async def trim_story_item(
story_id: str,
item_id: str,
data: StoryItemTrim,
db: Session,
) -> Optional[StoryItemDetail]:
"""
Trim a story item (update trim_start_ms and trim_end_ms).
Args:
story_id: Story ID
item_id: Story item ID
data: Trim data (trim_start_ms, trim_end_ms)
db: Database session
Returns:
Updated item detail or None if not found
"""
# Get the item
item = (
db.query(DBStoryItem)
.filter_by(
id=item_id,
story_id=story_id,
)
.first()
)
if not item:
return None
# Get the generation
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
if not generation:
return None
# Validate trim values don't exceed duration
max_duration_ms = int(generation.duration * 1000)
if data.trim_start_ms + data.trim_end_ms >= max_duration_ms:
return None # Invalid trim - would result in zero or negative duration
# Update trim values
item.trim_start_ms = data.trim_start_ms
item.trim_end_ms = data.trim_end_ms
# Update story updated_at
story = db.query(DBStory).filter_by(id=story_id).first()
if story:
story.updated_at = datetime.utcnow()
db.commit()
db.refresh(item)
# Get profile name
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
async def split_story_item(
story_id: str,
item_id: str,
data: StoryItemSplit,
db: Session,
) -> Optional[List[StoryItemDetail]]:
"""
Split a story item at a given time, creating two clips.
Args:
story_id: Story ID
item_id: Story item ID to split
data: Split data (split_time_ms - time within clip to split at)
db: Database session
Returns:
List of two updated item details (original and new) or None if not found/invalid
"""
# Get the item with a row lock to prevent concurrent splits on the
# same clip (e.g. from rapid double-clicks racing each other).
item = (
db.query(DBStoryItem)
.filter_by(
id=item_id,
story_id=story_id,
)
.with_for_update()
.first()
)
if not item:
return None
# Get the generation
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
if not generation:
return None
# Calculate effective duration and validate split point
current_trim_start = getattr(item, "trim_start_ms", 0)
current_trim_end = getattr(item, "trim_end_ms", 0)
original_duration_ms = int(generation.duration * 1000)
effective_duration_ms = original_duration_ms - current_trim_start - current_trim_end
# Validate split_time_ms is within the effective duration
if data.split_time_ms <= 0 or data.split_time_ms >= effective_duration_ms:
return None # Invalid split point
# Calculate the absolute time in the original audio where we're splitting
absolute_split_ms = current_trim_start + data.split_time_ms
# Update original clip: trim from the end
item.trim_end_ms = original_duration_ms - absolute_split_ms
# Create new clip: starts after the split, trimmed from the start
new_item = DBStoryItem(
id=str(uuid.uuid4()),
story_id=story_id,
generation_id=item.generation_id, # Same generation, different trim
version_id=getattr(item, "version_id", None), # Preserve pinned version
start_time_ms=item.start_time_ms + data.split_time_ms,
track=item.track,
trim_start_ms=absolute_split_ms,
trim_end_ms=current_trim_end,
created_at=datetime.utcnow(),
)
db.add(new_item)
# Update story updated_at
story = db.query(DBStory).filter_by(id=story_id).first()
if story:
story.updated_at = datetime.utcnow()
db.commit()
db.refresh(item)
db.refresh(new_item)
# Get profile name
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
profile_name = profile.name if profile else "Unknown"
return [
_build_item_detail(item, generation, profile_name, db),
_build_item_detail(new_item, generation, profile_name, db),
]
async def duplicate_story_item(
story_id: str,
item_id: str,
db: Session,
) -> Optional[StoryItemDetail]:
"""
Duplicate a story item, creating a copy with all properties.
Args:
story_id: Story ID
item_id: Story item ID to duplicate
db: Database session
Returns:
New item detail or None if not found
"""
# Get the original item
original_item = (
db.query(DBStoryItem)
.filter_by(
id=item_id,
story_id=story_id,
)
.first()
)
if not original_item:
return None
# Get the generation
generation = db.query(DBGeneration).filter_by(id=original_item.generation_id).first()
if not generation:
return None
# Calculate effective duration
current_trim_start = getattr(original_item, "trim_start_ms", 0)
current_trim_end = getattr(original_item, "trim_end_ms", 0)
original_duration_ms = int(generation.duration * 1000)
effective_duration_ms = original_duration_ms - current_trim_start - current_trim_end
# Create duplicate item - place it right after the original
new_item = DBStoryItem(
id=str(uuid.uuid4()),
story_id=story_id,
generation_id=original_item.generation_id, # Same generation as original
version_id=getattr(original_item, "version_id", None), # Preserve pinned version
start_time_ms=original_item.start_time_ms + effective_duration_ms + 200, # 200ms gap
track=original_item.track,
trim_start_ms=current_trim_start,
trim_end_ms=current_trim_end,
created_at=datetime.utcnow(),
)
db.add(new_item)
# Update story updated_at
story = db.query(DBStory).filter_by(id=story_id).first()
if story:
story.updated_at = datetime.utcnow()
db.commit()
db.refresh(new_item)
# Get profile name
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
return _build_item_detail(new_item, generation, profile.name if profile else "Unknown", db)
async def update_story_item_times(
story_id: str,
data: StoryItemBatchUpdate,
db: Session,
) -> bool:
"""
Update story item timecodes.
Args:
story_id: Story ID
data: Batch update data with timecodes
db: Database session
Returns:
True if updated, False if story not found or invalid
"""
story = db.query(DBStory).filter_by(id=story_id).first()
if not story:
return False
# Get all items for this story
items = db.query(DBStoryItem).filter_by(story_id=story_id).all()
item_map = {item.generation_id: item for item in items}
# Verify all generation IDs belong to this story and update timecodes
for update in data.updates:
if update.generation_id not in item_map:
return False
item_map[update.generation_id].start_time_ms = update.start_time_ms
# Update story updated_at
story.updated_at = datetime.utcnow()
db.commit()
return True
async def reorder_story_items(
story_id: str,
generation_ids: List[str],
db: Session,
gap_ms: int = 200,
) -> Optional[List[StoryItemDetail]]:
"""
Reorder story items and recalculate timecodes.
Args:
story_id: Story ID
generation_ids: List of generation IDs in the desired order
db: Database session
gap_ms: Gap in milliseconds between items (default 200ms)
Returns:
Updated list of story items with new timecodes, or None if invalid
"""
story = db.query(DBStory).filter_by(id=story_id).first()
if not story:
return None
# Get all items for this story with their generation data
items_with_gen = (
db.query(DBStoryItem, DBGeneration, DBVoiceProfile.name.label("profile_name"))
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
.join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id)
.filter(DBStoryItem.story_id == story_id)
.all()
)
# Create maps for quick lookup
item_map = {item.generation_id: (item, gen, profile_name) for item, gen, profile_name in items_with_gen}
# Verify all generation IDs belong to this story
if set(generation_ids) != set(item_map.keys()):
return None
# Recalculate timecodes based on new order
current_time_ms = 0
updated_items = []
for gen_id in generation_ids:
item, generation, profile_name = item_map[gen_id]
# Update the item's start time
item.start_time_ms = current_time_ms
# Calculate the duration in ms
duration_ms = int(generation.duration * 1000)
# Move to next position (current end + gap)
current_time_ms += duration_ms + gap_ms
# Build the response item
updated_items.append(_build_item_detail(item, generation, profile_name, db))
# Update story updated_at
story.updated_at = datetime.utcnow()
db.commit()
return updated_items
async def set_story_item_version(
story_id: str,
item_id: str,
data: StoryItemVersionUpdate,
db: Session,
) -> Optional[StoryItemDetail]:
"""
Pin a story item to a specific generation version.
Args:
story_id: Story ID
item_id: Story item ID
data: Version update data (version_id or null for default)
db: Database session
Returns:
Updated item detail or None if not found
"""
item = (
db.query(DBStoryItem)
.filter_by(
id=item_id,
story_id=story_id,
)
.first()
)
if not item:
return None
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
if not generation:
return None
# Validate version_id belongs to this generation if provided
if data.version_id:
from ..database import GenerationVersion as DBGenerationVersion
version = (
db.query(DBGenerationVersion)
.filter_by(
id=data.version_id,
generation_id=item.generation_id,
)
.first()
)
if not version:
return None
item.version_id = data.version_id
# Update story updated_at
story = db.query(DBStory).filter_by(id=story_id).first()
if story:
story.updated_at = datetime.utcnow()
db.commit()
db.refresh(item)
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
async def export_story_audio(
story_id: str,
db: Session,
) -> Optional[bytes]:
"""
Export story as single mixed audio file with timecode-based mixing.
Args:
story_id: Story ID
db: Database session
Returns:
Audio file bytes or None if story not found
"""
story = db.query(DBStory).filter_by(id=story_id).first()
if not story:
return None
# Get all items ordered by start_time_ms
items = (
db.query(DBStoryItem, DBGeneration)
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
.filter(DBStoryItem.story_id == story_id)
.order_by(DBStoryItem.start_time_ms)
.all()
)
if not items:
return None
# Load all audio files and calculate total duration
audio_data = []
sample_rate = 24000 # Default sample rate
for item, generation in items:
# Resolve audio path: use pinned version if set, otherwise generation default
resolved_audio_path = generation.audio_path
if getattr(item, "version_id", None):
from ..database import GenerationVersion as DBGenerationVersion
version = db.query(DBGenerationVersion).filter_by(id=item.version_id).first()
if version:
resolved_audio_path = version.audio_path
audio_path = config.resolve_storage_path(resolved_audio_path)
if audio_path is None or not audio_path.exists():
continue
try:
audio, sr = load_audio(str(audio_path), sample_rate=sample_rate)
sample_rate = sr # Use actual sample rate from first file
# Get trim values
trim_start_ms = getattr(item, "trim_start_ms", 0)
trim_end_ms = getattr(item, "trim_end_ms", 0)
# Calculate effective duration
original_duration_ms = int(generation.duration * 1000)
effective_duration_ms = original_duration_ms - trim_start_ms - trim_end_ms
# Slice audio based on trim values
trim_start_sample = int((trim_start_ms / 1000.0) * sample_rate)
trim_end_sample = int((trim_end_ms / 1000.0) * sample_rate)
# Extract the trimmed portion
if trim_end_ms > 0:
trimmed_audio = (
audio[trim_start_sample:-trim_end_sample] if trim_end_sample > 0 else audio[trim_start_sample:]
)
else:
trimmed_audio = audio[trim_start_sample:]
# Store audio with its timecode info
start_time_ms = item.start_time_ms
audio_data.append(
{
"audio": trimmed_audio,
"start_time_ms": start_time_ms,
"duration_ms": effective_duration_ms,
}
)
except Exception:
# Skip files that can't be loaded
continue
if not audio_data:
return None
# Calculate total duration: max(start_time_ms + duration_ms)
max_end_time_ms = max((data["start_time_ms"] + data["duration_ms"] for data in audio_data), default=0)
# Convert to samples
total_samples = int((max_end_time_ms / 1000.0) * sample_rate)
# Create output buffer initialized to zeros
final_audio = np.zeros(total_samples, dtype=np.float32)
# Mix each audio segment at its timecode position
for data in audio_data:
audio = data["audio"]
start_time_ms = data["start_time_ms"]
# Calculate start sample index
start_sample = int((start_time_ms / 1000.0) * sample_rate)
# Ensure we don't exceed buffer bounds
audio_length = len(audio)
end_sample = min(start_sample + audio_length, total_samples)
if start_sample < total_samples:
# Trim audio if it extends beyond buffer
audio_to_mix = audio[: end_sample - start_sample]
# Mix: add audio to existing buffer (overlapping audio will sum)
# Normalize to prevent clipping (simple approach: divide by max)
final_audio[start_sample:end_sample] += audio_to_mix
# Normalize to prevent clipping
max_val = np.abs(final_audio).max()
if max_val > 1.0:
final_audio = final_audio / max_val
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp_path = tmp.name
try:
save_audio(final_audio, tmp_path, sample_rate)
# Read file bytes
with open(tmp_path, "rb") as f:
audio_bytes = f.read()
return audio_bytes
finally:
# Clean up temp file
Path(tmp_path).unlink(missing_ok=True)

View File

@@ -0,0 +1,108 @@
"""
Serial generation queue — ensures only one TTS inference runs at a time
to avoid GPU contention.
"""
import asyncio
import traceback
from dataclasses import dataclass
from typing import Coroutine, Literal
# Keep references to fire-and-forget background tasks to prevent GC
_background_tasks: set = set()
@dataclass
class GenerationJob:
"""Queued generation work plus the generation ID it belongs to."""
generation_id: str
coro: Coroutine
# Generation queue — serializes TTS inference to avoid GPU contention
_generation_queue: asyncio.Queue = None # type: ignore # initialized at startup
_generation_worker_task: asyncio.Task | None = None
_queued_generation_ids: set[str] = set()
_running_generation_tasks: dict[str, asyncio.Task] = {}
_cancelled_generation_ids: set[str] = set()
def create_background_task(coro) -> asyncio.Task:
"""Create a background task and prevent it from being garbage collected."""
task = asyncio.create_task(coro)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
return task
async def _generation_worker():
"""Worker that processes generation tasks one at a time."""
while True:
job = await _generation_queue.get()
try:
if job.generation_id in _cancelled_generation_ids:
_cancelled_generation_ids.discard(job.generation_id)
job.coro.close()
continue
task = asyncio.create_task(job.coro)
_running_generation_tasks[job.generation_id] = task
_queued_generation_ids.discard(job.generation_id)
try:
await task
except asyncio.CancelledError:
if not task.cancelled():
raise
except Exception:
traceback.print_exc()
finally:
_running_generation_tasks.pop(job.generation_id, None)
_queued_generation_ids.discard(job.generation_id)
_generation_queue.task_done()
def enqueue_generation(generation_id: str, coro):
"""Add a generation coroutine to the serial queue."""
if _generation_queue is None:
raise RuntimeError("Generation queue has not been initialized")
_queued_generation_ids.add(generation_id)
_generation_queue.put_nowait(GenerationJob(generation_id=generation_id, coro=coro))
def cancel_generation(generation_id: str) -> Literal["queued", "running"] | None:
"""Cancel a queued or running generation if it is still active."""
running_task = _running_generation_tasks.get(generation_id)
if running_task is not None:
running_task.cancel()
return "running"
if generation_id in _queued_generation_ids:
_queued_generation_ids.discard(generation_id)
_cancelled_generation_ids.add(generation_id)
return "queued"
return None
def init_queue(force: bool = False):
"""Initialize the generation queue and start the worker.
Must be called once during application startup (inside a running event loop).
"""
global _generation_queue, _generation_worker_task
global _queued_generation_ids, _running_generation_tasks, _cancelled_generation_ids
if _generation_worker_task is not None and not _generation_worker_task.done():
if not force:
return
_generation_worker_task.cancel()
for task in list(_running_generation_tasks.values()):
task.cancel()
_generation_queue = asyncio.Queue()
_queued_generation_ids = set()
_running_generation_tasks = {}
_cancelled_generation_ids = set()
_generation_worker_task = create_background_task(_generation_worker())

View File

@@ -0,0 +1,22 @@
"""
STT (Speech-to-Text) module - delegates to backend abstraction layer.
"""
from typing import Optional
from ..backends import get_stt_backend, STTBackend
def get_whisper_model() -> STTBackend:
"""
Get STT backend instance (MLX or PyTorch based on platform).
Returns:
STT backend instance
"""
return get_stt_backend()
def unload_whisper_model():
"""Unload Whisper model to free memory."""
backend = get_stt_backend()
backend.unload_model()

34
backend/services/tts.py Normal file
View File

@@ -0,0 +1,34 @@
"""
TTS inference module - delegates to backend abstraction layer.
"""
from typing import Optional
import numpy as np
import io
import soundfile as sf
from ..backends import get_tts_backend, TTSBackend
def get_tts_model() -> TTSBackend:
"""
Get TTS backend instance (MLX or PyTorch based on platform).
Returns:
TTS backend instance
"""
return get_tts_backend()
def unload_tts_model():
"""Unload TTS model to free memory."""
backend = get_tts_backend()
backend.unload_model()
def audio_to_wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes:
"""Convert audio array to WAV bytes."""
buffer = io.BytesIO()
sf.write(buffer, audio, sample_rate, format="WAV")
buffer.seek(0)
return buffer.read()

View File

@@ -0,0 +1,211 @@
"""
Generation versions management module.
Each generation can have multiple audio versions: a clean (unprocessed)
version and any number of processed versions with different effects chains.
"""
from __future__ import annotations
import json
import uuid
from pathlib import Path
from typing import List, Optional
from sqlalchemy.orm import Session
from ..database import (
GenerationVersion as DBGenerationVersion,
Generation as DBGeneration,
)
from ..models import GenerationVersionResponse, EffectConfig
from .. import config
def _version_response(v: DBGenerationVersion) -> GenerationVersionResponse:
"""Convert a DB version row to a Pydantic response."""
effects_chain = None
if v.effects_chain:
raw = json.loads(v.effects_chain)
effects_chain = [EffectConfig(**e) for e in raw]
return GenerationVersionResponse(
id=v.id,
generation_id=v.generation_id,
label=v.label,
audio_path=v.audio_path,
effects_chain=effects_chain,
source_version_id=v.source_version_id,
is_default=v.is_default,
created_at=v.created_at,
)
def list_versions(generation_id: str, db: Session) -> List[GenerationVersionResponse]:
"""List all versions for a generation."""
versions = (
db.query(DBGenerationVersion)
.filter_by(generation_id=generation_id)
.order_by(DBGenerationVersion.created_at)
.all()
)
return [_version_response(v) for v in versions]
def get_version(version_id: str, db: Session) -> Optional[GenerationVersionResponse]:
"""Get a specific version by ID."""
v = db.query(DBGenerationVersion).filter_by(id=version_id).first()
if not v:
return None
return _version_response(v)
def get_default_version(generation_id: str, db: Session) -> Optional[GenerationVersionResponse]:
"""Get the default version for a generation."""
v = (
db.query(DBGenerationVersion)
.filter_by(generation_id=generation_id, is_default=True)
.first()
)
if not v:
# Fallback: return the first version
v = (
db.query(DBGenerationVersion)
.filter_by(generation_id=generation_id)
.order_by(DBGenerationVersion.created_at)
.first()
)
if not v:
return None
return _version_response(v)
def create_version(
generation_id: str,
label: str,
audio_path: str,
db: Session,
effects_chain: Optional[List[dict]] = None,
is_default: bool = False,
source_version_id: Optional[str] = None,
) -> GenerationVersionResponse:
"""Create a new version for a generation.
If ``is_default`` is True, all other versions for this generation
are un-defaulted first.
"""
if is_default:
_clear_defaults(generation_id, db)
version = DBGenerationVersion(
id=str(uuid.uuid4()),
generation_id=generation_id,
label=label,
audio_path=audio_path,
effects_chain=json.dumps(effects_chain) if effects_chain else None,
source_version_id=source_version_id,
is_default=is_default,
)
db.add(version)
db.commit()
db.refresh(version)
# If this version is the default, update the generation's audio_path
if is_default:
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if gen:
gen.audio_path = audio_path
db.commit()
return _version_response(version)
def set_default_version(version_id: str, db: Session) -> Optional[GenerationVersionResponse]:
"""Set a version as the default for its generation."""
version = db.query(DBGenerationVersion).filter_by(id=version_id).first()
if not version:
return None
_clear_defaults(version.generation_id, db)
version.is_default = True
db.commit()
db.refresh(version)
# Update generation's audio_path to point to this version
gen = db.query(DBGeneration).filter_by(id=version.generation_id).first()
if gen:
gen.audio_path = version.audio_path
db.commit()
return _version_response(version)
def delete_version(version_id: str, db: Session) -> bool:
"""Delete a version. Cannot delete the last remaining version."""
version = db.query(DBGenerationVersion).filter_by(id=version_id).first()
if not version:
return False
# Don't allow deleting the last version
count = (
db.query(DBGenerationVersion)
.filter_by(generation_id=version.generation_id)
.count()
)
if count <= 1:
return False
was_default = version.is_default
gen_id = version.generation_id
# Delete audio file
audio_path = config.resolve_storage_path(version.audio_path)
if audio_path is not None and audio_path.exists():
audio_path.unlink()
db.delete(version)
db.commit()
# If this was the default, promote the first remaining version
if was_default:
first = (
db.query(DBGenerationVersion)
.filter_by(generation_id=gen_id)
.order_by(DBGenerationVersion.created_at)
.first()
)
if first:
first.is_default = True
db.commit()
gen = db.query(DBGeneration).filter_by(id=gen_id).first()
if gen:
gen.audio_path = first.audio_path
db.commit()
return True
def delete_versions_for_generation(generation_id: str, db: Session) -> int:
"""Delete all versions for a generation (used when deleting a generation)."""
versions = (
db.query(DBGenerationVersion)
.filter_by(generation_id=generation_id)
.all()
)
count = 0
for v in versions:
audio_path = config.resolve_storage_path(v.audio_path)
if audio_path is not None and audio_path.exists():
audio_path.unlink()
db.delete(v)
count += 1
if count > 0:
db.commit()
return count
def _clear_defaults(generation_id: str, db: Session) -> None:
"""Clear the is_default flag on all versions for a generation."""
db.query(DBGenerationVersion).filter_by(
generation_id=generation_id, is_default=True
).update({"is_default": False})
db.flush()

View File

@@ -0,0 +1,220 @@
# End-to-End Model Generation Test — Design
## Goal
A single script, runnable on macOS and Windows, that exercises every TTS model against the **frozen PyInstaller binary** (not the dev server), captures per-model pass/fail and error messages, and exits non-zero if any model fails. Generation is strictly sequential — one model loaded at a time.
## Test matrix (10 runs)
Derived from `backend/backends/__init__.py:185-316`. Each row maps to one `POST /generate` call.
| # | engine | model_size | profile kind | notes |
|---|-----------------------|------------|--------------|-------|
| 1 | `qwen` | `1.7B` | cloned | reference audio required |
| 2 | `qwen` | `0.6B` | cloned | |
| 3 | `qwen_custom_voice` | `1.7B` | preset | `preset_voice_id="Ryan"` |
| 4 | `qwen_custom_voice` | `0.6B` | preset | `preset_voice_id="Ryan"` |
| 5 | `luxtts` | — | cloned | English only |
| 6 | `chatterbox` | — | cloned | |
| 7 | `chatterbox_turbo` | — | cloned | English only |
| 8 | `tada` | `1B` | cloned | tada-1b, English only |
| 9 | `tada` | `3B` | cloned | tada-3b-ml, multilingual |
| 10| `kokoro` | — | preset | `preset_voice_id="af_heart"` |
Cloned engines (1, 2, 5, 6, 7, 8, 9) share **one** profile created once with the reference WAV. Preset profiles are created separately, one for kokoro and one for qwen_custom_voice.
Language for every run: `en` (covers every engine's supported set).
## End-to-end flow
```
1. Resolve paths → find binary, build if missing
2. Launch binary → spawn with --port --data-dir --parent-pid
3. Wait for /health → poll until status=="healthy" or 120s timeout
4. Create profiles → 1 cloned + 2 preset, via /profiles (+ /samples)
5. For each (engine, model_size) in matrix:
a. Check cache → GET /models/status → cached? short timeout : long
b. POST /generate → get generation_id
c. Stream /status → consume SSE until completed/failed/timeout
d. Record result → {engine, model_size, status, duration, error, elapsed}
6. Write results → JSON + Markdown table to ./results/
7. Shutdown binary → SIGTERM, fall back to kill, verify port freed
8. Exit code → 0 if all passed, 1 otherwise
```
## Binary resolution
Search order — **first hit wins**:
| Platform | Path | Build type |
|----------|------|------------|
| macOS | `backend/dist/voicebox-server-cuda/voicebox-server-cuda` | onedir (CUDA, rarely on Mac) |
| macOS | `backend/dist/voicebox-server` | onefile (CPU) |
| Windows | `backend\dist\voicebox-server-cuda\voicebox-server-cuda.exe` | onedir (CUDA) |
| Windows | `backend\dist\voicebox-server.exe` | onefile (CPU) |
If none exist, run `python backend/build_binary.py` and wait for it to finish (can take 5-20 min). Fail with a clear error if the build itself fails. `--skip-build` flag forces "error out if no binary" instead of building.
## Spawn command
Mirrors Tauri's launch in `tauri/src-tauri/src/main.rs:369-388`:
```
<binary> --host 127.0.0.1 --port <free-port> --data-dir <tempdir> --parent-pid <test-pid>
```
- **Port**: bind to `0` first in Python to grab a free port, then pass that number.
- **Data dir**: `tempfile.mkdtemp(prefix="voicebox-e2e-")`. Deleted after the run unless `--keep-data-dir`. Profiles and generated WAVs land here.
- **Parent PID**: current Python PID — ensures the backend dies if the test crashes (watchdog in `server.py:102-224`).
- **stdout/stderr**: tee to both a log file in `./results/server-<timestamp>.log` and a rolling in-memory buffer. On model failure, last 100 lines of the buffer are attached to that model's error record.
## Profile setup
One cloned profile shared across all cloning engines:
```http
POST /profiles
{
"name": "e2e-cloned",
"voice_type": "cloned",
"language": "en"
}
```
Then:
```http
POST /profiles/{id}/samples (multipart)
file: <reference WAV>
reference_text: <exact transcription>
```
Two preset profiles:
```http
POST /profiles
{ "name": "e2e-kokoro", "voice_type": "preset", "language": "en",
"preset_engine": "kokoro", "preset_voice_id": "af_heart" }
POST /profiles
{ "name": "e2e-qwen-cv", "voice_type": "preset", "language": "en",
"preset_engine": "qwen_custom_voice", "preset_voice_id": "Ryan" }
```
## Generation request (per matrix row)
```http
POST /generate
{
"profile_id": "<appropriate profile>",
"text": "The quick brown fox jumps over the lazy dog.",
"language": "en",
"engine": "<engine>",
"model_size": "<size or omitted>",
"seed": 42,
"normalize": true
}
```
Response `id` feeds into the SSE status loop (`GET /generate/{id}/status`, `routes/generations.py:190-227`). Loop reads lines until a payload with `status in ("completed", "failed")` arrives, then breaks.
## Timeout strategy (split)
Check `GET /models/status` for the target model **before** generation:
| Cached? | Per-model timeout | Rationale |
|---------|-------------------|-----------|
| Yes | **3 minutes** | Inference only; generous for CPU builds |
| No | **20 minutes** | First-run HF download up to 8 GB (tada-3b-ml) |
On timeout: cancel the SSE stream, mark the row `timeout`, and continue to the next row. Don't abort the whole run on one timeout.
## Result format
`./results/e2e-<platform>-<arch>-<timestamp>.json`:
```json
{
"platform": "darwin-arm64",
"binary": "/abs/path/voicebox-server",
"binary_size_mb": 612,
"started_at": "2026-04-16T12:34:56Z",
"finished_at": "...",
"results": [
{
"engine": "qwen",
"model_size": "1.7B",
"status": "passed|failed|timeout",
"generation_id": "...",
"was_cached": true,
"elapsed_seconds": 12.4,
"audio_duration": 3.1,
"audio_path": "/tmp/.../gen.wav",
"error": null,
"server_log_tail": null
}
]
}
```
Companion `./results/e2e-<...>.md`:
```
# Voicebox E2E — darwin-arm64 — 2026-04-16 12:34
| Engine | Size | Status | Elapsed | Error |
|---------------------|------|--------|---------|-------|
| qwen | 1.7B | PASS | 12.4s | |
| qwen | 0.6B | FAIL | 4.1s | CUDA OOM: ... |
...
```
## CLI flags
```
python -m backend.tests.test_all_models_e2e [flags]
--binary PATH Use this binary instead of auto-detecting
--skip-build Error if no binary found (no auto-build)
--reference-wav PATH Reference audio (default: backend/tests/fixtures/reference_voice.wav)
--reference-text STR Transcription (default: read from fixtures/reference_voice.txt)
--only ENGINE[,...] Run only these engines (e.g. kokoro,qwen)
--skip ENGINE[,...] Skip these engines
--keep-data-dir Don't delete tempdir after run
--timeout-cached SEC Override 180
--timeout-download SEC Override 1200
--port N Override auto-picked port
--output-dir PATH Default: backend/tests/results/
```
## File layout
```
backend/tests/
├── E2E_MODEL_TEST_DESIGN.md (this file)
├── test_all_models_e2e.py (main script, ~400-500 LoC)
├── fixtures/
│ ├── reference_voice.wav (user-provided, ~5-15s clean speech)
│ └── reference_voice.txt (exact transcription)
└── results/ (gitignored)
├── e2e-darwin-arm64-<ts>.json
├── e2e-darwin-arm64-<ts>.md
└── server-<ts>.log
```
The script uses only stdlib + `httpx` (or `requests`) + `sseclient-py` — all already in `backend/requirements.txt`. No pytest to keep it invocable as a single command on fresh checkouts.
## Safety & cleanup
- Always kill the spawned binary in a `try/finally`. On Windows, `taskkill /F /T` the whole tree (Tauri does the same).
- Verify the port is free on shutdown (Tauri port-reuse check in `main.rs:114-186` could otherwise pick up a ghost).
- Don't touch the user's HF cache by default — let the server use `HF_HUB_CACHE` / `VOICEBOX_MODELS_DIR`. Passing `--isolated-cache` would point both env vars at the tempdir for a true cold-start run (opt-in only; would re-download every time).
## Non-goals
- Not validating audio quality (no WER, no waveform comparison). Pass = "endpoint returned `completed` and produced a non-empty WAV".
- Not testing STT (Whisper), effects chains, channels, or streaming endpoints.
- Not running on CI today — human-invoked on dev machines. CI integration is a follow-up once the script is stable.
- No model unload between runs — models stay loaded; server manages its own eviction.
- No version-drift check on the binary.
- No `instruct` parameter exercised on qwen_custom_voice runs.

58
backend/tests/README.md Normal file
View File

@@ -0,0 +1,58 @@
# Backend Tests
Manual test scripts for debugging and validating backend functionality.
## Test Files
### `test_generation_progress.py`
Tests TTS generation with SSE progress monitoring to identify UX issues where users see download progress even when the model is already cached.
**Usage:**
```bash
cd backend
python tests/test_generation_progress.py
```
**Prerequisites:**
- Server must be running (`python main.py`)
- At least one voice profile must exist
### `test_real_download.py`
Tests real model download with SSE progress monitoring.
**Usage:**
```bash
cd backend
# Delete cache first to force fresh download
rm -rf ~/.cache/huggingface/hub/models--openai--whisper-base
python tests/test_real_download.py
```
**Prerequisites:**
- Server must be running (`python main.py`)
### `test_progress.py`
Unit tests for ProgressManager and HFProgressTracker functionality.
**Usage:**
```bash
cd backend
python tests/test_progress.py
```
### `test_check_progress_state.py`
Debugging script to inspect the internal state of ProgressManager and TaskManager.
**Usage:**
```bash
cd backend
python tests/test_check_progress_state.py
```
## Notes
These are manual test scripts, not automated unit tests. They're designed for:
- Debugging progress tracking issues
- Validating SSE event streams
- Monitoring real-time download behavior
- Inspecting internal state during development

View File

@@ -0,0 +1,6 @@
"""
Test suite for Voicebox backend.
This directory contains manual test scripts for debugging and validating
progress tracking, model downloads, and generation functionality.
"""

16
backend/tests/fixtures/README.md vendored Normal file
View File

@@ -0,0 +1,16 @@
# E2E Test Fixtures
Place two files here before running `test_all_models_e2e.py`:
- `reference_voice.wav` — a clean speech sample, mono, 1624 kHz, ~515 seconds.
- `reference_voice.txt` — the **exact** transcription of the WAV (single line, no trailing newline required).
These are used to create a cloned voice profile for every cloning-capable engine (qwen, luxtts, chatterbox, chatterbox_turbo, tada). Keep them out of version control if they contain personal audio — this directory is not gitignored by default, so add them to `.gitignore` locally if needed.
You can point the test at different files with:
```
python backend/tests/test_all_models_e2e.py \
--reference-wav /path/to/your.wav \
--reference-text "exact transcription here"
```

View File

@@ -0,0 +1,630 @@
"""
End-to-end model generation test.
Exercises every TTS model against the frozen PyInstaller binary, captures
per-model pass/fail, and writes a JSON + Markdown report.
Usage:
python backend/tests/test_all_models_e2e.py [flags]
See E2E_MODEL_TEST_DESIGN.md for the full design.
"""
from __future__ import annotations
import argparse
import json
import os
import platform
import shutil
import signal
import socket
import subprocess
import sys
import tempfile
import threading
import time
from collections import deque
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
import httpx
REPO_ROOT = Path(__file__).resolve().parents[2]
BACKEND_DIR = REPO_ROOT / "backend"
DIST_DIR = BACKEND_DIR / "dist"
FIXTURES_DIR = Path(__file__).resolve().parent / "fixtures"
RESULTS_DIR = Path(__file__).resolve().parent / "results"
# ── Test matrix ──────────────────────────────────────────────────────
@dataclass(frozen=True)
class MatrixRow:
label: str # human-readable (appears in report)
engine: str # /generate engine
model_size: Optional[str] # /generate model_size (None = omit)
profile_kind: str # "cloned" | "preset_kokoro" | "preset_qwen_cv"
model_name: str # /models/status key for cache lookup
MATRIX: list[MatrixRow] = [
MatrixRow("qwen 1.7B", "qwen", "1.7B", "cloned", "qwen-tts-1.7B"),
MatrixRow("qwen 0.6B", "qwen", "0.6B", "cloned", "qwen-tts-0.6B"),
MatrixRow("qwen_custom_voice 1.7B", "qwen_custom_voice", "1.7B", "preset_qwen_cv", "qwen-custom-voice-1.7B"),
MatrixRow("qwen_custom_voice 0.6B", "qwen_custom_voice", "0.6B", "preset_qwen_cv", "qwen-custom-voice-0.6B"),
MatrixRow("luxtts", "luxtts", None, "cloned", "luxtts"),
MatrixRow("chatterbox", "chatterbox", None, "cloned", "chatterbox-tts"),
MatrixRow("chatterbox_turbo", "chatterbox_turbo", None, "cloned", "chatterbox-turbo"),
MatrixRow("tada 1B", "tada", "1B", "cloned", "tada-1b"),
MatrixRow("tada 3B", "tada", "3B", "cloned", "tada-3b-ml"),
MatrixRow("kokoro", "kokoro", None, "preset_kokoro", "kokoro"),
]
TEXT = "The quick brown fox jumps over the lazy dog."
DEFAULT_TIMEOUT_CACHED = 180
DEFAULT_TIMEOUT_DOWNLOAD = 1200
HEALTH_TIMEOUT = 120
# ── Result record ────────────────────────────────────────────────────
@dataclass
class ModelResult:
label: str
engine: str
model_size: Optional[str]
status: str # "passed" | "failed" | "timeout"
was_cached: Optional[bool] = None
generation_id: Optional[str] = None
elapsed_seconds: float = 0.0
audio_duration: Optional[float] = None
audio_path: Optional[str] = None
audio_bytes: Optional[int] = None
error: Optional[str] = None
http_status: Optional[int] = None
server_log_tail: Optional[list[str]] = None
# ── Binary resolution ────────────────────────────────────────────────
def find_binary() -> Optional[Path]:
"""Return the first existing binary in priority order, or None."""
is_win = platform.system() == "Windows"
exe = ".exe" if is_win else ""
candidates = [
DIST_DIR / "voicebox-server-cuda" / f"voicebox-server-cuda{exe}",
DIST_DIR / f"voicebox-server{exe}",
]
for c in candidates:
if c.exists() and c.is_file():
return c
return None
def build_binary() -> Path:
"""Invoke build_binary.py and return the resulting binary path."""
print("[build] No frozen binary found — invoking build_binary.py (this may take 5-20 minutes)...", flush=True)
script = BACKEND_DIR / "build_binary.py"
result = subprocess.run(
[sys.executable, str(script)],
cwd=str(BACKEND_DIR),
)
if result.returncode != 0:
raise RuntimeError(f"build_binary.py exited with code {result.returncode}")
found = find_binary()
if found is None:
raise RuntimeError("build_binary.py finished but no binary was found in backend/dist/")
return found
# ── Server spawn + log capture ───────────────────────────────────────
class ServerProcess:
def __init__(self, binary: Path, port: int, data_dir: Path, log_path: Path):
self.binary = binary
self.port = port
self.data_dir = data_dir
self.log_path = log_path
self.proc: Optional[subprocess.Popen] = None
self._log_buffer: deque[str] = deque(maxlen=500)
self._reader_thread: Optional[threading.Thread] = None
def start(self) -> None:
args = [
str(self.binary),
"--host", "127.0.0.1",
"--port", str(self.port),
"--data-dir", str(self.data_dir),
"--parent-pid", str(os.getpid()),
]
print(f"[spawn] {' '.join(args)}", flush=True)
self._log_fh = open(self.log_path, "w", encoding="utf-8", errors="replace")
# Combine stderr into stdout so we get a single ordered stream.
self.proc = subprocess.Popen(
args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
bufsize=1,
text=True,
errors="replace",
)
self._reader_thread = threading.Thread(target=self._pump_logs, daemon=True)
self._reader_thread.start()
def _pump_logs(self) -> None:
assert self.proc is not None and self.proc.stdout is not None
for line in self.proc.stdout:
self._log_buffer.append(line.rstrip("\n"))
self._log_fh.write(line)
self._log_fh.flush()
def log_tail(self, n: int = 100) -> list[str]:
tail = list(self._log_buffer)[-n:]
return tail
def is_alive(self) -> bool:
return self.proc is not None and self.proc.poll() is None
def stop(self) -> None:
if self.proc is None:
return
if self.proc.poll() is not None:
return
try:
if platform.system() == "Windows":
subprocess.run(
["taskkill", "/F", "/T", "/PID", str(self.proc.pid)],
capture_output=True,
)
else:
self.proc.send_signal(signal.SIGTERM)
except Exception as e:
print(f"[shutdown] signal failed: {e}", flush=True)
try:
self.proc.wait(timeout=10)
except subprocess.TimeoutExpired:
print("[shutdown] server didn't exit cleanly, killing", flush=True)
self.proc.kill()
try:
self.proc.wait(timeout=5)
except subprocess.TimeoutExpired:
pass
if self._reader_thread is not None:
self._reader_thread.join(timeout=2)
try:
self._log_fh.close()
except Exception:
pass
def pick_free_port() -> int:
s = socket.socket()
s.bind(("127.0.0.1", 0))
port = s.getsockname()[1]
s.close()
return port
# ── HTTP helpers ─────────────────────────────────────────────────────
def wait_for_health(base_url: str, server: ServerProcess, timeout: int) -> None:
deadline = time.time() + timeout
with httpx.Client(timeout=5.0) as client:
while time.time() < deadline:
if not server.is_alive():
raise RuntimeError("Server process exited before becoming healthy")
try:
r = client.get(f"{base_url}/health")
if r.status_code == 200 and r.json().get("status") == "healthy":
return
except httpx.HTTPError:
pass
time.sleep(1.0)
raise TimeoutError(f"Server did not become healthy within {timeout}s")
def get_model_cached(client: httpx.Client, base_url: str, model_name: str) -> Optional[bool]:
try:
r = client.get(f"{base_url}/models/status", timeout=30.0)
r.raise_for_status()
for m in r.json().get("models", []):
if m.get("model_name") == model_name:
return bool(m.get("downloaded"))
except httpx.HTTPError:
return None
return None
def create_cloned_profile(client: httpx.Client, base_url: str, wav_path: Path, reference_text: str) -> str:
r = client.post(f"{base_url}/profiles", json={
"name": "e2e-cloned",
"voice_type": "cloned",
"language": "en",
})
r.raise_for_status()
profile_id = r.json()["id"]
with open(wav_path, "rb") as f:
r = client.post(
f"{base_url}/profiles/{profile_id}/samples",
files={"file": (wav_path.name, f, "audio/wav")},
data={"reference_text": reference_text},
timeout=120.0,
)
r.raise_for_status()
return profile_id
def create_preset_profile(client: httpx.Client, base_url: str, name: str, engine: str, voice_id: str) -> str:
r = client.post(f"{base_url}/profiles", json={
"name": name,
"voice_type": "preset",
"language": "en",
"preset_engine": engine,
"preset_voice_id": voice_id,
})
r.raise_for_status()
return r.json()["id"]
def run_one_generation(
client: httpx.Client,
base_url: str,
row: MatrixRow,
profile_id: str,
timeout_s: int,
) -> tuple[str, dict]:
"""Start a generation and stream its status until done/failed/timeout.
Returns (status, payload) where status is "completed" | "failed" | "timeout".
"""
body = {
"profile_id": profile_id,
"text": TEXT,
"language": "en",
"engine": row.engine,
"seed": 42,
"normalize": True,
}
if row.model_size is not None:
body["model_size"] = row.model_size
r = client.post(f"{base_url}/generate", json=body, timeout=30.0)
r.raise_for_status()
gen = r.json()
gen_id = gen["id"]
deadline = time.time() + timeout_s
last_payload: dict = gen
status_url = f"{base_url}/generate/{gen_id}/status"
while time.time() < deadline:
remaining = max(1.0, deadline - time.time())
try:
with client.stream("GET", status_url, timeout=httpx.Timeout(remaining + 5, read=remaining + 5)) as resp:
resp.raise_for_status()
for line in resp.iter_lines():
if not line or not line.startswith("data: "):
continue
try:
payload = json.loads(line[6:])
except json.JSONDecodeError:
continue
last_payload = payload
status = payload.get("status")
if status == "not_found":
return "failed", {"error": "generation not found", **payload}
if status in ("completed", "failed"):
return status, payload
if time.time() >= deadline:
break
except httpx.HTTPError:
time.sleep(1.0)
continue
return "timeout", last_payload
def fetch_audio_info(
client: httpx.Client, base_url: str, generation_id: str, data_dir: Path
) -> tuple[Optional[str], Optional[int]]:
"""Return (audio_path, audio_bytes) for a completed generation.
Server stores audio_path relative to data_dir; resolve it to get a size.
"""
try:
r = client.get(f"{base_url}/history/{generation_id}", timeout=10.0)
if r.status_code != 200:
return None, None
data = r.json()
audio_path = data.get("audio_path")
if not audio_path:
return None, None
p = Path(audio_path)
if not p.is_absolute():
p = data_dir / p
if p.exists():
return str(p), p.stat().st_size
return audio_path, None
except httpx.HTTPError:
return None, None
# ── Report writers ───────────────────────────────────────────────────
def write_reports(
output_dir: Path,
binary: Path,
started_at: datetime,
finished_at: datetime,
results: list[ModelResult],
) -> tuple[Path, Path]:
output_dir.mkdir(parents=True, exist_ok=True)
plat = f"{platform.system().lower()}-{platform.machine().lower()}"
ts = started_at.strftime("%Y%m%d-%H%M%S")
json_path = output_dir / f"e2e-{plat}-{ts}.json"
md_path = output_dir / f"e2e-{plat}-{ts}.md"
doc = {
"platform": plat,
"binary": str(binary),
"binary_size_mb": round(binary.stat().st_size / (1024 * 1024), 1) if binary.exists() else None,
"started_at": started_at.isoformat(),
"finished_at": finished_at.isoformat(),
"elapsed_seconds": (finished_at - started_at).total_seconds(),
"results": [asdict(r) for r in results],
}
json_path.write_text(json.dumps(doc, indent=2))
lines = [
f"# Voicebox E2E — {plat}{started_at.strftime('%Y-%m-%d %H:%M UTC')}",
"",
f"Binary: `{binary}` ",
f"Elapsed: {doc['elapsed_seconds']:.1f}s",
"",
"| Model | Status | Cached | Elapsed | Audio | Error |",
"|-------|--------|--------|---------|-------|-------|",
]
for r in results:
status_icon = {"passed": "PASS", "failed": "FAIL", "timeout": "TIMEOUT"}.get(r.status, r.status.upper())
cached = "yes" if r.was_cached else ("no" if r.was_cached is False else "?")
audio_col = f"{r.audio_duration:.2f}s" if r.audio_duration else ("" if r.status != "passed" else "?")
error_col = (r.error or "").replace("\n", " ")[:120]
lines.append(f"| {r.label} | {status_icon} | {cached} | {r.elapsed_seconds:.1f}s | {audio_col} | {error_col} |")
failed_rows = [r for r in results if r.status != "passed"]
if failed_rows:
lines.append("")
lines.append("## Failures")
for r in failed_rows:
lines.append("")
lines.append(f"### {r.label}{r.status}")
if r.error:
lines.append("")
lines.append("```")
lines.append(r.error)
lines.append("```")
if r.server_log_tail:
lines.append("")
lines.append("<details><summary>server log (last lines)</summary>")
lines.append("")
lines.append("```")
lines.extend(r.server_log_tail)
lines.append("```")
lines.append("</details>")
md_path.write_text("\n".join(lines) + "\n")
return json_path, md_path
# ── Main ─────────────────────────────────────────────────────────────
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Voicebox E2E model generation test")
p.add_argument("--binary", type=Path, help="Path to voicebox-server binary (overrides auto-detect)")
p.add_argument("--skip-build", action="store_true", help="Error if binary missing instead of building")
p.add_argument(
"--reference-wav",
type=Path,
default=FIXTURES_DIR / "reference_voice.wav",
help="Reference audio for cloning engines",
)
p.add_argument(
"--reference-text",
help="Transcription of reference-wav (default: read from fixtures/reference_voice.txt)",
)
p.add_argument("--only", help="Comma-separated engines to run (e.g. kokoro,qwen)")
p.add_argument("--skip", help="Comma-separated engines to skip")
p.add_argument("--keep-data-dir", action="store_true", help="Don't delete tempdir after run")
p.add_argument("--timeout-cached", type=int, default=DEFAULT_TIMEOUT_CACHED)
p.add_argument("--timeout-download", type=int, default=DEFAULT_TIMEOUT_DOWNLOAD)
p.add_argument("--port", type=int, help="Override auto-picked port")
p.add_argument("--output-dir", type=Path, default=RESULTS_DIR)
return p.parse_args()
def filter_matrix(args: argparse.Namespace) -> list[MatrixRow]:
only = set(x.strip() for x in args.only.split(",")) if args.only else None
skip = set(x.strip() for x in args.skip.split(",")) if args.skip else set()
rows = []
for r in MATRIX:
if only is not None and r.engine not in only:
continue
if r.engine in skip:
continue
rows.append(r)
return rows
def resolve_reference(args: argparse.Namespace) -> tuple[Path, str]:
wav = args.reference_wav
if not wav.exists():
raise FileNotFoundError(
f"Reference WAV not found: {wav}\n"
f"Place a sample at {FIXTURES_DIR / 'reference_voice.wav'} or pass --reference-wav.\n"
f"See backend/tests/fixtures/README.md."
)
if args.reference_text:
text = args.reference_text
else:
txt_path = wav.with_suffix(".txt")
if not txt_path.exists():
raise FileNotFoundError(
f"Reference transcription not found: {txt_path}\n"
f"Create it next to the WAV, or pass --reference-text."
)
text = txt_path.read_text().strip()
if not text:
raise ValueError("Reference transcription is empty")
return wav, text
def main() -> int:
args = parse_args()
rows = filter_matrix(args)
if not rows:
print("No rows selected after --only/--skip filtering", file=sys.stderr)
return 2
# Binary
binary = args.binary or find_binary()
if binary is None:
if args.skip_build:
print("No frozen binary found and --skip-build set. Run: python backend/build_binary.py", file=sys.stderr)
return 2
binary = build_binary()
if not binary.exists():
print(f"Binary path does not exist: {binary}", file=sys.stderr)
return 2
print(f"[binary] {binary}", flush=True)
# Reference audio (only required if any cloning row is in the matrix)
needs_reference = any(r.profile_kind == "cloned" for r in rows)
ref_wav: Optional[Path] = None
ref_text: Optional[str] = None
if needs_reference:
try:
ref_wav, ref_text = resolve_reference(args)
except (FileNotFoundError, ValueError) as e:
print(f"[fixture] {e}", file=sys.stderr)
return 2
print(f"[fixture] reference WAV: {ref_wav}", flush=True)
print(f"[fixture] reference text: {ref_text!r}", flush=True)
# Tempdir + log path
data_dir = Path(tempfile.mkdtemp(prefix="voicebox-e2e-"))
args.output_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
log_path = args.output_dir / f"server-{ts}.log"
port = args.port or pick_free_port()
base_url = f"http://127.0.0.1:{port}"
server = ServerProcess(binary=binary, port=port, data_dir=data_dir, log_path=log_path)
started_at = datetime.now(timezone.utc)
results: list[ModelResult] = []
try:
server.start()
print(f"[health] waiting for {base_url}/health ...", flush=True)
wait_for_health(base_url, server, HEALTH_TIMEOUT)
print("[health] ready", flush=True)
with httpx.Client(timeout=30.0) as client:
# Profile setup (only create what's needed)
cloned_profile_id: Optional[str] = None
kokoro_profile_id: Optional[str] = None
qwen_cv_profile_id: Optional[str] = None
needed_kinds = {r.profile_kind for r in rows}
if "cloned" in needed_kinds:
assert ref_wav is not None and ref_text is not None
print("[profile] creating cloned profile...", flush=True)
cloned_profile_id = create_cloned_profile(client, base_url, ref_wav, ref_text)
if "preset_kokoro" in needed_kinds:
print("[profile] creating kokoro preset...", flush=True)
kokoro_profile_id = create_preset_profile(client, base_url, "e2e-kokoro", "kokoro", "af_heart")
if "preset_qwen_cv" in needed_kinds:
print("[profile] creating qwen_custom_voice preset...", flush=True)
qwen_cv_profile_id = create_preset_profile(client, base_url, "e2e-qwen-cv", "qwen_custom_voice", "Ryan")
profile_lookup = {
"cloned": cloned_profile_id,
"preset_kokoro": kokoro_profile_id,
"preset_qwen_cv": qwen_cv_profile_id,
}
# Matrix loop
for row in rows:
print(f"\n[run] {row.label} (engine={row.engine}, size={row.model_size})", flush=True)
profile_id = profile_lookup[row.profile_kind]
assert profile_id is not None
was_cached = get_model_cached(client, base_url, row.model_name)
timeout_s = args.timeout_cached if was_cached else args.timeout_download
print(f"[run] cached={was_cached} timeout={timeout_s}s", flush=True)
t0 = time.time()
result = ModelResult(
label=row.label,
engine=row.engine,
model_size=row.model_size,
status="failed",
was_cached=was_cached,
)
try:
status, payload = run_one_generation(client, base_url, row, profile_id, timeout_s)
result.status = "passed" if status == "completed" else status
result.generation_id = payload.get("id")
result.audio_duration = payload.get("duration")
result.error = payload.get("error")
if status == "completed" and result.generation_id:
audio_path, audio_bytes = fetch_audio_info(
client, base_url, result.generation_id, data_dir
)
result.audio_path = audio_path
result.audio_bytes = audio_bytes
if audio_bytes is not None and audio_bytes == 0:
result.status = "failed"
result.error = (result.error or "") + " (audio file is empty)"
except httpx.HTTPStatusError as e:
result.status = "failed"
result.http_status = e.response.status_code
try:
detail = e.response.json().get("detail")
except Exception:
detail = e.response.text
result.error = f"HTTP {e.response.status_code}: {detail}"
except Exception as e:
result.status = "failed"
result.error = f"{type(e).__name__}: {e}"
result.elapsed_seconds = round(time.time() - t0, 2)
if result.status != "passed":
result.server_log_tail = server.log_tail(100)
print(f"[run] {row.label}{result.status} in {result.elapsed_seconds}s"
+ (f" ({result.error})" if result.error else ""), flush=True)
results.append(result)
finally:
finished_at = datetime.now(timezone.utc)
server.stop()
if not args.keep_data_dir:
shutil.rmtree(data_dir, ignore_errors=True)
else:
print(f"[cleanup] keeping data dir: {data_dir}", flush=True)
json_path, md_path = write_reports(args.output_dir, binary, started_at, finished_at, results)
print(f"\n[report] {json_path}")
print(f"[report] {md_path}")
print(f"[report] server log: {log_path}")
passed = sum(1 for r in results if r.status == "passed")
failed = len(results) - passed
print(f"\n== {passed} passed, {failed} failed ==")
return 0 if failed == 0 else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,112 @@
"""
Unit tests for reference-audio preprocessing.
Covers :func:`backend.utils.audio.preprocess_reference_audio` and
:func:`backend.utils.audio.validate_and_load_reference_audio`.
"""
import sys
from pathlib import Path
import numpy as np
import pytest
import soundfile as sf
sys.path.insert(0, str(Path(__file__).parent.parent))
from utils.audio import ( # noqa: E402
preprocess_reference_audio,
validate_and_load_reference_audio,
)
SR = 24000
def _tone(duration_s: float, amp: float = 0.3, freq: float = 220.0) -> np.ndarray:
n = int(duration_s * SR)
t = np.arange(n, dtype=np.float32) / SR
return (amp * np.sin(2 * np.pi * freq * t)).astype(np.float32)
def test_peak_cap_scales_hot_input():
audio = _tone(3.0, amp=0.99)
out = preprocess_reference_audio(audio, SR)
assert np.abs(out).max() <= 0.951
def test_peak_cap_leaves_moderate_input_untouched():
audio = _tone(3.0, amp=0.5)
out = preprocess_reference_audio(audio, SR)
assert np.isclose(np.abs(out).max(), 0.5, atol=1e-3)
def test_dc_offset_removed():
audio = _tone(3.0, amp=0.3) + 0.1
out = preprocess_reference_audio(audio, SR)
assert abs(float(np.mean(out))) < 1e-3
def test_silence_is_trimmed_with_padding_kept():
silence = np.zeros(int(SR * 1.0), dtype=np.float32)
speech = _tone(3.0, amp=0.3)
audio = np.concatenate([silence, speech, silence])
out = preprocess_reference_audio(audio, SR)
# Most of the 2s of leading/trailing silence should be gone, but the
# 3s of speech plus ~200ms of padding should remain.
assert len(audio) - len(out) >= SR, "expected >=1s of silence trimmed"
assert len(out) >= int(3.0 * SR), "speech body should be preserved"
def test_clean_audio_is_not_padded_past_original_length():
# Well-recorded audio with no edge silence shouldn't get longer after
# preprocessing — otherwise a 29.9 s upload could be pushed past the
# 30 s max_duration ceiling downstream.
audio = _tone(3.0, amp=0.3)
out = preprocess_reference_audio(audio, SR)
assert len(out) <= len(audio)
def test_empty_input_returns_empty():
out = preprocess_reference_audio(np.zeros(0, dtype=np.float32), SR)
assert out.size == 0
def test_validate_accepts_previously_rejected_hot_file(tmp_path):
audio = _tone(3.0, amp=0.995)
path = tmp_path / "hot.wav"
sf.write(str(path), audio, SR)
ok, err, out_audio, out_sr = validate_and_load_reference_audio(str(path))
assert ok, f"expected pass, got error: {err}"
assert out_audio is not None
assert out_sr == SR
assert np.abs(out_audio).max() <= 0.951
def test_validate_still_rejects_silent_input(tmp_path):
audio = np.zeros(int(SR * 3.0), dtype=np.float32)
path = tmp_path / "silent.wav"
sf.write(str(path), audio, SR)
ok, err, _, _ = validate_and_load_reference_audio(str(path))
assert not ok
assert err is not None
assert "too short" in err.lower() or "quiet" in err.lower()
def test_validate_rejects_too_short(tmp_path):
audio = _tone(0.5, amp=0.3)
path = tmp_path / "short.wav"
sf.write(str(path), audio, SR)
ok, err, _, _ = validate_and_load_reference_audio(str(path))
assert not ok
assert "too short" in (err or "").lower()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

162
backend/tests/test_cors.py Normal file
View File

@@ -0,0 +1,162 @@
"""
Tests for CORS origin restrictions.
Validates that the CORS middleware only allows known local origins
and respects the VOICEBOX_CORS_ORIGINS environment variable.
Uses a minimal FastAPI app that mirrors the exact CORS configuration
from backend/main.py, so tests run without heavy ML dependencies.
Usage:
pip install httpx pytest fastapi starlette
python -m pytest backend/tests/test_cors.py -v
"""
import os
import pytest
from unittest.mock import patch
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.testclient import TestClient
def _build_app(env_origins: str = "") -> FastAPI:
"""
Build a minimal FastAPI app with the same CORS logic as backend/main.py.
This mirrors the exact code in main.py so the test validates the real
configuration without needing torch/numpy/transformers installed.
"""
app = FastAPI()
_default_origins = [
"http://localhost:5173",
"http://127.0.0.1:5173",
"http://localhost:17493",
"http://127.0.0.1:17493",
"tauri://localhost",
"https://tauri.localhost",
]
_cors_origins = _default_origins + [o.strip() for o in env_origins.split(",") if o.strip()]
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
async def health():
return {"status": "ok"}
return app
@pytest.fixture()
def client():
return TestClient(_build_app())
@pytest.fixture()
def client_with_custom_origins():
return TestClient(_build_app("https://custom.example.com,https://other.example.com"))
def _get_with_origin(client: TestClient, origin: str) -> dict:
"""Send a GET with Origin header, return response headers."""
response = client.get("/health", headers={"Origin": origin})
return dict(response.headers)
def _preflight(client: TestClient, origin: str) -> dict:
"""Send CORS preflight OPTIONS request, return response headers."""
response = client.options(
"/health",
headers={
"Origin": origin,
"Access-Control-Request-Method": "GET",
},
)
return dict(response.headers)
class TestCORSDefaultOrigins:
"""CORS should allow known local origins and block everything else."""
@pytest.mark.parametrize("origin", [
"http://localhost:5173",
"http://127.0.0.1:5173",
"http://localhost:17493",
"http://127.0.0.1:17493",
"tauri://localhost",
"https://tauri.localhost",
])
def test_allowed_origins(self, client, origin):
headers = _get_with_origin(client, origin)
assert headers.get("access-control-allow-origin") == origin
@pytest.mark.parametrize("origin", [
"http://evil.com",
"http://localhost:9999",
"https://attacker.example.com",
"null",
])
def test_blocked_origins(self, client, origin):
headers = _get_with_origin(client, origin)
assert "access-control-allow-origin" not in headers
def test_preflight_allowed(self, client):
headers = _preflight(client, "http://localhost:5173")
assert headers.get("access-control-allow-origin") == "http://localhost:5173"
def test_preflight_blocked(self, client):
headers = _preflight(client, "http://evil.com")
assert "access-control-allow-origin" not in headers
def test_credentials_header_present(self, client):
headers = _get_with_origin(client, "http://localhost:5173")
assert headers.get("access-control-allow-credentials") == "true"
class TestCORSCustomOrigins:
"""VOICEBOX_CORS_ORIGINS env var should extend the allowlist."""
def test_custom_origin_allowed(self, client_with_custom_origins):
headers = _get_with_origin(client_with_custom_origins, "https://custom.example.com")
assert headers.get("access-control-allow-origin") == "https://custom.example.com"
def test_other_custom_origin_allowed(self, client_with_custom_origins):
headers = _get_with_origin(client_with_custom_origins, "https://other.example.com")
assert headers.get("access-control-allow-origin") == "https://other.example.com"
def test_default_origins_still_work(self, client_with_custom_origins):
headers = _get_with_origin(client_with_custom_origins, "http://localhost:5173")
assert headers.get("access-control-allow-origin") == "http://localhost:5173"
def test_unlisted_origin_still_blocked(self, client_with_custom_origins):
headers = _get_with_origin(client_with_custom_origins, "http://evil.com")
assert "access-control-allow-origin" not in headers
class TestCORSEnvVarParsing:
"""Edge cases for VOICEBOX_CORS_ORIGINS parsing."""
def test_empty_env_var(self):
app = _build_app("")
client = TestClient(app)
headers = _get_with_origin(client, "http://evil.com")
assert "access-control-allow-origin" not in headers
def test_whitespace_trimmed(self):
app = _build_app(" https://spaced.example.com ")
client = TestClient(app)
headers = _get_with_origin(client, "https://spaced.example.com")
assert headers.get("access-control-allow-origin") == "https://spaced.example.com"
def test_trailing_comma_ignored(self):
app = _build_app("https://one.example.com,")
client = TestClient(app)
headers = _get_with_origin(client, "https://one.example.com")
assert headers.get("access-control-allow-origin") == "https://one.example.com"

View File

@@ -0,0 +1,305 @@
"""
Test TTS generation with SSE progress monitoring.
This test captures the exact SSE events triggered during generation
to identify UX issues where users see download progress even when
the model is already cached.
"""
import asyncio
import json
import httpx
from typing import List, Dict, Optional
from datetime import datetime
async def monitor_sse_stream(model_name: str, timeout: int = 120):
"""Monitor SSE stream for a model during generation."""
events: List[Dict] = []
url = f"http://localhost:8000/models/progress/{model_name}"
print(f"[{_timestamp()}] Connecting to SSE endpoint: {url}")
try:
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("GET", url) as response:
print(f"[{_timestamp()}] SSE connected, status: {response.status_code}")
if response.status_code != 200:
print(f"[{_timestamp()}] Error: SSE endpoint returned {response.status_code}")
return events
async for line in response.aiter_lines():
if not line:
continue
timestamp = _timestamp()
if line.startswith("data: "):
try:
data = json.loads(line[6:])
print(
f"[{timestamp}] → SSE Event: {data['status']:12} {data.get('progress', 0):6.1f}% {data.get('filename', '')}"
)
events.append({**data, "_timestamp": timestamp})
# Stop if complete or error
if data.get("status") in ("complete", "error"):
print(f"[{timestamp}] → Model {data['status']}!")
break
except json.JSONDecodeError as e:
print(f"[{timestamp}] Error parsing JSON: {e}")
print(f" Line was: {line}")
elif line.startswith(": heartbeat"):
print(f"[{timestamp}] ♥ heartbeat")
except asyncio.TimeoutError:
print(f"[{_timestamp()}] SSE monitoring timed out")
except Exception as e:
print(f"[{_timestamp()}] SSE error: {e}")
return events
async def trigger_generation(profile_id: str, text: str, model_size: str = "1.7B"):
"""Trigger TTS generation via the API."""
url = "http://localhost:8000/generate"
print(f"\n[{_timestamp()}] Triggering generation...")
print(f" Profile: {profile_id}")
print(f" Text: {text[:50]}...")
print(f" Model: {model_size}")
try:
async with httpx.AsyncClient(timeout=120) as client:
response = await client.post(
url,
json={
"profile_id": profile_id,
"text": text,
"language": "en",
"model_size": model_size,
},
)
print(f"[{_timestamp()}] Response: {response.status_code}")
if response.status_code == 200:
result = response.json()
print(f"[{_timestamp()}] ✓ Generation successful!")
print(f" Generation ID: {result.get('id')}")
print(f" Duration: {result.get('duration', 0):.2f}s")
return True, result
elif response.status_code == 202:
# Model is being downloaded
result = response.json()
print(f"[{_timestamp()}] → Model download in progress")
print(f" Detail: {result}")
return False, result
else:
print(f"[{_timestamp()}] ✗ Error: {response.text}")
return False, None
except Exception as e:
print(f"[{_timestamp()}] ✗ Exception: {e}")
return False, None
async def get_first_profile():
"""Get the first available voice profile."""
url = "http://localhost:8000/profiles"
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(url)
if response.status_code == 200:
profiles = response.json()
if profiles:
return profiles[0]["id"]
except Exception as e:
print(f"Error getting profiles: {e}")
return None
async def check_server():
"""Check if the server is running."""
try:
async with httpx.AsyncClient(timeout=5) as client:
response = await client.get("http://localhost:8000/health")
return response.status_code == 200
except Exception as e:
print(f"Server not running: {e}")
return False
def _timestamp():
"""Get current timestamp for logging."""
return datetime.now().strftime("%H:%M:%S.%f")[:-3]
async def test_generation_with_cached_model():
"""
Test Case 1: Generation when model is already cached.
This should NOT show any download progress events.
If it does, that's the UX bug we're trying to fix.
"""
print("\n" + "=" * 80)
print("TEST CASE 1: Generation with Cached Model")
print("=" * 80)
print("Expected: No download progress events (or minimal/instant completion)")
print("Actual UX Issue: Users see 'started' and 'finished' events even for cached models")
print("=" * 80)
model_size = "1.7B"
model_name = f"qwen-tts-{model_size}"
# Get a profile
profile_id = await get_first_profile()
if not profile_id:
print("✗ No voice profiles found. Please create a profile first.")
return False
print(f"\nUsing profile: {profile_id}")
# Start SSE monitor BEFORE triggering generation
monitor_task = asyncio.create_task(monitor_sse_stream(model_name, timeout=30))
# Wait for SSE to connect
await asyncio.sleep(1)
# Trigger generation
test_text = "Hello, this is a test of the voice generation system."
success, result = await trigger_generation(profile_id, test_text, model_size)
if not success and result and result.get("downloading"):
print("\n⚠ Model is being downloaded. Waiting for download to complete...")
# Wait for SSE monitor to capture download events
events = await monitor_task
return events
# Wait a bit more to catch any progress events
await asyncio.sleep(3)
# Cancel SSE monitor
monitor_task.cancel()
try:
events = await monitor_task
except asyncio.CancelledError:
events = []
return events
async def test_generation_with_fresh_download():
"""
Test Case 2: Generation when model needs to be downloaded.
This SHOULD show download progress events.
"""
print("\n" + "=" * 80)
print("TEST CASE 2: Generation with Model Download")
print("=" * 80)
print("Expected: Download progress events from 0% to 100%")
print("=" * 80)
# Use a different model size to force download
model_size = "0.6B" # Smaller model for faster testing
model_name = f"qwen-tts-{model_size}"
# Get a profile
profile_id = await get_first_profile()
if not profile_id:
print("✗ No voice profiles found. Please create a profile first.")
return False
print(f"\nUsing profile: {profile_id}")
print("Note: This will download the model if not cached")
# Start SSE monitor BEFORE triggering generation
monitor_task = asyncio.create_task(monitor_sse_stream(model_name, timeout=300))
# Wait for SSE to connect
await asyncio.sleep(1)
# Trigger generation
test_text = "This should trigger a model download if the model is not cached."
success, result = await trigger_generation(profile_id, test_text, model_size)
if not success and result and result.get("downloading"):
print("\n→ Model download initiated. Monitoring progress...")
# Wait for download to complete
events = await monitor_task
# Try generation again
print(f"\n[{_timestamp()}] Retrying generation after download...")
await asyncio.sleep(2)
success, result = await trigger_generation(profile_id, test_text, model_size)
if success:
print("✓ Generation successful after download")
return events
# If model was already cached
await asyncio.sleep(3)
monitor_task.cancel()
try:
events = await monitor_task
except asyncio.CancelledError:
events = []
return events
async def main():
print("=" * 80)
print("TTS Generation Progress Test")
print("=" * 80)
print("Purpose: Capture exact SSE events during generation to identify UX issues")
print("=" * 80)
# Check if server is running
print(f"\n[{_timestamp()}] Checking if server is running...")
if not await check_server():
print("✗ Server is not running on http://localhost:8000")
print("\nPlease start the server first:")
print(" cd backend && python main.py")
return False
print("✓ Server is running")
# Test Case 1: Cached model
print("\n" + "🧪 " * 20)
events_cached = await test_generation_with_cached_model()
# Results for Test Case 1
print("\n" + "=" * 80)
print("TEST CASE 1 RESULTS: Generation with Cached Model")
print("=" * 80)
if not events_cached:
print("✓ GOOD: No SSE progress events received")
print(" This is the expected behavior for a cached model.")
else:
print(f"⚠ ISSUE FOUND: Received {len(events_cached)} SSE events:")
print("\nEvent Timeline:")
for i, event in enumerate(events_cached, 1):
timestamp = event.pop("_timestamp", "??:??:??.???")
print(f" {i}. [{timestamp}] {event}")
print("\n⚠ This explains the UX issue!")
print(" Users see progress events even when the model is already cached,")
print(" making them think the model is downloading again.")
print("\n" + "=" * 80)
print("Test Complete!")
print("=" * 80)
return True
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,118 @@
"""
Unit tests for the ``force_offline_if_cached`` helper.
Verifies that the helper mutates the cached module constants in
``huggingface_hub.constants`` and ``transformers.utils.hub`` — not just
``os.environ`` — and that concurrent users are refcount-coordinated so
one thread's exit can't strip another thread's offline protection.
NOTE: These tests mutate process-global state in ``huggingface_hub.constants``
and ``transformers.utils.hub``. They are not safe under cross-process
parallelism (e.g. ``pytest-xdist`` with ``--dist=loadfile``/``loadscope``);
run this file serially.
"""
import os
import sys
import threading
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
from utils.hf_offline_patch import force_offline_if_cached # noqa: E402
def _hf_const():
import huggingface_hub.constants as hf_const
return hf_const
def _tf_hub():
import transformers.utils.hub as tf_hub
return tf_hub
def test_mutates_cached_huggingface_hub_constant():
original = _hf_const().HF_HUB_OFFLINE
with force_offline_if_cached(True, "t"):
assert _hf_const().HF_HUB_OFFLINE is True
assert original == _hf_const().HF_HUB_OFFLINE
def test_mutates_cached_transformers_constant():
original = _tf_hub()._is_offline_mode
with force_offline_if_cached(True, "t"):
assert _tf_hub()._is_offline_mode is True
assert original == _tf_hub()._is_offline_mode
def test_sets_env_variable():
original = os.environ.get("HF_HUB_OFFLINE")
with force_offline_if_cached(True, "t"):
assert "1" == os.environ.get("HF_HUB_OFFLINE")
assert original == os.environ.get("HF_HUB_OFFLINE")
def test_noop_when_not_cached():
before = _hf_const().HF_HUB_OFFLINE
with force_offline_if_cached(False, "t"):
assert before == _hf_const().HF_HUB_OFFLINE
def test_nested_contexts_respect_refcount():
original = _hf_const().HF_HUB_OFFLINE
with force_offline_if_cached(True, "outer"):
assert _hf_const().HF_HUB_OFFLINE is True
with force_offline_if_cached(True, "inner"):
assert _hf_const().HF_HUB_OFFLINE is True
# inner exit must not restore while outer is still active
assert _hf_const().HF_HUB_OFFLINE is True
assert original == _hf_const().HF_HUB_OFFLINE
def test_concurrent_threads_share_offline_window():
"""A slow thread must keep seeing offline mode even if a peer exits first."""
original = _hf_const().HF_HUB_OFFLINE
observations: list[bool] = []
errors: list[Exception] = []
barrier = threading.Barrier(2)
fast_exited = threading.Event()
def slow():
try:
with force_offline_if_cached(True, "slow"):
barrier.wait(timeout=5)
assert fast_exited.wait(timeout=5), "fast thread did not exit"
observations.append(_hf_const().HF_HUB_OFFLINE)
except Exception as exc: # noqa: BLE001
errors.append(exc)
def fast():
try:
with force_offline_if_cached(True, "fast"):
barrier.wait(timeout=5)
except Exception as exc: # noqa: BLE001
errors.append(exc)
finally:
fast_exited.set()
t_slow = threading.Thread(target=slow)
t_fast = threading.Thread(target=fast)
t_slow.start()
t_fast.start()
t_slow.join(timeout=5)
t_fast.join(timeout=5)
assert not t_slow.is_alive(), "slow thread did not finish"
assert not t_fast.is_alive(), "fast thread did not finish"
assert not errors, errors
assert observations == [True], "slow thread lost offline protection"
assert original == _hf_const().HF_HUB_OFFLINE
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,113 @@
"""
Unit tests for ``patch_transformers_mistral_regex``.
Verifies that our wrapper around
``transformers.PreTrainedTokenizerBase._patch_mistral_regex`` catches
exceptions from the unconditional ``huggingface_hub.model_info()`` lookup
and returns the tokenizer unchanged — matching the success-path behavior
for non-Mistral repos (transformers 4.57.3, ``tokenization_utils_base.py:2503``).
NOTE: These tests mutate ``transformers.PreTrainedTokenizerBase`` globally;
run serially, not under ``pytest-xdist`` with per-worker process isolation.
"""
import sys
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
from huggingface_hub.errors import OfflineModeIsEnabled # noqa: E402
from transformers.tokenization_utils_base import PreTrainedTokenizerBase # noqa: E402
import utils.hf_offline_patch as hf_offline_patch # noqa: E402
@pytest.fixture(autouse=True)
def restore_mistral_regex():
"""Snapshot the current ``_patch_mistral_regex`` and restore after each test."""
saved = PreTrainedTokenizerBase.__dict__.get("_patch_mistral_regex")
saved_flag = hf_offline_patch._mistral_regex_patched
try:
yield
finally:
if saved is not None:
PreTrainedTokenizerBase._patch_mistral_regex = saved
hf_offline_patch._mistral_regex_patched = saved_flag
def _apply_patch():
hf_offline_patch._mistral_regex_patched = False
hf_offline_patch.patch_transformers_mistral_regex()
def test_suppresses_offline_mode_is_enabled(monkeypatch):
_apply_patch()
import huggingface_hub
def raise_offline(*_args, **_kwargs):
raise OfflineModeIsEnabled("offline")
monkeypatch.setattr(huggingface_hub, "model_info", raise_offline)
sentinel = object()
result = PreTrainedTokenizerBase._patch_mistral_regex(
sentinel, "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
)
assert result is sentinel
def test_suppresses_connection_errors(monkeypatch):
_apply_patch()
import huggingface_hub
def raise_connection(*_args, **_kwargs):
raise ConnectionError("network unreachable")
monkeypatch.setattr(huggingface_hub, "model_info", raise_connection)
sentinel = object()
result = PreTrainedTokenizerBase._patch_mistral_regex(
sentinel, "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
)
assert result is sentinel
def test_passthrough_on_success(monkeypatch):
"""When model_info returns non-Mistral tags the original falls through and returns the tokenizer unchanged."""
_apply_patch()
import huggingface_hub
class FakeInfo:
tags = ["model-type:qwen", "language:en"]
monkeypatch.setattr(huggingface_hub, "model_info", lambda *_a, **_kw: FakeInfo())
sentinel = object()
result = PreTrainedTokenizerBase._patch_mistral_regex(
sentinel, "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
)
assert result is sentinel
def test_idempotent():
_apply_patch()
first = PreTrainedTokenizerBase._patch_mistral_regex
hf_offline_patch.patch_transformers_mistral_regex()
second = PreTrainedTokenizerBase._patch_mistral_regex
assert first.__func__ is second.__func__
def test_missing_method_is_noop(monkeypatch):
monkeypatch.delattr(PreTrainedTokenizerBase, "_patch_mistral_regex", raising=False)
hf_offline_patch._mistral_regex_patched = False
hf_offline_patch.patch_transformers_mistral_regex()
assert hf_offline_patch._mistral_regex_patched is False
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,217 @@
"""
Tests for profile duplicate name validation.
This test suite verifies that the application correctly handles
duplicate profile names and provides user-friendly error messages.
"""
import pytest
import tempfile
import shutil
from pathlib import Path
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Add parent directory to path to import backend modules
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from database import Base, VoiceProfile as DBVoiceProfile
from models import VoiceProfileCreate
from profiles import create_profile, update_profile
@pytest.fixture
def test_db():
"""Create a temporary test database."""
# Create temporary directory for test database
temp_dir = tempfile.mkdtemp()
db_path = Path(temp_dir) / "test.db"
# Create engine and session
engine = create_engine(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=engine)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
yield db
# Cleanup
db.close()
shutil.rmtree(temp_dir)
@pytest.fixture
def mock_profiles_dir(monkeypatch, tmp_path):
"""Mock the profiles directory to use a temporary path."""
from backend import config
monkeypatch.setattr(config, 'get_profiles_dir', lambda: tmp_path)
return tmp_path
@pytest.mark.asyncio
async def test_create_profile_duplicate_name_raises_error(test_db, mock_profiles_dir):
"""Test that creating a profile with a duplicate name raises a ValueError."""
# Create first profile
profile_data_1 = VoiceProfileCreate(
name="Test Profile",
description="First profile",
language="en"
)
profile_1 = await create_profile(profile_data_1, test_db)
assert profile_1.name == "Test Profile"
# Try to create second profile with same name
profile_data_2 = VoiceProfileCreate(
name="Test Profile",
description="Second profile",
language="en"
)
with pytest.raises(ValueError) as exc_info:
await create_profile(profile_data_2, test_db)
# Verify error message is user-friendly
assert "already exists" in str(exc_info.value)
assert "Test Profile" in str(exc_info.value)
assert "choose a different name" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_create_profile_different_names_succeeds(test_db, mock_profiles_dir):
"""Test that creating profiles with different names succeeds."""
# Create first profile
profile_data_1 = VoiceProfileCreate(
name="Profile One",
description="First profile",
language="en"
)
profile_1 = await create_profile(profile_data_1, test_db)
assert profile_1.name == "Profile One"
# Create second profile with different name
profile_data_2 = VoiceProfileCreate(
name="Profile Two",
description="Second profile",
language="en"
)
profile_2 = await create_profile(profile_data_2, test_db)
assert profile_2.name == "Profile Two"
# Verify both profiles exist
assert profile_1.id != profile_2.id
@pytest.mark.asyncio
async def test_update_profile_to_duplicate_name_raises_error(test_db, mock_profiles_dir):
"""Test that updating a profile to a duplicate name raises a ValueError."""
# Create two profiles with different names
profile_data_1 = VoiceProfileCreate(
name="Profile A",
description="First profile",
language="en"
)
profile_1 = await create_profile(profile_data_1, test_db)
profile_data_2 = VoiceProfileCreate(
name="Profile B",
description="Second profile",
language="en"
)
profile_2 = await create_profile(profile_data_2, test_db)
# Try to update profile_2 to use profile_1's name
update_data = VoiceProfileCreate(
name="Profile A", # Duplicate name
description="Updated description",
language="en"
)
with pytest.raises(ValueError) as exc_info:
await update_profile(profile_2.id, update_data, test_db)
# Verify error message is user-friendly
assert "already exists" in str(exc_info.value)
assert "Profile A" in str(exc_info.value)
@pytest.mark.asyncio
async def test_update_profile_keep_same_name_succeeds(test_db, mock_profiles_dir):
"""Test that updating a profile while keeping the same name succeeds."""
# Create profile
profile_data = VoiceProfileCreate(
name="My Profile",
description="Original description",
language="en"
)
profile = await create_profile(profile_data, test_db)
# Update profile with same name but different description
update_data = VoiceProfileCreate(
name="My Profile", # Same name
description="Updated description",
language="en"
)
updated_profile = await update_profile(profile.id, update_data, test_db)
# Verify update succeeded
assert updated_profile is not None
assert updated_profile.id == profile.id
assert updated_profile.name == "My Profile"
assert updated_profile.description == "Updated description"
@pytest.mark.asyncio
async def test_update_profile_to_new_unique_name_succeeds(test_db, mock_profiles_dir):
"""Test that updating a profile to a new unique name succeeds."""
# Create profile
profile_data = VoiceProfileCreate(
name="Original Name",
description="Profile description",
language="en"
)
profile = await create_profile(profile_data, test_db)
# Update profile with new unique name
update_data = VoiceProfileCreate(
name="New Unique Name",
description="Updated description",
language="en"
)
updated_profile = await update_profile(profile.id, update_data, test_db)
# Verify update succeeded
assert updated_profile is not None
assert updated_profile.id == profile.id
assert updated_profile.name == "New Unique Name"
@pytest.mark.asyncio
async def test_case_sensitive_names_allowed(test_db, mock_profiles_dir):
"""Test that profile names are case-sensitive (e.g., 'Test' and 'test' are different)."""
# Create profile with lowercase name
profile_data_1 = VoiceProfileCreate(
name="test profile",
description="Lowercase",
language="en"
)
profile_1 = await create_profile(profile_data_1, test_db)
# Create profile with different case
profile_data_2 = VoiceProfileCreate(
name="Test Profile",
description="Title case",
language="en"
)
profile_2 = await create_profile(profile_data_2, test_db)
# Both should succeed since SQLite unique constraint is case-sensitive by default
assert profile_1.name == "test profile"
assert profile_2.name == "Test Profile"
assert profile_1.id != profile_2.id

View File

@@ -0,0 +1,313 @@
"""
Test script to debug model download progress tracking.
"""
import asyncio
import json
import time
from typing import List, Dict
import logging
# Set up logging to see what's happening
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
from utils.progress import ProgressManager, get_progress_manager
from utils.hf_progress import HFProgressTracker, create_hf_progress_callback
def test_progress_manager_basic():
"""Test 1: Basic ProgressManager functionality."""
print("\n" + "=" * 60)
print("Test 1: ProgressManager Basic Operations")
print("=" * 60)
pm = ProgressManager()
# Test update_progress
pm.update_progress(
model_name="test-model",
current=50,
total=100,
filename="test.bin",
status="downloading"
)
# Test get_progress
progress = pm.get_progress("test-model")
print(f"✓ Progress stored: {progress}")
assert progress is not None
assert progress["progress"] == 50.0
assert progress["filename"] == "test.bin"
assert progress["status"] == "downloading"
# Test mark_complete
pm.mark_complete("test-model")
progress = pm.get_progress("test-model")
print(f"✓ Marked complete: {progress}")
assert progress["status"] == "complete"
assert progress["progress"] == 100.0
print("✓ Test 1 PASSED\n")
return True
async def test_progress_manager_sse():
"""Test 2: ProgressManager SSE streaming."""
print("\n" + "=" * 60)
print("Test 2: ProgressManager SSE Streaming")
print("=" * 60)
pm = ProgressManager()
collected_events: List[Dict] = []
# Simulate SSE client
async def sse_client():
"""Simulates a frontend SSE connection."""
print(" SSE client: Subscribing to test-model-sse...")
async for event in pm.subscribe("test-model-sse"):
# Parse SSE event
if event.startswith("data: "):
data = json.loads(event[6:])
print(f" SSE client: Received event: {data['status']} - {data.get('progress', 0):.1f}%")
collected_events.append(data)
# Stop when complete
if data.get("status") in ("complete", "error"):
break
elif event.startswith(": heartbeat"):
print(" SSE client: Received heartbeat")
# Simulate download progress updates (from backend thread)
async def simulate_download():
"""Simulates backend sending progress updates."""
print(" Backend: Starting simulated download...")
await asyncio.sleep(0.2) # Let SSE client subscribe first
# Send progress updates
for i in range(0, 101, 20):
print(f" Backend: Updating progress to {i}%")
pm.update_progress(
model_name="test-model-sse",
current=i,
total=100,
filename=f"file_{i}.bin",
status="downloading" if i < 100 else "downloading"
)
await asyncio.sleep(0.1)
# Mark complete
print(" Backend: Marking download complete")
pm.mark_complete("test-model-sse")
# Run SSE client and download simulation concurrently
await asyncio.gather(
sse_client(),
simulate_download()
)
# Verify we got events
print(f"\n Collected {len(collected_events)} events")
assert len(collected_events) > 0, "Should have received at least one event"
assert collected_events[-1]["status"] == "complete", "Last event should be 'complete'"
print("✓ Test 2 PASSED\n")
return True
def test_hf_progress_tracker():
"""Test 3: HFProgressTracker tqdm patching."""
print("\n" + "=" * 60)
print("Test 3: HFProgressTracker tqdm Patching")
print("=" * 60)
captured_progress: List[tuple] = []
def progress_callback(downloaded: int, total: int, filename: str):
"""Capture progress updates."""
captured_progress.append((downloaded, total, filename))
print(f" Progress callback: {downloaded}/{total} bytes ({filename})")
tracker = HFProgressTracker(progress_callback)
# Simulate a download with tqdm
with tracker.patch_download():
try:
from tqdm import tqdm
# Simulate downloading a file
print(" Simulating download with tqdm...")
total_size = 1000
with tqdm(total=total_size, desc="model.bin", unit="B", unit_scale=True) as pbar:
for chunk in range(0, total_size, 100):
pbar.update(100)
time.sleep(0.01)
print(f" Captured {len(captured_progress)} progress updates")
assert len(captured_progress) > 0, "Should have captured progress updates"
# Verify progress increases
last_downloaded = 0
for downloaded, total, filename in captured_progress:
assert downloaded >= last_downloaded, "Downloaded bytes should increase"
assert total == total_size, "Total should be consistent"
last_downloaded = downloaded
print("✓ Test 3 PASSED\n")
return True
except ImportError:
print("✗ tqdm not available, skipping test\n")
return None
async def test_full_integration():
"""Test 4: Full integration test."""
print("\n" + "=" * 60)
print("Test 4: Full Integration (ProgressManager + HFProgressTracker)")
print("=" * 60)
pm = get_progress_manager()
collected_events: List[Dict] = []
# SSE client
async def sse_client():
print(" SSE client: Subscribing...")
async for event in pm.subscribe("integration-test"):
if event.startswith("data: "):
data = json.loads(event[6:])
print(f" SSE client: {data['status']} - {data.get('progress', 0):.1f}% - {data.get('filename', '')}")
collected_events.append(data)
if data.get("status") in ("complete", "error"):
break
# Simulate backend download with HFProgressTracker
async def simulate_real_download():
await asyncio.sleep(0.2) # Let SSE subscribe
print(" Backend: Starting download with HFProgressTracker...")
# Set up tracking (like the real backend does)
progress_callback = create_hf_progress_callback("integration-test", pm)
tracker = HFProgressTracker(progress_callback)
# Initialize progress
pm.update_progress(
model_name="integration-test",
current=0,
total=1,
filename="",
status="downloading"
)
# Simulate download with tqdm patching
with tracker.patch_download():
try:
from tqdm import tqdm
# Simulate multi-file download (like HuggingFace does)
files = [
("model.safetensors", 5000),
("config.json", 1000),
("tokenizer.json", 500),
]
for filename, size in files:
print(f" Backend: Downloading {filename}...")
with tqdm(total=size, desc=filename, unit="B") as pbar:
for chunk in range(0, size, 500):
chunk_size = min(500, size - chunk)
pbar.update(chunk_size)
await asyncio.sleep(0.05)
# Mark complete
print(" Backend: Download complete")
pm.mark_complete("integration-test")
except ImportError:
print(" ✗ tqdm not available")
pm.mark_error("integration-test", "tqdm not available")
# Run both
await asyncio.gather(
sse_client(),
simulate_real_download()
)
# Verify
print(f"\n Collected {len(collected_events)} events")
if len(collected_events) > 0:
print(f" First event: {collected_events[0]}")
print(f" Last event: {collected_events[-1]}")
assert collected_events[-1]["status"] == "complete", "Should end with 'complete'"
print("✓ Test 4 PASSED\n")
return True
else:
print("✗ Test 4 FAILED - No events received\n")
return False
async def main():
"""Run all tests."""
print("\n" + "=" * 60)
print("Voicebox Progress Tracking Test Suite")
print("=" * 60)
results = []
# Test 1: Basic operations
try:
results.append(("Basic Operations", test_progress_manager_basic()))
except Exception as e:
print(f"✗ Test 1 FAILED: {e}\n")
results.append(("Basic Operations", False))
# Test 2: SSE streaming
try:
results.append(("SSE Streaming", await test_progress_manager_sse()))
except Exception as e:
print(f"✗ Test 2 FAILED: {e}\n")
results.append(("SSE Streaming", False))
# Test 3: tqdm patching
try:
results.append(("tqdm Patching", test_hf_progress_tracker()))
except Exception as e:
print(f"✗ Test 3 FAILED: {e}\n")
results.append(("tqdm Patching", False))
# Test 4: Full integration
try:
results.append(("Full Integration", await test_full_integration()))
except Exception as e:
print(f"✗ Test 4 FAILED: {e}\n")
results.append(("Full Integration", False))
# Summary
print("\n" + "=" * 60)
print("Test Results Summary")
print("=" * 60)
for name, result in results:
status = "✓ PASS" if result else ("⊘ SKIP" if result is None else "✗ FAIL")
print(f" {status:8} {name}")
passed = sum(1 for _, r in results if r is True)
failed = sum(1 for _, r in results if r is False)
skipped = sum(1 for _, r in results if r is None)
print()
print(f" Total: {len(results)} tests")
print(f" Passed: {passed}")
print(f" Failed: {failed}")
print(f" Skipped: {skipped}")
print("=" * 60 + "\n")
return failed == 0
if __name__ == "__main__":
success = asyncio.run(main())
exit(0 if success else 1)

View File

@@ -0,0 +1,317 @@
"""
Test Qwen TTS model download with SSE progress monitoring.
This specifically tests the MLX TTS backend download progress tracking,
which requires tqdm to be patched BEFORE mlx_audio is imported.
Usage:
cd backend && python -m tests.test_qwen_download
Prerequisites:
- Server must be running: cd backend && python main.py
- Delete model first for fresh download test:
curl -X DELETE http://localhost:8000/models/qwen-tts-0.6B
"""
import asyncio
import json
import httpx
import time
from typing import List, Dict, Optional
async def monitor_sse_stream(model_name: str, timeout: int = 600) -> List[Dict]:
"""
Monitor SSE stream for a model download.
Args:
model_name: Name of the model to monitor
timeout: Maximum time to wait for download (seconds)
Returns:
List of SSE events received
"""
events: List[Dict] = []
url = f"http://localhost:8000/models/progress/{model_name}"
last_progress = -1
print(f"\n📡 Connecting to SSE endpoint: {url}")
try:
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("GET", url) as response:
print(f" SSE connected, status: {response.status_code}")
if response.status_code != 200:
print(f" ❌ Error: SSE endpoint returned {response.status_code}")
return events
async for line in response.aiter_lines():
if not line:
continue
if line.startswith("data: "):
try:
data = json.loads(line[6:])
events.append(data)
# Print progress (only when it changes significantly)
progress = data.get('progress', 0)
status = data.get('status', 'unknown')
filename = data.get('filename', '')
current = data.get('current', 0)
total = data.get('total', 0)
# Print every 5% change or status change
if abs(progress - last_progress) >= 5 or status in ('complete', 'error'):
current_mb = current / (1024 * 1024)
total_mb = total / (1024 * 1024)
print(f" 📊 {status:12} {progress:6.1f}% ({current_mb:.1f}MB / {total_mb:.1f}MB) {filename[:50]}")
last_progress = progress
# Stop if complete or error
if status in ("complete", "error"):
if status == "complete":
print(f" ✅ Download complete!")
else:
print(f" ❌ Download error: {data.get('error', 'unknown')}")
break
except json.JSONDecodeError as e:
print(f" ⚠️ Error parsing JSON: {e}")
elif line.startswith(": heartbeat"):
# Heartbeat every 1 second, don't spam
pass
except asyncio.CancelledError:
print(" ⏹️ SSE monitor cancelled")
except Exception as e:
print(f" ❌ SSE error: {e}")
return events
async def trigger_download(model_name: str) -> bool:
"""Trigger a model download via the API."""
url = "http://localhost:8000/models/download"
print(f"\n🚀 Triggering download for: {model_name}")
try:
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(url, json={"model_name": model_name})
result = response.json()
print(f" Response: {response.status_code} - {result}")
return response.status_code == 200
except Exception as e:
print(f" ❌ Error triggering download: {e}")
return False
async def delete_model(model_name: str) -> bool:
"""Delete a model from cache."""
url = f"http://localhost:8000/models/{model_name}"
print(f"\n🗑️ Deleting model: {model_name}")
try:
async with httpx.AsyncClient(timeout=30) as client:
response = await client.delete(url)
if response.status_code == 200:
print(f" ✅ Model deleted")
return True
elif response.status_code == 404:
print(f" Model not found (already deleted)")
return True
else:
print(f" ⚠️ Delete response: {response.status_code} - {response.text}")
return False
except Exception as e:
print(f" ❌ Error deleting model: {e}")
return False
async def check_model_status(model_name: str) -> Optional[Dict]:
"""Check the status of a model."""
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get("http://localhost:8000/models/status")
if response.status_code == 200:
data = response.json()
for model in data.get("models", []):
if model["model_name"] == model_name:
return model
except Exception as e:
print(f" ⚠️ Error checking model status: {e}")
return None
async def check_server() -> bool:
"""Check if the server is running."""
try:
async with httpx.AsyncClient(timeout=5) as client:
response = await client.get("http://localhost:8000/health")
return response.status_code == 200
except Exception:
return False
async def main():
print("=" * 70)
print("🧪 Qwen TTS Model Download Progress Test")
print("=" * 70)
print("\nThis test verifies that MLX TTS download progress tracking works.")
print("It specifically tests the tqdm patching for mlx_audio.tts imports.")
# Check if server is running
print("\n📡 Checking if server is running...")
if not await check_server():
print(" ❌ Server is not running on http://localhost:8000")
print("\n Please start the server first:")
print(" cd backend && python main.py")
return False
print(" ✅ Server is running")
# Test model
model_name = "qwen-tts-0.6B"
# Check current status
print(f"\n📊 Checking status of {model_name}...")
status = await check_model_status(model_name)
if status:
print(f" Downloaded: {status.get('downloaded', False)}")
print(f" Downloading: {status.get('downloading', False)}")
print(f" Loaded: {status.get('loaded', False)}")
if status.get('size_mb'):
print(f" Size: {status['size_mb']:.1f} MB")
else:
print(" ⚠️ Could not get model status")
# Ask if user wants to delete first
print("\n" + "-" * 70)
if status and status.get('downloaded'):
print("⚠️ Model is already downloaded. Delete it for a fresh download test?")
print(" [y] Yes, delete and download fresh")
print(" [n] No, just test SSE connection")
print(" [q] Quit")
choice = input("\nChoice [y/n/q]: ").strip().lower()
if choice == 'q':
print("Exiting...")
return True
if choice == 'y':
if not await delete_model(model_name):
print("Failed to delete model. Continue anyway? [y/n]")
if input().strip().lower() != 'y':
return False
else:
print("Model not downloaded. Will perform fresh download test.")
input("Press Enter to continue...")
# Run the test
print("\n" + "=" * 70)
print("🏃 Starting Download Test")
print("=" * 70)
async def run_test():
# Start SSE monitor in background FIRST
monitor_task = asyncio.create_task(monitor_sse_stream(model_name, timeout=600))
# Wait for SSE to connect
await asyncio.sleep(1)
# Trigger download
success = await trigger_download(model_name)
if not success:
print(" ❌ Failed to trigger download")
monitor_task.cancel()
try:
await monitor_task
except asyncio.CancelledError:
pass
return []
# Wait for SSE monitor to complete
print("\n⏳ Waiting for download to complete (this may take several minutes)...")
events = await monitor_task
return events
start_time = time.time()
events = await run_test()
elapsed = time.time() - start_time
# Results
print("\n" + "=" * 70)
print("📋 Test Results")
print("=" * 70)
print(f"\n⏱️ Elapsed time: {elapsed:.1f} seconds")
print(f"📨 Total SSE events received: {len(events)}")
if not events:
print("\n❌ FAILED - No SSE events received!")
print("\nPossible causes:")
print(" 1. SSE endpoint not working")
print(" 2. tqdm not patched before mlx_audio import")
print(" 3. Progress callbacks not firing")
print(" 4. Model already fully downloaded")
print("\nDebug steps:")
print(" 1. Check server logs for [DEBUG] messages")
print(" 2. Look for 'tqdm patched' before 'mlx_audio.tts import'")
print(f" 3. Delete model: curl -X DELETE http://localhost:8000/models/{model_name}")
return False
# Analyze events
first_event = events[0]
last_event = events[-1]
print(f"\n📊 First event:")
print(f" Status: {first_event.get('status')}")
print(f" Progress: {first_event.get('progress', 0):.1f}%")
print(f"\n📊 Last event:")
print(f" Status: {last_event.get('status')}")
print(f" Progress: {last_event.get('progress', 0):.1f}%")
# Check for expected behaviors
has_progress_updates = len(events) > 2
has_increasing_progress = False
has_complete = any(e.get('status') == 'complete' for e in events)
has_100_percent = any(e.get('progress', 0) >= 100 for e in events)
# Check if progress increased over time
if len(events) >= 2:
progress_values = [e.get('progress', 0) for e in events]
has_increasing_progress = progress_values[-1] > progress_values[0]
print("\n📋 Checks:")
print(f" {'' if has_progress_updates else ''} Multiple progress updates received ({len(events)} events)")
print(f" {'' if has_increasing_progress else ''} Progress increased over time")
print(f" {'' if has_100_percent else ''} Reached 100% progress")
print(f" {'' if has_complete else ''} Received 'complete' status")
# Overall result
success = has_progress_updates and has_complete
if success:
print("\n" + "=" * 70)
print("✅ TEST PASSED - Qwen TTS download progress tracking works!")
print("=" * 70)
else:
print("\n" + "=" * 70)
print("❌ TEST FAILED - Progress tracking has issues")
print("=" * 70)
print("\nCheck the server logs for debug output.")
return success
if __name__ == "__main__":
result = asyncio.run(main())
exit(0 if result else 1)

View File

@@ -0,0 +1,54 @@
import asyncio
import pytest
from backend.services import task_queue
@pytest.mark.asyncio
async def test_cancel_queued_generation_skips_execution():
task_queue.init_queue(force=True)
running_started = asyncio.Event()
release_running = asyncio.Event()
queued_ran = asyncio.Event()
async def running_job():
running_started.set()
await release_running.wait()
async def queued_job():
queued_ran.set()
task_queue.enqueue_generation("gen-running", running_job())
await asyncio.wait_for(running_started.wait(), timeout=1)
task_queue.enqueue_generation("gen-queued", queued_job())
assert task_queue.cancel_generation("gen-queued") == "queued"
release_running.set()
await asyncio.sleep(0.1)
assert not queued_ran.is_set()
@pytest.mark.asyncio
async def test_cancel_running_generation_cancels_task():
task_queue.init_queue(force=True)
running_started = asyncio.Event()
running_cancelled = asyncio.Event()
async def running_job():
running_started.set()
try:
await asyncio.Event().wait()
except asyncio.CancelledError:
running_cancelled.set()
raise
task_queue.enqueue_generation("gen-running", running_job())
await asyncio.wait_for(running_started.wait(), timeout=1)
assert task_queue.cancel_generation("gen-running") == "running"
await asyncio.wait_for(running_cancelled.wait(), timeout=1)

View File

@@ -0,0 +1,178 @@
"""
Test real model download with SSE progress monitoring.
"""
import asyncio
import json
import httpx
import time
from typing import List, Dict
async def monitor_sse_stream(model_name: str, timeout: int = 300):
"""Monitor SSE stream for a model download."""
events: List[Dict] = []
url = f"http://localhost:8000/models/progress/{model_name}"
print(f"Connecting to SSE endpoint: {url}")
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("GET", url) as response:
print(f"SSE connected, status: {response.status_code}")
if response.status_code != 200:
print(f"Error: SSE endpoint returned {response.status_code}")
return events
async for line in response.aiter_lines():
if not line:
continue
print(f" Raw SSE: {line[:100]}...") # Print first 100 chars
if line.startswith("data: "):
try:
data = json.loads(line[6:])
print(f"{data['status']:12} {data.get('progress', 0):6.1f}% {data.get('filename', '')}")
events.append(data)
# Stop if complete or error
if data.get("status") in ("complete", "error"):
print(f" Download {data['status']}!")
break
except json.JSONDecodeError as e:
print(f" Error parsing JSON: {e}")
print(f" Line was: {line}")
elif line.startswith(": heartbeat"):
print(" ♥ heartbeat")
return events
async def trigger_download(model_name: str):
"""Trigger a model download via the API."""
url = "http://localhost:8000/models/download"
print(f"\nTriggering download for: {model_name}")
async with httpx.AsyncClient(timeout=300) as client:
response = await client.post(url, json={"model_name": model_name})
print(f"Response: {response.status_code} - {response.json()}")
return response.status_code == 200
async def check_server():
"""Check if the server is running."""
try:
async with httpx.AsyncClient(timeout=5) as client:
response = await client.get("http://localhost:8000/health")
return response.status_code == 200
except Exception as e:
print(f"Server not running: {e}")
return False
async def main():
print("=" * 60)
print("Real Model Download Progress Test")
print("=" * 60)
# Check if server is running
print("\nChecking if server is running...")
if not await check_server():
print("✗ Server is not running on http://localhost:8000")
print("\nPlease start the server first:")
print(" cd backend && python main.py")
return False
print("✓ Server is running")
# Choose a small model for testing
model_name = "whisper-base" # ~150MB, faster to download
print(f"\nUsing model: {model_name}")
# Option to delete model first if it exists
print("\nDo you want to delete the model first to force a fresh download? (y/n)")
# For automated testing, skip deletion prompt
# delete_first = input().strip().lower() == 'y'
delete_first = False
if delete_first:
print(f"Deleting {model_name}...")
async with httpx.AsyncClient(timeout=30) as client:
response = await client.delete(f"http://localhost:8000/models/{model_name}")
print(f"Delete response: {response.status_code}")
print("\n" + "=" * 60)
print("Starting Test")
print("=" * 60)
# Start monitoring SSE stream BEFORE triggering download
async def run_test():
# Start SSE monitor in background
monitor_task = asyncio.create_task(monitor_sse_stream(model_name))
# Wait a bit to ensure SSE is connected
await asyncio.sleep(1)
# Trigger download
success = await trigger_download(model_name)
if not success:
print("✗ Failed to trigger download")
monitor_task.cancel()
return False
# Wait for SSE monitor to complete
events = await monitor_task
return events
events = await run_test()
# Results
print("\n" + "=" * 60)
print("Test Results")
print("=" * 60)
if not events:
print("✗ FAILED - No SSE events received!")
print("\nPossible causes:")
print(" 1. SSE endpoint not working")
print(" 2. Progress updates not being sent")
print(" 3. Model already downloaded (no progress to report)")
print("\nTry deleting the model first to force a fresh download:")
print(f" curl -X DELETE http://localhost:8000/models/{model_name}")
return False
print(f"✓ Received {len(events)} SSE events")
print(f"\nFirst event: {events[0]}")
print(f"Last event: {events[-1]}")
# Check if we got meaningful progress
has_progress = any(e.get('progress', 0) > 0 for e in events)
has_complete = any(e.get('status') == 'complete' for e in events)
if has_progress:
print("✓ Progress updates received")
else:
print("✗ No progress updates (might be already downloaded)")
if has_complete:
print("✓ Download completed successfully")
else:
print("✗ Download did not complete")
success = has_progress and has_complete
if success:
print("\n✓ TEST PASSED - Progress tracking works!")
else:
print("\n⊘ TEST INCONCLUSIVE - Try with a fresh download")
return success
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1 @@
# Utils package

318
backend/utils/audio.py Normal file
View File

@@ -0,0 +1,318 @@
"""
Audio processing utilities.
"""
import numpy as np
import soundfile as sf
import librosa
from typing import Tuple, Optional
def normalize_audio(
audio: np.ndarray,
target_db: float = -20.0,
peak_limit: float = 0.85,
) -> np.ndarray:
"""
Normalize audio to target loudness with peak limiting.
Args:
audio: Input audio array
target_db: Target RMS level in dB
peak_limit: Peak limit (0.0-1.0)
Returns:
Normalized audio array
"""
# Convert to float32
audio = audio.astype(np.float32)
# Calculate current RMS
rms = np.sqrt(np.mean(audio**2))
# Calculate target RMS
target_rms = 10**(target_db / 20)
# Apply gain
if rms > 0:
gain = target_rms / rms
audio = audio * gain
# Peak limiting
audio = np.clip(audio, -peak_limit, peak_limit)
return audio
def load_audio(
path: str,
sample_rate: int = 24000,
mono: bool = True,
) -> Tuple[np.ndarray, int]:
"""
Load audio file with normalization.
Args:
path: Path to audio file
sample_rate: Target sample rate
mono: Convert to mono
Returns:
Tuple of (audio_array, sample_rate)
"""
audio, sr = librosa.load(path, sr=sample_rate, mono=mono)
return audio, sr
def save_audio(
audio: np.ndarray,
path: str,
sample_rate: int = 24000,
) -> None:
"""
Save audio file with atomic write and error handling.
Writes to a temporary file first, then atomically renames to the
target path. This prevents corrupted/partial WAV files if the
process is interrupted mid-write.
Args:
audio: Audio array
path: Output path
sample_rate: Sample rate
Raises:
OSError: If file cannot be written
"""
from pathlib import Path
import os
temp_path = f"{path}.tmp"
try:
# Ensure parent directory exists
Path(path).parent.mkdir(parents=True, exist_ok=True)
# Write to temporary file first (explicit format since .tmp
# extension is not recognised by soundfile)
sf.write(temp_path, audio, sample_rate, format='WAV')
# Atomic rename to final path
os.replace(temp_path, path)
except Exception as e:
# Clean up temp file on failure
try:
if Path(temp_path).exists():
Path(temp_path).unlink()
except Exception:
pass # Best effort cleanup
raise OSError(f"Failed to save audio to {path}: {e}") from e
def trim_tts_output(
audio: np.ndarray,
sample_rate: int = 24000,
frame_ms: int = 20,
silence_threshold_db: float = -40.0,
min_silence_ms: int = 200,
max_internal_silence_ms: int = 1000,
fade_ms: int = 30,
) -> np.ndarray:
"""
Trim trailing silence and post-silence hallucination from TTS output.
Chatterbox sometimes produces ``[speech][silence][hallucinated noise]``.
This detects internal silence gaps longer than *max_internal_silence_ms*
and cuts the audio at that boundary, then trims trailing silence and
applies a short cosine fade-out.
Args:
audio: Input audio array (mono float32)
sample_rate: Sample rate in Hz
frame_ms: Frame size for RMS energy calculation
silence_threshold_db: dB threshold below which a frame is silence
min_silence_ms: Minimum trailing silence to keep
max_internal_silence_ms: Cut after any silence gap longer than this
fade_ms: Cosine fade-out duration in ms
Returns:
Trimmed audio array
"""
frame_len = int(sample_rate * frame_ms / 1000)
if frame_len == 0 or len(audio) < frame_len:
return audio
n_frames = len(audio) // frame_len
threshold_linear = 10 ** (silence_threshold_db / 20)
# Compute per-frame RMS
rms = np.array(
[
np.sqrt(np.mean(audio[i * frame_len : (i + 1) * frame_len] ** 2))
for i in range(n_frames)
]
)
is_speech = rms >= threshold_linear
# Find first speech frame
first_speech = 0
for i, s in enumerate(is_speech):
if s:
first_speech = max(0, i - 1) # keep 1 frame padding
break
# Walk forward from first speech; cut at long internal silence gaps
max_silence_frames = int(max_internal_silence_ms / frame_ms)
consecutive_silence = 0
cut_frame = n_frames
for i in range(first_speech, n_frames):
if is_speech[i]:
consecutive_silence = 0
else:
consecutive_silence += 1
if consecutive_silence >= max_silence_frames:
cut_frame = i - consecutive_silence + 1
break
# Trim trailing silence from the cut point
min_silence_frames = int(min_silence_ms / frame_ms)
end_frame = cut_frame
while end_frame > first_speech and not is_speech[end_frame - 1]:
end_frame -= 1
# Keep a short tail
end_frame = min(end_frame + min_silence_frames, cut_frame)
# Convert frames back to samples
start_sample = first_speech * frame_len
end_sample = min(end_frame * frame_len, len(audio))
trimmed = audio[start_sample:end_sample].copy()
# Cosine fade-out
fade_samples = int(sample_rate * fade_ms / 1000)
if fade_samples > 0 and len(trimmed) > fade_samples:
fade = np.cos(np.linspace(0, np.pi / 2, fade_samples)) ** 2
trimmed[-fade_samples:] *= fade
return trimmed
def preprocess_reference_audio(
audio: np.ndarray,
sample_rate: int,
peak_target: float = 0.95,
trim_top_db: float = 40.0,
edge_padding_ms: int = 100,
) -> np.ndarray:
"""
Clean up a reference-audio sample before validation/storage.
Removes DC offset, trims leading/trailing silence, and caps the peak so a
slightly-hot recording doesn't get rejected downstream as "clipping". The
goal is to accept reasonable real-world recordings — not to repair badly
distorted ones. True clipping artifacts inside the waveform can't be
recovered by peak scaling and will still sound bad.
Args:
audio: Mono audio array.
sample_rate: Sample rate of ``audio`` in Hz.
peak_target: Peak amplitude cap in [0, 1]. Applied only if the input
peak exceeds this value.
trim_top_db: Silence threshold for edge trimming, in dB below peak.
40 dB sits below normal speech dynamic range (≈30 dB) so soft
trailing syllables are preserved, while still catching obvious
leading/trailing silence. Lower values are more aggressive;
librosa's own default is 60.
edge_padding_ms: Milliseconds of padding to add back at each edge
*only if* trimming shortened the waveform, so TTS engines have a
brief silence to anchor on without ever making the output longer
than the input.
Returns:
Preprocessed audio array (float32).
"""
audio = audio.astype(np.float32, copy=False)
if audio.size == 0:
return audio
audio = audio - float(np.mean(audio))
trimmed, _ = librosa.effects.trim(audio, top_db=trim_top_db)
if 0 < trimmed.size < audio.size:
pad_each = int(sample_rate * edge_padding_ms / 1000)
# Never pad past the original length — for near-max-duration uploads
# an unconditional pad would push them over the 30 s ceiling and
# trigger a spurious "too long" rejection.
headroom = (audio.size - trimmed.size) // 2
pad = min(pad_each, max(headroom, 0))
if pad > 0:
trimmed = np.pad(trimmed, (pad, pad), mode="constant")
audio = trimmed
peak = float(np.abs(audio).max())
if peak > peak_target and peak > 0:
audio = audio * (peak_target / peak)
return audio
def validate_reference_audio(
audio_path: str,
min_duration: float = 2.0,
max_duration: float = 30.0,
min_rms: float = 0.01,
) -> Tuple[bool, Optional[str]]:
"""
Validate reference audio for voice cloning.
Args:
audio_path: Path to audio file
min_duration: Minimum duration in seconds
max_duration: Maximum duration in seconds
min_rms: Minimum RMS level
Returns:
Tuple of (is_valid, error_message)
"""
result = validate_and_load_reference_audio(
audio_path, min_duration, max_duration, min_rms
)
return (result[0], result[1])
def validate_and_load_reference_audio(
audio_path: str,
min_duration: float = 2.0,
max_duration: float = 30.0,
min_rms: float = 0.01,
) -> Tuple[bool, Optional[str], Optional[np.ndarray], Optional[int]]:
"""
Validate and load reference audio in a single pass.
Applies :func:`preprocess_reference_audio` before checks so that
slightly-hot recordings aren't rejected as clipping. Duration and RMS
checks run on the preprocessed waveform.
Returns:
Tuple of (is_valid, error_message, audio_array, sample_rate)
"""
try:
audio, sr = load_audio(audio_path)
audio = preprocess_reference_audio(audio, sr)
duration = len(audio) / sr
if duration < min_duration:
return False, f"Audio too short (minimum {min_duration} seconds)", None, None
if duration > max_duration:
return False, f"Audio too long (maximum {max_duration} seconds)", None, None
rms = np.sqrt(np.mean(audio**2))
if rms < min_rms:
return False, "Audio is too quiet or silent", None, None
return True, None, audio, sr
except Exception as e:
return False, f"Error validating audio: {str(e)}", None, None

153
backend/utils/cache.py Normal file
View File

@@ -0,0 +1,153 @@
"""
Voice prompt caching utilities.
"""
import hashlib
import logging
import torch
from pathlib import Path
from typing import Optional, Union, Dict, Any
from .. import config
logger = logging.getLogger(__name__)
def _get_cache_dir() -> Path:
"""Get cache directory from config."""
return config.get_cache_dir()
# In-memory cache - can store dict (voice prompt) or tensor (legacy)
_memory_cache: dict[str, Union[torch.Tensor, Dict[str, Any]]] = {}
def get_cache_key(audio_path: str, reference_text: str) -> str:
"""
Generate cache key from audio file and reference text.
Args:
audio_path: Path to audio file
reference_text: Reference text
Returns:
Cache key (MD5 hash)
"""
# Read audio file
with open(audio_path, "rb") as f:
audio_bytes = f.read()
# Combine audio bytes and text
combined = audio_bytes + reference_text.encode("utf-8")
# Generate hash
return hashlib.md5(combined).hexdigest()
def get_cached_voice_prompt(
cache_key: str,
) -> Optional[Union[torch.Tensor, Dict[str, Any]]]:
"""
Get cached voice prompt if available.
Args:
cache_key: Cache key
Returns:
Cached voice prompt (dict or tensor) or None
"""
# Check in-memory cache
if cache_key in _memory_cache:
return _memory_cache[cache_key]
# Check disk cache
cache_file = _get_cache_dir() / f"{cache_key}.prompt"
if cache_file.exists():
try:
prompt = torch.load(cache_file, weights_only=True)
_memory_cache[cache_key] = prompt
return prompt
except Exception:
# Cache file corrupted, delete it
cache_file.unlink()
return None
def cache_voice_prompt(
cache_key: str,
voice_prompt: Union[torch.Tensor, Dict[str, Any]],
) -> None:
"""
Cache voice prompt to memory and disk.
Args:
cache_key: Cache key
voice_prompt: Voice prompt (dict or tensor)
"""
# Store in memory
_memory_cache[cache_key] = voice_prompt
# Store on disk (torch.save can handle both dicts and tensors)
cache_file = _get_cache_dir() / f"{cache_key}.prompt"
torch.save(voice_prompt, cache_file)
def clear_voice_prompt_cache() -> int:
"""
Clear all voice prompt caches (memory and disk).
Returns:
Number of cache files deleted
"""
# Clear memory cache
_memory_cache.clear()
# Clear disk cache
cache_dir = _get_cache_dir()
deleted_count = 0
if cache_dir.exists():
# Delete prompt cache files
for cache_file in cache_dir.glob("*.prompt"):
try:
cache_file.unlink()
deleted_count += 1
except Exception as e:
logger.warning("Failed to delete cache file %s: %s", cache_file, e)
# Delete combined audio files
for audio_file in cache_dir.glob("combined_*.wav"):
try:
audio_file.unlink()
deleted_count += 1
except Exception as e:
logger.warning("Failed to delete combined audio file %s: %s", audio_file, e)
return deleted_count
def clear_profile_cache(profile_id: str) -> int:
"""
Clear cache files for a specific profile.
Args:
profile_id: Profile ID
Returns:
Number of cache files deleted
"""
cache_dir = _get_cache_dir()
deleted_count = 0
if cache_dir.exists():
# Delete combined audio files for this profile
pattern = f"combined_{profile_id}_*.wav"
for audio_file in cache_dir.glob(pattern):
try:
audio_file.unlink()
deleted_count += 1
except Exception as e:
logger.warning("Failed to delete combined audio file %s: %s", audio_file, e)
return deleted_count

View File

@@ -0,0 +1,299 @@
"""
Chunked TTS generation utilities.
Splits long text into sentence-boundary chunks, generates audio per-chunk
via any TTSBackend, and concatenates with crossfade. All logic is
engine-agnostic — it wraps the standard ``TTSBackend.generate()`` interface.
Short text (≤ max_chunk_chars) uses the single-shot fast path with zero
overhead.
"""
import logging
import re
from typing import List, Tuple
import numpy as np
logger = logging.getLogger("voicebox.chunked-tts")
# Default chunk size in characters. Can be overridden per-request via
# the ``max_chunk_chars`` field on GenerationRequest.
DEFAULT_MAX_CHUNK_CHARS = 800
# Common abbreviations that should NOT be treated as sentence endings.
# Lowercase for case-insensitive matching.
_ABBREVIATIONS = frozenset(
{
"mr",
"mrs",
"ms",
"dr",
"prof",
"sr",
"jr",
"st",
"ave",
"blvd",
"inc",
"ltd",
"corp",
"dept",
"est",
"approx",
"vs",
"etc",
"e.g",
"i.e",
"a.m",
"p.m",
"u.s",
"u.s.a",
"u.k",
}
)
# Paralinguistic tags used by Chatterbox Turbo. The splitter must never
# cut inside one of these.
_PARA_TAG_RE = re.compile(r"\[[^\]]*\]")
def split_text_into_chunks(text: str, max_chars: int = DEFAULT_MAX_CHUNK_CHARS) -> List[str]:
"""Split *text* at natural boundaries into chunks of at most *max_chars*.
Priority: sentence-end (``.!?`` not preceded by an abbreviation and not
inside brackets) → clause boundary (``;:,—``) → whitespace → hard cut.
Paralinguistic tags like ``[laugh]`` are treated as atomic and will not
be split across chunks.
"""
text = text.strip()
if not text:
return []
if len(text) <= max_chars:
return [text]
chunks: List[str] = []
remaining = text
while remaining:
remaining = remaining.lstrip()
if not remaining:
break
if len(remaining) <= max_chars:
chunks.append(remaining)
break
segment = remaining[:max_chars]
# Try to split at the last real sentence ending
split_pos = _find_last_sentence_end(segment)
if split_pos == -1:
split_pos = _find_last_clause_boundary(segment)
if split_pos == -1:
split_pos = segment.rfind(" ")
if split_pos == -1:
# Absolute fallback: hard cut but avoid splitting inside a tag
split_pos = _safe_hard_cut(segment, max_chars)
chunk = remaining[: split_pos + 1].strip()
if chunk:
chunks.append(chunk)
remaining = remaining[split_pos + 1 :]
return chunks
def _find_last_sentence_end(text: str) -> int:
"""Return the index of the last sentence-ending punctuation in *text*.
Skips periods that follow common abbreviations (``Dr.``, ``Mr.``, etc.)
and periods inside bracket tags (``[laugh]``). Also handles CJK
sentence-ending punctuation (``。!?``).
"""
best = -1
# ASCII sentence ends
for m in re.finditer(r"[.!?](?:\s|$)", text):
pos = m.start()
char = text[pos]
# Skip periods after abbreviations
if char == ".":
# Walk backwards to find the preceding word
word_start = pos - 1
while word_start >= 0 and text[word_start].isalpha():
word_start -= 1
word = text[word_start + 1 : pos].lower()
if word in _ABBREVIATIONS:
continue
# Skip decimal numbers (digit immediately before the period)
if word_start >= 0 and text[word_start].isdigit():
continue
# Skip if we're inside a bracket tag
if _inside_bracket_tag(text, pos):
continue
best = pos
# CJK sentence-ending punctuation
for m in re.finditer(r"[\u3002\uff01\uff1f]", text):
if m.start() > best:
best = m.start()
return best
def _find_last_clause_boundary(text: str) -> int:
"""Return the index of the last clause-boundary punctuation."""
best = -1
for m in re.finditer(r"[;:,\u2014](?:\s|$)", text):
pos = m.start()
# Skip if inside a bracket tag
if _inside_bracket_tag(text, pos):
continue
best = pos
return best
def _inside_bracket_tag(text: str, pos: int) -> bool:
"""Return True if *pos* falls inside a ``[...]`` tag."""
for m in _PARA_TAG_RE.finditer(text):
if m.start() < pos < m.end():
return True
return False
def _safe_hard_cut(segment: str, max_chars: int) -> int:
"""Find a hard-cut position that doesn't split a ``[tag]``."""
cut = max_chars - 1
# Check if the cut falls inside a bracket tag; if so, move before it
for m in _PARA_TAG_RE.finditer(segment):
if m.start() < cut < m.end():
return m.start() - 1 if m.start() > 0 else cut
return cut
def concatenate_audio_chunks(
chunks: List[np.ndarray],
sample_rate: int,
crossfade_ms: int = 50,
) -> np.ndarray:
"""Concatenate audio arrays with a short crossfade to eliminate clicks.
Each chunk is expected to be a 1-D float32 ndarray at *sample_rate* Hz.
"""
if not chunks:
return np.array([], dtype=np.float32)
if len(chunks) == 1:
return chunks[0]
crossfade_samples = int(sample_rate * crossfade_ms / 1000)
result = np.array(chunks[0], dtype=np.float32, copy=True)
for chunk in chunks[1:]:
if len(chunk) == 0:
continue
overlap = min(crossfade_samples, len(result), len(chunk))
if overlap > 0:
fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32)
fade_in = np.linspace(0.0, 1.0, overlap, dtype=np.float32)
result[-overlap:] = result[-overlap:] * fade_out + chunk[:overlap] * fade_in
result = np.concatenate([result, chunk[overlap:]])
else:
result = np.concatenate([result, chunk])
return result
async def generate_chunked(
backend,
text: str,
voice_prompt: dict,
language: str = "en",
seed: int | None = None,
instruct: str | None = None,
max_chunk_chars: int = DEFAULT_MAX_CHUNK_CHARS,
crossfade_ms: int = 50,
trim_fn=None,
) -> Tuple[np.ndarray, int]:
"""Generate audio with automatic chunking for long text.
For text shorter than *max_chunk_chars* this is a thin wrapper around
``backend.generate()`` with zero overhead.
For longer text the input is split at natural sentence boundaries,
each chunk is generated independently, optionally trimmed (useful for
Chatterbox engines that hallucinate trailing noise), and the results
are concatenated with a crossfade (or hard cut if *crossfade_ms* is 0).
Parameters
----------
backend : TTSBackend
Any backend implementing the ``generate()`` protocol.
text : str
Input text (may be arbitrarily long).
voice_prompt, language, seed, instruct
Forwarded to ``backend.generate()`` verbatim.
max_chunk_chars : int
Maximum characters per chunk (default 800).
crossfade_ms : int
Crossfade duration in milliseconds between chunks. 0 for a hard
cut with no overlap (default 50).
trim_fn : callable | None
Optional ``(audio, sample_rate) -> audio`` post-processing
function applied to each chunk before concatenation (e.g.
``trim_tts_output`` for Chatterbox engines).
Returns
-------
(audio, sample_rate) : Tuple[np.ndarray, int]
"""
chunks = split_text_into_chunks(text, max_chunk_chars)
if len(chunks) <= 1:
# Short text — single-shot fast path
audio, sample_rate = await backend.generate(
text,
voice_prompt,
language,
seed,
instruct,
)
if trim_fn is not None:
audio = trim_fn(audio, sample_rate)
return audio, sample_rate
# Long text — chunked generation
logger.info(
"Splitting %d chars into %d chunks (max %d chars each)",
len(text),
len(chunks),
max_chunk_chars,
)
audio_chunks: List[np.ndarray] = []
sample_rate: int | None = None
for i, chunk_text in enumerate(chunks):
logger.info(
"Generating chunk %d/%d (%d chars)",
i + 1,
len(chunks),
len(chunk_text),
)
# Vary the seed per chunk to avoid correlated RNG artefacts,
# but keep it deterministic so the same (text, seed) pair
# always produces the same output.
chunk_seed = (seed + i) if seed is not None else None
chunk_audio, chunk_sr = await backend.generate(
chunk_text,
voice_prompt,
language,
chunk_seed,
instruct,
)
if trim_fn is not None:
chunk_audio = trim_fn(chunk_audio, chunk_sr)
audio_chunks.append(np.asarray(chunk_audio, dtype=np.float32))
if sample_rate is None:
sample_rate = chunk_sr
audio = concatenate_audio_chunks(audio_chunks, sample_rate, crossfade_ms=crossfade_ms)
return audio, sample_rate

95
backend/utils/dac_shim.py Normal file
View File

@@ -0,0 +1,95 @@
"""
Minimal shim for descript-audio-codec (DAC).
TADA only imports Snake1d from dac.nn.layers and dac.model.dac.
The real DAC package pulls in descript-audiotools which depends on
onnx, tensorboard, protobuf, matplotlib, pystoi, etc. — none of
which are needed for TADA's runtime use of Snake1d.
This shim provides the exact Snake1d implementation (MIT-licensed,
from https://github.com/descriptinc/descript-audio-codec) so we can
avoid the entire audiotools dependency chain.
If the real DAC package is installed, this module is never used —
Python's import system will find the site-packages version first.
Install this shim only when descript-audio-codec is NOT installed.
"""
import sys
import types
import torch
import torch.nn as nn
# ── Snake activation (from dac/nn/layers.py) ────────────────────────
# NOTE: The original DAC code uses @torch.jit.script here for a 1.4x
# speedup. We omit it because TorchScript calls inspect.getsource()
# which fails inside a PyInstaller frozen binary (no .py source files).
def snake(x: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return snake(x, self.alpha)
# ── Register as dac.nn.layers and dac.model.dac ─────────────────────
def install_dac_shim() -> None:
"""Register fake dac package modules in sys.modules.
Only installs the shim if 'dac' is not already importable
(i.e. the real descript-audio-codec is not installed).
"""
try:
import dac # noqa: F401 — real package exists, do nothing
return
except ImportError:
pass
# Create the module tree: dac -> dac.nn -> dac.nn.layers
# -> dac.model -> dac.model.dac
dac_pkg = types.ModuleType("dac")
dac_pkg.__path__ = [] # make it a package
dac_pkg.__package__ = "dac"
dac_nn = types.ModuleType("dac.nn")
dac_nn.__path__ = []
dac_nn.__package__ = "dac.nn"
dac_nn_layers = types.ModuleType("dac.nn.layers")
dac_nn_layers.__package__ = "dac.nn"
dac_nn_layers.Snake1d = Snake1d
dac_nn_layers.snake = snake
dac_model = types.ModuleType("dac.model")
dac_model.__path__ = []
dac_model.__package__ = "dac.model"
dac_model_dac = types.ModuleType("dac.model.dac")
dac_model_dac.__package__ = "dac.model"
dac_model_dac.Snake1d = Snake1d
# Wire up submodules
dac_pkg.nn = dac_nn
dac_pkg.model = dac_model
dac_nn.layers = dac_nn_layers
dac_model.dac = dac_model_dac
# Register in sys.modules
sys.modules["dac"] = dac_pkg
sys.modules["dac.nn"] = dac_nn
sys.modules["dac.nn.layers"] = dac_nn_layers
sys.modules["dac.model"] = dac_model
sys.modules["dac.model.dac"] = dac_model_dac

373
backend/utils/effects.py Normal file
View File

@@ -0,0 +1,373 @@
"""
Audio post-processing effects engine.
Uses Spotify's pedalboard library to apply professional-grade DSP effects
to generated audio. Effects are described as a JSON-serializable chain
(list of effect dicts) so they can be stored in the database and sent
over the API.
Supported effect types:
- chorus (flanger-style with short delays)
- reverb (room reverb)
- delay (echo / delay line)
- compressor (dynamic range compression)
- gain (volume adjustment in dB)
- highpass (high-pass filter)
- lowpass (low-pass filter)
- pitch_shift (semitone pitch shifting)
"""
from __future__ import annotations
import numpy as np
from typing import Any, Dict, List, Optional
from pedalboard import (
Pedalboard,
Chorus,
Reverb,
Compressor,
Gain,
HighpassFilter,
LowpassFilter,
Delay,
PitchShift,
)
# Each param definition: (default, min, max, description)
EFFECT_REGISTRY: Dict[str, Dict[str, Any]] = {
"chorus": {
"cls": Chorus,
"label": "Chorus / Flanger",
"description": "Modulated delay for flanging or chorus effects. Short centre_delay_ms (<10) gives flanger; longer gives chorus.",
"params": {
"rate_hz": {"default": 1.0, "min": 0.01, "max": 20.0, "step": 0.01, "description": "LFO speed (Hz)"},
"depth": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Modulation depth"},
"feedback": {"default": 0.0, "min": 0.0, "max": 0.95, "step": 0.01, "description": "Feedback amount"},
"centre_delay_ms": {
"default": 7.0,
"min": 0.5,
"max": 50.0,
"step": 0.1,
"description": "Centre delay (ms)",
},
"mix": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Wet/dry mix"},
},
},
"reverb": {
"cls": Reverb,
"label": "Reverb",
"description": "Room reverb effect.",
"params": {
"room_size": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Room size"},
"damping": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "High frequency damping"},
"wet_level": {"default": 0.33, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Wet level"},
"dry_level": {"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Dry level"},
"width": {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Stereo width"},
},
},
"delay": {
"cls": Delay,
"label": "Delay",
"description": "Echo / delay line.",
"params": {
"delay_seconds": {
"default": 0.3,
"min": 0.01,
"max": 2.0,
"step": 0.01,
"description": "Delay time (seconds)",
},
"feedback": {"default": 0.3, "min": 0.0, "max": 0.95, "step": 0.01, "description": "Feedback amount"},
"mix": {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Wet/dry mix"},
},
},
"compressor": {
"cls": Compressor,
"label": "Compressor",
"description": "Dynamic range compression for consistent loudness.",
"params": {
"threshold_db": {"default": -20.0, "min": -60.0, "max": 0.0, "step": 0.5, "description": "Threshold (dB)"},
"ratio": {"default": 4.0, "min": 1.0, "max": 20.0, "step": 0.1, "description": "Compression ratio"},
"attack_ms": {"default": 10.0, "min": 0.1, "max": 100.0, "step": 0.1, "description": "Attack time (ms)"},
"release_ms": {
"default": 100.0,
"min": 10.0,
"max": 1000.0,
"step": 1.0,
"description": "Release time (ms)",
},
},
},
"gain": {
"cls": Gain,
"label": "Gain",
"description": "Volume adjustment in decibels.",
"params": {
"gain_db": {"default": 0.0, "min": -40.0, "max": 40.0, "step": 0.5, "description": "Gain (dB)"},
},
},
"highpass": {
"cls": HighpassFilter,
"label": "High-Pass Filter",
"description": "Removes frequencies below the cutoff.",
"params": {
"cutoff_frequency_hz": {
"default": 80.0,
"min": 20.0,
"max": 8000.0,
"step": 1.0,
"description": "Cutoff frequency (Hz)",
},
},
},
"lowpass": {
"cls": LowpassFilter,
"label": "Low-Pass Filter",
"description": "Removes frequencies above the cutoff.",
"params": {
"cutoff_frequency_hz": {
"default": 8000.0,
"min": 200.0,
"max": 20000.0,
"step": 1.0,
"description": "Cutoff frequency (Hz)",
},
},
},
"pitch_shift": {
"cls": PitchShift,
"label": "Pitch Shift",
"description": "Shift pitch up or down by semitones.",
"params": {
"semitones": {"default": 0.0, "min": -12.0, "max": 12.0, "step": 0.5, "description": "Semitones to shift"},
},
},
}
BUILTIN_PRESETS: Dict[str, Dict[str, Any]] = {
"robotic": {
"name": "Robotic",
"sort_order": 0,
"description": "Metallic robotic voice (flanger with slow LFO and high feedback)",
"effects_chain": [
{
"type": "chorus",
"enabled": True,
"params": {
"rate_hz": 0.2,
"depth": 1.0,
"feedback": 0.35,
"centre_delay_ms": 7.0,
"mix": 0.5,
},
},
],
},
"radio": {
"name": "Radio",
"sort_order": 1,
"description": "Thin AM-radio voice with band-pass filtering and light compression",
"effects_chain": [
{
"type": "highpass",
"enabled": True,
"params": {"cutoff_frequency_hz": 300.0},
},
{
"type": "lowpass",
"enabled": True,
"params": {"cutoff_frequency_hz": 3500.0},
},
{
"type": "compressor",
"enabled": True,
"params": {
"threshold_db": -15.0,
"ratio": 6.0,
"attack_ms": 5.0,
"release_ms": 50.0,
},
},
{
"type": "gain",
"enabled": True,
"params": {"gain_db": 6.0},
},
],
},
"echo_chamber": {
"name": "Echo Chamber",
"sort_order": 2,
"description": "Spacious reverb with trailing echo",
"effects_chain": [
{
"type": "reverb",
"enabled": True,
"params": {
"room_size": 0.85,
"damping": 0.3,
"wet_level": 0.45,
"dry_level": 0.55,
"width": 1.0,
},
},
{
"type": "delay",
"enabled": True,
"params": {
"delay_seconds": 0.25,
"feedback": 0.3,
"mix": 0.2,
},
},
],
},
"deep_voice": {
"name": "Deep Voice",
"sort_order": 99,
"description": "Lower pitch with added warmth",
"effects_chain": [
{
"type": "pitch_shift",
"enabled": True,
"params": {"semitones": -3.0},
},
{
"type": "lowpass",
"enabled": True,
"params": {"cutoff_frequency_hz": 6000.0},
},
{
"type": "compressor",
"enabled": True,
"params": {
"threshold_db": -18.0,
"ratio": 3.0,
"attack_ms": 10.0,
"release_ms": 150.0,
},
},
],
},
}
def get_available_effects() -> List[Dict[str, Any]]:
"""Return the list of available effect types with their parameter definitions.
Used by the frontend to build the effects chain editor UI.
"""
result = []
for effect_type, info in EFFECT_REGISTRY.items():
result.append(
{
"type": effect_type,
"label": info["label"],
"description": info["description"],
"params": {name: {k: v for k, v in pdef.items()} for name, pdef in info["params"].items()},
}
)
return result
def get_builtin_presets() -> Dict[str, Dict[str, Any]]:
"""Return all built-in effect presets."""
return BUILTIN_PRESETS
def validate_effects_chain(effects_chain: List[Dict[str, Any]]) -> Optional[str]:
"""Validate an effects chain configuration.
Returns None if valid, or an error message string.
"""
if not isinstance(effects_chain, list):
return "effects_chain must be a list"
for i, effect in enumerate(effects_chain):
if not isinstance(effect, dict):
return f"Effect at index {i} must be a dict"
effect_type = effect.get("type")
if effect_type not in EFFECT_REGISTRY:
return f"Unknown effect type '{effect_type}' at index {i}. Available: {list(EFFECT_REGISTRY.keys())}"
params = effect.get("params", {})
if not isinstance(params, dict):
return f"Effect '{effect_type}' at index {i}: params must be a dict"
registry = EFFECT_REGISTRY[effect_type]
for param_name, value in params.items():
if param_name not in registry["params"]:
return f"Effect '{effect_type}' at index {i}: unknown param '{param_name}'"
pdef = registry["params"][param_name]
if not isinstance(value, (int, float)):
return f"Effect '{effect_type}' at index {i}: param '{param_name}' must be a number"
if value < pdef["min"] or value > pdef["max"]:
return (
f"Effect '{effect_type}' at index {i}: param '{param_name}' "
f"must be between {pdef['min']} and {pdef['max']} (got {value})"
)
return None
def build_pedalboard(effects_chain: List[Dict[str, Any]]) -> Pedalboard:
"""Build a Pedalboard instance from an effects chain config.
Skips effects where ``enabled`` is ``False``.
"""
plugins = []
for effect in effects_chain:
if not effect.get("enabled", True):
continue
effect_type = effect["type"]
registry = EFFECT_REGISTRY[effect_type]
cls = registry["cls"]
# Merge defaults with provided params
params = {}
for pname, pdef in registry["params"].items():
params[pname] = effect.get("params", {}).get(pname, pdef["default"])
plugins.append(cls(**params))
return Pedalboard(plugins)
def apply_effects(
audio: np.ndarray,
sample_rate: int,
effects_chain: List[Dict[str, Any]],
) -> np.ndarray:
"""Apply an effects chain to audio data.
Args:
audio: Input audio array (1-D mono float32).
sample_rate: Sample rate in Hz.
effects_chain: List of effect configuration dicts.
Returns:
Processed audio array.
"""
if not effects_chain:
return audio
board = build_pedalboard(effects_chain)
# pedalboard expects shape (channels, samples)
if audio.ndim == 1:
audio_2d = audio[np.newaxis, :]
else:
audio_2d = audio
processed = board(audio_2d.astype(np.float32), sample_rate)
# Return same dimensionality as input
if audio.ndim == 1:
return processed[0]
return processed

View File

@@ -0,0 +1,270 @@
"""Monkey-patch huggingface_hub to force offline mode with cached models.
Prevents mlx_audio / transformers from making network requests when models
are already downloaded. Must be imported BEFORE mlx_audio.
"""
import logging
import os
import threading
from contextlib import contextmanager
from pathlib import Path
from typing import Optional, Union
logger = logging.getLogger(__name__)
# huggingface_hub reads ``HF_HUB_OFFLINE`` once at import time into
# ``huggingface_hub.constants.HF_HUB_OFFLINE``; transformers mirrors that into
# ``transformers.utils.hub._is_offline_mode`` at *its* import time. Toggling
# ``os.environ`` after either module is imported does not flip those cached
# bools, and the hot paths (``_http._default_backend_factory``,
# ``transformers.utils.hub.is_offline_mode``) read the bools — not the env.
# We mutate the cached constants directly, guarded by a refcount so
# concurrent inference threads share a single offline window safely.
_offline_lock = threading.RLock()
_offline_refcount = 0
_saved_env: Optional[str] = None
_saved_hf_const: Optional[bool] = None
_saved_transformers_const: Optional[bool] = None
@contextmanager
def force_offline_if_cached(is_cached: bool, model_label: str = ""):
"""Force offline mode for the duration of a cached-model operation.
Flips ``HF_HUB_OFFLINE`` in the process env **and** in the cached bools
inside ``huggingface_hub.constants`` / ``transformers.utils.hub`` so HTTP
adapters and offline-mode checks actually see the change. Uses a refcount
so multiple concurrent inference threads share a single offline window
and the last one to exit restores state.
If *is_cached* is ``False`` the block runs normally (network allowed).
Args:
is_cached: Whether the model weights are already on disk.
model_label: Human-readable name used in log messages.
"""
if not is_cached:
yield
return
global _offline_refcount, _saved_env, _saved_hf_const, _saved_transformers_const
with _offline_lock:
if _offline_refcount == 0:
# Snapshot prior state, apply new state, roll back on *any*
# failure. Catching only ImportError here would let a partially
# broken install (RuntimeError, AttributeError from a half-init
# module, etc.) leave the cached HF constants mutated without
# bumping the refcount — a persistent offline leak that outlives
# the process and is miserable to debug.
prev_env = os.environ.get("HF_HUB_OFFLINE")
prev_hf: Optional[bool] = None
prev_tf: Optional[bool] = None
try:
try:
import huggingface_hub.constants as hf_const
prev_hf = hf_const.HF_HUB_OFFLINE
hf_const.HF_HUB_OFFLINE = True
except ImportError:
prev_hf = None
try:
import transformers.utils.hub as tf_hub
prev_tf = tf_hub._is_offline_mode
tf_hub._is_offline_mode = True
except ImportError:
prev_tf = None
os.environ["HF_HUB_OFFLINE"] = "1"
except BaseException:
# Roll back whatever we already changed, then re-raise so
# the caller sees the real failure.
if prev_hf is not None:
try:
import huggingface_hub.constants as hf_const
hf_const.HF_HUB_OFFLINE = prev_hf
except ImportError:
pass
if prev_tf is not None:
try:
import transformers.utils.hub as tf_hub
tf_hub._is_offline_mode = prev_tf
except ImportError:
pass
if prev_env is not None:
os.environ["HF_HUB_OFFLINE"] = prev_env
else:
os.environ.pop("HF_HUB_OFFLINE", None)
raise
_saved_env = prev_env
_saved_hf_const = prev_hf
_saved_transformers_const = prev_tf
logger.info(
"[offline-guard] %s is cached — forcing offline mode",
model_label or "model",
)
_offline_refcount += 1
try:
yield
finally:
with _offline_lock:
_offline_refcount -= 1
if _offline_refcount == 0:
if _saved_env is not None:
os.environ["HF_HUB_OFFLINE"] = _saved_env
else:
os.environ.pop("HF_HUB_OFFLINE", None)
if _saved_hf_const is not None:
try:
import huggingface_hub.constants as hf_const
hf_const.HF_HUB_OFFLINE = _saved_hf_const
except ImportError:
pass
if _saved_transformers_const is not None:
try:
import transformers.utils.hub as tf_hub
tf_hub._is_offline_mode = _saved_transformers_const
except ImportError:
pass
_saved_env = None
_saved_hf_const = None
_saved_transformers_const = None
_mistral_regex_patched = False
def patch_transformers_mistral_regex():
"""Make transformers' tokenizer load robust to HuggingFace metadata failures.
transformers 4.57.x added ``PreTrainedTokenizerBase._patch_mistral_regex``
which unconditionally calls ``huggingface_hub.model_info(repo_id)`` during
every non-local tokenizer load to check whether the model is a Mistral
variant. That call raises on ``HF_HUB_OFFLINE=1`` and on plain network
failures, killing unrelated loads (Qwen TTS, TADA, etc.).
Voicebox never loads Mistral models, so the rewrite the function would
apply is a no-op for us anyway. Wrap the method so any exception from the
metadata lookup returns the tokenizer unchanged — matching the success-path
behavior for non-Mistral repos (transformers 4.57.3,
``tokenization_utils_base.py:2503``).
"""
global _mistral_regex_patched
if _mistral_regex_patched:
return
try:
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
except ImportError:
logger.debug("transformers not available, skipping mistral-regex patch")
return
original = getattr(PreTrainedTokenizerBase, "_patch_mistral_regex", None)
if original is None:
logger.debug(
"transformers has no _patch_mistral_regex attribute, skipping patch",
)
return
def safe_patch_mistral_regex(cls, tokenizer, pretrained_model_name_or_path, *args, **kwargs):
try:
return original(tokenizer, pretrained_model_name_or_path, *args, **kwargs)
except Exception as exc:
logger.debug(
"[mistral-regex-patch] suppressed %s for %r, returning tokenizer as-is",
type(exc).__name__,
pretrained_model_name_or_path,
)
return tokenizer
PreTrainedTokenizerBase._patch_mistral_regex = classmethod(safe_patch_mistral_regex)
_mistral_regex_patched = True
logger.debug("installed _patch_mistral_regex wrapper")
def patch_huggingface_hub_offline():
"""Monkey-patch huggingface_hub to force offline mode."""
try:
import huggingface_hub # noqa: F401 -- need the package loaded
from huggingface_hub import constants as hf_constants
from huggingface_hub.file_download import _try_to_load_from_cache
original_try_load = _try_to_load_from_cache
def _patched_try_to_load_from_cache(
repo_id: str,
filename: str,
cache_dir: Union[str, Path, None] = None,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
):
result = original_try_load(
repo_id=repo_id,
filename=filename,
cache_dir=cache_dir,
revision=revision,
repo_type=repo_type,
)
if result is None:
cache_path = Path(hf_constants.HF_HUB_CACHE) / f"models--{repo_id.replace('/', '--')}"
logger.debug("file not cached: %s/%s (expected at %s)", repo_id, filename, cache_path)
else:
logger.debug("cache hit: %s/%s", repo_id, filename)
return result
import huggingface_hub.file_download as fd
fd._try_to_load_from_cache = _patched_try_to_load_from_cache
logger.debug("huggingface_hub patched for offline mode")
except ImportError:
logger.debug("huggingface_hub not available, skipping offline patch")
except Exception:
logger.exception("failed to patch huggingface_hub for offline mode")
def ensure_original_qwen_config_cached():
"""Symlink the original Qwen repo cache to the MLX community version.
mlx_audio may try to fetch config from the original Qwen repo. If only
the MLX community variant is cached, create a symlink so the cache lookup
succeeds without a network request.
"""
try:
from huggingface_hub import constants as hf_constants
except ImportError:
return
original_repo = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
mlx_repo = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16"
cache_dir = Path(hf_constants.HF_HUB_CACHE)
original_path = cache_dir / f"models--{original_repo.replace('/', '--')}"
mlx_path = cache_dir / f"models--{mlx_repo.replace('/', '--')}"
if not original_path.exists() and mlx_path.exists():
try:
original_path.parent.mkdir(parents=True, exist_ok=True)
original_path.symlink_to(mlx_path, target_is_directory=True)
logger.info("created cache symlink: %s -> %s", original_repo, mlx_repo)
except Exception:
logger.warning("could not create cache symlink for %s", original_repo, exc_info=True)
if os.environ.get("VOICEBOX_OFFLINE_PATCH", "1") != "0":
patch_huggingface_hub_offline()
patch_transformers_mistral_regex()
ensure_original_qwen_config_cached()

View File

@@ -0,0 +1,383 @@
"""
HuggingFace Hub download progress tracking.
"""
from typing import Optional, Callable
from contextlib import contextmanager
import logging
import threading
import sys
logger = logging.getLogger(__name__)
class HFProgressTracker:
"""Tracks HuggingFace Hub download progress by intercepting tqdm."""
def __init__(self, progress_callback: Optional[Callable] = None, filter_non_downloads: bool = False):
self.progress_callback = progress_callback
self.filter_non_downloads = filter_non_downloads # Only filter if True
self._original_tqdm_class = None
self._lock = threading.Lock()
self._total_downloaded = 0
self._total_size = 0
self._file_sizes = {} # Track sizes of individual files
self._file_downloaded = {} # Track downloaded bytes per file
self._current_filename = ""
self._active_tqdms = {} # Track active tqdm instances
self._hf_tqdm_original_update = None # For monkey-patching hf's tqdm
def _create_tracked_tqdm_class(self):
"""Create a tqdm subclass that tracks progress."""
tracker = self
original_tqdm = self._original_tqdm_class
class TrackedTqdm(original_tqdm):
"""A tqdm subclass that reports progress to our tracker."""
def __init__(self, *args, **kwargs):
# Extract filename from desc before passing to parent
desc = kwargs.get("desc", "")
if not desc and args:
first_arg = args[0]
if isinstance(first_arg, str):
desc = first_arg
filename = ""
if desc:
# Try to extract filename from description
# HuggingFace Hub uses format like "model.safetensors: 0%|..."
if ":" in desc:
filename = desc.split(":")[0].strip()
else:
filename = desc.strip()
# Filter out non-standard kwargs that huggingface_hub might pass
# These are custom kwargs that tqdm doesn't understand
filtered_kwargs = {}
# Known tqdm kwargs - pass these through
tqdm_kwargs = {
"iterable",
"desc",
"total",
"leave",
"file",
"ncols",
"mininterval",
"maxinterval",
"miniters",
"ascii",
"disable",
"unit",
"unit_scale",
"dynamic_ncols",
"smoothing",
"bar_format",
"initial",
"position",
"postfix",
"unit_divisor",
"write_bytes",
"lock_args",
"nrows",
"colour",
"color",
"delay",
"gui",
"disable_default",
"pos",
}
for key, value in kwargs.items():
if key in tqdm_kwargs:
filtered_kwargs[key] = value
# Force-enable the progress bar — we're tracking progress ourselves,
# we don't need tqdm to render to a terminal, but we DO need
# self.n to be updated when update() is called.
filtered_kwargs["disable"] = False
# Try to initialize with filtered kwargs, fall back to all kwargs if that fails
try:
super().__init__(*args, **filtered_kwargs)
except TypeError:
# If filtering failed, try with all kwargs (maybe tqdm version accepts them)
kwargs["disable"] = False
super().__init__(*args, **kwargs)
self._tracker_filename = filename or "unknown"
with tracker._lock:
if filename:
tracker._current_filename = filename
tracker._active_tqdms[id(self)] = {
"filename": self._tracker_filename,
}
def update(self, n=1):
result = super().update(n)
# Report progress
with tracker._lock:
if id(self) in tracker._active_tqdms:
filename = tracker._active_tqdms[id(self)]["filename"]
current = getattr(self, "n", 0)
total = getattr(self, "total", 0)
if total and total > 0:
# Always filter out non-byte progress bars (e.g., "Fetching 12 files")
# These cause crazy percentages because they're counting files, not bytes
if self._is_non_byte_progress(filename):
return result
# When model is cached, also filter out generation-related progress
if tracker.filter_non_downloads:
if not self._is_download_progress(filename):
return result
# Update per-file tracking
tracker._file_sizes[filename] = total
tracker._file_downloaded[filename] = current
# Calculate totals across all files
tracker._total_size = sum(tracker._file_sizes.values())
tracker._total_downloaded = sum(tracker._file_downloaded.values())
# Only report progress once we have a meaningful total (at least 1MB)
# This avoids the "100% at 0MB" issue when small config
# files are counted before the real model files
MIN_TOTAL_BYTES = 1_000_000 # 1MB
if tracker._total_size < MIN_TOTAL_BYTES:
return result
# Call progress callback
if tracker.progress_callback:
tracker.progress_callback(tracker._total_downloaded, tracker._total_size, filename)
return result
def _is_non_byte_progress(self, filename: str) -> bool:
"""Check if this progress bar should be SKIPPED (returns True to skip).
We want to track byte-based progress bars. This method identifies
progress bars that count files/items instead of bytes, which would
cause crazy percentages if mixed with our byte counting.
Returns:
True = SKIP this bar (it's not byte-based)
False = TRACK this bar (it counts bytes)
"""
if not filename:
return False
filename_lower = filename.lower()
# Skip "Fetching X files" - it counts files (total=12), not bytes
# Don't skip "Downloading (incomplete total...)" - that IS byte-based
skip_patterns = [
"fetching", # "Fetching 12 files" has total=12 files, not bytes
]
return any(pattern in filename_lower for pattern in skip_patterns)
def _is_download_progress(self, filename: str) -> bool:
"""Check if this is a real file download progress bar vs internal processing."""
if not filename or filename == "unknown":
return False
# Real downloads have file extensions
download_extensions = [
".safetensors",
".bin",
".pt",
".pth", # Model weights
".json",
".txt",
".py", # Config files
".msgpack",
".h5", # Other formats
]
filename_lower = filename.lower()
has_extension = any(filename_lower.endswith(ext) for ext in download_extensions)
# Skip generation-related progress indicators
skip_patterns = ["segment", "processing", "generating", "loading"]
has_skip_pattern = any(pattern in filename_lower for pattern in skip_patterns)
return has_extension and not has_skip_pattern
def close(self):
with tracker._lock:
if id(self) in tracker._active_tqdms:
del tracker._active_tqdms[id(self)]
return super().close()
return TrackedTqdm
@contextmanager
def patch_download(self):
"""Context manager to patch tqdm for progress tracking."""
try:
import tqdm as tqdm_module
# Store original tqdm class
self._original_tqdm_class = tqdm_module.tqdm
# Reset totals
with self._lock:
self._total_downloaded = 0
self._total_size = 0
self._file_sizes = {}
self._file_downloaded = {}
self._current_filename = ""
self._active_tqdms = {}
# Create our tracked tqdm class
tracked_tqdm = self._create_tracked_tqdm_class()
# Patch tqdm.tqdm
tqdm_module.tqdm = tracked_tqdm
# Also patch tqdm.auto.tqdm if it exists (used by huggingface_hub)
self._original_tqdm_auto = None
if hasattr(tqdm_module, "auto") and hasattr(tqdm_module.auto, "tqdm"):
self._original_tqdm_auto = tqdm_module.auto.tqdm
tqdm_module.auto.tqdm = tracked_tqdm
# Patch in sys.modules to catch already-imported references
# huggingface_hub uses: from tqdm.auto import tqdm as base_tqdm
# So we need to patch both 'tqdm' and 'base_tqdm' attributes
self._patched_modules = {}
tqdm_attr_names = ["tqdm", "base_tqdm", "old_tqdm"] # Various names used
patched_count = 0
for module_name in list(sys.modules.keys()):
if "huggingface" in module_name or module_name.startswith("tqdm"):
try:
module = sys.modules[module_name]
for attr_name in tqdm_attr_names:
if hasattr(module, attr_name):
attr = getattr(module, attr_name)
# Only patch if it's a tqdm class (not already patched)
is_tqdm_class = (
attr is self._original_tqdm_class
or (self._original_tqdm_auto and attr is self._original_tqdm_auto)
or (
hasattr(attr, "__name__")
and attr.__name__ == "tqdm"
and hasattr(attr, "update")
) # tqdm classes have update method
)
if is_tqdm_class:
key = f"{module_name}.{attr_name}"
self._patched_modules[key] = (module, attr_name, attr)
setattr(module, attr_name, tracked_tqdm)
patched_count += 1
except (AttributeError, TypeError):
pass
# ALSO monkey-patch the update method on huggingface_hub's tqdm class
# This is needed because the class was already defined at import time
self._hf_tqdm_original_update = None
try:
from huggingface_hub.utils import tqdm as hf_tqdm_module
if hasattr(hf_tqdm_module, "tqdm"):
hf_tqdm_class = hf_tqdm_module.tqdm
self._hf_tqdm_original_update = hf_tqdm_class.update
# Create a wrapper that calls our tracking
tracker = self # Reference to HFProgressTracker instance
def patched_update(tqdm_self, n=1):
result = tracker._hf_tqdm_original_update(tqdm_self, n)
# Track this progress
with tracker._lock:
desc = getattr(tqdm_self, "desc", "") or ""
current = getattr(tqdm_self, "n", 0)
total = getattr(tqdm_self, "total", 0) or 0
# Skip non-byte progress bars
if "fetching" in desc.lower():
return result
# Skip until we have a meaningful total (at least 1MB)
# This avoids the "100% at 0MB" issue when small config
# files are counted before the real model files
MIN_TOTAL_BYTES = 1_000_000 # 1MB
if total >= MIN_TOTAL_BYTES:
tracker._total_downloaded = current
tracker._total_size = total
if tracker.progress_callback:
tracker.progress_callback(current, total, desc)
return result
hf_tqdm_class.update = patched_update
patched_count += 1
logger.debug("Monkey-patched huggingface_hub.utils.tqdm.tqdm.update")
except (ImportError, AttributeError) as e:
logger.warning("Could not monkey-patch hf_tqdm: %s", e)
logger.debug("Patched %d tqdm references", patched_count)
yield
except ImportError:
# If tqdm not available, just yield without patching
yield
finally:
# Restore original tqdm
if self._original_tqdm_class:
try:
import tqdm as tqdm_module
tqdm_module.tqdm = self._original_tqdm_class
if self._original_tqdm_auto:
tqdm_module.auto.tqdm = self._original_tqdm_auto
# Restore patched modules
for key, (module, attr_name, original) in self._patched_modules.items():
try:
if module and original:
setattr(module, attr_name, original)
except (AttributeError, TypeError):
pass
self._patched_modules = {}
# Restore hf_tqdm's original update method
if self._hf_tqdm_original_update:
try:
from huggingface_hub.utils import tqdm as hf_tqdm_module
if hasattr(hf_tqdm_module, "tqdm"):
hf_tqdm_module.tqdm.update = self._hf_tqdm_original_update
except (ImportError, AttributeError):
pass
self._hf_tqdm_original_update = None
except (ImportError, AttributeError):
pass
def create_hf_progress_callback(model_name: str, progress_manager):
"""Create a progress callback for HuggingFace downloads."""
def callback(downloaded: int, total: int, filename: str = ""):
"""Progress callback.
Note: We send updates even when total=0 (unknown) to provide feedback
during the "incomplete total" phase of huggingface_hub downloads.
The frontend handles total=0 gracefully.
"""
progress_manager.update_progress(
model_name=model_name,
current=downloaded,
total=total,
filename=filename or "",
status="downloading",
)
return callback

114
backend/utils/images.py Normal file
View File

@@ -0,0 +1,114 @@
"""Image processing utilities for avatar uploads."""
from pathlib import Path
from typing import Optional, Tuple
from PIL import Image
# JPEG can be reported as 'JPEG' or 'MPO' (for multi-picture format from some cameras)
ALLOWED_FORMATS = {'PNG', 'JPEG', 'WEBP', 'MPO', 'JPG'}
MAX_SIZE = 512
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
def validate_image(file_path: str) -> Tuple[bool, Optional[str]]:
"""
Validate image format and file size.
Args:
file_path: Path to image file
Returns:
Tuple of (is_valid, error_message)
"""
path = Path(file_path)
# Check file size
if path.stat().st_size > MAX_FILE_SIZE:
return False, f"File size exceeds maximum of {MAX_FILE_SIZE // (1024 * 1024)}MB"
try:
with Image.open(file_path) as img:
# Verify the image can be loaded
img.load()
# Check format (normalize JPEG variants)
img_format = img.format
if img_format in ('MPO', 'JPG'):
img_format = 'JPEG'
if img_format not in {'PNG', 'JPEG', 'WEBP'}:
return False, f"Invalid format '{img_format}'. Allowed formats: PNG, JPEG, WEBP"
return True, None
except Exception as e:
return False, f"Invalid image file: {str(e)}"
def process_avatar(input_path: str, output_path: str, max_size: int = MAX_SIZE) -> None:
"""
Process avatar image: resize and optimize.
Resizes image to fit within max_size x max_size while maintaining aspect ratio.
Args:
input_path: Path to input image
output_path: Path to save processed image
max_size: Maximum width or height in pixels
"""
with Image.open(input_path) as img:
# Handle EXIF orientation for JPEG images
try:
from PIL import ExifTags
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation] == 'Orientation':
break
exif = img._getexif()
if exif is not None:
orientation_value = exif.get(orientation)
if orientation_value == 3:
img = img.rotate(180, expand=True)
elif orientation_value == 6:
img = img.rotate(270, expand=True)
elif orientation_value == 8:
img = img.rotate(90, expand=True)
except (AttributeError, KeyError, IndexError, TypeError):
# No EXIF data or orientation tag
pass
# Convert to RGB if necessary (handles RGBA, P, CMYK, etc.)
if img.mode not in ('RGB', 'L'):
if img.mode == 'RGBA':
# Create white background for RGBA images
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=img.split()[3]) # Use alpha channel as mask
img = background
elif img.mode == 'CMYK':
# Convert CMYK to RGB
img = img.convert('RGB')
elif img.mode == 'P':
# Convert palette mode to RGB
img = img.convert('RGB')
else:
img = img.convert('RGB')
# Calculate new size maintaining aspect ratio
img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
# Determine output format from extension
output_ext = Path(output_path).suffix.lower()
format_map = {
'.png': 'PNG',
'.jpeg': 'JPEG',
'.jpg': 'JPEG',
'.webp': 'WEBP'
}
output_format = format_map.get(output_ext, 'PNG')
# Save with optimization
save_kwargs = {'optimize': True}
if output_format == 'JPEG':
save_kwargs['quality'] = 90
img.save(output_path, format=output_format, **save_kwargs)

View File

@@ -0,0 +1,35 @@
"""
Platform detection for backend selection.
"""
import platform
from typing import Literal
def is_apple_silicon() -> bool:
"""
Check if running on Apple Silicon (arm64 macOS).
Returns:
True if on Apple Silicon, False otherwise
"""
return platform.system() == "Darwin" and platform.machine() == "arm64"
def get_backend_type() -> Literal["mlx", "pytorch"]:
"""
Detect the best backend for the current platform.
Returns:
"mlx" on Apple Silicon (if MLX is available and functional), "pytorch" otherwise
"""
if is_apple_silicon():
try:
import mlx.core # noqa: F401 — triggers native lib loading
return "mlx"
except (ImportError, OSError, RuntimeError):
# MLX not installed, or native libraries failed to load inside a
# PyInstaller bundle (OSError on missing .dylib / .metallib).
# Fall through to PyTorch.
return "pytorch"
return "pytorch"

315
backend/utils/progress.py Normal file
View File

@@ -0,0 +1,315 @@
"""
Progress tracking for model downloads using Server-Sent Events.
"""
from typing import Optional, Callable, Dict, List
from fastapi.responses import StreamingResponse
import asyncio
import json
import threading
from datetime import datetime
class ProgressManager:
"""Manages download progress for multiple models.
Thread-safe: can be called from background threads (e.g., via asyncio.to_thread).
"""
# Throttle settings to prevent overwhelming SSE clients
THROTTLE_INTERVAL_SECONDS = 0.5 # Minimum time between updates
THROTTLE_PROGRESS_DELTA = 1.0 # Minimum progress change (%) to force update
def __init__(self):
self._progress: Dict[str, Dict] = {}
self._listeners: Dict[str, list] = {}
self._lock = threading.Lock() # Thread-safe lock for progress dict
self._main_loop: Optional[asyncio.AbstractEventLoop] = None
self._last_notify_time: Dict[str, float] = {} # Last notification time per model
self._last_notify_progress: Dict[str, float] = {} # Last notified progress per model
def _set_main_loop(self, loop: asyncio.AbstractEventLoop):
"""Set the main event loop for thread-safe operations."""
self._main_loop = loop
def _notify_listeners_threadsafe(self, model_name: str, progress_data: Dict):
"""Notify listeners in a thread-safe manner."""
import logging
logger = logging.getLogger(__name__)
if model_name not in self._listeners:
return
for queue in self._listeners[model_name]:
try:
# Check if we're in the main event loop thread
try:
running_loop = asyncio.get_running_loop()
# We're in an async context, can use put_nowait directly
queue.put_nowait(progress_data.copy())
except RuntimeError:
# Not in async context (running in background thread)
# Use call_soon_threadsafe to safely put on queue
if self._main_loop and self._main_loop.is_running():
self._main_loop.call_soon_threadsafe(
lambda q=queue, d=progress_data.copy(): q.put_nowait(d) if not q.full() else None
)
else:
logger.debug(f"No main loop available for {model_name}, skipping notification")
except asyncio.QueueFull:
logger.warning(f"Queue full for {model_name}, dropping update")
except Exception as e:
logger.warning(f"Error notifying listener for {model_name}: {e}")
def update_progress(
self,
model_name: str,
current: int,
total: int,
filename: Optional[str] = None,
status: str = "downloading",
):
"""
Update progress for a model download.
Thread-safe: can be called from background threads.
Progress updates are throttled to prevent overwhelming SSE clients.
Updates are sent at most every THROTTLE_INTERVAL_SECONDS, or when
progress changes by at least THROTTLE_PROGRESS_DELTA percent.
Args:
model_name: Name of the model (e.g., "qwen-tts-1.7B", "whisper-base")
current: Current bytes downloaded
total: Total bytes to download
filename: Current file being downloaded
status: Status string (downloading, extracting, complete, error)
"""
import logging
import time
logger = logging.getLogger(__name__)
# Calculate progress percentage, clamped to 0-100 range
# This prevents crazy percentages from edge cases like:
# - current > total temporarily during aggregation
# - mixing file-count progress with byte-count progress
if total > 0:
progress_pct = min(100.0, max(0.0, (current / total * 100)))
else:
progress_pct = 0
progress_data = {
"model_name": model_name,
"current": current,
"total": total,
"progress": progress_pct,
"filename": filename,
"status": status,
"timestamp": datetime.now().isoformat(),
}
# Thread-safe update of progress dict (always update internal state)
with self._lock:
self._progress[model_name] = progress_data
# Check if we should notify listeners (throttling)
current_time = time.time()
last_time = self._last_notify_time.get(model_name, 0)
last_progress = self._last_notify_progress.get(model_name, -100)
time_delta = current_time - last_time
progress_delta = abs(progress_pct - last_progress)
# Always notify for complete/error status, or if throttle conditions are met
should_notify = (
status in ("complete", "error") or
time_delta >= self.THROTTLE_INTERVAL_SECONDS or
progress_delta >= self.THROTTLE_PROGRESS_DELTA
)
if not should_notify:
return # Skip this update (throttled)
# Update throttle tracking
self._last_notify_time[model_name] = current_time
self._last_notify_progress[model_name] = progress_pct
# Notify all listeners (thread-safe)
listener_count = len(self._listeners.get(model_name, []))
if listener_count > 0:
logger.debug(f"Notifying {listener_count} listeners for {model_name}: {progress_pct:.1f}% ({filename})")
self._notify_listeners_threadsafe(model_name, progress_data)
else:
logger.debug(f"No listeners for {model_name}, progress update stored: {progress_pct:.1f}%")
def get_progress(self, model_name: str) -> Optional[Dict]:
"""Get current progress for a model. Thread-safe."""
with self._lock:
progress = self._progress.get(model_name)
return progress.copy() if progress else None
def get_all_active(self) -> List[Dict]:
"""Get all active downloads (status is 'downloading' or 'extracting'). Thread-safe."""
active = []
with self._lock:
for model_name, progress in self._progress.items():
status = progress.get("status", "")
if status in ("downloading", "extracting"):
active.append(progress.copy())
return active
def create_progress_callback(self, model_name: str, filename: Optional[str] = None):
"""
Create a progress callback function for HuggingFace downloads.
Args:
model_name: Name of the model
filename: Optional filename filter
Returns:
Callback function
"""
def callback(progress: Dict):
"""HuggingFace Hub progress callback."""
if "total" in progress and "current" in progress:
current = progress.get("current", 0)
total = progress.get("total", 0)
file_name = progress.get("filename", filename)
self.update_progress(
model_name=model_name,
current=current,
total=total,
filename=file_name,
status="downloading",
)
return callback
async def subscribe(self, model_name: str):
"""
Subscribe to progress updates for a model.
Yields progress updates as Server-Sent Events.
"""
import logging
logger = logging.getLogger(__name__)
# Store the main event loop for thread-safe operations
try:
self._main_loop = asyncio.get_running_loop()
except RuntimeError:
pass
queue = asyncio.Queue(maxsize=10)
# Add to listeners
if model_name not in self._listeners:
self._listeners[model_name] = []
self._listeners[model_name].append(queue)
logger.info(f"SSE client subscribed to {model_name}, total listeners: {len(self._listeners[model_name])}")
try:
# Send initial progress if available and still in progress (thread-safe read)
with self._lock:
initial_progress = self._progress.get(model_name)
if initial_progress:
initial_progress = initial_progress.copy()
if initial_progress:
status = initial_progress.get('status')
# Only send initial progress if download is actually in progress
# Don't send old 'complete' or 'error' status from previous downloads
if status in ('downloading', 'extracting'):
logger.info(f"Sending initial progress for {model_name}: {status}")
yield f"data: {json.dumps(initial_progress)}\n\n"
else:
logger.info(f"Skipping initial progress for {model_name} (status: {status})")
else:
logger.info(f"No initial progress available for {model_name}")
# Stream updates
while True:
try:
# Wait for update with timeout
progress = await asyncio.wait_for(queue.get(), timeout=1.0)
logger.debug(f"Sending progress update for {model_name}: {progress.get('status')} - {progress.get('progress', 0):.1f}%")
yield f"data: {json.dumps(progress)}\n\n"
# Stop if complete or error
if progress.get("status") in ("complete", "error"):
logger.info(f"Download {progress.get('status')} for {model_name}, closing SSE connection")
break
except asyncio.TimeoutError:
# Send heartbeat
yield ": heartbeat\n\n"
continue
except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError):
logger.debug(f"SSE client disconnected from {model_name}")
finally:
# Remove from listeners
if model_name in self._listeners:
self._listeners[model_name].remove(queue)
if not self._listeners[model_name]:
del self._listeners[model_name]
logger.info(f"SSE client unsubscribed from {model_name}, remaining listeners: {len(self._listeners.get(model_name, []))}")
def mark_complete(self, model_name: str):
"""Mark a model download as complete. Thread-safe."""
import logging
logger = logging.getLogger(__name__)
with self._lock:
if model_name in self._progress:
self._progress[model_name]["status"] = "complete"
self._progress[model_name]["progress"] = 100.0
progress_data = self._progress[model_name].copy()
else:
logger.warning(f"Cannot mark {model_name} as complete: not found in progress")
return
logger.info(f"Marked {model_name} as complete")
# Notify listeners (thread-safe)
self._notify_listeners_threadsafe(model_name, progress_data)
def mark_error(self, model_name: str, error: str):
"""Mark a model download as failed. Thread-safe."""
import logging
logger = logging.getLogger(__name__)
with self._lock:
if model_name in self._progress:
self._progress[model_name]["status"] = "error"
self._progress[model_name]["error"] = error
progress_data = self._progress[model_name].copy()
else:
# Create new progress entry for error
progress_data = {
"model_name": model_name,
"current": 0,
"total": 0,
"progress": 0,
"filename": None,
"status": "error",
"error": error,
"timestamp": datetime.now().isoformat(),
}
self._progress[model_name] = progress_data
logger.error(f"Marked {model_name} as error: {error}")
# Notify listeners (thread-safe)
self._notify_listeners_threadsafe(model_name, progress_data)
# Global progress manager instance
_progress_manager: Optional[ProgressManager] = None
def get_progress_manager() -> ProgressManager:
"""Get or create the global progress manager."""
global _progress_manager
if _progress_manager is None:
_progress_manager = ProgressManager()
return _progress_manager

102
backend/utils/tasks.py Normal file
View File

@@ -0,0 +1,102 @@
"""
Task tracking for active downloads and generations.
"""
from typing import Optional, Dict, List
from datetime import datetime
from dataclasses import dataclass, field
@dataclass
class DownloadTask:
"""Represents an active download task."""
model_name: str
status: str = "downloading" # downloading, extracting, complete, error
started_at: datetime = field(default_factory=datetime.utcnow)
error: Optional[str] = None
@dataclass
class GenerationTask:
"""Represents an active generation task."""
task_id: str
profile_id: str
text_preview: str # First 50 chars of text
started_at: datetime = field(default_factory=datetime.utcnow)
class TaskManager:
"""Manages active downloads and generations."""
def __init__(self):
self._active_downloads: Dict[str, DownloadTask] = {}
self._active_generations: Dict[str, GenerationTask] = {}
def start_download(self, model_name: str) -> None:
"""Mark a download as started."""
self._active_downloads[model_name] = DownloadTask(
model_name=model_name,
status="downloading",
)
def complete_download(self, model_name: str) -> None:
"""Mark a download as complete."""
if model_name in self._active_downloads:
del self._active_downloads[model_name]
def error_download(self, model_name: str, error: str) -> None:
"""Mark a download as failed."""
if model_name in self._active_downloads:
self._active_downloads[model_name].status = "error"
self._active_downloads[model_name].error = error
def start_generation(self, task_id: str, profile_id: str, text: str) -> None:
"""Mark a generation as started."""
text_preview = text[:50] + "..." if len(text) > 50 else text
self._active_generations[task_id] = GenerationTask(
task_id=task_id,
profile_id=profile_id,
text_preview=text_preview,
)
def complete_generation(self, task_id: str) -> None:
"""Mark a generation as complete."""
if task_id in self._active_generations:
del self._active_generations[task_id]
def get_active_downloads(self) -> List[DownloadTask]:
"""Get all active downloads."""
return list(self._active_downloads.values())
def get_active_generations(self) -> List[GenerationTask]:
"""Get all active generations."""
return list(self._active_generations.values())
def cancel_download(self, model_name: str) -> bool:
"""Cancel/dismiss a download task (removes it from active list)."""
return self._active_downloads.pop(model_name, None) is not None
def clear_all(self) -> None:
"""Clear all download and generation tasks."""
self._active_downloads.clear()
self._active_generations.clear()
def is_download_active(self, model_name: str) -> bool:
"""Check if a download is active."""
return model_name in self._active_downloads
def is_generation_active(self, task_id: str) -> bool:
"""Check if a generation is active."""
return task_id in self._active_generations
# Global task manager instance
_task_manager: Optional[TaskManager] = None
def get_task_manager() -> TaskManager:
"""Get or create the global task manager."""
global _task_manager
if _task_manager is None:
_task_manager = TaskManager()
return _task_manager