Initial commit
This commit is contained in:
135
backend/README.md
Normal file
135
backend/README.md
Normal file
@@ -0,0 +1,135 @@
|
||||
# Voicebox Backend
|
||||
|
||||
FastAPI server powering voice cloning, speech generation, and audio processing. Runs locally as a Tauri sidecar or standalone via `python -m backend.main`.
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
# Via justfile (recommended)
|
||||
just dev:server
|
||||
|
||||
# Standalone
|
||||
python -m backend.main --host 127.0.0.1 --port 17493
|
||||
|
||||
# With custom data directory
|
||||
python -m backend.main --data-dir /path/to/data
|
||||
```
|
||||
|
||||
The server auto-initializes the SQLite database on first startup. Models are downloaded from HuggingFace on first use.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
backend/
|
||||
app.py # FastAPI app factory, CORS, lifecycle events
|
||||
main.py # Entry point (imports app, runs uvicorn)
|
||||
config.py # Data directory paths and configuration
|
||||
models.py # Pydantic request/response schemas
|
||||
server.py # Tauri sidecar launcher, parent-pid watchdog
|
||||
|
||||
routes/ # Thin HTTP handlers — validation, delegation, response formatting
|
||||
services/ # Business logic, CRUD, orchestration
|
||||
backends/ # TTS/STT engine implementations (MLX, PyTorch, etc.)
|
||||
database/ # ORM models, session management, migrations, seed data
|
||||
utils/ # Shared utilities (audio, effects, caching, progress tracking)
|
||||
```
|
||||
|
||||
### Request flow
|
||||
|
||||
```
|
||||
HTTP request
|
||||
-> routes/ (validate input, parse params)
|
||||
-> services/ (business logic, database queries, orchestration)
|
||||
-> backends/ (TTS/STT inference)
|
||||
-> utils/ (audio processing, effects, caching)
|
||||
```
|
||||
|
||||
Route handlers are intentionally thin. They validate input, delegate to a service function, and format the response. All business logic lives in `services/`.
|
||||
|
||||
### Key modules
|
||||
|
||||
**services/generation.py** -- Single `run_generation()` function that handles all three generation modes (generate, retry, regenerate). Manages model loading, voice prompt creation, chunked inference, normalization, effects, and version persistence.
|
||||
|
||||
**services/task_queue.py** -- Serial generation queue. Ensures only one GPU inference runs at a time. Background tasks are tracked to prevent garbage collection.
|
||||
|
||||
**backends/__init__.py** -- Protocol definitions (`TTSBackend`, `STTBackend`), model config registry, and factory functions. Adding a new engine means implementing the protocol and registering a config entry.
|
||||
|
||||
**backends/base.py** -- Shared utilities used across all engine implementations: HuggingFace cache checks, device detection, voice prompt combination, progress tracking.
|
||||
|
||||
**database/** -- SQLAlchemy ORM models with a re-exporting `__init__.py` for backward compatibility. Migrations run automatically on startup.
|
||||
|
||||
### Backend selection
|
||||
|
||||
The server detects the best inference backend at startup:
|
||||
|
||||
| Platform | Backend | Acceleration |
|
||||
|----------|---------|-------------|
|
||||
| macOS (Apple Silicon) | MLX | Metal / Neural Engine |
|
||||
| Windows / Linux (NVIDIA) | PyTorch | CUDA |
|
||||
| Linux (AMD) | PyTorch | ROCm |
|
||||
| Intel Arc | PyTorch | IPEX / XPU |
|
||||
| Windows (any GPU) | PyTorch | DirectML |
|
||||
| Any | PyTorch | CPU fallback |
|
||||
|
||||
Detection is handled by `utils/platform_detect.py`. Both backends implement the same `TTSBackend` protocol, so the API layer is engine-agnostic.
|
||||
|
||||
## API
|
||||
|
||||
90 endpoints organized by domain. Full interactive documentation available at `http://localhost:17493/docs` when the server is running.
|
||||
|
||||
| Domain | Prefix | Description |
|
||||
|--------|--------|-------------|
|
||||
| Health | `/`, `/health` | Server status, GPU info, filesystem checks |
|
||||
| Profiles | `/profiles` | Voice profile CRUD, samples, avatars, import/export |
|
||||
| Channels | `/channels` | Audio channel management and voice assignment |
|
||||
| Generation | `/generate` | TTS generation, retry, regenerate, status SSE |
|
||||
| History | `/history` | Generation history, search, favorites, export |
|
||||
| Transcription | `/transcribe` | Whisper-based audio-to-text |
|
||||
| Stories | `/stories` | Multi-track timeline editor, audio export |
|
||||
| Effects | `/effects` | Effect presets, preview, version management |
|
||||
| Audio | `/audio`, `/samples` | Audio file serving |
|
||||
| Models | `/models` | Load, unload, download, migrate, status |
|
||||
| Tasks | `/tasks`, `/cache` | Active task tracking, cache management |
|
||||
| CUDA | `/backend/cuda-*` | CUDA binary download and management |
|
||||
|
||||
### Quick examples
|
||||
|
||||
```bash
|
||||
# Generate speech
|
||||
curl -X POST http://localhost:17493/generate \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"text": "Hello world", "profile_id": "...", "language": "en"}'
|
||||
|
||||
# List profiles
|
||||
curl http://localhost:17493/profiles
|
||||
|
||||
# Stream generation status (SSE)
|
||||
curl http://localhost:17493/generate/{id}/status
|
||||
```
|
||||
|
||||
## Data directory
|
||||
|
||||
```
|
||||
{data_dir}/
|
||||
voicebox.db # SQLite database
|
||||
profiles/{id}/ # Voice samples per profile
|
||||
generations/ # Generated audio files
|
||||
cache/ # Voice prompt cache (memory + disk)
|
||||
backends/ # Downloaded CUDA binary (if applicable)
|
||||
```
|
||||
|
||||
Default location is the OS-specific app data directory. Override with `--data-dir` or the `VOICEBOX_DATA_DIR` environment variable.
|
||||
|
||||
## Code quality
|
||||
|
||||
Linting and formatting are enforced by [ruff](https://docs.astral.sh/ruff/), configured in `pyproject.toml`. See `STYLE_GUIDE.md` for conventions.
|
||||
|
||||
```bash
|
||||
just check-python # lint + format check
|
||||
just fix-python # auto-fix lint issues + reformat
|
||||
just test # run pytest
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
Runtime dependencies are in `requirements.txt`. macOS-only MLX dependencies are in `requirements-mlx.txt`. Dev tools (ruff, pytest) are installed automatically by `just setup-python`.
|
||||
404
backend/STYLE_GUIDE.md
Normal file
404
backend/STYLE_GUIDE.md
Normal file
@@ -0,0 +1,404 @@
|
||||
# Python Style Guide
|
||||
|
||||
Target: **Python 3.12+** | Formatter/Linter: **Ruff** | Config: `backend/pyproject.toml`
|
||||
|
||||
This guide codifies the conventions used across the backend, and prescribes the target style for code written during the refactor (Phases 3-6). Existing code should be migrated incrementally -- don't reformat entire files in unrelated PRs.
|
||||
|
||||
---
|
||||
|
||||
## Formatting
|
||||
|
||||
Enforced by `ruff format` (Black-compatible).
|
||||
|
||||
- **Line length**: 120 characters.
|
||||
- **Indent**: 4 spaces. No tabs.
|
||||
- **Trailing commas**: Required on multi-line function signatures, arguments, collections.
|
||||
- **Quotes**: Double quotes (`"`) for strings. Single quotes are acceptable in f-string expressions and dict keys inside f-strings where avoiding escapes improves readability.
|
||||
|
||||
Run: `ruff format backend/`
|
||||
|
||||
---
|
||||
|
||||
## Imports
|
||||
|
||||
Enforced by ruff's `isort` rules (rule set `I`).
|
||||
|
||||
**Grouping** -- three blocks separated by a blank line:
|
||||
|
||||
```python
|
||||
import asyncio # 1. stdlib
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np # 2. third-party
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from backend.config import get_data_dir # 3. local (absolute)
|
||||
from .database import get_db # or relative
|
||||
```
|
||||
|
||||
**Rules:**
|
||||
- Within the `backend` package, use **relative imports** for sibling/child modules: `from .database import get_db`, `from ..utils.audio import load_audio`.
|
||||
- Absolute imports are fine for top-level references from entry points (`main.py`, `server.py`).
|
||||
- Never use wildcard imports (`from module import *`).
|
||||
- One import per line for `from X import Y` when there are 4+ names; below that, comma-separated is fine.
|
||||
- **Lazy imports** are acceptable for heavy dependencies (torch, transformers, mlx) inside functions to reduce startup time. Add a comment: `# lazy: heavy import`.
|
||||
|
||||
---
|
||||
|
||||
## Type Annotations
|
||||
|
||||
Python 3.12 means we use **built-in generics and union syntax natively**. No `from __future__ import annotations`, no `typing.List`/`typing.Dict`.
|
||||
|
||||
```python
|
||||
# Yes
|
||||
def process(items: list[str], config: dict[str, int] | None = None) -> tuple[int, str]: ...
|
||||
|
||||
# No
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
def process(items: List[str], config: Optional[Dict[str, int]] = None) -> Tuple[int, str]: ...
|
||||
```
|
||||
|
||||
**What to annotate:**
|
||||
- All public function signatures (parameters + return type).
|
||||
- Private functions: parameters at minimum; return type encouraged.
|
||||
- Module-level variables: only when the type isn't obvious from the assignment.
|
||||
- Route handlers: parameters are annotated via FastAPI's dependency injection. Add explicit `-> SomeResponse` return types when the route doesn't use `response_model`.
|
||||
|
||||
**Imports from `typing` that are still needed** (no built-in equivalent):
|
||||
`Literal`, `TypeAlias`, `Protocol`, `runtime_checkable`, `Callable`, `Any`, `ClassVar`, `TypeVar`, `overload`, `TYPE_CHECKING`.
|
||||
|
||||
Use `collections.abc` for abstract types: `Sequence`, `Mapping`, `Iterable`, `Iterator`, `Generator`.
|
||||
|
||||
---
|
||||
|
||||
## Naming
|
||||
|
||||
| Thing | Convention | Example |
|
||||
|-------|-----------|---------|
|
||||
| Module | `snake_case` | `task_queue.py` |
|
||||
| Class | `PascalCase` | `ProgressManager` |
|
||||
| Function / method | `snake_case` | `create_profile` |
|
||||
| Variable | `snake_case` | `sample_rate` |
|
||||
| Constant | `UPPER_SNAKE_CASE` | `DEFAULT_SAMPLE_RATE` |
|
||||
| Private | `_leading_underscore` | `_generation_queue` |
|
||||
| Type alias | `PascalCase` | `EffectChain = list[dict[str, Any]]` |
|
||||
|
||||
**Specific conventions:**
|
||||
- Database ORM models imported with `DB` prefix alias: `from .database import VoiceProfile as DBVoiceProfile`.
|
||||
- Pydantic models use descriptive suffixes: `VoiceProfileCreate`, `VoiceProfileResponse`, `GenerationRequest`.
|
||||
- Backend classes use engine-name prefix: `MLXTTSBackend`, `PyTorchSTTBackend`.
|
||||
|
||||
---
|
||||
|
||||
## Docstrings
|
||||
|
||||
**Google style**. Required on all public functions, classes, and modules.
|
||||
|
||||
```python
|
||||
def combine_voice_prompts(
|
||||
profile_dir: Path,
|
||||
*,
|
||||
target_sr: int = 24000,
|
||||
) -> tuple[np.ndarray, int]:
|
||||
"""Load and concatenate all voice prompt files for a profile.
|
||||
|
||||
Reads .wav/.mp3/.flac files from the profile directory, resamples to
|
||||
the target sample rate, normalizes, and concatenates into a single array.
|
||||
|
||||
Args:
|
||||
profile_dir: Path to the voice profile directory containing audio files.
|
||||
target_sr: Target sample rate for the output. Defaults to 24000.
|
||||
|
||||
Returns:
|
||||
Tuple of (concatenated audio array, sample rate).
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If profile_dir does not exist.
|
||||
ValueError: If no valid audio files are found.
|
||||
"""
|
||||
```
|
||||
|
||||
**Short form** is fine for simple functions:
|
||||
|
||||
```python
|
||||
def get_db_path() -> Path:
|
||||
"""Get the path to the SQLite database file."""
|
||||
```
|
||||
|
||||
**When to skip**: Private helpers under ~5 lines where the name and signature make intent obvious.
|
||||
|
||||
**Module docstrings**: A single sentence at the top of every file describing its purpose.
|
||||
|
||||
```python
|
||||
"""Voice profile CRUD operations."""
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Comments
|
||||
|
||||
Comments explain **why**, not **what**. If the code needs a comment to explain what it does, the code should be rewritten to be clearer. The exceptions are non-obvious performance choices, external constraints, and concurrency/race-condition reasoning -- those always deserve a comment.
|
||||
|
||||
### No section dividers
|
||||
|
||||
Do not use ASCII dividers to create visual sections in files:
|
||||
|
||||
```python
|
||||
# No -- any of these:
|
||||
# ============================================
|
||||
# GENERATION ENDPOINTS
|
||||
# ============================================
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# --- Load model --------------------------------------------------
|
||||
```
|
||||
|
||||
If a file needs section dividers to be navigable, the file is too long. Split it into modules. Within a function, if you need labeled sections to follow the logic, extract those sections into named functions.
|
||||
|
||||
### Inline comments
|
||||
|
||||
Inline comments (end-of-line) are fine when they add information the code can't express:
|
||||
|
||||
```python
|
||||
# Yes -- explains a non-obvious constraint or gives context:
|
||||
audio, sr = load_audio(path, sr=24000) # Qwen expects 24kHz mono
|
||||
_generation_queue: asyncio.Queue = None # type: ignore # initialized at startup
|
||||
"tauri://localhost", # Tauri webview (macOS)
|
||||
|
||||
# No -- restates the code:
|
||||
# Check if profile name already exists
|
||||
existing = db.query(DBVoiceProfile).filter_by(name=data.name).first()
|
||||
|
||||
# Delete from database
|
||||
db.delete(sample)
|
||||
|
||||
# Update fields
|
||||
profile.name = data.name
|
||||
```
|
||||
|
||||
Delete comments that narrate what the next line of code obviously does. If the function name, variable name, or method call already communicates intent, the comment is noise.
|
||||
|
||||
### Block comments
|
||||
|
||||
Use block comments for **why** explanations -- constraints, workarounds, non-obvious decisions:
|
||||
|
||||
```python
|
||||
# PyInstaller + multiprocessing: child processes re-execute the frozen binary
|
||||
# with internal arguments. freeze_support() handles this and exits early.
|
||||
multiprocessing.freeze_support()
|
||||
|
||||
# Mark any stale "generating" records as failed -- these are leftovers
|
||||
# from a previous process that was killed mid-generation.
|
||||
db.query(Generation).filter_by(status="generating").update({"status": "failed"})
|
||||
```
|
||||
|
||||
Keep block comments tight. Two to three lines is normal. If you need a paragraph, it probably belongs in the docstring or a design doc.
|
||||
|
||||
### Linter/type-checker suppression
|
||||
|
||||
Always add a reason after `noqa` and `type: ignore`:
|
||||
|
||||
```python
|
||||
import intel_extension_for_pytorch # noqa: F401 -- side-effect import enables XPU
|
||||
_queue: asyncio.Queue = None # type: ignore[assignment] # initialized at startup
|
||||
```
|
||||
|
||||
Bare `# noqa` or `# type: ignore` with no explanation are not allowed.
|
||||
|
||||
### TODO / FIXME
|
||||
|
||||
Use sparingly. Every `TODO` must include a brief description of what needs doing. Don't use them as a substitute for tracking work properly:
|
||||
|
||||
```python
|
||||
# TODO: replace with async SQLAlchemy once CRUD modules are migrated (Phase 5)
|
||||
result = await asyncio.to_thread(profiles.get_profile, profile_id, db)
|
||||
```
|
||||
|
||||
Never commit `HACK`, `XXX`, or `FIXME` -- fix the problem or file an issue.
|
||||
|
||||
### Commented-out code
|
||||
|
||||
Delete it. That's what git is for. If you need to document that something was intentionally removed, a short tombstone comment is acceptable:
|
||||
|
||||
```python
|
||||
# Removed config.json-only check -- too lenient, doesn't confirm weights exist.
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Error Handling
|
||||
|
||||
The refactor is standardizing on a **two-layer pattern**:
|
||||
|
||||
### 1. Domain layer -- raise plain exceptions
|
||||
|
||||
CRUD modules and services raise `ValueError`, `FileNotFoundError`, or (post-refactor) custom exceptions defined in `backend/errors.py`:
|
||||
|
||||
```python
|
||||
# backend/errors.py (to be created in Phase 4)
|
||||
class NotFoundError(Exception):
|
||||
"""Raised when a requested resource does not exist."""
|
||||
|
||||
class ConflictError(Exception):
|
||||
"""Raised on uniqueness constraint violations."""
|
||||
```
|
||||
|
||||
```python
|
||||
# In a service or CRUD module:
|
||||
raise NotFoundError(f"Profile {profile_id} not found")
|
||||
```
|
||||
|
||||
### 2. Route layer -- translate to HTTPException
|
||||
|
||||
Route handlers catch domain exceptions and convert:
|
||||
|
||||
```python
|
||||
@router.post("/profiles")
|
||||
async def create_profile(data: VoiceProfileCreate, db: Session = Depends(get_db)):
|
||||
try:
|
||||
return await profiles.create_profile(data, db)
|
||||
except ConflictError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
```
|
||||
|
||||
**Background tasks** catch `Exception` broadly, log with `logger.exception()`, and update the task status to `"failed"`.
|
||||
|
||||
**Never**: silently swallow exceptions, use bare `except:`, or catch `BaseException`.
|
||||
|
||||
---
|
||||
|
||||
## Async
|
||||
|
||||
### Rules for the refactor
|
||||
|
||||
1. **Don't declare `async def` unless the function awaits something.** Several service modules still declare `async def` without awaiting -- these should be migrated to sync functions with `asyncio.to_thread()` at the route layer, or to real async SQLAlchemy.
|
||||
2. **CPU-bound work** (audio processing, numpy operations) goes through `asyncio.to_thread()`:
|
||||
```python
|
||||
audio, sr = await asyncio.to_thread(load_audio, source_path)
|
||||
```
|
||||
3. **GPU-bound TTS inference** is serialized through the generation queue (`services/task_queue.py`). Never call a backend's `generate()` directly from a route handler.
|
||||
4. **Fire-and-forget tasks**: use `asyncio.create_task()` and track the task reference to prevent garbage collection:
|
||||
```python
|
||||
task = asyncio.create_task(some_coro())
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Logging
|
||||
|
||||
Use the `logging` module. Not `print()`.
|
||||
|
||||
```python
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info("Loading model %s on %s", model_name, device)
|
||||
logger.warning("Cache miss for %s, downloading", repo_id)
|
||||
logger.exception("Generation %s failed") # logs traceback automatically
|
||||
```
|
||||
|
||||
**Rules:**
|
||||
- Use `%s`-style placeholders in log calls (not f-strings). This avoids formatting the string if the log level is filtered out.
|
||||
- Use `logger.exception()` inside `except` blocks -- it captures the traceback.
|
||||
- Logger name should be `__name__` (yields `backend.utils.audio`, etc.).
|
||||
- Existing `print()` calls should be migrated to logging as files are touched during the refactor.
|
||||
|
||||
---
|
||||
|
||||
## Constants
|
||||
|
||||
- Define at **module level** in the file where they're primarily used.
|
||||
- Use `UPPER_SNAKE_CASE`.
|
||||
- Shared/cross-cutting constants (sample rates, file size limits, CORS origins) go in `backend/config.py` after Phase 6 consolidation.
|
||||
- Magic numbers in function bodies should be extracted to named constants:
|
||||
```python
|
||||
# No
|
||||
if len(audio) > 24000 * 60 * 10:
|
||||
|
||||
# Yes
|
||||
MAX_AUDIO_DURATION_SAMPLES = SAMPLE_RATE * 60 * 10
|
||||
if len(audio) > MAX_AUDIO_DURATION_SAMPLES:
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Function Signatures
|
||||
|
||||
- **Keyword-only arguments** (after `*`) for functions with 3+ parameters, especially when several share the same type:
|
||||
```python
|
||||
def is_model_cached(
|
||||
hf_repo: str,
|
||||
*,
|
||||
weight_extensions: tuple[str, ...] = (".safetensors", ".bin"),
|
||||
required_files: list[str] | None = None,
|
||||
) -> bool:
|
||||
```
|
||||
- Parameters on **separate lines** when the signature exceeds ~100 characters or has 3+ params.
|
||||
- **Trailing comma** after the last parameter in multi-line signatures.
|
||||
- Default values inline with the parameter.
|
||||
|
||||
---
|
||||
|
||||
## String Formatting
|
||||
|
||||
- **f-strings** for runtime string construction.
|
||||
- **`%s`-style** for `logging` calls (lazy evaluation).
|
||||
- **`.format()`**: avoid; f-strings are preferred.
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
Framework: **pytest** with `pytest-asyncio`.
|
||||
|
||||
- Test files: `test_<module>.py` in `backend/tests/`.
|
||||
- Use `conftest.py` for shared fixtures (db sessions, test client, mock backends).
|
||||
- Group related tests in classes: `class TestProfileCRUD:`.
|
||||
- Use `@pytest.mark.asyncio` for async tests.
|
||||
- Use `@pytest.mark.parametrize` to reduce repetition.
|
||||
- Manual integration scripts stay in `tests/` but are clearly marked (filename prefix `manual_` or documented in `tests/README.md`).
|
||||
|
||||
---
|
||||
|
||||
## Project Layout
|
||||
|
||||
```
|
||||
backend/
|
||||
app.py # FastAPI app factory, CORS, lifecycle events
|
||||
main.py # Entry point (imports app, runs uvicorn)
|
||||
config.py # Data directory paths
|
||||
models.py # Pydantic request/response schemas
|
||||
server.py # Tauri sidecar launcher, parent-pid watchdog
|
||||
routes/ # Thin HTTP handlers (validation, delegation, response formatting)
|
||||
services/ # Business logic, CRUD, orchestration
|
||||
backends/ # TTS/STT engine implementations
|
||||
database/ # ORM models, session management, migrations, seeds
|
||||
utils/ # Shared utilities (audio, effects, caching, progress)
|
||||
tests/ # pytest suite
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Ruff Adoption
|
||||
|
||||
`pyproject.toml` configures ruff for linting and formatting. Run:
|
||||
|
||||
```bash
|
||||
# Lint (check)
|
||||
ruff check backend/
|
||||
|
||||
# Lint (auto-fix)
|
||||
ruff check backend/ --fix
|
||||
|
||||
# Format
|
||||
ruff format backend/
|
||||
```
|
||||
|
||||
Introduce ruff fixes file-by-file as you touch them. Don't run `--fix` across the entire codebase in one shot -- that creates unreviewable diffs.
|
||||
3
backend/__init__.py
Normal file
3
backend/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Backend package
|
||||
|
||||
__version__ = "0.4.5"
|
||||
281
backend/app.py
Normal file
281
backend/app.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""FastAPI application factory, middleware, and lifecycle events."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
"""Custom formatter to add colors matching uvicorn's style."""
|
||||
|
||||
COLORS = {
|
||||
"DEBUG": "\033[36m", # Cyan
|
||||
"INFO": "\033[32m", # Green
|
||||
"WARNING": "\033[33m", # Yellow
|
||||
"ERROR": "\033[31m", # Red
|
||||
"CRITICAL": "\033[35m", # Magenta
|
||||
}
|
||||
RESET = "\033[0m"
|
||||
|
||||
def format(self, record):
|
||||
log_color = self.COLORS.get(record.levelname, self.RESET)
|
||||
record.levelname = f"{log_color}{record.levelname}{self.RESET}"
|
||||
return super().format(record)
|
||||
|
||||
|
||||
# Configure logging to match uvicorn's format with colors
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
handler.setFormatter(ColoredFormatter("%(levelname)s: %(message)s"))
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
handlers=[handler],
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# AMD GPU environment variables must be set before torch import
|
||||
if not os.environ.get("HSA_OVERRIDE_GFX_VERSION"):
|
||||
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0"
|
||||
if not os.environ.get("MIOPEN_LOG_LEVEL"):
|
||||
os.environ["MIOPEN_LOG_LEVEL"] = "4"
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from urllib.parse import quote
|
||||
|
||||
from . import __version__, config, database
|
||||
from .services import tts, transcribe
|
||||
from .database import get_db
|
||||
from .utils.platform_detect import get_backend_type
|
||||
from .utils.progress import get_progress_manager
|
||||
from .services.task_queue import create_background_task, init_queue
|
||||
from .routes import register_routers
|
||||
|
||||
|
||||
def safe_content_disposition(disposition_type: str, filename: str) -> str:
|
||||
"""Build a Content-Disposition header safe for non-ASCII filenames.
|
||||
|
||||
Uses RFC 5987 ``filename*`` parameter so browsers can decode UTF-8
|
||||
filenames while the ``filename`` fallback stays ASCII-only.
|
||||
"""
|
||||
ascii_name = "".join(c for c in filename if c.isascii() and (c.isalnum() or c in " -_.")).strip() or "download"
|
||||
utf8_name = quote(filename, safe="")
|
||||
return f"{disposition_type}; filename=\"{ascii_name}\"; filename*=UTF-8''{utf8_name}"
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application."""
|
||||
application = FastAPI(
|
||||
title="voicebox API",
|
||||
description="Production-quality Qwen3-TTS voice cloning API",
|
||||
version=__version__,
|
||||
)
|
||||
|
||||
_configure_cors(application)
|
||||
register_routers(application)
|
||||
_register_lifecycle(application)
|
||||
_mount_frontend(application)
|
||||
|
||||
return application
|
||||
|
||||
|
||||
def _configure_cors(application: FastAPI) -> None:
|
||||
"""Set up CORS middleware with local-first defaults."""
|
||||
default_origins = [
|
||||
"http://localhost:5173", # Vite dev server
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:17493",
|
||||
"http://127.0.0.1:17493",
|
||||
"tauri://localhost", # Tauri webview (macOS)
|
||||
"https://tauri.localhost", # Tauri webview (Windows/Linux)
|
||||
"http://tauri.localhost", # Tauri webview (Windows, some builds)
|
||||
]
|
||||
env_origins = os.environ.get("VOICEBOX_CORS_ORIGINS", "")
|
||||
all_origins = default_origins + [o.strip() for o in env_origins.split(",") if o.strip()]
|
||||
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=all_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def _mount_frontend(application: FastAPI) -> None:
|
||||
"""Serve the built web frontend when present (Docker / web deployment).
|
||||
|
||||
The Dockerfile copies the Vite build output to ``/app/frontend/``. When
|
||||
that directory exists we mount static assets and add a catch-all route so
|
||||
the React SPA handles client-side routing. In dev or API-only mode the
|
||||
directory is absent and this function is a no-op.
|
||||
"""
|
||||
frontend_dir = Path(__file__).resolve().parent.parent / "frontend"
|
||||
if not frontend_dir.is_dir():
|
||||
return
|
||||
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
# Mount hashed assets (JS, CSS, images) that Vite places under /assets
|
||||
assets_dir = frontend_dir / "assets"
|
||||
if assets_dir.is_dir():
|
||||
application.mount(
|
||||
"/assets",
|
||||
StaticFiles(directory=str(assets_dir)),
|
||||
name="frontend-assets",
|
||||
)
|
||||
|
||||
# SPA catch-all: serve files if they exist, otherwise index.html for
|
||||
# client-side routes like /voices, /stories, /models, etc.
|
||||
@application.get("/{full_path:path}")
|
||||
async def serve_spa(full_path: str):
|
||||
file_path = (frontend_dir / full_path).resolve()
|
||||
# Guard against path traversal — only serve files inside frontend_dir
|
||||
if full_path and file_path.is_file() and file_path.is_relative_to(frontend_dir):
|
||||
return FileResponse(file_path)
|
||||
return FileResponse(frontend_dir / "index.html", media_type="text/html")
|
||||
|
||||
logger.info("Frontend: serving SPA from %s", frontend_dir)
|
||||
|
||||
|
||||
def _get_gpu_status() -> str:
|
||||
"""Return a human-readable string describing GPU availability."""
|
||||
backend_type = get_backend_type()
|
||||
if torch.cuda.is_available():
|
||||
from .backends.base import check_cuda_compatibility
|
||||
|
||||
device_name = torch.cuda.get_device_name(0)
|
||||
compatible, _warning = check_cuda_compatibility()
|
||||
is_rocm = hasattr(torch.version, "hip") and torch.version.hip is not None
|
||||
if is_rocm:
|
||||
label = f"ROCm ({device_name})"
|
||||
else:
|
||||
label = f"CUDA ({device_name})"
|
||||
if not compatible:
|
||||
label += " [UNSUPPORTED - see logs]"
|
||||
return label
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return "MPS (Apple Silicon)"
|
||||
elif backend_type == "mlx":
|
||||
return "Metal (Apple Silicon via MLX)"
|
||||
|
||||
# Intel XPU (Arc / Data Center) via IPEX
|
||||
try:
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
try:
|
||||
xpu_name = torch.xpu.get_device_name(0)
|
||||
except Exception:
|
||||
xpu_name = "Intel GPU"
|
||||
return f"XPU ({xpu_name})"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return "None (CPU only)"
|
||||
|
||||
|
||||
def _register_lifecycle(application: FastAPI) -> None:
|
||||
"""Attach startup and shutdown event handlers."""
|
||||
|
||||
@application.on_event("startup")
|
||||
async def startup_event():
|
||||
import platform
|
||||
import sys
|
||||
|
||||
logger.info("Voicebox v%s starting up", __version__)
|
||||
logger.info(
|
||||
"Python %s on %s %s (%s)",
|
||||
sys.version.split()[0],
|
||||
platform.system(),
|
||||
platform.release(),
|
||||
platform.machine(),
|
||||
)
|
||||
|
||||
database.init_db()
|
||||
|
||||
from .database.session import _db_path
|
||||
|
||||
logger.info("Database: %s", _db_path)
|
||||
logger.info("Data directory: %s", config.get_data_dir())
|
||||
|
||||
init_queue()
|
||||
|
||||
# Mark stale "generating" records as failed -- leftovers from a killed process
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
result = db.execute(
|
||||
sa_text(
|
||||
"UPDATE generations SET status = 'failed', "
|
||||
"error = 'Server was shut down during generation' "
|
||||
"WHERE status IN ('generating', 'loading_model')"
|
||||
)
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
logger.info("Marked %d stale generation(s) as failed", result.rowcount)
|
||||
|
||||
from .database import VoiceProfile as DBVoiceProfile, Generation as DBGeneration
|
||||
|
||||
profile_count = db.query(DBVoiceProfile).count()
|
||||
generation_count = db.query(DBGeneration).count()
|
||||
logger.info("Profiles: %d, Generations: %d", profile_count, generation_count)
|
||||
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.warning("Could not clean up stale generations: %s", e)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
backend_type = get_backend_type()
|
||||
logger.info("Backend: %s", backend_type.upper())
|
||||
logger.info("GPU: %s", _get_gpu_status())
|
||||
|
||||
# Warn if GPU architecture is not supported by this PyTorch build
|
||||
from .backends.base import check_cuda_compatibility
|
||||
|
||||
_compatible, _cuda_warning = check_cuda_compatibility()
|
||||
if not _compatible:
|
||||
logger.warning("GPU COMPATIBILITY: %s", _cuda_warning)
|
||||
|
||||
from .services.cuda import check_and_update_cuda_binary
|
||||
|
||||
create_background_task(check_and_update_cuda_binary())
|
||||
|
||||
try:
|
||||
progress_manager = get_progress_manager()
|
||||
progress_manager._set_main_loop(asyncio.get_running_loop())
|
||||
except Exception as e:
|
||||
logger.warning("Could not initialize progress manager event loop: %s", e)
|
||||
|
||||
try:
|
||||
from huggingface_hub import constants as hf_constants
|
||||
|
||||
cache_dir = Path(hf_constants.HF_HUB_CACHE)
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Model cache: %s", cache_dir)
|
||||
except Exception as e:
|
||||
logger.warning("Could not create HuggingFace cache directory: %s", e)
|
||||
|
||||
logger.info("Ready")
|
||||
|
||||
@application.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
logger.info("Voicebox server shutting down...")
|
||||
try:
|
||||
tts.unload_tts_model()
|
||||
except Exception:
|
||||
logger.exception("Failed to unload TTS model")
|
||||
try:
|
||||
transcribe.unload_whisper_model()
|
||||
except Exception:
|
||||
logger.exception("Failed to unload Whisper model")
|
||||
|
||||
|
||||
app = create_app()
|
||||
621
backend/backends/__init__.py
Normal file
621
backend/backends/__init__.py
Normal file
@@ -0,0 +1,621 @@
|
||||
"""
|
||||
Backend abstraction layer for TTS and STT.
|
||||
|
||||
Provides a unified interface for MLX and PyTorch backends,
|
||||
and a model config registry that eliminates per-engine dispatch maps.
|
||||
"""
|
||||
|
||||
# Install HF compatibility patches before any backend imports transformers /
|
||||
# huggingface_hub. The module runs ``patch_transformers_mistral_regex`` at
|
||||
# import time, which wraps transformers' tokenizer load against the
|
||||
# unconditional HuggingFace metadata call that otherwise raises on
|
||||
# HF_HUB_OFFLINE=1 and on network failures.
|
||||
from ..utils import hf_offline_patch # noqa: F401
|
||||
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol, Optional, Tuple, List
|
||||
from typing_extensions import runtime_checkable
|
||||
import numpy as np
|
||||
|
||||
from ..utils.platform_detect import get_backend_type
|
||||
|
||||
LANGUAGE_CODE_TO_NAME = {
|
||||
"zh": "chinese",
|
||||
"en": "english",
|
||||
"ja": "japanese",
|
||||
"ko": "korean",
|
||||
"de": "german",
|
||||
"fr": "french",
|
||||
"ru": "russian",
|
||||
"pt": "portuguese",
|
||||
"es": "spanish",
|
||||
"it": "italian",
|
||||
}
|
||||
|
||||
WHISPER_HF_REPOS = {
|
||||
"base": "openai/whisper-base",
|
||||
"small": "openai/whisper-small",
|
||||
"medium": "openai/whisper-medium",
|
||||
"large": "openai/whisper-large-v3",
|
||||
"turbo": "openai/whisper-large-v3-turbo",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Declarative config for a downloadable model variant."""
|
||||
|
||||
model_name: str # e.g. "luxtts", "chatterbox-tts"
|
||||
display_name: str # e.g. "LuxTTS (Fast, CPU-friendly)"
|
||||
engine: str # e.g. "luxtts", "chatterbox"
|
||||
hf_repo_id: str # e.g. "YatharthS/LuxTTS"
|
||||
model_size: str = "default"
|
||||
size_mb: int = 0
|
||||
needs_trim: bool = False
|
||||
supports_instruct: bool = False
|
||||
languages: list[str] = field(default_factory=lambda: ["en"])
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TTSBackend(Protocol):
|
||||
"""Protocol for TTS backend implementations."""
|
||||
|
||||
# Each backend class should define MODEL_CONFIGS as a class variable:
|
||||
# MODEL_CONFIGS: list[ModelConfig]
|
||||
|
||||
async def load_model(self, model_size: str) -> None:
|
||||
"""Load TTS model."""
|
||||
...
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
Returns:
|
||||
Tuple of (voice_prompt_dict, was_cached)
|
||||
"""
|
||||
...
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
"""
|
||||
Combine multiple voice prompts.
|
||||
|
||||
Returns:
|
||||
Tuple of (combined_audio_array, combined_text)
|
||||
"""
|
||||
...
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio from text.
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
...
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
...
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
...
|
||||
|
||||
def _get_model_path(self, model_size: str) -> str:
|
||||
"""
|
||||
Get model path for a given size.
|
||||
|
||||
Returns:
|
||||
Model path or HuggingFace Hub ID
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class STTBackend(Protocol):
|
||||
"""Protocol for STT (Speech-to-Text) backend implementations."""
|
||||
|
||||
async def load_model(self, model_size: str) -> None:
|
||||
"""Load STT model."""
|
||||
...
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: str,
|
||||
language: Optional[str] = None,
|
||||
model_size: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe audio to text.
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
...
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
...
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
...
|
||||
|
||||
|
||||
# Global backend instances
|
||||
_tts_backend: Optional[TTSBackend] = None
|
||||
_tts_backends: dict[str, TTSBackend] = {}
|
||||
_tts_backends_lock = threading.Lock()
|
||||
_stt_backend: Optional[STTBackend] = None
|
||||
|
||||
# Supported TTS engines — keyed by engine name, value is the backend class import path.
|
||||
# The factory function uses this for the if/elif chain; the model configs live on the backend classes.
|
||||
TTS_ENGINES = {
|
||||
"qwen": "Qwen TTS",
|
||||
"qwen_custom_voice": "Qwen CustomVoice",
|
||||
"luxtts": "LuxTTS",
|
||||
"chatterbox": "Chatterbox TTS",
|
||||
"chatterbox_turbo": "Chatterbox Turbo",
|
||||
"tada": "TADA",
|
||||
"kokoro": "Kokoro",
|
||||
}
|
||||
|
||||
|
||||
def _get_qwen_model_configs() -> list[ModelConfig]:
|
||||
"""Return Qwen model configs with backend-aware HF repo IDs."""
|
||||
backend_type = get_backend_type()
|
||||
if backend_type == "mlx":
|
||||
repo_1_7b = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16"
|
||||
repo_0_6b = "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-bf16"
|
||||
else:
|
||||
repo_1_7b = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
repo_0_6b = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
|
||||
|
||||
return [
|
||||
ModelConfig(
|
||||
model_name="qwen-tts-1.7B",
|
||||
display_name="Qwen TTS 1.7B",
|
||||
engine="qwen",
|
||||
hf_repo_id=repo_1_7b,
|
||||
model_size="1.7B",
|
||||
size_mb=3500,
|
||||
supports_instruct=False, # Base model drops instruct silently
|
||||
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="qwen-tts-0.6B",
|
||||
display_name="Qwen TTS 0.6B",
|
||||
engine="qwen",
|
||||
hf_repo_id=repo_0_6b,
|
||||
model_size="0.6B",
|
||||
size_mb=1200,
|
||||
supports_instruct=False,
|
||||
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _get_qwen_custom_voice_configs() -> list[ModelConfig]:
|
||||
"""Return Qwen CustomVoice model configs."""
|
||||
return [
|
||||
ModelConfig(
|
||||
model_name="qwen-custom-voice-1.7B",
|
||||
display_name="Qwen CustomVoice 1.7B",
|
||||
engine="qwen_custom_voice",
|
||||
hf_repo_id="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
|
||||
model_size="1.7B",
|
||||
size_mb=3500,
|
||||
supports_instruct=True,
|
||||
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="qwen-custom-voice-0.6B",
|
||||
display_name="Qwen CustomVoice 0.6B",
|
||||
engine="qwen_custom_voice",
|
||||
hf_repo_id="Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
|
||||
model_size="0.6B",
|
||||
size_mb=1200,
|
||||
supports_instruct=True,
|
||||
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _get_non_qwen_tts_configs() -> list[ModelConfig]:
|
||||
"""Return model configs for non-Qwen TTS engines.
|
||||
|
||||
These are static — no backend-type branching needed.
|
||||
"""
|
||||
return [
|
||||
ModelConfig(
|
||||
model_name="luxtts",
|
||||
display_name="LuxTTS (Fast, CPU-friendly)",
|
||||
engine="luxtts",
|
||||
hf_repo_id="YatharthS/LuxTTS",
|
||||
size_mb=300,
|
||||
languages=["en"],
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="chatterbox-tts",
|
||||
display_name="Chatterbox TTS (Multilingual)",
|
||||
engine="chatterbox",
|
||||
hf_repo_id="ResembleAI/chatterbox",
|
||||
size_mb=3200,
|
||||
needs_trim=True,
|
||||
languages=[
|
||||
"zh",
|
||||
"en",
|
||||
"ja",
|
||||
"ko",
|
||||
"de",
|
||||
"fr",
|
||||
"ru",
|
||||
"pt",
|
||||
"es",
|
||||
"it",
|
||||
"he",
|
||||
"ar",
|
||||
"da",
|
||||
"el",
|
||||
"fi",
|
||||
"hi",
|
||||
"ms",
|
||||
"nl",
|
||||
"no",
|
||||
"pl",
|
||||
"sv",
|
||||
"sw",
|
||||
"tr",
|
||||
],
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="chatterbox-turbo",
|
||||
display_name="Chatterbox Turbo (English, Tags)",
|
||||
engine="chatterbox_turbo",
|
||||
hf_repo_id="ResembleAI/chatterbox-turbo",
|
||||
size_mb=1500,
|
||||
needs_trim=True,
|
||||
languages=["en"],
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="tada-1b",
|
||||
display_name="TADA 1B (English)",
|
||||
engine="tada",
|
||||
hf_repo_id="HumeAI/tada-1b",
|
||||
model_size="1B",
|
||||
size_mb=4000,
|
||||
languages=["en"],
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="tada-3b-ml",
|
||||
display_name="TADA 3B Multilingual",
|
||||
engine="tada",
|
||||
hf_repo_id="HumeAI/tada-3b-ml",
|
||||
model_size="3B",
|
||||
size_mb=8000,
|
||||
languages=["en", "ar", "zh", "de", "es", "fr", "it", "ja", "pl", "pt"],
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="kokoro",
|
||||
display_name="Kokoro 82M",
|
||||
engine="kokoro",
|
||||
hf_repo_id="hexgrad/Kokoro-82M",
|
||||
size_mb=350,
|
||||
languages=["en", "es", "fr", "hi", "it", "pt", "ja", "zh"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _get_whisper_configs() -> list[ModelConfig]:
|
||||
"""Return Whisper STT model configs."""
|
||||
return [
|
||||
ModelConfig(
|
||||
model_name="whisper-base",
|
||||
display_name="Whisper Base",
|
||||
engine="whisper",
|
||||
hf_repo_id="openai/whisper-base",
|
||||
model_size="base",
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="whisper-small",
|
||||
display_name="Whisper Small",
|
||||
engine="whisper",
|
||||
hf_repo_id="openai/whisper-small",
|
||||
model_size="small",
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="whisper-medium",
|
||||
display_name="Whisper Medium",
|
||||
engine="whisper",
|
||||
hf_repo_id="openai/whisper-medium",
|
||||
model_size="medium",
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="whisper-large",
|
||||
display_name="Whisper Large",
|
||||
engine="whisper",
|
||||
hf_repo_id="openai/whisper-large-v3",
|
||||
model_size="large",
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="whisper-turbo",
|
||||
display_name="Whisper Turbo",
|
||||
engine="whisper",
|
||||
hf_repo_id="openai/whisper-large-v3-turbo",
|
||||
model_size="turbo",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_all_model_configs() -> list[ModelConfig]:
|
||||
"""Return the full list of model configs (TTS + STT)."""
|
||||
return _get_qwen_model_configs() + _get_qwen_custom_voice_configs() + _get_non_qwen_tts_configs() + _get_whisper_configs()
|
||||
|
||||
|
||||
def get_tts_model_configs() -> list[ModelConfig]:
|
||||
"""Return only TTS model configs."""
|
||||
return _get_qwen_model_configs() + _get_qwen_custom_voice_configs() + _get_non_qwen_tts_configs()
|
||||
|
||||
|
||||
# Lookup helpers — these replace the if/elif chains in main.py
|
||||
|
||||
|
||||
def get_model_config(model_name: str) -> Optional[ModelConfig]:
|
||||
"""Look up a model config by model_name."""
|
||||
for cfg in get_all_model_configs():
|
||||
if cfg.model_name == model_name:
|
||||
return cfg
|
||||
return None
|
||||
|
||||
|
||||
def engine_needs_trim(engine: str) -> bool:
|
||||
"""Whether this engine's output should be run through trim_tts_output."""
|
||||
for cfg in get_tts_model_configs():
|
||||
if cfg.engine == engine:
|
||||
return cfg.needs_trim
|
||||
return False
|
||||
|
||||
|
||||
def engine_has_model_sizes(engine: str) -> bool:
|
||||
"""Whether this engine supports multiple model sizes (only Qwen currently)."""
|
||||
configs = [c for c in get_tts_model_configs() if c.engine == engine]
|
||||
return len(configs) > 1
|
||||
|
||||
|
||||
async def load_engine_model(engine: str, model_size: str = "default") -> None:
|
||||
"""Load a model for the given engine, handling engines with multiple model sizes."""
|
||||
backend = get_tts_backend_for_engine(engine)
|
||||
if engine in ("qwen", "qwen_custom_voice"):
|
||||
await backend.load_model_async(model_size)
|
||||
elif engine == "tada":
|
||||
await backend.load_model(model_size)
|
||||
else:
|
||||
await backend.load_model()
|
||||
|
||||
|
||||
async def ensure_model_cached_or_raise(engine: str, model_size: str = "default") -> None:
|
||||
"""Check if a model is cached, raise HTTPException if not. Used by streaming endpoint."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
backend = get_tts_backend_for_engine(engine)
|
||||
cfg = None
|
||||
for c in get_tts_model_configs():
|
||||
if c.engine == engine and c.model_size == model_size:
|
||||
cfg = c
|
||||
break
|
||||
|
||||
if engine in ("qwen", "qwen_custom_voice", "tada"):
|
||||
if not backend._is_model_cached(model_size):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Model {model_size} is not downloaded yet. Use /generate to trigger a download.",
|
||||
)
|
||||
else:
|
||||
if not backend._is_model_cached():
|
||||
display = cfg.display_name if cfg else engine
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"{display} model is not downloaded yet. Use /generate to trigger a download.",
|
||||
)
|
||||
|
||||
|
||||
def unload_model_by_config(config: ModelConfig) -> bool:
|
||||
"""Unload a model given its config. Returns True if it was loaded, False otherwise."""
|
||||
from . import get_tts_backend_for_engine
|
||||
from ..services import tts, transcribe
|
||||
|
||||
if config.engine == "whisper":
|
||||
whisper_model = transcribe.get_whisper_model()
|
||||
if whisper_model.is_loaded() and whisper_model.model_size == config.model_size:
|
||||
transcribe.unload_whisper_model()
|
||||
return True
|
||||
return False
|
||||
|
||||
if config.engine == "qwen":
|
||||
tts_model = tts.get_tts_model()
|
||||
loaded_size = getattr(tts_model, "_current_model_size", None) or getattr(tts_model, "model_size", None)
|
||||
if tts_model.is_loaded() and loaded_size == config.model_size:
|
||||
tts.unload_tts_model()
|
||||
return True
|
||||
return False
|
||||
|
||||
if config.engine == "qwen_custom_voice":
|
||||
backend = get_tts_backend_for_engine(config.engine)
|
||||
loaded_size = getattr(backend, "_current_model_size", None) or getattr(backend, "model_size", None)
|
||||
if backend.is_loaded() and loaded_size == config.model_size:
|
||||
backend.unload_model()
|
||||
return True
|
||||
return False
|
||||
|
||||
# All other TTS engines
|
||||
backend = get_tts_backend_for_engine(config.engine)
|
||||
if backend.is_loaded():
|
||||
backend.unload_model()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_model_loaded(config: ModelConfig) -> bool:
|
||||
"""Check if a model is currently loaded."""
|
||||
from . import get_tts_backend_for_engine
|
||||
from ..services import tts, transcribe
|
||||
|
||||
try:
|
||||
if config.engine == "whisper":
|
||||
whisper_model = transcribe.get_whisper_model()
|
||||
return whisper_model.is_loaded() and getattr(whisper_model, "model_size", None) == config.model_size
|
||||
|
||||
if config.engine == "qwen":
|
||||
tts_model = tts.get_tts_model()
|
||||
loaded_size = getattr(tts_model, "_current_model_size", None) or getattr(tts_model, "model_size", None)
|
||||
return tts_model.is_loaded() and loaded_size == config.model_size
|
||||
|
||||
if config.engine == "qwen_custom_voice":
|
||||
backend = get_tts_backend_for_engine(config.engine)
|
||||
loaded_size = getattr(backend, "_current_model_size", None) or getattr(backend, "model_size", None)
|
||||
return backend.is_loaded() and loaded_size == config.model_size
|
||||
|
||||
backend = get_tts_backend_for_engine(config.engine)
|
||||
return backend.is_loaded()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_model_load_func(config: ModelConfig):
|
||||
"""Return a callable that loads/downloads the model."""
|
||||
from . import get_tts_backend_for_engine
|
||||
from ..services import tts, transcribe
|
||||
|
||||
if config.engine == "whisper":
|
||||
return lambda: transcribe.get_whisper_model().load_model(config.model_size)
|
||||
|
||||
if config.engine == "qwen":
|
||||
return lambda: tts.get_tts_model().load_model(config.model_size)
|
||||
|
||||
if config.engine == "qwen_custom_voice":
|
||||
return lambda: get_tts_backend_for_engine(config.engine).load_model(config.model_size)
|
||||
|
||||
return lambda: get_tts_backend_for_engine(config.engine).load_model()
|
||||
|
||||
|
||||
def get_tts_backend() -> TTSBackend:
|
||||
"""
|
||||
Get or create the default (Qwen) TTS backend instance based on platform.
|
||||
|
||||
Returns:
|
||||
TTS backend instance (MLX or PyTorch)
|
||||
"""
|
||||
return get_tts_backend_for_engine("qwen")
|
||||
|
||||
|
||||
def get_tts_backend_for_engine(engine: str) -> TTSBackend:
|
||||
"""
|
||||
Get or create a TTS backend for the given engine.
|
||||
|
||||
Args:
|
||||
engine: Engine name (e.g. "qwen", "luxtts", "chatterbox", "chatterbox_turbo")
|
||||
|
||||
Returns:
|
||||
TTS backend instance
|
||||
"""
|
||||
global _tts_backends
|
||||
|
||||
# Fast path: check without lock
|
||||
if engine in _tts_backends:
|
||||
return _tts_backends[engine]
|
||||
|
||||
# Slow path: create with lock to avoid duplicate instantiation
|
||||
with _tts_backends_lock:
|
||||
# Double-check after acquiring lock
|
||||
if engine in _tts_backends:
|
||||
return _tts_backends[engine]
|
||||
|
||||
if engine == "qwen":
|
||||
backend_type = get_backend_type()
|
||||
if backend_type == "mlx":
|
||||
from .mlx_backend import MLXTTSBackend
|
||||
|
||||
backend = MLXTTSBackend()
|
||||
else:
|
||||
from .pytorch_backend import PyTorchTTSBackend
|
||||
|
||||
backend = PyTorchTTSBackend()
|
||||
elif engine == "luxtts":
|
||||
from .luxtts_backend import LuxTTSBackend
|
||||
|
||||
backend = LuxTTSBackend()
|
||||
elif engine == "chatterbox":
|
||||
from .chatterbox_backend import ChatterboxTTSBackend
|
||||
|
||||
backend = ChatterboxTTSBackend()
|
||||
elif engine == "chatterbox_turbo":
|
||||
from .chatterbox_turbo_backend import ChatterboxTurboTTSBackend
|
||||
|
||||
backend = ChatterboxTurboTTSBackend()
|
||||
elif engine == "tada":
|
||||
from .hume_backend import HumeTadaBackend
|
||||
|
||||
backend = HumeTadaBackend()
|
||||
elif engine == "kokoro":
|
||||
from .kokoro_backend import KokoroTTSBackend
|
||||
|
||||
backend = KokoroTTSBackend()
|
||||
elif engine == "qwen_custom_voice":
|
||||
from .qwen_custom_voice_backend import QwenCustomVoiceBackend
|
||||
|
||||
backend = QwenCustomVoiceBackend()
|
||||
else:
|
||||
raise ValueError(f"Unknown TTS engine: {engine}. Supported: {list(TTS_ENGINES.keys())}")
|
||||
|
||||
_tts_backends[engine] = backend
|
||||
return backend
|
||||
|
||||
|
||||
def get_stt_backend() -> STTBackend:
|
||||
"""
|
||||
Get or create STT backend instance based on platform.
|
||||
|
||||
Returns:
|
||||
STT backend instance (MLX or PyTorch)
|
||||
"""
|
||||
global _stt_backend
|
||||
|
||||
if _stt_backend is None:
|
||||
backend_type = get_backend_type()
|
||||
|
||||
if backend_type == "mlx":
|
||||
from .mlx_backend import MLXSTTBackend
|
||||
|
||||
_stt_backend = MLXSTTBackend()
|
||||
else:
|
||||
from .pytorch_backend import PyTorchSTTBackend
|
||||
|
||||
_stt_backend = PyTorchSTTBackend()
|
||||
|
||||
return _stt_backend
|
||||
|
||||
|
||||
def reset_backends():
|
||||
"""Reset backend instances (useful for testing)."""
|
||||
global _tts_backend, _tts_backends, _stt_backend
|
||||
_tts_backend = None
|
||||
_tts_backends.clear()
|
||||
_stt_backend = None
|
||||
327
backend/backends/base.py
Normal file
327
backend/backends/base.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
Shared utilities for TTS/STT backend implementations.
|
||||
|
||||
Eliminates duplication of cache checking, device detection,
|
||||
voice prompt combination, and model loading progress tracking.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import platform
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils.audio import normalize_audio, load_audio
|
||||
from ..utils.progress import get_progress_manager
|
||||
from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback
|
||||
from ..utils.tasks import get_task_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_model_cached(
|
||||
hf_repo: str,
|
||||
*,
|
||||
weight_extensions: tuple[str, ...] = (".safetensors", ".bin"),
|
||||
required_files: Optional[list[str]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a HuggingFace model is fully cached locally.
|
||||
|
||||
Args:
|
||||
hf_repo: HuggingFace repo ID (e.g. "Qwen/Qwen3-TTS-12Hz-1.7B-Base")
|
||||
weight_extensions: File extensions that count as model weights.
|
||||
required_files: If set, check that these specific filenames exist
|
||||
in snapshots instead of checking by extension.
|
||||
|
||||
Returns:
|
||||
True if model is fully cached, False if missing or incomplete.
|
||||
"""
|
||||
try:
|
||||
from huggingface_hub import constants as hf_constants
|
||||
|
||||
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + hf_repo.replace("/", "--"))
|
||||
|
||||
if not repo_cache.exists():
|
||||
return False
|
||||
|
||||
# Incomplete blobs mean a download is still in progress
|
||||
blobs_dir = repo_cache / "blobs"
|
||||
if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
|
||||
logger.debug(f"Found .incomplete files for {hf_repo}")
|
||||
return False
|
||||
|
||||
snapshots_dir = repo_cache / "snapshots"
|
||||
if not snapshots_dir.exists():
|
||||
return False
|
||||
|
||||
if required_files:
|
||||
# Check that every required filename exists somewhere in snapshots
|
||||
for fname in required_files:
|
||||
if not any(snapshots_dir.rglob(fname)):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Check that at least one weight file exists
|
||||
for ext in weight_extensions:
|
||||
if any(snapshots_dir.rglob(f"*{ext}")):
|
||||
return True
|
||||
|
||||
logger.debug(f"No model weights found for {hf_repo}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking cache for {hf_repo}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_torch_device(
|
||||
*,
|
||||
allow_xpu: bool = False,
|
||||
allow_directml: bool = False,
|
||||
allow_mps: bool = False,
|
||||
force_cpu_on_mac: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Detect the best available torch device.
|
||||
|
||||
Args:
|
||||
allow_xpu: Check for Intel XPU (IPEX) support.
|
||||
allow_directml: Check for DirectML (Windows) support.
|
||||
allow_mps: Allow MPS (Apple Silicon). If False, MPS falls back to CPU.
|
||||
force_cpu_on_mac: Force CPU on macOS regardless of GPU availability.
|
||||
"""
|
||||
if force_cpu_on_mac and platform.system() == "Darwin":
|
||||
return "cpu"
|
||||
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
|
||||
if allow_xpu:
|
||||
try:
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return "xpu"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if allow_directml:
|
||||
try:
|
||||
import torch_directml
|
||||
|
||||
if torch_directml.device_count() > 0:
|
||||
return torch_directml.device(0)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if allow_mps:
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
|
||||
return "cpu"
|
||||
|
||||
|
||||
def check_cuda_compatibility() -> tuple[bool, str | None]:
|
||||
"""Check if the installed PyTorch supports the current GPU's compute capability.
|
||||
|
||||
Returns:
|
||||
(compatible, warning_message) — compatible is True if OK or no CUDA GPU,
|
||||
warning_message is a human-readable string if there's a problem.
|
||||
"""
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return True, None
|
||||
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
capability = f"{major}.{minor}"
|
||||
device_name = torch.cuda.get_device_name(0)
|
||||
sm_tag = f"sm_{major}{minor}"
|
||||
|
||||
# torch.cuda._get_arch_list() returns the SM architectures this build
|
||||
# was compiled for (e.g. ["sm_50", "sm_60", ..., "sm_90"]).
|
||||
try:
|
||||
arch_list = torch.cuda._get_arch_list()
|
||||
if arch_list:
|
||||
# Check for both sm_XX and compute_XX (JIT-compiled) entries
|
||||
compute_tag = f"compute_{major}{minor}"
|
||||
if sm_tag not in arch_list and compute_tag not in arch_list:
|
||||
return False, (
|
||||
f"{device_name} (compute capability {capability} / {sm_tag}) "
|
||||
f"is not supported by this PyTorch build. "
|
||||
f"Supported architectures: {', '.join(arch_list)}. "
|
||||
f"Install PyTorch nightly (cu128) for newer GPU support: "
|
||||
f"pip install torch --index-url https://download.pytorch.org/whl/nightly/cu128"
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def empty_device_cache(device: str) -> None:
|
||||
"""
|
||||
Free cached memory on the given device (CUDA or XPU).
|
||||
|
||||
Backends should call this after unloading models so VRAM is returned
|
||||
to the OS.
|
||||
"""
|
||||
import torch
|
||||
|
||||
if device == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif device == "xpu" and hasattr(torch, "xpu"):
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
|
||||
def manual_seed(seed: int, device: str) -> None:
|
||||
"""
|
||||
Set the random seed on both CPU and the active accelerator.
|
||||
|
||||
Covers CUDA and Intel XPU so that generation is reproducible
|
||||
regardless of which GPU backend is in use.
|
||||
"""
|
||||
import torch
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if device == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
elif device == "xpu" and hasattr(torch, "xpu"):
|
||||
torch.xpu.manual_seed(seed)
|
||||
|
||||
|
||||
async def combine_voice_prompts(
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
*,
|
||||
sample_rate: Optional[int] = None,
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
"""
|
||||
Combine multiple reference audio samples into one.
|
||||
|
||||
Loads each audio file, normalizes, concatenates, and joins texts.
|
||||
|
||||
Args:
|
||||
audio_paths: Paths to reference audio files.
|
||||
reference_texts: Corresponding transcripts.
|
||||
sample_rate: If set, resample audio to this rate during loading.
|
||||
"""
|
||||
combined_audio = []
|
||||
|
||||
for path in audio_paths:
|
||||
kwargs = {"sample_rate": sample_rate} if sample_rate else {}
|
||||
audio, _sr = load_audio(path, **kwargs)
|
||||
audio = normalize_audio(audio)
|
||||
combined_audio.append(audio)
|
||||
|
||||
mixed = np.concatenate(combined_audio)
|
||||
mixed = normalize_audio(mixed)
|
||||
combined_text = " ".join(reference_texts)
|
||||
|
||||
return mixed, combined_text
|
||||
|
||||
|
||||
@contextmanager
|
||||
def model_load_progress(
|
||||
model_name: str,
|
||||
is_cached: bool,
|
||||
filter_non_downloads: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Context manager for model loading with HF download progress tracking.
|
||||
|
||||
Handles the tqdm patching, progress_manager/task_manager lifecycle,
|
||||
and error reporting that every backend duplicates.
|
||||
|
||||
Args:
|
||||
model_name: Progress tracking key (e.g. "qwen-tts-1.7B", "whisper-base").
|
||||
is_cached: Whether the model is already downloaded.
|
||||
filter_non_downloads: Whether to filter non-download tqdm bars.
|
||||
Defaults to `is_cached`.
|
||||
|
||||
Yields:
|
||||
The tracker context (already entered). The caller loads the model
|
||||
inside the `with` block. The tqdm patch is torn down on exit.
|
||||
|
||||
Usage:
|
||||
with model_load_progress("qwen-tts-1.7B", is_cached) as ctx:
|
||||
self.model = SomeModel.from_pretrained(...)
|
||||
"""
|
||||
if filter_non_downloads is None:
|
||||
filter_non_downloads = is_cached
|
||||
|
||||
progress_manager = get_progress_manager()
|
||||
task_manager = get_task_manager()
|
||||
|
||||
progress_callback = create_hf_progress_callback(model_name, progress_manager)
|
||||
tracker = HFProgressTracker(progress_callback, filter_non_downloads=filter_non_downloads)
|
||||
|
||||
tracker_context = tracker.patch_download()
|
||||
tracker_context.__enter__()
|
||||
|
||||
if not is_cached:
|
||||
task_manager.start_download(model_name)
|
||||
progress_manager.update_progress(
|
||||
model_name=model_name,
|
||||
current=0,
|
||||
total=0,
|
||||
filename="Connecting to HuggingFace...",
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
try:
|
||||
yield tracker_context
|
||||
except Exception as e:
|
||||
# Report error to both managers
|
||||
progress_manager.mark_error(model_name, str(e))
|
||||
task_manager.error_download(model_name, str(e))
|
||||
raise
|
||||
else:
|
||||
# Only mark complete if we were tracking a download
|
||||
if not is_cached:
|
||||
progress_manager.mark_complete(model_name)
|
||||
task_manager.complete_download(model_name)
|
||||
finally:
|
||||
tracker_context.__exit__(None, None, None)
|
||||
|
||||
|
||||
def patch_chatterbox_f32(model) -> None:
|
||||
"""
|
||||
Patch float64 -> float32 dtype mismatches in upstream chatterbox.
|
||||
|
||||
librosa.load returns float64 numpy arrays. Multiple upstream code paths
|
||||
convert these to torch tensors via torch.from_numpy() without casting,
|
||||
then matmul against float32 model weights. This patches the two known
|
||||
entry points:
|
||||
|
||||
1. S3Tokenizer.log_mel_spectrogram — audio tensor hits _mel_filters (f32)
|
||||
2. VoiceEncoder.forward — float64 mel spectrograms hit LSTM weights (f32)
|
||||
"""
|
||||
import types
|
||||
|
||||
# Patch S3Tokenizer
|
||||
_tokzr = model.s3gen.tokenizer
|
||||
_orig_log_mel = _tokzr.log_mel_spectrogram.__func__
|
||||
|
||||
def _f32_log_mel(self_tokzr, audio, padding=0):
|
||||
import torch as _torch
|
||||
|
||||
if _torch.is_tensor(audio):
|
||||
audio = audio.float()
|
||||
return _orig_log_mel(self_tokzr, audio, padding)
|
||||
|
||||
_tokzr.log_mel_spectrogram = types.MethodType(_f32_log_mel, _tokzr)
|
||||
|
||||
# Patch VoiceEncoder
|
||||
_ve = model.ve
|
||||
_orig_ve_forward = _ve.forward.__func__
|
||||
|
||||
def _f32_ve_forward(self_ve, mels):
|
||||
return _orig_ve_forward(self_ve, mels.float())
|
||||
|
||||
_ve.forward = types.MethodType(_f32_ve_forward, _ve)
|
||||
226
backend/backends/chatterbox_backend.py
Normal file
226
backend/backends/chatterbox_backend.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
Chatterbox TTS backend implementation.
|
||||
|
||||
Wraps ChatterboxMultilingualTTS from chatterbox-tts for zero-shot
|
||||
voice cloning. Supports 23 languages including Hebrew. Forces CPU
|
||||
on macOS due to known MPS tensor issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import TTSBackend
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
empty_device_cache,
|
||||
manual_seed,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
patch_chatterbox_f32,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHATTERBOX_HF_REPO = "ResembleAI/chatterbox"
|
||||
|
||||
# Files that must be present for the multilingual model
|
||||
_MTL_WEIGHT_FILES = [
|
||||
"t3_mtl23ls_v2.safetensors",
|
||||
"s3gen.pt",
|
||||
"ve.pt",
|
||||
]
|
||||
|
||||
|
||||
class ChatterboxTTSBackend:
|
||||
"""Chatterbox Multilingual TTS backend for voice cloning."""
|
||||
|
||||
# Class-level lock for torch.load monkey-patching
|
||||
_load_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.model_size = "default"
|
||||
self._device = None
|
||||
self._model_load_lock = asyncio.Lock()
|
||||
|
||||
def _get_device(self) -> str:
|
||||
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str = "default") -> str:
|
||||
return CHATTERBOX_HF_REPO
|
||||
|
||||
def _is_model_cached(self, model_size: str = "default") -> bool:
|
||||
return is_model_cached(CHATTERBOX_HF_REPO, required_files=_MTL_WEIGHT_FILES)
|
||||
|
||||
async def load_model(self, model_size: str = "default") -> None:
|
||||
"""Load the Chatterbox multilingual model."""
|
||||
if self.model is not None:
|
||||
return
|
||||
async with self._model_load_lock:
|
||||
if self.model is not None:
|
||||
return
|
||||
await asyncio.to_thread(self._load_model_sync)
|
||||
|
||||
def _load_model_sync(self):
|
||||
"""Synchronous model loading."""
|
||||
model_name = "chatterbox-tts"
|
||||
is_cached = self._is_model_cached()
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
device = self._get_device()
|
||||
self._device = device
|
||||
logger.info(f"Loading Chatterbox Multilingual TTS on {device}...")
|
||||
|
||||
import torch
|
||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||
|
||||
if device == "cpu":
|
||||
_orig_torch_load = torch.load
|
||||
|
||||
def _patched_load(*args, **kwargs):
|
||||
kwargs.setdefault("map_location", "cpu")
|
||||
return _orig_torch_load(*args, **kwargs)
|
||||
|
||||
with ChatterboxTTSBackend._load_lock:
|
||||
torch.load = _patched_load
|
||||
try:
|
||||
model = ChatterboxMultilingualTTS.from_pretrained(device=device)
|
||||
finally:
|
||||
torch.load = _orig_torch_load
|
||||
else:
|
||||
model = ChatterboxMultilingualTTS.from_pretrained(device=device)
|
||||
|
||||
# Fix sdpa attention for output_attentions support
|
||||
t3_tfmr = model.t3.tfmr
|
||||
if hasattr(t3_tfmr, "config") and hasattr(t3_tfmr.config, "_attn_implementation"):
|
||||
t3_tfmr.config._attn_implementation = "eager"
|
||||
for layer in getattr(t3_tfmr, "layers", []):
|
||||
if hasattr(layer, "self_attn"):
|
||||
layer.self_attn._attn_implementation = "eager"
|
||||
|
||||
patch_chatterbox_f32(model)
|
||||
self.model = model
|
||||
|
||||
logger.info("Chatterbox Multilingual TTS loaded successfully")
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
if self.model is not None:
|
||||
device = self._device
|
||||
del self.model
|
||||
self.model = None
|
||||
self._device = None
|
||||
empty_device_cache(device)
|
||||
logger.info("Chatterbox unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
Chatterbox processes reference audio at generation time, so the
|
||||
prompt just stores the file path. The actual audio is loaded by
|
||||
model.generate() via audio_prompt_path.
|
||||
"""
|
||||
voice_prompt = {
|
||||
"ref_audio": str(audio_path),
|
||||
"ref_text": reference_text,
|
||||
}
|
||||
return voice_prompt, False
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts)
|
||||
|
||||
# Per-language generation defaults. Lower temp + higher cfg = clearer speech.
|
||||
_LANG_DEFAULTS: ClassVar[dict] = {
|
||||
"he": {
|
||||
"exaggeration": 0.4,
|
||||
"cfg_weight": 0.7,
|
||||
"temperature": 0.65,
|
||||
"repetition_penalty": 2.5,
|
||||
},
|
||||
}
|
||||
_GLOBAL_DEFAULTS: ClassVar[dict] = {
|
||||
"exaggeration": 0.5,
|
||||
"cfg_weight": 0.5,
|
||||
"temperature": 0.8,
|
||||
"repetition_penalty": 2.0,
|
||||
}
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio using Chatterbox Multilingual TTS.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Dict with ref_audio path
|
||||
language: BCP-47 language code
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Unused (protocol compatibility)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model()
|
||||
|
||||
ref_audio = voice_prompt.get("ref_audio")
|
||||
if ref_audio and not Path(ref_audio).exists():
|
||||
logger.warning(f"Reference audio not found: {ref_audio}")
|
||||
ref_audio = None
|
||||
|
||||
# Merge language-specific defaults with global defaults
|
||||
lang_defaults = self._LANG_DEFAULTS.get(language, self._GLOBAL_DEFAULTS)
|
||||
|
||||
def _generate_sync():
|
||||
import torch
|
||||
|
||||
if seed is not None:
|
||||
manual_seed(seed, self._device)
|
||||
|
||||
logger.info(f"[Chatterbox] Generating: lang={language}")
|
||||
|
||||
wav = self.model.generate(
|
||||
text,
|
||||
language_id=language,
|
||||
audio_prompt_path=ref_audio,
|
||||
exaggeration=lang_defaults["exaggeration"],
|
||||
cfg_weight=lang_defaults["cfg_weight"],
|
||||
temperature=lang_defaults["temperature"],
|
||||
repetition_penalty=lang_defaults["repetition_penalty"],
|
||||
)
|
||||
|
||||
# Convert tensor -> numpy
|
||||
if isinstance(wav, torch.Tensor):
|
||||
audio = wav.squeeze().cpu().numpy().astype(np.float32)
|
||||
else:
|
||||
audio = np.asarray(wav, dtype=np.float32)
|
||||
|
||||
sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000)
|
||||
|
||||
return audio, sample_rate
|
||||
|
||||
return await asyncio.to_thread(_generate_sync)
|
||||
206
backend/backends/chatterbox_turbo_backend.py
Normal file
206
backend/backends/chatterbox_turbo_backend.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Chatterbox Turbo TTS backend implementation.
|
||||
|
||||
Wraps ChatterboxTurboTTS from chatterbox-tts for fast, English-only
|
||||
voice cloning with paralinguistic tag support ([laugh], [cough], etc.).
|
||||
Forces CPU on macOS due to known MPS tensor issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import TTSBackend
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
empty_device_cache,
|
||||
manual_seed,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
patch_chatterbox_f32,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHATTERBOX_TURBO_HF_REPO = "ResembleAI/chatterbox-turbo"
|
||||
|
||||
# Files that must be present for the turbo model
|
||||
_TURBO_WEIGHT_FILES = [
|
||||
"t3_turbo_v1.safetensors",
|
||||
"s3gen_meanflow.safetensors",
|
||||
"ve.safetensors",
|
||||
]
|
||||
|
||||
|
||||
class ChatterboxTurboTTSBackend:
|
||||
"""Chatterbox Turbo TTS backend — fast, English-only, with paralinguistic tags."""
|
||||
|
||||
# Class-level lock for torch.load monkey-patching
|
||||
_load_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.model_size = "default"
|
||||
self._device = None
|
||||
self._model_load_lock = asyncio.Lock()
|
||||
|
||||
def _get_device(self) -> str:
|
||||
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str = "default") -> str:
|
||||
return CHATTERBOX_TURBO_HF_REPO
|
||||
|
||||
def _is_model_cached(self, model_size: str = "default") -> bool:
|
||||
return is_model_cached(CHATTERBOX_TURBO_HF_REPO, required_files=_TURBO_WEIGHT_FILES)
|
||||
|
||||
async def load_model(self, model_size: str = "default") -> None:
|
||||
"""Load the Chatterbox Turbo model."""
|
||||
if self.model is not None:
|
||||
return
|
||||
async with self._model_load_lock:
|
||||
if self.model is not None:
|
||||
return
|
||||
await asyncio.to_thread(self._load_model_sync)
|
||||
|
||||
def _load_model_sync(self):
|
||||
"""Synchronous model loading."""
|
||||
model_name = "chatterbox-turbo"
|
||||
is_cached = self._is_model_cached()
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
device = self._get_device()
|
||||
self._device = device
|
||||
logger.info(f"Loading Chatterbox Turbo TTS on {device}...")
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
||||
|
||||
local_path = snapshot_download(
|
||||
repo_id=CHATTERBOX_TURBO_HF_REPO,
|
||||
token=None,
|
||||
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"],
|
||||
)
|
||||
|
||||
if device == "cpu":
|
||||
_orig_torch_load = torch.load
|
||||
|
||||
def _patched_load(*args, **kwargs):
|
||||
kwargs.setdefault("map_location", "cpu")
|
||||
return _orig_torch_load(*args, **kwargs)
|
||||
|
||||
with ChatterboxTurboTTSBackend._load_lock:
|
||||
torch.load = _patched_load
|
||||
try:
|
||||
model = ChatterboxTurboTTS.from_local(local_path, device)
|
||||
finally:
|
||||
torch.load = _orig_torch_load
|
||||
else:
|
||||
model = ChatterboxTurboTTS.from_local(local_path, device)
|
||||
|
||||
patch_chatterbox_f32(model)
|
||||
self.model = model
|
||||
|
||||
logger.info("Chatterbox Turbo TTS loaded successfully")
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
if self.model is not None:
|
||||
device = self._device
|
||||
del self.model
|
||||
self.model = None
|
||||
self._device = None
|
||||
empty_device_cache(device)
|
||||
logger.info("Chatterbox Turbo unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
Chatterbox Turbo processes reference audio at generation time, so the
|
||||
prompt just stores the file path.
|
||||
"""
|
||||
voice_prompt = {
|
||||
"ref_audio": str(audio_path),
|
||||
"ref_text": reference_text,
|
||||
}
|
||||
return voice_prompt, False
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio using Chatterbox Turbo TTS.
|
||||
|
||||
Supports paralinguistic tags in text: [laugh], [cough], [chuckle], etc.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize (may include paralinguistic tags)
|
||||
voice_prompt: Dict with ref_audio path
|
||||
language: Ignored (Turbo is English-only)
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Unused (protocol compatibility)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model()
|
||||
|
||||
ref_audio = voice_prompt.get("ref_audio")
|
||||
if ref_audio and not Path(ref_audio).exists():
|
||||
logger.warning(f"Reference audio not found: {ref_audio}")
|
||||
ref_audio = None
|
||||
|
||||
def _generate_sync():
|
||||
import torch
|
||||
|
||||
if seed is not None:
|
||||
manual_seed(seed, self._device)
|
||||
|
||||
logger.info("[Chatterbox Turbo] Generating (English)")
|
||||
|
||||
wav = self.model.generate(
|
||||
text,
|
||||
audio_prompt_path=ref_audio,
|
||||
temperature=0.8,
|
||||
top_k=1000,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.2,
|
||||
)
|
||||
|
||||
# Convert tensor -> numpy
|
||||
if isinstance(wav, torch.Tensor):
|
||||
audio = wav.squeeze().cpu().numpy().astype(np.float32)
|
||||
else:
|
||||
audio = np.asarray(wav, dtype=np.float32)
|
||||
|
||||
sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000)
|
||||
|
||||
return audio, sample_rate
|
||||
|
||||
return await asyncio.to_thread(_generate_sync)
|
||||
346
backend/backends/hume_backend.py
Normal file
346
backend/backends/hume_backend.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
HumeAI TADA TTS backend implementation.
|
||||
|
||||
Wraps HumeAI's TADA (Text-Acoustic Dual Alignment) model for
|
||||
high-quality voice cloning. Two model variants:
|
||||
- tada-1b: English-only, ~2B params (Llama 3.2 1B base)
|
||||
- tada-3b-ml: Multilingual, ~4B params (Llama 3.2 3B base)
|
||||
|
||||
Both use a shared encoder/codec (HumeAI/tada-codec). The encoder
|
||||
produces 1:1 aligned token embeddings from reference audio, and the
|
||||
causal LM generates speech via flow-matching diffusion.
|
||||
|
||||
24kHz output, bf16 inference on CUDA, fp32 on CPU.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import ClassVar, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import TTSBackend
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
empty_device_cache,
|
||||
manual_seed,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
)
|
||||
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HuggingFace repos
|
||||
TADA_CODEC_REPO = "HumeAI/tada-codec"
|
||||
TADA_1B_REPO = "HumeAI/tada-1b"
|
||||
TADA_3B_ML_REPO = "HumeAI/tada-3b-ml"
|
||||
|
||||
TADA_MODEL_REPOS = {
|
||||
"1B": TADA_1B_REPO,
|
||||
"3B": TADA_3B_ML_REPO,
|
||||
}
|
||||
|
||||
# Key weight files for cache detection
|
||||
_TADA_MODEL_WEIGHT_FILES = [
|
||||
"model.safetensors",
|
||||
]
|
||||
|
||||
_TADA_CODEC_WEIGHT_FILES = [
|
||||
"encoder/model.safetensors",
|
||||
]
|
||||
|
||||
|
||||
class HumeTadaBackend:
|
||||
"""HumeAI TADA TTS backend for high-quality voice cloning."""
|
||||
|
||||
_load_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.encoder = None
|
||||
self.model_size = "1B" # default to 1B
|
||||
self._device = None
|
||||
self._model_load_lock = asyncio.Lock()
|
||||
|
||||
def _get_device(self) -> str:
|
||||
# Force CPU on macOS — MPS has issues with flow matching
|
||||
# and large vocab lm_head (>65536 output channels)
|
||||
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str = "1B") -> str:
|
||||
return TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO)
|
||||
|
||||
def _is_model_cached(self, model_size: str = "1B") -> bool:
|
||||
repo = TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO)
|
||||
model_cached = is_model_cached(repo, required_files=_TADA_MODEL_WEIGHT_FILES)
|
||||
codec_cached = is_model_cached(TADA_CODEC_REPO, required_files=_TADA_CODEC_WEIGHT_FILES)
|
||||
return model_cached and codec_cached
|
||||
|
||||
async def load_model(self, model_size: str = "1B") -> None:
|
||||
"""Load the TADA model and encoder."""
|
||||
if self.model is not None and self.model_size == model_size:
|
||||
return
|
||||
async with self._model_load_lock:
|
||||
if self.model is not None and self.model_size == model_size:
|
||||
return
|
||||
# Unload existing model if switching sizes
|
||||
if self.model is not None:
|
||||
self.unload_model()
|
||||
self.model_size = model_size
|
||||
await asyncio.to_thread(self._load_model_sync, model_size)
|
||||
|
||||
def _load_model_sync(self, model_size: str = "1B"):
|
||||
"""Synchronous model loading with progress tracking."""
|
||||
model_name = f"tada-{model_size.lower()}"
|
||||
is_cached = self._is_model_cached(model_size)
|
||||
repo = TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO)
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
# Install DAC shim before importing tada — tada's encoder/decoder
|
||||
# import dac.nn.layers.Snake1d which requires the descript-audio-codec
|
||||
# package. The real package pulls in onnx/tensorboard/matplotlib via
|
||||
# descript-audiotools, so we use a lightweight shim instead.
|
||||
from ..utils.dac_shim import install_dac_shim
|
||||
|
||||
install_dac_shim()
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
device = self._get_device()
|
||||
self._device = device
|
||||
logger.info(f"Loading HumeAI TADA {model_size} on {device}...")
|
||||
|
||||
# Download codec (encoder + decoder) if not cached
|
||||
logger.info("Downloading TADA codec...")
|
||||
snapshot_download(
|
||||
repo_id=TADA_CODEC_REPO,
|
||||
token=None,
|
||||
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.bin"],
|
||||
)
|
||||
|
||||
# Download model weights if not cached
|
||||
logger.info(f"Downloading TADA {model_size} model...")
|
||||
snapshot_download(
|
||||
repo_id=repo,
|
||||
token=None,
|
||||
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.bin", "*.model"],
|
||||
)
|
||||
|
||||
# TADA hardcodes "meta-llama/Llama-3.2-1B" as the tokenizer
|
||||
# source in its Aligner and TadaForCausalLM.from_pretrained().
|
||||
# That repo is gated (requires Meta license acceptance).
|
||||
# Download the tokenizer from an ungated mirror and get its
|
||||
# local cache path so we can point TADA at it directly.
|
||||
logger.info("Downloading Llama tokenizer (ungated mirror)...")
|
||||
tokenizer_path = snapshot_download(
|
||||
repo_id="unsloth/Llama-3.2-1B",
|
||||
token=None,
|
||||
allow_patterns=["tokenizer*", "special_tokens*"],
|
||||
)
|
||||
|
||||
# Determine dtype — use bf16 on CUDA/XPU for ~50% memory savings
|
||||
if device == "cuda" and torch.cuda.is_bf16_supported():
|
||||
model_dtype = torch.bfloat16
|
||||
elif device == "xpu":
|
||||
# Intel Arc (Alchemist+) supports bf16 natively
|
||||
model_dtype = torch.bfloat16
|
||||
else:
|
||||
model_dtype = torch.float32
|
||||
|
||||
# Patch the Aligner config class to use the local tokenizer
|
||||
# path instead of the gated "meta-llama/Llama-3.2-1B" default.
|
||||
# This avoids monkey-patching AutoTokenizer.from_pretrained
|
||||
# which corrupts the classmethod descriptor for other engines.
|
||||
from tada.modules.aligner import AlignerConfig
|
||||
|
||||
AlignerConfig.tokenizer_name = tokenizer_path
|
||||
|
||||
# Load encoder (only needed for voice prompt encoding)
|
||||
from tada.modules.encoder import Encoder
|
||||
|
||||
logger.info("Loading TADA encoder...")
|
||||
self.encoder = Encoder.from_pretrained(TADA_CODEC_REPO, subfolder="encoder").to(device)
|
||||
self.encoder.eval()
|
||||
|
||||
# Load the causal LM (includes decoder for wav generation).
|
||||
# TadaForCausalLM.from_pretrained() calls
|
||||
# getattr(config, "tokenizer_name", "meta-llama/Llama-3.2-1B")
|
||||
# which hits the gated repo. Pre-load the config from HF,
|
||||
# inject the local tokenizer path, then pass it in.
|
||||
from tada.modules.tada import TadaForCausalLM, TadaConfig
|
||||
|
||||
logger.info(f"Loading TADA {model_size} model...")
|
||||
config = TadaConfig.from_pretrained(repo)
|
||||
config.tokenizer_name = tokenizer_path
|
||||
self.model = TadaForCausalLM.from_pretrained(repo, config=config, torch_dtype=model_dtype).to(device)
|
||||
self.model.eval()
|
||||
|
||||
logger.info(f"HumeAI TADA {model_size} loaded successfully on {device}")
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model and encoder to free memory."""
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
if self.encoder is not None:
|
||||
del self.encoder
|
||||
self.encoder = None
|
||||
|
||||
device = self._device
|
||||
self._device = None
|
||||
|
||||
if device:
|
||||
empty_device_cache(device)
|
||||
|
||||
logger.info("HumeAI TADA unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio using TADA's encoder.
|
||||
|
||||
TADA's encoder performs forced alignment between audio and text tokens,
|
||||
producing an EncoderOutput with 1:1 token-audio alignment. If no
|
||||
reference_text is provided, the encoder uses built-in ASR (English only).
|
||||
|
||||
We serialize the EncoderOutput to a dict for caching.
|
||||
"""
|
||||
await self.load_model(self.model_size)
|
||||
|
||||
cache_key = ("tada_" + get_cache_key(audio_path, reference_text)) if use_cache else None
|
||||
|
||||
if cache_key:
|
||||
cached = get_cached_voice_prompt(cache_key)
|
||||
if cached is not None and isinstance(cached, dict):
|
||||
return cached, True
|
||||
|
||||
def _encode_sync():
|
||||
import torch
|
||||
import soundfile as sf
|
||||
|
||||
device = self._device
|
||||
|
||||
# Load audio with soundfile (torchaudio 2.10+ requires torchcodec)
|
||||
audio_np, sr = sf.read(str(audio_path), dtype="float32")
|
||||
audio = torch.from_numpy(audio_np).float()
|
||||
if audio.ndim == 1:
|
||||
audio = audio.unsqueeze(0) # (samples,) -> (1, samples)
|
||||
else:
|
||||
audio = audio.T # (samples, channels) -> (channels, samples)
|
||||
audio = audio.to(device)
|
||||
|
||||
# Encode with forced alignment
|
||||
text_arg = [reference_text] if reference_text else None
|
||||
prompt = self.encoder(audio, text=text_arg, sample_rate=sr)
|
||||
|
||||
# Serialize EncoderOutput to a dict of CPU tensors for caching
|
||||
prompt_dict = {}
|
||||
for field_name in prompt.__dataclass_fields__:
|
||||
val = getattr(prompt, field_name)
|
||||
if isinstance(val, torch.Tensor):
|
||||
prompt_dict[field_name] = val.detach().cpu()
|
||||
elif isinstance(val, list):
|
||||
prompt_dict[field_name] = val
|
||||
elif isinstance(val, (int, float)):
|
||||
prompt_dict[field_name] = val
|
||||
else:
|
||||
prompt_dict[field_name] = val
|
||||
return prompt_dict
|
||||
|
||||
encoded = await asyncio.to_thread(_encode_sync)
|
||||
|
||||
if cache_key:
|
||||
cache_voice_prompt(cache_key, encoded)
|
||||
|
||||
return encoded, False
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts, sample_rate=24000)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio from text using HumeAI TADA.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Serialized EncoderOutput dict from create_voice_prompt()
|
||||
language: Language code (en, ar, de, es, fr, it, ja, pl, pt, zh)
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Not supported by TADA (ignored)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate=24000)
|
||||
"""
|
||||
await self.load_model(self.model_size)
|
||||
|
||||
def _generate_sync():
|
||||
import torch
|
||||
from tada.modules.encoder import EncoderOutput
|
||||
|
||||
if seed is not None:
|
||||
manual_seed(seed, self._device)
|
||||
|
||||
device = self._device
|
||||
|
||||
# Reconstruct EncoderOutput from the cached dict
|
||||
restored = {}
|
||||
for k, v in voice_prompt.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
# Move to device and match model dtype for float tensors
|
||||
if v.is_floating_point():
|
||||
model_dtype = next(self.model.parameters()).dtype
|
||||
restored[k] = v.to(device=device, dtype=model_dtype)
|
||||
else:
|
||||
restored[k] = v.to(device=device)
|
||||
else:
|
||||
restored[k] = v
|
||||
|
||||
prompt = EncoderOutput(**restored)
|
||||
|
||||
# For non-English with the 3B-ML model, we could reload the
|
||||
# encoder with the language-specific aligner. However, the
|
||||
# generation itself is language-agnostic — only the encoder's
|
||||
# aligner changes. Since we encode at create_voice_prompt time,
|
||||
# the language is already baked in. For simplicity, we don't
|
||||
# reload the encoder here.
|
||||
|
||||
logger.info(f"[TADA] Generating ({language}), text length: {len(text)}")
|
||||
|
||||
output = self.model.generate(
|
||||
prompt=prompt,
|
||||
text=text,
|
||||
)
|
||||
|
||||
# output.audio is a list of tensors (one per batch item)
|
||||
if output.audio and output.audio[0] is not None:
|
||||
audio_tensor = output.audio[0]
|
||||
audio = audio_tensor.detach().cpu().numpy().squeeze().astype(np.float32)
|
||||
else:
|
||||
logger.warning("[TADA] Generation produced no audio")
|
||||
audio = np.zeros(24000, dtype=np.float32)
|
||||
|
||||
return audio, 24000
|
||||
|
||||
return await asyncio.to_thread(_generate_sync)
|
||||
288
backend/backends/kokoro_backend.py
Normal file
288
backend/backends/kokoro_backend.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Kokoro TTS backend implementation.
|
||||
|
||||
Wraps the Kokoro-82M model for fast, lightweight text-to-speech.
|
||||
82M parameters, CPU realtime, 24kHz output, Apache 2.0 license.
|
||||
|
||||
Kokoro uses pre-built voice style vectors (not traditional zero-shot cloning
|
||||
from arbitrary audio). Voice prompts are stored as deferred references to
|
||||
HF-hosted voice .pt files.
|
||||
|
||||
Languages supported (via misaki G2P):
|
||||
- American English (a), British English (b)
|
||||
- Spanish (e), French (f), Hindi (h), Italian (i), Portuguese (p)
|
||||
- Japanese (j) — requires misaki[ja]
|
||||
- Chinese (z) — requires misaki[zh]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import TTSBackend
|
||||
from .base import (
|
||||
get_torch_device,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HuggingFace repo for model + voice detection
|
||||
KOKORO_HF_REPO = "hexgrad/Kokoro-82M"
|
||||
KOKORO_SAMPLE_RATE = 24000
|
||||
|
||||
# Default voice if none specified
|
||||
KOKORO_DEFAULT_VOICE = "af_heart"
|
||||
|
||||
# All available Kokoro voices: (voice_id, display_name, gender, lang_code)
|
||||
KOKORO_VOICES = [
|
||||
# American English female
|
||||
("af_alloy", "Alloy", "female", "en"),
|
||||
("af_aoede", "Aoede", "female", "en"),
|
||||
("af_bella", "Bella", "female", "en"),
|
||||
("af_heart", "Heart", "female", "en"),
|
||||
("af_jessica", "Jessica", "female", "en"),
|
||||
("af_kore", "Kore", "female", "en"),
|
||||
("af_nicole", "Nicole", "female", "en"),
|
||||
("af_nova", "Nova", "female", "en"),
|
||||
("af_river", "River", "female", "en"),
|
||||
("af_sarah", "Sarah", "female", "en"),
|
||||
("af_sky", "Sky", "female", "en"),
|
||||
# American English male
|
||||
("am_adam", "Adam", "male", "en"),
|
||||
("am_echo", "Echo", "male", "en"),
|
||||
("am_eric", "Eric", "male", "en"),
|
||||
("am_fenrir", "Fenrir", "male", "en"),
|
||||
("am_liam", "Liam", "male", "en"),
|
||||
("am_michael", "Michael", "male", "en"),
|
||||
("am_onyx", "Onyx", "male", "en"),
|
||||
("am_puck", "Puck", "male", "en"),
|
||||
("am_santa", "Santa", "male", "en"),
|
||||
# British English female
|
||||
("bf_alice", "Alice", "female", "en"),
|
||||
("bf_emma", "Emma", "female", "en"),
|
||||
("bf_isabella", "Isabella", "female", "en"),
|
||||
("bf_lily", "Lily", "female", "en"),
|
||||
# British English male
|
||||
("bm_daniel", "Daniel", "male", "en"),
|
||||
("bm_fable", "Fable", "male", "en"),
|
||||
("bm_george", "George", "male", "en"),
|
||||
("bm_lewis", "Lewis", "male", "en"),
|
||||
# Spanish
|
||||
("ef_dora", "Dora", "female", "es"),
|
||||
("em_alex", "Alex", "male", "es"),
|
||||
("em_santa", "Santa", "male", "es"),
|
||||
# French
|
||||
("ff_siwis", "Siwis", "female", "fr"),
|
||||
# Hindi
|
||||
("hf_alpha", "Alpha", "female", "hi"),
|
||||
("hf_beta", "Beta", "female", "hi"),
|
||||
("hm_omega", "Omega", "male", "hi"),
|
||||
("hm_psi", "Psi", "male", "hi"),
|
||||
# Italian
|
||||
("if_sara", "Sara", "female", "it"),
|
||||
("im_nicola", "Nicola", "male", "it"),
|
||||
# Japanese
|
||||
("jf_alpha", "Alpha", "female", "ja"),
|
||||
("jf_gongitsune", "Gongitsune", "female", "ja"),
|
||||
("jf_nezumi", "Nezumi", "female", "ja"),
|
||||
("jf_tebukuro", "Tebukuro", "female", "ja"),
|
||||
("jm_kumo", "Kumo", "male", "ja"),
|
||||
# Portuguese
|
||||
("pf_dora", "Dora", "female", "pt"),
|
||||
("pm_alex", "Alex", "male", "pt"),
|
||||
("pm_santa", "Santa", "male", "pt"),
|
||||
# Chinese
|
||||
("zf_xiaobei", "Xiaobei", "female", "zh"),
|
||||
("zf_xiaoni", "Xiaoni", "female", "zh"),
|
||||
("zf_xiaoxiao", "Xiaoxiao", "female", "zh"),
|
||||
("zf_xiaoyi", "Xiaoyi", "female", "zh"),
|
||||
]
|
||||
|
||||
# Map our ISO language codes to Kokoro lang_code characters
|
||||
LANG_CODE_MAP = {
|
||||
"en": "a", # American English
|
||||
"es": "e",
|
||||
"fr": "f",
|
||||
"hi": "h",
|
||||
"it": "i",
|
||||
"pt": "p",
|
||||
"ja": "j",
|
||||
"zh": "z",
|
||||
}
|
||||
|
||||
|
||||
class KokoroTTSBackend:
|
||||
"""Kokoro-82M TTS backend — tiny, fast, CPU-friendly."""
|
||||
|
||||
def __init__(self):
|
||||
self._model = None
|
||||
self._pipelines: dict = {} # lang_code -> KPipeline
|
||||
self._device: Optional[str] = None
|
||||
self.model_size = "default"
|
||||
|
||||
def _get_device(self) -> str:
|
||||
"""Select device. Kokoro supports CUDA and CPU. MPS needs fallback env var."""
|
||||
device = get_torch_device(allow_mps=False)
|
||||
# Kokoro can use MPS but requires PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
# For now, skip MPS to avoid user confusion — CPU is already realtime
|
||||
return device
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
if self._device is None:
|
||||
self._device = self._get_device()
|
||||
return self._device
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self._model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str) -> str:
|
||||
return KOKORO_HF_REPO
|
||||
|
||||
def _is_model_cached(self, model_size: str = "default") -> bool:
|
||||
"""Check if Kokoro model files are cached locally."""
|
||||
from .base import is_model_cached
|
||||
|
||||
return is_model_cached(
|
||||
KOKORO_HF_REPO,
|
||||
required_files=["config.json", "kokoro-v1_0.pth"],
|
||||
)
|
||||
|
||||
async def load_model(self, model_size: str = "default") -> None:
|
||||
"""Load the Kokoro model."""
|
||||
if self._model is not None:
|
||||
return
|
||||
await asyncio.to_thread(self._load_model_sync)
|
||||
|
||||
def _load_model_sync(self):
|
||||
"""Synchronous model loading."""
|
||||
model_name = "kokoro"
|
||||
is_cached = self._is_model_cached()
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
from kokoro import KModel
|
||||
|
||||
device = self.device
|
||||
logger.info(f"Loading Kokoro-82M on {device}...")
|
||||
|
||||
self._model = KModel(repo_id=KOKORO_HF_REPO).to(device).eval()
|
||||
|
||||
logger.info("Kokoro-82M loaded successfully")
|
||||
|
||||
def _get_pipeline(self, lang_code: str):
|
||||
"""Get or create a KPipeline for the given language code."""
|
||||
kokoro_lang = LANG_CODE_MAP.get(lang_code, "a")
|
||||
|
||||
if kokoro_lang not in self._pipelines:
|
||||
from kokoro import KPipeline
|
||||
|
||||
# Create pipeline with our existing model (no redundant model loading)
|
||||
self._pipelines[kokoro_lang] = KPipeline(
|
||||
lang_code=kokoro_lang,
|
||||
repo_id=KOKORO_HF_REPO,
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
return self._pipelines[kokoro_lang]
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
if self._model is not None:
|
||||
del self._model
|
||||
self._model = None
|
||||
self._pipelines.clear()
|
||||
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.info("Kokoro unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt for Kokoro.
|
||||
|
||||
Kokoro doesn't do traditional voice cloning from arbitrary audio.
|
||||
When called for a cloned profile (fallback), uses the default voice.
|
||||
For preset profiles, the voice_prompt dict is built by the profile
|
||||
service and bypasses this method entirely.
|
||||
"""
|
||||
return {
|
||||
"voice_type": "preset",
|
||||
"preset_engine": "kokoro",
|
||||
"preset_voice_id": KOKORO_DEFAULT_VOICE,
|
||||
}, False
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: list[str],
|
||||
reference_texts: list[str],
|
||||
) -> tuple[np.ndarray, str]:
|
||||
"""Combine voice prompts — uses base implementation for audio concatenation."""
|
||||
return await _combine_voice_prompts(
|
||||
audio_paths, reference_texts, sample_rate=KOKORO_SAMPLE_RATE
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio from text using Kokoro.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Dict with kokoro_voice key
|
||||
language: Language code
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Not supported by Kokoro (ignored)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model()
|
||||
|
||||
voice_name = voice_prompt.get("preset_voice_id") or voice_prompt.get("kokoro_voice") or KOKORO_DEFAULT_VOICE
|
||||
|
||||
def _generate_sync():
|
||||
import torch
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
pipeline = self._get_pipeline(language)
|
||||
|
||||
# Generate all chunks and concatenate
|
||||
audio_chunks = []
|
||||
for result in pipeline(text, voice=voice_name, speed=1.0):
|
||||
if result.audio is not None:
|
||||
chunk = result.audio
|
||||
if isinstance(chunk, torch.Tensor):
|
||||
chunk = chunk.detach().cpu().numpy()
|
||||
audio_chunks.append(chunk.squeeze())
|
||||
|
||||
if not audio_chunks:
|
||||
# Return 1 second of silence as fallback
|
||||
return np.zeros(KOKORO_SAMPLE_RATE, dtype=np.float32), KOKORO_SAMPLE_RATE
|
||||
|
||||
audio = np.concatenate(audio_chunks)
|
||||
return audio.astype(np.float32), KOKORO_SAMPLE_RATE
|
||||
|
||||
return await asyncio.to_thread(_generate_sync)
|
||||
184
backend/backends/luxtts_backend.py
Normal file
184
backend/backends/luxtts_backend.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
LuxTTS backend implementation.
|
||||
|
||||
Wraps the LuxTTS (ZipVoice) model for zero-shot voice cloning.
|
||||
~1GB VRAM, 48kHz output, 150x realtime on CPU.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import TTSBackend
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
empty_device_cache,
|
||||
manual_seed,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
)
|
||||
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HuggingFace repo for model weight detection
|
||||
LUXTTS_HF_REPO = "YatharthS/LuxTTS"
|
||||
|
||||
|
||||
class LuxTTSBackend:
|
||||
"""LuxTTS backend for zero-shot voice cloning."""
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.model_size = "default" # LuxTTS has only one model size
|
||||
self._device = None
|
||||
|
||||
def _get_device(self) -> str:
|
||||
return get_torch_device(allow_mps=True, allow_xpu=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
if self._device is None:
|
||||
self._device = self._get_device()
|
||||
return self._device
|
||||
|
||||
def _get_model_path(self, model_size: str) -> str:
|
||||
return LUXTTS_HF_REPO
|
||||
|
||||
def _is_model_cached(self, model_size: str = "default") -> bool:
|
||||
return is_model_cached(
|
||||
LUXTTS_HF_REPO,
|
||||
weight_extensions=(".pt", ".safetensors", ".onnx", ".bin"),
|
||||
)
|
||||
|
||||
async def load_model(self, model_size: str = "default") -> None:
|
||||
"""Load the LuxTTS model."""
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
await asyncio.to_thread(self._load_model_sync)
|
||||
|
||||
def _load_model_sync(self):
|
||||
model_name = "luxtts"
|
||||
is_cached = self._is_model_cached()
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
from zipvoice.luxvoice import LuxTTS
|
||||
|
||||
device = self.device
|
||||
logger.info(f"Loading LuxTTS on {device}...")
|
||||
|
||||
if device == "cpu":
|
||||
import os
|
||||
|
||||
threads = os.cpu_count() or 4
|
||||
self.model = LuxTTS(
|
||||
model_path=LUXTTS_HF_REPO,
|
||||
device="cpu",
|
||||
threads=min(threads, 8),
|
||||
)
|
||||
else:
|
||||
self.model = LuxTTS(model_path=LUXTTS_HF_REPO, device=device)
|
||||
|
||||
logger.info("LuxTTS loaded successfully")
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
if self.model is not None:
|
||||
device = self.device
|
||||
del self.model
|
||||
self.model = None
|
||||
self._device = None
|
||||
|
||||
empty_device_cache(device)
|
||||
|
||||
logger.info("LuxTTS unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
LuxTTS uses its own encode_prompt() which runs Whisper ASR internally
|
||||
to transcribe the reference. The reference_text parameter is not used
|
||||
by LuxTTS itself, but we include it in the cache key for consistency.
|
||||
"""
|
||||
await self.load_model()
|
||||
|
||||
# Compute cache key once for both lookup and storage
|
||||
cache_key = ("luxtts_" + get_cache_key(audio_path, reference_text)) if use_cache else None
|
||||
|
||||
if cache_key:
|
||||
cached = get_cached_voice_prompt(cache_key)
|
||||
if cached is not None and isinstance(cached, dict):
|
||||
return cached, True
|
||||
|
||||
def _encode_sync():
|
||||
return self.model.encode_prompt(
|
||||
prompt_audio=str(audio_path),
|
||||
duration=5,
|
||||
rms=0.01,
|
||||
)
|
||||
|
||||
encoded = await asyncio.to_thread(_encode_sync)
|
||||
|
||||
if cache_key:
|
||||
cache_voice_prompt(cache_key, encoded)
|
||||
|
||||
return encoded, False
|
||||
|
||||
async def combine_voice_prompts(self, audio_paths, reference_texts):
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts, sample_rate=24000)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio from text using LuxTTS.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Encoded prompt dict from encode_prompt()
|
||||
language: Language code (LuxTTS is English-focused)
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Not supported by LuxTTS (ignored)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model()
|
||||
|
||||
def _generate_sync():
|
||||
if seed is not None:
|
||||
manual_seed(seed, self.device)
|
||||
|
||||
wav = self.model.generate_speech(
|
||||
text=text,
|
||||
encode_dict=voice_prompt,
|
||||
num_steps=4,
|
||||
guidance_scale=3.0,
|
||||
t_shift=0.5,
|
||||
speed=1.0,
|
||||
return_smooth=False, # 48kHz output
|
||||
)
|
||||
|
||||
# LuxTTS returns a tensor (may be on GPU/MPS), move to CPU first
|
||||
audio = wav.detach().cpu().numpy().squeeze()
|
||||
return audio, 48000
|
||||
|
||||
return await asyncio.to_thread(_generate_sync)
|
||||
367
backend/backends/mlx_backend.py
Normal file
367
backend/backends/mlx_backend.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
MLX backend implementation for TTS and STT using mlx-audio.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# PATCH: Import and apply offline patch BEFORE any huggingface_hub usage
|
||||
# This prevents mlx_audio from making network requests when models are cached
|
||||
from ..utils.hf_offline_patch import patch_huggingface_hub_offline, ensure_original_qwen_config_cached
|
||||
|
||||
patch_huggingface_hub_offline()
|
||||
ensure_original_qwen_config_cached()
|
||||
|
||||
from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS
|
||||
from .base import is_model_cached, combine_voice_prompts as _combine_voice_prompts, model_load_progress
|
||||
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
|
||||
|
||||
|
||||
class MLXTTSBackend:
|
||||
"""MLX-based TTS backend using mlx-audio."""
|
||||
|
||||
def __init__(self, model_size: str = "1.7B"):
|
||||
self.model = None
|
||||
self.model_size = model_size
|
||||
self._current_model_size = None
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str) -> str:
|
||||
"""
|
||||
Get the MLX model path.
|
||||
|
||||
Args:
|
||||
model_size: Model size (1.7B or 0.6B)
|
||||
|
||||
Returns:
|
||||
HuggingFace Hub model ID for MLX
|
||||
"""
|
||||
mlx_model_map = {
|
||||
"1.7B": "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16",
|
||||
"0.6B": "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-bf16",
|
||||
}
|
||||
|
||||
if model_size not in mlx_model_map:
|
||||
raise ValueError(f"Unknown model size: {model_size}")
|
||||
|
||||
hf_model_id = mlx_model_map[model_size]
|
||||
logger.info("Will download MLX model from HuggingFace Hub: %s", hf_model_id)
|
||||
|
||||
return hf_model_id
|
||||
|
||||
def _is_model_cached(self, model_size: str) -> bool:
|
||||
return is_model_cached(
|
||||
self._get_model_path(model_size),
|
||||
weight_extensions=(".safetensors", ".bin", ".npz"),
|
||||
)
|
||||
|
||||
async def load_model_async(self, model_size: Optional[str] = None):
|
||||
"""
|
||||
Lazy load the MLX TTS model.
|
||||
|
||||
Args:
|
||||
model_size: Model size to load (1.7B or 0.6B)
|
||||
"""
|
||||
if model_size is None:
|
||||
model_size = self.model_size
|
||||
|
||||
# If already loaded with correct size, return
|
||||
if self.model is not None and self._current_model_size == model_size:
|
||||
return
|
||||
|
||||
# Unload existing model if different size requested
|
||||
if self.model is not None and self._current_model_size != model_size:
|
||||
self.unload_model()
|
||||
|
||||
# Run blocking load in thread pool
|
||||
await asyncio.to_thread(self._load_model_sync, model_size)
|
||||
|
||||
# Alias for compatibility
|
||||
load_model = load_model_async
|
||||
|
||||
def _load_model_sync(self, model_size: str):
|
||||
"""Synchronous model loading."""
|
||||
model_path = self._get_model_path(model_size)
|
||||
model_name = f"qwen-tts-{model_size}"
|
||||
is_cached = self._is_model_cached(model_size)
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
from mlx_audio.tts import load
|
||||
|
||||
logger.info("Loading MLX TTS model %s...", model_size)
|
||||
|
||||
self.model = load(model_path)
|
||||
|
||||
self._current_model_size = model_size
|
||||
self.model_size = model_size
|
||||
logger.info("MLX TTS model %s loaded successfully", model_size)
|
||||
|
||||
def unload_model(self):
|
||||
"""Unload the model to free memory."""
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
self._current_model_size = None
|
||||
logger.info("MLX TTS model unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
MLX backend stores voice prompt as a dict with audio path and text.
|
||||
The actual voice prompt processing happens during generation.
|
||||
|
||||
Args:
|
||||
audio_path: Path to reference audio file
|
||||
reference_text: Transcript of reference audio
|
||||
use_cache: Whether to use cached prompt if available
|
||||
|
||||
Returns:
|
||||
Tuple of (voice_prompt_dict, was_cached)
|
||||
"""
|
||||
await self.load_model_async(None)
|
||||
|
||||
# Check cache if enabled
|
||||
if use_cache:
|
||||
cache_key = get_cache_key(audio_path, reference_text)
|
||||
cached_prompt = get_cached_voice_prompt(cache_key)
|
||||
if cached_prompt is not None:
|
||||
# Return cached prompt (should be dict format)
|
||||
if isinstance(cached_prompt, dict):
|
||||
# Validate that the cached audio file still exists
|
||||
cached_audio_path = cached_prompt.get("ref_audio") or cached_prompt.get("ref_audio_path")
|
||||
if cached_audio_path and Path(cached_audio_path).exists():
|
||||
return cached_prompt, True
|
||||
else:
|
||||
# Cached file no longer exists, invalidate cache
|
||||
logger.warning("Cached audio file not found: %s, regenerating prompt", cached_audio_path)
|
||||
|
||||
# MLX voice prompt format - store audio path and text
|
||||
# The model will process this during generation
|
||||
voice_prompt_items = {
|
||||
"ref_audio": str(audio_path),
|
||||
"ref_text": reference_text,
|
||||
}
|
||||
|
||||
# Cache if enabled
|
||||
if use_cache:
|
||||
cache_key = get_cache_key(audio_path, reference_text)
|
||||
cache_voice_prompt(cache_key, voice_prompt_items)
|
||||
|
||||
return voice_prompt_items, False
|
||||
|
||||
async def combine_voice_prompts(self, audio_paths, reference_texts):
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio from text using voice prompt.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Voice prompt dictionary with ref_audio and ref_text
|
||||
language: Language code (en or zh) - may not be fully supported by MLX
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Natural language instruction (may not be supported by MLX)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model_async(None)
|
||||
|
||||
logger.info("Generating audio for text: %s", text)
|
||||
|
||||
def _generate_sync():
|
||||
"""Run synchronous generation in thread pool."""
|
||||
# MLX generate() returns a generator yielding GenerationResult objects
|
||||
audio_chunks = []
|
||||
sample_rate = 24000
|
||||
lang = LANGUAGE_CODE_TO_NAME.get(language, "auto")
|
||||
|
||||
# Set seed if provided (MLX uses numpy random)
|
||||
if seed is not None:
|
||||
import mlx.core as mx
|
||||
|
||||
np.random.seed(seed)
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Extract voice prompt info
|
||||
ref_audio = voice_prompt.get("ref_audio") or voice_prompt.get("ref_audio_path")
|
||||
ref_text = voice_prompt.get("ref_text", "")
|
||||
|
||||
# Validate that the audio file exists
|
||||
if ref_audio and not Path(ref_audio).exists():
|
||||
logger.warning("Audio file not found: %s", ref_audio)
|
||||
logger.warning("This may be due to a cached voice prompt referencing a deleted temp file.")
|
||||
logger.warning("Regenerating without voice prompt.")
|
||||
ref_audio = None
|
||||
|
||||
# Inference runs with the process's default HF_HUB_OFFLINE
|
||||
# state. Forcing offline here (previously used to avoid lazy
|
||||
# mlx_audio lookups hanging when the network drops mid-inference,
|
||||
# issue #462) regressed online users because libraries make
|
||||
# legitimate metadata calls during generation.
|
||||
try:
|
||||
if ref_audio:
|
||||
# Check if generate accepts ref_audio parameter
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(self.model.generate)
|
||||
if "ref_audio" in sig.parameters:
|
||||
# Generate with voice cloning
|
||||
for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text, lang_code=lang):
|
||||
audio_chunks.append(np.array(result.audio))
|
||||
sample_rate = result.sample_rate
|
||||
else:
|
||||
# Fallback: generate without voice cloning
|
||||
for result in self.model.generate(text, lang_code=lang):
|
||||
audio_chunks.append(np.array(result.audio))
|
||||
sample_rate = result.sample_rate
|
||||
else:
|
||||
# No voice prompt, generate normally
|
||||
for result in self.model.generate(text, lang_code=lang):
|
||||
audio_chunks.append(np.array(result.audio))
|
||||
sample_rate = result.sample_rate
|
||||
except Exception as e:
|
||||
# If voice cloning fails, try without it
|
||||
logger.warning("Voice cloning failed, generating without voice prompt: %s", e)
|
||||
for result in self.model.generate(text, lang_code=lang):
|
||||
audio_chunks.append(np.array(result.audio))
|
||||
sample_rate = result.sample_rate
|
||||
|
||||
# Concatenate all chunks
|
||||
if audio_chunks:
|
||||
audio = np.concatenate([np.asarray(chunk, dtype=np.float32) for chunk in audio_chunks])
|
||||
else:
|
||||
# Fallback: empty audio
|
||||
audio = np.array([], dtype=np.float32)
|
||||
|
||||
return audio, sample_rate
|
||||
|
||||
# Run blocking inference in thread pool
|
||||
audio, sample_rate = await asyncio.to_thread(_generate_sync)
|
||||
|
||||
return audio, sample_rate
|
||||
|
||||
|
||||
class MLXSTTBackend:
|
||||
"""MLX-based STT backend using mlx-audio Whisper."""
|
||||
|
||||
def __init__(self, model_size: str = "base"):
|
||||
self.model = None
|
||||
self.model_size = model_size
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self.model is not None
|
||||
|
||||
def _is_model_cached(self, model_size: str) -> bool:
|
||||
hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
|
||||
return is_model_cached(hf_repo, weight_extensions=(".safetensors", ".bin", ".npz"))
|
||||
|
||||
async def load_model_async(self, model_size: Optional[str] = None):
|
||||
"""
|
||||
Lazy load the MLX Whisper model.
|
||||
|
||||
Args:
|
||||
model_size: Model size (tiny, base, small, medium, large)
|
||||
"""
|
||||
if model_size is None:
|
||||
model_size = self.model_size
|
||||
|
||||
if self.model is not None and self.model_size == model_size:
|
||||
return
|
||||
|
||||
# Run blocking load in thread pool
|
||||
await asyncio.to_thread(self._load_model_sync, model_size)
|
||||
|
||||
# Alias for compatibility
|
||||
load_model = load_model_async
|
||||
|
||||
def _load_model_sync(self, model_size: str):
|
||||
"""Synchronous model loading."""
|
||||
progress_model_name = f"whisper-{model_size}"
|
||||
is_cached = self._is_model_cached(model_size)
|
||||
|
||||
with model_load_progress(progress_model_name, is_cached):
|
||||
from mlx_audio.stt import load
|
||||
|
||||
model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
|
||||
logger.info("Loading MLX Whisper model %s...", model_size)
|
||||
|
||||
self.model = load(model_name)
|
||||
|
||||
self.model_size = model_size
|
||||
logger.info("MLX Whisper model %s loaded successfully", model_size)
|
||||
|
||||
def unload_model(self):
|
||||
"""Unload the model to free memory."""
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
logger.info("MLX Whisper model unloaded")
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: str,
|
||||
language: Optional[str] = None,
|
||||
model_size: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe audio to text.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
language: Optional language hint
|
||||
model_size: Optional model size override
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
await self.load_model_async(model_size)
|
||||
|
||||
def _transcribe_sync():
|
||||
"""Run synchronous transcription in thread pool."""
|
||||
# MLX Whisper transcription using generate method
|
||||
# The generate method accepts audio path directly
|
||||
decode_options = {}
|
||||
if language:
|
||||
decode_options["language"] = language
|
||||
|
||||
# Inference runs with the process's default HF_HUB_OFFLINE
|
||||
# state — see the comment in MLXTTSBackend.generate for the
|
||||
# regression this revert fixes (issue #462).
|
||||
result = self.model.generate(str(audio_path), **decode_options)
|
||||
|
||||
# Extract text from result
|
||||
if isinstance(result, str):
|
||||
return result.strip()
|
||||
elif isinstance(result, dict):
|
||||
return result.get("text", "").strip()
|
||||
elif hasattr(result, "text"):
|
||||
return result.text.strip()
|
||||
else:
|
||||
return str(result).strip()
|
||||
|
||||
# Run blocking transcription in thread pool
|
||||
return await asyncio.to_thread(_transcribe_sync)
|
||||
378
backend/backends/pytorch_backend.py
Normal file
378
backend/backends/pytorch_backend.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
PyTorch backend implementation for TTS and STT.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
import asyncio
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
empty_device_cache,
|
||||
manual_seed,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
)
|
||||
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
|
||||
from ..utils.audio import load_audio
|
||||
|
||||
|
||||
class PyTorchTTSBackend:
|
||||
"""PyTorch-based TTS backend using Qwen3-TTS."""
|
||||
|
||||
def __init__(self, model_size: str = "1.7B"):
|
||||
self.model = None
|
||||
self.model_size = model_size
|
||||
self.device = self._get_device()
|
||||
self._current_model_size = None
|
||||
|
||||
def _get_device(self) -> str:
|
||||
"""Get the best available device."""
|
||||
return get_torch_device(allow_xpu=True, allow_directml=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str) -> str:
|
||||
"""
|
||||
Get the HuggingFace Hub model ID.
|
||||
|
||||
Args:
|
||||
model_size: Model size (1.7B or 0.6B)
|
||||
|
||||
Returns:
|
||||
HuggingFace Hub model ID
|
||||
"""
|
||||
hf_model_map = {
|
||||
"1.7B": "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
|
||||
"0.6B": "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
|
||||
}
|
||||
|
||||
if model_size not in hf_model_map:
|
||||
raise ValueError(f"Unknown model size: {model_size}")
|
||||
|
||||
return hf_model_map[model_size]
|
||||
|
||||
def _is_model_cached(self, model_size: str) -> bool:
|
||||
return is_model_cached(self._get_model_path(model_size))
|
||||
|
||||
async def load_model_async(self, model_size: Optional[str] = None):
|
||||
"""
|
||||
Lazy load the TTS model with automatic downloading from HuggingFace Hub.
|
||||
|
||||
Args:
|
||||
model_size: Model size to load (1.7B or 0.6B)
|
||||
"""
|
||||
if model_size is None:
|
||||
model_size = self.model_size
|
||||
|
||||
# If already loaded with correct size, return
|
||||
if self.model is not None and self._current_model_size == model_size:
|
||||
return
|
||||
|
||||
# Unload existing model if different size requested
|
||||
if self.model is not None and self._current_model_size != model_size:
|
||||
self.unload_model()
|
||||
|
||||
# Run blocking load in thread pool
|
||||
await asyncio.to_thread(self._load_model_sync, model_size)
|
||||
|
||||
# Alias for compatibility
|
||||
load_model = load_model_async
|
||||
|
||||
def _load_model_sync(self, model_size: str):
|
||||
"""Synchronous model loading."""
|
||||
model_name = f"qwen-tts-{model_size}"
|
||||
is_cached = self._is_model_cached(model_size)
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
|
||||
model_path = self._get_model_path(model_size)
|
||||
logger.info("Loading TTS model %s on %s...", model_size, self.device)
|
||||
|
||||
# Route both HF Hub and Transformers through a single cache root.
|
||||
# On Windows local setups, model assets can otherwise split between
|
||||
# .hf-cache/hub and .hf-cache/transformers, causing speech_tokenizer
|
||||
# and preprocessor_config.json to fail to resolve during load.
|
||||
from huggingface_hub import constants as hf_constants
|
||||
tts_cache_dir = hf_constants.HF_HUB_CACHE
|
||||
|
||||
if self.device == "cpu":
|
||||
self.model = Qwen3TTSModel.from_pretrained(
|
||||
model_path,
|
||||
cache_dir=tts_cache_dir,
|
||||
torch_dtype=torch.float32,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
else:
|
||||
self.model = Qwen3TTSModel.from_pretrained(
|
||||
model_path,
|
||||
cache_dir=tts_cache_dir,
|
||||
device_map=self.device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
self._current_model_size = model_size
|
||||
self.model_size = model_size
|
||||
logger.info("TTS model %s loaded successfully", model_size)
|
||||
|
||||
def unload_model(self):
|
||||
"""Unload the model to free memory."""
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
self._current_model_size = None
|
||||
|
||||
empty_device_cache(self.device)
|
||||
|
||||
logger.info("TTS model unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
Args:
|
||||
audio_path: Path to reference audio file
|
||||
reference_text: Transcript of reference audio
|
||||
use_cache: Whether to use cached prompt if available
|
||||
|
||||
Returns:
|
||||
Tuple of (voice_prompt_dict, was_cached)
|
||||
"""
|
||||
await self.load_model_async(None)
|
||||
|
||||
# Check cache if enabled
|
||||
if use_cache:
|
||||
cache_key = get_cache_key(audio_path, reference_text)
|
||||
cached_prompt = get_cached_voice_prompt(cache_key)
|
||||
if cached_prompt is not None:
|
||||
# Cache stores as torch.Tensor but actual prompt is dict
|
||||
# Convert if needed
|
||||
if isinstance(cached_prompt, dict):
|
||||
# For PyTorch backend, the dict should contain tensors, not file paths
|
||||
# So we can safely return it
|
||||
return cached_prompt, True
|
||||
elif isinstance(cached_prompt, torch.Tensor):
|
||||
# Legacy cache format - convert to dict
|
||||
# This shouldn't happen in practice, but handle it
|
||||
return {"prompt": cached_prompt}, True
|
||||
|
||||
def _create_prompt_sync():
|
||||
"""Run synchronous voice prompt creation in thread pool."""
|
||||
# Inference runs with the process's default HF_HUB_OFFLINE
|
||||
# state. Forcing offline here (issue #462) regressed online
|
||||
# users whose libraries issue legitimate metadata lookups
|
||||
# during voice-prompt creation.
|
||||
return self.model.create_voice_clone_prompt(
|
||||
ref_audio=str(audio_path),
|
||||
ref_text=reference_text,
|
||||
x_vector_only_mode=False,
|
||||
)
|
||||
|
||||
# Run blocking operation in thread pool
|
||||
voice_prompt_items = await asyncio.to_thread(_create_prompt_sync)
|
||||
|
||||
# Cache if enabled
|
||||
if use_cache:
|
||||
cache_key = get_cache_key(audio_path, reference_text)
|
||||
cache_voice_prompt(cache_key, voice_prompt_items)
|
||||
|
||||
return voice_prompt_items, False
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio from text using voice prompt.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Voice prompt dictionary from create_voice_prompt
|
||||
language: Language code (en or zh)
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Natural language instruction for speech delivery control
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
# Load model
|
||||
await self.load_model_async(None)
|
||||
|
||||
def _generate_sync():
|
||||
"""Run synchronous generation in thread pool."""
|
||||
# Set seed if provided
|
||||
if seed is not None:
|
||||
manual_seed(seed, self.device)
|
||||
|
||||
# See _create_prompt_sync comment — inference runs with the
|
||||
# process's default HF_HUB_OFFLINE state (issue #462).
|
||||
wavs, sample_rate = self.model.generate_voice_clone(
|
||||
text=text,
|
||||
voice_clone_prompt=voice_prompt,
|
||||
language=LANGUAGE_CODE_TO_NAME.get(language, "auto"),
|
||||
instruct=instruct,
|
||||
)
|
||||
return wavs[0], sample_rate
|
||||
|
||||
# Run blocking inference in thread pool to avoid blocking event loop
|
||||
audio, sample_rate = await asyncio.to_thread(_generate_sync)
|
||||
|
||||
return audio, sample_rate
|
||||
|
||||
|
||||
class PyTorchSTTBackend:
|
||||
"""PyTorch-based STT backend using Whisper."""
|
||||
|
||||
def __init__(self, model_size: str = "base"):
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.model_size = model_size
|
||||
self.device = self._get_device()
|
||||
|
||||
def _get_device(self) -> str:
|
||||
"""Get the best available device."""
|
||||
return get_torch_device(allow_xpu=True, allow_directml=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self.model is not None
|
||||
|
||||
def _is_model_cached(self, model_size: str) -> bool:
|
||||
hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
|
||||
return is_model_cached(hf_repo)
|
||||
|
||||
async def load_model_async(self, model_size: Optional[str] = None):
|
||||
"""
|
||||
Lazy load the Whisper model.
|
||||
|
||||
Args:
|
||||
model_size: Model size (tiny, base, small, medium, large)
|
||||
"""
|
||||
if model_size is None:
|
||||
model_size = self.model_size
|
||||
|
||||
if self.model is not None and self.model_size == model_size:
|
||||
return
|
||||
|
||||
await asyncio.to_thread(self._load_model_sync, model_size)
|
||||
|
||||
# Alias for compatibility
|
||||
load_model = load_model_async
|
||||
|
||||
def _load_model_sync(self, model_size: str):
|
||||
"""Synchronous model loading."""
|
||||
progress_model_name = f"whisper-{model_size}"
|
||||
is_cached = self._is_model_cached(model_size)
|
||||
|
||||
with model_load_progress(progress_model_name, is_cached):
|
||||
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||
|
||||
model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
|
||||
logger.info("Loading Whisper model %s on %s...", model_size, self.device)
|
||||
|
||||
self.processor = WhisperProcessor.from_pretrained(model_name)
|
||||
self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
|
||||
|
||||
self.model.to(self.device)
|
||||
self.model_size = model_size
|
||||
logger.info("Whisper model %s loaded successfully", model_size)
|
||||
|
||||
def unload_model(self):
|
||||
"""Unload the model to free memory."""
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
del self.processor
|
||||
self.model = None
|
||||
self.processor = None
|
||||
|
||||
empty_device_cache(self.device)
|
||||
|
||||
logger.info("Whisper model unloaded")
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: str,
|
||||
language: Optional[str] = None,
|
||||
model_size: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe audio to text.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
language: Optional language hint
|
||||
model_size: Optional model size override
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
await self.load_model_async(model_size)
|
||||
|
||||
def _transcribe_sync():
|
||||
"""Run synchronous transcription in thread pool."""
|
||||
# Load audio
|
||||
audio, _sr = load_audio(audio_path, sample_rate=16000)
|
||||
|
||||
# Inference runs with the process's default HF_HUB_OFFLINE
|
||||
# state — forcing offline here (issue #462) broke online users
|
||||
# whose `get_decoder_prompt_ids` / tokenizer calls issue
|
||||
# legitimate metadata lookups.
|
||||
# Process audio
|
||||
inputs = self.processor(
|
||||
audio,
|
||||
sampling_rate=16000,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(self.device)
|
||||
|
||||
# Generate transcription
|
||||
# If language is provided, force it; otherwise let Whisper auto-detect
|
||||
generate_kwargs = {}
|
||||
if language:
|
||||
forced_decoder_ids = self.processor.get_decoder_prompt_ids(
|
||||
language=language,
|
||||
task="transcribe",
|
||||
)
|
||||
generate_kwargs["forced_decoder_ids"] = forced_decoder_ids
|
||||
|
||||
with torch.no_grad():
|
||||
predicted_ids = self.model.generate(
|
||||
inputs["input_features"],
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
# Decode
|
||||
transcription = self.processor.batch_decode(
|
||||
predicted_ids,
|
||||
skip_special_tokens=True,
|
||||
)[0]
|
||||
|
||||
return transcription.strip()
|
||||
|
||||
# Run blocking transcription in thread pool
|
||||
return await asyncio.to_thread(_transcribe_sync)
|
||||
214
backend/backends/qwen_custom_voice_backend.py
Normal file
214
backend/backends/qwen_custom_voice_backend.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Qwen3-TTS CustomVoice backend implementation.
|
||||
|
||||
Wraps the Qwen3-TTS-12Hz CustomVoice model for preset-speaker TTS with
|
||||
instruction-based style control. Uses the same qwen_tts library as the
|
||||
Base model (pytorch_backend.py) but loads a different checkpoint and
|
||||
calls generate_custom_voice() instead of generate_voice_clone().
|
||||
|
||||
Key differences from the Base engine:
|
||||
- Uses preset speakers (9 built-in voices) instead of zero-shot cloning
|
||||
- Supports instruct parameter for tone/emotion/prosody control
|
||||
- Two model sizes: 1.7B and 0.6B
|
||||
|
||||
Languages supported: zh, en, ja, ko, de, fr, ru, pt, es, it
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import TTSBackend, LANGUAGE_CODE_TO_NAME
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Preset speakers ──────────────────────────────────────────────────
|
||||
|
||||
# (speaker_id, display_name, gender, native_language_code, description)
|
||||
QWEN_CUSTOM_VOICES = [
|
||||
("Vivian", "Vivian", "female", "zh", "Bright, slightly edgy young female voice"),
|
||||
("Serena", "Serena", "female", "zh", "Warm, gentle young female voice"),
|
||||
("Uncle_Fu", "Uncle Fu", "male", "zh", "Seasoned male voice with a low, mellow timbre"),
|
||||
("Dylan", "Dylan", "male", "zh", "Youthful Beijing male voice with a clear, natural timbre"),
|
||||
("Eric", "Eric", "male", "zh", "Lively Chengdu male voice with a slightly husky brightness"),
|
||||
("Ryan", "Ryan", "male", "en", "Dynamic male voice with strong rhythmic drive"),
|
||||
("Aiden", "Aiden", "male", "en", "Sunny American male voice with a clear midrange"),
|
||||
("Ono_Anna", "Ono Anna", "female", "ja", "Playful Japanese female voice with a light, nimble timbre"),
|
||||
("Sohee", "Sohee", "female", "ko", "Warm Korean female voice with rich emotion"),
|
||||
]
|
||||
|
||||
QWEN_CV_DEFAULT_SPEAKER = "Ryan"
|
||||
|
||||
# HuggingFace repo IDs per model size
|
||||
QWEN_CV_HF_REPOS = {
|
||||
"1.7B": "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
|
||||
"0.6B": "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
|
||||
}
|
||||
|
||||
|
||||
class QwenCustomVoiceBackend:
|
||||
"""Qwen3-TTS CustomVoice backend — preset speakers with instruct control."""
|
||||
|
||||
def __init__(self, model_size: str = "1.7B"):
|
||||
self.model = None
|
||||
self.model_size = model_size
|
||||
self.device = self._get_device()
|
||||
self._current_model_size: Optional[str] = None
|
||||
|
||||
def _get_device(self) -> str:
|
||||
return get_torch_device(allow_xpu=True, allow_directml=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str) -> str:
|
||||
if model_size not in QWEN_CV_HF_REPOS:
|
||||
raise ValueError(f"Unknown model size: {model_size}")
|
||||
return QWEN_CV_HF_REPOS[model_size]
|
||||
|
||||
def _is_model_cached(self, model_size: Optional[str] = None) -> bool:
|
||||
size = model_size or self.model_size
|
||||
return is_model_cached(self._get_model_path(size))
|
||||
|
||||
async def load_model_async(self, model_size: Optional[str] = None) -> None:
|
||||
if model_size is None:
|
||||
model_size = self.model_size
|
||||
|
||||
if self.model is not None and self._current_model_size == model_size:
|
||||
return
|
||||
|
||||
if self.model is not None and self._current_model_size != model_size:
|
||||
self.unload_model()
|
||||
|
||||
await asyncio.to_thread(self._load_model_sync, model_size)
|
||||
|
||||
# Alias for compatibility with the TTSBackend protocol
|
||||
load_model = load_model_async
|
||||
|
||||
def _load_model_sync(self, model_size: str) -> None:
|
||||
model_name = f"qwen-custom-voice-{model_size}"
|
||||
is_cached = self._is_model_cached(model_size)
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
|
||||
model_path = self._get_model_path(model_size)
|
||||
logger.info("Loading Qwen CustomVoice %s on %s...", model_size, self.device)
|
||||
|
||||
if self.device == "cpu":
|
||||
self.model = Qwen3TTSModel.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.float32,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
else:
|
||||
self.model = Qwen3TTSModel.from_pretrained(
|
||||
model_path,
|
||||
device_map=self.device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
self._current_model_size = model_size
|
||||
self.model_size = model_size
|
||||
logger.info("Qwen CustomVoice %s loaded successfully", model_size)
|
||||
|
||||
def unload_model(self) -> None:
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
self._current_model_size = None
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.info("Qwen CustomVoice unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt for CustomVoice.
|
||||
|
||||
CustomVoice doesn't use reference audio — it uses preset speakers.
|
||||
When called for a cloned profile (fallback), uses the default speaker.
|
||||
For preset profiles, the voice_prompt dict is built by the profile
|
||||
service and bypasses this method entirely.
|
||||
"""
|
||||
return {
|
||||
"voice_type": "preset",
|
||||
"preset_engine": "qwen_custom_voice",
|
||||
"preset_voice_id": QWEN_CV_DEFAULT_SPEAKER,
|
||||
}, False
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: list[str],
|
||||
reference_texts: list[str],
|
||||
) -> tuple[np.ndarray, str]:
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio using Qwen CustomVoice.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Dict with preset_voice_id (speaker name)
|
||||
language: Language code (zh, en, ja, ko, etc.)
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Natural language instruction for style control
|
||||
(e.g. "Speak in an angry tone", "Very happy")
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model_async(None)
|
||||
|
||||
speaker = voice_prompt.get("preset_voice_id") or QWEN_CV_DEFAULT_SPEAKER
|
||||
|
||||
def _generate_sync():
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
lang_name = LANGUAGE_CODE_TO_NAME.get(language, "auto")
|
||||
|
||||
kwargs = {
|
||||
"text": text,
|
||||
"language": lang_name.capitalize() if lang_name != "auto" else "Auto",
|
||||
"speaker": speaker,
|
||||
}
|
||||
|
||||
# Only pass instruct if non-empty
|
||||
if instruct:
|
||||
kwargs["instruct"] = instruct
|
||||
|
||||
# Inference runs with the process's default HF_HUB_OFFLINE
|
||||
# state. Forcing offline here (issue #462) regressed online
|
||||
# users whose libraries issue legitimate metadata lookups
|
||||
# during generation.
|
||||
wavs, sample_rate = self.model.generate_custom_voice(**kwargs)
|
||||
return wavs[0], sample_rate
|
||||
|
||||
audio, sample_rate = await asyncio.to_thread(_generate_sync)
|
||||
return audio, sample_rate
|
||||
458
backend/build_binary.py
Normal file
458
backend/build_binary.py
Normal file
@@ -0,0 +1,458 @@
|
||||
"""
|
||||
PyInstaller build script for creating standalone Python server binary.
|
||||
|
||||
Usage:
|
||||
python build_binary.py # Build default (CPU) server binary
|
||||
python build_binary.py --cuda # Build CUDA-enabled server binary
|
||||
"""
|
||||
|
||||
import PyInstaller.__main__
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_apple_silicon():
|
||||
"""Check if running on Apple Silicon."""
|
||||
return platform.system() == "Darwin" and platform.machine() == "arm64"
|
||||
|
||||
|
||||
def build_server(cuda=False):
|
||||
"""Build Python server as standalone binary.
|
||||
|
||||
Args:
|
||||
cuda: If True, build with CUDA support and name the binary
|
||||
voicebox-server-cuda instead of voicebox-server.
|
||||
"""
|
||||
backend_dir = Path(__file__).parent
|
||||
|
||||
binary_name = "voicebox-server-cuda" if cuda else "voicebox-server"
|
||||
|
||||
# PyInstaller arguments
|
||||
# CUDA builds use --onedir so we can split the output into two archives:
|
||||
# 1. Server core (~200-400MB) — versioned with the app
|
||||
# 2. CUDA libs (~2GB) — versioned independently (only redownloaded on
|
||||
# CUDA toolkit / torch major version changes)
|
||||
# CPU builds remain --onefile for simplicity.
|
||||
pack_mode = "--onedir" if cuda else "--onefile"
|
||||
args = [
|
||||
"server.py", # Use server.py as entry point instead of main.py
|
||||
pack_mode,
|
||||
"--name",
|
||||
binary_name,
|
||||
]
|
||||
|
||||
# Hide console window on Windows only. On macOS/Linux the sidecar needs
|
||||
# stdout/stderr for Tauri to capture logs.
|
||||
if platform.system() == "Windows":
|
||||
args.append("--noconsole")
|
||||
|
||||
# numpy 2.x / torch ABI mismatch fix: install memmove fallback for
|
||||
# torch.from_numpy() before the app starts. Runtime hooks run after
|
||||
# FrozenImporter is registered so frozen torch/numpy are importable.
|
||||
# Paths are passed relative to backend_dir because os.chdir(backend_dir)
|
||||
# runs before PyInstaller. Absolute paths would get baked into the
|
||||
# generated .spec, breaking reproducible builds on other machines / CI.
|
||||
args.extend(
|
||||
[
|
||||
"--runtime-hook",
|
||||
"pyi_rth_numpy_compat.py",
|
||||
# Stub torch.compiler.disable before transformers imports
|
||||
# flex_attention, which otherwise triggers torch._dynamo →
|
||||
# torch._numpy._ufuncs and crashes at module load under
|
||||
# PyInstaller. See pyi_rth_torch_compiler_disable.py.
|
||||
"--runtime-hook",
|
||||
"pyi_rth_torch_compiler_disable.py",
|
||||
# Per-module collection overrides (e.g. forcing scipy.stats._distn_infrastructure
|
||||
# to bundle .py source alongside .pyc so the runtime hook can source-patch it).
|
||||
"--additional-hooks-dir",
|
||||
"pyi_hooks",
|
||||
]
|
||||
)
|
||||
|
||||
# Add local qwen_tts path if specified (for editable installs)
|
||||
qwen_tts_path = os.getenv("QWEN_TTS_PATH")
|
||||
if qwen_tts_path and Path(qwen_tts_path).exists():
|
||||
args.extend(["--paths", str(qwen_tts_path)])
|
||||
logger.info("Using local qwen_tts source from: %s", qwen_tts_path)
|
||||
|
||||
# Add common hidden imports
|
||||
args.extend(
|
||||
[
|
||||
"--hidden-import",
|
||||
"backend",
|
||||
"--hidden-import",
|
||||
"backend.main",
|
||||
"--hidden-import",
|
||||
"backend.config",
|
||||
"--hidden-import",
|
||||
"backend.database",
|
||||
"--hidden-import",
|
||||
"backend.models",
|
||||
"--hidden-import",
|
||||
"backend.services.profiles",
|
||||
"--hidden-import",
|
||||
"backend.services.history",
|
||||
"--hidden-import",
|
||||
"backend.services.tts",
|
||||
"--hidden-import",
|
||||
"backend.services.transcribe",
|
||||
"--hidden-import",
|
||||
"backend.utils.platform_detect",
|
||||
"--hidden-import",
|
||||
"backend.backends",
|
||||
"--hidden-import",
|
||||
"backend.backends.pytorch_backend",
|
||||
"--hidden-import",
|
||||
"backend.backends.qwen_custom_voice_backend",
|
||||
"--hidden-import",
|
||||
"backend.utils.audio",
|
||||
"--hidden-import",
|
||||
"backend.utils.cache",
|
||||
"--hidden-import",
|
||||
"backend.utils.progress",
|
||||
"--hidden-import",
|
||||
"backend.utils.hf_progress",
|
||||
"--hidden-import",
|
||||
"backend.services.cuda",
|
||||
"--hidden-import",
|
||||
"backend.services.effects",
|
||||
"--hidden-import",
|
||||
"backend.utils.effects",
|
||||
"--hidden-import",
|
||||
"backend.services.versions",
|
||||
"--hidden-import",
|
||||
"pedalboard",
|
||||
"--hidden-import",
|
||||
"chatterbox",
|
||||
"--hidden-import",
|
||||
"chatterbox.tts_turbo",
|
||||
"--hidden-import",
|
||||
"chatterbox.mtl_tts",
|
||||
"--hidden-import",
|
||||
"backend.backends.chatterbox_backend",
|
||||
"--hidden-import",
|
||||
"backend.backends.chatterbox_turbo_backend",
|
||||
# chatterbox multilingual uses spacy_pkuseg for Chinese word
|
||||
# segmentation, which ships pickled dict files (dicts/default.pkl)
|
||||
# and native .so extensions that --hidden-import alone won't bundle.
|
||||
"--collect-all",
|
||||
"spacy_pkuseg",
|
||||
"--hidden-import",
|
||||
"backend.backends.luxtts_backend",
|
||||
"--hidden-import",
|
||||
"zipvoice",
|
||||
"--hidden-import",
|
||||
"zipvoice.luxvoice",
|
||||
"--collect-all",
|
||||
"zipvoice",
|
||||
"--collect-all",
|
||||
"linacodec",
|
||||
"--hidden-import",
|
||||
"torch",
|
||||
"--hidden-import",
|
||||
"transformers",
|
||||
"--hidden-import",
|
||||
"fastapi",
|
||||
"--hidden-import",
|
||||
"uvicorn",
|
||||
"--hidden-import",
|
||||
"sqlalchemy",
|
||||
# librosa uses lazy_loader which generates .pyi stub files at
|
||||
# install time and reads them at runtime to discover submodules.
|
||||
# --hidden-import alone doesn't bundle the stubs, causing
|
||||
# "Cannot load imports from non-existent stub" at runtime.
|
||||
"--collect-all",
|
||||
"lazy_loader",
|
||||
"--collect-all",
|
||||
"librosa",
|
||||
"--hidden-import",
|
||||
"soundfile",
|
||||
"--hidden-import",
|
||||
"qwen_tts",
|
||||
"--hidden-import",
|
||||
"qwen_tts.inference",
|
||||
"--hidden-import",
|
||||
"qwen_tts.inference.qwen3_tts_model",
|
||||
"--hidden-import",
|
||||
"qwen_tts.inference.qwen3_tts_tokenizer",
|
||||
"--hidden-import",
|
||||
"qwen_tts.core",
|
||||
"--hidden-import",
|
||||
"qwen_tts.cli",
|
||||
"--copy-metadata",
|
||||
"qwen-tts",
|
||||
"--copy-metadata",
|
||||
"requests",
|
||||
"--copy-metadata",
|
||||
"transformers",
|
||||
"--copy-metadata",
|
||||
"huggingface-hub",
|
||||
"--copy-metadata",
|
||||
"tokenizers",
|
||||
"--copy-metadata",
|
||||
"safetensors",
|
||||
"--copy-metadata",
|
||||
"tqdm",
|
||||
"--hidden-import",
|
||||
"requests",
|
||||
# qwen_tts uses inspect.getsource() at runtime to locate
|
||||
# modeling_qwen3_tts.py — needs physical .py source files bundled
|
||||
"--collect-all",
|
||||
"qwen_tts",
|
||||
# Fix for pkg_resources and jaraco namespace packages
|
||||
"--hidden-import",
|
||||
"pkg_resources.extern",
|
||||
"--collect-submodules",
|
||||
"jaraco",
|
||||
# inflect uses typeguard @typechecked which calls inspect.getsource()
|
||||
# at import time — needs .py source files, not just .pyc bytecode
|
||||
"--collect-all",
|
||||
"inflect",
|
||||
# perth ships pretrained watermark model files (hparams.yaml, .pth.tar)
|
||||
# in perth/perth_net/pretrained/ — needed by chatterbox at runtime
|
||||
"--collect-all",
|
||||
"perth",
|
||||
# piper_phonemize ships espeak-ng-data/ (phoneme tables, language dicts)
|
||||
# needed by LuxTTS for text-to-phoneme conversion
|
||||
"--collect-all",
|
||||
"piper_phonemize",
|
||||
# HumeAI TADA — speech-language model using Llama + flow matching
|
||||
"--hidden-import",
|
||||
"backend.backends.hume_backend",
|
||||
"--hidden-import",
|
||||
"tada",
|
||||
"--hidden-import",
|
||||
"tada.modules",
|
||||
"--hidden-import",
|
||||
"tada.modules.tada",
|
||||
"--hidden-import",
|
||||
"tada.modules.encoder",
|
||||
"--hidden-import",
|
||||
"tada.modules.decoder",
|
||||
"--hidden-import",
|
||||
"tada.modules.aligner",
|
||||
"--hidden-import",
|
||||
"tada.modules.acoustic_spkr_verf",
|
||||
"--hidden-import",
|
||||
"tada.nn",
|
||||
"--hidden-import",
|
||||
"tada.nn.vibevoice",
|
||||
"--hidden-import",
|
||||
"tada.utils",
|
||||
"--hidden-import",
|
||||
"tada.utils.gray_code",
|
||||
"--hidden-import",
|
||||
"tada.utils.text",
|
||||
# DAC shim — provides dac.nn.layers.Snake1d without the real
|
||||
# descript-audio-codec package (which pulls onnx/tensorboard via
|
||||
# descript-audiotools). The shim is in backend/utils/dac_shim.py.
|
||||
"--hidden-import",
|
||||
"backend.utils.dac_shim",
|
||||
"--hidden-import",
|
||||
"torchaudio",
|
||||
"--collect-submodules",
|
||||
"tada",
|
||||
# Kokoro 82M — lightweight TTS engine using misaki G2P
|
||||
# collect-all is required because transformers introspects .py source
|
||||
# files at runtime (e.g. _can_set_attn_implementation opens the class
|
||||
# file); hidden-import alone only bundles bytecode.
|
||||
"--hidden-import",
|
||||
"backend.backends.kokoro_backend",
|
||||
"--collect-all",
|
||||
"kokoro",
|
||||
# misaki ships G2P data files (dictionaries, phoneme tables)
|
||||
# that must be bundled for espeak/en/ja/zh G2P to work
|
||||
"--collect-all",
|
||||
"misaki",
|
||||
# language_tags ships JSON data files (index.json etc.) loaded at
|
||||
# runtime via: misaki → phonemizer → segments → csvw → language_tags
|
||||
"--collect-all",
|
||||
"language_tags",
|
||||
# espeakng_loader ships the entire espeak-ng-data directory (369 files)
|
||||
# loaded at import time by misaki.espeak via get_data_path()
|
||||
"--collect-all",
|
||||
"espeakng_loader",
|
||||
# spacy en_core_web_sm model — misaki.en tries to spacy.cli.download()
|
||||
# at runtime if not found, which calls pip as a subprocess and crashes
|
||||
# the frozen binary. Bundle the model so spacy.util.is_package() passes.
|
||||
"--collect-all",
|
||||
"en_core_web_sm",
|
||||
"--copy-metadata",
|
||||
"en_core_web_sm",
|
||||
"--hidden-import",
|
||||
"en_core_web_sm",
|
||||
# unidic-lite ships the MeCab dictionary used by fugashi (pulled in
|
||||
# by misaki[ja]). The dict lives in unidic_lite/dicdir/ and is
|
||||
# discovered via the package's DICDIR constant, so the data files
|
||||
# must be collected or Japanese Kokoro voices crash at runtime.
|
||||
"--collect-all",
|
||||
"unidic_lite",
|
||||
"--hidden-import",
|
||||
"loguru",
|
||||
]
|
||||
)
|
||||
|
||||
# Add CUDA-specific hidden imports
|
||||
if cuda:
|
||||
logger.info("Building with CUDA support")
|
||||
args.extend(
|
||||
[
|
||||
"--hidden-import",
|
||||
"torch.cuda",
|
||||
"--hidden-import",
|
||||
"torch.backends.cudnn",
|
||||
]
|
||||
)
|
||||
else:
|
||||
# Exclude NVIDIA CUDA packages from CPU-only builds to keep binary small.
|
||||
# When building from a venv with CUDA torch installed, PyInstaller would
|
||||
# bundle ~3GB of NVIDIA shared libraries. We exclude both the Python
|
||||
# modules and the binary DLLs.
|
||||
nvidia_packages = [
|
||||
"nvidia",
|
||||
"nvidia.cublas",
|
||||
"nvidia.cuda_cupti",
|
||||
"nvidia.cuda_nvrtc",
|
||||
"nvidia.cuda_runtime",
|
||||
"nvidia.cudnn",
|
||||
"nvidia.cufft",
|
||||
"nvidia.curand",
|
||||
"nvidia.cusolver",
|
||||
"nvidia.cusparse",
|
||||
"nvidia.nccl",
|
||||
"nvidia.nvjitlink",
|
||||
"nvidia.nvtx",
|
||||
]
|
||||
for pkg in nvidia_packages:
|
||||
args.extend(["--exclude-module", pkg])
|
||||
|
||||
# Add MLX-specific imports if building on Apple Silicon (never for CUDA builds)
|
||||
if is_apple_silicon() and not cuda:
|
||||
logger.info("Building for Apple Silicon - including MLX dependencies")
|
||||
args.extend(
|
||||
[
|
||||
"--hidden-import",
|
||||
"backend.backends.mlx_backend",
|
||||
"--hidden-import",
|
||||
"mlx",
|
||||
"--hidden-import",
|
||||
"mlx.core",
|
||||
"--hidden-import",
|
||||
"mlx.nn",
|
||||
"--hidden-import",
|
||||
"mlx_audio",
|
||||
"--hidden-import",
|
||||
"mlx_audio.tts",
|
||||
"--hidden-import",
|
||||
"mlx_audio.stt",
|
||||
"--collect-submodules",
|
||||
"mlx",
|
||||
"--collect-submodules",
|
||||
"mlx_audio",
|
||||
# Use --collect-all so PyInstaller bundles both data files AND
|
||||
# native shared libraries (.dylib, .metallib) for MLX.
|
||||
# Previously only --collect-data was used, which caused MLX to
|
||||
# raise OSError at runtime inside the bundled binary because
|
||||
# the Metal shader libraries were missing.
|
||||
"--collect-all",
|
||||
"mlx",
|
||||
"--collect-all",
|
||||
"mlx_audio",
|
||||
]
|
||||
)
|
||||
elif not cuda:
|
||||
logger.info("Building for non-Apple Silicon platform - PyTorch only")
|
||||
|
||||
dist_dir = str(backend_dir / "dist")
|
||||
build_dir = str(backend_dir / "build")
|
||||
|
||||
args.extend(
|
||||
[
|
||||
"--distpath",
|
||||
dist_dir,
|
||||
"--workpath",
|
||||
build_dir,
|
||||
"--noconfirm",
|
||||
"--clean",
|
||||
]
|
||||
)
|
||||
|
||||
# Change to backend directory
|
||||
os.chdir(backend_dir)
|
||||
|
||||
# For CPU builds on Windows, ensure we're using CPU-only torch.
|
||||
# If CUDA torch is installed (local dev), swap to CPU torch before building,
|
||||
# then restore CUDA torch after. This prevents PyInstaller from bundling
|
||||
# ~3GB of CUDA DLLs into the CPU binary.
|
||||
restore_cuda = False
|
||||
if not cuda and platform.system() == "Windows":
|
||||
import subprocess
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-c", "import torch; print(torch.version.cuda or '')"], capture_output=True, text=True
|
||||
)
|
||||
has_cuda_torch = bool(result.stdout.strip())
|
||||
if has_cuda_torch:
|
||||
logger.info("CUDA torch detected — installing CPU torch for CPU build...")
|
||||
subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"pip",
|
||||
"install",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
"--index-url",
|
||||
"https://download.pytorch.org/whl/cpu",
|
||||
"--force-reinstall",
|
||||
"-q",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
restore_cuda = True
|
||||
|
||||
# Run PyInstaller
|
||||
try:
|
||||
PyInstaller.__main__.run(args)
|
||||
finally:
|
||||
# Restore CUDA torch if we swapped it out (even on build failure)
|
||||
if restore_cuda:
|
||||
logger.info("Restoring CUDA torch...")
|
||||
import subprocess
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"pip",
|
||||
"install",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio",
|
||||
"--index-url",
|
||||
"https://download.pytorch.org/whl/cu128",
|
||||
"--force-reinstall",
|
||||
"-q",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
logger.info("Binary built in %s", backend_dir / "dist" / binary_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Build voicebox-server binary")
|
||||
parser.add_argument(
|
||||
"--cuda",
|
||||
action="store_true",
|
||||
help="Build CUDA-enabled binary (voicebox-server-cuda)",
|
||||
)
|
||||
cli_args = parser.parse_args()
|
||||
build_server(cuda=cli_args.cuda)
|
||||
133
backend/config.py
Normal file
133
backend/config.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Configuration module for voicebox backend.
|
||||
|
||||
Handles data directory configuration for production bundling.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Allow users to override the HuggingFace model download directory.
|
||||
# Set VOICEBOX_MODELS_DIR to an absolute path before starting the server.
|
||||
# This sets HF_HUB_CACHE so all huggingface_hub downloads go to that path.
|
||||
_custom_models_dir = os.environ.get("VOICEBOX_MODELS_DIR")
|
||||
if _custom_models_dir:
|
||||
os.environ["HF_HUB_CACHE"] = _custom_models_dir
|
||||
logger.info("Model download path set to: %s", _custom_models_dir)
|
||||
|
||||
# Default data directory (used in development)
|
||||
_data_dir = Path("data").resolve()
|
||||
|
||||
|
||||
def _path_relative_to_any_data_dir(path: Path) -> Path | None:
|
||||
"""Extract the path within a data dir from an absolute or relative path."""
|
||||
parts = path.parts
|
||||
for idx, part in enumerate(parts):
|
||||
if part != "data":
|
||||
continue
|
||||
|
||||
tail = parts[idx + 1 :]
|
||||
if tail:
|
||||
return Path(*tail)
|
||||
return Path()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def set_data_dir(path: str | Path):
|
||||
"""
|
||||
Set the data directory path.
|
||||
|
||||
Args:
|
||||
path: Path to the data directory
|
||||
"""
|
||||
global _data_dir
|
||||
_data_dir = Path(path).resolve()
|
||||
_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Data directory set to: %s", _data_dir)
|
||||
|
||||
|
||||
def get_data_dir() -> Path:
|
||||
"""
|
||||
Get the data directory path.
|
||||
|
||||
Returns:
|
||||
Path to the data directory
|
||||
"""
|
||||
return _data_dir
|
||||
|
||||
|
||||
def to_storage_path(path: str | Path) -> str:
|
||||
"""Convert a filesystem path to a DB-safe path relative to the data dir."""
|
||||
resolved_path = Path(path).resolve()
|
||||
|
||||
relative_to_any_data_dir = _path_relative_to_any_data_dir(resolved_path)
|
||||
if relative_to_any_data_dir is not None:
|
||||
return str(relative_to_any_data_dir)
|
||||
|
||||
try:
|
||||
return str(resolved_path.relative_to(_data_dir))
|
||||
except ValueError:
|
||||
return str(resolved_path)
|
||||
|
||||
|
||||
def resolve_storage_path(path: str | Path | None) -> Path | None:
|
||||
"""Resolve a DB-stored path against the configured data dir."""
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
stored_path = Path(path)
|
||||
if stored_path.is_absolute():
|
||||
rebased_path = _path_relative_to_any_data_dir(stored_path)
|
||||
if rebased_path is not None:
|
||||
candidate = (_data_dir / rebased_path).resolve()
|
||||
if candidate.exists() or not stored_path.exists():
|
||||
return candidate
|
||||
|
||||
return stored_path
|
||||
|
||||
# 0.3.0 records sometimes stored relative paths with the data-dir name
|
||||
# baked in (e.g. "data/profiles/..."). Joining those directly with
|
||||
# _data_dir produces a spurious "<data_dir>/data/profiles/..." nest.
|
||||
if stored_path.parts and stored_path.parts[0] == "data":
|
||||
stored_path = (
|
||||
Path(*stored_path.parts[1:]) if len(stored_path.parts) > 1 else Path()
|
||||
)
|
||||
|
||||
return (_data_dir / stored_path).resolve()
|
||||
|
||||
|
||||
def get_db_path() -> Path:
|
||||
"""Get database file path."""
|
||||
return _data_dir / "voicebox.db"
|
||||
|
||||
|
||||
def get_profiles_dir() -> Path:
|
||||
"""Get profiles directory path."""
|
||||
path = _data_dir / "profiles"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def get_generations_dir() -> Path:
|
||||
"""Get generations directory path."""
|
||||
path = _data_dir / "generations"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def get_cache_dir() -> Path:
|
||||
"""Get cache directory path."""
|
||||
path = _data_dir / "cache"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def get_models_dir() -> Path:
|
||||
"""Get models directory path."""
|
||||
path = _data_dir / "models"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
44
backend/database/__init__.py
Normal file
44
backend/database/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Database package — ORM models, session management, and migrations.
|
||||
|
||||
Re-exports all public symbols so that ``from .database import get_db``
|
||||
and ``from .database import Generation as DBGeneration`` continue to work
|
||||
without changing any importers.
|
||||
"""
|
||||
|
||||
from .models import (
|
||||
Base,
|
||||
AudioChannel,
|
||||
ChannelDeviceMapping,
|
||||
EffectPreset,
|
||||
Generation,
|
||||
GenerationVersion,
|
||||
ProfileChannelMapping,
|
||||
ProfileSample,
|
||||
Project,
|
||||
Story,
|
||||
StoryItem,
|
||||
VoiceProfile,
|
||||
)
|
||||
from .session import engine, SessionLocal, _db_path, init_db, get_db
|
||||
|
||||
__all__ = [
|
||||
# Models
|
||||
"Base",
|
||||
"AudioChannel",
|
||||
"ChannelDeviceMapping",
|
||||
"EffectPreset",
|
||||
"Generation",
|
||||
"GenerationVersion",
|
||||
"ProfileChannelMapping",
|
||||
"ProfileSample",
|
||||
"Project",
|
||||
"Story",
|
||||
"StoryItem",
|
||||
"VoiceProfile",
|
||||
# Session
|
||||
"engine",
|
||||
"SessionLocal",
|
||||
"_db_path",
|
||||
"init_db",
|
||||
"get_db",
|
||||
]
|
||||
226
backend/database/migrations.py
Normal file
226
backend/database/migrations.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Column-level migrations for the voicebox SQLite database.
|
||||
|
||||
Why not Alembic? voicebox is a single-user desktop app shipping as a
|
||||
PyInstaller binary. Every user has exactly one SQLite file. Alembic's
|
||||
strengths -- migration tracking across environments, rollback, team
|
||||
coordination -- don't apply here and would add bundling complexity
|
||||
(alembic.ini, env.py, versions/ directory all need to survive
|
||||
PyInstaller). The column-existence checks below are idempotent, run in
|
||||
<50 ms on startup, and have worked reliably across 12 schema changes.
|
||||
If the project ever moves to a server-based deployment or Postgres, this
|
||||
decision should be revisited.
|
||||
|
||||
Adding a new migration:
|
||||
1. Append a new ``_migrate_*`` helper at the bottom of this file.
|
||||
2. Call it from ``run_migrations()`` in the appropriate spot.
|
||||
3. The helper should check column/table existence before acting
|
||||
(idempotent) and print a short message when it does real work.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import inspect, text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_migrations(engine) -> None:
|
||||
"""Run all schema migrations. Safe to call on every startup."""
|
||||
inspector = inspect(engine)
|
||||
tables = set(inspector.get_table_names())
|
||||
|
||||
_migrate_story_items(engine, inspector, tables)
|
||||
_migrate_profiles(engine, inspector, tables)
|
||||
_migrate_generations(engine, inspector, tables)
|
||||
_migrate_effect_presets(engine, inspector, tables)
|
||||
_migrate_generation_versions(engine, inspector, tables)
|
||||
_normalize_storage_paths(engine, tables)
|
||||
|
||||
|
||||
# -- helpers ---------------------------------------------------------------
|
||||
|
||||
def _get_columns(inspector, table: str) -> set[str]:
|
||||
return {col["name"] for col in inspector.get_columns(table)}
|
||||
|
||||
|
||||
def _add_column(engine, table: str, column_sql: str, label: str) -> None:
|
||||
"""Add a column if it doesn't already exist."""
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f"ALTER TABLE {table} ADD COLUMN {column_sql}"))
|
||||
conn.commit()
|
||||
logger.info("Added %s column to %s", label, table)
|
||||
|
||||
|
||||
# -- per-table migrations --------------------------------------------------
|
||||
|
||||
def _migrate_story_items(engine, inspector, tables: set[str]) -> None:
|
||||
if "story_items" not in tables:
|
||||
return
|
||||
|
||||
columns = _get_columns(inspector, "story_items")
|
||||
|
||||
# Replace position-based ordering with absolute timecodes
|
||||
if "position" in columns:
|
||||
logger.info("Migrating story_items: removing position column, using start_time_ms")
|
||||
with engine.connect() as conn:
|
||||
if "start_time_ms" not in columns:
|
||||
conn.execute(text(
|
||||
"ALTER TABLE story_items ADD COLUMN start_time_ms INTEGER DEFAULT 0"
|
||||
))
|
||||
result = conn.execute(text("""
|
||||
SELECT si.id, si.story_id, si.position, g.duration
|
||||
FROM story_items si
|
||||
JOIN generations g ON si.generation_id = g.id
|
||||
ORDER BY si.story_id, si.position
|
||||
"""))
|
||||
current_story_id = None
|
||||
current_time_ms = 0
|
||||
for item_id, story_id, _position, duration in result.fetchall():
|
||||
if story_id != current_story_id:
|
||||
current_story_id = story_id
|
||||
current_time_ms = 0
|
||||
conn.execute(
|
||||
text("UPDATE story_items SET start_time_ms = :time WHERE id = :id"),
|
||||
{"time": current_time_ms, "id": item_id},
|
||||
)
|
||||
current_time_ms += int((duration or 0) * 1000) + 200
|
||||
conn.commit()
|
||||
|
||||
# Recreate table without the position column (SQLite lacks DROP COLUMN)
|
||||
conn.execute(text("""
|
||||
CREATE TABLE story_items_new (
|
||||
id VARCHAR PRIMARY KEY,
|
||||
story_id VARCHAR NOT NULL,
|
||||
generation_id VARCHAR NOT NULL,
|
||||
start_time_ms INTEGER NOT NULL DEFAULT 0,
|
||||
track INTEGER NOT NULL DEFAULT 0,
|
||||
trim_start_ms INTEGER NOT NULL DEFAULT 0,
|
||||
trim_end_ms INTEGER NOT NULL DEFAULT 0,
|
||||
version_id VARCHAR,
|
||||
created_at DATETIME,
|
||||
FOREIGN KEY (story_id) REFERENCES stories(id),
|
||||
FOREIGN KEY (generation_id) REFERENCES generations(id)
|
||||
)
|
||||
"""))
|
||||
conn.execute(text("""
|
||||
INSERT INTO story_items_new (id, story_id, generation_id, start_time_ms, track, trim_start_ms, trim_end_ms, version_id, created_at)
|
||||
SELECT id, story_id, generation_id, start_time_ms,
|
||||
COALESCE(track, 0), COALESCE(trim_start_ms, 0), COALESCE(trim_end_ms, 0), version_id, created_at
|
||||
FROM story_items
|
||||
"""))
|
||||
conn.execute(text("DROP TABLE story_items"))
|
||||
conn.execute(text("ALTER TABLE story_items_new RENAME TO story_items"))
|
||||
conn.commit()
|
||||
|
||||
# Re-read after table recreation
|
||||
columns = _get_columns(inspector, "story_items")
|
||||
|
||||
if "track" not in columns:
|
||||
_add_column(engine, "story_items", "track INTEGER NOT NULL DEFAULT 0", "track")
|
||||
# Re-read so subsequent checks see new columns
|
||||
columns = _get_columns(inspector, "story_items")
|
||||
if "trim_start_ms" not in columns:
|
||||
_add_column(engine, "story_items", "trim_start_ms INTEGER NOT NULL DEFAULT 0", "trim_start_ms")
|
||||
if "trim_end_ms" not in columns:
|
||||
_add_column(engine, "story_items", "trim_end_ms INTEGER NOT NULL DEFAULT 0", "trim_end_ms")
|
||||
if "version_id" not in columns:
|
||||
_add_column(engine, "story_items", "version_id VARCHAR", "version_id")
|
||||
|
||||
|
||||
def _migrate_profiles(engine, inspector, tables: set[str]) -> None:
|
||||
if "profiles" not in tables:
|
||||
return
|
||||
columns = _get_columns(inspector, "profiles")
|
||||
if "avatar_path" not in columns:
|
||||
_add_column(engine, "profiles", "avatar_path VARCHAR", "avatar_path")
|
||||
if "effects_chain" not in columns:
|
||||
_add_column(engine, "profiles", "effects_chain TEXT", "effects_chain")
|
||||
# Voice type system — v0.3.x
|
||||
if "voice_type" not in columns:
|
||||
_add_column(engine, "profiles", "voice_type VARCHAR DEFAULT 'cloned'", "voice_type")
|
||||
if "preset_engine" not in columns:
|
||||
_add_column(engine, "profiles", "preset_engine VARCHAR", "preset_engine")
|
||||
if "preset_voice_id" not in columns:
|
||||
_add_column(engine, "profiles", "preset_voice_id VARCHAR", "preset_voice_id")
|
||||
if "design_prompt" not in columns:
|
||||
_add_column(engine, "profiles", "design_prompt TEXT", "design_prompt")
|
||||
if "default_engine" not in columns:
|
||||
_add_column(engine, "profiles", "default_engine VARCHAR", "default_engine")
|
||||
|
||||
|
||||
def _migrate_generations(engine, inspector, tables: set[str]) -> None:
|
||||
if "generations" not in tables:
|
||||
return
|
||||
columns = _get_columns(inspector, "generations")
|
||||
if "status" not in columns:
|
||||
_add_column(engine, "generations", "status VARCHAR DEFAULT 'completed'", "status")
|
||||
if "error" not in columns:
|
||||
_add_column(engine, "generations", "error TEXT", "error")
|
||||
if "engine" not in columns:
|
||||
_add_column(engine, "generations", "engine VARCHAR DEFAULT 'qwen'", "engine")
|
||||
# Re-read after engine column (variable name shadows outer scope in old code)
|
||||
columns = _get_columns(inspector, "generations")
|
||||
if "model_size" not in columns:
|
||||
_add_column(engine, "generations", "model_size VARCHAR", "model_size")
|
||||
if "is_favorited" not in columns:
|
||||
_add_column(engine, "generations", "is_favorited BOOLEAN DEFAULT 0", "is_favorited")
|
||||
|
||||
|
||||
def _migrate_effect_presets(engine, inspector, tables: set[str]) -> None:
|
||||
if "effect_presets" not in tables:
|
||||
return
|
||||
columns = _get_columns(inspector, "effect_presets")
|
||||
if "sort_order" not in columns:
|
||||
_add_column(engine, "effect_presets", "sort_order INTEGER DEFAULT 100", "sort_order")
|
||||
|
||||
|
||||
def _migrate_generation_versions(engine, inspector, tables: set[str]) -> None:
|
||||
if "generation_versions" not in tables:
|
||||
return
|
||||
columns = _get_columns(inspector, "generation_versions")
|
||||
if "source_version_id" not in columns:
|
||||
_add_column(engine, "generation_versions", "source_version_id VARCHAR", "source_version_id")
|
||||
|
||||
|
||||
def _normalize_storage_paths(engine, tables: set[str]) -> None:
|
||||
"""Normalize stored file paths to be relative to the configured data dir."""
|
||||
from pathlib import Path
|
||||
|
||||
from ..config import get_data_dir, to_storage_path, resolve_storage_path
|
||||
|
||||
data_dir = get_data_dir()
|
||||
|
||||
path_columns = [
|
||||
("generations", "audio_path"),
|
||||
("generation_versions", "audio_path"),
|
||||
("profile_samples", "audio_path"),
|
||||
("profiles", "avatar_path"),
|
||||
]
|
||||
|
||||
total_fixed = 0
|
||||
with engine.connect() as conn:
|
||||
for table, column in path_columns:
|
||||
if table not in tables:
|
||||
continue
|
||||
rows = conn.execute(
|
||||
text(f"SELECT id, {column} FROM {table} WHERE {column} IS NOT NULL")
|
||||
).fetchall()
|
||||
for row_id, path_val in rows:
|
||||
if not path_val:
|
||||
continue
|
||||
p = Path(path_val)
|
||||
resolved = resolve_storage_path(p)
|
||||
if resolved is None:
|
||||
continue
|
||||
|
||||
normalized = to_storage_path(resolved)
|
||||
|
||||
if normalized != path_val:
|
||||
conn.execute(
|
||||
text(f"UPDATE {table} SET {column} = :path WHERE id = :id"),
|
||||
{"path": normalized, "id": row_id},
|
||||
)
|
||||
total_fixed += 1
|
||||
if total_fixed > 0:
|
||||
conn.commit()
|
||||
logger.info("Normalized %d stored file paths", total_fixed)
|
||||
169
backend/database/models.py
Normal file
169
backend/database/models.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""ORM model definitions for the voicebox SQLite database."""
|
||||
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Column, String, Integer, Float, DateTime, Text, ForeignKey, Boolean
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class VoiceProfile(Base):
|
||||
"""Voice profile.
|
||||
|
||||
voice_type discriminates three flavours:
|
||||
- "cloned" — traditional reference-audio profiles (all cloning engines)
|
||||
- "preset" — engine-specific pre-built voice (e.g. Kokoro voices)
|
||||
- "designed" — text-described voice (e.g. Qwen CustomVoice, future)
|
||||
"""
|
||||
|
||||
__tablename__ = "profiles"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
name = Column(String, unique=True, nullable=False)
|
||||
description = Column(Text)
|
||||
language = Column(String, default="en")
|
||||
avatar_path = Column(String, nullable=True)
|
||||
effects_chain = Column(Text, nullable=True)
|
||||
|
||||
# Voice type system — added v0.3.x
|
||||
voice_type = Column(String, default="cloned") # "cloned" | "preset" | "designed"
|
||||
preset_engine = Column(String, nullable=True) # e.g. "kokoro" — only for preset
|
||||
preset_voice_id = Column(String, nullable=True) # e.g. "am_adam" — only for preset
|
||||
design_prompt = Column(Text, nullable=True) # text description — only for designed
|
||||
default_engine = Column(String, nullable=True) # auto-selected engine, locked for preset
|
||||
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class ProfileSample(Base):
|
||||
"""Audio sample attached to a voice profile."""
|
||||
|
||||
__tablename__ = "profile_samples"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
profile_id = Column(String, ForeignKey("profiles.id"), nullable=False)
|
||||
audio_path = Column(String, nullable=False)
|
||||
reference_text = Column(Text, nullable=False)
|
||||
|
||||
|
||||
class Generation(Base):
|
||||
"""A single TTS generation."""
|
||||
|
||||
__tablename__ = "generations"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
profile_id = Column(String, ForeignKey("profiles.id"), nullable=False)
|
||||
text = Column(Text, nullable=False)
|
||||
language = Column(String, default="en")
|
||||
audio_path = Column(String, nullable=True)
|
||||
duration = Column(Float, nullable=True)
|
||||
seed = Column(Integer)
|
||||
instruct = Column(Text)
|
||||
engine = Column(String, default="qwen")
|
||||
model_size = Column(String, nullable=True)
|
||||
status = Column(String, default="completed")
|
||||
error = Column(Text, nullable=True)
|
||||
is_favorited = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
class Story(Base):
|
||||
"""A story that sequences multiple generations."""
|
||||
|
||||
__tablename__ = "stories"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
name = Column(String, nullable=False)
|
||||
description = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class StoryItem(Base):
|
||||
"""Links a generation to a story at a specific timecode."""
|
||||
|
||||
__tablename__ = "story_items"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
story_id = Column(String, ForeignKey("stories.id"), nullable=False)
|
||||
generation_id = Column(String, ForeignKey("generations.id"), nullable=False)
|
||||
version_id = Column(String, ForeignKey("generation_versions.id"), nullable=True)
|
||||
start_time_ms = Column(Integer, nullable=False, default=0)
|
||||
track = Column(Integer, nullable=False, default=0)
|
||||
trim_start_ms = Column(Integer, nullable=False, default=0)
|
||||
trim_end_ms = Column(Integer, nullable=False, default=0)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
class Project(Base):
|
||||
"""Audio studio project (JSON blob)."""
|
||||
|
||||
__tablename__ = "projects"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
name = Column(String, nullable=False)
|
||||
data = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class GenerationVersion(Base):
|
||||
"""A version of a generation's audio (original, processed, alternate takes)."""
|
||||
|
||||
__tablename__ = "generation_versions"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
generation_id = Column(String, ForeignKey("generations.id"), nullable=False)
|
||||
label = Column(String, nullable=False)
|
||||
audio_path = Column(String, nullable=False)
|
||||
effects_chain = Column(Text, nullable=True)
|
||||
source_version_id = Column(String, ForeignKey("generation_versions.id"), nullable=True)
|
||||
is_default = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
class EffectPreset(Base):
|
||||
"""Saved effect chain preset."""
|
||||
|
||||
__tablename__ = "effect_presets"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
name = Column(String, unique=True, nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
effects_chain = Column(Text, nullable=False)
|
||||
is_builtin = Column(Boolean, default=False)
|
||||
sort_order = Column(Integer, default=100)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
class AudioChannel(Base):
|
||||
"""Audio output channel (bus)."""
|
||||
|
||||
__tablename__ = "audio_channels"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
name = Column(String, nullable=False)
|
||||
is_default = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
|
||||
class ChannelDeviceMapping(Base):
|
||||
"""Mapping between a channel and an OS audio device."""
|
||||
|
||||
__tablename__ = "channel_device_mappings"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
channel_id = Column(String, ForeignKey("audio_channels.id"), nullable=False)
|
||||
device_id = Column(String, nullable=False)
|
||||
|
||||
|
||||
class ProfileChannelMapping(Base):
|
||||
"""Many-to-many mapping between voice profiles and audio channels."""
|
||||
|
||||
__tablename__ = "profile_channel_mappings"
|
||||
|
||||
profile_id = Column(String, ForeignKey("profiles.id"), primary_key=True)
|
||||
channel_id = Column(String, ForeignKey("audio_channels.id"), primary_key=True)
|
||||
73
backend/database/seed.py
Normal file
73
backend/database/seed.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Post-migration data seeding and backfills."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from .. import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def backfill_generation_versions(SessionLocal, Generation, GenerationVersion) -> None:
|
||||
"""Create 'clean' version entries for generations that predate the versions feature."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
existing_version_gen_ids = {
|
||||
row[0] for row in db.query(GenerationVersion.generation_id).all()
|
||||
}
|
||||
generations = db.query(Generation).filter(
|
||||
Generation.status == "completed",
|
||||
Generation.audio_path.isnot(None),
|
||||
Generation.audio_path != "",
|
||||
).all()
|
||||
|
||||
count = 0
|
||||
for gen in generations:
|
||||
if gen.id in existing_version_gen_ids:
|
||||
continue
|
||||
resolved_audio_path = config.resolve_storage_path(gen.audio_path)
|
||||
if resolved_audio_path is None or not resolved_audio_path.exists():
|
||||
continue
|
||||
version = GenerationVersion(
|
||||
id=str(uuid.uuid4()),
|
||||
generation_id=gen.id,
|
||||
label="clean",
|
||||
audio_path=gen.audio_path,
|
||||
effects_chain=None,
|
||||
is_default=True,
|
||||
)
|
||||
db.add(version)
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
db.commit()
|
||||
logger.info("Backfilled %d generation version entries", count)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def seed_builtin_presets(SessionLocal, EffectPreset) -> None:
|
||||
"""Ensure built-in effect presets exist in the database."""
|
||||
from ..utils.effects import BUILTIN_PRESETS
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for idx, (_key, preset_data) in enumerate(BUILTIN_PRESETS.items()):
|
||||
sort_order = preset_data.get("sort_order", idx)
|
||||
existing = db.query(EffectPreset).filter_by(name=preset_data["name"]).first()
|
||||
if not existing:
|
||||
preset = EffectPreset(
|
||||
id=str(uuid.uuid4()),
|
||||
name=preset_data["name"],
|
||||
description=preset_data.get("description"),
|
||||
effects_chain=json.dumps(preset_data["effects_chain"]),
|
||||
is_builtin=True,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
db.add(preset)
|
||||
elif existing.sort_order != sort_order:
|
||||
existing.sort_order = sort_order
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
78
backend/database/session.py
Normal file
78
backend/database/session.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Engine creation, initialization, and session management."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from .. import config
|
||||
from .models import (
|
||||
Base,
|
||||
AudioChannel,
|
||||
EffectPreset,
|
||||
Generation,
|
||||
GenerationVersion,
|
||||
ProfileChannelMapping,
|
||||
VoiceProfile,
|
||||
)
|
||||
from .migrations import run_migrations
|
||||
from .seed import backfill_generation_versions, seed_builtin_presets
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialized by init_db()
|
||||
engine = None
|
||||
SessionLocal = None
|
||||
_db_path = None
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
"""Initialize the database engine, run migrations, create tables, and seed data."""
|
||||
global engine, SessionLocal, _db_path
|
||||
|
||||
_db_path = config.get_db_path()
|
||||
_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
engine = create_engine(
|
||||
f"sqlite:///{_db_path}",
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
run_migrations(engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# Create default audio channel if it doesn't exist
|
||||
db = SessionLocal()
|
||||
try:
|
||||
default_channel = db.query(AudioChannel).filter(AudioChannel.is_default == True).first()
|
||||
if not default_channel:
|
||||
default_channel = AudioChannel(
|
||||
id=str(uuid.uuid4()),
|
||||
name="Default",
|
||||
is_default=True,
|
||||
)
|
||||
db.add(default_channel)
|
||||
|
||||
for profile in db.query(VoiceProfile).all():
|
||||
db.add(ProfileChannelMapping(
|
||||
profile_id=profile.id,
|
||||
channel_id=default_channel.id,
|
||||
))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
backfill_generation_versions(SessionLocal, Generation, GenerationVersion)
|
||||
seed_builtin_presets(SessionLocal, EffectPreset)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Yield a database session (FastAPI dependency)."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
45
backend/main.py
Normal file
45
backend/main.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Entry point for the voicebox backend.
|
||||
|
||||
Imports the configured FastAPI app and provides a ``python -m backend.main``
|
||||
entry point for development.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import uvicorn
|
||||
|
||||
from .app import app # noqa: F401 -- re-export for uvicorn "backend.main:app"
|
||||
from . import config, database
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="voicebox backend server")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Host to bind to (use 0.0.0.0 for remote access)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to bind to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Data directory for database, profiles, and generated audio",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.data_dir:
|
||||
config.set_data_dir(args.data_dir)
|
||||
|
||||
database.init_db()
|
||||
|
||||
uvicorn.run(
|
||||
"backend.main:app",
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
reload=False,
|
||||
)
|
||||
521
backend/models.py
Normal file
521
backend/models.py
Normal file
@@ -0,0 +1,521 @@
|
||||
"""
|
||||
Pydantic models for request/response validation.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class VoiceProfileCreate(BaseModel):
|
||||
"""Request model for creating a voice profile."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
language: str = Field(
|
||||
default="en", pattern="^(zh|en|ja|ko|de|fr|ru|pt|es|it|he|ar|da|el|fi|hi|ms|nl|no|pl|sv|sw|tr)$"
|
||||
)
|
||||
voice_type: Optional[str] = Field(default="cloned", pattern="^(cloned|preset|designed)$")
|
||||
preset_engine: Optional[str] = Field(None, max_length=50)
|
||||
preset_voice_id: Optional[str] = Field(None, max_length=100)
|
||||
design_prompt: Optional[str] = Field(None, max_length=2000)
|
||||
default_engine: Optional[str] = Field(None, max_length=50)
|
||||
|
||||
|
||||
class VoiceProfileResponse(BaseModel):
|
||||
"""Response model for voice profile."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
language: str
|
||||
avatar_path: Optional[str] = None
|
||||
effects_chain: Optional[List["EffectConfig"]] = None
|
||||
voice_type: str = "cloned"
|
||||
preset_engine: Optional[str] = None
|
||||
preset_voice_id: Optional[str] = None
|
||||
design_prompt: Optional[str] = None
|
||||
default_engine: Optional[str] = None
|
||||
generation_count: int = 0
|
||||
sample_count: int = 0
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ProfileSampleCreate(BaseModel):
|
||||
"""Request model for adding a sample to a profile."""
|
||||
|
||||
reference_text: str = Field(..., min_length=1, max_length=1000)
|
||||
|
||||
|
||||
class ProfileSampleUpdate(BaseModel):
|
||||
"""Request model for updating a profile sample."""
|
||||
|
||||
reference_text: str = Field(..., min_length=1, max_length=1000)
|
||||
|
||||
|
||||
class ProfileSampleResponse(BaseModel):
|
||||
"""Response model for profile sample."""
|
||||
|
||||
id: str
|
||||
profile_id: str
|
||||
audio_path: str
|
||||
reference_text: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class GenerationRequest(BaseModel):
|
||||
"""Request model for voice generation."""
|
||||
|
||||
profile_id: str
|
||||
text: str = Field(..., min_length=1, max_length=50000)
|
||||
language: str = Field(default="en", pattern="^(zh|en|ja|ko|de|fr|ru|pt|es|it|he|ar|da|el|fi|hi|ms|nl|no|pl|sv|sw|tr)$")
|
||||
seed: Optional[int] = Field(None, ge=0)
|
||||
model_size: Optional[str] = Field(default="1.7B", pattern="^(1\\.7B|0\\.6B|1B|3B)$")
|
||||
instruct: Optional[str] = Field(None, max_length=500)
|
||||
engine: Optional[str] = Field(default="qwen", pattern="^(qwen|qwen_custom_voice|luxtts|chatterbox|chatterbox_turbo|tada|kokoro)$")
|
||||
max_chunk_chars: int = Field(
|
||||
default=800, ge=100, le=5000, description="Max characters per chunk for long text splitting"
|
||||
)
|
||||
crossfade_ms: int = Field(
|
||||
default=50, ge=0, le=500, description="Crossfade duration in ms between chunks (0 for hard cut)"
|
||||
)
|
||||
normalize: bool = Field(default=True, description="Normalize output audio volume")
|
||||
effects_chain: Optional[List["EffectConfig"]] = Field(
|
||||
None, description="Effects chain to apply after generation (overrides profile default)"
|
||||
)
|
||||
|
||||
|
||||
class GenerationResponse(BaseModel):
|
||||
"""Response model for voice generation."""
|
||||
|
||||
id: str
|
||||
profile_id: str
|
||||
text: str
|
||||
language: str
|
||||
audio_path: Optional[str] = None
|
||||
duration: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
instruct: Optional[str] = None
|
||||
engine: Optional[str] = "qwen"
|
||||
model_size: Optional[str] = None
|
||||
status: str = "completed"
|
||||
error: Optional[str] = None
|
||||
is_favorited: bool = False
|
||||
created_at: datetime
|
||||
versions: Optional[List["GenerationVersionResponse"]] = None
|
||||
active_version_id: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HistoryQuery(BaseModel):
|
||||
"""Query model for generation history."""
|
||||
|
||||
profile_id: Optional[str] = None
|
||||
search: Optional[str] = None
|
||||
limit: int = Field(default=50, ge=1, le=100)
|
||||
offset: int = Field(default=0, ge=0)
|
||||
|
||||
|
||||
class HistoryResponse(BaseModel):
|
||||
"""Response model for history entry (includes profile name)."""
|
||||
|
||||
id: str
|
||||
profile_id: str
|
||||
profile_name: str
|
||||
text: str
|
||||
language: str
|
||||
audio_path: Optional[str] = None
|
||||
duration: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
instruct: Optional[str] = None
|
||||
engine: Optional[str] = "qwen"
|
||||
model_size: Optional[str] = None
|
||||
status: str = "completed"
|
||||
error: Optional[str] = None
|
||||
is_favorited: bool = False
|
||||
created_at: datetime
|
||||
versions: Optional[List["GenerationVersionResponse"]] = None
|
||||
active_version_id: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HistoryListResponse(BaseModel):
|
||||
"""Response model for history list."""
|
||||
|
||||
items: List[HistoryResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class TranscriptionRequest(BaseModel):
|
||||
"""Request model for audio transcription."""
|
||||
|
||||
language: Optional[str] = Field(None, pattern="^(en|zh|ja|ko|de|fr|ru|pt|es|it)$")
|
||||
model: Optional[str] = Field(None, pattern="^(base|small|medium|large|turbo)$")
|
||||
|
||||
|
||||
class TranscriptionResponse(BaseModel):
|
||||
"""Response model for transcription."""
|
||||
|
||||
text: str
|
||||
duration: float
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Response model for health check."""
|
||||
|
||||
status: str
|
||||
model_loaded: bool
|
||||
model_downloaded: Optional[bool] = None # Whether model is cached/downloaded
|
||||
model_size: Optional[str] = None # Current model size if loaded
|
||||
gpu_available: bool
|
||||
gpu_type: Optional[str] = None # GPU type (CUDA, MPS, or None)
|
||||
vram_used_mb: Optional[float] = None
|
||||
backend_type: Optional[str] = None # Backend type (mlx or pytorch)
|
||||
backend_variant: Optional[str] = None # Binary variant (cpu or cuda)
|
||||
gpu_compatibility_warning: Optional[str] = None # Warning if GPU arch unsupported
|
||||
|
||||
|
||||
class DirectoryCheck(BaseModel):
|
||||
"""Health status for a single directory."""
|
||||
|
||||
path: str
|
||||
exists: bool
|
||||
writable: bool
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class FilesystemHealthResponse(BaseModel):
|
||||
"""Response model for filesystem health check."""
|
||||
|
||||
healthy: bool
|
||||
disk_free_mb: Optional[float] = None
|
||||
disk_total_mb: Optional[float] = None
|
||||
directories: List[DirectoryCheck]
|
||||
|
||||
|
||||
class ModelStatus(BaseModel):
|
||||
"""Response model for model status."""
|
||||
|
||||
model_name: str
|
||||
display_name: str
|
||||
hf_repo_id: Optional[str] = None # HuggingFace repository ID
|
||||
downloaded: bool
|
||||
downloading: bool = False # True if download is in progress
|
||||
size_mb: Optional[float] = None
|
||||
loaded: bool = False
|
||||
|
||||
|
||||
class ModelStatusListResponse(BaseModel):
|
||||
"""Response model for model status list."""
|
||||
|
||||
models: List[ModelStatus]
|
||||
|
||||
|
||||
class ModelDownloadRequest(BaseModel):
|
||||
"""Request model for triggering model download."""
|
||||
|
||||
model_name: str
|
||||
|
||||
|
||||
class ModelMigrateRequest(BaseModel):
|
||||
"""Request model for migrating models to a new directory."""
|
||||
|
||||
destination: str
|
||||
|
||||
|
||||
class ActiveDownloadTask(BaseModel):
|
||||
"""Response model for active download task."""
|
||||
|
||||
model_name: str
|
||||
status: str
|
||||
started_at: datetime
|
||||
error: Optional[str] = None
|
||||
progress: Optional[float] = None # 0-100 percentage
|
||||
current: Optional[int] = None # bytes downloaded
|
||||
total: Optional[int] = None # total bytes
|
||||
filename: Optional[str] = None # current file being downloaded
|
||||
|
||||
|
||||
class ActiveGenerationTask(BaseModel):
|
||||
"""Response model for active generation task."""
|
||||
|
||||
task_id: str
|
||||
profile_id: str
|
||||
text_preview: str
|
||||
started_at: datetime
|
||||
|
||||
|
||||
class ActiveTasksResponse(BaseModel):
|
||||
"""Response model for active tasks."""
|
||||
|
||||
downloads: List[ActiveDownloadTask]
|
||||
generations: List[ActiveGenerationTask]
|
||||
|
||||
|
||||
class AudioChannelCreate(BaseModel):
|
||||
"""Request model for creating an audio channel."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
device_ids: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AudioChannelUpdate(BaseModel):
|
||||
"""Request model for updating an audio channel."""
|
||||
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
device_ids: Optional[List[str]] = None
|
||||
|
||||
|
||||
class AudioChannelResponse(BaseModel):
|
||||
"""Response model for audio channel."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
is_default: bool
|
||||
device_ids: List[str]
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ChannelVoiceAssignment(BaseModel):
|
||||
"""Request model for assigning voices to a channel."""
|
||||
|
||||
profile_ids: List[str]
|
||||
|
||||
|
||||
class ProfileChannelAssignment(BaseModel):
|
||||
"""Request model for assigning channels to a profile."""
|
||||
|
||||
channel_ids: List[str]
|
||||
|
||||
|
||||
class StoryCreate(BaseModel):
|
||||
"""Request model for creating a story."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
|
||||
|
||||
class StoryResponse(BaseModel):
|
||||
"""Response model for story (list view)."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
item_count: int = 0
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class StoryItemDetail(BaseModel):
|
||||
"""Detail model for story item with generation info."""
|
||||
|
||||
id: str
|
||||
story_id: str
|
||||
generation_id: str
|
||||
version_id: Optional[str] = None
|
||||
start_time_ms: int
|
||||
track: int = 0
|
||||
trim_start_ms: int = 0
|
||||
trim_end_ms: int = 0
|
||||
created_at: datetime
|
||||
# Generation details
|
||||
profile_id: str
|
||||
profile_name: str
|
||||
text: str
|
||||
language: str
|
||||
audio_path: str
|
||||
duration: float
|
||||
seed: Optional[int]
|
||||
instruct: Optional[str]
|
||||
generation_created_at: datetime
|
||||
# Versions available for this generation
|
||||
versions: Optional[List["GenerationVersionResponse"]] = None
|
||||
active_version_id: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class StoryDetailResponse(BaseModel):
|
||||
"""Response model for story with items."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
items: List[StoryItemDetail] = []
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class StoryItemCreate(BaseModel):
|
||||
"""Request model for adding a generation to a story."""
|
||||
|
||||
generation_id: str
|
||||
start_time_ms: Optional[int] = None # If not provided, will be calculated automatically
|
||||
track: Optional[int] = 0 # Track number (0 = main track)
|
||||
|
||||
|
||||
class StoryItemUpdateTime(BaseModel):
|
||||
"""Request model for updating a story item's timecode."""
|
||||
|
||||
generation_id: str
|
||||
start_time_ms: int = Field(..., ge=0)
|
||||
|
||||
|
||||
class StoryItemBatchUpdate(BaseModel):
|
||||
"""Request model for batch updating story item timecodes."""
|
||||
|
||||
updates: List[StoryItemUpdateTime]
|
||||
|
||||
|
||||
class StoryItemReorder(BaseModel):
|
||||
"""Request model for reordering story items."""
|
||||
|
||||
generation_ids: List[str] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class StoryItemMove(BaseModel):
|
||||
"""Request model for moving a story item (position and/or track)."""
|
||||
|
||||
start_time_ms: int = Field(..., ge=0)
|
||||
track: int = 0
|
||||
|
||||
|
||||
class StoryItemTrim(BaseModel):
|
||||
"""Request model for trimming a story item."""
|
||||
|
||||
trim_start_ms: int = Field(..., ge=0)
|
||||
trim_end_ms: int = Field(..., ge=0)
|
||||
|
||||
|
||||
class StoryItemSplit(BaseModel):
|
||||
"""Request model for splitting a story item."""
|
||||
|
||||
split_time_ms: int = Field(..., ge=0) # Time within the clip to split at (relative to clip start)
|
||||
|
||||
|
||||
class StoryItemVersionUpdate(BaseModel):
|
||||
"""Request model for setting a story item's pinned version."""
|
||||
|
||||
version_id: Optional[str] = None # null = use generation default
|
||||
|
||||
|
||||
class EffectConfig(BaseModel):
|
||||
"""A single effect in an effects chain."""
|
||||
|
||||
type: str
|
||||
enabled: bool = True
|
||||
params: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class EffectsChain(BaseModel):
|
||||
"""An ordered list of effects to apply."""
|
||||
|
||||
effects: List[EffectConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class EffectPresetCreate(BaseModel):
|
||||
"""Request model for creating an effect preset."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
effects_chain: List[EffectConfig]
|
||||
|
||||
|
||||
class EffectPresetUpdate(BaseModel):
|
||||
"""Request model for updating an effect preset."""
|
||||
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = None
|
||||
effects_chain: Optional[List[EffectConfig]] = None
|
||||
|
||||
|
||||
class EffectPresetResponse(BaseModel):
|
||||
"""Response model for effect preset."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
effects_chain: List[EffectConfig]
|
||||
is_builtin: bool = False
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class GenerationVersionResponse(BaseModel):
|
||||
"""Response model for a generation version."""
|
||||
|
||||
id: str
|
||||
generation_id: str
|
||||
label: str
|
||||
audio_path: str
|
||||
effects_chain: Optional[List[EffectConfig]] = None
|
||||
source_version_id: Optional[str] = None
|
||||
is_default: bool
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ApplyEffectsRequest(BaseModel):
|
||||
"""Request to apply effects to an existing generation."""
|
||||
|
||||
effects_chain: List[EffectConfig]
|
||||
source_version_id: Optional[str] = Field(
|
||||
None, description="Version to use as source audio (defaults to clean/original)"
|
||||
)
|
||||
label: Optional[str] = Field(None, max_length=100, description="Label for this version (auto-generated if omitted)")
|
||||
set_as_default: bool = Field(default=True, description="Set this version as the default")
|
||||
|
||||
|
||||
class ProfileEffectsUpdate(BaseModel):
|
||||
"""Request to update the default effects chain on a profile."""
|
||||
|
||||
effects_chain: Optional[List[EffectConfig]] = Field(None, description="Effects chain (null to remove)")
|
||||
|
||||
|
||||
class AvailableEffectParam(BaseModel):
|
||||
"""Description of a single effect parameter."""
|
||||
|
||||
default: float
|
||||
min: float
|
||||
max: float
|
||||
step: float
|
||||
description: str
|
||||
|
||||
|
||||
class AvailableEffect(BaseModel):
|
||||
"""Description of an available effect type."""
|
||||
|
||||
type: str
|
||||
label: str
|
||||
description: str
|
||||
params: dict # param_name -> AvailableEffectParam
|
||||
|
||||
|
||||
class AvailableEffectsResponse(BaseModel):
|
||||
"""Response listing all available effect types."""
|
||||
|
||||
effects: List[AvailableEffect]
|
||||
12
backend/pyi_hooks/hook-scipy.stats._distn_infrastructure.py
Normal file
12
backend/pyi_hooks/hook-scipy.stats._distn_infrastructure.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Force scipy.stats._distn_infrastructure to be bundled with its .py source file
|
||||
alongside the .pyc bytecode.
|
||||
|
||||
The runtime hook in backend/pyi_rth_torch_compiler_disable.py patches this
|
||||
module's source at load time (the module has a `del obj` at line 369 that
|
||||
raises NameError under PyInstaller's frozen importer). That patch reads the
|
||||
source via loader.get_source(), which only works if the .py file was
|
||||
actually collected into the bundle.
|
||||
"""
|
||||
|
||||
module_collection_mode = "pyz+py"
|
||||
11
backend/pyi_hooks/hook-transformers.masking_utils.py
Normal file
11
backend/pyi_hooks/hook-transformers.masking_utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Force transformers.masking_utils to be bundled with its .py source alongside
|
||||
the .pyc bytecode so the runtime hook in
|
||||
backend/pyi_rth_torch_compiler_disable.py can source-patch it.
|
||||
|
||||
The patch forces the torch<2.6 code path, bypassing `with TransformGetItemToIndex()`
|
||||
which our torch._dynamo no-op stub can't implement for real — the real context
|
||||
manager uses dynamo graph transforms to avoid `.item()` calls inside vmap.
|
||||
"""
|
||||
|
||||
module_collection_mode = "pyz+py"
|
||||
95
backend/pyi_rth_numpy_compat.py
Normal file
95
backend/pyi_rth_numpy_compat.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
PyInstaller runtime hook: numpy 2.x / torch ABI mismatch fix.
|
||||
|
||||
Problem
|
||||
-------
|
||||
torch is compiled against numpy 1.x headers. numpy 2.x changed the version
|
||||
number returned by PyArray_GetNDArrayCVersion() (0x01000009 → 0x02000000),
|
||||
so torch's is_numpy_available() returns False and every torch.from_numpy()
|
||||
call raises:
|
||||
|
||||
RuntimeError: Numpy is not available
|
||||
|
||||
This surfaces as:
|
||||
|
||||
ValueError: Unable to create tensor, you should probably activate
|
||||
padding with 'padding=True'
|
||||
|
||||
during TTS generation (EncodecFeatureExtractor → BatchFeature.convert_to_tensors).
|
||||
|
||||
Fix
|
||||
---
|
||||
Runtime hooks execute after PyInstaller's FrozenImporter is registered, so
|
||||
frozen torch/numpy are importable here. We start a background thread that
|
||||
waits for torch to finish loading then wraps torch.from_numpy with a ctypes
|
||||
memmove fallback that bypasses the C-level numpy ABI check entirely.
|
||||
|
||||
This approach works with any numpy version and is safer than binary-patching
|
||||
libtorch_python.dylib (which risks PyArray_Descr struct layout mismatches).
|
||||
"""
|
||||
|
||||
import sys
|
||||
import threading
|
||||
|
||||
|
||||
def _patch_torch_from_numpy():
|
||||
import time
|
||||
|
||||
for _ in range(7200): # poll up to 360 s at 50 ms intervals
|
||||
time.sleep(0.05)
|
||||
torch = sys.modules.get("torch")
|
||||
if torch is None or not hasattr(torch, "from_numpy"):
|
||||
continue
|
||||
if getattr(torch, "_vb_from_numpy_patched", False):
|
||||
return
|
||||
try:
|
||||
import ctypes
|
||||
import numpy as np
|
||||
|
||||
_orig = torch.from_numpy
|
||||
|
||||
# Explicit numpy → torch dtype map. Silent fallback to float32 on
|
||||
# unknown dtypes would reinterpret the memcpy'd bytes as fp32 and
|
||||
# silently corrupt data (e.g. fp16 tensors from some TTS engines),
|
||||
# so we raise instead.
|
||||
dtype_map = {
|
||||
"float16": _t.float16,
|
||||
"float32": _t.float32,
|
||||
"float64": _t.float64,
|
||||
"int8": _t.int8,
|
||||
"int16": _t.int16,
|
||||
"int32": _t.int32,
|
||||
"int64": _t.int64,
|
||||
"uint8": _t.uint8,
|
||||
"bool": _t.bool,
|
||||
"complex64": _t.complex64,
|
||||
"complex128": _t.complex128,
|
||||
}
|
||||
|
||||
def _safe_from_numpy(
|
||||
arr, _orig=_orig, _c=ctypes, _np=np, _t=torch, _map=dtype_map
|
||||
):
|
||||
try:
|
||||
return _orig(arr)
|
||||
except RuntimeError:
|
||||
a = _np.ascontiguousarray(arr)
|
||||
key = str(a.dtype)
|
||||
if key not in _map:
|
||||
raise TypeError(
|
||||
f"pyi_rth_numpy_compat: unsupported numpy dtype "
|
||||
f"{key!r} in torch.from_numpy fallback; add an "
|
||||
f"explicit mapping rather than silently copying "
|
||||
f"bytes into the wrong dtype."
|
||||
)
|
||||
out = _t.empty(list(a.shape), dtype=_map[key])
|
||||
_c.memmove(out.data_ptr(), a.ctypes.data, a.nbytes)
|
||||
return out
|
||||
|
||||
torch.from_numpy = _safe_from_numpy
|
||||
torch._vb_from_numpy_patched = True
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
|
||||
threading.Thread(target=_patch_torch_from_numpy, daemon=True).start()
|
||||
540
backend/pyi_rth_torch_compiler_disable.py
Normal file
540
backend/pyi_rth_torch_compiler_disable.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""
|
||||
PyInstaller runtime hook: stub torch._dynamo to a no-op module.
|
||||
|
||||
Problem
|
||||
-------
|
||||
transformers triggers torch._dynamo import at module-load time (not just
|
||||
when torch.compile is called) via class-body decorators:
|
||||
|
||||
transformers/modeling_utils.py:1984
|
||||
@torch._dynamo.allow_in_graph
|
||||
class PreTrainedModel(...)
|
||||
|
||||
transformers/integrations/flex_attention.py:61
|
||||
@torch.compiler.disable(recursive=False)
|
||||
class WrappedFlexAttention...
|
||||
|
||||
The attribute access triggers torch.__getattr__ -> importlib.import_module
|
||||
-> torch._dynamo -> torch._dynamo.utils imports torch._numpy ->
|
||||
torch._numpy._ndarray imports torch._numpy._ufuncs, which crashes under
|
||||
PyInstaller with:
|
||||
|
||||
File "torch/_numpy/_ufuncs.py", line 235, in <module>
|
||||
vars()[name] = deco_binary_ufunc(ufunc)
|
||||
NameError: name 'name' is not defined
|
||||
|
||||
(The module-level `for name in _binary: vars()[name] = ...` pattern works
|
||||
in a regular venv but fails in the PyInstaller bundle. Root cause is in
|
||||
PyInstaller's importer / bytecode pipeline and not easily fixed upstream.)
|
||||
|
||||
Surfaces as Kokoro failing to load when `from transformers import AlbertModel`
|
||||
trips the decorator chain.
|
||||
|
||||
Fix
|
||||
---
|
||||
voicebox never uses torch.compile / torch._dynamo for inference, so we
|
||||
replace torch._dynamo with a no-op stub module before transformers is
|
||||
imported. Any attribute access on the stub returns a pass-through callable,
|
||||
so `@torch._dynamo.allow_in_graph`, `torch._dynamo.is_compiling()`,
|
||||
`torch._dynamo.mark_static_address(...)`, etc. all work.
|
||||
|
||||
This hook is pure sys.modules manipulation — we deliberately do NOT import
|
||||
torch here. Runtime hooks run before the app starts and before
|
||||
pyi_rth_numpy_compat has had a chance to patch torch.from_numpy (it runs
|
||||
in a background thread, waiting for torch to appear in sys.modules).
|
||||
Eager-importing torch at hook time would trip the numpy ABI issue and
|
||||
kill the server process at startup.
|
||||
|
||||
torch.compiler.disable does not need a separate stub: its implementation
|
||||
is effectively `import torch._dynamo; return torch._dynamo.disable(...)`,
|
||||
and since our stub is in sys.modules, that call resolves to our no-op
|
||||
_NoopDecorator pass-through.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
|
||||
|
||||
# Diagnostics — log hook activity to a file alongside the bundle so we can
|
||||
# see what's happening when the server is run as a sidecar (no stdout for
|
||||
# runtime hook prints). Safe no-op if the file can't be written.
|
||||
_DIAG_PATH = os.path.join(tempfile.gettempdir(), "voicebox_rt_hook.log")
|
||||
|
||||
|
||||
def _diag(msg: str) -> None:
|
||||
try:
|
||||
with open(_DIAG_PATH, "a", encoding="utf-8") as f:
|
||||
f.write(msg + "\n")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
_HOOK_VERSION = "v6-masking-utils-finder"
|
||||
_diag(f"=== runtime hook load @ pid={os.getpid()} version={_HOOK_VERSION} ===")
|
||||
|
||||
|
||||
class _NoopDecorator:
|
||||
"""Multi-role no-op: decorator, falsey predicate, and context manager.
|
||||
|
||||
Returned from calls like `torch._dynamo.disable()` (decorator),
|
||||
`torch._dynamo.is_compiling()` (predicate used in `if not ...`), and
|
||||
`with torch._dynamo._trace_wrapped_higher_order_op.TransformGetItemToIndex():`
|
||||
(context manager used to scope an fx graph transformation).
|
||||
|
||||
By implementing __call__, __bool__, __enter__, __exit__, and __iter__ we
|
||||
cover every use pattern we've seen transformers/torch use on a stubbed
|
||||
object. Anything we haven't covered will raise a clearer error than a
|
||||
silent wrong-result.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __call__(self, fn=None, *args, **kwargs):
|
||||
return fn
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return False
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
return False # don't suppress exceptions
|
||||
|
||||
def __iter__(self):
|
||||
return iter(())
|
||||
|
||||
|
||||
_noop_decorator_singleton = _NoopDecorator()
|
||||
|
||||
|
||||
def _noop_callable(*args, **kwargs):
|
||||
# Direct-decorator use: @torch._dynamo.foo (no parens) — fn is positional
|
||||
if len(args) == 1 and callable(args[0]) and not kwargs:
|
||||
return args[0]
|
||||
# Side-effect call with non-callable arg(s), e.g. mark_static_address(tensor)
|
||||
return _noop_decorator_singleton
|
||||
|
||||
|
||||
class _NoopDynamoModule(types.ModuleType):
|
||||
"""Permissive stub: every attribute is a pass-through callable.
|
||||
|
||||
Covers attributes transformers hits at import time (allow_in_graph) and
|
||||
runtime (is_compiling, mark_static_address, reset, disable, ...).
|
||||
|
||||
Dunder attributes (__file__, __spec__, __loader__, ...) raise
|
||||
AttributeError so probes like inspect.getmodule() — which does
|
||||
`hasattr(m, '__file__')` then `os.path.normpath(m.__file__)` — see the
|
||||
module as having no source file and fall through to its normal
|
||||
handling, instead of receiving a function and blowing up.
|
||||
"""
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if name.startswith("__") and name.endswith("__"):
|
||||
raise AttributeError(name)
|
||||
return _noop_callable
|
||||
|
||||
|
||||
class _DynamoLoader:
|
||||
"""Loader used by _DynamoMetaPathFinder to materialise stub submodules."""
|
||||
|
||||
def create_module(self, spec):
|
||||
return _NoopDynamoModule(spec.name)
|
||||
|
||||
def exec_module(self, module):
|
||||
# Mark every stub submodule as a package so deeper submodule imports
|
||||
# (`from torch._dynamo.X.Y import Z`) keep working.
|
||||
module.__path__ = []
|
||||
|
||||
|
||||
class _DynamoMetaPathFinder:
|
||||
"""Resolve any `torch._dynamo.X[.Y...]` import to a no-op stub module.
|
||||
|
||||
Without this, `from torch._dynamo._trace_wrapped_higher_order_op import X`
|
||||
fails even with torch._dynamo pre-populated in sys.modules — Python's
|
||||
import machinery checks the parent's __path__ and then looks up the
|
||||
child, and we need to provide both.
|
||||
"""
|
||||
|
||||
def find_spec(self, fullname, path=None, target=None):
|
||||
if fullname == "torch._dynamo":
|
||||
return None # handled by the pre-populated sys.modules entry
|
||||
if not fullname.startswith("torch._dynamo."):
|
||||
return None
|
||||
from importlib.machinery import ModuleSpec
|
||||
|
||||
return ModuleSpec(fullname, _DynamoLoader(), is_package=True)
|
||||
|
||||
|
||||
class _TransformersStubFinder:
|
||||
"""Replace specific transformers submodules with no-op stubs.
|
||||
|
||||
Two modules are targeted:
|
||||
|
||||
1. transformers.utils.auto_docstring
|
||||
The real @auto_docstring decorator loads
|
||||
transformers.models.auto.modeling_auto just to build example docstrings,
|
||||
which drags in GenerationMixin -> candidate_generator -> sklearn.metrics
|
||||
-> scipy.stats._distn_infrastructure and trips (2) below. Docstrings
|
||||
aren't functional for inference, so a pass-through decorator is safe.
|
||||
|
||||
2. transformers.generation.candidate_generator
|
||||
Imported at module scope by transformers.generation.utils. It does
|
||||
`from sklearn.metrics import roc_curve` at module load, which triggers:
|
||||
|
||||
File "scipy/stats/_distn_infrastructure.py", line 369, in <module>
|
||||
NameError: name 'obj' is not defined
|
||||
|
||||
This is a PyInstaller-specific module-load bug (same class as the
|
||||
torch._numpy._ufuncs crash) where a module-level `for obj in [s for s
|
||||
in dir() if ...]` loop evaluates to empty in the bundle, leaving `obj`
|
||||
unbound before `del obj`.
|
||||
|
||||
The exports (AssistedCandidateGenerator, EarlyExitCandidateGenerator,
|
||||
etc.) are speculative-decoding helpers voicebox's TTS engines do not
|
||||
use; a no-op stub module satisfies the imports.
|
||||
"""
|
||||
|
||||
_STUBBED_MODULES = frozenset(
|
||||
{
|
||||
"transformers.utils.auto_docstring",
|
||||
"transformers.generation.candidate_generator",
|
||||
}
|
||||
)
|
||||
|
||||
def find_spec(self, fullname, path=None, target=None):
|
||||
if fullname not in self._STUBBED_MODULES:
|
||||
return None
|
||||
from importlib.machinery import ModuleSpec
|
||||
|
||||
return ModuleSpec(fullname, _NoopStubLoader(), is_package=False)
|
||||
|
||||
|
||||
class _NoopStubLoader:
|
||||
def create_module(self, spec):
|
||||
return _NoopDynamoModule(spec.name)
|
||||
|
||||
def exec_module(self, module):
|
||||
# _NoopDynamoModule.__getattr__ already answers every non-dunder
|
||||
# attribute with a pass-through callable, which satisfies
|
||||
# `from stubbed_module import X` for any X.
|
||||
pass
|
||||
|
||||
|
||||
def _patch_scipy_distn_source(source: str) -> str:
|
||||
"""Replace the unsafe `del obj` with a no-op that survives when obj is unbound.
|
||||
|
||||
Returns the input unchanged if the target line isn't found (e.g. scipy
|
||||
version has changed).
|
||||
"""
|
||||
target = "\ndel obj\n"
|
||||
replacement = "\nglobals().pop('obj', None)\n"
|
||||
if target in source:
|
||||
return source.replace(target, replacement, 1)
|
||||
return source
|
||||
|
||||
|
||||
def _patch_masking_utils_source(source: str) -> str:
|
||||
"""Force torch<2.6 code path in transformers.masking_utils.
|
||||
|
||||
The torch>=2.6 path uses `with TransformGetItemToIndex():` to allow
|
||||
`.item()` calls inside vmap. That context manager is implemented via
|
||||
torch._dynamo graph transforms, which our stub doesn't reproduce — it's
|
||||
a no-op. The inner `_vmap_for_bhqkv` then crashes with:
|
||||
|
||||
RuntimeError: vmap: It looks like you're calling .item() on a Tensor.
|
||||
|
||||
Forcing the torch<2.6 flag off selects sdpa_mask_older_torch which uses
|
||||
a different vmap pattern that does not hit .item() and does not need
|
||||
TransformGetItemToIndex.
|
||||
"""
|
||||
target = 'is_torch_greater_or_equal("2.6", accept_dev=True)'
|
||||
# Find the specific line that assigns _is_torch_greater_or_equal_than_2_6
|
||||
if "_is_torch_greater_or_equal_than_2_6 = " + target in source:
|
||||
return source.replace(
|
||||
"_is_torch_greater_or_equal_than_2_6 = " + target,
|
||||
"_is_torch_greater_or_equal_than_2_6 = False",
|
||||
1,
|
||||
)
|
||||
return source
|
||||
|
||||
|
||||
class _SourcePatchingFinder:
|
||||
"""Generic delegate-and-wrap meta-path finder that patches a module's
|
||||
source before exec'ing.
|
||||
|
||||
Subclasses declare `target` (module fullname) and `patch` (str->str).
|
||||
Requires the target module's .py source to be bundled (use a PyInstaller
|
||||
hook setting module_collection_mode = "pyz+py").
|
||||
"""
|
||||
|
||||
target: str
|
||||
patch_fn: callable = None
|
||||
|
||||
def find_spec(self, fullname, path=None, target=None):
|
||||
if fullname != self.target:
|
||||
return None
|
||||
for finder in sys.meta_path:
|
||||
if finder is self:
|
||||
continue
|
||||
find = getattr(finder, "find_spec", None)
|
||||
if find is None:
|
||||
continue
|
||||
try:
|
||||
real_spec = find(fullname, path, target)
|
||||
except Exception:
|
||||
continue
|
||||
if real_spec is None or real_spec.loader is None:
|
||||
continue
|
||||
real_spec.loader = _SourcePatchLoader(real_spec.loader, self.patch_fn)
|
||||
return real_spec
|
||||
return None
|
||||
|
||||
|
||||
class _SourcePatchLoader:
|
||||
"""Delegate loader that reads source via get_source, applies a patch, and
|
||||
compile/exec's the patched text into module.__dict__.
|
||||
"""
|
||||
|
||||
def __init__(self, inner, patch_fn):
|
||||
self._inner = inner
|
||||
self._patch_fn = patch_fn
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._inner, name)
|
||||
|
||||
def create_module(self, spec):
|
||||
return self._inner.create_module(spec)
|
||||
|
||||
def exec_module(self, module):
|
||||
source = None
|
||||
try:
|
||||
source = self._inner.get_source(module.__name__)
|
||||
except Exception as e:
|
||||
_diag(f"[source-patch] get_source({module.__name__}) failed: {e!r}")
|
||||
|
||||
if not source:
|
||||
_diag(
|
||||
f"[source-patch] no source for {module.__name__}; "
|
||||
"falling back to inner exec_module (patch NOT applied)"
|
||||
)
|
||||
self._inner.exec_module(module)
|
||||
return
|
||||
|
||||
patched = self._patch_fn(source)
|
||||
_diag(
|
||||
f"[source-patch] {module.__name__}: "
|
||||
f"patched={patched is not source}, len={len(patched)}"
|
||||
)
|
||||
spec = module.__spec__
|
||||
if spec is not None and spec.submodule_search_locations is not None:
|
||||
module.__path__ = spec.submodule_search_locations
|
||||
filename = getattr(self._inner, "path", module.__name__)
|
||||
exec(compile(patched, filename, "exec"), module.__dict__)
|
||||
_diag(f"[source-patch] {module.__name__} OK")
|
||||
|
||||
|
||||
class _MaskingUtilsFinder(_SourcePatchingFinder):
|
||||
target = "transformers.masking_utils"
|
||||
patch_fn = staticmethod(_patch_masking_utils_source)
|
||||
|
||||
|
||||
class _ScipyDistnPatchingFinder:
|
||||
"""Delegate-and-wrap finder for scipy.stats._distn_infrastructure.
|
||||
|
||||
That module ends with:
|
||||
|
||||
for obj in [s for s in dir() if s.startswith('_doc_')]:
|
||||
exec('del ' + obj)
|
||||
del obj
|
||||
|
||||
In the PyInstaller bundle the list comprehension evaluates to empty
|
||||
(module-level dir() under the frozen importer returns a different scope
|
||||
than CPython's normal module-exec path — same class of bug as the
|
||||
torch._numpy._ufuncs crash). The for loop body doesn't run, `obj` is
|
||||
never bound, and the trailing `del obj` raises NameError at module load.
|
||||
|
||||
This kills every downstream module: librosa (needed by nearly every TTS
|
||||
engine for mel filters) -> scipy.signal -> scipy.stats -> here.
|
||||
|
||||
Workaround: delegate to the real loader, but pre-bind `obj = None` in the
|
||||
module namespace before its bytecode runs. If the for loop executes, each
|
||||
iteration overwrites the sentinel via STORE_NAME (normal behaviour). If it
|
||||
doesn't, `del obj` removes the sentinel and module load succeeds. The
|
||||
`_doc_*` cleanup this line was meant to do is purely cosmetic — those vars
|
||||
stay in the module namespace but nothing references them after this point.
|
||||
"""
|
||||
|
||||
_TARGET = "scipy.stats._distn_infrastructure"
|
||||
|
||||
def find_spec(self, fullname, path=None, target=None):
|
||||
if fullname != self._TARGET:
|
||||
return None
|
||||
_diag(f"[scipy-finder] match: {fullname}, path={path!r}")
|
||||
# Delegate to the other finders to locate the real spec
|
||||
for finder in sys.meta_path:
|
||||
if finder is self:
|
||||
continue
|
||||
find = getattr(finder, "find_spec", None)
|
||||
if find is None:
|
||||
continue
|
||||
try:
|
||||
real_spec = find(fullname, path, target)
|
||||
except Exception as e:
|
||||
_diag(f"[scipy-finder] inner finder {type(finder).__name__} raised: {e}")
|
||||
continue
|
||||
if real_spec is None:
|
||||
continue
|
||||
if real_spec.loader is None:
|
||||
_diag(f"[scipy-finder] {type(finder).__name__} returned spec with loader=None")
|
||||
continue
|
||||
_diag(
|
||||
f"[scipy-finder] wrapped loader from "
|
||||
f"{type(finder).__name__} -> {type(real_spec.loader).__name__}"
|
||||
)
|
||||
real_spec.loader = _ScipyDistnPrebindLoader(real_spec.loader)
|
||||
return real_spec
|
||||
_diag("[scipy-finder] NO inner finder returned a spec")
|
||||
return None
|
||||
|
||||
|
||||
class _ScipyDistnPrebindLoader:
|
||||
"""Thin wrapper that pre-binds `obj = None` before delegating to the
|
||||
real PyInstaller loader.
|
||||
|
||||
Every other attribute/method delegates to the inner loader — PyiFrozenLoader
|
||||
is a rich FileLoader/ExecutionLoader with get_code/get_source/get_filename/
|
||||
is_package/get_resource_reader/etc., any of which Python's import machinery
|
||||
or 3rd-party code may call on spec.loader. Forwarding via __getattr__
|
||||
avoids breaking any of those paths (and preserves @_check_name contracts
|
||||
because the decorated methods run on the inner instance where self.name
|
||||
matches spec.name).
|
||||
"""
|
||||
|
||||
def __init__(self, inner):
|
||||
self._inner = inner
|
||||
|
||||
def __getattr__(self, name):
|
||||
# __getattr__ fires only for attrs not already on self, so delegate
|
||||
# everything that isn't create_module/exec_module (or __getattr__/init).
|
||||
return getattr(self._inner, name)
|
||||
|
||||
def create_module(self, spec):
|
||||
return self._inner.create_module(spec)
|
||||
|
||||
def exec_module(self, module):
|
||||
# Compile scipy's module source with the problematic line patched.
|
||||
#
|
||||
# The real module ends with:
|
||||
# for obj in [s for s in dir() if s.startswith('_doc_')]:
|
||||
# exec('del ' + obj)
|
||||
# del obj
|
||||
#
|
||||
# Under PyInstaller's frozen importer, `del obj` raises NameError
|
||||
# even when we pre-populate module.__dict__['obj'] — the pre-compiled
|
||||
# .pyc bytecode interacts with the frame setup differently than a
|
||||
# fresh compile() from source. Easiest robust fix: read the source
|
||||
# and replace `del obj` with a safe variant before compiling.
|
||||
#
|
||||
# Requires the .py source to be bundled alongside the .pyc — see
|
||||
# backend/pyi_hooks/hook-scipy.stats._distn_infrastructure.py.
|
||||
source = None
|
||||
try:
|
||||
source = self._inner.get_source(module.__name__)
|
||||
except Exception as e:
|
||||
_diag(f"[scipy-loader] get_source failed: {e!r}")
|
||||
|
||||
if source:
|
||||
patched = _patch_scipy_distn_source(source)
|
||||
_diag(
|
||||
f"[scipy-loader] source-patch path: patched={patched is not source}, "
|
||||
f"len={len(patched)}"
|
||||
)
|
||||
spec = module.__spec__
|
||||
if spec is not None and spec.submodule_search_locations is not None:
|
||||
module.__path__ = spec.submodule_search_locations
|
||||
filename = getattr(self._inner, "path", module.__name__)
|
||||
bytecode = compile(patched, filename, "exec")
|
||||
try:
|
||||
exec(bytecode, module.__dict__)
|
||||
except Exception as e:
|
||||
_diag(f"[scipy-loader] patched exec raised {type(e).__name__}: {e!r}")
|
||||
raise
|
||||
_diag(f"[scipy-loader] exec_module {module.__name__} OK (source-patched)")
|
||||
return
|
||||
|
||||
# No source available — fall back to the pre-bind approach. This is
|
||||
# best-effort; if the frozen .pyc really does see a different `obj`
|
||||
# slot, this will still crash, but we've done all we can without
|
||||
# source.
|
||||
_diag("[scipy-loader] no source available; falling back to pre-bind")
|
||||
module.__dict__["obj"] = None
|
||||
self._inner.exec_module(module)
|
||||
|
||||
|
||||
def _install_dynamo_stub() -> None:
|
||||
stub = _NoopDynamoModule("torch._dynamo")
|
||||
# Mark as a package so `from torch._dynamo.X import Y` imports work
|
||||
# (Python's import machinery checks parent.__path__ before looking up
|
||||
# the child).
|
||||
stub.__path__ = []
|
||||
# torch._dynamo.config is accessed as a nested attribute namespace
|
||||
# (e.g. `torch._dynamo.config.capture_scalar_outputs = True`), so use
|
||||
# a permissive module so any attr read returns a no-op and sets succeed.
|
||||
stub.config = _NoopDynamoModule("torch._dynamo.config")
|
||||
stub.config.__path__ = []
|
||||
sys.modules["torch._dynamo"] = stub
|
||||
sys.modules["torch._dynamo.config"] = stub.config
|
||||
|
||||
# Finders:
|
||||
# - torch._dynamo.* submodules -> no-op stubs
|
||||
# - transformers.utils.auto_docstring and
|
||||
# transformers.generation.candidate_generator -> no-op stubs (both
|
||||
# paths reach sklearn -> scipy.stats which trips a separate crash)
|
||||
# - scipy.stats._distn_infrastructure -> real load with `obj` pre-bound,
|
||||
# so librosa -> scipy.signal -> scipy.stats loads cleanly
|
||||
for _FinderCls in (
|
||||
_DynamoMetaPathFinder,
|
||||
_TransformersStubFinder,
|
||||
_ScipyDistnPatchingFinder,
|
||||
_MaskingUtilsFinder,
|
||||
):
|
||||
try:
|
||||
sys.meta_path.insert(0, _FinderCls())
|
||||
_diag(f"installed finder: {_FinderCls.__name__}")
|
||||
except Exception as e:
|
||||
_diag(f"FAILED to install {_FinderCls.__name__}: {e!r}")
|
||||
_diag(
|
||||
"final sys.meta_path head: "
|
||||
+ ", ".join(type(f).__name__ for f in sys.meta_path[:6])
|
||||
)
|
||||
|
||||
# If torch is already imported, also set the attribute on the package so
|
||||
# `torch._dynamo` resolves to our stub without triggering torch.__getattr__
|
||||
# (which would lazy-import the real module and crash).
|
||||
torch_mod = sys.modules.get("torch")
|
||||
if torch_mod is not None:
|
||||
torch_mod._dynamo = stub
|
||||
|
||||
|
||||
try:
|
||||
_install_dynamo_stub()
|
||||
except Exception as _e:
|
||||
# Best effort. If this fails the original NameError will surface when
|
||||
# transformers imports — no worse than not patching at all.
|
||||
_diag(f"_install_dynamo_stub FAILED: {_e!r}")
|
||||
|
||||
# NOTE: we deliberately do NOT import torch or torch.compiler here.
|
||||
# Runtime hooks run before the app starts and before pyi_rth_numpy_compat
|
||||
# has had a chance to patch torch.from_numpy (it runs in a background
|
||||
# thread, waiting for torch to appear in sys.modules). Importing torch
|
||||
# eagerly at hook time would trip the numpy ABI issue and kill the
|
||||
# server process at startup.
|
||||
#
|
||||
# torch.compiler.disable does not need an explicit stub: its
|
||||
# implementation is effectively `import torch._dynamo; return
|
||||
# torch._dynamo.disable(fn, recursive, reason=reason)`, and since our
|
||||
# stub is installed in sys.modules, that call resolves to our no-op
|
||||
# _NoopDecorator pass-through.
|
||||
83
backend/pyproject.toml
Normal file
83
backend/pyproject.toml
Normal file
@@ -0,0 +1,83 @@
|
||||
[project]
|
||||
name = "voicebox-backend"
|
||||
version = "0.2.3"
|
||||
requires-python = ">=3.12"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ruff – linter + formatter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
line-length = 120
|
||||
src = ["."]
|
||||
|
||||
# Files/dirs to skip entirely.
|
||||
extend-exclude = [
|
||||
"voicebox-server.spec",
|
||||
"build_binary.py",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"F", # pyflakes
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"I", # isort
|
||||
"N", # pep8-naming
|
||||
"UP", # pyupgrade (modernize syntax for 3.12)
|
||||
"B", # flake8-bugbear
|
||||
"A", # flake8-builtins (shadowing built-in names)
|
||||
"SIM", # flake8-simplify
|
||||
"T20", # flake8-print (flag print() calls)
|
||||
"RET", # flake8-return
|
||||
"PIE", # misc lints
|
||||
"PT", # flake8-pytest-style
|
||||
"RUF", # ruff-specific rules
|
||||
"ERA", # commented-out code detection
|
||||
"FIX", # flag TODO/FIXME/HACK/XXX for review
|
||||
]
|
||||
|
||||
ignore = [
|
||||
# Allow print() in existing code -- remove items from this list as files
|
||||
# are migrated to logging during the refactor.
|
||||
"T201", # print() found
|
||||
|
||||
# These conflict with the formatter or are too noisy during migration:
|
||||
"E501", # line too long (formatter handles this)
|
||||
"RET504", # unnecessary assignment before return
|
||||
"SIM108", # use ternary operator (sometimes less readable)
|
||||
"B008", # function call in default argument (FastAPI Depends() pattern)
|
||||
"UP007", # use X | Y for union (auto-fixed by UP, but noisy on big diffs)
|
||||
]
|
||||
|
||||
# Per-file rule overrides.
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
# Tests can use assert, print, and magic values freely.
|
||||
"tests/**" = ["S101", "T201", "PLR2004", "ERA001"]
|
||||
# __init__.py re-exports are expected to have unused imports.
|
||||
"**/__init__.py" = ["F401"]
|
||||
# Entry points and scripts legitimately use print.
|
||||
"server.py" = ["T201"]
|
||||
"main.py" = ["T201"]
|
||||
# AMD GPU env vars must be set before torch import.
|
||||
"app.py" = ["E402"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["backend"]
|
||||
# Group "from backend.*" imports into the first-party section.
|
||||
force-single-line = false
|
||||
combine-as-imports = true
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
docstring-code-format = true
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# pytest
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
22
backend/requirements-mlx.txt
Normal file
22
backend/requirements-mlx.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
# MLX-specific dependencies (Apple Silicon only)
|
||||
# These should only be installed on aarch64-apple-darwin platforms
|
||||
|
||||
mlx>=0.30.0
|
||||
|
||||
# miniaudio is a runtime dep of mlx-audio's STT path (mlx_audio.stt).
|
||||
# mlx-audio itself is installed --no-deps (see comment below), so we
|
||||
# must list miniaudio explicitly here or transcription fails on fresh
|
||||
# M1 installs with `ModuleNotFoundError: miniaudio` (issue #505).
|
||||
miniaudio>=1.59
|
||||
|
||||
# NOTE: mlx-audio is intentionally not listed here. From 0.3.1 onward it
|
||||
# declares `transformers==5.0.0rc3` / `>=5.0.0`, which conflicts with the
|
||||
# `transformers<=4.57.6` cap in requirements.txt and breaks CI's clean
|
||||
# resolver. The mlx-audio API surface we use (mlx_audio.tts.load,
|
||||
# mlx_audio.stt.load) works fine on transformers 4.57.x in practice.
|
||||
#
|
||||
# Install it via `pip install --no-deps mlx-audio==0.4.1` after this file
|
||||
# (see .github/workflows/release.yml). Most other mlx-audio runtime deps
|
||||
# (huggingface_hub, librosa, mlx-lm, numba, numpy, protobuf, pyloudnorm,
|
||||
# sounddevice, tqdm) are already in requirements.txt or pulled in by
|
||||
# other engines.
|
||||
67
backend/requirements.txt
Normal file
67
backend/requirements.txt
Normal file
@@ -0,0 +1,67 @@
|
||||
# FastAPI and server
|
||||
fastapi>=0.109.0
|
||||
uvicorn[standard]>=0.27.0
|
||||
pydantic>=2.5.0
|
||||
|
||||
# Database
|
||||
sqlalchemy>=2.0.0
|
||||
alembic>=1.13.0
|
||||
|
||||
# ML models
|
||||
torch>=2.2.0
|
||||
transformers>=4.36.0,<=4.57.6
|
||||
accelerate>=0.26.0
|
||||
huggingface_hub>=0.20.0
|
||||
qwen-tts>=0.0.5
|
||||
|
||||
# LuxTTS (voice cloning engine)
|
||||
# piper-phonemize needs custom index (no PyPI wheels)
|
||||
--find-links https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||
# linacodec is a git-only dep of Zipvoice (uv-only source, pip can't resolve it)
|
||||
linacodec @ git+https://github.com/ysharma3501/LinaCodec.git
|
||||
Zipvoice @ git+https://github.com/ysharma3501/LuxTTS.git
|
||||
|
||||
# Chatterbox TTS sub-dependencies (chatterbox-tts itself is installed
|
||||
# --no-deps in the setup script because it pins numpy<1.26 / torch==2.6
|
||||
# which are incompatible with Python 3.12+)
|
||||
conformer>=0.3.2
|
||||
diffusers>=0.29.0
|
||||
omegaconf
|
||||
pykakasi
|
||||
resemble-perth>=1.0.1
|
||||
s3tokenizer
|
||||
spacy-pkuseg
|
||||
pyloudnorm
|
||||
|
||||
# HumeAI TADA sub-dependencies (hume-tada itself is installed
|
||||
# --no-deps in the setup script because it pins torch>=2.7,<2.8.
|
||||
# descript-audio-codec is NOT installed — it pulls onnx/tensorboard
|
||||
# via descript-audiotools. A lightweight shim in utils/dac_shim.py
|
||||
# provides the only class TADA uses: Snake1d.)
|
||||
torchaudio
|
||||
|
||||
# Kokoro TTS (lightweight 82M-param engine)
|
||||
kokoro>=0.9.4
|
||||
misaki[en,ja,zh]>=0.9.4
|
||||
# spacy model for misaki English G2P — must be pre-installed or misaki
|
||||
# tries spacy.cli.download() at runtime which crashes frozen builds
|
||||
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl
|
||||
# fugashi (pulled in by misaki[ja]) needs a MeCab dictionary on disk.
|
||||
# unidic-lite ships one inside the wheel (~50MB); the full `unidic` package
|
||||
# requires `python -m unidic download` (~526MB) which breaks frozen builds
|
||||
# for the same reason en_core_web_sm does.
|
||||
unidic-lite>=1.0.8
|
||||
|
||||
# Audio processing
|
||||
librosa>=0.10.0
|
||||
soundfile>=0.12.0
|
||||
numpy>=1.24.0,<2.0
|
||||
numba>=0.60.0,<0.61.0
|
||||
pedalboard>=0.9.0
|
||||
|
||||
# HTTP client (for CUDA backend download)
|
||||
httpx>=0.27.0
|
||||
|
||||
# Utilities
|
||||
python-multipart>=0.0.6
|
||||
Pillow>=10.0.0
|
||||
32
backend/routes/__init__.py
Normal file
32
backend/routes/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Route registration for the voicebox API."""
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def register_routers(app: FastAPI) -> None:
|
||||
"""Include all domain routers on the application."""
|
||||
from .health import router as health_router
|
||||
from .profiles import router as profiles_router
|
||||
from .channels import router as channels_router
|
||||
from .generations import router as generations_router
|
||||
from .history import router as history_router
|
||||
from .transcription import router as transcription_router
|
||||
from .stories import router as stories_router
|
||||
from .effects import router as effects_router
|
||||
from .audio import router as audio_router
|
||||
from .models import router as models_router
|
||||
from .tasks import router as tasks_router
|
||||
from .cuda import router as cuda_router
|
||||
|
||||
app.include_router(health_router)
|
||||
app.include_router(profiles_router)
|
||||
app.include_router(channels_router)
|
||||
app.include_router(generations_router)
|
||||
app.include_router(history_router)
|
||||
app.include_router(transcription_router)
|
||||
app.include_router(stories_router)
|
||||
app.include_router(effects_router)
|
||||
app.include_router(audio_router)
|
||||
app.include_router(models_router)
|
||||
app.include_router(tasks_router)
|
||||
app.include_router(cuda_router)
|
||||
69
backend/routes/audio.py
Normal file
69
backend/routes/audio.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Audio file serving endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import config, models
|
||||
from ..services import history
|
||||
from ..database import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/audio/version/{version_id}")
|
||||
async def get_version_audio(version_id: str, db: Session = Depends(get_db)):
|
||||
"""Serve audio for a specific version."""
|
||||
from ..services import versions as versions_mod
|
||||
|
||||
version = versions_mod.get_version(version_id, db)
|
||||
if not version:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
audio_path = config.resolve_storage_path(version.audio_path)
|
||||
if audio_path is None or not audio_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
return FileResponse(
|
||||
audio_path,
|
||||
media_type="audio/wav",
|
||||
filename=f"generation_{version.generation_id}_{version.label}.wav",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/audio/{generation_id}")
|
||||
async def get_audio(generation_id: str, db: Session = Depends(get_db)):
|
||||
"""Serve generated audio file (serves the default version)."""
|
||||
generation = await history.get_generation(generation_id, db)
|
||||
if not generation:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is None or not audio_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
return FileResponse(
|
||||
audio_path,
|
||||
media_type="audio/wav",
|
||||
filename=f"generation_{generation_id}.wav",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/samples/{sample_id}")
|
||||
async def get_sample_audio(sample_id: str, db: Session = Depends(get_db)):
|
||||
"""Serve profile sample audio file."""
|
||||
from ..database import ProfileSample as DBProfileSample
|
||||
|
||||
sample = db.query(DBProfileSample).filter_by(id=sample_id).first()
|
||||
if not sample:
|
||||
raise HTTPException(status_code=404, detail="Sample not found")
|
||||
|
||||
audio_path = config.resolve_storage_path(sample.audio_path)
|
||||
if audio_path is None or not audio_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
return FileResponse(
|
||||
audio_path,
|
||||
media_type="audio/wav",
|
||||
filename=f"sample_{sample_id}.wav",
|
||||
)
|
||||
98
backend/routes/channels.py
Normal file
98
backend/routes/channels.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Audio channel endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import models
|
||||
from ..services import channels
|
||||
from ..database import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/channels", response_model=list[models.AudioChannelResponse])
|
||||
async def list_channels(db: Session = Depends(get_db)):
|
||||
"""List all audio channels."""
|
||||
return await channels.list_channels(db)
|
||||
|
||||
|
||||
@router.post("/channels", response_model=models.AudioChannelResponse)
|
||||
async def create_channel(
|
||||
data: models.AudioChannelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create a new audio channel."""
|
||||
try:
|
||||
return await channels.create_channel(data, db)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/channels/{channel_id}", response_model=models.AudioChannelResponse)
|
||||
async def get_channel(
|
||||
channel_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get an audio channel by ID."""
|
||||
channel = await channels.get_channel(channel_id, db)
|
||||
if not channel:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
return channel
|
||||
|
||||
|
||||
@router.put("/channels/{channel_id}", response_model=models.AudioChannelResponse)
|
||||
async def update_channel(
|
||||
channel_id: str,
|
||||
data: models.AudioChannelUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update an audio channel."""
|
||||
try:
|
||||
channel = await channels.update_channel(channel_id, data, db)
|
||||
if not channel:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
return channel
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/channels/{channel_id}")
|
||||
async def delete_channel(
|
||||
channel_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete an audio channel."""
|
||||
try:
|
||||
success = await channels.delete_channel(channel_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
return {"message": "Channel deleted successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/channels/{channel_id}/voices")
|
||||
async def get_channel_voices(
|
||||
channel_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get list of profile IDs assigned to a channel."""
|
||||
try:
|
||||
profile_ids = await channels.get_channel_voices(channel_id, db)
|
||||
return {"profile_ids": profile_ids}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/channels/{channel_id}/voices")
|
||||
async def set_channel_voices(
|
||||
channel_id: str,
|
||||
data: models.ChannelVoiceAssignment,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Set which voices are assigned to a channel."""
|
||||
try:
|
||||
await channels.set_channel_voices(channel_id, data, db)
|
||||
return {"message": "Channel voices updated successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
82
backend/routes/cuda.py
Normal file
82
backend/routes/cuda.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""CUDA backend management endpoints."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from ..services.task_queue import create_background_task
|
||||
from ..utils.progress import get_progress_manager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/backend/cuda-status")
|
||||
async def get_cuda_status():
|
||||
"""Get CUDA backend download/availability status."""
|
||||
from ..services import cuda
|
||||
|
||||
return cuda.get_cuda_status()
|
||||
|
||||
|
||||
@router.post("/backend/download-cuda")
|
||||
async def download_cuda_backend():
|
||||
"""Download the CUDA backend binary."""
|
||||
from ..services import cuda
|
||||
|
||||
if cuda.get_cuda_binary_path() is not None:
|
||||
raise HTTPException(status_code=409, detail="CUDA backend already downloaded")
|
||||
|
||||
progress_manager = get_progress_manager()
|
||||
existing = progress_manager.get_progress(cuda.PROGRESS_KEY)
|
||||
if existing and existing.get("status") == "downloading":
|
||||
raise HTTPException(status_code=409, detail="CUDA backend download already in progress")
|
||||
|
||||
async def _download():
|
||||
try:
|
||||
await cuda.download_cuda_binary()
|
||||
except Exception as e:
|
||||
logger.error("CUDA download failed: %s", e)
|
||||
|
||||
create_background_task(_download())
|
||||
return {"message": "CUDA backend download started", "progress_key": "cuda-backend"}
|
||||
|
||||
|
||||
@router.delete("/backend/cuda")
|
||||
async def delete_cuda_backend():
|
||||
"""Delete the downloaded CUDA backend binary."""
|
||||
from ..services import cuda
|
||||
|
||||
if cuda.is_cuda_active():
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Cannot delete CUDA backend while it is active. Switch to CPU first.",
|
||||
)
|
||||
|
||||
deleted = await cuda.delete_cuda_binary()
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="No CUDA backend found to delete")
|
||||
|
||||
return {"message": "CUDA backend deleted"}
|
||||
|
||||
|
||||
@router.get("/backend/cuda-progress")
|
||||
async def get_cuda_download_progress():
|
||||
"""Get CUDA backend download progress via Server-Sent Events."""
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
async def event_generator():
|
||||
async for event in progress_manager.subscribe("cuda-backend"):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
262
backend/routes/effects.py
Normal file
262
backend/routes/effects.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""Effects presets and generation version endpoints."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import config, models
|
||||
from ..services import history
|
||||
from ..database import Generation as DBGeneration, get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/effects/preview/{generation_id}")
|
||||
async def preview_effects(
|
||||
generation_id: str,
|
||||
data: models.ApplyEffectsRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Apply effects to a generation's clean audio and stream back without saving."""
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
if (gen.status or "completed") != "completed":
|
||||
raise HTTPException(status_code=400, detail="Generation is not completed")
|
||||
|
||||
from ..services import versions as versions_mod
|
||||
from ..utils.effects import apply_effects, validate_effects_chain
|
||||
from ..utils.audio import load_audio
|
||||
|
||||
chain_dicts = [e.model_dump() for e in data.effects_chain]
|
||||
error = validate_effects_chain(chain_dicts)
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
|
||||
all_versions = versions_mod.list_versions(generation_id, db)
|
||||
clean_version = next((v for v in all_versions if v.effects_chain is None), None)
|
||||
source_path = clean_version.audio_path if clean_version else gen.audio_path
|
||||
resolved_source_path = config.resolve_storage_path(source_path)
|
||||
if resolved_source_path is None or not resolved_source_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Source audio file not found")
|
||||
|
||||
audio, sample_rate = await asyncio.to_thread(load_audio, str(resolved_source_path))
|
||||
processed = await asyncio.to_thread(apply_effects, audio, sample_rate, chain_dicts)
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
buf = io.BytesIO()
|
||||
await asyncio.to_thread(lambda: sf.write(buf, processed, sample_rate, format="WAV"))
|
||||
buf.seek(0)
|
||||
|
||||
return StreamingResponse(
|
||||
buf,
|
||||
media_type="audio/wav",
|
||||
headers={
|
||||
"Content-Disposition": f'inline; filename="preview_{generation_id}.wav"',
|
||||
"Cache-Control": "no-cache, no-store",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/effects/available", response_model=models.AvailableEffectsResponse)
|
||||
async def get_available_effects():
|
||||
"""List all available effect types with parameter definitions."""
|
||||
from ..utils.effects import get_available_effects as _get_effects
|
||||
|
||||
return models.AvailableEffectsResponse(effects=[models.AvailableEffect(**e) for e in _get_effects()])
|
||||
|
||||
|
||||
@router.get("/effects/presets", response_model=list[models.EffectPresetResponse])
|
||||
async def list_effect_presets(db: Session = Depends(get_db)):
|
||||
"""List all effect presets (built-in + user-created)."""
|
||||
from ..services import effects as effects_mod
|
||||
|
||||
return effects_mod.list_presets(db)
|
||||
|
||||
|
||||
@router.get("/effects/presets/{preset_id}", response_model=models.EffectPresetResponse)
|
||||
async def get_effect_preset(preset_id: str, db: Session = Depends(get_db)):
|
||||
"""Get a specific effect preset."""
|
||||
from ..services import effects as effects_mod
|
||||
|
||||
preset = effects_mod.get_preset(preset_id, db)
|
||||
if not preset:
|
||||
raise HTTPException(status_code=404, detail="Preset not found")
|
||||
return preset
|
||||
|
||||
|
||||
@router.post("/effects/presets", response_model=models.EffectPresetResponse)
|
||||
async def create_effect_preset(
|
||||
data: models.EffectPresetCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create a new effect preset."""
|
||||
from ..services import effects as effects_mod
|
||||
|
||||
try:
|
||||
return effects_mod.create_preset(data, db)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/effects/presets/{preset_id}", response_model=models.EffectPresetResponse)
|
||||
async def update_effect_preset(
|
||||
preset_id: str,
|
||||
data: models.EffectPresetUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update an effect preset."""
|
||||
from ..services import effects as effects_mod
|
||||
|
||||
try:
|
||||
result = effects_mod.update_preset(preset_id, data, db)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Preset not found")
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/effects/presets/{preset_id}")
|
||||
async def delete_effect_preset(preset_id: str, db: Session = Depends(get_db)):
|
||||
"""Delete a user effect preset."""
|
||||
from ..services import effects as effects_mod
|
||||
|
||||
try:
|
||||
if not effects_mod.delete_preset(preset_id, db):
|
||||
raise HTTPException(status_code=404, detail="Preset not found")
|
||||
return {"status": "deleted"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/generations/{generation_id}/versions",
|
||||
response_model=list[models.GenerationVersionResponse],
|
||||
)
|
||||
async def list_generation_versions(
|
||||
generation_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""List all versions for a generation."""
|
||||
gen = await history.get_generation(generation_id, db)
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
from ..services import versions as versions_mod
|
||||
|
||||
return versions_mod.list_versions(generation_id, db)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/generations/{generation_id}/versions/apply-effects",
|
||||
response_model=models.GenerationVersionResponse,
|
||||
)
|
||||
async def apply_effects_to_generation(
|
||||
generation_id: str,
|
||||
data: models.ApplyEffectsRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Apply an effects chain to an existing generation, creating a new version."""
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
if (gen.status or "completed") != "completed":
|
||||
raise HTTPException(status_code=400, detail="Generation is not completed")
|
||||
|
||||
from ..services import versions as versions_mod
|
||||
from ..utils.effects import apply_effects, validate_effects_chain
|
||||
from ..utils.audio import load_audio, save_audio
|
||||
|
||||
chain_dicts = [e.model_dump() for e in data.effects_chain]
|
||||
error = validate_effects_chain(chain_dicts)
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
|
||||
all_versions = versions_mod.list_versions(generation_id, db)
|
||||
source_version_id = data.source_version_id
|
||||
if source_version_id:
|
||||
source_version = next((v for v in all_versions if v.id == source_version_id), None)
|
||||
if not source_version:
|
||||
raise HTTPException(status_code=404, detail="Source version not found")
|
||||
source_path = source_version.audio_path
|
||||
else:
|
||||
clean_version = next((v for v in all_versions if v.effects_chain is None), None)
|
||||
if not clean_version:
|
||||
source_path = gen.audio_path
|
||||
else:
|
||||
source_path = clean_version.audio_path
|
||||
source_version_id = clean_version.id
|
||||
|
||||
resolved_source_path = config.resolve_storage_path(source_path)
|
||||
if resolved_source_path is None or not resolved_source_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Source audio file not found")
|
||||
|
||||
audio, sample_rate = await asyncio.to_thread(load_audio, str(resolved_source_path))
|
||||
processed_audio = await asyncio.to_thread(apply_effects, audio, sample_rate, chain_dicts)
|
||||
|
||||
version_id = str(uuid.uuid4())
|
||||
processed_path = config.get_generations_dir() / f"{generation_id}_{version_id[:8]}.wav"
|
||||
await asyncio.to_thread(save_audio, processed_audio, str(processed_path), sample_rate)
|
||||
|
||||
label = data.label or f"version-{len(all_versions) + 1}"
|
||||
|
||||
version = versions_mod.create_version(
|
||||
generation_id=generation_id,
|
||||
label=label,
|
||||
audio_path=config.to_storage_path(processed_path),
|
||||
db=db,
|
||||
effects_chain=chain_dicts,
|
||||
is_default=data.set_as_default,
|
||||
source_version_id=source_version_id,
|
||||
)
|
||||
|
||||
return version
|
||||
|
||||
|
||||
@router.put(
|
||||
"/generations/{generation_id}/versions/{version_id}/set-default",
|
||||
response_model=models.GenerationVersionResponse,
|
||||
)
|
||||
async def set_default_version(
|
||||
generation_id: str,
|
||||
version_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Set a specific version as the default for a generation."""
|
||||
from ..services import versions as versions_mod
|
||||
|
||||
version = versions_mod.get_version(version_id, db)
|
||||
if not version or version.generation_id != generation_id:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
result = versions_mod.set_default_version(version_id, db)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/generations/{generation_id}/versions/{version_id}")
|
||||
async def delete_generation_version(
|
||||
generation_id: str,
|
||||
version_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete a version. Cannot delete the last remaining version."""
|
||||
from ..services import versions as versions_mod
|
||||
|
||||
version = versions_mod.get_version(version_id, db)
|
||||
if not version or version.generation_id != generation_id:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
if not versions_mod.delete_version(version_id, db):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete the last remaining version",
|
||||
)
|
||||
return {"status": "deleted"}
|
||||
345
backend/routes/generations.py
Normal file
345
backend/routes/generations.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""TTS generation endpoints."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .. import models
|
||||
from ..services import history, profiles, tts
|
||||
from ..database import Generation as DBGeneration, VoiceProfile as DBVoiceProfile, get_db
|
||||
from ..services.generation import run_generation
|
||||
from ..services.task_queue import cancel_generation as cancel_generation_job, enqueue_generation
|
||||
from ..utils.tasks import get_task_manager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _resolve_generation_engine(data: models.GenerationRequest, profile) -> str:
|
||||
return data.engine or getattr(profile, "default_engine", None) or getattr(profile, "preset_engine", None) or "qwen"
|
||||
|
||||
|
||||
@router.post("/generate", response_model=models.GenerationResponse)
|
||||
async def generate_speech(
|
||||
data: models.GenerationRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Generate speech from text using a voice profile."""
|
||||
task_manager = get_task_manager()
|
||||
generation_id = str(uuid.uuid4())
|
||||
|
||||
profile = await profiles.get_profile(data.profile_id, db)
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
|
||||
from ..backends import engine_has_model_sizes
|
||||
|
||||
engine = _resolve_generation_engine(data, profile)
|
||||
try:
|
||||
profiles.validate_profile_engine(profile, engine)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
model_size = (data.model_size or "1.7B") if engine_has_model_sizes(engine) else None
|
||||
|
||||
generation = await history.create_generation(
|
||||
profile_id=data.profile_id,
|
||||
text=data.text,
|
||||
language=data.language,
|
||||
audio_path="",
|
||||
duration=0,
|
||||
seed=data.seed,
|
||||
db=db,
|
||||
instruct=data.instruct,
|
||||
generation_id=generation_id,
|
||||
status="generating",
|
||||
engine=engine,
|
||||
model_size=model_size if engine_has_model_sizes(engine) else None,
|
||||
)
|
||||
|
||||
task_manager.start_generation(
|
||||
task_id=generation_id,
|
||||
profile_id=data.profile_id,
|
||||
text=data.text,
|
||||
)
|
||||
|
||||
effects_chain_config = None
|
||||
if data.effects_chain is not None:
|
||||
effects_chain_config = [e.model_dump() for e in data.effects_chain]
|
||||
else:
|
||||
import json as _json
|
||||
|
||||
profile_obj = db.query(DBVoiceProfile).filter_by(id=data.profile_id).first()
|
||||
if profile_obj and profile_obj.effects_chain:
|
||||
try:
|
||||
effects_chain_config = _json.loads(profile_obj.effects_chain)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
enqueue_generation(
|
||||
generation_id,
|
||||
run_generation(
|
||||
generation_id=generation_id,
|
||||
profile_id=data.profile_id,
|
||||
text=data.text,
|
||||
language=data.language,
|
||||
engine=engine,
|
||||
model_size=model_size,
|
||||
seed=data.seed,
|
||||
normalize=data.normalize,
|
||||
effects_chain=effects_chain_config,
|
||||
instruct=data.instruct,
|
||||
mode="generate",
|
||||
max_chunk_chars=data.max_chunk_chars,
|
||||
crossfade_ms=data.crossfade_ms,
|
||||
)
|
||||
)
|
||||
|
||||
return generation
|
||||
|
||||
|
||||
@router.post("/generate/{generation_id}/retry", response_model=models.GenerationResponse)
|
||||
async def retry_generation(generation_id: str, db: Session = Depends(get_db)):
|
||||
"""Retry a failed generation using the same parameters."""
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
if (gen.status or "completed") != "failed":
|
||||
raise HTTPException(status_code=400, detail="Only failed generations can be retried")
|
||||
|
||||
gen.status = "generating"
|
||||
gen.error = None
|
||||
gen.audio_path = ""
|
||||
gen.duration = 0
|
||||
db.commit()
|
||||
db.refresh(gen)
|
||||
|
||||
task_manager = get_task_manager()
|
||||
task_manager.start_generation(
|
||||
task_id=generation_id,
|
||||
profile_id=gen.profile_id,
|
||||
text=gen.text,
|
||||
)
|
||||
|
||||
enqueue_generation(
|
||||
generation_id,
|
||||
run_generation(
|
||||
generation_id=generation_id,
|
||||
profile_id=gen.profile_id,
|
||||
text=gen.text,
|
||||
language=gen.language,
|
||||
engine=gen.engine or "qwen",
|
||||
model_size=gen.model_size or "1.7B",
|
||||
seed=gen.seed,
|
||||
instruct=gen.instruct,
|
||||
mode="retry",
|
||||
)
|
||||
)
|
||||
|
||||
return models.GenerationResponse.model_validate(gen)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/generate/{generation_id}/regenerate",
|
||||
response_model=models.GenerationResponse,
|
||||
)
|
||||
async def regenerate_generation(generation_id: str, db: Session = Depends(get_db)):
|
||||
"""Re-run TTS with the same parameters and save the result as a new version."""
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
if (gen.status or "completed") != "completed":
|
||||
raise HTTPException(status_code=400, detail="Generation must be completed to regenerate")
|
||||
|
||||
gen.status = "generating"
|
||||
gen.error = None
|
||||
db.commit()
|
||||
db.refresh(gen)
|
||||
|
||||
task_manager = get_task_manager()
|
||||
task_manager.start_generation(
|
||||
task_id=generation_id,
|
||||
profile_id=gen.profile_id,
|
||||
text=gen.text,
|
||||
)
|
||||
|
||||
version_id = str(uuid.uuid4())
|
||||
|
||||
enqueue_generation(
|
||||
generation_id,
|
||||
run_generation(
|
||||
generation_id=generation_id,
|
||||
profile_id=gen.profile_id,
|
||||
text=gen.text,
|
||||
language=gen.language,
|
||||
engine=gen.engine or "qwen",
|
||||
model_size=gen.model_size or "1.7B",
|
||||
seed=gen.seed,
|
||||
instruct=gen.instruct,
|
||||
mode="regenerate",
|
||||
version_id=version_id,
|
||||
)
|
||||
)
|
||||
|
||||
return models.GenerationResponse.model_validate(gen)
|
||||
|
||||
|
||||
@router.post("/generate/{generation_id}/cancel")
|
||||
async def cancel_generation(generation_id: str, db: Session = Depends(get_db)):
|
||||
"""Cancel a queued or running generation."""
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
if (gen.status or "completed") not in ("loading_model", "generating"):
|
||||
raise HTTPException(status_code=400, detail="Only active generations can be cancelled")
|
||||
|
||||
cancellation_state = cancel_generation_job(generation_id)
|
||||
if cancellation_state is None:
|
||||
raise HTTPException(status_code=409, detail="Generation is no longer cancellable")
|
||||
|
||||
if cancellation_state == "queued":
|
||||
task_manager = get_task_manager()
|
||||
task_manager.complete_generation(generation_id)
|
||||
await history.update_generation_status(
|
||||
generation_id=generation_id,
|
||||
status="failed",
|
||||
db=db,
|
||||
error="Generation cancelled",
|
||||
)
|
||||
return {"message": "Queued generation cancelled"}
|
||||
|
||||
return {"message": "Generation cancellation requested"}
|
||||
|
||||
|
||||
@router.get("/generate/{generation_id}/status")
|
||||
async def get_generation_status(generation_id: str, db: Session = Depends(get_db)):
|
||||
"""SSE endpoint that streams generation status updates."""
|
||||
import json
|
||||
|
||||
async def event_stream():
|
||||
try:
|
||||
while True:
|
||||
db.expire_all()
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not gen:
|
||||
yield f"data: {json.dumps({'status': 'not_found', 'id': generation_id})}\n\n"
|
||||
return
|
||||
|
||||
payload = {
|
||||
"id": gen.id,
|
||||
"status": gen.status or "completed",
|
||||
"duration": gen.duration,
|
||||
"error": gen.error,
|
||||
}
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
if (gen.status or "completed") in ("completed", "failed"):
|
||||
return
|
||||
|
||||
await asyncio.sleep(1)
|
||||
except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError):
|
||||
logger.debug("SSE client disconnected for generation %s", generation_id)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/generate/stream")
|
||||
async def stream_speech(
|
||||
data: models.GenerationRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Generate speech and stream the WAV audio directly without saving to disk."""
|
||||
from ..backends import get_tts_backend_for_engine, ensure_model_cached_or_raise, load_engine_model, engine_needs_trim
|
||||
|
||||
profile = await profiles.get_profile(data.profile_id, db)
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
|
||||
engine = _resolve_generation_engine(data, profile)
|
||||
try:
|
||||
profiles.validate_profile_engine(profile, engine)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
tts_model = get_tts_backend_for_engine(engine)
|
||||
model_size = data.model_size or "1.7B"
|
||||
|
||||
await ensure_model_cached_or_raise(engine, model_size)
|
||||
await load_engine_model(engine, model_size)
|
||||
|
||||
voice_prompt = await profiles.create_voice_prompt_for_profile(
|
||||
data.profile_id,
|
||||
db,
|
||||
engine=engine,
|
||||
)
|
||||
|
||||
from ..utils.chunked_tts import generate_chunked
|
||||
|
||||
trim_fn = None
|
||||
if engine_needs_trim(engine):
|
||||
from ..utils.audio import trim_tts_output
|
||||
|
||||
trim_fn = trim_tts_output
|
||||
|
||||
audio, sample_rate = await generate_chunked(
|
||||
tts_model,
|
||||
data.text,
|
||||
voice_prompt,
|
||||
language=data.language,
|
||||
seed=data.seed,
|
||||
instruct=data.instruct,
|
||||
max_chunk_chars=data.max_chunk_chars,
|
||||
crossfade_ms=data.crossfade_ms,
|
||||
trim_fn=trim_fn,
|
||||
)
|
||||
|
||||
effects_chain_config = None
|
||||
if data.effects_chain is not None:
|
||||
effects_chain_config = [e.model_dump() for e in data.effects_chain]
|
||||
elif profile.effects_chain:
|
||||
import json as _json
|
||||
|
||||
try:
|
||||
effects_chain_config = _json.loads(profile.effects_chain)
|
||||
except Exception:
|
||||
effects_chain_config = None
|
||||
|
||||
if effects_chain_config:
|
||||
from ..utils.effects import apply_effects
|
||||
|
||||
audio = apply_effects(audio, sample_rate, effects_chain_config)
|
||||
|
||||
if data.normalize:
|
||||
from ..utils.audio import normalize_audio
|
||||
|
||||
audio = normalize_audio(audio)
|
||||
|
||||
wav_bytes = tts.audio_to_wav_bytes(audio, sample_rate)
|
||||
|
||||
async def _wav_stream():
|
||||
try:
|
||||
chunk_size = 64 * 1024
|
||||
for i in range(0, len(wav_bytes), chunk_size):
|
||||
yield wav_bytes[i : i + chunk_size]
|
||||
except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError):
|
||||
logger.debug("Client disconnected during audio stream")
|
||||
|
||||
return StreamingResponse(
|
||||
_wav_stream(),
|
||||
media_type="audio/wav",
|
||||
headers={"Content-Disposition": 'attachment; filename="speech.wav"'},
|
||||
)
|
||||
248
backend/routes/health.py
Normal file
248
backend/routes/health.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Health and infrastructure endpoints."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import config, models
|
||||
from ..services import tts
|
||||
from ..database import get_db
|
||||
from ..utils.platform_detect import get_backend_type
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Frontend build directory — present in Docker, absent in dev/API-only mode
|
||||
_frontend_dir = Path(__file__).resolve().parent.parent.parent / "frontend"
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root():
|
||||
"""Root endpoint — serves SPA index.html in Docker, JSON otherwise."""
|
||||
from .. import __version__
|
||||
|
||||
index = _frontend_dir / "index.html"
|
||||
if index.is_file():
|
||||
return FileResponse(index, media_type="text/html")
|
||||
return {"message": "voicebox API", "version": __version__}
|
||||
|
||||
|
||||
@router.post("/shutdown")
|
||||
async def shutdown():
|
||||
"""Gracefully shutdown the server."""
|
||||
|
||||
async def shutdown_async():
|
||||
await asyncio.sleep(0.1)
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
|
||||
asyncio.create_task(shutdown_async())
|
||||
return {"message": "Shutting down..."}
|
||||
|
||||
|
||||
@router.post("/watchdog/disable")
|
||||
async def watchdog_disable():
|
||||
"""Disable the parent process watchdog so the server keeps running."""
|
||||
from backend.server import disable_watchdog
|
||||
|
||||
disable_watchdog()
|
||||
return {"message": "Watchdog disabled"}
|
||||
|
||||
|
||||
@router.get("/health", response_model=models.HealthResponse)
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
from huggingface_hub import constants as hf_constants
|
||||
from pathlib import Path
|
||||
|
||||
tts_model = tts.get_tts_model()
|
||||
backend_type = get_backend_type()
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
has_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
|
||||
has_xpu = False
|
||||
xpu_name = None
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa: F401 -- side-effect import enables XPU
|
||||
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
has_xpu = True
|
||||
try:
|
||||
xpu_name = torch.xpu.get_device_name(0)
|
||||
except Exception:
|
||||
xpu_name = "Intel GPU"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
has_directml = False
|
||||
directml_name = None
|
||||
try:
|
||||
import torch_directml
|
||||
|
||||
if torch_directml.device_count() > 0:
|
||||
has_directml = True
|
||||
try:
|
||||
directml_name = torch_directml.device_name(0)
|
||||
except Exception:
|
||||
directml_name = "DirectML GPU"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
gpu_compat_warning = None
|
||||
if has_cuda:
|
||||
from ..backends.base import check_cuda_compatibility
|
||||
|
||||
_compatible, gpu_compat_warning = check_cuda_compatibility()
|
||||
|
||||
gpu_available = has_cuda or has_mps or has_xpu or has_directml or backend_type == "mlx"
|
||||
|
||||
gpu_type = None
|
||||
if has_cuda:
|
||||
gpu_type = f"CUDA ({torch.cuda.get_device_name(0)})"
|
||||
elif has_mps:
|
||||
gpu_type = "MPS (Apple Silicon)"
|
||||
elif backend_type == "mlx":
|
||||
gpu_type = "Metal (Apple Silicon via MLX)"
|
||||
elif has_xpu:
|
||||
gpu_type = f"XPU ({xpu_name})"
|
||||
elif has_directml:
|
||||
gpu_type = f"DirectML ({directml_name})"
|
||||
|
||||
vram_used = None
|
||||
if has_cuda:
|
||||
vram_used = torch.cuda.memory_allocated() / 1024 / 1024
|
||||
elif has_xpu:
|
||||
try:
|
||||
vram_used = torch.xpu.memory_allocated() / 1024 / 1024
|
||||
except Exception:
|
||||
pass # memory_allocated() may not be available on all IPEX versions
|
||||
|
||||
model_loaded = False
|
||||
model_size = None
|
||||
try:
|
||||
if tts_model.is_loaded():
|
||||
model_loaded = True
|
||||
model_size = getattr(tts_model, "_current_model_size", None)
|
||||
if not model_size:
|
||||
model_size = getattr(tts_model, "model_size", None)
|
||||
except Exception:
|
||||
model_loaded = False
|
||||
model_size = None
|
||||
|
||||
model_downloaded = None
|
||||
try:
|
||||
from ..backends import get_model_config
|
||||
|
||||
default_config = get_model_config("qwen-tts-1.7B")
|
||||
default_model_id = default_config.hf_repo_id if default_config else "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
|
||||
try:
|
||||
from huggingface_hub import scan_cache_dir
|
||||
|
||||
cache_info = scan_cache_dir()
|
||||
for repo in cache_info.repos:
|
||||
if repo.repo_id == default_model_id:
|
||||
model_downloaded = True
|
||||
break
|
||||
except (ImportError, Exception):
|
||||
cache_dir = hf_constants.HF_HUB_CACHE
|
||||
repo_cache = Path(cache_dir) / ("models--" + default_model_id.replace("/", "--"))
|
||||
if repo_cache.exists():
|
||||
has_model_files = (
|
||||
any(repo_cache.rglob("*.bin"))
|
||||
or any(repo_cache.rglob("*.safetensors"))
|
||||
or any(repo_cache.rglob("*.pt"))
|
||||
or any(repo_cache.rglob("*.pth"))
|
||||
or any(repo_cache.rglob("*.npz"))
|
||||
)
|
||||
model_downloaded = has_model_files
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return models.HealthResponse(
|
||||
status="healthy",
|
||||
model_loaded=model_loaded,
|
||||
model_downloaded=model_downloaded,
|
||||
model_size=model_size,
|
||||
gpu_available=gpu_available,
|
||||
gpu_type=gpu_type,
|
||||
vram_used_mb=vram_used,
|
||||
backend_type=backend_type,
|
||||
backend_variant=os.environ.get(
|
||||
"VOICEBOX_BACKEND_VARIANT",
|
||||
"cuda" if torch.cuda.is_available() else ("xpu" if has_xpu else "cpu"),
|
||||
),
|
||||
gpu_compatibility_warning=gpu_compat_warning,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health/filesystem", response_model=models.FilesystemHealthResponse)
|
||||
async def filesystem_health():
|
||||
"""Check filesystem health: directory existence, write permissions, and disk space."""
|
||||
import shutil
|
||||
|
||||
dirs_to_check = {
|
||||
"generations": config.get_generations_dir(),
|
||||
"profiles": config.get_profiles_dir(),
|
||||
"data": config.get_data_dir(),
|
||||
}
|
||||
|
||||
checks: list[models.DirectoryCheck] = []
|
||||
all_ok = True
|
||||
|
||||
for _label, dir_path in dirs_to_check.items():
|
||||
exists = dir_path.exists()
|
||||
writable = False
|
||||
error = None
|
||||
if exists:
|
||||
probe = dir_path / ".voicebox_probe"
|
||||
try:
|
||||
probe.write_text("ok")
|
||||
probe.unlink()
|
||||
writable = True
|
||||
except PermissionError:
|
||||
error = "Permission denied"
|
||||
except OSError as e:
|
||||
error = str(e)
|
||||
finally:
|
||||
try:
|
||||
probe.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
error = "Directory does not exist"
|
||||
|
||||
if not exists or not writable:
|
||||
all_ok = False
|
||||
|
||||
checks.append(
|
||||
models.DirectoryCheck(
|
||||
path=str(dir_path.resolve()),
|
||||
exists=exists,
|
||||
writable=writable,
|
||||
error=error,
|
||||
)
|
||||
)
|
||||
|
||||
disk_free_mb = None
|
||||
disk_total_mb = None
|
||||
try:
|
||||
usage = shutil.disk_usage(str(config.get_data_dir()))
|
||||
disk_free_mb = round(usage.free / (1024 * 1024), 1)
|
||||
disk_total_mb = round(usage.total / (1024 * 1024), 1)
|
||||
if disk_free_mb < 500:
|
||||
all_ok = False
|
||||
except OSError:
|
||||
all_ok = False
|
||||
|
||||
return models.FilesystemHealthResponse(
|
||||
healthy=all_ok,
|
||||
disk_free_mb=disk_free_mb,
|
||||
disk_total_mb=disk_total_mb,
|
||||
directories=checks,
|
||||
)
|
||||
189
backend/routes/history.py
Normal file
189
backend/routes/history.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Generation history endpoints."""
|
||||
|
||||
import io
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import config, models
|
||||
from ..services import export_import, history
|
||||
from ..app import safe_content_disposition
|
||||
from ..database import Generation as DBGeneration, VoiceProfile as DBVoiceProfile, get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/history", response_model=models.HistoryListResponse)
|
||||
async def list_history(
|
||||
profile_id: str | None = None,
|
||||
search: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""List generation history with optional filters."""
|
||||
query = models.HistoryQuery(
|
||||
profile_id=profile_id,
|
||||
search=search,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return await history.list_generations(query, db)
|
||||
|
||||
|
||||
@router.get("/history/stats")
|
||||
async def get_stats(db: Session = Depends(get_db)):
|
||||
"""Get generation statistics."""
|
||||
return await history.get_generation_stats(db)
|
||||
|
||||
|
||||
@router.post("/history/import")
|
||||
async def import_generation(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Import a generation from a ZIP archive."""
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024
|
||||
|
||||
content = await file.read()
|
||||
|
||||
if len(content) > MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"File too large. Maximum size is {MAX_FILE_SIZE / (1024 * 1024)}MB"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await export_import.import_generation_from_zip(content, db)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/history/failed")
|
||||
async def clear_failed_generations(db: Session = Depends(get_db)):
|
||||
"""Delete every generation with status='failed'. Used by the UI's 'Clear failed' button (#410)."""
|
||||
count = await history.delete_failed_generations(db)
|
||||
return {"deleted": count}
|
||||
|
||||
|
||||
@router.get("/history/{generation_id}", response_model=models.HistoryResponse)
|
||||
async def get_generation(
|
||||
generation_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get a generation by ID."""
|
||||
result = (
|
||||
db.query(DBGeneration, DBVoiceProfile.name.label("profile_name"))
|
||||
.join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id)
|
||||
.filter(DBGeneration.id == generation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
gen, profile_name = result
|
||||
return models.HistoryResponse(
|
||||
id=gen.id,
|
||||
profile_id=gen.profile_id,
|
||||
profile_name=profile_name,
|
||||
text=gen.text,
|
||||
language=gen.language,
|
||||
audio_path=gen.audio_path,
|
||||
duration=gen.duration,
|
||||
seed=gen.seed,
|
||||
instruct=gen.instruct,
|
||||
engine=gen.engine or "qwen",
|
||||
model_size=gen.model_size,
|
||||
status=gen.status or "completed",
|
||||
error=gen.error,
|
||||
is_favorited=bool(gen.is_favorited),
|
||||
created_at=gen.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/history/{generation_id}/favorite")
|
||||
async def toggle_favorite(
|
||||
generation_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Toggle the favorite status of a generation."""
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
gen.is_favorited = not gen.is_favorited
|
||||
db.commit()
|
||||
return {"is_favorited": gen.is_favorited}
|
||||
|
||||
|
||||
@router.delete("/history/{generation_id}")
|
||||
async def delete_generation(
|
||||
generation_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete a generation."""
|
||||
success = await history.delete_generation(generation_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
return {"message": "Generation deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/history/{generation_id}/export")
|
||||
async def export_generation(
|
||||
generation_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Export a generation as a ZIP archive."""
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
try:
|
||||
zip_bytes = export_import.export_generation_to_zip(generation_id, db)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
safe_text = "".join(c for c in generation.text[:30] if c.isalnum() or c in (" ", "-", "_")).strip()
|
||||
if not safe_text:
|
||||
safe_text = "generation"
|
||||
filename = f"generation-{safe_text}.voicebox.zip"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(zip_bytes),
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history/{generation_id}/export-audio")
|
||||
async def export_generation_audio(
|
||||
generation_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Export only the audio file from a generation."""
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
if not generation.audio_path:
|
||||
raise HTTPException(status_code=404, detail="Generation has no audio file")
|
||||
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is None or not audio_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
safe_text = "".join(c for c in generation.text[:30] if c.isalnum() or c in (" ", "-", "_")).strip()
|
||||
if not safe_text:
|
||||
safe_text = "generation"
|
||||
filename = f"{safe_text}.wav"
|
||||
|
||||
return FileResponse(
|
||||
audio_path,
|
||||
media_type="audio/wav",
|
||||
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
|
||||
)
|
||||
475
backend/routes/models.py
Normal file
475
backend/routes/models.py
Normal file
@@ -0,0 +1,475 @@
|
||||
"""Model management endpoints."""
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import models
|
||||
from ..utils.platform_detect import get_backend_type
|
||||
from ..services.task_queue import create_background_task
|
||||
from ..utils.progress import get_progress_manager
|
||||
from ..utils.tasks import get_task_manager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _get_dir_size(path: Path) -> int:
|
||||
"""Get total size of a directory in bytes."""
|
||||
total = 0
|
||||
for f in path.rglob("*"):
|
||||
if f.is_file():
|
||||
total += f.stat().st_size
|
||||
return total
|
||||
|
||||
|
||||
def _copy_with_progress(src: Path, dst: Path, progress_manager, copied_so_far: int, total_bytes: int) -> int:
|
||||
"""Copy a directory tree with byte-level progress tracking."""
|
||||
dst.mkdir(parents=True, exist_ok=True)
|
||||
for item in src.iterdir():
|
||||
dest_item = dst / item.name
|
||||
if item.is_dir():
|
||||
copied_so_far = _copy_with_progress(item, dest_item, progress_manager, copied_so_far, total_bytes)
|
||||
else:
|
||||
size = item.stat().st_size
|
||||
shutil.copy2(str(item), str(dest_item))
|
||||
copied_so_far += size
|
||||
progress_manager.update_progress(
|
||||
"migration",
|
||||
copied_so_far,
|
||||
total_bytes,
|
||||
filename=item.name,
|
||||
status="downloading",
|
||||
)
|
||||
return copied_so_far
|
||||
|
||||
|
||||
@router.post("/models/load")
|
||||
async def load_model(model_size: str = "1.7B"):
|
||||
"""Manually load TTS model."""
|
||||
from ..services import tts
|
||||
|
||||
try:
|
||||
tts_model = tts.get_tts_model()
|
||||
await tts_model.load_model_async(model_size)
|
||||
return {"message": f"Model {model_size} loaded successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/models/unload")
|
||||
async def unload_model():
|
||||
"""Unload the default Qwen TTS model to free memory."""
|
||||
from ..services import tts
|
||||
|
||||
try:
|
||||
tts.unload_tts_model()
|
||||
return {"message": "Model unloaded successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/models/{model_name}/unload")
|
||||
async def unload_model_by_name(model_name: str):
|
||||
"""Unload a specific model from memory without deleting it from disk."""
|
||||
from ..backends import get_model_config, unload_model_by_config
|
||||
|
||||
config = get_model_config(model_name)
|
||||
if not config:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
|
||||
|
||||
try:
|
||||
was_loaded = unload_model_by_config(config)
|
||||
if not was_loaded:
|
||||
return {"message": f"Model {model_name} is not loaded"}
|
||||
return {"message": f"Model {model_name} unloaded successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/models/progress/{model_name}")
|
||||
async def get_model_progress(model_name: str):
|
||||
"""Get model download progress via Server-Sent Events."""
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
async def event_generator():
|
||||
async for event in progress_manager.subscribe(model_name):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models/cache-dir")
|
||||
async def get_models_cache_dir():
|
||||
"""Get the path to the HuggingFace model cache directory."""
|
||||
from huggingface_hub import constants as hf_constants
|
||||
|
||||
return {"path": str(Path(hf_constants.HF_HUB_CACHE))}
|
||||
|
||||
|
||||
@router.post("/models/migrate")
|
||||
async def migrate_models(request: models.ModelMigrateRequest):
|
||||
"""Move all downloaded models to a new directory with byte-level progress via SSE."""
|
||||
from huggingface_hub import constants as hf_constants
|
||||
|
||||
source = Path(hf_constants.HF_HUB_CACHE)
|
||||
destination = Path(request.destination)
|
||||
|
||||
if not source.exists():
|
||||
raise HTTPException(status_code=404, detail="Current model cache directory not found")
|
||||
|
||||
if source.resolve() == destination.resolve():
|
||||
raise HTTPException(status_code=400, detail="Source and destination are the same directory")
|
||||
|
||||
if destination.resolve().is_relative_to(source.resolve()):
|
||||
raise HTTPException(status_code=400, detail="Destination cannot be inside the current cache directory")
|
||||
|
||||
progress_manager = get_progress_manager()
|
||||
model_dirs = [d for d in source.iterdir() if d.name.startswith("models--") and d.is_dir()]
|
||||
if not model_dirs:
|
||||
progress_manager.update_progress("migration", 1, 1, status="complete")
|
||||
progress_manager.mark_complete("migration")
|
||||
return {"moved": 0, "errors": [], "source": str(source), "destination": str(destination)}
|
||||
|
||||
destination.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
same_fs = False
|
||||
try:
|
||||
same_fs = source.stat().st_dev == destination.stat().st_dev
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def migrate_background():
|
||||
moved = 0
|
||||
errors = []
|
||||
try:
|
||||
if same_fs:
|
||||
total = len(model_dirs)
|
||||
for i, item in enumerate(model_dirs):
|
||||
dest_item = destination / item.name
|
||||
try:
|
||||
if dest_item.exists():
|
||||
shutil.rmtree(dest_item)
|
||||
shutil.move(str(item), str(dest_item))
|
||||
moved += 1
|
||||
progress_manager.update_progress(
|
||||
"migration",
|
||||
i + 1,
|
||||
total,
|
||||
filename=item.name,
|
||||
status="downloading",
|
||||
)
|
||||
except Exception as e:
|
||||
errors.append(f"{item.name}: {str(e)}")
|
||||
else:
|
||||
total_bytes = sum(_get_dir_size(d) for d in model_dirs)
|
||||
progress_manager.update_progress(
|
||||
"migration", 0, total_bytes, filename="Calculating...", status="downloading"
|
||||
)
|
||||
|
||||
copied = 0
|
||||
for item in model_dirs:
|
||||
dest_item = destination / item.name
|
||||
try:
|
||||
if dest_item.exists():
|
||||
shutil.rmtree(dest_item)
|
||||
copied = await asyncio.to_thread(
|
||||
_copy_with_progress, item, dest_item, progress_manager, copied, total_bytes
|
||||
)
|
||||
await asyncio.to_thread(shutil.rmtree, str(item))
|
||||
moved += 1
|
||||
except Exception as e:
|
||||
errors.append(f"{item.name}: {str(e)}")
|
||||
|
||||
progress_manager.update_progress("migration", 1, 1, status="complete")
|
||||
progress_manager.mark_complete("migration")
|
||||
except Exception as e:
|
||||
progress_manager.update_progress("migration", 0, 0, status="error")
|
||||
progress_manager.mark_error("migration", str(e))
|
||||
|
||||
create_background_task(migrate_background())
|
||||
|
||||
return {"source": str(source), "destination": str(destination)}
|
||||
|
||||
|
||||
@router.get("/models/migrate/progress")
|
||||
async def get_migration_progress():
|
||||
"""Get model migration progress via Server-Sent Events."""
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
async def event_generator():
|
||||
async for event in progress_manager.subscribe("migration"):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models/status", response_model=models.ModelStatusListResponse)
|
||||
async def get_model_status():
|
||||
"""Get status of all available models."""
|
||||
from huggingface_hub import constants as hf_constants
|
||||
|
||||
backend_type = get_backend_type()
|
||||
task_manager = get_task_manager()
|
||||
|
||||
active_download_names = {task.model_name for task in task_manager.get_active_downloads()}
|
||||
|
||||
try:
|
||||
from huggingface_hub import scan_cache_dir
|
||||
|
||||
use_scan_cache = True
|
||||
except ImportError:
|
||||
use_scan_cache = False
|
||||
|
||||
from ..backends import get_all_model_configs, check_model_loaded
|
||||
|
||||
registry_configs = get_all_model_configs()
|
||||
model_configs = [
|
||||
{
|
||||
"model_name": cfg.model_name,
|
||||
"display_name": cfg.display_name,
|
||||
"hf_repo_id": cfg.hf_repo_id,
|
||||
"model_size": cfg.model_size,
|
||||
"check_loaded": lambda c=cfg: check_model_loaded(c),
|
||||
}
|
||||
for cfg in registry_configs
|
||||
]
|
||||
|
||||
model_to_repo = {cfg["model_name"]: cfg["hf_repo_id"] for cfg in model_configs}
|
||||
active_download_repos = {model_to_repo.get(name) for name in active_download_names if name in model_to_repo}
|
||||
|
||||
cache_info = None
|
||||
if use_scan_cache:
|
||||
try:
|
||||
cache_info = scan_cache_dir()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
statuses = []
|
||||
|
||||
for config in model_configs:
|
||||
try:
|
||||
downloaded = False
|
||||
size_mb = None
|
||||
loaded = False
|
||||
|
||||
if cache_info:
|
||||
repo_id = config["hf_repo_id"]
|
||||
for repo in cache_info.repos:
|
||||
if repo.repo_id == repo_id:
|
||||
has_model_weights = False
|
||||
for rev in repo.revisions:
|
||||
for f in rev.files:
|
||||
fname = f.file_name.lower()
|
||||
if fname.endswith((".safetensors", ".bin", ".pt", ".pth", ".npz")):
|
||||
has_model_weights = True
|
||||
break
|
||||
if has_model_weights:
|
||||
break
|
||||
|
||||
has_incomplete = False
|
||||
try:
|
||||
cache_dir = hf_constants.HF_HUB_CACHE
|
||||
blobs_dir = Path(cache_dir) / ("models--" + repo_id.replace("/", "--")) / "blobs"
|
||||
if blobs_dir.exists():
|
||||
has_incomplete = any(blobs_dir.glob("*.incomplete"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if has_model_weights and not has_incomplete:
|
||||
downloaded = True
|
||||
try:
|
||||
total_size = sum(revision.size_on_disk for revision in repo.revisions)
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
if not downloaded:
|
||||
try:
|
||||
cache_dir = hf_constants.HF_HUB_CACHE
|
||||
repo_cache = Path(cache_dir) / ("models--" + config["hf_repo_id"].replace("/", "--"))
|
||||
|
||||
if repo_cache.exists():
|
||||
blobs_dir = repo_cache / "blobs"
|
||||
has_incomplete = blobs_dir.exists() and any(blobs_dir.glob("*.incomplete"))
|
||||
|
||||
if not has_incomplete:
|
||||
snapshots_dir = repo_cache / "snapshots"
|
||||
has_model_files = False
|
||||
if snapshots_dir.exists():
|
||||
has_model_files = (
|
||||
any(snapshots_dir.rglob("*.bin"))
|
||||
or any(snapshots_dir.rglob("*.safetensors"))
|
||||
or any(snapshots_dir.rglob("*.pt"))
|
||||
or any(snapshots_dir.rglob("*.pth"))
|
||||
or any(snapshots_dir.rglob("*.npz"))
|
||||
)
|
||||
|
||||
if has_model_files:
|
||||
downloaded = True
|
||||
try:
|
||||
total_size = sum(
|
||||
f.stat().st_size
|
||||
for f in repo_cache.rglob("*")
|
||||
if f.is_file() and not f.name.endswith(".incomplete")
|
||||
)
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
loaded = config["check_loaded"]()
|
||||
except Exception:
|
||||
loaded = False
|
||||
|
||||
is_downloading = config["hf_repo_id"] in active_download_repos
|
||||
|
||||
if is_downloading:
|
||||
downloaded = False
|
||||
size_mb = None
|
||||
|
||||
statuses.append(
|
||||
models.ModelStatus(
|
||||
model_name=config["model_name"],
|
||||
display_name=config["display_name"],
|
||||
hf_repo_id=config["hf_repo_id"],
|
||||
downloaded=downloaded,
|
||||
downloading=is_downloading,
|
||||
size_mb=size_mb,
|
||||
loaded=loaded,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
loaded = config["check_loaded"]()
|
||||
except Exception:
|
||||
loaded = False
|
||||
|
||||
is_downloading = config["hf_repo_id"] in active_download_repos
|
||||
|
||||
statuses.append(
|
||||
models.ModelStatus(
|
||||
model_name=config["model_name"],
|
||||
display_name=config["display_name"],
|
||||
hf_repo_id=config["hf_repo_id"],
|
||||
downloaded=False,
|
||||
downloading=is_downloading,
|
||||
size_mb=None,
|
||||
loaded=loaded,
|
||||
)
|
||||
)
|
||||
|
||||
return models.ModelStatusListResponse(models=statuses)
|
||||
|
||||
|
||||
@router.post("/models/download")
|
||||
async def trigger_model_download(request: models.ModelDownloadRequest):
|
||||
"""Trigger download of a specific model."""
|
||||
from ..backends import get_model_config, get_model_load_func
|
||||
|
||||
task_manager = get_task_manager()
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
config = get_model_config(request.model_name)
|
||||
if not config:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown model: {request.model_name}")
|
||||
|
||||
load_func = get_model_load_func(config)
|
||||
|
||||
async def download_in_background():
|
||||
try:
|
||||
result = load_func()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
task_manager.complete_download(request.model_name)
|
||||
except Exception as e:
|
||||
task_manager.error_download(request.model_name, str(e))
|
||||
|
||||
task_manager.start_download(request.model_name)
|
||||
|
||||
progress_manager.update_progress(
|
||||
model_name=request.model_name,
|
||||
current=0,
|
||||
total=0,
|
||||
filename="Connecting to HuggingFace...",
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
create_background_task(download_in_background())
|
||||
|
||||
return {"message": f"Model {request.model_name} download started"}
|
||||
|
||||
|
||||
@router.post("/models/download/cancel")
|
||||
async def cancel_model_download(request: models.ModelDownloadRequest):
|
||||
"""Cancel or dismiss an errored/stale download task."""
|
||||
task_manager = get_task_manager()
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
removed = task_manager.cancel_download(request.model_name)
|
||||
|
||||
progress_removed = False
|
||||
with progress_manager._lock:
|
||||
if request.model_name in progress_manager._progress:
|
||||
del progress_manager._progress[request.model_name]
|
||||
progress_removed = True
|
||||
|
||||
if removed or progress_removed:
|
||||
return {"message": f"Download task for {request.model_name} cancelled"}
|
||||
return {"message": f"No active task found for {request.model_name}"}
|
||||
|
||||
|
||||
@router.delete("/models/{model_name}")
|
||||
async def delete_model(model_name: str):
|
||||
"""Delete a downloaded model from the HuggingFace cache."""
|
||||
from huggingface_hub import constants as hf_constants
|
||||
from ..backends import get_model_config, unload_model_by_config
|
||||
|
||||
config = get_model_config(model_name)
|
||||
if not config:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
|
||||
|
||||
hf_repo_id = config.hf_repo_id
|
||||
|
||||
try:
|
||||
unload_model_by_config(config)
|
||||
|
||||
cache_dir = hf_constants.HF_HUB_CACHE
|
||||
repo_cache_dir = Path(cache_dir) / ("models--" + hf_repo_id.replace("/", "--"))
|
||||
|
||||
if not repo_cache_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found in cache")
|
||||
|
||||
try:
|
||||
shutil.rmtree(repo_cache_dir)
|
||||
except OSError as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete model cache directory: {str(e)}")
|
||||
|
||||
return {"message": f"Model {model_name} deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
|
||||
363
backend/routes/profiles.py
Normal file
363
backend/routes/profiles.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""Voice profile endpoints."""
|
||||
|
||||
import io
|
||||
import json as _json
|
||||
import logging
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import config, models
|
||||
from ..app import safe_content_disposition
|
||||
from ..database import VoiceProfile as DBVoiceProfile, get_db
|
||||
from ..services import channels, export_import, profiles
|
||||
from ..services.profiles import _profile_to_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/profiles", response_model=models.VoiceProfileResponse)
|
||||
async def create_profile(
|
||||
data: models.VoiceProfileCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create a new voice profile."""
|
||||
try:
|
||||
return await profiles.create_profile(data, db)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/profiles", response_model=list[models.VoiceProfileResponse])
|
||||
async def list_profiles(db: Session = Depends(get_db)):
|
||||
"""List all voice profiles."""
|
||||
return await profiles.list_profiles(db)
|
||||
|
||||
|
||||
@router.post("/profiles/import", response_model=models.VoiceProfileResponse)
|
||||
async def import_profile(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Import a voice profile from a ZIP archive."""
|
||||
MAX_FILE_SIZE = 100 * 1024 * 1024
|
||||
|
||||
content = await file.read()
|
||||
|
||||
if len(content) > MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"File too large. Maximum size is {MAX_FILE_SIZE / (1024 * 1024)}MB"
|
||||
)
|
||||
|
||||
try:
|
||||
profile = await export_import.import_profile_from_zip(content, db)
|
||||
return profile
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ── Preset Voice Endpoints ───────────────────────────────────────────
|
||||
# These MUST be declared before /profiles/{profile_id} to avoid the
|
||||
# wildcard swallowing "presets" as a profile_id.
|
||||
|
||||
|
||||
@router.get("/profiles/presets/{engine}")
|
||||
async def list_preset_voices(engine: str):
|
||||
"""List available preset voices for an engine."""
|
||||
if engine == "kokoro":
|
||||
from ..backends.kokoro_backend import KOKORO_VOICES
|
||||
|
||||
return {
|
||||
"engine": engine,
|
||||
"voices": [
|
||||
{
|
||||
"voice_id": vid,
|
||||
"name": name,
|
||||
"gender": gender,
|
||||
"language": lang,
|
||||
}
|
||||
for vid, name, gender, lang in KOKORO_VOICES
|
||||
],
|
||||
}
|
||||
if engine == "qwen_custom_voice":
|
||||
from ..backends.qwen_custom_voice_backend import QWEN_CUSTOM_VOICES
|
||||
|
||||
return {
|
||||
"engine": engine,
|
||||
"voices": [
|
||||
{
|
||||
"voice_id": speaker_id,
|
||||
"name": display_name,
|
||||
"gender": gender,
|
||||
"language": lang,
|
||||
}
|
||||
for speaker_id, display_name, gender, lang, _desc in QWEN_CUSTOM_VOICES
|
||||
],
|
||||
}
|
||||
return {"engine": engine, "voices": []}
|
||||
|
||||
@router.get("/profiles/{profile_id}", response_model=models.VoiceProfileResponse)
|
||||
async def get_profile(
|
||||
profile_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get a voice profile by ID."""
|
||||
profile = await profiles.get_profile(profile_id, db)
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return profile
|
||||
|
||||
|
||||
@router.put("/profiles/{profile_id}", response_model=models.VoiceProfileResponse)
|
||||
async def update_profile(
|
||||
profile_id: str,
|
||||
data: models.VoiceProfileCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update a voice profile."""
|
||||
try:
|
||||
profile = await profiles.update_profile(profile_id, data, db)
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return profile
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/profiles/{profile_id}")
|
||||
async def delete_profile(
|
||||
profile_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete a voice profile."""
|
||||
success = await profiles.delete_profile(profile_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return {"message": "Profile deleted successfully"}
|
||||
|
||||
|
||||
SAMPLE_MAX_FILE_SIZE = 50 * 1024 * 1024 # 50 MB
|
||||
SAMPLE_UPLOAD_CHUNK_SIZE = 1024 * 1024 # 1 MB
|
||||
|
||||
|
||||
@router.post("/profiles/{profile_id}/samples", response_model=models.ProfileSampleResponse)
|
||||
async def add_profile_sample(
|
||||
profile_id: str,
|
||||
file: UploadFile = File(...),
|
||||
reference_text: str = Form(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Add a sample to a voice profile."""
|
||||
_allowed_audio_exts = {".wav", ".mp3", ".m4a", ".ogg", ".flac", ".aac", ".webm", ".opus"}
|
||||
_uploaded_ext = Path(file.filename or "").suffix.lower()
|
||||
file_suffix = _uploaded_ext if _uploaded_ext in _allowed_audio_exts else ".wav"
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=file_suffix, delete=False) as tmp:
|
||||
total_size = 0
|
||||
while chunk := await file.read(SAMPLE_UPLOAD_CHUNK_SIZE):
|
||||
total_size += len(chunk)
|
||||
if total_size > SAMPLE_MAX_FILE_SIZE:
|
||||
Path(tmp.name).unlink(missing_ok=True)
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File too large (max {SAMPLE_MAX_FILE_SIZE // (1024 * 1024)} MB)",
|
||||
)
|
||||
tmp.write(chunk)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
sample = await profiles.add_profile_sample(
|
||||
profile_id,
|
||||
tmp_path,
|
||||
reference_text,
|
||||
db,
|
||||
)
|
||||
return sample
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to process audio file: {str(e)}")
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}/samples", response_model=list[models.ProfileSampleResponse])
|
||||
async def get_profile_samples(
|
||||
profile_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all samples for a profile."""
|
||||
return await profiles.get_profile_samples(profile_id, db)
|
||||
|
||||
|
||||
@router.delete("/profiles/samples/{sample_id}")
|
||||
async def delete_profile_sample(
|
||||
sample_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete a profile sample."""
|
||||
success = await profiles.delete_profile_sample(sample_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Sample not found")
|
||||
return {"message": "Sample deleted successfully"}
|
||||
|
||||
|
||||
@router.put("/profiles/samples/{sample_id}", response_model=models.ProfileSampleResponse)
|
||||
async def update_profile_sample(
|
||||
sample_id: str,
|
||||
data: models.ProfileSampleUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update a profile sample's reference text."""
|
||||
sample = await profiles.update_profile_sample(sample_id, data.reference_text, db)
|
||||
if not sample:
|
||||
raise HTTPException(status_code=404, detail="Sample not found")
|
||||
return sample
|
||||
|
||||
|
||||
@router.post("/profiles/{profile_id}/avatar", response_model=models.VoiceProfileResponse)
|
||||
async def upload_profile_avatar(
|
||||
profile_id: str,
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Upload or update avatar image for a profile."""
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp:
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
profile = await profiles.upload_avatar(profile_id, tmp_path, db)
|
||||
return profile
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}/avatar")
|
||||
async def get_profile_avatar(
|
||||
profile_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get avatar image for a profile."""
|
||||
profile = await profiles.get_profile(profile_id, db)
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
|
||||
if not profile.avatar_path:
|
||||
raise HTTPException(status_code=404, detail="No avatar found for this profile")
|
||||
|
||||
avatar_path = config.resolve_storage_path(profile.avatar_path)
|
||||
if avatar_path is None or not avatar_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Avatar file not found")
|
||||
|
||||
return FileResponse(avatar_path)
|
||||
|
||||
|
||||
@router.delete("/profiles/{profile_id}/avatar")
|
||||
async def delete_profile_avatar(
|
||||
profile_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete avatar image for a profile."""
|
||||
success = await profiles.delete_avatar(profile_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Profile not found or no avatar to delete")
|
||||
return {"message": "Avatar deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}/export")
|
||||
async def export_profile(
|
||||
profile_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Export a voice profile as a ZIP archive."""
|
||||
try:
|
||||
profile = await profiles.get_profile(profile_id, db)
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
|
||||
zip_bytes = export_import.export_profile_to_zip(profile_id, db)
|
||||
|
||||
safe_name = "".join(c for c in profile.name if c.isalnum() or c in (" ", "-", "_")).strip()
|
||||
if not safe_name:
|
||||
safe_name = "profile"
|
||||
filename = f"profile-{safe_name}.voicebox.zip"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(zip_bytes),
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}/channels")
|
||||
async def get_profile_channels(
|
||||
profile_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get list of channel IDs assigned to a profile."""
|
||||
try:
|
||||
channel_ids = await channels.get_profile_channels(profile_id, db)
|
||||
return {"channel_ids": channel_ids}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/profiles/{profile_id}/channels")
|
||||
async def set_profile_channels(
|
||||
profile_id: str,
|
||||
data: models.ProfileChannelAssignment,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Set which channels a profile is assigned to."""
|
||||
try:
|
||||
await channels.set_profile_channels(profile_id, data, db)
|
||||
return {"message": "Profile channels updated successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/profiles/{profile_id}/effects", response_model=models.VoiceProfileResponse)
|
||||
async def update_profile_effects(
|
||||
profile_id: str,
|
||||
data: models.ProfileEffectsUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Set or clear the default effects chain for a voice profile."""
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
|
||||
if data.effects_chain is not None:
|
||||
from ..utils.effects import validate_effects_chain
|
||||
|
||||
chain_dicts = [e.model_dump() for e in data.effects_chain]
|
||||
error = validate_effects_chain(chain_dicts)
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
profile.effects_chain = _json.dumps(chain_dicts)
|
||||
else:
|
||||
profile.effects_chain = None
|
||||
|
||||
profile.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(profile)
|
||||
|
||||
return _profile_to_response(profile)
|
||||
223
backend/routes/stories.py
Normal file
223
backend/routes/stories.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Story endpoints."""
|
||||
|
||||
import io
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import database, models
|
||||
from ..services import stories
|
||||
from ..app import safe_content_disposition
|
||||
from ..database import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/stories", response_model=list[models.StoryResponse])
|
||||
async def list_stories(db: Session = Depends(get_db)):
|
||||
"""List all stories."""
|
||||
return await stories.list_stories(db)
|
||||
|
||||
|
||||
@router.post("/stories", response_model=models.StoryResponse)
|
||||
async def create_story(
|
||||
data: models.StoryCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create a new story."""
|
||||
try:
|
||||
return await stories.create_story(data, db)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stories/{story_id}", response_model=models.StoryDetailResponse)
|
||||
async def get_story(
|
||||
story_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get a story with all its items."""
|
||||
story = await stories.get_story(story_id, db)
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="Story not found")
|
||||
return story
|
||||
|
||||
|
||||
@router.put("/stories/{story_id}", response_model=models.StoryResponse)
|
||||
async def update_story(
|
||||
story_id: str,
|
||||
data: models.StoryCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update a story."""
|
||||
story = await stories.update_story(story_id, data, db)
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="Story not found")
|
||||
return story
|
||||
|
||||
|
||||
@router.delete("/stories/{story_id}")
|
||||
async def delete_story(
|
||||
story_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete a story."""
|
||||
success = await stories.delete_story(story_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Story not found")
|
||||
return {"message": "Story deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/stories/{story_id}/items", response_model=models.StoryItemDetail)
|
||||
async def add_story_item(
|
||||
story_id: str,
|
||||
data: models.StoryItemCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Add a generation to a story."""
|
||||
item = await stories.add_item_to_story(story_id, data, db)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Story or generation not found")
|
||||
return item
|
||||
|
||||
|
||||
@router.delete("/stories/{story_id}/items/{item_id}")
|
||||
async def remove_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Remove a story item from a story."""
|
||||
success = await stories.remove_item_from_story(story_id, item_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Story item not found")
|
||||
return {"message": "Item removed successfully"}
|
||||
|
||||
|
||||
@router.put("/stories/{story_id}/items/times")
|
||||
async def update_story_item_times(
|
||||
story_id: str,
|
||||
data: models.StoryItemBatchUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update story item timecodes."""
|
||||
success = await stories.update_story_item_times(story_id, data, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Invalid timecode update request")
|
||||
return {"message": "Item timecodes updated successfully"}
|
||||
|
||||
|
||||
@router.put("/stories/{story_id}/items/reorder", response_model=list[models.StoryItemDetail])
|
||||
async def reorder_story_items(
|
||||
story_id: str,
|
||||
data: models.StoryItemReorder,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Reorder story items and recalculate timecodes."""
|
||||
items = await stories.reorder_story_items(story_id, data.generation_ids, db)
|
||||
if items is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid reorder request - ensure all generation IDs belong to this story"
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
@router.put("/stories/{story_id}/items/{item_id}/move", response_model=models.StoryItemDetail)
|
||||
async def move_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: models.StoryItemMove,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Move a story item (update position and/or track)."""
|
||||
item = await stories.move_story_item(story_id, item_id, data, db)
|
||||
if item is None:
|
||||
raise HTTPException(status_code=404, detail="Story item not found")
|
||||
return item
|
||||
|
||||
|
||||
@router.put("/stories/{story_id}/items/{item_id}/trim", response_model=models.StoryItemDetail)
|
||||
async def trim_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: models.StoryItemTrim,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Trim a story item."""
|
||||
item = await stories.trim_story_item(story_id, item_id, data, db)
|
||||
if item is None:
|
||||
raise HTTPException(status_code=404, detail="Story item not found or invalid trim values")
|
||||
return item
|
||||
|
||||
|
||||
@router.post("/stories/{story_id}/items/{item_id}/split", response_model=list[models.StoryItemDetail])
|
||||
async def split_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: models.StoryItemSplit,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Split a story item at a given time, creating two clips."""
|
||||
items = await stories.split_story_item(story_id, item_id, data, db)
|
||||
if items is None:
|
||||
raise HTTPException(status_code=404, detail="Story item not found or invalid split point")
|
||||
return items
|
||||
|
||||
|
||||
@router.post("/stories/{story_id}/items/{item_id}/duplicate", response_model=models.StoryItemDetail)
|
||||
async def duplicate_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Duplicate a story item."""
|
||||
item = await stories.duplicate_story_item(story_id, item_id, db)
|
||||
if item is None:
|
||||
raise HTTPException(status_code=404, detail="Story item not found")
|
||||
return item
|
||||
|
||||
|
||||
@router.put("/stories/{story_id}/items/{item_id}/version", response_model=models.StoryItemDetail)
|
||||
async def set_story_item_version(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: models.StoryItemVersionUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Pin a story item to a specific generation version."""
|
||||
item = await stories.set_story_item_version(story_id, item_id, data, db)
|
||||
if item is None:
|
||||
raise HTTPException(status_code=404, detail="Story item or version not found")
|
||||
return item
|
||||
|
||||
|
||||
@router.get("/stories/{story_id}/export-audio")
|
||||
async def export_story_audio(
|
||||
story_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Export story as single mixed audio file."""
|
||||
try:
|
||||
story = db.query(database.Story).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="Story not found")
|
||||
|
||||
audio_bytes = await stories.export_story_audio(story_id, db)
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=400, detail="Story has no audio items")
|
||||
|
||||
safe_name = "".join(c for c in story.name if c.isalnum() or c in (" ", "-", "_")).strip()
|
||||
if not safe_name:
|
||||
safe_name = "story"
|
||||
filename = f"{safe_name}.wav"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(audio_bytes),
|
||||
media_type="audio/wav",
|
||||
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
125
backend/routes/tasks.py
Normal file
125
backend/routes/tasks.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Task and cache management endpoints."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .. import models
|
||||
from ..utils.cache import clear_voice_prompt_cache
|
||||
from ..utils.progress import get_progress_manager
|
||||
from ..utils.tasks import get_task_manager
|
||||
from fastapi import HTTPException
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/tasks/clear")
|
||||
async def clear_all_tasks():
|
||||
"""Clear all download tasks and progress state."""
|
||||
task_manager = get_task_manager()
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
task_manager.clear_all()
|
||||
|
||||
with progress_manager._lock:
|
||||
progress_manager._progress.clear()
|
||||
progress_manager._last_notify_time.clear()
|
||||
progress_manager._last_notify_progress.clear()
|
||||
|
||||
return {"message": "All task state cleared"}
|
||||
|
||||
|
||||
@router.post("/cache/clear")
|
||||
async def clear_cache():
|
||||
"""Clear all voice prompt caches (memory and disk)."""
|
||||
try:
|
||||
deleted_count = clear_voice_prompt_cache()
|
||||
return {
|
||||
"message": "Voice prompt cache cleared successfully",
|
||||
"files_deleted": deleted_count,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to clear cache: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tasks/active", response_model=models.ActiveTasksResponse)
|
||||
async def get_active_tasks():
|
||||
"""Return all currently active downloads and generations."""
|
||||
task_manager = get_task_manager()
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
active_downloads = []
|
||||
task_manager_downloads = task_manager.get_active_downloads()
|
||||
progress_active = progress_manager.get_all_active()
|
||||
|
||||
download_map = {task.model_name: task for task in task_manager_downloads}
|
||||
progress_map = {p["model_name"]: p for p in progress_active}
|
||||
|
||||
all_model_names = set(download_map.keys()) | set(progress_map.keys())
|
||||
for model_name in all_model_names:
|
||||
task = download_map.get(model_name)
|
||||
progress = progress_map.get(model_name)
|
||||
|
||||
if task:
|
||||
error = task.error
|
||||
if not error:
|
||||
with progress_manager._lock:
|
||||
pm_data = progress_manager._progress.get(model_name)
|
||||
if pm_data:
|
||||
error = pm_data.get("error")
|
||||
prog = progress or {}
|
||||
if not prog:
|
||||
with progress_manager._lock:
|
||||
pm_data = progress_manager._progress.get(model_name)
|
||||
if pm_data:
|
||||
prog = pm_data
|
||||
active_downloads.append(
|
||||
models.ActiveDownloadTask(
|
||||
model_name=model_name,
|
||||
status=task.status,
|
||||
started_at=task.started_at,
|
||||
error=error,
|
||||
progress=prog.get("progress"),
|
||||
current=prog.get("current"),
|
||||
total=prog.get("total"),
|
||||
filename=prog.get("filename"),
|
||||
)
|
||||
)
|
||||
elif progress:
|
||||
timestamp_str = progress.get("timestamp")
|
||||
if timestamp_str:
|
||||
try:
|
||||
started_at = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError):
|
||||
started_at = datetime.utcnow()
|
||||
else:
|
||||
started_at = datetime.utcnow()
|
||||
|
||||
active_downloads.append(
|
||||
models.ActiveDownloadTask(
|
||||
model_name=model_name,
|
||||
status=progress.get("status", "downloading"),
|
||||
started_at=started_at,
|
||||
error=progress.get("error"),
|
||||
progress=progress.get("progress"),
|
||||
current=progress.get("current"),
|
||||
total=progress.get("total"),
|
||||
filename=progress.get("filename"),
|
||||
)
|
||||
)
|
||||
|
||||
active_generations = []
|
||||
for gen_task in task_manager.get_active_generations():
|
||||
active_generations.append(
|
||||
models.ActiveGenerationTask(
|
||||
task_id=gen_task.task_id,
|
||||
profile_id=gen_task.profile_id,
|
||||
text_preview=gen_task.text_preview,
|
||||
started_at=gen_task.started_at,
|
||||
)
|
||||
)
|
||||
|
||||
return models.ActiveTasksResponse(
|
||||
downloads=active_downloads,
|
||||
generations=active_generations,
|
||||
)
|
||||
84
backend/routes/transcription.py
Normal file
84
backend/routes/transcription.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Transcription endpoints."""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
|
||||
from .. import models
|
||||
from ..services import transcribe
|
||||
from ..services.task_queue import create_background_task
|
||||
from ..utils.tasks import get_task_manager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
UPLOAD_CHUNK_SIZE = 1024 * 1024 # 1MB
|
||||
|
||||
|
||||
@router.post("/transcribe", response_model=models.TranscriptionResponse)
|
||||
async def transcribe_audio(
|
||||
file: UploadFile = File(...),
|
||||
language: str | None = Form(None),
|
||||
model: str | None = Form(None),
|
||||
):
|
||||
"""Transcribe audio file to text."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
|
||||
tmp.write(chunk)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
from ..utils.audio import load_audio
|
||||
from ..backends import WHISPER_HF_REPOS
|
||||
|
||||
audio, sr = await asyncio.to_thread(load_audio, tmp_path)
|
||||
duration = len(audio) / sr
|
||||
|
||||
whisper_model = transcribe.get_whisper_model()
|
||||
model_size = model if model else whisper_model.model_size
|
||||
|
||||
valid_sizes = list(WHISPER_HF_REPOS.keys())
|
||||
if model_size not in valid_sizes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid model size '{model_size}'. Must be one of: {', '.join(valid_sizes)}",
|
||||
)
|
||||
|
||||
already_loaded = whisper_model.is_loaded() and whisper_model.model_size == model_size
|
||||
if not already_loaded and not whisper_model._is_model_cached(model_size):
|
||||
progress_model_name = f"whisper-{model_size}"
|
||||
task_manager = get_task_manager()
|
||||
|
||||
async def download_whisper_background():
|
||||
try:
|
||||
await whisper_model.load_model_async(model_size)
|
||||
task_manager.complete_download(progress_model_name)
|
||||
except Exception as e:
|
||||
task_manager.error_download(progress_model_name, str(e))
|
||||
|
||||
task_manager.start_download(progress_model_name)
|
||||
create_background_task(download_whisper_background())
|
||||
|
||||
raise HTTPException(
|
||||
status_code=202,
|
||||
detail={
|
||||
"message": f"Whisper model {model_size} is being downloaded. Please wait and try again.",
|
||||
"model_name": progress_model_name,
|
||||
"downloading": True,
|
||||
},
|
||||
)
|
||||
|
||||
text = await whisper_model.transcribe(tmp_path, language, model_size)
|
||||
|
||||
return models.TranscriptionResponse(
|
||||
text=text,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
303
backend/server.py
Normal file
303
backend/server.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Entry point for PyInstaller-bundled voicebox server.
|
||||
|
||||
This module provides an entry point that works with PyInstaller by using
|
||||
absolute imports instead of relative imports.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# On Windows with --noconsole (PyInstaller), sys.stdout/stderr are None.
|
||||
# They can also be broken file objects in some edge cases.
|
||||
# Redirect to devnull to prevent crashes from print()/tqdm/logging.
|
||||
def _is_writable(stream):
|
||||
"""Check if a stream is usable for writing."""
|
||||
if stream is None:
|
||||
return False
|
||||
try:
|
||||
stream.write("")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if not _is_writable(sys.stdout):
|
||||
sys.stdout = open(os.devnull, 'w')
|
||||
if not _is_writable(sys.stderr):
|
||||
sys.stderr = open(os.devnull, 'w')
|
||||
|
||||
# PyInstaller + multiprocessing: child processes re-execute the frozen binary
|
||||
# with internal arguments. freeze_support() handles this and exits early.
|
||||
import multiprocessing
|
||||
multiprocessing.freeze_support()
|
||||
|
||||
# In frozen builds, piper_phonemize's espeak-ng C library falls back to
|
||||
# /usr/share/espeak-ng-data/ which doesn't exist. Point it at the bundled
|
||||
# data directory instead.
|
||||
if getattr(sys, 'frozen', False):
|
||||
_meipass = getattr(sys, '_MEIPASS', os.path.dirname(sys.executable))
|
||||
_espeak_data = os.path.join(_meipass, 'piper_phonemize', 'espeak-ng-data')
|
||||
if os.path.isdir(_espeak_data):
|
||||
os.environ.setdefault('ESPEAK_DATA_PATH', _espeak_data)
|
||||
|
||||
# Fast path: handle --version before any heavy imports so the Rust
|
||||
# version check doesn't block for 30+ seconds loading torch etc.
|
||||
if "--version" in sys.argv:
|
||||
from backend import __version__
|
||||
print(f"voicebox-server {__version__}")
|
||||
sys.exit(0)
|
||||
|
||||
import logging
|
||||
|
||||
# Set up logging FIRST, before any imports that might fail
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
stream=sys.stderr, # Log to stderr so it's captured by Tauri
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Log startup immediately to confirm binary execution
|
||||
logger.info("=" * 60)
|
||||
logger.info("voicebox-server starting up...")
|
||||
logger.info(f"Python version: {sys.version}")
|
||||
logger.info(f"Executable: {sys.executable}")
|
||||
logger.info(f"Arguments: {sys.argv}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
logger.info("Importing argparse...")
|
||||
import argparse
|
||||
logger.info("Importing uvicorn...")
|
||||
import uvicorn
|
||||
logger.info("Standard library imports successful")
|
||||
|
||||
# Import the FastAPI app from the backend package
|
||||
logger.info("Importing backend.config...")
|
||||
from backend import config
|
||||
logger.info("Importing backend.database...")
|
||||
from backend import database
|
||||
logger.info("Importing backend.main (this may take a while due to torch/transformers)...")
|
||||
from backend.main import app
|
||||
logger.info("Backend imports successful")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import required modules: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
_watchdog_disabled = False
|
||||
|
||||
|
||||
def disable_watchdog():
|
||||
"""Disable the parent watchdog so the server keeps running after parent exits."""
|
||||
global _watchdog_disabled
|
||||
_watchdog_disabled = True
|
||||
# Ignore SIGHUP so the server survives when the parent Tauri process exits.
|
||||
# On Unix, child processes receive SIGHUP when the parent's session leader
|
||||
# exits, which would kill the server even though we want it to persist.
|
||||
if sys.platform != "win32":
|
||||
import signal
|
||||
signal.signal(signal.SIGHUP, signal.SIG_IGN)
|
||||
|
||||
|
||||
def _start_parent_watchdog(parent_pid, data_dir=None):
|
||||
"""Monitor parent process and exit if it dies.
|
||||
|
||||
This is the clean shutdown mechanism: instead of the Tauri app trying to
|
||||
forcefully kill the server (which spawns console windows on Windows),
|
||||
the server monitors its parent and shuts itself down gracefully.
|
||||
|
||||
The Tauri app writes a .keep-running sentinel file to data_dir before
|
||||
exiting when "remain running after close" is enabled. This is a reliable
|
||||
fallback for the HTTP /watchdog/disable request, which can race with
|
||||
process exit on Windows.
|
||||
"""
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
|
||||
# Set up a file logger so we can debug in production
|
||||
watchdog_logger = logging.getLogger("watchdog")
|
||||
if data_dir:
|
||||
try:
|
||||
log_dir = os.path.join(data_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
fh = logging.FileHandler(os.path.join(log_dir, "watchdog.log"))
|
||||
fh.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
|
||||
watchdog_logger.addHandler(fh)
|
||||
except Exception:
|
||||
pass
|
||||
watchdog_logger.setLevel(logging.INFO)
|
||||
|
||||
def _is_pid_alive(pid):
|
||||
"""Check if a process with the given PID exists (cross-platform)."""
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
import ctypes
|
||||
kernel32 = ctypes.windll.kernel32
|
||||
PROCESS_QUERY_LIMITED_INFORMATION = 0x1000
|
||||
handle = kernel32.OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, False, pid)
|
||||
if handle:
|
||||
# Check if process has actually exited
|
||||
STILL_ACTIVE = 259
|
||||
exit_code = ctypes.c_ulong()
|
||||
result = kernel32.GetExitCodeProcess(handle, ctypes.byref(exit_code))
|
||||
kernel32.CloseHandle(handle)
|
||||
if result and exit_code.value == STILL_ACTIVE:
|
||||
return True
|
||||
watchdog_logger.info(f"PID {pid}: exited with code {exit_code.value}")
|
||||
return False
|
||||
# OpenProcess failed — check if it's an access error (process exists
|
||||
# but we can't open it) vs process not found
|
||||
error = ctypes.GetLastError()
|
||||
ACCESS_DENIED = 5
|
||||
if error == ACCESS_DENIED:
|
||||
return True # process exists, we just can't open it
|
||||
watchdog_logger.info(f"PID {pid}: OpenProcess failed, error={error}")
|
||||
return False
|
||||
else:
|
||||
os.kill(pid, 0)
|
||||
return True
|
||||
except (OSError, PermissionError):
|
||||
return False
|
||||
|
||||
def _watch():
|
||||
watchdog_logger.info(f"Parent watchdog started, monitoring PID {parent_pid}, server PID {os.getpid()}")
|
||||
# Verify parent is alive before starting the loop
|
||||
alive = _is_pid_alive(parent_pid)
|
||||
watchdog_logger.info(f"Parent PID {parent_pid} initial check: alive={alive}")
|
||||
if not alive:
|
||||
watchdog_logger.warning(f"Parent PID {parent_pid} not found on first check — disabling watchdog")
|
||||
return
|
||||
# Clear any stale .keep-running sentinel from a previous session. The
|
||||
# sentinel is only removed by the watchdog when it's consumed during a
|
||||
# grace period; if the HTTP /watchdog/disable path wins the race on a
|
||||
# "keep running" exit, the sentinel is left on disk. Wipe it here so a
|
||||
# future session can't inherit that stale signal.
|
||||
if data_dir:
|
||||
stale = os.path.join(data_dir, ".keep-running")
|
||||
if os.path.exists(stale):
|
||||
try:
|
||||
os.remove(stale)
|
||||
watchdog_logger.info("Removed stale .keep-running sentinel from previous session")
|
||||
except OSError as e:
|
||||
watchdog_logger.warning(f"Failed to remove stale sentinel: {e}")
|
||||
while True:
|
||||
if _watchdog_disabled:
|
||||
watchdog_logger.info("Watchdog disabled (keep server running), stopping monitor")
|
||||
return
|
||||
if not _is_pid_alive(parent_pid):
|
||||
# Parent is gone. Before shutting down, give the app a moment
|
||||
# to send /watchdog/disable — there is a race where the Tauri
|
||||
# RunEvent::Exit handler sends the disable request while we are
|
||||
# mid-iteration (already past the _watchdog_disabled check above).
|
||||
watchdog_logger.info(f"Parent process {parent_pid} gone, waiting for possible disable request...")
|
||||
time.sleep(1)
|
||||
if _watchdog_disabled:
|
||||
watchdog_logger.info("Watchdog was disabled during grace period, keeping server alive")
|
||||
return
|
||||
# Check for sentinel file written by Tauri before exit.
|
||||
# This catches the case where the HTTP disable request
|
||||
# didn't arrive before the parent process died (common
|
||||
# on Windows where process teardown is fast).
|
||||
sentinel = os.path.join(data_dir, ".keep-running") if data_dir else None
|
||||
if sentinel and os.path.exists(sentinel):
|
||||
watchdog_logger.info("Found .keep-running sentinel file, keeping server alive")
|
||||
try:
|
||||
os.remove(sentinel)
|
||||
except OSError:
|
||||
pass
|
||||
return
|
||||
watchdog_logger.info("Watchdog still enabled after grace period, shutting down server...")
|
||||
if sys.platform == "win32":
|
||||
# sys.exit triggers SystemExit, allowing uvicorn to run
|
||||
# shutdown handlers. os.kill(SIGTERM) on Windows calls
|
||||
# TerminateProcess which hard-kills without cleanup.
|
||||
os._exit(0)
|
||||
else:
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
t = threading.Thread(target=_watch, daemon=True)
|
||||
t.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
parser = argparse.ArgumentParser(description="voicebox backend server")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
help="Host to bind to (use 0.0.0.0 for remote access)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to bind to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Data directory for database, profiles, and generated audio",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--parent-pid",
|
||||
type=int,
|
||||
default=None,
|
||||
help="PID of parent process to monitor; server exits when parent dies",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
action="store_true",
|
||||
help="Print version and exit (handled above, kept for argparse help)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.parent_pid is not None and args.parent_pid <= 0:
|
||||
parser.error("--parent-pid must be a positive integer")
|
||||
|
||||
# Detect backend variant from binary name
|
||||
# voicebox-server-cuda → sets VOICEBOX_BACKEND_VARIANT=cuda
|
||||
import os
|
||||
binary_name = os.path.basename(sys.executable).lower()
|
||||
if "cuda" in binary_name:
|
||||
os.environ["VOICEBOX_BACKEND_VARIANT"] = "cuda"
|
||||
logger.info("Backend variant: CUDA")
|
||||
else:
|
||||
os.environ["VOICEBOX_BACKEND_VARIANT"] = "cpu"
|
||||
logger.info("Backend variant: CPU")
|
||||
|
||||
# Register parent watchdog to start after server is fully ready
|
||||
if args.parent_pid is not None:
|
||||
_parent_pid = args.parent_pid
|
||||
_data_dir = args.data_dir
|
||||
@app.on_event("startup")
|
||||
async def _on_startup():
|
||||
_start_parent_watchdog(_parent_pid, _data_dir)
|
||||
|
||||
logger.info(f"Parsed arguments: host={args.host}, port={args.port}, data_dir={args.data_dir}")
|
||||
|
||||
# Set data directory if provided
|
||||
if args.data_dir:
|
||||
logger.info(f"Setting data directory to: {args.data_dir}")
|
||||
config.set_data_dir(args.data_dir)
|
||||
|
||||
# Initialize database after data directory is set
|
||||
logger.info("Initializing database...")
|
||||
database.init_db()
|
||||
logger.info("Database initialized successfully")
|
||||
|
||||
logger.info(f"Starting uvicorn server on {args.host}:{args.port}...")
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="info",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Server startup failed: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
1
backend/services/__init__.py
Normal file
1
backend/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Services layer — generation orchestration and background task management.
|
||||
263
backend/services/channels.py
Normal file
263
backend/services/channels.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Audio channel management module.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import (
|
||||
AudioChannelCreate,
|
||||
AudioChannelUpdate,
|
||||
AudioChannelResponse,
|
||||
ChannelVoiceAssignment,
|
||||
ProfileChannelAssignment,
|
||||
)
|
||||
from ..database import (
|
||||
AudioChannel as DBAudioChannel,
|
||||
ChannelDeviceMapping as DBChannelDeviceMapping,
|
||||
ProfileChannelMapping as DBProfileChannelMapping,
|
||||
VoiceProfile as DBVoiceProfile,
|
||||
)
|
||||
|
||||
|
||||
async def list_channels(db: Session) -> List[AudioChannelResponse]:
|
||||
"""List all audio channels."""
|
||||
channels = db.query(DBAudioChannel).all()
|
||||
result = []
|
||||
|
||||
for channel in channels:
|
||||
# Get device IDs for this channel
|
||||
device_mappings = db.query(DBChannelDeviceMapping).filter_by(
|
||||
channel_id=channel.id
|
||||
).all()
|
||||
device_ids = [m.device_id for m in device_mappings]
|
||||
|
||||
result.append(AudioChannelResponse(
|
||||
id=channel.id,
|
||||
name=channel.name,
|
||||
is_default=channel.is_default,
|
||||
device_ids=device_ids,
|
||||
created_at=channel.created_at,
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def get_channel(channel_id: str, db: Session) -> Optional[AudioChannelResponse]:
|
||||
"""Get a channel by ID."""
|
||||
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
|
||||
if not channel:
|
||||
return None
|
||||
|
||||
# Get device IDs
|
||||
device_mappings = db.query(DBChannelDeviceMapping).filter_by(
|
||||
channel_id=channel.id
|
||||
).all()
|
||||
device_ids = [m.device_id for m in device_mappings]
|
||||
|
||||
return AudioChannelResponse(
|
||||
id=channel.id,
|
||||
name=channel.name,
|
||||
is_default=channel.is_default,
|
||||
device_ids=device_ids,
|
||||
created_at=channel.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def create_channel(
|
||||
data: AudioChannelCreate,
|
||||
db: Session,
|
||||
) -> AudioChannelResponse:
|
||||
"""Create a new audio channel."""
|
||||
# Check if name already exists
|
||||
existing = db.query(DBAudioChannel).filter_by(name=data.name).first()
|
||||
if existing:
|
||||
raise ValueError(f"Channel with name '{data.name}' already exists")
|
||||
|
||||
# Create channel
|
||||
channel = DBAudioChannel(
|
||||
id=str(uuid.uuid4()),
|
||||
name=data.name,
|
||||
is_default=False,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(channel)
|
||||
db.flush()
|
||||
|
||||
# Add device mappings
|
||||
for device_id in data.device_ids:
|
||||
mapping = DBChannelDeviceMapping(
|
||||
id=str(uuid.uuid4()),
|
||||
channel_id=channel.id,
|
||||
device_id=device_id,
|
||||
)
|
||||
db.add(mapping)
|
||||
|
||||
db.commit()
|
||||
db.refresh(channel)
|
||||
|
||||
return AudioChannelResponse(
|
||||
id=channel.id,
|
||||
name=channel.name,
|
||||
is_default=channel.is_default,
|
||||
device_ids=data.device_ids,
|
||||
created_at=channel.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def update_channel(
|
||||
channel_id: str,
|
||||
data: AudioChannelUpdate,
|
||||
db: Session,
|
||||
) -> Optional[AudioChannelResponse]:
|
||||
"""Update an audio channel."""
|
||||
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
|
||||
if not channel:
|
||||
return None
|
||||
|
||||
if channel.is_default:
|
||||
raise ValueError("Cannot modify the default channel")
|
||||
|
||||
# Update name if provided
|
||||
if data.name is not None:
|
||||
# Check if name already exists (excluding current channel)
|
||||
existing = db.query(DBAudioChannel).filter(
|
||||
DBAudioChannel.name == data.name,
|
||||
DBAudioChannel.id != channel_id
|
||||
).first()
|
||||
if existing:
|
||||
raise ValueError(f"Channel with name '{data.name}' already exists")
|
||||
channel.name = data.name
|
||||
|
||||
# Update device mappings if provided
|
||||
if data.device_ids is not None:
|
||||
# Delete existing mappings
|
||||
db.query(DBChannelDeviceMapping).filter_by(channel_id=channel_id).delete()
|
||||
|
||||
# Add new mappings
|
||||
for device_id in data.device_ids:
|
||||
mapping = DBChannelDeviceMapping(
|
||||
id=str(uuid.uuid4()),
|
||||
channel_id=channel.id,
|
||||
device_id=device_id,
|
||||
)
|
||||
db.add(mapping)
|
||||
|
||||
db.commit()
|
||||
db.refresh(channel)
|
||||
|
||||
# Get updated device IDs
|
||||
device_mappings = db.query(DBChannelDeviceMapping).filter_by(
|
||||
channel_id=channel.id
|
||||
).all()
|
||||
device_ids = [m.device_id for m in device_mappings]
|
||||
|
||||
return AudioChannelResponse(
|
||||
id=channel.id,
|
||||
name=channel.name,
|
||||
is_default=channel.is_default,
|
||||
device_ids=device_ids,
|
||||
created_at=channel.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def delete_channel(channel_id: str, db: Session) -> bool:
|
||||
"""Delete an audio channel."""
|
||||
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
|
||||
if not channel:
|
||||
return False
|
||||
|
||||
if channel.is_default:
|
||||
raise ValueError("Cannot delete the default channel")
|
||||
|
||||
# Delete device mappings
|
||||
db.query(DBChannelDeviceMapping).filter_by(channel_id=channel_id).delete()
|
||||
|
||||
# Delete profile-channel mappings
|
||||
db.query(DBProfileChannelMapping).filter_by(channel_id=channel_id).delete()
|
||||
|
||||
# Delete channel
|
||||
db.delete(channel)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def get_channel_voices(channel_id: str, db: Session) -> List[str]:
|
||||
"""Get list of profile IDs assigned to a channel."""
|
||||
mappings = db.query(DBProfileChannelMapping).filter_by(
|
||||
channel_id=channel_id
|
||||
).all()
|
||||
return [m.profile_id for m in mappings]
|
||||
|
||||
|
||||
async def set_channel_voices(
|
||||
channel_id: str,
|
||||
data: ChannelVoiceAssignment,
|
||||
db: Session,
|
||||
) -> None:
|
||||
"""Set which voices are assigned to a channel."""
|
||||
# Verify channel exists
|
||||
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
|
||||
if not channel:
|
||||
raise ValueError(f"Channel {channel_id} not found")
|
||||
|
||||
# Verify all profiles exist
|
||||
for profile_id in data.profile_ids:
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile {profile_id} not found")
|
||||
|
||||
# Delete existing mappings for this channel
|
||||
db.query(DBProfileChannelMapping).filter_by(channel_id=channel_id).delete()
|
||||
|
||||
# Add new mappings
|
||||
for profile_id in data.profile_ids:
|
||||
mapping = DBProfileChannelMapping(
|
||||
profile_id=profile_id,
|
||||
channel_id=channel_id,
|
||||
)
|
||||
db.add(mapping)
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
async def get_profile_channels(profile_id: str, db: Session) -> List[str]:
|
||||
"""Get list of channel IDs assigned to a profile."""
|
||||
mappings = db.query(DBProfileChannelMapping).filter_by(
|
||||
profile_id=profile_id
|
||||
).all()
|
||||
return [m.channel_id for m in mappings]
|
||||
|
||||
|
||||
async def set_profile_channels(
|
||||
profile_id: str,
|
||||
data: ProfileChannelAssignment,
|
||||
db: Session,
|
||||
) -> None:
|
||||
"""Set which channels a profile is assigned to."""
|
||||
# Verify profile exists
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile {profile_id} not found")
|
||||
|
||||
# Verify all channels exist
|
||||
for channel_id in data.channel_ids:
|
||||
channel = db.query(DBAudioChannel).filter_by(id=channel_id).first()
|
||||
if not channel:
|
||||
raise ValueError(f"Channel {channel_id} not found")
|
||||
|
||||
# Delete existing mappings for this profile
|
||||
db.query(DBProfileChannelMapping).filter_by(profile_id=profile_id).delete()
|
||||
|
||||
# Add new mappings
|
||||
for channel_id in data.channel_ids:
|
||||
mapping = DBProfileChannelMapping(
|
||||
profile_id=profile_id,
|
||||
channel_id=channel_id,
|
||||
)
|
||||
db.add(mapping)
|
||||
|
||||
db.commit()
|
||||
422
backend/services/cuda.py
Normal file
422
backend/services/cuda.py
Normal file
@@ -0,0 +1,422 @@
|
||||
"""
|
||||
CUDA backend download, assembly, and verification.
|
||||
|
||||
Downloads two archives from GitHub Releases:
|
||||
1. Server core (voicebox-server-cuda.tar.gz) — the exe + non-NVIDIA deps,
|
||||
versioned with the app.
|
||||
2. CUDA libs (cuda-libs-{version}.tar.gz) — NVIDIA runtime libraries,
|
||||
versioned independently (only redownloaded on CUDA toolkit bump).
|
||||
|
||||
Both archives are extracted into {data_dir}/backends/cuda/ which forms the
|
||||
complete PyInstaller --onedir directory structure that torch expects.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from ..config import get_data_dir
|
||||
from ..utils.progress import get_progress_manager
|
||||
from .. import __version__
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GITHUB_RELEASES_URL = "https://github.com/jamiepine/voicebox/releases/download"
|
||||
|
||||
PROGRESS_KEY = "cuda-backend"
|
||||
|
||||
# The current expected CUDA libs version. Bump this when we change the
|
||||
# CUDA toolkit version or torch's CUDA dependency changes (e.g. cu126 -> cu128).
|
||||
CUDA_LIBS_VERSION = "cu128-v1"
|
||||
|
||||
# Prevents concurrent download_cuda_binary() calls from racing on the same
|
||||
# temp file. The auto-update background task and the manual HTTP endpoint
|
||||
# can both invoke download_cuda_binary(); without this lock the progress-
|
||||
# manager status check is a TOCTOU race.
|
||||
_download_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def get_backends_dir() -> Path:
|
||||
"""Directory where downloaded backend binaries are stored."""
|
||||
d = get_data_dir() / "backends"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
|
||||
|
||||
def get_cuda_dir() -> Path:
|
||||
"""Directory where the CUDA backend (onedir) is extracted."""
|
||||
d = get_backends_dir() / "cuda"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
|
||||
|
||||
def get_cuda_exe_name() -> str:
|
||||
"""Platform-specific CUDA executable filename."""
|
||||
if sys.platform == "win32":
|
||||
return "voicebox-server-cuda.exe"
|
||||
return "voicebox-server-cuda"
|
||||
|
||||
|
||||
def get_cuda_binary_path() -> Optional[Path]:
|
||||
"""Return path to the CUDA executable if it exists inside the onedir."""
|
||||
p = get_cuda_dir() / get_cuda_exe_name()
|
||||
if p.exists():
|
||||
return p
|
||||
return None
|
||||
|
||||
|
||||
def get_cuda_libs_manifest_path() -> Path:
|
||||
"""Path to the cuda-libs.json manifest inside the CUDA dir."""
|
||||
return get_cuda_dir() / "cuda-libs.json"
|
||||
|
||||
|
||||
def get_installed_cuda_libs_version() -> Optional[str]:
|
||||
"""Read the installed CUDA libs version from cuda-libs.json, or None."""
|
||||
manifest_path = get_cuda_libs_manifest_path()
|
||||
if not manifest_path.exists():
|
||||
return None
|
||||
try:
|
||||
data = json.loads(manifest_path.read_text())
|
||||
return data.get("version")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read cuda-libs.json: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def is_cuda_active() -> bool:
|
||||
"""Check if the current process is the CUDA binary.
|
||||
|
||||
The CUDA binary sets this env var on startup (see server.py).
|
||||
"""
|
||||
return os.environ.get("VOICEBOX_BACKEND_VARIANT") == "cuda"
|
||||
|
||||
|
||||
def get_cuda_status() -> dict:
|
||||
"""Get current CUDA backend status for the API."""
|
||||
progress_manager = get_progress_manager()
|
||||
cuda_path = get_cuda_binary_path()
|
||||
progress = progress_manager.get_progress(PROGRESS_KEY)
|
||||
cuda_libs_version = get_installed_cuda_libs_version()
|
||||
|
||||
return {
|
||||
"available": cuda_path is not None,
|
||||
"active": is_cuda_active(),
|
||||
"binary_path": str(cuda_path) if cuda_path else None,
|
||||
"cuda_libs_version": cuda_libs_version,
|
||||
"downloading": progress is not None and progress.get("status") == "downloading",
|
||||
"download_progress": progress,
|
||||
}
|
||||
|
||||
|
||||
def _needs_server_download(version: Optional[str] = None) -> bool:
|
||||
"""Check if the server core archive needs to be (re)downloaded."""
|
||||
cuda_path = get_cuda_binary_path()
|
||||
if not cuda_path:
|
||||
return True
|
||||
# Check if the binary version matches the expected app version
|
||||
installed = get_cuda_binary_version()
|
||||
expected = version or __version__
|
||||
if expected.startswith("v"):
|
||||
expected = expected[1:]
|
||||
return installed != expected
|
||||
|
||||
|
||||
def _needs_cuda_libs_download() -> bool:
|
||||
"""Check if the CUDA libs archive needs to be (re)downloaded."""
|
||||
installed = get_installed_cuda_libs_version()
|
||||
if installed is None:
|
||||
return True
|
||||
return installed != CUDA_LIBS_VERSION
|
||||
|
||||
|
||||
async def _download_and_extract_archive(
|
||||
client,
|
||||
url: str,
|
||||
sha256_url: Optional[str],
|
||||
dest_dir: Path,
|
||||
label: str,
|
||||
progress_offset: int,
|
||||
total_size: int,
|
||||
):
|
||||
"""Download a .tar.gz archive and extract it into dest_dir.
|
||||
|
||||
Args:
|
||||
client: httpx.AsyncClient
|
||||
url: URL of the .tar.gz archive
|
||||
sha256_url: URL of the .sha256 checksum file (optional)
|
||||
dest_dir: Directory to extract into
|
||||
label: Human-readable label for progress updates
|
||||
progress_offset: Byte offset for progress reporting (when downloading
|
||||
multiple archives sequentially)
|
||||
total_size: Total bytes across all downloads (for progress bar)
|
||||
"""
|
||||
progress = get_progress_manager()
|
||||
temp_path = dest_dir / f".download-{label.replace(' ', '-')}.tmp"
|
||||
|
||||
# Clean up leftover partial download
|
||||
if temp_path.exists():
|
||||
temp_path.unlink()
|
||||
|
||||
# Fetch expected checksum (fail-fast: never extract an unverified archive)
|
||||
expected_sha = None
|
||||
if sha256_url:
|
||||
try:
|
||||
sha_resp = await client.get(sha256_url)
|
||||
sha_resp.raise_for_status()
|
||||
expected_sha = sha_resp.text.strip().split()[0]
|
||||
logger.info(f"{label}: expected SHA-256: {expected_sha[:16]}...")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"{label}: failed to fetch checksum from {sha256_url}") from e
|
||||
|
||||
# Stream download, verify, and extract — always clean up temp file
|
||||
downloaded = 0
|
||||
try:
|
||||
async with client.stream("GET", url) as response:
|
||||
response.raise_for_status()
|
||||
with open(temp_path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes(chunk_size=1024 * 1024):
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
progress.update_progress(
|
||||
PROGRESS_KEY,
|
||||
current=progress_offset + downloaded,
|
||||
total=total_size,
|
||||
filename=f"Downloading {label}",
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
# Verify integrity
|
||||
if expected_sha:
|
||||
progress.update_progress(
|
||||
PROGRESS_KEY,
|
||||
current=progress_offset + downloaded,
|
||||
total=total_size,
|
||||
filename=f"Verifying {label}...",
|
||||
status="downloading",
|
||||
)
|
||||
sha256 = hashlib.sha256()
|
||||
with open(temp_path, "rb") as f:
|
||||
while True:
|
||||
data = f.read(1024 * 1024)
|
||||
if not data:
|
||||
break
|
||||
sha256.update(data)
|
||||
actual = sha256.hexdigest()
|
||||
if actual != expected_sha:
|
||||
raise ValueError(
|
||||
f"{label} integrity check failed: expected {expected_sha[:16]}..., got {actual[:16]}..."
|
||||
)
|
||||
logger.info(f"{label}: integrity verified")
|
||||
|
||||
# Extract (use data filter for path traversal protection on Python 3.12+)
|
||||
progress.update_progress(
|
||||
PROGRESS_KEY,
|
||||
current=progress_offset + downloaded,
|
||||
total=total_size,
|
||||
filename=f"Extracting {label}...",
|
||||
status="downloading",
|
||||
)
|
||||
with tarfile.open(temp_path, "r:gz") as tar:
|
||||
if sys.version_info >= (3, 12):
|
||||
tar.extractall(path=dest_dir, filter="data")
|
||||
else:
|
||||
tar.extractall(path=dest_dir)
|
||||
|
||||
logger.info(f"{label}: extracted to {dest_dir}")
|
||||
finally:
|
||||
if temp_path.exists():
|
||||
temp_path.unlink()
|
||||
return downloaded
|
||||
|
||||
|
||||
async def download_cuda_binary(version: Optional[str] = None):
|
||||
"""Download the CUDA backend (server core + CUDA libs if needed).
|
||||
|
||||
Downloads both archives from GitHub Releases, extracts them into
|
||||
{data_dir}/backends/cuda/, and writes the cuda-libs.json manifest.
|
||||
|
||||
Only downloads what's needed:
|
||||
- Server core: always redownloaded (versioned with app)
|
||||
- CUDA libs: only if missing or version mismatch
|
||||
|
||||
Args:
|
||||
version: Version tag (e.g. "v0.3.0"). Defaults to current app version.
|
||||
"""
|
||||
if _download_lock.locked():
|
||||
logger.info("CUDA download already in progress, skipping duplicate request")
|
||||
return
|
||||
async with _download_lock:
|
||||
await _download_cuda_binary_locked(version)
|
||||
|
||||
|
||||
async def _download_cuda_binary_locked(version: Optional[str] = None):
|
||||
"""Inner implementation of download_cuda_binary, called under _download_lock."""
|
||||
import httpx
|
||||
|
||||
if version is None:
|
||||
version = f"v{__version__}"
|
||||
|
||||
progress = get_progress_manager()
|
||||
cuda_dir = get_cuda_dir()
|
||||
|
||||
need_server = _needs_server_download(version)
|
||||
need_libs = _needs_cuda_libs_download()
|
||||
|
||||
if not need_server and not need_libs:
|
||||
logger.info("CUDA backend is up to date, nothing to download")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Starting CUDA backend download for {version} "
|
||||
f"(server={'yes' if need_server else 'cached'}, "
|
||||
f"libs={'yes' if need_libs else 'cached'})"
|
||||
)
|
||||
progress.update_progress(
|
||||
PROGRESS_KEY,
|
||||
current=0,
|
||||
total=0,
|
||||
filename="Preparing download...",
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
base_url = f"{GITHUB_RELEASES_URL}/{version}"
|
||||
server_archive = "voicebox-server-cuda.tar.gz"
|
||||
libs_archive = f"cuda-libs-{CUDA_LIBS_VERSION}.tar.gz"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(follow_redirects=True, timeout=30.0) as client:
|
||||
# Estimate total download size
|
||||
total_size = 0
|
||||
if need_server:
|
||||
try:
|
||||
head = await client.head(f"{base_url}/{server_archive}")
|
||||
total_size += int(head.headers.get("content-length", 0))
|
||||
except Exception:
|
||||
pass
|
||||
if need_libs:
|
||||
try:
|
||||
head = await client.head(f"{base_url}/{libs_archive}")
|
||||
total_size += int(head.headers.get("content-length", 0))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"Total download size: {total_size / 1024 / 1024:.1f} MB")
|
||||
|
||||
offset = 0
|
||||
|
||||
# Download server core
|
||||
if need_server:
|
||||
server_downloaded = await _download_and_extract_archive(
|
||||
client,
|
||||
url=f"{base_url}/{server_archive}",
|
||||
sha256_url=f"{base_url}/{server_archive}.sha256",
|
||||
dest_dir=cuda_dir,
|
||||
label="CUDA server",
|
||||
progress_offset=offset,
|
||||
total_size=total_size,
|
||||
)
|
||||
offset += server_downloaded
|
||||
|
||||
# Make executable on Unix
|
||||
exe_path = cuda_dir / get_cuda_exe_name()
|
||||
if sys.platform != "win32" and exe_path.exists():
|
||||
exe_path.chmod(0o755)
|
||||
|
||||
# Download CUDA libs
|
||||
if need_libs:
|
||||
await _download_and_extract_archive(
|
||||
client,
|
||||
url=f"{base_url}/{libs_archive}",
|
||||
sha256_url=f"{base_url}/{libs_archive}.sha256",
|
||||
dest_dir=cuda_dir,
|
||||
label="CUDA libraries",
|
||||
progress_offset=offset,
|
||||
total_size=total_size,
|
||||
)
|
||||
|
||||
# Write local cuda-libs.json manifest
|
||||
manifest = {"version": CUDA_LIBS_VERSION}
|
||||
get_cuda_libs_manifest_path().write_text(json.dumps(manifest, indent=2) + "\n")
|
||||
|
||||
logger.info(f"CUDA backend ready at {cuda_dir}")
|
||||
progress.mark_complete(PROGRESS_KEY)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CUDA backend download failed: {e}")
|
||||
progress.mark_error(PROGRESS_KEY, str(e))
|
||||
raise
|
||||
|
||||
|
||||
def get_cuda_binary_version() -> Optional[str]:
|
||||
"""Get the version of the installed CUDA binary, or None if not installed."""
|
||||
import subprocess
|
||||
|
||||
cuda_path = get_cuda_binary_path()
|
||||
if not cuda_path:
|
||||
return None
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[str(cuda_path), "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
cwd=str(cuda_path.parent), # Run from the onedir directory
|
||||
)
|
||||
# Output format: "voicebox-server 0.3.0"
|
||||
for line in result.stdout.strip().splitlines():
|
||||
if "voicebox-server" in line:
|
||||
return line.split()[-1]
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get CUDA binary version: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def check_and_update_cuda_binary():
|
||||
"""Check if the CUDA binary is outdated and auto-download if so.
|
||||
|
||||
Called on server startup. Checks both server version and CUDA libs
|
||||
version. Downloads only what's needed.
|
||||
"""
|
||||
cuda_path = get_cuda_binary_path()
|
||||
if not cuda_path:
|
||||
return # No CUDA binary installed, nothing to update
|
||||
|
||||
need_server = _needs_server_download()
|
||||
need_libs = _needs_cuda_libs_download()
|
||||
|
||||
if not need_server and not need_libs:
|
||||
logger.info(f"CUDA binary is up to date (server=v{__version__}, libs={get_installed_cuda_libs_version()})")
|
||||
return
|
||||
|
||||
reasons = []
|
||||
if need_server:
|
||||
cuda_version = get_cuda_binary_version()
|
||||
reasons.append(f"server v{cuda_version} != v{__version__}")
|
||||
if need_libs:
|
||||
installed_libs = get_installed_cuda_libs_version()
|
||||
reasons.append(f"libs {installed_libs} != {CUDA_LIBS_VERSION}")
|
||||
|
||||
logger.info(f"CUDA backend needs update ({', '.join(reasons)}). Auto-downloading...")
|
||||
|
||||
try:
|
||||
await download_cuda_binary()
|
||||
except Exception as e:
|
||||
logger.error(f"Auto-update of CUDA binary failed: {e}")
|
||||
|
||||
|
||||
async def delete_cuda_binary() -> bool:
|
||||
"""Delete the downloaded CUDA backend directory. Returns True if deleted."""
|
||||
import shutil
|
||||
|
||||
cuda_dir = get_cuda_dir()
|
||||
if cuda_dir.exists() and any(cuda_dir.iterdir()):
|
||||
shutil.rmtree(cuda_dir)
|
||||
logger.info(f"Deleted CUDA backend directory: {cuda_dir}")
|
||||
return True
|
||||
return False
|
||||
120
backend/services/effects.py
Normal file
120
backend/services/effects.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Effect presets CRUD operations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ..utils.effects import validate_effects_chain
|
||||
|
||||
from ..database import EffectPreset as DBEffectPreset
|
||||
from ..models import EffectPresetResponse, EffectPresetCreate, EffectPresetUpdate, EffectConfig
|
||||
|
||||
|
||||
def _preset_response(p: DBEffectPreset) -> EffectPresetResponse:
|
||||
"""Convert a DB preset row to a Pydantic response."""
|
||||
effects_chain = [EffectConfig(**e) for e in json.loads(p.effects_chain)]
|
||||
return EffectPresetResponse(
|
||||
id=p.id,
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
effects_chain=effects_chain,
|
||||
is_builtin=p.is_builtin or False,
|
||||
created_at=p.created_at,
|
||||
)
|
||||
|
||||
|
||||
def list_presets(db: Session) -> List[EffectPresetResponse]:
|
||||
"""List all effect presets (built-in + user-created)."""
|
||||
presets = db.query(DBEffectPreset).order_by(DBEffectPreset.sort_order, DBEffectPreset.name).all()
|
||||
return [_preset_response(p) for p in presets]
|
||||
|
||||
|
||||
def get_preset(preset_id: str, db: Session) -> Optional[EffectPresetResponse]:
|
||||
"""Get a preset by ID."""
|
||||
p = db.query(DBEffectPreset).filter_by(id=preset_id).first()
|
||||
if not p:
|
||||
return None
|
||||
return _preset_response(p)
|
||||
|
||||
|
||||
def get_preset_by_name(name: str, db: Session) -> Optional[EffectPresetResponse]:
|
||||
"""Get a preset by name."""
|
||||
p = db.query(DBEffectPreset).filter_by(name=name).first()
|
||||
if not p:
|
||||
return None
|
||||
return _preset_response(p)
|
||||
|
||||
|
||||
def create_preset(data: EffectPresetCreate, db: Session) -> EffectPresetResponse:
|
||||
"""Create a new user effect preset."""
|
||||
|
||||
chain_dicts = [e.model_dump() for e in data.effects_chain]
|
||||
error = validate_effects_chain(chain_dicts)
|
||||
if error:
|
||||
raise ValueError(error)
|
||||
|
||||
# Check for duplicate name before insert
|
||||
existing = db.query(DBEffectPreset).filter_by(name=data.name).first()
|
||||
if existing:
|
||||
raise ValueError(f"A preset named '{data.name}' already exists")
|
||||
|
||||
preset = DBEffectPreset(
|
||||
id=str(uuid.uuid4()),
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
effects_chain=json.dumps(chain_dicts),
|
||||
is_builtin=False,
|
||||
)
|
||||
db.add(preset)
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError:
|
||||
db.rollback()
|
||||
raise ValueError(f"A preset named '{data.name}' already exists")
|
||||
db.refresh(preset)
|
||||
return _preset_response(preset)
|
||||
|
||||
|
||||
def update_preset(preset_id: str, data: EffectPresetUpdate, db: Session) -> Optional[EffectPresetResponse]:
|
||||
"""Update a user effect preset. Cannot modify built-in presets."""
|
||||
preset = db.query(DBEffectPreset).filter_by(id=preset_id).first()
|
||||
if not preset:
|
||||
return None
|
||||
if preset.is_builtin:
|
||||
raise ValueError("Cannot modify built-in presets")
|
||||
|
||||
if data.name is not None:
|
||||
preset.name = data.name
|
||||
if data.description is not None:
|
||||
preset.description = data.description
|
||||
if data.effects_chain is not None:
|
||||
|
||||
chain_dicts = [e.model_dump() for e in data.effects_chain]
|
||||
error = validate_effects_chain(chain_dicts)
|
||||
if error:
|
||||
raise ValueError(error)
|
||||
preset.effects_chain = json.dumps(chain_dicts)
|
||||
|
||||
db.commit()
|
||||
db.refresh(preset)
|
||||
return _preset_response(preset)
|
||||
|
||||
|
||||
def delete_preset(preset_id: str, db: Session) -> bool:
|
||||
"""Delete a user effect preset. Cannot delete built-in presets."""
|
||||
preset = db.query(DBEffectPreset).filter_by(id=preset_id).first()
|
||||
if not preset:
|
||||
return False
|
||||
if preset.is_builtin:
|
||||
raise ValueError("Cannot delete built-in presets")
|
||||
|
||||
db.delete(preset)
|
||||
db.commit()
|
||||
return True
|
||||
461
backend/services/export_import.py
Normal file
461
backend/services/export_import.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
Voice profile export/import module.
|
||||
|
||||
Handles exporting profiles to ZIP archives and importing them back.
|
||||
Also handles exporting individual generations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import zipfile
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..models import VoiceProfileResponse
|
||||
from ..database import VoiceProfile as DBVoiceProfile, ProfileSample as DBProfileSample, Generation as DBGeneration, GenerationVersion as DBGenerationVersion
|
||||
from .profiles import create_profile, add_profile_sample
|
||||
from ..models import VoiceProfileCreate
|
||||
from .. import config
|
||||
|
||||
|
||||
def _get_unique_profile_name(name: str, db: Session) -> str:
|
||||
"""
|
||||
Get a unique profile name by appending a number if needed.
|
||||
|
||||
Args:
|
||||
name: Original profile name
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Unique profile name
|
||||
"""
|
||||
base_name = name
|
||||
counter = 1
|
||||
|
||||
while True:
|
||||
existing = db.query(DBVoiceProfile).filter_by(name=name).first()
|
||||
if not existing:
|
||||
return name
|
||||
|
||||
name = f"{base_name} ({counter})"
|
||||
counter += 1
|
||||
|
||||
|
||||
def export_profile_to_zip(profile_id: str, db: Session) -> bytes:
|
||||
"""
|
||||
Export a voice profile to a ZIP archive.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID to export
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ZIP file contents as bytes
|
||||
|
||||
Raises:
|
||||
ValueError: If profile not found or has no samples
|
||||
"""
|
||||
# Get profile
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile {profile_id} not found")
|
||||
|
||||
# Get all samples
|
||||
samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all()
|
||||
if not samples:
|
||||
raise ValueError(f"Profile {profile_id} has no samples")
|
||||
|
||||
# Create ZIP in memory
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
||||
# Check if profile has avatar
|
||||
has_avatar = False
|
||||
if profile.avatar_path:
|
||||
avatar_path = config.resolve_storage_path(profile.avatar_path)
|
||||
if avatar_path is not None and avatar_path.exists():
|
||||
has_avatar = True
|
||||
# Add avatar to ZIP root with original extension
|
||||
avatar_ext = avatar_path.suffix
|
||||
zip_file.write(avatar_path, f"avatar{avatar_ext}")
|
||||
|
||||
# Create manifest.json
|
||||
manifest = {
|
||||
"version": "1.0",
|
||||
"profile": {
|
||||
"name": profile.name,
|
||||
"description": profile.description,
|
||||
"language": profile.language,
|
||||
},
|
||||
"has_avatar": has_avatar,
|
||||
}
|
||||
zip_file.writestr("manifest.json", json.dumps(manifest, indent=2))
|
||||
|
||||
# Create samples.json mapping
|
||||
samples_data = {}
|
||||
profile_dir = config.get_profiles_dir() / profile_id
|
||||
|
||||
for sample in samples:
|
||||
# Get filename from audio_path (should be {sample_id}.wav)
|
||||
audio_path = config.resolve_storage_path(sample.audio_path)
|
||||
if audio_path is None:
|
||||
raise ValueError(f"Audio file not found: {sample.audio_path}")
|
||||
filename = audio_path.name
|
||||
|
||||
# Read audio file
|
||||
if not audio_path.exists():
|
||||
raise ValueError(f"Audio file not found: {audio_path}")
|
||||
|
||||
# Add to samples directory in ZIP
|
||||
zip_path = f"samples/{filename}"
|
||||
zip_file.write(audio_path, zip_path)
|
||||
|
||||
# Map filename to reference text
|
||||
samples_data[filename] = sample.reference_text
|
||||
|
||||
zip_file.writestr("samples.json", json.dumps(samples_data, indent=2))
|
||||
|
||||
zip_buffer.seek(0)
|
||||
return zip_buffer.read()
|
||||
|
||||
|
||||
async def import_profile_from_zip(file_bytes: bytes, db: Session) -> VoiceProfileResponse:
|
||||
"""
|
||||
Import a voice profile from a ZIP archive.
|
||||
|
||||
Args:
|
||||
file_bytes: ZIP file contents
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created profile
|
||||
|
||||
Raises:
|
||||
ValueError: If ZIP is invalid or missing required files
|
||||
"""
|
||||
zip_buffer = io.BytesIO(file_bytes)
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_buffer, 'r') as zip_file:
|
||||
# Validate ZIP structure
|
||||
namelist = zip_file.namelist()
|
||||
|
||||
if "manifest.json" not in namelist:
|
||||
raise ValueError("ZIP archive missing manifest.json")
|
||||
|
||||
if "samples.json" not in namelist:
|
||||
raise ValueError("ZIP archive missing samples.json")
|
||||
|
||||
# Read manifest
|
||||
manifest_data = json.loads(zip_file.read("manifest.json"))
|
||||
|
||||
if "version" not in manifest_data:
|
||||
raise ValueError("Invalid manifest.json: missing version")
|
||||
|
||||
if "profile" not in manifest_data:
|
||||
raise ValueError("Invalid manifest.json: missing profile")
|
||||
|
||||
profile_data = manifest_data["profile"]
|
||||
|
||||
# Read samples mapping
|
||||
samples_data = json.loads(zip_file.read("samples.json"))
|
||||
|
||||
if not isinstance(samples_data, dict):
|
||||
raise ValueError("Invalid samples.json: must be a dictionary")
|
||||
|
||||
# Get unique profile name
|
||||
original_name = profile_data.get("name", "Imported Profile")
|
||||
unique_name = _get_unique_profile_name(original_name, db)
|
||||
|
||||
# Create profile
|
||||
profile_create = VoiceProfileCreate(
|
||||
name=unique_name,
|
||||
description=profile_data.get("description"),
|
||||
language=profile_data.get("language", "en"),
|
||||
)
|
||||
|
||||
profile = await create_profile(profile_create, db)
|
||||
|
||||
# Extract and add samples
|
||||
profile_dir = config.get_profiles_dir() / profile.id
|
||||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Handle avatar if present
|
||||
avatar_files = [f for f in namelist if f.startswith("avatar.")]
|
||||
if avatar_files:
|
||||
try:
|
||||
avatar_file = avatar_files[0]
|
||||
# Extract to temporary file
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=Path(avatar_file).suffix, delete=False) as tmp:
|
||||
tmp.write(zip_file.read(avatar_file))
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
from .profiles import upload_avatar
|
||||
await upload_avatar(profile.id, tmp_path, db)
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
# Avatar import is optional - continue even if it fails
|
||||
pass
|
||||
|
||||
for filename, reference_text in samples_data.items():
|
||||
# Validate filename
|
||||
if not filename.endswith('.wav'):
|
||||
raise ValueError(f"Invalid sample filename: {filename} (must be .wav)")
|
||||
|
||||
# Extract audio file to temp location
|
||||
zip_path = f"samples/{filename}"
|
||||
|
||||
if zip_path not in namelist:
|
||||
raise ValueError(f"Sample file not found in ZIP: {zip_path}")
|
||||
|
||||
# Extract to temporary file
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp.write(zip_file.read(zip_path))
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
# Add sample to profile
|
||||
await add_profile_sample(
|
||||
profile.id,
|
||||
tmp_path,
|
||||
reference_text,
|
||||
db,
|
||||
)
|
||||
finally:
|
||||
# Clean up temp file
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
return profile
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
raise ValueError("Invalid ZIP file")
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in archive: {e}")
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Error importing profile: {str(e)}")
|
||||
|
||||
|
||||
def export_generation_to_zip(generation_id: str, db: Session) -> bytes:
|
||||
"""
|
||||
Export a generation to a ZIP archive.
|
||||
|
||||
Args:
|
||||
generation_id: Generation ID to export
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ZIP file contents as bytes
|
||||
|
||||
Raises:
|
||||
ValueError: If generation not found
|
||||
"""
|
||||
# Get generation
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
raise ValueError(f"Generation {generation_id} not found")
|
||||
|
||||
# Get profile info
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile {generation.profile_id} not found")
|
||||
|
||||
# Get all versions for this generation
|
||||
versions = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=generation_id)
|
||||
.order_by(DBGenerationVersion.created_at)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create ZIP in memory
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
||||
# Build version manifest entries
|
||||
version_entries = []
|
||||
for v in versions:
|
||||
v_path = config.resolve_storage_path(v.audio_path)
|
||||
effects_chain = None
|
||||
if v.effects_chain:
|
||||
effects_chain = json.loads(v.effects_chain)
|
||||
version_entries.append({
|
||||
"id": v.id,
|
||||
"label": v.label,
|
||||
"is_default": v.is_default,
|
||||
"effects_chain": effects_chain,
|
||||
"filename": v_path.name,
|
||||
})
|
||||
|
||||
manifest = {
|
||||
"version": "1.0",
|
||||
"generation": {
|
||||
"id": generation.id,
|
||||
"text": generation.text,
|
||||
"language": generation.language,
|
||||
"duration": generation.duration,
|
||||
"seed": generation.seed,
|
||||
"instruct": generation.instruct,
|
||||
"created_at": generation.created_at.isoformat(),
|
||||
},
|
||||
"profile": {
|
||||
"id": profile.id,
|
||||
"name": profile.name,
|
||||
"description": profile.description,
|
||||
"language": profile.language,
|
||||
},
|
||||
"versions": version_entries,
|
||||
}
|
||||
zip_file.writestr("manifest.json", json.dumps(manifest, indent=2))
|
||||
|
||||
# Add all version audio files
|
||||
for v in versions:
|
||||
v_path = config.resolve_storage_path(v.audio_path)
|
||||
if v_path is not None and v_path.exists():
|
||||
zip_file.write(v_path, f"audio/{v_path.name}")
|
||||
|
||||
# Fallback: if no versions exist, include the generation's main audio
|
||||
if not versions:
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
zip_file.write(audio_path, f"audio/{audio_path.name}")
|
||||
|
||||
zip_buffer.seek(0)
|
||||
return zip_buffer.read()
|
||||
|
||||
|
||||
async def import_generation_from_zip(file_bytes: bytes, db: Session) -> dict:
|
||||
"""
|
||||
Import a generation from a ZIP archive.
|
||||
|
||||
Args:
|
||||
file_bytes: ZIP file contents
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary with generation ID and profile info
|
||||
|
||||
Raises:
|
||||
ValueError: If ZIP is invalid or missing required files
|
||||
"""
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from .. import config
|
||||
|
||||
zip_buffer = io.BytesIO(file_bytes)
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_buffer, 'r') as zip_file:
|
||||
# Validate ZIP structure
|
||||
namelist = zip_file.namelist()
|
||||
|
||||
if "manifest.json" not in namelist:
|
||||
raise ValueError("ZIP archive missing manifest.json")
|
||||
|
||||
# Read manifest
|
||||
manifest_data = json.loads(zip_file.read("manifest.json"))
|
||||
|
||||
if "version" not in manifest_data:
|
||||
raise ValueError("Invalid manifest.json: missing version")
|
||||
|
||||
if "generation" not in manifest_data:
|
||||
raise ValueError("Invalid manifest.json: missing generation data")
|
||||
|
||||
generation_data = manifest_data["generation"]
|
||||
profile_data = manifest_data.get("profile", {})
|
||||
|
||||
# Validate required fields
|
||||
required_fields = ["text", "language", "duration"]
|
||||
for field in required_fields:
|
||||
if field not in generation_data:
|
||||
raise ValueError(f"Invalid manifest.json: missing generation.{field}")
|
||||
|
||||
# Find audio file in archive
|
||||
audio_files = [f for f in namelist if f.startswith("audio/") and f.endswith(".wav")]
|
||||
if not audio_files:
|
||||
raise ValueError("No audio file found in ZIP archive")
|
||||
|
||||
audio_file_path = audio_files[0]
|
||||
|
||||
# Check if we should match an existing profile or create metadata
|
||||
profile_id = None
|
||||
profile_name = profile_data.get("name", "Unknown Profile")
|
||||
|
||||
# Try to find matching profile by name
|
||||
if profile_name and profile_name != "Unknown Profile":
|
||||
existing_profile = db.query(DBVoiceProfile).filter_by(name=profile_name).first()
|
||||
if existing_profile:
|
||||
profile_id = existing_profile.id
|
||||
|
||||
# If no matching profile, use a placeholder or the first available profile
|
||||
if not profile_id:
|
||||
# Get any profile, or None if no profiles exist
|
||||
any_profile = db.query(DBVoiceProfile).first()
|
||||
if any_profile:
|
||||
profile_id = any_profile.id
|
||||
profile_name = any_profile.name
|
||||
else:
|
||||
raise ValueError("No voice profiles found. Please create a profile before importing generations.")
|
||||
|
||||
# Extract audio file to temporary location
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp.write(zip_file.read(audio_file_path))
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
# Create generations directory
|
||||
generations_dir = config.get_generations_dir()
|
||||
generations_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate new ID for this generation
|
||||
new_generation_id = str(__import__('uuid').uuid4())
|
||||
|
||||
# Copy audio to generations directory
|
||||
audio_dest = generations_dir / f"{new_generation_id}.wav"
|
||||
shutil.copy(tmp_path, audio_dest)
|
||||
|
||||
# Create generation record
|
||||
db_generation = DBGeneration(
|
||||
id=new_generation_id,
|
||||
profile_id=profile_id,
|
||||
text=generation_data["text"],
|
||||
language=generation_data["language"],
|
||||
audio_path=config.to_storage_path(audio_dest),
|
||||
duration=generation_data["duration"],
|
||||
seed=generation_data.get("seed"),
|
||||
instruct=generation_data.get("instruct"),
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(db_generation)
|
||||
db.commit()
|
||||
db.refresh(db_generation)
|
||||
|
||||
return {
|
||||
"id": db_generation.id,
|
||||
"profile_id": profile_id,
|
||||
"profile_name": profile_name,
|
||||
"text": db_generation.text,
|
||||
"message": f"Generation imported successfully (assigned to profile: {profile_name})"
|
||||
}
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
raise ValueError("Invalid ZIP file")
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON in archive: {e}")
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Error importing generation: {str(e)}")
|
||||
263
backend/services/generation.py
Normal file
263
backend/services/generation.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Unified TTS generation orchestration.
|
||||
|
||||
Replaces the three near-identical closures (_run_generation, _run_retry,
|
||||
_run_regenerate) that lived in main.py with a single ``run_generation()``
|
||||
function parameterized by *mode*.
|
||||
|
||||
Mode differences:
|
||||
- "generate" : full pipeline -- save clean version, optionally apply
|
||||
effects and create a processed version.
|
||||
- "retry" : re-runs a failed generation with the same seed.
|
||||
No effects, no version creation.
|
||||
- "regenerate" : re-runs with seed=None for variation. Creates a new
|
||||
version with an auto-incremented "take-N" label.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Literal, Optional
|
||||
|
||||
from .. import config
|
||||
from . import history, profiles
|
||||
from ..database import get_db
|
||||
from ..utils.tasks import get_task_manager
|
||||
|
||||
|
||||
async def run_generation(
|
||||
*,
|
||||
generation_id: str,
|
||||
profile_id: str,
|
||||
text: str,
|
||||
language: str,
|
||||
engine: str,
|
||||
model_size: str,
|
||||
seed: Optional[int],
|
||||
normalize: bool = False,
|
||||
effects_chain: Optional[list] = None,
|
||||
instruct: Optional[str] = None,
|
||||
mode: Literal["generate", "retry", "regenerate"],
|
||||
max_chunk_chars: Optional[int] = None,
|
||||
crossfade_ms: Optional[int] = None,
|
||||
version_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Execute TTS inference and persist the result.
|
||||
|
||||
This is the single entry point for all background generation work.
|
||||
It is designed to be enqueued via ``services.task_queue.enqueue_generation``.
|
||||
"""
|
||||
from ..backends import load_engine_model, get_tts_backend_for_engine, engine_needs_trim
|
||||
from ..utils.chunked_tts import generate_chunked
|
||||
from ..utils.audio import normalize_audio, save_audio, trim_tts_output
|
||||
|
||||
task_manager = get_task_manager()
|
||||
bg_db = next(get_db())
|
||||
|
||||
try:
|
||||
tts_model = get_tts_backend_for_engine(engine)
|
||||
|
||||
if not tts_model.is_loaded():
|
||||
await history.update_generation_status(generation_id, "loading_model", bg_db)
|
||||
|
||||
await load_engine_model(engine, model_size)
|
||||
|
||||
voice_prompt = await profiles.create_voice_prompt_for_profile(
|
||||
profile_id,
|
||||
bg_db,
|
||||
use_cache=True,
|
||||
engine=engine,
|
||||
)
|
||||
|
||||
await history.update_generation_status(generation_id, "generating", bg_db)
|
||||
trim_fn = trim_tts_output if engine_needs_trim(engine) else None
|
||||
|
||||
gen_kwargs: dict = dict(
|
||||
language=language,
|
||||
seed=seed if mode != "regenerate" else None,
|
||||
instruct=instruct,
|
||||
trim_fn=trim_fn,
|
||||
)
|
||||
if max_chunk_chars is not None:
|
||||
gen_kwargs["max_chunk_chars"] = max_chunk_chars
|
||||
if crossfade_ms is not None:
|
||||
gen_kwargs["crossfade_ms"] = crossfade_ms
|
||||
|
||||
audio, sample_rate = await generate_chunked(tts_model, text, voice_prompt, **gen_kwargs)
|
||||
|
||||
# --- Normalize (generate and regenerate always; retry skips) -----
|
||||
if normalize or mode == "regenerate":
|
||||
audio = normalize_audio(audio)
|
||||
|
||||
duration = len(audio) / sample_rate
|
||||
|
||||
# --- Persist audio and update status -----------------------------
|
||||
if mode == "generate":
|
||||
final_path = _save_generate(
|
||||
generation_id=generation_id,
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
effects_chain=effects_chain,
|
||||
save_audio=save_audio,
|
||||
db=bg_db,
|
||||
)
|
||||
elif mode == "retry":
|
||||
final_path = _save_retry(
|
||||
generation_id=generation_id,
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
save_audio=save_audio,
|
||||
)
|
||||
elif mode == "regenerate":
|
||||
final_path = _save_regenerate(
|
||||
generation_id=generation_id,
|
||||
version_id=version_id,
|
||||
audio=audio,
|
||||
sample_rate=sample_rate,
|
||||
save_audio=save_audio,
|
||||
db=bg_db,
|
||||
)
|
||||
|
||||
await history.update_generation_status(
|
||||
generation_id=generation_id,
|
||||
status="completed",
|
||||
db=bg_db,
|
||||
audio_path=final_path,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
await history.update_generation_status(
|
||||
generation_id=generation_id,
|
||||
status="failed",
|
||||
db=bg_db,
|
||||
error="Generation cancelled",
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
await history.update_generation_status(
|
||||
generation_id=generation_id,
|
||||
status="failed",
|
||||
db=bg_db,
|
||||
error=str(e),
|
||||
)
|
||||
finally:
|
||||
task_manager.complete_generation(generation_id)
|
||||
bg_db.close()
|
||||
|
||||
|
||||
def _save_generate(
|
||||
*,
|
||||
generation_id: str,
|
||||
audio,
|
||||
sample_rate: int,
|
||||
effects_chain: Optional[list],
|
||||
save_audio,
|
||||
db,
|
||||
) -> str:
|
||||
"""Save clean version and optionally an effects-processed version.
|
||||
|
||||
Returns the final audio path (processed if effects were applied,
|
||||
otherwise clean).
|
||||
"""
|
||||
from . import versions as versions_mod
|
||||
|
||||
clean_audio_path = config.get_generations_dir() / f"{generation_id}.wav"
|
||||
save_audio(audio, str(clean_audio_path), sample_rate)
|
||||
|
||||
has_effects = effects_chain and any(e.get("enabled", True) for e in effects_chain)
|
||||
|
||||
versions_mod.create_version(
|
||||
generation_id=generation_id,
|
||||
label="original",
|
||||
audio_path=config.to_storage_path(clean_audio_path),
|
||||
db=db,
|
||||
effects_chain=None,
|
||||
is_default=not has_effects,
|
||||
)
|
||||
|
||||
final_audio_path = str(clean_audio_path)
|
||||
|
||||
if has_effects:
|
||||
from ..utils.effects import apply_effects, validate_effects_chain
|
||||
|
||||
assert effects_chain is not None
|
||||
|
||||
error_msg = validate_effects_chain(effects_chain)
|
||||
if error_msg:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning("invalid effects chain, skipping: %s", error_msg)
|
||||
versions_mod.set_default_version(
|
||||
versions_mod.list_versions(generation_id, db)[0].id, db
|
||||
)
|
||||
else:
|
||||
processed_audio = apply_effects(audio, sample_rate, effects_chain)
|
||||
processed_path = config.get_generations_dir() / f"{generation_id}_processed.wav"
|
||||
save_audio(processed_audio, str(processed_path), sample_rate)
|
||||
final_audio_path = str(processed_path)
|
||||
versions_mod.create_version(
|
||||
generation_id=generation_id,
|
||||
label="version-2",
|
||||
audio_path=config.to_storage_path(processed_path),
|
||||
db=db,
|
||||
effects_chain=effects_chain,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
return config.to_storage_path(final_audio_path)
|
||||
|
||||
|
||||
def _save_retry(
|
||||
*,
|
||||
generation_id: str,
|
||||
audio,
|
||||
sample_rate: int,
|
||||
save_audio,
|
||||
) -> str:
|
||||
"""Save retry output -- single file, no versions.
|
||||
|
||||
Returns the audio path.
|
||||
"""
|
||||
audio_path = config.get_generations_dir() / f"{generation_id}.wav"
|
||||
save_audio(audio, str(audio_path), sample_rate)
|
||||
return config.to_storage_path(audio_path)
|
||||
|
||||
|
||||
def _save_regenerate(
|
||||
*,
|
||||
generation_id: str,
|
||||
version_id: Optional[str],
|
||||
audio,
|
||||
sample_rate: int,
|
||||
save_audio,
|
||||
db,
|
||||
) -> str:
|
||||
"""Save regeneration output as a new version with auto-label.
|
||||
|
||||
Returns the audio path.
|
||||
"""
|
||||
from . import versions as versions_mod
|
||||
|
||||
import uuid as _uuid
|
||||
|
||||
suffix = _uuid.uuid4().hex[:8]
|
||||
audio_path = config.get_generations_dir() / f"{generation_id}_{suffix}.wav"
|
||||
save_audio(audio, str(audio_path), sample_rate)
|
||||
|
||||
# Count via DB query rather than list length to avoid TOCTOU race
|
||||
from ..database import GenerationVersion as DBGenerationVersion
|
||||
|
||||
count = db.query(DBGenerationVersion).filter_by(generation_id=generation_id).count()
|
||||
label = f"take-{count + 1}"
|
||||
|
||||
versions_mod.create_version(
|
||||
generation_id=generation_id,
|
||||
label=label,
|
||||
audio_path=config.to_storage_path(audio_path),
|
||||
db=db,
|
||||
effects_chain=None,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
return config.to_storage_path(audio_path)
|
||||
368
backend/services/history.py
Normal file
368
backend/services/history.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
Generation history management module.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_
|
||||
|
||||
from ..models import GenerationRequest, GenerationResponse, HistoryQuery, HistoryResponse, HistoryListResponse, GenerationVersionResponse, EffectConfig
|
||||
from ..database import Generation as DBGeneration, GenerationVersion as DBGenerationVersion, VoiceProfile as DBVoiceProfile
|
||||
from .. import config
|
||||
|
||||
|
||||
def _get_versions_for_generation(generation_id: str, db: Session) -> tuple:
|
||||
"""Get versions list and active version ID for a generation."""
|
||||
import json
|
||||
versions_rows = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=generation_id)
|
||||
.order_by(DBGenerationVersion.created_at)
|
||||
.all()
|
||||
)
|
||||
if not versions_rows:
|
||||
return None, None
|
||||
|
||||
versions = []
|
||||
active_version_id = None
|
||||
for v in versions_rows:
|
||||
effects_chain = None
|
||||
if v.effects_chain:
|
||||
try:
|
||||
raw = json.loads(v.effects_chain)
|
||||
effects_chain = [EffectConfig(**e) for e in raw]
|
||||
except Exception:
|
||||
pass
|
||||
versions.append(GenerationVersionResponse(
|
||||
id=v.id,
|
||||
generation_id=v.generation_id,
|
||||
label=v.label,
|
||||
audio_path=v.audio_path,
|
||||
effects_chain=effects_chain,
|
||||
is_default=v.is_default,
|
||||
created_at=v.created_at,
|
||||
))
|
||||
if v.is_default:
|
||||
active_version_id = v.id
|
||||
|
||||
return versions, active_version_id
|
||||
|
||||
|
||||
async def create_generation(
|
||||
profile_id: str,
|
||||
text: str,
|
||||
language: str,
|
||||
audio_path: str,
|
||||
duration: float,
|
||||
seed: Optional[int],
|
||||
db: Session,
|
||||
instruct: Optional[str] = None,
|
||||
generation_id: Optional[str] = None,
|
||||
status: str = "completed",
|
||||
engine: Optional[str] = "qwen",
|
||||
model_size: Optional[str] = None,
|
||||
) -> GenerationResponse:
|
||||
"""
|
||||
Create a new generation history entry.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID used for generation
|
||||
text: Generated text
|
||||
language: Language code
|
||||
audio_path: Path where audio was saved
|
||||
duration: Audio duration in seconds
|
||||
seed: Random seed used (if any)
|
||||
db: Database session
|
||||
instruct: Natural language instruction used (if any)
|
||||
generation_id: Pre-assigned ID (for async generation flow)
|
||||
status: Generation status (generating, completed, failed)
|
||||
engine: TTS engine used (qwen, luxtts, chatterbox, chatterbox_turbo)
|
||||
model_size: Model size variant (1.7B, 0.6B) — only relevant for qwen
|
||||
|
||||
Returns:
|
||||
Created generation entry
|
||||
"""
|
||||
db_generation = DBGeneration(
|
||||
id=generation_id or str(uuid.uuid4()),
|
||||
profile_id=profile_id,
|
||||
text=text,
|
||||
language=language,
|
||||
audio_path=audio_path,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
instruct=instruct,
|
||||
engine=engine,
|
||||
model_size=model_size,
|
||||
status=status,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(db_generation)
|
||||
db.commit()
|
||||
db.refresh(db_generation)
|
||||
|
||||
return GenerationResponse.model_validate(db_generation)
|
||||
|
||||
|
||||
async def update_generation_status(
|
||||
generation_id: str,
|
||||
status: str,
|
||||
db: Session,
|
||||
audio_path: Optional[str] = None,
|
||||
duration: Optional[float] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> Optional[GenerationResponse]:
|
||||
"""Update the status of a generation (used by async generation flow)."""
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
generation.status = status
|
||||
if audio_path is not None:
|
||||
generation.audio_path = audio_path
|
||||
if duration is not None:
|
||||
generation.duration = duration
|
||||
if error is not None:
|
||||
generation.error = error
|
||||
|
||||
db.commit()
|
||||
db.refresh(generation)
|
||||
return GenerationResponse.model_validate(generation)
|
||||
|
||||
|
||||
async def get_generation(
|
||||
generation_id: str,
|
||||
db: Session,
|
||||
) -> Optional[GenerationResponse]:
|
||||
"""
|
||||
Get a generation by ID.
|
||||
|
||||
Args:
|
||||
generation_id: Generation ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Generation or None if not found
|
||||
"""
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
return GenerationResponse.model_validate(generation)
|
||||
|
||||
|
||||
async def list_generations(
|
||||
query: HistoryQuery,
|
||||
db: Session,
|
||||
) -> HistoryListResponse:
|
||||
"""
|
||||
List generations with optional filters.
|
||||
|
||||
Args:
|
||||
query: Query parameters (filters, pagination)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
HistoryListResponse with items and total count
|
||||
"""
|
||||
# Build base query with join to get profile name
|
||||
q = db.query(
|
||||
DBGeneration,
|
||||
DBVoiceProfile.name.label('profile_name')
|
||||
).join(
|
||||
DBVoiceProfile,
|
||||
DBGeneration.profile_id == DBVoiceProfile.id
|
||||
)
|
||||
|
||||
# Apply profile filter
|
||||
if query.profile_id:
|
||||
q = q.filter(DBGeneration.profile_id == query.profile_id)
|
||||
|
||||
# Apply search filter (searches in text content)
|
||||
if query.search:
|
||||
search_pattern = f"%{query.search}%"
|
||||
q = q.filter(DBGeneration.text.like(search_pattern))
|
||||
|
||||
# Get total count before pagination
|
||||
total_count = q.count()
|
||||
|
||||
# Apply ordering (newest first)
|
||||
q = q.order_by(DBGeneration.created_at.desc())
|
||||
|
||||
# Apply pagination
|
||||
q = q.offset(query.offset).limit(query.limit)
|
||||
|
||||
# Execute query
|
||||
results = q.all()
|
||||
|
||||
# Convert to HistoryResponse with profile_name
|
||||
items = []
|
||||
for generation, profile_name in results:
|
||||
versions, active_version_id = _get_versions_for_generation(generation.id, db)
|
||||
items.append(HistoryResponse(
|
||||
id=generation.id,
|
||||
profile_id=generation.profile_id,
|
||||
profile_name=profile_name,
|
||||
text=generation.text,
|
||||
language=generation.language,
|
||||
audio_path=generation.audio_path,
|
||||
duration=generation.duration,
|
||||
seed=generation.seed,
|
||||
instruct=generation.instruct,
|
||||
engine=generation.engine or "qwen",
|
||||
model_size=generation.model_size,
|
||||
status=generation.status or "completed",
|
||||
error=generation.error,
|
||||
is_favorited=bool(generation.is_favorited),
|
||||
created_at=generation.created_at,
|
||||
versions=versions,
|
||||
active_version_id=active_version_id,
|
||||
))
|
||||
|
||||
return HistoryListResponse(
|
||||
items=items,
|
||||
total=total_count,
|
||||
)
|
||||
|
||||
|
||||
async def delete_generation(
|
||||
generation_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a generation.
|
||||
|
||||
Args:
|
||||
generation_id: Generation ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
return False
|
||||
|
||||
# Delete all version files and records
|
||||
from . import versions as versions_mod
|
||||
versions_mod.delete_versions_for_generation(generation_id, db)
|
||||
|
||||
# Delete main audio file (if not already removed by version cleanup)
|
||||
if generation.audio_path:
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
audio_path.unlink()
|
||||
|
||||
# Delete from database
|
||||
db.delete(generation)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def delete_failed_generations(db: Session) -> int:
|
||||
"""
|
||||
Delete every generation whose status is 'failed'.
|
||||
|
||||
Used by the "Clear failed" action in the UI so users can tidy up
|
||||
history after the model wasn't loaded, the app was closed mid-run,
|
||||
or a generation otherwise errored out (see issue #410).
|
||||
|
||||
Returns:
|
||||
Number of generations deleted.
|
||||
"""
|
||||
from . import versions as versions_mod
|
||||
|
||||
failed = db.query(DBGeneration).filter(DBGeneration.status == "failed").all()
|
||||
count = 0
|
||||
for generation in failed:
|
||||
# Clean up version files/rows first.
|
||||
versions_mod.delete_versions_for_generation(generation.id, db)
|
||||
|
||||
# Remove the main audio file if it somehow made it to disk.
|
||||
if generation.audio_path:
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
try:
|
||||
audio_path.unlink()
|
||||
except OSError:
|
||||
# Best-effort cleanup — don't abort the whole sweep
|
||||
# if a single file can't be removed.
|
||||
pass
|
||||
|
||||
db.delete(generation)
|
||||
count += 1
|
||||
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
|
||||
async def delete_generations_by_profile(
|
||||
profile_id: str,
|
||||
db: Session,
|
||||
) -> int:
|
||||
"""
|
||||
Delete all generations for a profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of generations deleted
|
||||
"""
|
||||
generations = db.query(DBGeneration).filter_by(profile_id=profile_id).all()
|
||||
|
||||
count = 0
|
||||
for generation in generations:
|
||||
# Delete associated version files and rows first
|
||||
from . import versions as versions_mod
|
||||
versions_mod.delete_versions_for_generation(generation.id, db)
|
||||
|
||||
# Delete audio file
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
audio_path.unlink()
|
||||
|
||||
# Delete from database
|
||||
db.delete(generation)
|
||||
count += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
return count
|
||||
|
||||
|
||||
async def get_generation_stats(db: Session) -> dict:
|
||||
"""
|
||||
Get generation statistics.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Statistics dictionary
|
||||
"""
|
||||
from sqlalchemy import func
|
||||
|
||||
total = db.query(func.count(DBGeneration.id)).scalar()
|
||||
|
||||
total_duration = db.query(func.sum(DBGeneration.duration)).scalar() or 0
|
||||
|
||||
# Get generations by profile
|
||||
by_profile = db.query(
|
||||
DBGeneration.profile_id,
|
||||
func.count(DBGeneration.id).label('count')
|
||||
).group_by(DBGeneration.profile_id).all()
|
||||
|
||||
return {
|
||||
"total_generations": total,
|
||||
"total_duration_seconds": total_duration,
|
||||
"generations_by_profile": {
|
||||
profile_id: count for profile_id, count in by_profile
|
||||
},
|
||||
}
|
||||
686
backend/services/profiles.py
Normal file
686
backend/services/profiles.py
Normal file
@@ -0,0 +1,686 @@
|
||||
"""Voice profile management module."""
|
||||
|
||||
import json as _json
|
||||
import logging
|
||||
import shutil
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import config
|
||||
from ..database import Generation as DBGeneration, ProfileSample as DBProfileSample, VoiceProfile as DBVoiceProfile
|
||||
from ..models import (
|
||||
EffectConfig,
|
||||
ProfileSampleResponse,
|
||||
VoiceProfileCreate,
|
||||
VoiceProfileResponse,
|
||||
)
|
||||
from ..utils.audio import save_audio, validate_and_load_reference_audio
|
||||
from ..utils.cache import _get_cache_dir, clear_profile_cache
|
||||
from ..utils.images import process_avatar, validate_image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CLONING_ENGINES = {"qwen", "luxtts", "chatterbox", "chatterbox_turbo", "tada"}
|
||||
|
||||
|
||||
def _profile_to_response(
|
||||
profile: DBVoiceProfile,
|
||||
generation_count: int = 0,
|
||||
sample_count: int = 0,
|
||||
) -> VoiceProfileResponse:
|
||||
"""Convert a DB profile to a VoiceProfileResponse, deserializing effects_chain."""
|
||||
effects_chain = None
|
||||
if profile.effects_chain:
|
||||
try:
|
||||
raw = _json.loads(profile.effects_chain)
|
||||
effects_chain = [EffectConfig(**e) for e in raw]
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logging.warning(f"Failed to parse effects_chain for profile {profile.id}: {e}")
|
||||
return VoiceProfileResponse(
|
||||
id=profile.id,
|
||||
name=profile.name,
|
||||
description=profile.description,
|
||||
language=profile.language,
|
||||
avatar_path=profile.avatar_path,
|
||||
effects_chain=effects_chain,
|
||||
voice_type=getattr(profile, "voice_type", None) or "cloned",
|
||||
preset_engine=getattr(profile, "preset_engine", None),
|
||||
preset_voice_id=getattr(profile, "preset_voice_id", None),
|
||||
design_prompt=getattr(profile, "design_prompt", None),
|
||||
default_engine=getattr(profile, "default_engine", None),
|
||||
generation_count=generation_count,
|
||||
sample_count=sample_count,
|
||||
created_at=profile.created_at,
|
||||
updated_at=profile.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _get_preset_voice_ids(engine: str) -> set[str]:
|
||||
if engine == "kokoro":
|
||||
from ..backends.kokoro_backend import KOKORO_VOICES
|
||||
|
||||
return {voice_id for voice_id, _name, _gender, _lang in KOKORO_VOICES}
|
||||
|
||||
if engine == "qwen_custom_voice":
|
||||
from ..backends.qwen_custom_voice_backend import QWEN_CUSTOM_VOICES
|
||||
|
||||
return {voice_id for voice_id, _name, _gender, _lang, _desc in QWEN_CUSTOM_VOICES}
|
||||
|
||||
return set()
|
||||
|
||||
|
||||
def _validate_profile_fields(
|
||||
*,
|
||||
voice_type: str,
|
||||
preset_engine: str | None,
|
||||
preset_voice_id: str | None,
|
||||
design_prompt: str | None,
|
||||
default_engine: str | None,
|
||||
) -> str | None:
|
||||
if voice_type == "preset":
|
||||
if not preset_engine or not preset_voice_id:
|
||||
return "Preset profiles require both preset_engine and preset_voice_id"
|
||||
if default_engine and default_engine != preset_engine:
|
||||
return "Preset profiles must use their preset_engine as default_engine"
|
||||
|
||||
available_voice_ids = _get_preset_voice_ids(preset_engine)
|
||||
if available_voice_ids and preset_voice_id not in available_voice_ids:
|
||||
return f"Preset voice '{preset_voice_id}' is not valid for engine '{preset_engine}'"
|
||||
return None
|
||||
|
||||
if voice_type == "designed":
|
||||
if not design_prompt or not design_prompt.strip():
|
||||
return "Designed profiles require a design_prompt"
|
||||
if preset_engine or preset_voice_id:
|
||||
return "Designed profiles cannot set preset_engine or preset_voice_id"
|
||||
return None
|
||||
|
||||
if preset_engine or preset_voice_id:
|
||||
return "Cloned profiles cannot set preset_engine or preset_voice_id"
|
||||
if design_prompt:
|
||||
return "Cloned profiles cannot set design_prompt"
|
||||
if default_engine and default_engine not in CLONING_ENGINES:
|
||||
return f"Cloned profiles cannot use default engine '{default_engine}'"
|
||||
return None
|
||||
|
||||
|
||||
def validate_profile_engine(profile, engine: str) -> None:
|
||||
voice_type = getattr(profile, "voice_type", None) or "cloned"
|
||||
|
||||
if voice_type == "preset":
|
||||
preset_engine = getattr(profile, "preset_engine", None)
|
||||
preset_voice_id = getattr(profile, "preset_voice_id", None)
|
||||
if not preset_engine or not preset_voice_id:
|
||||
raise ValueError(f"Preset profile {profile.id} is missing preset engine metadata")
|
||||
if preset_engine != engine:
|
||||
raise ValueError(
|
||||
f"Preset profile {profile.id} only supports engine '{preset_engine}', not '{engine}'"
|
||||
)
|
||||
return
|
||||
|
||||
if voice_type == "designed":
|
||||
design_prompt = getattr(profile, "design_prompt", None)
|
||||
if not design_prompt or not design_prompt.strip():
|
||||
raise ValueError(f"Designed profile {profile.id} is missing design_prompt")
|
||||
return
|
||||
|
||||
if engine not in CLONING_ENGINES:
|
||||
raise ValueError(f"Engine '{engine}' does not support cloned voice profiles")
|
||||
|
||||
|
||||
async def create_profile(
|
||||
data: VoiceProfileCreate,
|
||||
db: Session,
|
||||
) -> VoiceProfileResponse:
|
||||
"""
|
||||
Create a new voice profile.
|
||||
|
||||
Args:
|
||||
data: Profile creation data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created profile
|
||||
|
||||
Raises:
|
||||
ValueError: If a profile with the same name already exists
|
||||
"""
|
||||
existing_profile = db.query(DBVoiceProfile).filter_by(name=data.name).first()
|
||||
if existing_profile:
|
||||
raise ValueError(f"A profile with the name '{data.name}' already exists. Please choose a different name.")
|
||||
|
||||
# Auto-set default_engine for preset profiles
|
||||
default_engine = data.default_engine
|
||||
voice_type = data.voice_type or "cloned"
|
||||
if voice_type == "preset" and data.preset_engine and not default_engine:
|
||||
default_engine = data.preset_engine
|
||||
|
||||
validation_error = _validate_profile_fields(
|
||||
voice_type=voice_type,
|
||||
preset_engine=data.preset_engine,
|
||||
preset_voice_id=data.preset_voice_id,
|
||||
design_prompt=data.design_prompt,
|
||||
default_engine=default_engine,
|
||||
)
|
||||
if validation_error:
|
||||
raise ValueError(validation_error)
|
||||
|
||||
db_profile = DBVoiceProfile(
|
||||
id=str(uuid.uuid4()),
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
language=data.language,
|
||||
voice_type=voice_type,
|
||||
preset_engine=data.preset_engine,
|
||||
preset_voice_id=data.preset_voice_id,
|
||||
design_prompt=data.design_prompt,
|
||||
default_engine=default_engine,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(db_profile)
|
||||
db.commit()
|
||||
db.refresh(db_profile)
|
||||
|
||||
profile_dir = config.get_profiles_dir() / db_profile.id
|
||||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return _profile_to_response(db_profile)
|
||||
|
||||
|
||||
async def add_profile_sample(
|
||||
profile_id: str,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
db: Session,
|
||||
) -> ProfileSampleResponse:
|
||||
"""
|
||||
Add a sample to a voice profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
audio_path: Path to temporary audio file
|
||||
reference_text: Transcript of audio
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created sample
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile {profile_id} not found")
|
||||
|
||||
# Validate and load audio in a single pass, off the event loop
|
||||
is_valid, error_msg, audio, sr = await asyncio.to_thread(
|
||||
validate_and_load_reference_audio, audio_path
|
||||
)
|
||||
if not is_valid:
|
||||
raise ValueError(f"Invalid reference audio: {error_msg}")
|
||||
|
||||
sample_id = str(uuid.uuid4())
|
||||
profile_dir = config.get_profiles_dir() / profile_id
|
||||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dest_path = profile_dir / f"{sample_id}.wav"
|
||||
await asyncio.to_thread(save_audio, audio, str(dest_path), sr)
|
||||
|
||||
db_sample = DBProfileSample(
|
||||
id=sample_id,
|
||||
profile_id=profile_id,
|
||||
audio_path=config.to_storage_path(dest_path),
|
||||
reference_text=reference_text,
|
||||
)
|
||||
|
||||
db.add(db_sample)
|
||||
|
||||
profile.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_sample)
|
||||
|
||||
# Invalidate combined audio cache for this profile
|
||||
# Since a new sample was added, any cached combined audio is now stale
|
||||
clear_profile_cache(profile_id)
|
||||
|
||||
return ProfileSampleResponse.model_validate(db_sample)
|
||||
|
||||
|
||||
async def get_profile(
|
||||
profile_id: str,
|
||||
db: Session,
|
||||
) -> VoiceProfileResponse | None:
|
||||
"""
|
||||
Get a voice profile by ID.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Profile or None if not found
|
||||
"""
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
return None
|
||||
|
||||
return _profile_to_response(profile)
|
||||
|
||||
|
||||
async def get_profile_samples(
|
||||
profile_id: str,
|
||||
db: Session,
|
||||
) -> list[ProfileSampleResponse]:
|
||||
"""
|
||||
Get all samples for a profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of samples
|
||||
"""
|
||||
samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all()
|
||||
return [ProfileSampleResponse.model_validate(s) for s in samples]
|
||||
|
||||
|
||||
async def list_profiles(db: Session) -> list[VoiceProfileResponse]:
|
||||
"""
|
||||
List all voice profiles with generation and sample counts.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of profiles
|
||||
"""
|
||||
profiles = db.query(DBVoiceProfile).order_by(DBVoiceProfile.created_at.desc()).all()
|
||||
|
||||
if not profiles:
|
||||
return []
|
||||
|
||||
# Batch-fetch generation counts
|
||||
gen_counts_rows = (
|
||||
db.query(DBGeneration.profile_id, func.count(DBGeneration.id)).group_by(DBGeneration.profile_id).all()
|
||||
)
|
||||
gen_counts = {row[0]: row[1] for row in gen_counts_rows}
|
||||
|
||||
# Batch-fetch sample counts
|
||||
sample_counts_rows = (
|
||||
db.query(DBProfileSample.profile_id, func.count(DBProfileSample.id)).group_by(DBProfileSample.profile_id).all()
|
||||
)
|
||||
sample_counts = {row[0]: row[1] for row in sample_counts_rows}
|
||||
|
||||
return [
|
||||
_profile_to_response(
|
||||
p,
|
||||
generation_count=gen_counts.get(p.id, 0),
|
||||
sample_count=sample_counts.get(p.id, 0),
|
||||
)
|
||||
for p in profiles
|
||||
]
|
||||
|
||||
|
||||
async def update_profile(
|
||||
profile_id: str,
|
||||
data: VoiceProfileCreate,
|
||||
db: Session,
|
||||
) -> VoiceProfileResponse | None:
|
||||
"""
|
||||
Update a voice profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
data: Updated profile data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated profile or None if not found
|
||||
|
||||
Raises:
|
||||
ValueError: If a profile with the same name already exists (different profile)
|
||||
"""
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
return None
|
||||
|
||||
if profile.name != data.name:
|
||||
existing_profile = db.query(DBVoiceProfile).filter_by(name=data.name).first()
|
||||
if existing_profile:
|
||||
raise ValueError(f"A profile with the name '{data.name}' already exists. Please choose a different name.")
|
||||
|
||||
voice_type = getattr(profile, "voice_type", None) or "cloned"
|
||||
preset_engine = getattr(profile, "preset_engine", None)
|
||||
preset_voice_id = getattr(profile, "preset_voice_id", None)
|
||||
design_prompt = getattr(profile, "design_prompt", None)
|
||||
default_engine = data.default_engine if data.default_engine is not None else getattr(profile, "default_engine", None)
|
||||
|
||||
validation_error = _validate_profile_fields(
|
||||
voice_type=voice_type,
|
||||
preset_engine=preset_engine,
|
||||
preset_voice_id=preset_voice_id,
|
||||
design_prompt=design_prompt,
|
||||
default_engine=default_engine,
|
||||
)
|
||||
if validation_error:
|
||||
raise ValueError(validation_error)
|
||||
|
||||
profile.name = data.name
|
||||
profile.description = data.description
|
||||
profile.language = data.language
|
||||
if data.default_engine is not None:
|
||||
profile.default_engine = data.default_engine or None # empty string → NULL
|
||||
profile.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(profile)
|
||||
|
||||
return _profile_to_response(profile)
|
||||
|
||||
|
||||
async def delete_profile(
|
||||
profile_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a voice profile and all associated data.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
return False
|
||||
|
||||
db.query(DBProfileSample).filter_by(profile_id=profile_id).delete()
|
||||
|
||||
db.delete(profile)
|
||||
db.commit()
|
||||
|
||||
profile_dir = config.get_profiles_dir() / profile_id
|
||||
if profile_dir.exists():
|
||||
shutil.rmtree(profile_dir)
|
||||
|
||||
# Clean up combined audio cache files for this profile
|
||||
clear_profile_cache(profile_id)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def delete_profile_sample(
|
||||
sample_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a profile sample.
|
||||
|
||||
Args:
|
||||
sample_id: Sample ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
sample = db.query(DBProfileSample).filter_by(id=sample_id).first()
|
||||
if not sample:
|
||||
return False
|
||||
|
||||
# Store profile_id before deleting
|
||||
profile_id = sample.profile_id
|
||||
|
||||
audio_path = config.resolve_storage_path(sample.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
audio_path.unlink()
|
||||
|
||||
db.delete(sample)
|
||||
db.commit()
|
||||
|
||||
# Invalidate combined audio cache for this profile
|
||||
# Since the sample set changed, any cached combined audio is now stale
|
||||
clear_profile_cache(profile_id)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def update_profile_sample(
|
||||
sample_id: str,
|
||||
reference_text: str,
|
||||
db: Session,
|
||||
) -> ProfileSampleResponse | None:
|
||||
"""
|
||||
Update a profile sample's reference text.
|
||||
|
||||
Args:
|
||||
sample_id: Sample ID
|
||||
reference_text: Updated reference text
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated sample or None if not found
|
||||
"""
|
||||
sample = db.query(DBProfileSample).filter_by(id=sample_id).first()
|
||||
if not sample:
|
||||
return None
|
||||
|
||||
# Store profile_id before updating
|
||||
profile_id = sample.profile_id
|
||||
|
||||
sample.reference_text = reference_text
|
||||
db.commit()
|
||||
db.refresh(sample)
|
||||
|
||||
# Invalidate combined audio cache for this profile
|
||||
# Since the reference text changed, cache keys and combined text are now stale
|
||||
clear_profile_cache(profile_id)
|
||||
|
||||
return ProfileSampleResponse.model_validate(sample)
|
||||
|
||||
|
||||
async def create_voice_prompt_for_profile(
|
||||
profile_id: str,
|
||||
db: Session,
|
||||
use_cache: bool = True,
|
||||
engine: str = "qwen",
|
||||
) -> dict:
|
||||
"""
|
||||
Create a voice prompt from a profile.
|
||||
|
||||
For cloned profiles: combines all audio samples into a voice prompt.
|
||||
For preset profiles: returns the engine-specific preset voice reference.
|
||||
For designed profiles: returns the text design prompt (future).
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
db: Database session
|
||||
use_cache: Whether to use cached prompts
|
||||
engine: TTS engine to create prompt for
|
||||
|
||||
Returns:
|
||||
Voice prompt dictionary
|
||||
"""
|
||||
from ..backends import get_tts_backend_for_engine
|
||||
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile not found: {profile_id}")
|
||||
|
||||
voice_type = getattr(profile, "voice_type", None) or "cloned"
|
||||
validate_profile_engine(profile, engine)
|
||||
|
||||
# ── Preset profiles: return engine-specific voice reference ──
|
||||
if voice_type == "preset":
|
||||
if not profile.preset_engine or not profile.preset_voice_id:
|
||||
raise ValueError(f"Preset profile {profile_id} is missing preset engine metadata")
|
||||
if profile.preset_engine != engine:
|
||||
raise ValueError(
|
||||
f"Preset profile {profile_id} only supports engine '{profile.preset_engine}', not '{engine}'"
|
||||
)
|
||||
return {
|
||||
"voice_type": "preset",
|
||||
"preset_engine": profile.preset_engine,
|
||||
"preset_voice_id": profile.preset_voice_id,
|
||||
}
|
||||
|
||||
# ── Designed profiles: return text description (future) ──
|
||||
if voice_type == "designed":
|
||||
if not profile.design_prompt or not profile.design_prompt.strip():
|
||||
raise ValueError(f"Designed profile {profile_id} is missing design_prompt")
|
||||
return {
|
||||
"voice_type": "designed",
|
||||
"design_prompt": profile.design_prompt,
|
||||
}
|
||||
|
||||
if engine not in CLONING_ENGINES:
|
||||
raise ValueError(f"Engine '{engine}' does not support cloned voice profiles")
|
||||
|
||||
# ── Cloned profiles: create from audio samples ──
|
||||
samples = db.query(DBProfileSample).filter_by(profile_id=profile_id).all()
|
||||
|
||||
if not samples:
|
||||
raise ValueError(f"No samples found for profile {profile_id}")
|
||||
|
||||
tts_model = get_tts_backend_for_engine(engine)
|
||||
|
||||
if len(samples) == 1:
|
||||
sample = samples[0]
|
||||
sample_audio_path = config.resolve_storage_path(sample.audio_path)
|
||||
if sample_audio_path is None:
|
||||
raise ValueError(f"Sample audio not found for profile {profile_id}")
|
||||
voice_prompt, _ = await tts_model.create_voice_prompt(
|
||||
str(sample_audio_path),
|
||||
sample.reference_text,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
return voice_prompt
|
||||
|
||||
audio_paths = []
|
||||
for sample in samples:
|
||||
sample_audio_path = config.resolve_storage_path(sample.audio_path)
|
||||
if sample_audio_path is None:
|
||||
raise ValueError(f"Sample audio not found for profile {profile_id}")
|
||||
audio_paths.append(str(sample_audio_path))
|
||||
reference_texts = [s.reference_text for s in samples]
|
||||
|
||||
combined_audio, combined_text = await tts_model.combine_voice_prompts(
|
||||
audio_paths,
|
||||
reference_texts,
|
||||
)
|
||||
|
||||
# Save combined audio to cache directory (persistent)
|
||||
# Create a hash of sample IDs to identify this specific combination
|
||||
import hashlib
|
||||
|
||||
sample_ids_str = "-".join(sorted([s.id for s in samples]))
|
||||
combination_hash = hashlib.md5(sample_ids_str.encode()).hexdigest()[:12]
|
||||
|
||||
cache_dir = _get_cache_dir()
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
combined_path = cache_dir / f"combined_{profile_id}_{combination_hash}.wav"
|
||||
|
||||
save_audio(combined_audio, str(combined_path), 24000)
|
||||
|
||||
voice_prompt, _ = await tts_model.create_voice_prompt(
|
||||
str(combined_path),
|
||||
combined_text,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
return voice_prompt
|
||||
|
||||
|
||||
async def upload_avatar(
|
||||
profile_id: str,
|
||||
image_path: str,
|
||||
db: Session,
|
||||
) -> VoiceProfileResponse:
|
||||
"""
|
||||
Upload and process avatar image for a profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
image_path: Path to uploaded image file
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated profile
|
||||
"""
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise ValueError(f"Profile {profile_id} not found")
|
||||
|
||||
is_valid, error_msg = validate_image(image_path)
|
||||
if not is_valid:
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if profile.avatar_path:
|
||||
old_avatar = config.resolve_storage_path(profile.avatar_path)
|
||||
if old_avatar is not None and old_avatar.exists():
|
||||
old_avatar.unlink()
|
||||
|
||||
# Determine file extension from uploaded file
|
||||
from PIL import Image
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
# Normalize JPEG variants (MPO is multi-picture format from some cameras)
|
||||
img_format = img.format
|
||||
if img_format in ("MPO", "JPG"):
|
||||
img_format = "JPEG"
|
||||
|
||||
ext_map = {"PNG": ".png", "JPEG": ".jpg", "WEBP": ".webp"}
|
||||
ext = ext_map.get(img_format, ".png")
|
||||
|
||||
profile_dir = config.get_profiles_dir() / profile_id
|
||||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = profile_dir / f"avatar{ext}"
|
||||
|
||||
process_avatar(image_path, str(output_path))
|
||||
|
||||
profile.avatar_path = config.to_storage_path(output_path)
|
||||
profile.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(profile)
|
||||
|
||||
return _profile_to_response(profile)
|
||||
|
||||
|
||||
async def delete_avatar(
|
||||
profile_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete avatar image for a profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found or no avatar
|
||||
"""
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile or not profile.avatar_path:
|
||||
return False
|
||||
|
||||
avatar_path = config.resolve_storage_path(profile.avatar_path)
|
||||
if avatar_path is not None and avatar_path.exists():
|
||||
avatar_path.unlink()
|
||||
|
||||
profile.avatar_path = None
|
||||
profile.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
925
backend/services/stories.py
Normal file
925
backend/services/stories.py
Normal file
@@ -0,0 +1,925 @@
|
||||
"""
|
||||
Story management module.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from .. import config
|
||||
from ..models import (
|
||||
StoryCreate,
|
||||
StoryResponse,
|
||||
StoryDetailResponse,
|
||||
StoryItemDetail,
|
||||
StoryItemCreate,
|
||||
StoryItemBatchUpdate,
|
||||
StoryItemMove,
|
||||
StoryItemTrim,
|
||||
StoryItemSplit,
|
||||
StoryItemVersionUpdate,
|
||||
)
|
||||
from ..database import (
|
||||
Story as DBStory,
|
||||
StoryItem as DBStoryItem,
|
||||
Generation as DBGeneration,
|
||||
VoiceProfile as DBVoiceProfile,
|
||||
)
|
||||
from .history import _get_versions_for_generation
|
||||
from ..utils.audio import load_audio, save_audio
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _build_item_detail(
|
||||
item: DBStoryItem,
|
||||
generation: DBGeneration,
|
||||
profile_name: str,
|
||||
db: Session,
|
||||
) -> StoryItemDetail:
|
||||
"""Build a StoryItemDetail with version info from a story item and its generation."""
|
||||
versions, active_version_id = _get_versions_for_generation(generation.id, db)
|
||||
|
||||
# Resolve the audio path: if version_id is set, use that version's audio
|
||||
audio_path = generation.audio_path
|
||||
if item.version_id and versions:
|
||||
for v in versions:
|
||||
if v.id == item.version_id:
|
||||
audio_path = v.audio_path
|
||||
break
|
||||
|
||||
return StoryItemDetail(
|
||||
id=item.id,
|
||||
story_id=item.story_id,
|
||||
generation_id=item.generation_id,
|
||||
version_id=getattr(item, "version_id", None),
|
||||
start_time_ms=item.start_time_ms,
|
||||
track=item.track,
|
||||
trim_start_ms=getattr(item, "trim_start_ms", 0),
|
||||
trim_end_ms=getattr(item, "trim_end_ms", 0),
|
||||
created_at=item.created_at,
|
||||
profile_id=generation.profile_id,
|
||||
profile_name=profile_name,
|
||||
text=generation.text,
|
||||
language=generation.language,
|
||||
audio_path=audio_path,
|
||||
duration=generation.duration,
|
||||
seed=generation.seed,
|
||||
instruct=generation.instruct,
|
||||
generation_created_at=generation.created_at,
|
||||
versions=versions,
|
||||
active_version_id=active_version_id,
|
||||
)
|
||||
|
||||
|
||||
async def create_story(
|
||||
data: StoryCreate,
|
||||
db: Session,
|
||||
) -> StoryResponse:
|
||||
"""
|
||||
Create a new story.
|
||||
|
||||
Args:
|
||||
data: Story creation data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created story
|
||||
"""
|
||||
db_story = DBStory(
|
||||
id=str(uuid.uuid4()),
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(db_story)
|
||||
db.commit()
|
||||
db.refresh(db_story)
|
||||
|
||||
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == db_story.id).scalar()
|
||||
|
||||
response = StoryResponse.model_validate(db_story)
|
||||
response.item_count = item_count
|
||||
return response
|
||||
|
||||
|
||||
async def list_stories(
|
||||
db: Session,
|
||||
) -> List[StoryResponse]:
|
||||
"""
|
||||
List all stories.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of stories with item counts
|
||||
"""
|
||||
stories = db.query(DBStory).order_by(DBStory.updated_at.desc()).all()
|
||||
|
||||
result = []
|
||||
for story in stories:
|
||||
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == story.id).scalar()
|
||||
|
||||
response = StoryResponse.model_validate(story)
|
||||
response.item_count = item_count
|
||||
result.append(response)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def get_story(
|
||||
story_id: str,
|
||||
db: Session,
|
||||
) -> Optional[StoryDetailResponse]:
|
||||
"""
|
||||
Get a story with all its items.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Story with items or None if not found
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
items = (
|
||||
db.query(DBStoryItem, DBGeneration, DBVoiceProfile.name.label("profile_name"))
|
||||
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
|
||||
.join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id)
|
||||
.filter(DBStoryItem.story_id == story_id)
|
||||
.order_by(DBStoryItem.start_time_ms)
|
||||
.all()
|
||||
)
|
||||
|
||||
item_details = []
|
||||
for item, generation, profile_name in items:
|
||||
item_details.append(_build_item_detail(item, generation, profile_name, db))
|
||||
|
||||
response = StoryDetailResponse.model_validate(story)
|
||||
response.items = item_details
|
||||
return response
|
||||
|
||||
|
||||
async def update_story(
|
||||
story_id: str,
|
||||
data: StoryCreate,
|
||||
db: Session,
|
||||
) -> Optional[StoryResponse]:
|
||||
"""
|
||||
Update a story.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
data: Update data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated story or None if not found
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
story.name = data.name
|
||||
story.description = data.description
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(story)
|
||||
|
||||
item_count = db.query(func.count(DBStoryItem.id)).filter(DBStoryItem.story_id == story.id).scalar()
|
||||
|
||||
response = StoryResponse.model_validate(story)
|
||||
response.item_count = item_count
|
||||
return response
|
||||
|
||||
|
||||
async def delete_story(
|
||||
story_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a story and all its items.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return False
|
||||
|
||||
# Delete all items
|
||||
db.query(DBStoryItem).filter_by(story_id=story_id).delete()
|
||||
|
||||
# Delete story
|
||||
db.delete(story)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def add_item_to_story(
|
||||
story_id: str,
|
||||
data: StoryItemCreate,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Add a generation to a story.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
data: Item creation data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created item detail or None if story/generation not found
|
||||
"""
|
||||
# Verify story exists
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
# Verify generation exists
|
||||
generation = db.query(DBGeneration).filter_by(id=data.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Check if generation is already in story
|
||||
existing = db.query(DBStoryItem).filter_by(story_id=story_id, generation_id=data.generation_id).first()
|
||||
if existing:
|
||||
# Return existing item
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
return _build_item_detail(existing, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
# Get track from data or default to 0
|
||||
track = data.track if data.track is not None else 0
|
||||
|
||||
# Calculate start_time_ms if not provided
|
||||
if data.start_time_ms is not None:
|
||||
start_time_ms = data.start_time_ms
|
||||
else:
|
||||
existing_items = (
|
||||
db.query(DBStoryItem, DBGeneration)
|
||||
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
|
||||
.filter(
|
||||
DBStoryItem.story_id == story_id,
|
||||
DBStoryItem.track == track,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not existing_items:
|
||||
start_time_ms = 0
|
||||
else:
|
||||
max_end_time_ms = 0
|
||||
for item, gen in existing_items:
|
||||
item_end_ms = item.start_time_ms + int(gen.duration * 1000)
|
||||
max_end_time_ms = max(max_end_time_ms, item_end_ms)
|
||||
|
||||
# Add 200ms gap after the last item
|
||||
start_time_ms = max_end_time_ms + 200
|
||||
|
||||
# Create item
|
||||
item = DBStoryItem(
|
||||
id=str(uuid.uuid4()),
|
||||
story_id=story_id,
|
||||
generation_id=data.generation_id,
|
||||
start_time_ms=start_time_ms,
|
||||
track=track,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(item)
|
||||
|
||||
# Update story updated_at
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def move_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: StoryItemMove,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Move a story item (update position and/or track).
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID
|
||||
data: New position and track data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated item detail or None if not found
|
||||
"""
|
||||
# Get the item
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return None
|
||||
|
||||
# Get the generation
|
||||
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Update position and track
|
||||
item.start_time_ms = data.start_time_ms
|
||||
item.track = data.track
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def remove_item_from_story(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Remove a story item from a story.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID to remove
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found
|
||||
"""
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return False
|
||||
|
||||
# Delete item
|
||||
db.delete(item)
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def trim_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: StoryItemTrim,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Trim a story item (update trim_start_ms and trim_end_ms).
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID
|
||||
data: Trim data (trim_start_ms, trim_end_ms)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated item detail or None if not found
|
||||
"""
|
||||
# Get the item
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return None
|
||||
|
||||
# Get the generation
|
||||
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Validate trim values don't exceed duration
|
||||
max_duration_ms = int(generation.duration * 1000)
|
||||
if data.trim_start_ms + data.trim_end_ms >= max_duration_ms:
|
||||
return None # Invalid trim - would result in zero or negative duration
|
||||
|
||||
# Update trim values
|
||||
item.trim_start_ms = data.trim_start_ms
|
||||
item.trim_end_ms = data.trim_end_ms
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def split_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: StoryItemSplit,
|
||||
db: Session,
|
||||
) -> Optional[List[StoryItemDetail]]:
|
||||
"""
|
||||
Split a story item at a given time, creating two clips.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID to split
|
||||
data: Split data (split_time_ms - time within clip to split at)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of two updated item details (original and new) or None if not found/invalid
|
||||
"""
|
||||
# Get the item with a row lock to prevent concurrent splits on the
|
||||
# same clip (e.g. from rapid double-clicks racing each other).
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return None
|
||||
|
||||
# Get the generation
|
||||
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Calculate effective duration and validate split point
|
||||
current_trim_start = getattr(item, "trim_start_ms", 0)
|
||||
current_trim_end = getattr(item, "trim_end_ms", 0)
|
||||
original_duration_ms = int(generation.duration * 1000)
|
||||
effective_duration_ms = original_duration_ms - current_trim_start - current_trim_end
|
||||
|
||||
# Validate split_time_ms is within the effective duration
|
||||
if data.split_time_ms <= 0 or data.split_time_ms >= effective_duration_ms:
|
||||
return None # Invalid split point
|
||||
|
||||
# Calculate the absolute time in the original audio where we're splitting
|
||||
absolute_split_ms = current_trim_start + data.split_time_ms
|
||||
|
||||
# Update original clip: trim from the end
|
||||
item.trim_end_ms = original_duration_ms - absolute_split_ms
|
||||
|
||||
# Create new clip: starts after the split, trimmed from the start
|
||||
new_item = DBStoryItem(
|
||||
id=str(uuid.uuid4()),
|
||||
story_id=story_id,
|
||||
generation_id=item.generation_id, # Same generation, different trim
|
||||
version_id=getattr(item, "version_id", None), # Preserve pinned version
|
||||
start_time_ms=item.start_time_ms + data.split_time_ms,
|
||||
track=item.track,
|
||||
trim_start_ms=absolute_split_ms,
|
||||
trim_end_ms=current_trim_end,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(new_item)
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
db.refresh(new_item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
profile_name = profile.name if profile else "Unknown"
|
||||
|
||||
return [
|
||||
_build_item_detail(item, generation, profile_name, db),
|
||||
_build_item_detail(new_item, generation, profile_name, db),
|
||||
]
|
||||
|
||||
|
||||
async def duplicate_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Duplicate a story item, creating a copy with all properties.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID to duplicate
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
New item detail or None if not found
|
||||
"""
|
||||
# Get the original item
|
||||
original_item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not original_item:
|
||||
return None
|
||||
|
||||
# Get the generation
|
||||
generation = db.query(DBGeneration).filter_by(id=original_item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Calculate effective duration
|
||||
current_trim_start = getattr(original_item, "trim_start_ms", 0)
|
||||
current_trim_end = getattr(original_item, "trim_end_ms", 0)
|
||||
original_duration_ms = int(generation.duration * 1000)
|
||||
effective_duration_ms = original_duration_ms - current_trim_start - current_trim_end
|
||||
|
||||
# Create duplicate item - place it right after the original
|
||||
new_item = DBStoryItem(
|
||||
id=str(uuid.uuid4()),
|
||||
story_id=story_id,
|
||||
generation_id=original_item.generation_id, # Same generation as original
|
||||
version_id=getattr(original_item, "version_id", None), # Preserve pinned version
|
||||
start_time_ms=original_item.start_time_ms + effective_duration_ms + 200, # 200ms gap
|
||||
track=original_item.track,
|
||||
trim_start_ms=current_trim_start,
|
||||
trim_end_ms=current_trim_end,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(new_item)
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(new_item)
|
||||
|
||||
# Get profile name
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(new_item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def update_story_item_times(
|
||||
story_id: str,
|
||||
data: StoryItemBatchUpdate,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Update story item timecodes.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
data: Batch update data with timecodes
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
True if updated, False if story not found or invalid
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return False
|
||||
|
||||
# Get all items for this story
|
||||
items = db.query(DBStoryItem).filter_by(story_id=story_id).all()
|
||||
item_map = {item.generation_id: item for item in items}
|
||||
|
||||
# Verify all generation IDs belong to this story and update timecodes
|
||||
for update in data.updates:
|
||||
if update.generation_id not in item_map:
|
||||
return False
|
||||
item_map[update.generation_id].start_time_ms = update.start_time_ms
|
||||
|
||||
# Update story updated_at
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def reorder_story_items(
|
||||
story_id: str,
|
||||
generation_ids: List[str],
|
||||
db: Session,
|
||||
gap_ms: int = 200,
|
||||
) -> Optional[List[StoryItemDetail]]:
|
||||
"""
|
||||
Reorder story items and recalculate timecodes.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
generation_ids: List of generation IDs in the desired order
|
||||
db: Database session
|
||||
gap_ms: Gap in milliseconds between items (default 200ms)
|
||||
|
||||
Returns:
|
||||
Updated list of story items with new timecodes, or None if invalid
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
# Get all items for this story with their generation data
|
||||
items_with_gen = (
|
||||
db.query(DBStoryItem, DBGeneration, DBVoiceProfile.name.label("profile_name"))
|
||||
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
|
||||
.join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id)
|
||||
.filter(DBStoryItem.story_id == story_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create maps for quick lookup
|
||||
item_map = {item.generation_id: (item, gen, profile_name) for item, gen, profile_name in items_with_gen}
|
||||
|
||||
# Verify all generation IDs belong to this story
|
||||
if set(generation_ids) != set(item_map.keys()):
|
||||
return None
|
||||
|
||||
# Recalculate timecodes based on new order
|
||||
current_time_ms = 0
|
||||
updated_items = []
|
||||
|
||||
for gen_id in generation_ids:
|
||||
item, generation, profile_name = item_map[gen_id]
|
||||
|
||||
# Update the item's start time
|
||||
item.start_time_ms = current_time_ms
|
||||
|
||||
# Calculate the duration in ms
|
||||
duration_ms = int(generation.duration * 1000)
|
||||
|
||||
# Move to next position (current end + gap)
|
||||
current_time_ms += duration_ms + gap_ms
|
||||
|
||||
# Build the response item
|
||||
updated_items.append(_build_item_detail(item, generation, profile_name, db))
|
||||
|
||||
# Update story updated_at
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
return updated_items
|
||||
|
||||
|
||||
async def set_story_item_version(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: StoryItemVersionUpdate,
|
||||
db: Session,
|
||||
) -> Optional[StoryItemDetail]:
|
||||
"""
|
||||
Pin a story item to a specific generation version.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
item_id: Story item ID
|
||||
data: Version update data (version_id or null for default)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated item detail or None if not found
|
||||
"""
|
||||
item = (
|
||||
db.query(DBStoryItem)
|
||||
.filter_by(
|
||||
id=item_id,
|
||||
story_id=story_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not item:
|
||||
return None
|
||||
|
||||
generation = db.query(DBGeneration).filter_by(id=item.generation_id).first()
|
||||
if not generation:
|
||||
return None
|
||||
|
||||
# Validate version_id belongs to this generation if provided
|
||||
if data.version_id:
|
||||
from ..database import GenerationVersion as DBGenerationVersion
|
||||
|
||||
version = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(
|
||||
id=data.version_id,
|
||||
generation_id=item.generation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not version:
|
||||
return None
|
||||
|
||||
item.version_id = data.version_id
|
||||
|
||||
# Update story updated_at
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if story:
|
||||
story.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=generation.profile_id).first()
|
||||
|
||||
return _build_item_detail(item, generation, profile.name if profile else "Unknown", db)
|
||||
|
||||
|
||||
async def export_story_audio(
|
||||
story_id: str,
|
||||
db: Session,
|
||||
) -> Optional[bytes]:
|
||||
"""
|
||||
Export story as single mixed audio file with timecode-based mixing.
|
||||
|
||||
Args:
|
||||
story_id: Story ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Audio file bytes or None if story not found
|
||||
"""
|
||||
story = db.query(DBStory).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
return None
|
||||
|
||||
# Get all items ordered by start_time_ms
|
||||
items = (
|
||||
db.query(DBStoryItem, DBGeneration)
|
||||
.join(DBGeneration, DBStoryItem.generation_id == DBGeneration.id)
|
||||
.filter(DBStoryItem.story_id == story_id)
|
||||
.order_by(DBStoryItem.start_time_ms)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not items:
|
||||
return None
|
||||
|
||||
# Load all audio files and calculate total duration
|
||||
audio_data = []
|
||||
sample_rate = 24000 # Default sample rate
|
||||
|
||||
for item, generation in items:
|
||||
# Resolve audio path: use pinned version if set, otherwise generation default
|
||||
resolved_audio_path = generation.audio_path
|
||||
if getattr(item, "version_id", None):
|
||||
from ..database import GenerationVersion as DBGenerationVersion
|
||||
|
||||
version = db.query(DBGenerationVersion).filter_by(id=item.version_id).first()
|
||||
if version:
|
||||
resolved_audio_path = version.audio_path
|
||||
|
||||
audio_path = config.resolve_storage_path(resolved_audio_path)
|
||||
if audio_path is None or not audio_path.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
audio, sr = load_audio(str(audio_path), sample_rate=sample_rate)
|
||||
sample_rate = sr # Use actual sample rate from first file
|
||||
|
||||
# Get trim values
|
||||
trim_start_ms = getattr(item, "trim_start_ms", 0)
|
||||
trim_end_ms = getattr(item, "trim_end_ms", 0)
|
||||
|
||||
# Calculate effective duration
|
||||
original_duration_ms = int(generation.duration * 1000)
|
||||
effective_duration_ms = original_duration_ms - trim_start_ms - trim_end_ms
|
||||
|
||||
# Slice audio based on trim values
|
||||
trim_start_sample = int((trim_start_ms / 1000.0) * sample_rate)
|
||||
trim_end_sample = int((trim_end_ms / 1000.0) * sample_rate)
|
||||
|
||||
# Extract the trimmed portion
|
||||
if trim_end_ms > 0:
|
||||
trimmed_audio = (
|
||||
audio[trim_start_sample:-trim_end_sample] if trim_end_sample > 0 else audio[trim_start_sample:]
|
||||
)
|
||||
else:
|
||||
trimmed_audio = audio[trim_start_sample:]
|
||||
|
||||
# Store audio with its timecode info
|
||||
start_time_ms = item.start_time_ms
|
||||
|
||||
audio_data.append(
|
||||
{
|
||||
"audio": trimmed_audio,
|
||||
"start_time_ms": start_time_ms,
|
||||
"duration_ms": effective_duration_ms,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Skip files that can't be loaded
|
||||
continue
|
||||
|
||||
if not audio_data:
|
||||
return None
|
||||
|
||||
# Calculate total duration: max(start_time_ms + duration_ms)
|
||||
max_end_time_ms = max((data["start_time_ms"] + data["duration_ms"] for data in audio_data), default=0)
|
||||
|
||||
# Convert to samples
|
||||
total_samples = int((max_end_time_ms / 1000.0) * sample_rate)
|
||||
|
||||
# Create output buffer initialized to zeros
|
||||
final_audio = np.zeros(total_samples, dtype=np.float32)
|
||||
|
||||
# Mix each audio segment at its timecode position
|
||||
for data in audio_data:
|
||||
audio = data["audio"]
|
||||
start_time_ms = data["start_time_ms"]
|
||||
|
||||
# Calculate start sample index
|
||||
start_sample = int((start_time_ms / 1000.0) * sample_rate)
|
||||
|
||||
# Ensure we don't exceed buffer bounds
|
||||
audio_length = len(audio)
|
||||
end_sample = min(start_sample + audio_length, total_samples)
|
||||
|
||||
if start_sample < total_samples:
|
||||
# Trim audio if it extends beyond buffer
|
||||
audio_to_mix = audio[: end_sample - start_sample]
|
||||
|
||||
# Mix: add audio to existing buffer (overlapping audio will sum)
|
||||
# Normalize to prevent clipping (simple approach: divide by max)
|
||||
final_audio[start_sample:end_sample] += audio_to_mix
|
||||
|
||||
# Normalize to prevent clipping
|
||||
max_val = np.abs(final_audio).max()
|
||||
if max_val > 1.0:
|
||||
final_audio = final_audio / max_val
|
||||
|
||||
# Save to temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
save_audio(final_audio, tmp_path, sample_rate)
|
||||
|
||||
# Read file bytes
|
||||
with open(tmp_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
return audio_bytes
|
||||
finally:
|
||||
# Clean up temp file
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
108
backend/services/task_queue.py
Normal file
108
backend/services/task_queue.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Serial generation queue — ensures only one TTS inference runs at a time
|
||||
to avoid GPU contention.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Coroutine, Literal
|
||||
|
||||
# Keep references to fire-and-forget background tasks to prevent GC
|
||||
_background_tasks: set = set()
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationJob:
|
||||
"""Queued generation work plus the generation ID it belongs to."""
|
||||
|
||||
generation_id: str
|
||||
coro: Coroutine
|
||||
|
||||
|
||||
# Generation queue — serializes TTS inference to avoid GPU contention
|
||||
_generation_queue: asyncio.Queue = None # type: ignore # initialized at startup
|
||||
_generation_worker_task: asyncio.Task | None = None
|
||||
_queued_generation_ids: set[str] = set()
|
||||
_running_generation_tasks: dict[str, asyncio.Task] = {}
|
||||
_cancelled_generation_ids: set[str] = set()
|
||||
|
||||
|
||||
def create_background_task(coro) -> asyncio.Task:
|
||||
"""Create a background task and prevent it from being garbage collected."""
|
||||
task = asyncio.create_task(coro)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
return task
|
||||
|
||||
|
||||
async def _generation_worker():
|
||||
"""Worker that processes generation tasks one at a time."""
|
||||
while True:
|
||||
job = await _generation_queue.get()
|
||||
try:
|
||||
if job.generation_id in _cancelled_generation_ids:
|
||||
_cancelled_generation_ids.discard(job.generation_id)
|
||||
job.coro.close()
|
||||
continue
|
||||
|
||||
task = asyncio.create_task(job.coro)
|
||||
_running_generation_tasks[job.generation_id] = task
|
||||
_queued_generation_ids.discard(job.generation_id)
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
if not task.cancelled():
|
||||
raise
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
_running_generation_tasks.pop(job.generation_id, None)
|
||||
_queued_generation_ids.discard(job.generation_id)
|
||||
_generation_queue.task_done()
|
||||
|
||||
|
||||
def enqueue_generation(generation_id: str, coro):
|
||||
"""Add a generation coroutine to the serial queue."""
|
||||
if _generation_queue is None:
|
||||
raise RuntimeError("Generation queue has not been initialized")
|
||||
|
||||
_queued_generation_ids.add(generation_id)
|
||||
_generation_queue.put_nowait(GenerationJob(generation_id=generation_id, coro=coro))
|
||||
|
||||
|
||||
def cancel_generation(generation_id: str) -> Literal["queued", "running"] | None:
|
||||
"""Cancel a queued or running generation if it is still active."""
|
||||
running_task = _running_generation_tasks.get(generation_id)
|
||||
if running_task is not None:
|
||||
running_task.cancel()
|
||||
return "running"
|
||||
|
||||
if generation_id in _queued_generation_ids:
|
||||
_queued_generation_ids.discard(generation_id)
|
||||
_cancelled_generation_ids.add(generation_id)
|
||||
return "queued"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def init_queue(force: bool = False):
|
||||
"""Initialize the generation queue and start the worker.
|
||||
|
||||
Must be called once during application startup (inside a running event loop).
|
||||
"""
|
||||
global _generation_queue, _generation_worker_task
|
||||
global _queued_generation_ids, _running_generation_tasks, _cancelled_generation_ids
|
||||
|
||||
if _generation_worker_task is not None and not _generation_worker_task.done():
|
||||
if not force:
|
||||
return
|
||||
_generation_worker_task.cancel()
|
||||
for task in list(_running_generation_tasks.values()):
|
||||
task.cancel()
|
||||
|
||||
_generation_queue = asyncio.Queue()
|
||||
_queued_generation_ids = set()
|
||||
_running_generation_tasks = {}
|
||||
_cancelled_generation_ids = set()
|
||||
_generation_worker_task = create_background_task(_generation_worker())
|
||||
22
backend/services/transcribe.py
Normal file
22
backend/services/transcribe.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
STT (Speech-to-Text) module - delegates to backend abstraction layer.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from ..backends import get_stt_backend, STTBackend
|
||||
|
||||
|
||||
def get_whisper_model() -> STTBackend:
|
||||
"""
|
||||
Get STT backend instance (MLX or PyTorch based on platform).
|
||||
|
||||
Returns:
|
||||
STT backend instance
|
||||
"""
|
||||
return get_stt_backend()
|
||||
|
||||
|
||||
def unload_whisper_model():
|
||||
"""Unload Whisper model to free memory."""
|
||||
backend = get_stt_backend()
|
||||
backend.unload_model()
|
||||
34
backend/services/tts.py
Normal file
34
backend/services/tts.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
TTS inference module - delegates to backend abstraction layer.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
import io
|
||||
import soundfile as sf
|
||||
|
||||
from ..backends import get_tts_backend, TTSBackend
|
||||
|
||||
|
||||
def get_tts_model() -> TTSBackend:
|
||||
"""
|
||||
Get TTS backend instance (MLX or PyTorch based on platform).
|
||||
|
||||
Returns:
|
||||
TTS backend instance
|
||||
"""
|
||||
return get_tts_backend()
|
||||
|
||||
|
||||
def unload_tts_model():
|
||||
"""Unload TTS model to free memory."""
|
||||
backend = get_tts_backend()
|
||||
backend.unload_model()
|
||||
|
||||
|
||||
def audio_to_wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes:
|
||||
"""Convert audio array to WAV bytes."""
|
||||
buffer = io.BytesIO()
|
||||
sf.write(buffer, audio, sample_rate, format="WAV")
|
||||
buffer.seek(0)
|
||||
return buffer.read()
|
||||
211
backend/services/versions.py
Normal file
211
backend/services/versions.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Generation versions management module.
|
||||
|
||||
Each generation can have multiple audio versions: a clean (unprocessed)
|
||||
version and any number of processed versions with different effects chains.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..database import (
|
||||
GenerationVersion as DBGenerationVersion,
|
||||
Generation as DBGeneration,
|
||||
)
|
||||
from ..models import GenerationVersionResponse, EffectConfig
|
||||
from .. import config
|
||||
|
||||
|
||||
def _version_response(v: DBGenerationVersion) -> GenerationVersionResponse:
|
||||
"""Convert a DB version row to a Pydantic response."""
|
||||
effects_chain = None
|
||||
if v.effects_chain:
|
||||
raw = json.loads(v.effects_chain)
|
||||
effects_chain = [EffectConfig(**e) for e in raw]
|
||||
return GenerationVersionResponse(
|
||||
id=v.id,
|
||||
generation_id=v.generation_id,
|
||||
label=v.label,
|
||||
audio_path=v.audio_path,
|
||||
effects_chain=effects_chain,
|
||||
source_version_id=v.source_version_id,
|
||||
is_default=v.is_default,
|
||||
created_at=v.created_at,
|
||||
)
|
||||
|
||||
|
||||
def list_versions(generation_id: str, db: Session) -> List[GenerationVersionResponse]:
|
||||
"""List all versions for a generation."""
|
||||
versions = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=generation_id)
|
||||
.order_by(DBGenerationVersion.created_at)
|
||||
.all()
|
||||
)
|
||||
return [_version_response(v) for v in versions]
|
||||
|
||||
|
||||
def get_version(version_id: str, db: Session) -> Optional[GenerationVersionResponse]:
|
||||
"""Get a specific version by ID."""
|
||||
v = db.query(DBGenerationVersion).filter_by(id=version_id).first()
|
||||
if not v:
|
||||
return None
|
||||
return _version_response(v)
|
||||
|
||||
|
||||
def get_default_version(generation_id: str, db: Session) -> Optional[GenerationVersionResponse]:
|
||||
"""Get the default version for a generation."""
|
||||
v = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=generation_id, is_default=True)
|
||||
.first()
|
||||
)
|
||||
if not v:
|
||||
# Fallback: return the first version
|
||||
v = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=generation_id)
|
||||
.order_by(DBGenerationVersion.created_at)
|
||||
.first()
|
||||
)
|
||||
if not v:
|
||||
return None
|
||||
return _version_response(v)
|
||||
|
||||
|
||||
def create_version(
|
||||
generation_id: str,
|
||||
label: str,
|
||||
audio_path: str,
|
||||
db: Session,
|
||||
effects_chain: Optional[List[dict]] = None,
|
||||
is_default: bool = False,
|
||||
source_version_id: Optional[str] = None,
|
||||
) -> GenerationVersionResponse:
|
||||
"""Create a new version for a generation.
|
||||
|
||||
If ``is_default`` is True, all other versions for this generation
|
||||
are un-defaulted first.
|
||||
"""
|
||||
if is_default:
|
||||
_clear_defaults(generation_id, db)
|
||||
|
||||
version = DBGenerationVersion(
|
||||
id=str(uuid.uuid4()),
|
||||
generation_id=generation_id,
|
||||
label=label,
|
||||
audio_path=audio_path,
|
||||
effects_chain=json.dumps(effects_chain) if effects_chain else None,
|
||||
source_version_id=source_version_id,
|
||||
is_default=is_default,
|
||||
)
|
||||
db.add(version)
|
||||
db.commit()
|
||||
db.refresh(version)
|
||||
|
||||
# If this version is the default, update the generation's audio_path
|
||||
if is_default:
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if gen:
|
||||
gen.audio_path = audio_path
|
||||
db.commit()
|
||||
|
||||
return _version_response(version)
|
||||
|
||||
|
||||
def set_default_version(version_id: str, db: Session) -> Optional[GenerationVersionResponse]:
|
||||
"""Set a version as the default for its generation."""
|
||||
version = db.query(DBGenerationVersion).filter_by(id=version_id).first()
|
||||
if not version:
|
||||
return None
|
||||
|
||||
_clear_defaults(version.generation_id, db)
|
||||
version.is_default = True
|
||||
db.commit()
|
||||
db.refresh(version)
|
||||
|
||||
# Update generation's audio_path to point to this version
|
||||
gen = db.query(DBGeneration).filter_by(id=version.generation_id).first()
|
||||
if gen:
|
||||
gen.audio_path = version.audio_path
|
||||
db.commit()
|
||||
|
||||
return _version_response(version)
|
||||
|
||||
|
||||
def delete_version(version_id: str, db: Session) -> bool:
|
||||
"""Delete a version. Cannot delete the last remaining version."""
|
||||
version = db.query(DBGenerationVersion).filter_by(id=version_id).first()
|
||||
if not version:
|
||||
return False
|
||||
|
||||
# Don't allow deleting the last version
|
||||
count = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=version.generation_id)
|
||||
.count()
|
||||
)
|
||||
if count <= 1:
|
||||
return False
|
||||
|
||||
was_default = version.is_default
|
||||
gen_id = version.generation_id
|
||||
|
||||
# Delete audio file
|
||||
audio_path = config.resolve_storage_path(version.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
audio_path.unlink()
|
||||
|
||||
db.delete(version)
|
||||
db.commit()
|
||||
|
||||
# If this was the default, promote the first remaining version
|
||||
if was_default:
|
||||
first = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=gen_id)
|
||||
.order_by(DBGenerationVersion.created_at)
|
||||
.first()
|
||||
)
|
||||
if first:
|
||||
first.is_default = True
|
||||
db.commit()
|
||||
gen = db.query(DBGeneration).filter_by(id=gen_id).first()
|
||||
if gen:
|
||||
gen.audio_path = first.audio_path
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def delete_versions_for_generation(generation_id: str, db: Session) -> int:
|
||||
"""Delete all versions for a generation (used when deleting a generation)."""
|
||||
versions = (
|
||||
db.query(DBGenerationVersion)
|
||||
.filter_by(generation_id=generation_id)
|
||||
.all()
|
||||
)
|
||||
count = 0
|
||||
for v in versions:
|
||||
audio_path = config.resolve_storage_path(v.audio_path)
|
||||
if audio_path is not None and audio_path.exists():
|
||||
audio_path.unlink()
|
||||
db.delete(v)
|
||||
count += 1
|
||||
if count > 0:
|
||||
db.commit()
|
||||
return count
|
||||
|
||||
|
||||
def _clear_defaults(generation_id: str, db: Session) -> None:
|
||||
"""Clear the is_default flag on all versions for a generation."""
|
||||
db.query(DBGenerationVersion).filter_by(
|
||||
generation_id=generation_id, is_default=True
|
||||
).update({"is_default": False})
|
||||
db.flush()
|
||||
220
backend/tests/E2E_MODEL_TEST_DESIGN.md
Normal file
220
backend/tests/E2E_MODEL_TEST_DESIGN.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# End-to-End Model Generation Test — Design
|
||||
|
||||
## Goal
|
||||
|
||||
A single script, runnable on macOS and Windows, that exercises every TTS model against the **frozen PyInstaller binary** (not the dev server), captures per-model pass/fail and error messages, and exits non-zero if any model fails. Generation is strictly sequential — one model loaded at a time.
|
||||
|
||||
## Test matrix (10 runs)
|
||||
|
||||
Derived from `backend/backends/__init__.py:185-316`. Each row maps to one `POST /generate` call.
|
||||
|
||||
| # | engine | model_size | profile kind | notes |
|
||||
|---|-----------------------|------------|--------------|-------|
|
||||
| 1 | `qwen` | `1.7B` | cloned | reference audio required |
|
||||
| 2 | `qwen` | `0.6B` | cloned | |
|
||||
| 3 | `qwen_custom_voice` | `1.7B` | preset | `preset_voice_id="Ryan"` |
|
||||
| 4 | `qwen_custom_voice` | `0.6B` | preset | `preset_voice_id="Ryan"` |
|
||||
| 5 | `luxtts` | — | cloned | English only |
|
||||
| 6 | `chatterbox` | — | cloned | |
|
||||
| 7 | `chatterbox_turbo` | — | cloned | English only |
|
||||
| 8 | `tada` | `1B` | cloned | tada-1b, English only |
|
||||
| 9 | `tada` | `3B` | cloned | tada-3b-ml, multilingual |
|
||||
| 10| `kokoro` | — | preset | `preset_voice_id="af_heart"` |
|
||||
|
||||
Cloned engines (1, 2, 5, 6, 7, 8, 9) share **one** profile created once with the reference WAV. Preset profiles are created separately, one for kokoro and one for qwen_custom_voice.
|
||||
|
||||
Language for every run: `en` (covers every engine's supported set).
|
||||
|
||||
## End-to-end flow
|
||||
|
||||
```
|
||||
1. Resolve paths → find binary, build if missing
|
||||
2. Launch binary → spawn with --port --data-dir --parent-pid
|
||||
3. Wait for /health → poll until status=="healthy" or 120s timeout
|
||||
4. Create profiles → 1 cloned + 2 preset, via /profiles (+ /samples)
|
||||
5. For each (engine, model_size) in matrix:
|
||||
a. Check cache → GET /models/status → cached? short timeout : long
|
||||
b. POST /generate → get generation_id
|
||||
c. Stream /status → consume SSE until completed/failed/timeout
|
||||
d. Record result → {engine, model_size, status, duration, error, elapsed}
|
||||
6. Write results → JSON + Markdown table to ./results/
|
||||
7. Shutdown binary → SIGTERM, fall back to kill, verify port freed
|
||||
8. Exit code → 0 if all passed, 1 otherwise
|
||||
```
|
||||
|
||||
## Binary resolution
|
||||
|
||||
Search order — **first hit wins**:
|
||||
|
||||
| Platform | Path | Build type |
|
||||
|----------|------|------------|
|
||||
| macOS | `backend/dist/voicebox-server-cuda/voicebox-server-cuda` | onedir (CUDA, rarely on Mac) |
|
||||
| macOS | `backend/dist/voicebox-server` | onefile (CPU) |
|
||||
| Windows | `backend\dist\voicebox-server-cuda\voicebox-server-cuda.exe` | onedir (CUDA) |
|
||||
| Windows | `backend\dist\voicebox-server.exe` | onefile (CPU) |
|
||||
|
||||
If none exist, run `python backend/build_binary.py` and wait for it to finish (can take 5-20 min). Fail with a clear error if the build itself fails. `--skip-build` flag forces "error out if no binary" instead of building.
|
||||
|
||||
## Spawn command
|
||||
|
||||
Mirrors Tauri's launch in `tauri/src-tauri/src/main.rs:369-388`:
|
||||
|
||||
```
|
||||
<binary> --host 127.0.0.1 --port <free-port> --data-dir <tempdir> --parent-pid <test-pid>
|
||||
```
|
||||
|
||||
- **Port**: bind to `0` first in Python to grab a free port, then pass that number.
|
||||
- **Data dir**: `tempfile.mkdtemp(prefix="voicebox-e2e-")`. Deleted after the run unless `--keep-data-dir`. Profiles and generated WAVs land here.
|
||||
- **Parent PID**: current Python PID — ensures the backend dies if the test crashes (watchdog in `server.py:102-224`).
|
||||
- **stdout/stderr**: tee to both a log file in `./results/server-<timestamp>.log` and a rolling in-memory buffer. On model failure, last 100 lines of the buffer are attached to that model's error record.
|
||||
|
||||
## Profile setup
|
||||
|
||||
One cloned profile shared across all cloning engines:
|
||||
|
||||
```http
|
||||
POST /profiles
|
||||
{
|
||||
"name": "e2e-cloned",
|
||||
"voice_type": "cloned",
|
||||
"language": "en"
|
||||
}
|
||||
```
|
||||
|
||||
Then:
|
||||
|
||||
```http
|
||||
POST /profiles/{id}/samples (multipart)
|
||||
file: <reference WAV>
|
||||
reference_text: <exact transcription>
|
||||
```
|
||||
|
||||
Two preset profiles:
|
||||
|
||||
```http
|
||||
POST /profiles
|
||||
{ "name": "e2e-kokoro", "voice_type": "preset", "language": "en",
|
||||
"preset_engine": "kokoro", "preset_voice_id": "af_heart" }
|
||||
|
||||
POST /profiles
|
||||
{ "name": "e2e-qwen-cv", "voice_type": "preset", "language": "en",
|
||||
"preset_engine": "qwen_custom_voice", "preset_voice_id": "Ryan" }
|
||||
```
|
||||
|
||||
## Generation request (per matrix row)
|
||||
|
||||
```http
|
||||
POST /generate
|
||||
{
|
||||
"profile_id": "<appropriate profile>",
|
||||
"text": "The quick brown fox jumps over the lazy dog.",
|
||||
"language": "en",
|
||||
"engine": "<engine>",
|
||||
"model_size": "<size or omitted>",
|
||||
"seed": 42,
|
||||
"normalize": true
|
||||
}
|
||||
```
|
||||
|
||||
Response `id` feeds into the SSE status loop (`GET /generate/{id}/status`, `routes/generations.py:190-227`). Loop reads lines until a payload with `status in ("completed", "failed")` arrives, then breaks.
|
||||
|
||||
## Timeout strategy (split)
|
||||
|
||||
Check `GET /models/status` for the target model **before** generation:
|
||||
|
||||
| Cached? | Per-model timeout | Rationale |
|
||||
|---------|-------------------|-----------|
|
||||
| Yes | **3 minutes** | Inference only; generous for CPU builds |
|
||||
| No | **20 minutes** | First-run HF download up to 8 GB (tada-3b-ml) |
|
||||
|
||||
On timeout: cancel the SSE stream, mark the row `timeout`, and continue to the next row. Don't abort the whole run on one timeout.
|
||||
|
||||
## Result format
|
||||
|
||||
`./results/e2e-<platform>-<arch>-<timestamp>.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"platform": "darwin-arm64",
|
||||
"binary": "/abs/path/voicebox-server",
|
||||
"binary_size_mb": 612,
|
||||
"started_at": "2026-04-16T12:34:56Z",
|
||||
"finished_at": "...",
|
||||
"results": [
|
||||
{
|
||||
"engine": "qwen",
|
||||
"model_size": "1.7B",
|
||||
"status": "passed|failed|timeout",
|
||||
"generation_id": "...",
|
||||
"was_cached": true,
|
||||
"elapsed_seconds": 12.4,
|
||||
"audio_duration": 3.1,
|
||||
"audio_path": "/tmp/.../gen.wav",
|
||||
"error": null,
|
||||
"server_log_tail": null
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Companion `./results/e2e-<...>.md`:
|
||||
|
||||
```
|
||||
# Voicebox E2E — darwin-arm64 — 2026-04-16 12:34
|
||||
|
||||
| Engine | Size | Status | Elapsed | Error |
|
||||
|---------------------|------|--------|---------|-------|
|
||||
| qwen | 1.7B | PASS | 12.4s | |
|
||||
| qwen | 0.6B | FAIL | 4.1s | CUDA OOM: ... |
|
||||
...
|
||||
```
|
||||
|
||||
## CLI flags
|
||||
|
||||
```
|
||||
python -m backend.tests.test_all_models_e2e [flags]
|
||||
|
||||
--binary PATH Use this binary instead of auto-detecting
|
||||
--skip-build Error if no binary found (no auto-build)
|
||||
--reference-wav PATH Reference audio (default: backend/tests/fixtures/reference_voice.wav)
|
||||
--reference-text STR Transcription (default: read from fixtures/reference_voice.txt)
|
||||
--only ENGINE[,...] Run only these engines (e.g. kokoro,qwen)
|
||||
--skip ENGINE[,...] Skip these engines
|
||||
--keep-data-dir Don't delete tempdir after run
|
||||
--timeout-cached SEC Override 180
|
||||
--timeout-download SEC Override 1200
|
||||
--port N Override auto-picked port
|
||||
--output-dir PATH Default: backend/tests/results/
|
||||
```
|
||||
|
||||
## File layout
|
||||
|
||||
```
|
||||
backend/tests/
|
||||
├── E2E_MODEL_TEST_DESIGN.md (this file)
|
||||
├── test_all_models_e2e.py (main script, ~400-500 LoC)
|
||||
├── fixtures/
|
||||
│ ├── reference_voice.wav (user-provided, ~5-15s clean speech)
|
||||
│ └── reference_voice.txt (exact transcription)
|
||||
└── results/ (gitignored)
|
||||
├── e2e-darwin-arm64-<ts>.json
|
||||
├── e2e-darwin-arm64-<ts>.md
|
||||
└── server-<ts>.log
|
||||
```
|
||||
|
||||
The script uses only stdlib + `httpx` (or `requests`) + `sseclient-py` — all already in `backend/requirements.txt`. No pytest to keep it invocable as a single command on fresh checkouts.
|
||||
|
||||
## Safety & cleanup
|
||||
|
||||
- Always kill the spawned binary in a `try/finally`. On Windows, `taskkill /F /T` the whole tree (Tauri does the same).
|
||||
- Verify the port is free on shutdown (Tauri port-reuse check in `main.rs:114-186` could otherwise pick up a ghost).
|
||||
- Don't touch the user's HF cache by default — let the server use `HF_HUB_CACHE` / `VOICEBOX_MODELS_DIR`. Passing `--isolated-cache` would point both env vars at the tempdir for a true cold-start run (opt-in only; would re-download every time).
|
||||
|
||||
## Non-goals
|
||||
|
||||
- Not validating audio quality (no WER, no waveform comparison). Pass = "endpoint returned `completed` and produced a non-empty WAV".
|
||||
- Not testing STT (Whisper), effects chains, channels, or streaming endpoints.
|
||||
- Not running on CI today — human-invoked on dev machines. CI integration is a follow-up once the script is stable.
|
||||
- No model unload between runs — models stay loaded; server manages its own eviction.
|
||||
- No version-drift check on the binary.
|
||||
- No `instruct` parameter exercised on qwen_custom_voice runs.
|
||||
58
backend/tests/README.md
Normal file
58
backend/tests/README.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# Backend Tests
|
||||
|
||||
Manual test scripts for debugging and validating backend functionality.
|
||||
|
||||
## Test Files
|
||||
|
||||
### `test_generation_progress.py`
|
||||
Tests TTS generation with SSE progress monitoring to identify UX issues where users see download progress even when the model is already cached.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
cd backend
|
||||
python tests/test_generation_progress.py
|
||||
```
|
||||
|
||||
**Prerequisites:**
|
||||
- Server must be running (`python main.py`)
|
||||
- At least one voice profile must exist
|
||||
|
||||
### `test_real_download.py`
|
||||
Tests real model download with SSE progress monitoring.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
cd backend
|
||||
# Delete cache first to force fresh download
|
||||
rm -rf ~/.cache/huggingface/hub/models--openai--whisper-base
|
||||
python tests/test_real_download.py
|
||||
```
|
||||
|
||||
**Prerequisites:**
|
||||
- Server must be running (`python main.py`)
|
||||
|
||||
### `test_progress.py`
|
||||
Unit tests for ProgressManager and HFProgressTracker functionality.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
cd backend
|
||||
python tests/test_progress.py
|
||||
```
|
||||
|
||||
### `test_check_progress_state.py`
|
||||
Debugging script to inspect the internal state of ProgressManager and TaskManager.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
cd backend
|
||||
python tests/test_check_progress_state.py
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
These are manual test scripts, not automated unit tests. They're designed for:
|
||||
- Debugging progress tracking issues
|
||||
- Validating SSE event streams
|
||||
- Monitoring real-time download behavior
|
||||
- Inspecting internal state during development
|
||||
6
backend/tests/__init__.py
Normal file
6
backend/tests/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Test suite for Voicebox backend.
|
||||
|
||||
This directory contains manual test scripts for debugging and validating
|
||||
progress tracking, model downloads, and generation functionality.
|
||||
"""
|
||||
16
backend/tests/fixtures/README.md
vendored
Normal file
16
backend/tests/fixtures/README.md
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
# E2E Test Fixtures
|
||||
|
||||
Place two files here before running `test_all_models_e2e.py`:
|
||||
|
||||
- `reference_voice.wav` — a clean speech sample, mono, 16–24 kHz, ~5–15 seconds.
|
||||
- `reference_voice.txt` — the **exact** transcription of the WAV (single line, no trailing newline required).
|
||||
|
||||
These are used to create a cloned voice profile for every cloning-capable engine (qwen, luxtts, chatterbox, chatterbox_turbo, tada). Keep them out of version control if they contain personal audio — this directory is not gitignored by default, so add them to `.gitignore` locally if needed.
|
||||
|
||||
You can point the test at different files with:
|
||||
|
||||
```
|
||||
python backend/tests/test_all_models_e2e.py \
|
||||
--reference-wav /path/to/your.wav \
|
||||
--reference-text "exact transcription here"
|
||||
```
|
||||
630
backend/tests/test_all_models_e2e.py
Normal file
630
backend/tests/test_all_models_e2e.py
Normal file
@@ -0,0 +1,630 @@
|
||||
"""
|
||||
End-to-end model generation test.
|
||||
|
||||
Exercises every TTS model against the frozen PyInstaller binary, captures
|
||||
per-model pass/fail, and writes a JSON + Markdown report.
|
||||
|
||||
Usage:
|
||||
python backend/tests/test_all_models_e2e.py [flags]
|
||||
|
||||
See E2E_MODEL_TEST_DESIGN.md for the full design.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
BACKEND_DIR = REPO_ROOT / "backend"
|
||||
DIST_DIR = BACKEND_DIR / "dist"
|
||||
FIXTURES_DIR = Path(__file__).resolve().parent / "fixtures"
|
||||
RESULTS_DIR = Path(__file__).resolve().parent / "results"
|
||||
|
||||
|
||||
# ── Test matrix ──────────────────────────────────────────────────────
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MatrixRow:
|
||||
label: str # human-readable (appears in report)
|
||||
engine: str # /generate engine
|
||||
model_size: Optional[str] # /generate model_size (None = omit)
|
||||
profile_kind: str # "cloned" | "preset_kokoro" | "preset_qwen_cv"
|
||||
model_name: str # /models/status key for cache lookup
|
||||
|
||||
|
||||
MATRIX: list[MatrixRow] = [
|
||||
MatrixRow("qwen 1.7B", "qwen", "1.7B", "cloned", "qwen-tts-1.7B"),
|
||||
MatrixRow("qwen 0.6B", "qwen", "0.6B", "cloned", "qwen-tts-0.6B"),
|
||||
MatrixRow("qwen_custom_voice 1.7B", "qwen_custom_voice", "1.7B", "preset_qwen_cv", "qwen-custom-voice-1.7B"),
|
||||
MatrixRow("qwen_custom_voice 0.6B", "qwen_custom_voice", "0.6B", "preset_qwen_cv", "qwen-custom-voice-0.6B"),
|
||||
MatrixRow("luxtts", "luxtts", None, "cloned", "luxtts"),
|
||||
MatrixRow("chatterbox", "chatterbox", None, "cloned", "chatterbox-tts"),
|
||||
MatrixRow("chatterbox_turbo", "chatterbox_turbo", None, "cloned", "chatterbox-turbo"),
|
||||
MatrixRow("tada 1B", "tada", "1B", "cloned", "tada-1b"),
|
||||
MatrixRow("tada 3B", "tada", "3B", "cloned", "tada-3b-ml"),
|
||||
MatrixRow("kokoro", "kokoro", None, "preset_kokoro", "kokoro"),
|
||||
]
|
||||
|
||||
TEXT = "The quick brown fox jumps over the lazy dog."
|
||||
DEFAULT_TIMEOUT_CACHED = 180
|
||||
DEFAULT_TIMEOUT_DOWNLOAD = 1200
|
||||
HEALTH_TIMEOUT = 120
|
||||
|
||||
|
||||
# ── Result record ────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class ModelResult:
|
||||
label: str
|
||||
engine: str
|
||||
model_size: Optional[str]
|
||||
status: str # "passed" | "failed" | "timeout"
|
||||
was_cached: Optional[bool] = None
|
||||
generation_id: Optional[str] = None
|
||||
elapsed_seconds: float = 0.0
|
||||
audio_duration: Optional[float] = None
|
||||
audio_path: Optional[str] = None
|
||||
audio_bytes: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
http_status: Optional[int] = None
|
||||
server_log_tail: Optional[list[str]] = None
|
||||
|
||||
|
||||
# ── Binary resolution ────────────────────────────────────────────────
|
||||
|
||||
def find_binary() -> Optional[Path]:
|
||||
"""Return the first existing binary in priority order, or None."""
|
||||
is_win = platform.system() == "Windows"
|
||||
exe = ".exe" if is_win else ""
|
||||
candidates = [
|
||||
DIST_DIR / "voicebox-server-cuda" / f"voicebox-server-cuda{exe}",
|
||||
DIST_DIR / f"voicebox-server{exe}",
|
||||
]
|
||||
for c in candidates:
|
||||
if c.exists() and c.is_file():
|
||||
return c
|
||||
return None
|
||||
|
||||
|
||||
def build_binary() -> Path:
|
||||
"""Invoke build_binary.py and return the resulting binary path."""
|
||||
print("[build] No frozen binary found — invoking build_binary.py (this may take 5-20 minutes)...", flush=True)
|
||||
script = BACKEND_DIR / "build_binary.py"
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(script)],
|
||||
cwd=str(BACKEND_DIR),
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"build_binary.py exited with code {result.returncode}")
|
||||
found = find_binary()
|
||||
if found is None:
|
||||
raise RuntimeError("build_binary.py finished but no binary was found in backend/dist/")
|
||||
return found
|
||||
|
||||
|
||||
# ── Server spawn + log capture ───────────────────────────────────────
|
||||
|
||||
class ServerProcess:
|
||||
def __init__(self, binary: Path, port: int, data_dir: Path, log_path: Path):
|
||||
self.binary = binary
|
||||
self.port = port
|
||||
self.data_dir = data_dir
|
||||
self.log_path = log_path
|
||||
self.proc: Optional[subprocess.Popen] = None
|
||||
self._log_buffer: deque[str] = deque(maxlen=500)
|
||||
self._reader_thread: Optional[threading.Thread] = None
|
||||
|
||||
def start(self) -> None:
|
||||
args = [
|
||||
str(self.binary),
|
||||
"--host", "127.0.0.1",
|
||||
"--port", str(self.port),
|
||||
"--data-dir", str(self.data_dir),
|
||||
"--parent-pid", str(os.getpid()),
|
||||
]
|
||||
print(f"[spawn] {' '.join(args)}", flush=True)
|
||||
self._log_fh = open(self.log_path, "w", encoding="utf-8", errors="replace")
|
||||
# Combine stderr into stdout so we get a single ordered stream.
|
||||
self.proc = subprocess.Popen(
|
||||
args,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
bufsize=1,
|
||||
text=True,
|
||||
errors="replace",
|
||||
)
|
||||
self._reader_thread = threading.Thread(target=self._pump_logs, daemon=True)
|
||||
self._reader_thread.start()
|
||||
|
||||
def _pump_logs(self) -> None:
|
||||
assert self.proc is not None and self.proc.stdout is not None
|
||||
for line in self.proc.stdout:
|
||||
self._log_buffer.append(line.rstrip("\n"))
|
||||
self._log_fh.write(line)
|
||||
self._log_fh.flush()
|
||||
|
||||
def log_tail(self, n: int = 100) -> list[str]:
|
||||
tail = list(self._log_buffer)[-n:]
|
||||
return tail
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
return self.proc is not None and self.proc.poll() is None
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.proc is None:
|
||||
return
|
||||
if self.proc.poll() is not None:
|
||||
return
|
||||
try:
|
||||
if platform.system() == "Windows":
|
||||
subprocess.run(
|
||||
["taskkill", "/F", "/T", "/PID", str(self.proc.pid)],
|
||||
capture_output=True,
|
||||
)
|
||||
else:
|
||||
self.proc.send_signal(signal.SIGTERM)
|
||||
except Exception as e:
|
||||
print(f"[shutdown] signal failed: {e}", flush=True)
|
||||
try:
|
||||
self.proc.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
print("[shutdown] server didn't exit cleanly, killing", flush=True)
|
||||
self.proc.kill()
|
||||
try:
|
||||
self.proc.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
pass
|
||||
if self._reader_thread is not None:
|
||||
self._reader_thread.join(timeout=2)
|
||||
try:
|
||||
self._log_fh.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def pick_free_port() -> int:
|
||||
s = socket.socket()
|
||||
s.bind(("127.0.0.1", 0))
|
||||
port = s.getsockname()[1]
|
||||
s.close()
|
||||
return port
|
||||
|
||||
|
||||
# ── HTTP helpers ─────────────────────────────────────────────────────
|
||||
|
||||
def wait_for_health(base_url: str, server: ServerProcess, timeout: int) -> None:
|
||||
deadline = time.time() + timeout
|
||||
with httpx.Client(timeout=5.0) as client:
|
||||
while time.time() < deadline:
|
||||
if not server.is_alive():
|
||||
raise RuntimeError("Server process exited before becoming healthy")
|
||||
try:
|
||||
r = client.get(f"{base_url}/health")
|
||||
if r.status_code == 200 and r.json().get("status") == "healthy":
|
||||
return
|
||||
except httpx.HTTPError:
|
||||
pass
|
||||
time.sleep(1.0)
|
||||
raise TimeoutError(f"Server did not become healthy within {timeout}s")
|
||||
|
||||
|
||||
def get_model_cached(client: httpx.Client, base_url: str, model_name: str) -> Optional[bool]:
|
||||
try:
|
||||
r = client.get(f"{base_url}/models/status", timeout=30.0)
|
||||
r.raise_for_status()
|
||||
for m in r.json().get("models", []):
|
||||
if m.get("model_name") == model_name:
|
||||
return bool(m.get("downloaded"))
|
||||
except httpx.HTTPError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def create_cloned_profile(client: httpx.Client, base_url: str, wav_path: Path, reference_text: str) -> str:
|
||||
r = client.post(f"{base_url}/profiles", json={
|
||||
"name": "e2e-cloned",
|
||||
"voice_type": "cloned",
|
||||
"language": "en",
|
||||
})
|
||||
r.raise_for_status()
|
||||
profile_id = r.json()["id"]
|
||||
|
||||
with open(wav_path, "rb") as f:
|
||||
r = client.post(
|
||||
f"{base_url}/profiles/{profile_id}/samples",
|
||||
files={"file": (wav_path.name, f, "audio/wav")},
|
||||
data={"reference_text": reference_text},
|
||||
timeout=120.0,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return profile_id
|
||||
|
||||
|
||||
def create_preset_profile(client: httpx.Client, base_url: str, name: str, engine: str, voice_id: str) -> str:
|
||||
r = client.post(f"{base_url}/profiles", json={
|
||||
"name": name,
|
||||
"voice_type": "preset",
|
||||
"language": "en",
|
||||
"preset_engine": engine,
|
||||
"preset_voice_id": voice_id,
|
||||
})
|
||||
r.raise_for_status()
|
||||
return r.json()["id"]
|
||||
|
||||
|
||||
def run_one_generation(
|
||||
client: httpx.Client,
|
||||
base_url: str,
|
||||
row: MatrixRow,
|
||||
profile_id: str,
|
||||
timeout_s: int,
|
||||
) -> tuple[str, dict]:
|
||||
"""Start a generation and stream its status until done/failed/timeout.
|
||||
|
||||
Returns (status, payload) where status is "completed" | "failed" | "timeout".
|
||||
"""
|
||||
body = {
|
||||
"profile_id": profile_id,
|
||||
"text": TEXT,
|
||||
"language": "en",
|
||||
"engine": row.engine,
|
||||
"seed": 42,
|
||||
"normalize": True,
|
||||
}
|
||||
if row.model_size is not None:
|
||||
body["model_size"] = row.model_size
|
||||
|
||||
r = client.post(f"{base_url}/generate", json=body, timeout=30.0)
|
||||
r.raise_for_status()
|
||||
gen = r.json()
|
||||
gen_id = gen["id"]
|
||||
|
||||
deadline = time.time() + timeout_s
|
||||
last_payload: dict = gen
|
||||
status_url = f"{base_url}/generate/{gen_id}/status"
|
||||
|
||||
while time.time() < deadline:
|
||||
remaining = max(1.0, deadline - time.time())
|
||||
try:
|
||||
with client.stream("GET", status_url, timeout=httpx.Timeout(remaining + 5, read=remaining + 5)) as resp:
|
||||
resp.raise_for_status()
|
||||
for line in resp.iter_lines():
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(line[6:])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
last_payload = payload
|
||||
status = payload.get("status")
|
||||
if status == "not_found":
|
||||
return "failed", {"error": "generation not found", **payload}
|
||||
if status in ("completed", "failed"):
|
||||
return status, payload
|
||||
if time.time() >= deadline:
|
||||
break
|
||||
except httpx.HTTPError:
|
||||
time.sleep(1.0)
|
||||
continue
|
||||
|
||||
return "timeout", last_payload
|
||||
|
||||
|
||||
def fetch_audio_info(
|
||||
client: httpx.Client, base_url: str, generation_id: str, data_dir: Path
|
||||
) -> tuple[Optional[str], Optional[int]]:
|
||||
"""Return (audio_path, audio_bytes) for a completed generation.
|
||||
|
||||
Server stores audio_path relative to data_dir; resolve it to get a size.
|
||||
"""
|
||||
try:
|
||||
r = client.get(f"{base_url}/history/{generation_id}", timeout=10.0)
|
||||
if r.status_code != 200:
|
||||
return None, None
|
||||
data = r.json()
|
||||
audio_path = data.get("audio_path")
|
||||
if not audio_path:
|
||||
return None, None
|
||||
p = Path(audio_path)
|
||||
if not p.is_absolute():
|
||||
p = data_dir / p
|
||||
if p.exists():
|
||||
return str(p), p.stat().st_size
|
||||
return audio_path, None
|
||||
except httpx.HTTPError:
|
||||
return None, None
|
||||
|
||||
|
||||
# ── Report writers ───────────────────────────────────────────────────
|
||||
|
||||
def write_reports(
|
||||
output_dir: Path,
|
||||
binary: Path,
|
||||
started_at: datetime,
|
||||
finished_at: datetime,
|
||||
results: list[ModelResult],
|
||||
) -> tuple[Path, Path]:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
plat = f"{platform.system().lower()}-{platform.machine().lower()}"
|
||||
ts = started_at.strftime("%Y%m%d-%H%M%S")
|
||||
json_path = output_dir / f"e2e-{plat}-{ts}.json"
|
||||
md_path = output_dir / f"e2e-{plat}-{ts}.md"
|
||||
|
||||
doc = {
|
||||
"platform": plat,
|
||||
"binary": str(binary),
|
||||
"binary_size_mb": round(binary.stat().st_size / (1024 * 1024), 1) if binary.exists() else None,
|
||||
"started_at": started_at.isoformat(),
|
||||
"finished_at": finished_at.isoformat(),
|
||||
"elapsed_seconds": (finished_at - started_at).total_seconds(),
|
||||
"results": [asdict(r) for r in results],
|
||||
}
|
||||
json_path.write_text(json.dumps(doc, indent=2))
|
||||
|
||||
lines = [
|
||||
f"# Voicebox E2E — {plat} — {started_at.strftime('%Y-%m-%d %H:%M UTC')}",
|
||||
"",
|
||||
f"Binary: `{binary}` ",
|
||||
f"Elapsed: {doc['elapsed_seconds']:.1f}s",
|
||||
"",
|
||||
"| Model | Status | Cached | Elapsed | Audio | Error |",
|
||||
"|-------|--------|--------|---------|-------|-------|",
|
||||
]
|
||||
for r in results:
|
||||
status_icon = {"passed": "PASS", "failed": "FAIL", "timeout": "TIMEOUT"}.get(r.status, r.status.upper())
|
||||
cached = "yes" if r.was_cached else ("no" if r.was_cached is False else "?")
|
||||
audio_col = f"{r.audio_duration:.2f}s" if r.audio_duration else ("—" if r.status != "passed" else "?")
|
||||
error_col = (r.error or "").replace("\n", " ")[:120]
|
||||
lines.append(f"| {r.label} | {status_icon} | {cached} | {r.elapsed_seconds:.1f}s | {audio_col} | {error_col} |")
|
||||
|
||||
failed_rows = [r for r in results if r.status != "passed"]
|
||||
if failed_rows:
|
||||
lines.append("")
|
||||
lines.append("## Failures")
|
||||
for r in failed_rows:
|
||||
lines.append("")
|
||||
lines.append(f"### {r.label} — {r.status}")
|
||||
if r.error:
|
||||
lines.append("")
|
||||
lines.append("```")
|
||||
lines.append(r.error)
|
||||
lines.append("```")
|
||||
if r.server_log_tail:
|
||||
lines.append("")
|
||||
lines.append("<details><summary>server log (last lines)</summary>")
|
||||
lines.append("")
|
||||
lines.append("```")
|
||||
lines.extend(r.server_log_tail)
|
||||
lines.append("```")
|
||||
lines.append("</details>")
|
||||
|
||||
md_path.write_text("\n".join(lines) + "\n")
|
||||
return json_path, md_path
|
||||
|
||||
|
||||
# ── Main ─────────────────────────────────────────────────────────────
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="Voicebox E2E model generation test")
|
||||
p.add_argument("--binary", type=Path, help="Path to voicebox-server binary (overrides auto-detect)")
|
||||
p.add_argument("--skip-build", action="store_true", help="Error if binary missing instead of building")
|
||||
p.add_argument(
|
||||
"--reference-wav",
|
||||
type=Path,
|
||||
default=FIXTURES_DIR / "reference_voice.wav",
|
||||
help="Reference audio for cloning engines",
|
||||
)
|
||||
p.add_argument(
|
||||
"--reference-text",
|
||||
help="Transcription of reference-wav (default: read from fixtures/reference_voice.txt)",
|
||||
)
|
||||
p.add_argument("--only", help="Comma-separated engines to run (e.g. kokoro,qwen)")
|
||||
p.add_argument("--skip", help="Comma-separated engines to skip")
|
||||
p.add_argument("--keep-data-dir", action="store_true", help="Don't delete tempdir after run")
|
||||
p.add_argument("--timeout-cached", type=int, default=DEFAULT_TIMEOUT_CACHED)
|
||||
p.add_argument("--timeout-download", type=int, default=DEFAULT_TIMEOUT_DOWNLOAD)
|
||||
p.add_argument("--port", type=int, help="Override auto-picked port")
|
||||
p.add_argument("--output-dir", type=Path, default=RESULTS_DIR)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def filter_matrix(args: argparse.Namespace) -> list[MatrixRow]:
|
||||
only = set(x.strip() for x in args.only.split(",")) if args.only else None
|
||||
skip = set(x.strip() for x in args.skip.split(",")) if args.skip else set()
|
||||
rows = []
|
||||
for r in MATRIX:
|
||||
if only is not None and r.engine not in only:
|
||||
continue
|
||||
if r.engine in skip:
|
||||
continue
|
||||
rows.append(r)
|
||||
return rows
|
||||
|
||||
|
||||
def resolve_reference(args: argparse.Namespace) -> tuple[Path, str]:
|
||||
wav = args.reference_wav
|
||||
if not wav.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Reference WAV not found: {wav}\n"
|
||||
f"Place a sample at {FIXTURES_DIR / 'reference_voice.wav'} or pass --reference-wav.\n"
|
||||
f"See backend/tests/fixtures/README.md."
|
||||
)
|
||||
if args.reference_text:
|
||||
text = args.reference_text
|
||||
else:
|
||||
txt_path = wav.with_suffix(".txt")
|
||||
if not txt_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Reference transcription not found: {txt_path}\n"
|
||||
f"Create it next to the WAV, or pass --reference-text."
|
||||
)
|
||||
text = txt_path.read_text().strip()
|
||||
if not text:
|
||||
raise ValueError("Reference transcription is empty")
|
||||
return wav, text
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
rows = filter_matrix(args)
|
||||
if not rows:
|
||||
print("No rows selected after --only/--skip filtering", file=sys.stderr)
|
||||
return 2
|
||||
|
||||
# Binary
|
||||
binary = args.binary or find_binary()
|
||||
if binary is None:
|
||||
if args.skip_build:
|
||||
print("No frozen binary found and --skip-build set. Run: python backend/build_binary.py", file=sys.stderr)
|
||||
return 2
|
||||
binary = build_binary()
|
||||
if not binary.exists():
|
||||
print(f"Binary path does not exist: {binary}", file=sys.stderr)
|
||||
return 2
|
||||
print(f"[binary] {binary}", flush=True)
|
||||
|
||||
# Reference audio (only required if any cloning row is in the matrix)
|
||||
needs_reference = any(r.profile_kind == "cloned" for r in rows)
|
||||
ref_wav: Optional[Path] = None
|
||||
ref_text: Optional[str] = None
|
||||
if needs_reference:
|
||||
try:
|
||||
ref_wav, ref_text = resolve_reference(args)
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
print(f"[fixture] {e}", file=sys.stderr)
|
||||
return 2
|
||||
print(f"[fixture] reference WAV: {ref_wav}", flush=True)
|
||||
print(f"[fixture] reference text: {ref_text!r}", flush=True)
|
||||
|
||||
# Tempdir + log path
|
||||
data_dir = Path(tempfile.mkdtemp(prefix="voicebox-e2e-"))
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
log_path = args.output_dir / f"server-{ts}.log"
|
||||
|
||||
port = args.port or pick_free_port()
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
|
||||
server = ServerProcess(binary=binary, port=port, data_dir=data_dir, log_path=log_path)
|
||||
started_at = datetime.now(timezone.utc)
|
||||
results: list[ModelResult] = []
|
||||
|
||||
try:
|
||||
server.start()
|
||||
print(f"[health] waiting for {base_url}/health ...", flush=True)
|
||||
wait_for_health(base_url, server, HEALTH_TIMEOUT)
|
||||
print("[health] ready", flush=True)
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
# Profile setup (only create what's needed)
|
||||
cloned_profile_id: Optional[str] = None
|
||||
kokoro_profile_id: Optional[str] = None
|
||||
qwen_cv_profile_id: Optional[str] = None
|
||||
needed_kinds = {r.profile_kind for r in rows}
|
||||
if "cloned" in needed_kinds:
|
||||
assert ref_wav is not None and ref_text is not None
|
||||
print("[profile] creating cloned profile...", flush=True)
|
||||
cloned_profile_id = create_cloned_profile(client, base_url, ref_wav, ref_text)
|
||||
if "preset_kokoro" in needed_kinds:
|
||||
print("[profile] creating kokoro preset...", flush=True)
|
||||
kokoro_profile_id = create_preset_profile(client, base_url, "e2e-kokoro", "kokoro", "af_heart")
|
||||
if "preset_qwen_cv" in needed_kinds:
|
||||
print("[profile] creating qwen_custom_voice preset...", flush=True)
|
||||
qwen_cv_profile_id = create_preset_profile(client, base_url, "e2e-qwen-cv", "qwen_custom_voice", "Ryan")
|
||||
|
||||
profile_lookup = {
|
||||
"cloned": cloned_profile_id,
|
||||
"preset_kokoro": kokoro_profile_id,
|
||||
"preset_qwen_cv": qwen_cv_profile_id,
|
||||
}
|
||||
|
||||
# Matrix loop
|
||||
for row in rows:
|
||||
print(f"\n[run] {row.label} (engine={row.engine}, size={row.model_size})", flush=True)
|
||||
profile_id = profile_lookup[row.profile_kind]
|
||||
assert profile_id is not None
|
||||
was_cached = get_model_cached(client, base_url, row.model_name)
|
||||
timeout_s = args.timeout_cached if was_cached else args.timeout_download
|
||||
print(f"[run] cached={was_cached} timeout={timeout_s}s", flush=True)
|
||||
|
||||
t0 = time.time()
|
||||
result = ModelResult(
|
||||
label=row.label,
|
||||
engine=row.engine,
|
||||
model_size=row.model_size,
|
||||
status="failed",
|
||||
was_cached=was_cached,
|
||||
)
|
||||
try:
|
||||
status, payload = run_one_generation(client, base_url, row, profile_id, timeout_s)
|
||||
result.status = "passed" if status == "completed" else status
|
||||
result.generation_id = payload.get("id")
|
||||
result.audio_duration = payload.get("duration")
|
||||
result.error = payload.get("error")
|
||||
if status == "completed" and result.generation_id:
|
||||
audio_path, audio_bytes = fetch_audio_info(
|
||||
client, base_url, result.generation_id, data_dir
|
||||
)
|
||||
result.audio_path = audio_path
|
||||
result.audio_bytes = audio_bytes
|
||||
if audio_bytes is not None and audio_bytes == 0:
|
||||
result.status = "failed"
|
||||
result.error = (result.error or "") + " (audio file is empty)"
|
||||
except httpx.HTTPStatusError as e:
|
||||
result.status = "failed"
|
||||
result.http_status = e.response.status_code
|
||||
try:
|
||||
detail = e.response.json().get("detail")
|
||||
except Exception:
|
||||
detail = e.response.text
|
||||
result.error = f"HTTP {e.response.status_code}: {detail}"
|
||||
except Exception as e:
|
||||
result.status = "failed"
|
||||
result.error = f"{type(e).__name__}: {e}"
|
||||
|
||||
result.elapsed_seconds = round(time.time() - t0, 2)
|
||||
if result.status != "passed":
|
||||
result.server_log_tail = server.log_tail(100)
|
||||
print(f"[run] {row.label} → {result.status} in {result.elapsed_seconds}s"
|
||||
+ (f" ({result.error})" if result.error else ""), flush=True)
|
||||
results.append(result)
|
||||
finally:
|
||||
finished_at = datetime.now(timezone.utc)
|
||||
server.stop()
|
||||
if not args.keep_data_dir:
|
||||
shutil.rmtree(data_dir, ignore_errors=True)
|
||||
else:
|
||||
print(f"[cleanup] keeping data dir: {data_dir}", flush=True)
|
||||
|
||||
json_path, md_path = write_reports(args.output_dir, binary, started_at, finished_at, results)
|
||||
print(f"\n[report] {json_path}")
|
||||
print(f"[report] {md_path}")
|
||||
print(f"[report] server log: {log_path}")
|
||||
|
||||
passed = sum(1 for r in results if r.status == "passed")
|
||||
failed = len(results) - passed
|
||||
print(f"\n== {passed} passed, {failed} failed ==")
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
112
backend/tests/test_audio_preprocess.py
Normal file
112
backend/tests/test_audio_preprocess.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Unit tests for reference-audio preprocessing.
|
||||
|
||||
Covers :func:`backend.utils.audio.preprocess_reference_audio` and
|
||||
:func:`backend.utils.audio.validate_and_load_reference_audio`.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from utils.audio import ( # noqa: E402
|
||||
preprocess_reference_audio,
|
||||
validate_and_load_reference_audio,
|
||||
)
|
||||
|
||||
|
||||
SR = 24000
|
||||
|
||||
|
||||
def _tone(duration_s: float, amp: float = 0.3, freq: float = 220.0) -> np.ndarray:
|
||||
n = int(duration_s * SR)
|
||||
t = np.arange(n, dtype=np.float32) / SR
|
||||
return (amp * np.sin(2 * np.pi * freq * t)).astype(np.float32)
|
||||
|
||||
|
||||
def test_peak_cap_scales_hot_input():
|
||||
audio = _tone(3.0, amp=0.99)
|
||||
out = preprocess_reference_audio(audio, SR)
|
||||
assert np.abs(out).max() <= 0.951
|
||||
|
||||
|
||||
def test_peak_cap_leaves_moderate_input_untouched():
|
||||
audio = _tone(3.0, amp=0.5)
|
||||
out = preprocess_reference_audio(audio, SR)
|
||||
assert np.isclose(np.abs(out).max(), 0.5, atol=1e-3)
|
||||
|
||||
|
||||
def test_dc_offset_removed():
|
||||
audio = _tone(3.0, amp=0.3) + 0.1
|
||||
out = preprocess_reference_audio(audio, SR)
|
||||
assert abs(float(np.mean(out))) < 1e-3
|
||||
|
||||
|
||||
def test_silence_is_trimmed_with_padding_kept():
|
||||
silence = np.zeros(int(SR * 1.0), dtype=np.float32)
|
||||
speech = _tone(3.0, amp=0.3)
|
||||
audio = np.concatenate([silence, speech, silence])
|
||||
out = preprocess_reference_audio(audio, SR)
|
||||
# Most of the 2s of leading/trailing silence should be gone, but the
|
||||
# 3s of speech plus ~200ms of padding should remain.
|
||||
assert len(audio) - len(out) >= SR, "expected >=1s of silence trimmed"
|
||||
assert len(out) >= int(3.0 * SR), "speech body should be preserved"
|
||||
|
||||
|
||||
def test_clean_audio_is_not_padded_past_original_length():
|
||||
# Well-recorded audio with no edge silence shouldn't get longer after
|
||||
# preprocessing — otherwise a 29.9 s upload could be pushed past the
|
||||
# 30 s max_duration ceiling downstream.
|
||||
audio = _tone(3.0, amp=0.3)
|
||||
out = preprocess_reference_audio(audio, SR)
|
||||
assert len(out) <= len(audio)
|
||||
|
||||
|
||||
def test_empty_input_returns_empty():
|
||||
out = preprocess_reference_audio(np.zeros(0, dtype=np.float32), SR)
|
||||
assert out.size == 0
|
||||
|
||||
|
||||
def test_validate_accepts_previously_rejected_hot_file(tmp_path):
|
||||
audio = _tone(3.0, amp=0.995)
|
||||
path = tmp_path / "hot.wav"
|
||||
sf.write(str(path), audio, SR)
|
||||
|
||||
ok, err, out_audio, out_sr = validate_and_load_reference_audio(str(path))
|
||||
|
||||
assert ok, f"expected pass, got error: {err}"
|
||||
assert out_audio is not None
|
||||
assert out_sr == SR
|
||||
assert np.abs(out_audio).max() <= 0.951
|
||||
|
||||
|
||||
def test_validate_still_rejects_silent_input(tmp_path):
|
||||
audio = np.zeros(int(SR * 3.0), dtype=np.float32)
|
||||
path = tmp_path / "silent.wav"
|
||||
sf.write(str(path), audio, SR)
|
||||
|
||||
ok, err, _, _ = validate_and_load_reference_audio(str(path))
|
||||
|
||||
assert not ok
|
||||
assert err is not None
|
||||
assert "too short" in err.lower() or "quiet" in err.lower()
|
||||
|
||||
|
||||
def test_validate_rejects_too_short(tmp_path):
|
||||
audio = _tone(0.5, amp=0.3)
|
||||
path = tmp_path / "short.wav"
|
||||
sf.write(str(path), audio, SR)
|
||||
|
||||
ok, err, _, _ = validate_and_load_reference_audio(str(path))
|
||||
|
||||
assert not ok
|
||||
assert "too short" in (err or "").lower()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
162
backend/tests/test_cors.py
Normal file
162
backend/tests/test_cors.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Tests for CORS origin restrictions.
|
||||
|
||||
Validates that the CORS middleware only allows known local origins
|
||||
and respects the VOICEBOX_CORS_ORIGINS environment variable.
|
||||
|
||||
Uses a minimal FastAPI app that mirrors the exact CORS configuration
|
||||
from backend/main.py, so tests run without heavy ML dependencies.
|
||||
|
||||
Usage:
|
||||
pip install httpx pytest fastapi starlette
|
||||
python -m pytest backend/tests/test_cors.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
|
||||
def _build_app(env_origins: str = "") -> FastAPI:
|
||||
"""
|
||||
Build a minimal FastAPI app with the same CORS logic as backend/main.py.
|
||||
|
||||
This mirrors the exact code in main.py so the test validates the real
|
||||
configuration without needing torch/numpy/transformers installed.
|
||||
"""
|
||||
app = FastAPI()
|
||||
|
||||
_default_origins = [
|
||||
"http://localhost:5173",
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:17493",
|
||||
"http://127.0.0.1:17493",
|
||||
"tauri://localhost",
|
||||
"https://tauri.localhost",
|
||||
]
|
||||
_cors_origins = _default_origins + [o.strip() for o in env_origins.split(",") if o.strip()]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=_cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client():
|
||||
return TestClient(_build_app())
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client_with_custom_origins():
|
||||
return TestClient(_build_app("https://custom.example.com,https://other.example.com"))
|
||||
|
||||
|
||||
def _get_with_origin(client: TestClient, origin: str) -> dict:
|
||||
"""Send a GET with Origin header, return response headers."""
|
||||
response = client.get("/health", headers={"Origin": origin})
|
||||
return dict(response.headers)
|
||||
|
||||
|
||||
def _preflight(client: TestClient, origin: str) -> dict:
|
||||
"""Send CORS preflight OPTIONS request, return response headers."""
|
||||
response = client.options(
|
||||
"/health",
|
||||
headers={
|
||||
"Origin": origin,
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
return dict(response.headers)
|
||||
|
||||
|
||||
class TestCORSDefaultOrigins:
|
||||
"""CORS should allow known local origins and block everything else."""
|
||||
|
||||
@pytest.mark.parametrize("origin", [
|
||||
"http://localhost:5173",
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:17493",
|
||||
"http://127.0.0.1:17493",
|
||||
"tauri://localhost",
|
||||
"https://tauri.localhost",
|
||||
])
|
||||
def test_allowed_origins(self, client, origin):
|
||||
headers = _get_with_origin(client, origin)
|
||||
assert headers.get("access-control-allow-origin") == origin
|
||||
|
||||
@pytest.mark.parametrize("origin", [
|
||||
"http://evil.com",
|
||||
"http://localhost:9999",
|
||||
"https://attacker.example.com",
|
||||
"null",
|
||||
])
|
||||
def test_blocked_origins(self, client, origin):
|
||||
headers = _get_with_origin(client, origin)
|
||||
assert "access-control-allow-origin" not in headers
|
||||
|
||||
def test_preflight_allowed(self, client):
|
||||
headers = _preflight(client, "http://localhost:5173")
|
||||
assert headers.get("access-control-allow-origin") == "http://localhost:5173"
|
||||
|
||||
def test_preflight_blocked(self, client):
|
||||
headers = _preflight(client, "http://evil.com")
|
||||
assert "access-control-allow-origin" not in headers
|
||||
|
||||
def test_credentials_header_present(self, client):
|
||||
headers = _get_with_origin(client, "http://localhost:5173")
|
||||
assert headers.get("access-control-allow-credentials") == "true"
|
||||
|
||||
|
||||
class TestCORSCustomOrigins:
|
||||
"""VOICEBOX_CORS_ORIGINS env var should extend the allowlist."""
|
||||
|
||||
def test_custom_origin_allowed(self, client_with_custom_origins):
|
||||
headers = _get_with_origin(client_with_custom_origins, "https://custom.example.com")
|
||||
assert headers.get("access-control-allow-origin") == "https://custom.example.com"
|
||||
|
||||
def test_other_custom_origin_allowed(self, client_with_custom_origins):
|
||||
headers = _get_with_origin(client_with_custom_origins, "https://other.example.com")
|
||||
assert headers.get("access-control-allow-origin") == "https://other.example.com"
|
||||
|
||||
def test_default_origins_still_work(self, client_with_custom_origins):
|
||||
headers = _get_with_origin(client_with_custom_origins, "http://localhost:5173")
|
||||
assert headers.get("access-control-allow-origin") == "http://localhost:5173"
|
||||
|
||||
def test_unlisted_origin_still_blocked(self, client_with_custom_origins):
|
||||
headers = _get_with_origin(client_with_custom_origins, "http://evil.com")
|
||||
assert "access-control-allow-origin" not in headers
|
||||
|
||||
|
||||
class TestCORSEnvVarParsing:
|
||||
"""Edge cases for VOICEBOX_CORS_ORIGINS parsing."""
|
||||
|
||||
def test_empty_env_var(self):
|
||||
app = _build_app("")
|
||||
client = TestClient(app)
|
||||
headers = _get_with_origin(client, "http://evil.com")
|
||||
assert "access-control-allow-origin" not in headers
|
||||
|
||||
def test_whitespace_trimmed(self):
|
||||
app = _build_app(" https://spaced.example.com ")
|
||||
client = TestClient(app)
|
||||
headers = _get_with_origin(client, "https://spaced.example.com")
|
||||
assert headers.get("access-control-allow-origin") == "https://spaced.example.com"
|
||||
|
||||
def test_trailing_comma_ignored(self):
|
||||
app = _build_app("https://one.example.com,")
|
||||
client = TestClient(app)
|
||||
headers = _get_with_origin(client, "https://one.example.com")
|
||||
assert headers.get("access-control-allow-origin") == "https://one.example.com"
|
||||
305
backend/tests/test_generation_download.py
Normal file
305
backend/tests/test_generation_download.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
Test TTS generation with SSE progress monitoring.
|
||||
This test captures the exact SSE events triggered during generation
|
||||
to identify UX issues where users see download progress even when
|
||||
the model is already cached.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import httpx
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
async def monitor_sse_stream(model_name: str, timeout: int = 120):
|
||||
"""Monitor SSE stream for a model during generation."""
|
||||
events: List[Dict] = []
|
||||
url = f"http://localhost:8000/models/progress/{model_name}"
|
||||
|
||||
print(f"[{_timestamp()}] Connecting to SSE endpoint: {url}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("GET", url) as response:
|
||||
print(f"[{_timestamp()}] SSE connected, status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"[{_timestamp()}] Error: SSE endpoint returned {response.status_code}")
|
||||
return events
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
timestamp = _timestamp()
|
||||
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
data = json.loads(line[6:])
|
||||
print(
|
||||
f"[{timestamp}] → SSE Event: {data['status']:12} {data.get('progress', 0):6.1f}% {data.get('filename', '')}"
|
||||
)
|
||||
events.append({**data, "_timestamp": timestamp})
|
||||
|
||||
# Stop if complete or error
|
||||
if data.get("status") in ("complete", "error"):
|
||||
print(f"[{timestamp}] → Model {data['status']}!")
|
||||
break
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[{timestamp}] Error parsing JSON: {e}")
|
||||
print(f" Line was: {line}")
|
||||
|
||||
elif line.startswith(": heartbeat"):
|
||||
print(f"[{timestamp}] ♥ heartbeat")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print(f"[{_timestamp()}] SSE monitoring timed out")
|
||||
except Exception as e:
|
||||
print(f"[{_timestamp()}] SSE error: {e}")
|
||||
|
||||
return events
|
||||
|
||||
|
||||
async def trigger_generation(profile_id: str, text: str, model_size: str = "1.7B"):
|
||||
"""Trigger TTS generation via the API."""
|
||||
url = "http://localhost:8000/generate"
|
||||
|
||||
print(f"\n[{_timestamp()}] Triggering generation...")
|
||||
print(f" Profile: {profile_id}")
|
||||
print(f" Text: {text[:50]}...")
|
||||
print(f" Model: {model_size}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json={
|
||||
"profile_id": profile_id,
|
||||
"text": text,
|
||||
"language": "en",
|
||||
"model_size": model_size,
|
||||
},
|
||||
)
|
||||
|
||||
print(f"[{_timestamp()}] Response: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print(f"[{_timestamp()}] ✓ Generation successful!")
|
||||
print(f" Generation ID: {result.get('id')}")
|
||||
print(f" Duration: {result.get('duration', 0):.2f}s")
|
||||
return True, result
|
||||
elif response.status_code == 202:
|
||||
# Model is being downloaded
|
||||
result = response.json()
|
||||
print(f"[{_timestamp()}] → Model download in progress")
|
||||
print(f" Detail: {result}")
|
||||
return False, result
|
||||
else:
|
||||
print(f"[{_timestamp()}] ✗ Error: {response.text}")
|
||||
return False, None
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{_timestamp()}] ✗ Exception: {e}")
|
||||
return False, None
|
||||
|
||||
|
||||
async def get_first_profile():
|
||||
"""Get the first available voice profile."""
|
||||
url = "http://localhost:8000/profiles"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get(url)
|
||||
if response.status_code == 200:
|
||||
profiles = response.json()
|
||||
if profiles:
|
||||
return profiles[0]["id"]
|
||||
except Exception as e:
|
||||
print(f"Error getting profiles: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def check_server():
|
||||
"""Check if the server is running."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
response = await client.get("http://localhost:8000/health")
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
print(f"Server not running: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _timestamp():
|
||||
"""Get current timestamp for logging."""
|
||||
return datetime.now().strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
|
||||
async def test_generation_with_cached_model():
|
||||
"""
|
||||
Test Case 1: Generation when model is already cached.
|
||||
|
||||
This should NOT show any download progress events.
|
||||
If it does, that's the UX bug we're trying to fix.
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST CASE 1: Generation with Cached Model")
|
||||
print("=" * 80)
|
||||
print("Expected: No download progress events (or minimal/instant completion)")
|
||||
print("Actual UX Issue: Users see 'started' and 'finished' events even for cached models")
|
||||
print("=" * 80)
|
||||
|
||||
model_size = "1.7B"
|
||||
model_name = f"qwen-tts-{model_size}"
|
||||
|
||||
# Get a profile
|
||||
profile_id = await get_first_profile()
|
||||
if not profile_id:
|
||||
print("✗ No voice profiles found. Please create a profile first.")
|
||||
return False
|
||||
|
||||
print(f"\nUsing profile: {profile_id}")
|
||||
|
||||
# Start SSE monitor BEFORE triggering generation
|
||||
monitor_task = asyncio.create_task(monitor_sse_stream(model_name, timeout=30))
|
||||
|
||||
# Wait for SSE to connect
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Trigger generation
|
||||
test_text = "Hello, this is a test of the voice generation system."
|
||||
success, result = await trigger_generation(profile_id, test_text, model_size)
|
||||
|
||||
if not success and result and result.get("downloading"):
|
||||
print("\n⚠ Model is being downloaded. Waiting for download to complete...")
|
||||
# Wait for SSE monitor to capture download events
|
||||
events = await monitor_task
|
||||
return events
|
||||
|
||||
# Wait a bit more to catch any progress events
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# Cancel SSE monitor
|
||||
monitor_task.cancel()
|
||||
try:
|
||||
events = await monitor_task
|
||||
except asyncio.CancelledError:
|
||||
events = []
|
||||
|
||||
return events
|
||||
|
||||
|
||||
async def test_generation_with_fresh_download():
|
||||
"""
|
||||
Test Case 2: Generation when model needs to be downloaded.
|
||||
|
||||
This SHOULD show download progress events.
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST CASE 2: Generation with Model Download")
|
||||
print("=" * 80)
|
||||
print("Expected: Download progress events from 0% to 100%")
|
||||
print("=" * 80)
|
||||
|
||||
# Use a different model size to force download
|
||||
model_size = "0.6B" # Smaller model for faster testing
|
||||
model_name = f"qwen-tts-{model_size}"
|
||||
|
||||
# Get a profile
|
||||
profile_id = await get_first_profile()
|
||||
if not profile_id:
|
||||
print("✗ No voice profiles found. Please create a profile first.")
|
||||
return False
|
||||
|
||||
print(f"\nUsing profile: {profile_id}")
|
||||
print("Note: This will download the model if not cached")
|
||||
|
||||
# Start SSE monitor BEFORE triggering generation
|
||||
monitor_task = asyncio.create_task(monitor_sse_stream(model_name, timeout=300))
|
||||
|
||||
# Wait for SSE to connect
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Trigger generation
|
||||
test_text = "This should trigger a model download if the model is not cached."
|
||||
success, result = await trigger_generation(profile_id, test_text, model_size)
|
||||
|
||||
if not success and result and result.get("downloading"):
|
||||
print("\n→ Model download initiated. Monitoring progress...")
|
||||
# Wait for download to complete
|
||||
events = await monitor_task
|
||||
|
||||
# Try generation again
|
||||
print(f"\n[{_timestamp()}] Retrying generation after download...")
|
||||
await asyncio.sleep(2)
|
||||
success, result = await trigger_generation(profile_id, test_text, model_size)
|
||||
|
||||
if success:
|
||||
print("✓ Generation successful after download")
|
||||
|
||||
return events
|
||||
|
||||
# If model was already cached
|
||||
await asyncio.sleep(3)
|
||||
monitor_task.cancel()
|
||||
try:
|
||||
events = await monitor_task
|
||||
except asyncio.CancelledError:
|
||||
events = []
|
||||
|
||||
return events
|
||||
|
||||
|
||||
async def main():
|
||||
print("=" * 80)
|
||||
print("TTS Generation Progress Test")
|
||||
print("=" * 80)
|
||||
print("Purpose: Capture exact SSE events during generation to identify UX issues")
|
||||
print("=" * 80)
|
||||
|
||||
# Check if server is running
|
||||
print(f"\n[{_timestamp()}] Checking if server is running...")
|
||||
if not await check_server():
|
||||
print("✗ Server is not running on http://localhost:8000")
|
||||
print("\nPlease start the server first:")
|
||||
print(" cd backend && python main.py")
|
||||
return False
|
||||
|
||||
print("✓ Server is running")
|
||||
|
||||
# Test Case 1: Cached model
|
||||
print("\n" + "🧪 " * 20)
|
||||
events_cached = await test_generation_with_cached_model()
|
||||
|
||||
# Results for Test Case 1
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST CASE 1 RESULTS: Generation with Cached Model")
|
||||
print("=" * 80)
|
||||
|
||||
if not events_cached:
|
||||
print("✓ GOOD: No SSE progress events received")
|
||||
print(" This is the expected behavior for a cached model.")
|
||||
else:
|
||||
print(f"⚠ ISSUE FOUND: Received {len(events_cached)} SSE events:")
|
||||
print("\nEvent Timeline:")
|
||||
for i, event in enumerate(events_cached, 1):
|
||||
timestamp = event.pop("_timestamp", "??:??:??.???")
|
||||
print(f" {i}. [{timestamp}] {event}")
|
||||
|
||||
print("\n⚠ This explains the UX issue!")
|
||||
print(" Users see progress events even when the model is already cached,")
|
||||
print(" making them think the model is downloading again.")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Test Complete!")
|
||||
print("=" * 80)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
118
backend/tests/test_offline_guard.py
Normal file
118
backend/tests/test_offline_guard.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Unit tests for the ``force_offline_if_cached`` helper.
|
||||
|
||||
Verifies that the helper mutates the cached module constants in
|
||||
``huggingface_hub.constants`` and ``transformers.utils.hub`` — not just
|
||||
``os.environ`` — and that concurrent users are refcount-coordinated so
|
||||
one thread's exit can't strip another thread's offline protection.
|
||||
|
||||
NOTE: These tests mutate process-global state in ``huggingface_hub.constants``
|
||||
and ``transformers.utils.hub``. They are not safe under cross-process
|
||||
parallelism (e.g. ``pytest-xdist`` with ``--dist=loadfile``/``loadscope``);
|
||||
run this file serially.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from utils.hf_offline_patch import force_offline_if_cached # noqa: E402
|
||||
|
||||
|
||||
def _hf_const():
|
||||
import huggingface_hub.constants as hf_const
|
||||
|
||||
return hf_const
|
||||
|
||||
|
||||
def _tf_hub():
|
||||
import transformers.utils.hub as tf_hub
|
||||
|
||||
return tf_hub
|
||||
|
||||
|
||||
def test_mutates_cached_huggingface_hub_constant():
|
||||
original = _hf_const().HF_HUB_OFFLINE
|
||||
with force_offline_if_cached(True, "t"):
|
||||
assert _hf_const().HF_HUB_OFFLINE is True
|
||||
assert original == _hf_const().HF_HUB_OFFLINE
|
||||
|
||||
|
||||
def test_mutates_cached_transformers_constant():
|
||||
original = _tf_hub()._is_offline_mode
|
||||
with force_offline_if_cached(True, "t"):
|
||||
assert _tf_hub()._is_offline_mode is True
|
||||
assert original == _tf_hub()._is_offline_mode
|
||||
|
||||
|
||||
def test_sets_env_variable():
|
||||
original = os.environ.get("HF_HUB_OFFLINE")
|
||||
with force_offline_if_cached(True, "t"):
|
||||
assert "1" == os.environ.get("HF_HUB_OFFLINE")
|
||||
assert original == os.environ.get("HF_HUB_OFFLINE")
|
||||
|
||||
|
||||
def test_noop_when_not_cached():
|
||||
before = _hf_const().HF_HUB_OFFLINE
|
||||
with force_offline_if_cached(False, "t"):
|
||||
assert before == _hf_const().HF_HUB_OFFLINE
|
||||
|
||||
|
||||
def test_nested_contexts_respect_refcount():
|
||||
original = _hf_const().HF_HUB_OFFLINE
|
||||
with force_offline_if_cached(True, "outer"):
|
||||
assert _hf_const().HF_HUB_OFFLINE is True
|
||||
with force_offline_if_cached(True, "inner"):
|
||||
assert _hf_const().HF_HUB_OFFLINE is True
|
||||
# inner exit must not restore while outer is still active
|
||||
assert _hf_const().HF_HUB_OFFLINE is True
|
||||
assert original == _hf_const().HF_HUB_OFFLINE
|
||||
|
||||
|
||||
def test_concurrent_threads_share_offline_window():
|
||||
"""A slow thread must keep seeing offline mode even if a peer exits first."""
|
||||
original = _hf_const().HF_HUB_OFFLINE
|
||||
observations: list[bool] = []
|
||||
errors: list[Exception] = []
|
||||
barrier = threading.Barrier(2)
|
||||
fast_exited = threading.Event()
|
||||
|
||||
def slow():
|
||||
try:
|
||||
with force_offline_if_cached(True, "slow"):
|
||||
barrier.wait(timeout=5)
|
||||
assert fast_exited.wait(timeout=5), "fast thread did not exit"
|
||||
observations.append(_hf_const().HF_HUB_OFFLINE)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(exc)
|
||||
|
||||
def fast():
|
||||
try:
|
||||
with force_offline_if_cached(True, "fast"):
|
||||
barrier.wait(timeout=5)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(exc)
|
||||
finally:
|
||||
fast_exited.set()
|
||||
|
||||
t_slow = threading.Thread(target=slow)
|
||||
t_fast = threading.Thread(target=fast)
|
||||
t_slow.start()
|
||||
t_fast.start()
|
||||
t_slow.join(timeout=5)
|
||||
t_fast.join(timeout=5)
|
||||
|
||||
assert not t_slow.is_alive(), "slow thread did not finish"
|
||||
assert not t_fast.is_alive(), "fast thread did not finish"
|
||||
assert not errors, errors
|
||||
assert observations == [True], "slow thread lost offline protection"
|
||||
assert original == _hf_const().HF_HUB_OFFLINE
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
113
backend/tests/test_offline_patch.py
Normal file
113
backend/tests/test_offline_patch.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
Unit tests for ``patch_transformers_mistral_regex``.
|
||||
|
||||
Verifies that our wrapper around
|
||||
``transformers.PreTrainedTokenizerBase._patch_mistral_regex`` catches
|
||||
exceptions from the unconditional ``huggingface_hub.model_info()`` lookup
|
||||
and returns the tokenizer unchanged — matching the success-path behavior
|
||||
for non-Mistral repos (transformers 4.57.3, ``tokenization_utils_base.py:2503``).
|
||||
|
||||
NOTE: These tests mutate ``transformers.PreTrainedTokenizerBase`` globally;
|
||||
run serially, not under ``pytest-xdist`` with per-worker process isolation.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from huggingface_hub.errors import OfflineModeIsEnabled # noqa: E402
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase # noqa: E402
|
||||
|
||||
import utils.hf_offline_patch as hf_offline_patch # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def restore_mistral_regex():
|
||||
"""Snapshot the current ``_patch_mistral_regex`` and restore after each test."""
|
||||
saved = PreTrainedTokenizerBase.__dict__.get("_patch_mistral_regex")
|
||||
saved_flag = hf_offline_patch._mistral_regex_patched
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if saved is not None:
|
||||
PreTrainedTokenizerBase._patch_mistral_regex = saved
|
||||
hf_offline_patch._mistral_regex_patched = saved_flag
|
||||
|
||||
|
||||
def _apply_patch():
|
||||
hf_offline_patch._mistral_regex_patched = False
|
||||
hf_offline_patch.patch_transformers_mistral_regex()
|
||||
|
||||
|
||||
def test_suppresses_offline_mode_is_enabled(monkeypatch):
|
||||
_apply_patch()
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
def raise_offline(*_args, **_kwargs):
|
||||
raise OfflineModeIsEnabled("offline")
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "model_info", raise_offline)
|
||||
|
||||
sentinel = object()
|
||||
result = PreTrainedTokenizerBase._patch_mistral_regex(
|
||||
sentinel, "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
)
|
||||
assert result is sentinel
|
||||
|
||||
|
||||
def test_suppresses_connection_errors(monkeypatch):
|
||||
_apply_patch()
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
def raise_connection(*_args, **_kwargs):
|
||||
raise ConnectionError("network unreachable")
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "model_info", raise_connection)
|
||||
|
||||
sentinel = object()
|
||||
result = PreTrainedTokenizerBase._patch_mistral_regex(
|
||||
sentinel, "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
)
|
||||
assert result is sentinel
|
||||
|
||||
|
||||
def test_passthrough_on_success(monkeypatch):
|
||||
"""When model_info returns non-Mistral tags the original falls through and returns the tokenizer unchanged."""
|
||||
_apply_patch()
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
class FakeInfo:
|
||||
tags = ["model-type:qwen", "language:en"]
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "model_info", lambda *_a, **_kw: FakeInfo())
|
||||
|
||||
sentinel = object()
|
||||
result = PreTrainedTokenizerBase._patch_mistral_regex(
|
||||
sentinel, "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
)
|
||||
assert result is sentinel
|
||||
|
||||
|
||||
def test_idempotent():
|
||||
_apply_patch()
|
||||
first = PreTrainedTokenizerBase._patch_mistral_regex
|
||||
hf_offline_patch.patch_transformers_mistral_regex()
|
||||
second = PreTrainedTokenizerBase._patch_mistral_regex
|
||||
assert first.__func__ is second.__func__
|
||||
|
||||
|
||||
def test_missing_method_is_noop(monkeypatch):
|
||||
monkeypatch.delattr(PreTrainedTokenizerBase, "_patch_mistral_regex", raising=False)
|
||||
hf_offline_patch._mistral_regex_patched = False
|
||||
hf_offline_patch.patch_transformers_mistral_regex()
|
||||
assert hf_offline_patch._mistral_regex_patched is False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
217
backend/tests/test_profile_duplicate_names.py
Normal file
217
backend/tests/test_profile_duplicate_names.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Tests for profile duplicate name validation.
|
||||
|
||||
This test suite verifies that the application correctly handles
|
||||
duplicate profile names and provides user-friendly error messages.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Add parent directory to path to import backend modules
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from database import Base, VoiceProfile as DBVoiceProfile
|
||||
from models import VoiceProfileCreate
|
||||
from profiles import create_profile, update_profile
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
"""Create a temporary test database."""
|
||||
# Create temporary directory for test database
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_path = Path(temp_dir) / "test.db"
|
||||
|
||||
# Create engine and session
|
||||
engine = create_engine(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=engine)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
db = SessionLocal()
|
||||
|
||||
yield db
|
||||
|
||||
# Cleanup
|
||||
db.close()
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_profiles_dir(monkeypatch, tmp_path):
|
||||
"""Mock the profiles directory to use a temporary path."""
|
||||
from backend import config
|
||||
monkeypatch.setattr(config, 'get_profiles_dir', lambda: tmp_path)
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_profile_duplicate_name_raises_error(test_db, mock_profiles_dir):
|
||||
"""Test that creating a profile with a duplicate name raises a ValueError."""
|
||||
# Create first profile
|
||||
profile_data_1 = VoiceProfileCreate(
|
||||
name="Test Profile",
|
||||
description="First profile",
|
||||
language="en"
|
||||
)
|
||||
|
||||
profile_1 = await create_profile(profile_data_1, test_db)
|
||||
assert profile_1.name == "Test Profile"
|
||||
|
||||
# Try to create second profile with same name
|
||||
profile_data_2 = VoiceProfileCreate(
|
||||
name="Test Profile",
|
||||
description="Second profile",
|
||||
language="en"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await create_profile(profile_data_2, test_db)
|
||||
|
||||
# Verify error message is user-friendly
|
||||
assert "already exists" in str(exc_info.value)
|
||||
assert "Test Profile" in str(exc_info.value)
|
||||
assert "choose a different name" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_profile_different_names_succeeds(test_db, mock_profiles_dir):
|
||||
"""Test that creating profiles with different names succeeds."""
|
||||
# Create first profile
|
||||
profile_data_1 = VoiceProfileCreate(
|
||||
name="Profile One",
|
||||
description="First profile",
|
||||
language="en"
|
||||
)
|
||||
|
||||
profile_1 = await create_profile(profile_data_1, test_db)
|
||||
assert profile_1.name == "Profile One"
|
||||
|
||||
# Create second profile with different name
|
||||
profile_data_2 = VoiceProfileCreate(
|
||||
name="Profile Two",
|
||||
description="Second profile",
|
||||
language="en"
|
||||
)
|
||||
|
||||
profile_2 = await create_profile(profile_data_2, test_db)
|
||||
assert profile_2.name == "Profile Two"
|
||||
|
||||
# Verify both profiles exist
|
||||
assert profile_1.id != profile_2.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile_to_duplicate_name_raises_error(test_db, mock_profiles_dir):
|
||||
"""Test that updating a profile to a duplicate name raises a ValueError."""
|
||||
# Create two profiles with different names
|
||||
profile_data_1 = VoiceProfileCreate(
|
||||
name="Profile A",
|
||||
description="First profile",
|
||||
language="en"
|
||||
)
|
||||
profile_1 = await create_profile(profile_data_1, test_db)
|
||||
|
||||
profile_data_2 = VoiceProfileCreate(
|
||||
name="Profile B",
|
||||
description="Second profile",
|
||||
language="en"
|
||||
)
|
||||
profile_2 = await create_profile(profile_data_2, test_db)
|
||||
|
||||
# Try to update profile_2 to use profile_1's name
|
||||
update_data = VoiceProfileCreate(
|
||||
name="Profile A", # Duplicate name
|
||||
description="Updated description",
|
||||
language="en"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await update_profile(profile_2.id, update_data, test_db)
|
||||
|
||||
# Verify error message is user-friendly
|
||||
assert "already exists" in str(exc_info.value)
|
||||
assert "Profile A" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile_keep_same_name_succeeds(test_db, mock_profiles_dir):
|
||||
"""Test that updating a profile while keeping the same name succeeds."""
|
||||
# Create profile
|
||||
profile_data = VoiceProfileCreate(
|
||||
name="My Profile",
|
||||
description="Original description",
|
||||
language="en"
|
||||
)
|
||||
profile = await create_profile(profile_data, test_db)
|
||||
|
||||
# Update profile with same name but different description
|
||||
update_data = VoiceProfileCreate(
|
||||
name="My Profile", # Same name
|
||||
description="Updated description",
|
||||
language="en"
|
||||
)
|
||||
|
||||
updated_profile = await update_profile(profile.id, update_data, test_db)
|
||||
|
||||
# Verify update succeeded
|
||||
assert updated_profile is not None
|
||||
assert updated_profile.id == profile.id
|
||||
assert updated_profile.name == "My Profile"
|
||||
assert updated_profile.description == "Updated description"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile_to_new_unique_name_succeeds(test_db, mock_profiles_dir):
|
||||
"""Test that updating a profile to a new unique name succeeds."""
|
||||
# Create profile
|
||||
profile_data = VoiceProfileCreate(
|
||||
name="Original Name",
|
||||
description="Profile description",
|
||||
language="en"
|
||||
)
|
||||
profile = await create_profile(profile_data, test_db)
|
||||
|
||||
# Update profile with new unique name
|
||||
update_data = VoiceProfileCreate(
|
||||
name="New Unique Name",
|
||||
description="Updated description",
|
||||
language="en"
|
||||
)
|
||||
|
||||
updated_profile = await update_profile(profile.id, update_data, test_db)
|
||||
|
||||
# Verify update succeeded
|
||||
assert updated_profile is not None
|
||||
assert updated_profile.id == profile.id
|
||||
assert updated_profile.name == "New Unique Name"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_case_sensitive_names_allowed(test_db, mock_profiles_dir):
|
||||
"""Test that profile names are case-sensitive (e.g., 'Test' and 'test' are different)."""
|
||||
# Create profile with lowercase name
|
||||
profile_data_1 = VoiceProfileCreate(
|
||||
name="test profile",
|
||||
description="Lowercase",
|
||||
language="en"
|
||||
)
|
||||
profile_1 = await create_profile(profile_data_1, test_db)
|
||||
|
||||
# Create profile with different case
|
||||
profile_data_2 = VoiceProfileCreate(
|
||||
name="Test Profile",
|
||||
description="Title case",
|
||||
language="en"
|
||||
)
|
||||
profile_2 = await create_profile(profile_data_2, test_db)
|
||||
|
||||
# Both should succeed since SQLite unique constraint is case-sensitive by default
|
||||
assert profile_1.name == "test profile"
|
||||
assert profile_2.name == "Test Profile"
|
||||
assert profile_1.id != profile_2.id
|
||||
313
backend/tests/test_progress.py
Normal file
313
backend/tests/test_progress.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""
|
||||
Test script to debug model download progress tracking.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict
|
||||
import logging
|
||||
|
||||
# Set up logging to see what's happening
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
from utils.progress import ProgressManager, get_progress_manager
|
||||
from utils.hf_progress import HFProgressTracker, create_hf_progress_callback
|
||||
|
||||
|
||||
def test_progress_manager_basic():
|
||||
"""Test 1: Basic ProgressManager functionality."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 1: ProgressManager Basic Operations")
|
||||
print("=" * 60)
|
||||
|
||||
pm = ProgressManager()
|
||||
|
||||
# Test update_progress
|
||||
pm.update_progress(
|
||||
model_name="test-model",
|
||||
current=50,
|
||||
total=100,
|
||||
filename="test.bin",
|
||||
status="downloading"
|
||||
)
|
||||
|
||||
# Test get_progress
|
||||
progress = pm.get_progress("test-model")
|
||||
print(f"✓ Progress stored: {progress}")
|
||||
assert progress is not None
|
||||
assert progress["progress"] == 50.0
|
||||
assert progress["filename"] == "test.bin"
|
||||
assert progress["status"] == "downloading"
|
||||
|
||||
# Test mark_complete
|
||||
pm.mark_complete("test-model")
|
||||
progress = pm.get_progress("test-model")
|
||||
print(f"✓ Marked complete: {progress}")
|
||||
assert progress["status"] == "complete"
|
||||
assert progress["progress"] == 100.0
|
||||
|
||||
print("✓ Test 1 PASSED\n")
|
||||
return True
|
||||
|
||||
|
||||
async def test_progress_manager_sse():
|
||||
"""Test 2: ProgressManager SSE streaming."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 2: ProgressManager SSE Streaming")
|
||||
print("=" * 60)
|
||||
|
||||
pm = ProgressManager()
|
||||
collected_events: List[Dict] = []
|
||||
|
||||
# Simulate SSE client
|
||||
async def sse_client():
|
||||
"""Simulates a frontend SSE connection."""
|
||||
print(" SSE client: Subscribing to test-model-sse...")
|
||||
async for event in pm.subscribe("test-model-sse"):
|
||||
# Parse SSE event
|
||||
if event.startswith("data: "):
|
||||
data = json.loads(event[6:])
|
||||
print(f" SSE client: Received event: {data['status']} - {data.get('progress', 0):.1f}%")
|
||||
collected_events.append(data)
|
||||
|
||||
# Stop when complete
|
||||
if data.get("status") in ("complete", "error"):
|
||||
break
|
||||
elif event.startswith(": heartbeat"):
|
||||
print(" SSE client: Received heartbeat")
|
||||
|
||||
# Simulate download progress updates (from backend thread)
|
||||
async def simulate_download():
|
||||
"""Simulates backend sending progress updates."""
|
||||
print(" Backend: Starting simulated download...")
|
||||
await asyncio.sleep(0.2) # Let SSE client subscribe first
|
||||
|
||||
# Send progress updates
|
||||
for i in range(0, 101, 20):
|
||||
print(f" Backend: Updating progress to {i}%")
|
||||
pm.update_progress(
|
||||
model_name="test-model-sse",
|
||||
current=i,
|
||||
total=100,
|
||||
filename=f"file_{i}.bin",
|
||||
status="downloading" if i < 100 else "downloading"
|
||||
)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Mark complete
|
||||
print(" Backend: Marking download complete")
|
||||
pm.mark_complete("test-model-sse")
|
||||
|
||||
# Run SSE client and download simulation concurrently
|
||||
await asyncio.gather(
|
||||
sse_client(),
|
||||
simulate_download()
|
||||
)
|
||||
|
||||
# Verify we got events
|
||||
print(f"\n Collected {len(collected_events)} events")
|
||||
assert len(collected_events) > 0, "Should have received at least one event"
|
||||
assert collected_events[-1]["status"] == "complete", "Last event should be 'complete'"
|
||||
|
||||
print("✓ Test 2 PASSED\n")
|
||||
return True
|
||||
|
||||
|
||||
def test_hf_progress_tracker():
|
||||
"""Test 3: HFProgressTracker tqdm patching."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 3: HFProgressTracker tqdm Patching")
|
||||
print("=" * 60)
|
||||
|
||||
captured_progress: List[tuple] = []
|
||||
|
||||
def progress_callback(downloaded: int, total: int, filename: str):
|
||||
"""Capture progress updates."""
|
||||
captured_progress.append((downloaded, total, filename))
|
||||
print(f" Progress callback: {downloaded}/{total} bytes ({filename})")
|
||||
|
||||
tracker = HFProgressTracker(progress_callback)
|
||||
|
||||
# Simulate a download with tqdm
|
||||
with tracker.patch_download():
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
# Simulate downloading a file
|
||||
print(" Simulating download with tqdm...")
|
||||
total_size = 1000
|
||||
with tqdm(total=total_size, desc="model.bin", unit="B", unit_scale=True) as pbar:
|
||||
for chunk in range(0, total_size, 100):
|
||||
pbar.update(100)
|
||||
time.sleep(0.01)
|
||||
|
||||
print(f" Captured {len(captured_progress)} progress updates")
|
||||
assert len(captured_progress) > 0, "Should have captured progress updates"
|
||||
|
||||
# Verify progress increases
|
||||
last_downloaded = 0
|
||||
for downloaded, total, filename in captured_progress:
|
||||
assert downloaded >= last_downloaded, "Downloaded bytes should increase"
|
||||
assert total == total_size, "Total should be consistent"
|
||||
last_downloaded = downloaded
|
||||
|
||||
print("✓ Test 3 PASSED\n")
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
print("✗ tqdm not available, skipping test\n")
|
||||
return None
|
||||
|
||||
|
||||
async def test_full_integration():
|
||||
"""Test 4: Full integration test."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 4: Full Integration (ProgressManager + HFProgressTracker)")
|
||||
print("=" * 60)
|
||||
|
||||
pm = get_progress_manager()
|
||||
collected_events: List[Dict] = []
|
||||
|
||||
# SSE client
|
||||
async def sse_client():
|
||||
print(" SSE client: Subscribing...")
|
||||
async for event in pm.subscribe("integration-test"):
|
||||
if event.startswith("data: "):
|
||||
data = json.loads(event[6:])
|
||||
print(f" SSE client: {data['status']} - {data.get('progress', 0):.1f}% - {data.get('filename', '')}")
|
||||
collected_events.append(data)
|
||||
if data.get("status") in ("complete", "error"):
|
||||
break
|
||||
|
||||
# Simulate backend download with HFProgressTracker
|
||||
async def simulate_real_download():
|
||||
await asyncio.sleep(0.2) # Let SSE subscribe
|
||||
|
||||
print(" Backend: Starting download with HFProgressTracker...")
|
||||
|
||||
# Set up tracking (like the real backend does)
|
||||
progress_callback = create_hf_progress_callback("integration-test", pm)
|
||||
tracker = HFProgressTracker(progress_callback)
|
||||
|
||||
# Initialize progress
|
||||
pm.update_progress(
|
||||
model_name="integration-test",
|
||||
current=0,
|
||||
total=1,
|
||||
filename="",
|
||||
status="downloading"
|
||||
)
|
||||
|
||||
# Simulate download with tqdm patching
|
||||
with tracker.patch_download():
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
# Simulate multi-file download (like HuggingFace does)
|
||||
files = [
|
||||
("model.safetensors", 5000),
|
||||
("config.json", 1000),
|
||||
("tokenizer.json", 500),
|
||||
]
|
||||
|
||||
for filename, size in files:
|
||||
print(f" Backend: Downloading {filename}...")
|
||||
with tqdm(total=size, desc=filename, unit="B") as pbar:
|
||||
for chunk in range(0, size, 500):
|
||||
chunk_size = min(500, size - chunk)
|
||||
pbar.update(chunk_size)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Mark complete
|
||||
print(" Backend: Download complete")
|
||||
pm.mark_complete("integration-test")
|
||||
|
||||
except ImportError:
|
||||
print(" ✗ tqdm not available")
|
||||
pm.mark_error("integration-test", "tqdm not available")
|
||||
|
||||
# Run both
|
||||
await asyncio.gather(
|
||||
sse_client(),
|
||||
simulate_real_download()
|
||||
)
|
||||
|
||||
# Verify
|
||||
print(f"\n Collected {len(collected_events)} events")
|
||||
if len(collected_events) > 0:
|
||||
print(f" First event: {collected_events[0]}")
|
||||
print(f" Last event: {collected_events[-1]}")
|
||||
assert collected_events[-1]["status"] == "complete", "Should end with 'complete'"
|
||||
print("✓ Test 4 PASSED\n")
|
||||
return True
|
||||
else:
|
||||
print("✗ Test 4 FAILED - No events received\n")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Voicebox Progress Tracking Test Suite")
|
||||
print("=" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
# Test 1: Basic operations
|
||||
try:
|
||||
results.append(("Basic Operations", test_progress_manager_basic()))
|
||||
except Exception as e:
|
||||
print(f"✗ Test 1 FAILED: {e}\n")
|
||||
results.append(("Basic Operations", False))
|
||||
|
||||
# Test 2: SSE streaming
|
||||
try:
|
||||
results.append(("SSE Streaming", await test_progress_manager_sse()))
|
||||
except Exception as e:
|
||||
print(f"✗ Test 2 FAILED: {e}\n")
|
||||
results.append(("SSE Streaming", False))
|
||||
|
||||
# Test 3: tqdm patching
|
||||
try:
|
||||
results.append(("tqdm Patching", test_hf_progress_tracker()))
|
||||
except Exception as e:
|
||||
print(f"✗ Test 3 FAILED: {e}\n")
|
||||
results.append(("tqdm Patching", False))
|
||||
|
||||
# Test 4: Full integration
|
||||
try:
|
||||
results.append(("Full Integration", await test_full_integration()))
|
||||
except Exception as e:
|
||||
print(f"✗ Test 4 FAILED: {e}\n")
|
||||
results.append(("Full Integration", False))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Test Results Summary")
|
||||
print("=" * 60)
|
||||
|
||||
for name, result in results:
|
||||
status = "✓ PASS" if result else ("⊘ SKIP" if result is None else "✗ FAIL")
|
||||
print(f" {status:8} {name}")
|
||||
|
||||
passed = sum(1 for _, r in results if r is True)
|
||||
failed = sum(1 for _, r in results if r is False)
|
||||
skipped = sum(1 for _, r in results if r is None)
|
||||
|
||||
print()
|
||||
print(f" Total: {len(results)} tests")
|
||||
print(f" Passed: {passed}")
|
||||
print(f" Failed: {failed}")
|
||||
print(f" Skipped: {skipped}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
return failed == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = asyncio.run(main())
|
||||
exit(0 if success else 1)
|
||||
317
backend/tests/test_qwen_download.py
Normal file
317
backend/tests/test_qwen_download.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Test Qwen TTS model download with SSE progress monitoring.
|
||||
|
||||
This specifically tests the MLX TTS backend download progress tracking,
|
||||
which requires tqdm to be patched BEFORE mlx_audio is imported.
|
||||
|
||||
Usage:
|
||||
cd backend && python -m tests.test_qwen_download
|
||||
|
||||
Prerequisites:
|
||||
- Server must be running: cd backend && python main.py
|
||||
- Delete model first for fresh download test:
|
||||
curl -X DELETE http://localhost:8000/models/qwen-tts-0.6B
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import httpx
|
||||
import time
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
|
||||
async def monitor_sse_stream(model_name: str, timeout: int = 600) -> List[Dict]:
|
||||
"""
|
||||
Monitor SSE stream for a model download.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to monitor
|
||||
timeout: Maximum time to wait for download (seconds)
|
||||
|
||||
Returns:
|
||||
List of SSE events received
|
||||
"""
|
||||
events: List[Dict] = []
|
||||
url = f"http://localhost:8000/models/progress/{model_name}"
|
||||
last_progress = -1
|
||||
|
||||
print(f"\n📡 Connecting to SSE endpoint: {url}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("GET", url) as response:
|
||||
print(f" SSE connected, status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f" ❌ Error: SSE endpoint returned {response.status_code}")
|
||||
return events
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
data = json.loads(line[6:])
|
||||
events.append(data)
|
||||
|
||||
# Print progress (only when it changes significantly)
|
||||
progress = data.get('progress', 0)
|
||||
status = data.get('status', 'unknown')
|
||||
filename = data.get('filename', '')
|
||||
current = data.get('current', 0)
|
||||
total = data.get('total', 0)
|
||||
|
||||
# Print every 5% change or status change
|
||||
if abs(progress - last_progress) >= 5 or status in ('complete', 'error'):
|
||||
current_mb = current / (1024 * 1024)
|
||||
total_mb = total / (1024 * 1024)
|
||||
print(f" 📊 {status:12} {progress:6.1f}% ({current_mb:.1f}MB / {total_mb:.1f}MB) {filename[:50]}")
|
||||
last_progress = progress
|
||||
|
||||
# Stop if complete or error
|
||||
if status in ("complete", "error"):
|
||||
if status == "complete":
|
||||
print(f" ✅ Download complete!")
|
||||
else:
|
||||
print(f" ❌ Download error: {data.get('error', 'unknown')}")
|
||||
break
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" ⚠️ Error parsing JSON: {e}")
|
||||
|
||||
elif line.startswith(": heartbeat"):
|
||||
# Heartbeat every 1 second, don't spam
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError:
|
||||
print(" ⏹️ SSE monitor cancelled")
|
||||
except Exception as e:
|
||||
print(f" ❌ SSE error: {e}")
|
||||
|
||||
return events
|
||||
|
||||
|
||||
async def trigger_download(model_name: str) -> bool:
|
||||
"""Trigger a model download via the API."""
|
||||
url = "http://localhost:8000/models/download"
|
||||
|
||||
print(f"\n🚀 Triggering download for: {model_name}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.post(url, json={"model_name": model_name})
|
||||
result = response.json()
|
||||
print(f" Response: {response.status_code} - {result}")
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
print(f" ❌ Error triggering download: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_model(model_name: str) -> bool:
|
||||
"""Delete a model from cache."""
|
||||
url = f"http://localhost:8000/models/{model_name}"
|
||||
|
||||
print(f"\n🗑️ Deleting model: {model_name}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.delete(url)
|
||||
if response.status_code == 200:
|
||||
print(f" ✅ Model deleted")
|
||||
return True
|
||||
elif response.status_code == 404:
|
||||
print(f" ℹ️ Model not found (already deleted)")
|
||||
return True
|
||||
else:
|
||||
print(f" ⚠️ Delete response: {response.status_code} - {response.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f" ❌ Error deleting model: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def check_model_status(model_name: str) -> Optional[Dict]:
|
||||
"""Check the status of a model."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get("http://localhost:8000/models/status")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
for model in data.get("models", []):
|
||||
if model["model_name"] == model_name:
|
||||
return model
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Error checking model status: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def check_server() -> bool:
|
||||
"""Check if the server is running."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
response = await client.get("http://localhost:8000/health")
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
print("=" * 70)
|
||||
print("🧪 Qwen TTS Model Download Progress Test")
|
||||
print("=" * 70)
|
||||
print("\nThis test verifies that MLX TTS download progress tracking works.")
|
||||
print("It specifically tests the tqdm patching for mlx_audio.tts imports.")
|
||||
|
||||
# Check if server is running
|
||||
print("\n📡 Checking if server is running...")
|
||||
if not await check_server():
|
||||
print(" ❌ Server is not running on http://localhost:8000")
|
||||
print("\n Please start the server first:")
|
||||
print(" cd backend && python main.py")
|
||||
return False
|
||||
|
||||
print(" ✅ Server is running")
|
||||
|
||||
# Test model
|
||||
model_name = "qwen-tts-0.6B"
|
||||
|
||||
# Check current status
|
||||
print(f"\n📊 Checking status of {model_name}...")
|
||||
status = await check_model_status(model_name)
|
||||
if status:
|
||||
print(f" Downloaded: {status.get('downloaded', False)}")
|
||||
print(f" Downloading: {status.get('downloading', False)}")
|
||||
print(f" Loaded: {status.get('loaded', False)}")
|
||||
if status.get('size_mb'):
|
||||
print(f" Size: {status['size_mb']:.1f} MB")
|
||||
else:
|
||||
print(" ⚠️ Could not get model status")
|
||||
|
||||
# Ask if user wants to delete first
|
||||
print("\n" + "-" * 70)
|
||||
if status and status.get('downloaded'):
|
||||
print("⚠️ Model is already downloaded. Delete it for a fresh download test?")
|
||||
print(" [y] Yes, delete and download fresh")
|
||||
print(" [n] No, just test SSE connection")
|
||||
print(" [q] Quit")
|
||||
|
||||
choice = input("\nChoice [y/n/q]: ").strip().lower()
|
||||
|
||||
if choice == 'q':
|
||||
print("Exiting...")
|
||||
return True
|
||||
|
||||
if choice == 'y':
|
||||
if not await delete_model(model_name):
|
||||
print("Failed to delete model. Continue anyway? [y/n]")
|
||||
if input().strip().lower() != 'y':
|
||||
return False
|
||||
else:
|
||||
print("Model not downloaded. Will perform fresh download test.")
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# Run the test
|
||||
print("\n" + "=" * 70)
|
||||
print("🏃 Starting Download Test")
|
||||
print("=" * 70)
|
||||
|
||||
async def run_test():
|
||||
# Start SSE monitor in background FIRST
|
||||
monitor_task = asyncio.create_task(monitor_sse_stream(model_name, timeout=600))
|
||||
|
||||
# Wait for SSE to connect
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Trigger download
|
||||
success = await trigger_download(model_name)
|
||||
|
||||
if not success:
|
||||
print(" ❌ Failed to trigger download")
|
||||
monitor_task.cancel()
|
||||
try:
|
||||
await monitor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
return []
|
||||
|
||||
# Wait for SSE monitor to complete
|
||||
print("\n⏳ Waiting for download to complete (this may take several minutes)...")
|
||||
events = await monitor_task
|
||||
|
||||
return events
|
||||
|
||||
start_time = time.time()
|
||||
events = await run_test()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Results
|
||||
print("\n" + "=" * 70)
|
||||
print("📋 Test Results")
|
||||
print("=" * 70)
|
||||
|
||||
print(f"\n⏱️ Elapsed time: {elapsed:.1f} seconds")
|
||||
print(f"📨 Total SSE events received: {len(events)}")
|
||||
|
||||
if not events:
|
||||
print("\n❌ FAILED - No SSE events received!")
|
||||
print("\nPossible causes:")
|
||||
print(" 1. SSE endpoint not working")
|
||||
print(" 2. tqdm not patched before mlx_audio import")
|
||||
print(" 3. Progress callbacks not firing")
|
||||
print(" 4. Model already fully downloaded")
|
||||
print("\nDebug steps:")
|
||||
print(" 1. Check server logs for [DEBUG] messages")
|
||||
print(" 2. Look for 'tqdm patched' before 'mlx_audio.tts import'")
|
||||
print(f" 3. Delete model: curl -X DELETE http://localhost:8000/models/{model_name}")
|
||||
return False
|
||||
|
||||
# Analyze events
|
||||
first_event = events[0]
|
||||
last_event = events[-1]
|
||||
|
||||
print(f"\n📊 First event:")
|
||||
print(f" Status: {first_event.get('status')}")
|
||||
print(f" Progress: {first_event.get('progress', 0):.1f}%")
|
||||
|
||||
print(f"\n📊 Last event:")
|
||||
print(f" Status: {last_event.get('status')}")
|
||||
print(f" Progress: {last_event.get('progress', 0):.1f}%")
|
||||
|
||||
# Check for expected behaviors
|
||||
has_progress_updates = len(events) > 2
|
||||
has_increasing_progress = False
|
||||
has_complete = any(e.get('status') == 'complete' for e in events)
|
||||
has_100_percent = any(e.get('progress', 0) >= 100 for e in events)
|
||||
|
||||
# Check if progress increased over time
|
||||
if len(events) >= 2:
|
||||
progress_values = [e.get('progress', 0) for e in events]
|
||||
has_increasing_progress = progress_values[-1] > progress_values[0]
|
||||
|
||||
print("\n📋 Checks:")
|
||||
print(f" {'✅' if has_progress_updates else '❌'} Multiple progress updates received ({len(events)} events)")
|
||||
print(f" {'✅' if has_increasing_progress else '❌'} Progress increased over time")
|
||||
print(f" {'✅' if has_100_percent else '❌'} Reached 100% progress")
|
||||
print(f" {'✅' if has_complete else '❌'} Received 'complete' status")
|
||||
|
||||
# Overall result
|
||||
success = has_progress_updates and has_complete
|
||||
|
||||
if success:
|
||||
print("\n" + "=" * 70)
|
||||
print("✅ TEST PASSED - Qwen TTS download progress tracking works!")
|
||||
print("=" * 70)
|
||||
else:
|
||||
print("\n" + "=" * 70)
|
||||
print("❌ TEST FAILED - Progress tracking has issues")
|
||||
print("=" * 70)
|
||||
print("\nCheck the server logs for debug output.")
|
||||
|
||||
return success
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = asyncio.run(main())
|
||||
exit(0 if result else 1)
|
||||
54
backend/tests/test_task_queue_cancellation.py
Normal file
54
backend/tests/test_task_queue_cancellation.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.services import task_queue
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_queued_generation_skips_execution():
|
||||
task_queue.init_queue(force=True)
|
||||
|
||||
running_started = asyncio.Event()
|
||||
release_running = asyncio.Event()
|
||||
queued_ran = asyncio.Event()
|
||||
|
||||
async def running_job():
|
||||
running_started.set()
|
||||
await release_running.wait()
|
||||
|
||||
async def queued_job():
|
||||
queued_ran.set()
|
||||
|
||||
task_queue.enqueue_generation("gen-running", running_job())
|
||||
await asyncio.wait_for(running_started.wait(), timeout=1)
|
||||
|
||||
task_queue.enqueue_generation("gen-queued", queued_job())
|
||||
assert task_queue.cancel_generation("gen-queued") == "queued"
|
||||
|
||||
release_running.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert not queued_ran.is_set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_running_generation_cancels_task():
|
||||
task_queue.init_queue(force=True)
|
||||
|
||||
running_started = asyncio.Event()
|
||||
running_cancelled = asyncio.Event()
|
||||
|
||||
async def running_job():
|
||||
running_started.set()
|
||||
try:
|
||||
await asyncio.Event().wait()
|
||||
except asyncio.CancelledError:
|
||||
running_cancelled.set()
|
||||
raise
|
||||
|
||||
task_queue.enqueue_generation("gen-running", running_job())
|
||||
await asyncio.wait_for(running_started.wait(), timeout=1)
|
||||
|
||||
assert task_queue.cancel_generation("gen-running") == "running"
|
||||
await asyncio.wait_for(running_cancelled.wait(), timeout=1)
|
||||
178
backend/tests/test_whisper_download.py
Normal file
178
backend/tests/test_whisper_download.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Test real model download with SSE progress monitoring.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import httpx
|
||||
import time
|
||||
from typing import List, Dict
|
||||
|
||||
async def monitor_sse_stream(model_name: str, timeout: int = 300):
|
||||
"""Monitor SSE stream for a model download."""
|
||||
events: List[Dict] = []
|
||||
url = f"http://localhost:8000/models/progress/{model_name}"
|
||||
|
||||
print(f"Connecting to SSE endpoint: {url}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("GET", url) as response:
|
||||
print(f"SSE connected, status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"Error: SSE endpoint returned {response.status_code}")
|
||||
return events
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
print(f" Raw SSE: {line[:100]}...") # Print first 100 chars
|
||||
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
data = json.loads(line[6:])
|
||||
print(f" → {data['status']:12} {data.get('progress', 0):6.1f}% {data.get('filename', '')}")
|
||||
events.append(data)
|
||||
|
||||
# Stop if complete or error
|
||||
if data.get("status") in ("complete", "error"):
|
||||
print(f" Download {data['status']}!")
|
||||
break
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" Error parsing JSON: {e}")
|
||||
print(f" Line was: {line}")
|
||||
|
||||
elif line.startswith(": heartbeat"):
|
||||
print(" ♥ heartbeat")
|
||||
|
||||
return events
|
||||
|
||||
|
||||
async def trigger_download(model_name: str):
|
||||
"""Trigger a model download via the API."""
|
||||
url = "http://localhost:8000/models/download"
|
||||
|
||||
print(f"\nTriggering download for: {model_name}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
response = await client.post(url, json={"model_name": model_name})
|
||||
print(f"Response: {response.status_code} - {response.json()}")
|
||||
return response.status_code == 200
|
||||
|
||||
|
||||
async def check_server():
|
||||
"""Check if the server is running."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
response = await client.get("http://localhost:8000/health")
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
print(f"Server not running: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
print("=" * 60)
|
||||
print("Real Model Download Progress Test")
|
||||
print("=" * 60)
|
||||
|
||||
# Check if server is running
|
||||
print("\nChecking if server is running...")
|
||||
if not await check_server():
|
||||
print("✗ Server is not running on http://localhost:8000")
|
||||
print("\nPlease start the server first:")
|
||||
print(" cd backend && python main.py")
|
||||
return False
|
||||
|
||||
print("✓ Server is running")
|
||||
|
||||
# Choose a small model for testing
|
||||
model_name = "whisper-base" # ~150MB, faster to download
|
||||
print(f"\nUsing model: {model_name}")
|
||||
|
||||
# Option to delete model first if it exists
|
||||
print("\nDo you want to delete the model first to force a fresh download? (y/n)")
|
||||
# For automated testing, skip deletion prompt
|
||||
# delete_first = input().strip().lower() == 'y'
|
||||
delete_first = False
|
||||
|
||||
if delete_first:
|
||||
print(f"Deleting {model_name}...")
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.delete(f"http://localhost:8000/models/{model_name}")
|
||||
print(f"Delete response: {response.status_code}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting Test")
|
||||
print("=" * 60)
|
||||
|
||||
# Start monitoring SSE stream BEFORE triggering download
|
||||
async def run_test():
|
||||
# Start SSE monitor in background
|
||||
monitor_task = asyncio.create_task(monitor_sse_stream(model_name))
|
||||
|
||||
# Wait a bit to ensure SSE is connected
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Trigger download
|
||||
success = await trigger_download(model_name)
|
||||
|
||||
if not success:
|
||||
print("✗ Failed to trigger download")
|
||||
monitor_task.cancel()
|
||||
return False
|
||||
|
||||
# Wait for SSE monitor to complete
|
||||
events = await monitor_task
|
||||
|
||||
return events
|
||||
|
||||
events = await run_test()
|
||||
|
||||
# Results
|
||||
print("\n" + "=" * 60)
|
||||
print("Test Results")
|
||||
print("=" * 60)
|
||||
|
||||
if not events:
|
||||
print("✗ FAILED - No SSE events received!")
|
||||
print("\nPossible causes:")
|
||||
print(" 1. SSE endpoint not working")
|
||||
print(" 2. Progress updates not being sent")
|
||||
print(" 3. Model already downloaded (no progress to report)")
|
||||
print("\nTry deleting the model first to force a fresh download:")
|
||||
print(f" curl -X DELETE http://localhost:8000/models/{model_name}")
|
||||
return False
|
||||
|
||||
print(f"✓ Received {len(events)} SSE events")
|
||||
print(f"\nFirst event: {events[0]}")
|
||||
print(f"Last event: {events[-1]}")
|
||||
|
||||
# Check if we got meaningful progress
|
||||
has_progress = any(e.get('progress', 0) > 0 for e in events)
|
||||
has_complete = any(e.get('status') == 'complete' for e in events)
|
||||
|
||||
if has_progress:
|
||||
print("✓ Progress updates received")
|
||||
else:
|
||||
print("✗ No progress updates (might be already downloaded)")
|
||||
|
||||
if has_complete:
|
||||
print("✓ Download completed successfully")
|
||||
else:
|
||||
print("✗ Download did not complete")
|
||||
|
||||
success = has_progress and has_complete
|
||||
|
||||
if success:
|
||||
print("\n✓ TEST PASSED - Progress tracking works!")
|
||||
else:
|
||||
print("\n⊘ TEST INCONCLUSIVE - Try with a fresh download")
|
||||
|
||||
return success
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
1
backend/utils/__init__.py
Normal file
1
backend/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Utils package
|
||||
318
backend/utils/audio.py
Normal file
318
backend/utils/audio.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Audio processing utilities.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import librosa
|
||||
from typing import Tuple, Optional
|
||||
|
||||
|
||||
def normalize_audio(
|
||||
audio: np.ndarray,
|
||||
target_db: float = -20.0,
|
||||
peak_limit: float = 0.85,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Normalize audio to target loudness with peak limiting.
|
||||
|
||||
Args:
|
||||
audio: Input audio array
|
||||
target_db: Target RMS level in dB
|
||||
peak_limit: Peak limit (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
Normalized audio array
|
||||
"""
|
||||
# Convert to float32
|
||||
audio = audio.astype(np.float32)
|
||||
|
||||
# Calculate current RMS
|
||||
rms = np.sqrt(np.mean(audio**2))
|
||||
|
||||
# Calculate target RMS
|
||||
target_rms = 10**(target_db / 20)
|
||||
|
||||
# Apply gain
|
||||
if rms > 0:
|
||||
gain = target_rms / rms
|
||||
audio = audio * gain
|
||||
|
||||
# Peak limiting
|
||||
audio = np.clip(audio, -peak_limit, peak_limit)
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
def load_audio(
|
||||
path: str,
|
||||
sample_rate: int = 24000,
|
||||
mono: bool = True,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Load audio file with normalization.
|
||||
|
||||
Args:
|
||||
path: Path to audio file
|
||||
sample_rate: Target sample rate
|
||||
mono: Convert to mono
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
audio, sr = librosa.load(path, sr=sample_rate, mono=mono)
|
||||
return audio, sr
|
||||
|
||||
|
||||
def save_audio(
|
||||
audio: np.ndarray,
|
||||
path: str,
|
||||
sample_rate: int = 24000,
|
||||
) -> None:
|
||||
"""
|
||||
Save audio file with atomic write and error handling.
|
||||
|
||||
Writes to a temporary file first, then atomically renames to the
|
||||
target path. This prevents corrupted/partial WAV files if the
|
||||
process is interrupted mid-write.
|
||||
|
||||
Args:
|
||||
audio: Audio array
|
||||
path: Output path
|
||||
sample_rate: Sample rate
|
||||
|
||||
Raises:
|
||||
OSError: If file cannot be written
|
||||
"""
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
temp_path = f"{path}.tmp"
|
||||
try:
|
||||
# Ensure parent directory exists
|
||||
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write to temporary file first (explicit format since .tmp
|
||||
# extension is not recognised by soundfile)
|
||||
sf.write(temp_path, audio, sample_rate, format='WAV')
|
||||
|
||||
# Atomic rename to final path
|
||||
os.replace(temp_path, path)
|
||||
|
||||
except Exception as e:
|
||||
# Clean up temp file on failure
|
||||
try:
|
||||
if Path(temp_path).exists():
|
||||
Path(temp_path).unlink()
|
||||
except Exception:
|
||||
pass # Best effort cleanup
|
||||
|
||||
raise OSError(f"Failed to save audio to {path}: {e}") from e
|
||||
|
||||
|
||||
def trim_tts_output(
|
||||
audio: np.ndarray,
|
||||
sample_rate: int = 24000,
|
||||
frame_ms: int = 20,
|
||||
silence_threshold_db: float = -40.0,
|
||||
min_silence_ms: int = 200,
|
||||
max_internal_silence_ms: int = 1000,
|
||||
fade_ms: int = 30,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Trim trailing silence and post-silence hallucination from TTS output.
|
||||
|
||||
Chatterbox sometimes produces ``[speech][silence][hallucinated noise]``.
|
||||
This detects internal silence gaps longer than *max_internal_silence_ms*
|
||||
and cuts the audio at that boundary, then trims trailing silence and
|
||||
applies a short cosine fade-out.
|
||||
|
||||
Args:
|
||||
audio: Input audio array (mono float32)
|
||||
sample_rate: Sample rate in Hz
|
||||
frame_ms: Frame size for RMS energy calculation
|
||||
silence_threshold_db: dB threshold below which a frame is silence
|
||||
min_silence_ms: Minimum trailing silence to keep
|
||||
max_internal_silence_ms: Cut after any silence gap longer than this
|
||||
fade_ms: Cosine fade-out duration in ms
|
||||
|
||||
Returns:
|
||||
Trimmed audio array
|
||||
"""
|
||||
frame_len = int(sample_rate * frame_ms / 1000)
|
||||
if frame_len == 0 or len(audio) < frame_len:
|
||||
return audio
|
||||
|
||||
n_frames = len(audio) // frame_len
|
||||
threshold_linear = 10 ** (silence_threshold_db / 20)
|
||||
|
||||
# Compute per-frame RMS
|
||||
rms = np.array(
|
||||
[
|
||||
np.sqrt(np.mean(audio[i * frame_len : (i + 1) * frame_len] ** 2))
|
||||
for i in range(n_frames)
|
||||
]
|
||||
)
|
||||
is_speech = rms >= threshold_linear
|
||||
|
||||
# Find first speech frame
|
||||
first_speech = 0
|
||||
for i, s in enumerate(is_speech):
|
||||
if s:
|
||||
first_speech = max(0, i - 1) # keep 1 frame padding
|
||||
break
|
||||
|
||||
# Walk forward from first speech; cut at long internal silence gaps
|
||||
max_silence_frames = int(max_internal_silence_ms / frame_ms)
|
||||
consecutive_silence = 0
|
||||
cut_frame = n_frames
|
||||
|
||||
for i in range(first_speech, n_frames):
|
||||
if is_speech[i]:
|
||||
consecutive_silence = 0
|
||||
else:
|
||||
consecutive_silence += 1
|
||||
if consecutive_silence >= max_silence_frames:
|
||||
cut_frame = i - consecutive_silence + 1
|
||||
break
|
||||
|
||||
# Trim trailing silence from the cut point
|
||||
min_silence_frames = int(min_silence_ms / frame_ms)
|
||||
end_frame = cut_frame
|
||||
while end_frame > first_speech and not is_speech[end_frame - 1]:
|
||||
end_frame -= 1
|
||||
# Keep a short tail
|
||||
end_frame = min(end_frame + min_silence_frames, cut_frame)
|
||||
|
||||
# Convert frames back to samples
|
||||
start_sample = first_speech * frame_len
|
||||
end_sample = min(end_frame * frame_len, len(audio))
|
||||
|
||||
trimmed = audio[start_sample:end_sample].copy()
|
||||
|
||||
# Cosine fade-out
|
||||
fade_samples = int(sample_rate * fade_ms / 1000)
|
||||
if fade_samples > 0 and len(trimmed) > fade_samples:
|
||||
fade = np.cos(np.linspace(0, np.pi / 2, fade_samples)) ** 2
|
||||
trimmed[-fade_samples:] *= fade
|
||||
|
||||
return trimmed
|
||||
|
||||
|
||||
def preprocess_reference_audio(
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
peak_target: float = 0.95,
|
||||
trim_top_db: float = 40.0,
|
||||
edge_padding_ms: int = 100,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Clean up a reference-audio sample before validation/storage.
|
||||
|
||||
Removes DC offset, trims leading/trailing silence, and caps the peak so a
|
||||
slightly-hot recording doesn't get rejected downstream as "clipping". The
|
||||
goal is to accept reasonable real-world recordings — not to repair badly
|
||||
distorted ones. True clipping artifacts inside the waveform can't be
|
||||
recovered by peak scaling and will still sound bad.
|
||||
|
||||
Args:
|
||||
audio: Mono audio array.
|
||||
sample_rate: Sample rate of ``audio`` in Hz.
|
||||
peak_target: Peak amplitude cap in [0, 1]. Applied only if the input
|
||||
peak exceeds this value.
|
||||
trim_top_db: Silence threshold for edge trimming, in dB below peak.
|
||||
40 dB sits below normal speech dynamic range (≈30 dB) so soft
|
||||
trailing syllables are preserved, while still catching obvious
|
||||
leading/trailing silence. Lower values are more aggressive;
|
||||
librosa's own default is 60.
|
||||
edge_padding_ms: Milliseconds of padding to add back at each edge
|
||||
*only if* trimming shortened the waveform, so TTS engines have a
|
||||
brief silence to anchor on without ever making the output longer
|
||||
than the input.
|
||||
|
||||
Returns:
|
||||
Preprocessed audio array (float32).
|
||||
"""
|
||||
audio = audio.astype(np.float32, copy=False)
|
||||
|
||||
if audio.size == 0:
|
||||
return audio
|
||||
|
||||
audio = audio - float(np.mean(audio))
|
||||
|
||||
trimmed, _ = librosa.effects.trim(audio, top_db=trim_top_db)
|
||||
if 0 < trimmed.size < audio.size:
|
||||
pad_each = int(sample_rate * edge_padding_ms / 1000)
|
||||
# Never pad past the original length — for near-max-duration uploads
|
||||
# an unconditional pad would push them over the 30 s ceiling and
|
||||
# trigger a spurious "too long" rejection.
|
||||
headroom = (audio.size - trimmed.size) // 2
|
||||
pad = min(pad_each, max(headroom, 0))
|
||||
if pad > 0:
|
||||
trimmed = np.pad(trimmed, (pad, pad), mode="constant")
|
||||
audio = trimmed
|
||||
|
||||
peak = float(np.abs(audio).max())
|
||||
if peak > peak_target and peak > 0:
|
||||
audio = audio * (peak_target / peak)
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
def validate_reference_audio(
|
||||
audio_path: str,
|
||||
min_duration: float = 2.0,
|
||||
max_duration: float = 30.0,
|
||||
min_rms: float = 0.01,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate reference audio for voice cloning.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
min_duration: Minimum duration in seconds
|
||||
max_duration: Maximum duration in seconds
|
||||
min_rms: Minimum RMS level
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
result = validate_and_load_reference_audio(
|
||||
audio_path, min_duration, max_duration, min_rms
|
||||
)
|
||||
return (result[0], result[1])
|
||||
|
||||
|
||||
def validate_and_load_reference_audio(
|
||||
audio_path: str,
|
||||
min_duration: float = 2.0,
|
||||
max_duration: float = 30.0,
|
||||
min_rms: float = 0.01,
|
||||
) -> Tuple[bool, Optional[str], Optional[np.ndarray], Optional[int]]:
|
||||
"""
|
||||
Validate and load reference audio in a single pass.
|
||||
|
||||
Applies :func:`preprocess_reference_audio` before checks so that
|
||||
slightly-hot recordings aren't rejected as clipping. Duration and RMS
|
||||
checks run on the preprocessed waveform.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message, audio_array, sample_rate)
|
||||
"""
|
||||
try:
|
||||
audio, sr = load_audio(audio_path)
|
||||
audio = preprocess_reference_audio(audio, sr)
|
||||
duration = len(audio) / sr
|
||||
|
||||
if duration < min_duration:
|
||||
return False, f"Audio too short (minimum {min_duration} seconds)", None, None
|
||||
if duration > max_duration:
|
||||
return False, f"Audio too long (maximum {max_duration} seconds)", None, None
|
||||
|
||||
rms = np.sqrt(np.mean(audio**2))
|
||||
if rms < min_rms:
|
||||
return False, "Audio is too quiet or silent", None, None
|
||||
|
||||
return True, None, audio, sr
|
||||
except Exception as e:
|
||||
return False, f"Error validating audio: {str(e)}", None, None
|
||||
153
backend/utils/cache.py
Normal file
153
backend/utils/cache.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Voice prompt caching utilities.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Dict, Any
|
||||
|
||||
from .. import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_cache_dir() -> Path:
|
||||
"""Get cache directory from config."""
|
||||
return config.get_cache_dir()
|
||||
|
||||
|
||||
# In-memory cache - can store dict (voice prompt) or tensor (legacy)
|
||||
_memory_cache: dict[str, Union[torch.Tensor, Dict[str, Any]]] = {}
|
||||
|
||||
|
||||
def get_cache_key(audio_path: str, reference_text: str) -> str:
|
||||
"""
|
||||
Generate cache key from audio file and reference text.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
reference_text: Reference text
|
||||
|
||||
Returns:
|
||||
Cache key (MD5 hash)
|
||||
"""
|
||||
# Read audio file
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
# Combine audio bytes and text
|
||||
combined = audio_bytes + reference_text.encode("utf-8")
|
||||
|
||||
# Generate hash
|
||||
return hashlib.md5(combined).hexdigest()
|
||||
|
||||
|
||||
def get_cached_voice_prompt(
|
||||
cache_key: str,
|
||||
) -> Optional[Union[torch.Tensor, Dict[str, Any]]]:
|
||||
"""
|
||||
Get cached voice prompt if available.
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached voice prompt (dict or tensor) or None
|
||||
"""
|
||||
# Check in-memory cache
|
||||
if cache_key in _memory_cache:
|
||||
return _memory_cache[cache_key]
|
||||
|
||||
# Check disk cache
|
||||
cache_file = _get_cache_dir() / f"{cache_key}.prompt"
|
||||
if cache_file.exists():
|
||||
try:
|
||||
prompt = torch.load(cache_file, weights_only=True)
|
||||
_memory_cache[cache_key] = prompt
|
||||
return prompt
|
||||
except Exception:
|
||||
# Cache file corrupted, delete it
|
||||
cache_file.unlink()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def cache_voice_prompt(
|
||||
cache_key: str,
|
||||
voice_prompt: Union[torch.Tensor, Dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Cache voice prompt to memory and disk.
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
voice_prompt: Voice prompt (dict or tensor)
|
||||
"""
|
||||
# Store in memory
|
||||
_memory_cache[cache_key] = voice_prompt
|
||||
|
||||
# Store on disk (torch.save can handle both dicts and tensors)
|
||||
cache_file = _get_cache_dir() / f"{cache_key}.prompt"
|
||||
torch.save(voice_prompt, cache_file)
|
||||
|
||||
|
||||
def clear_voice_prompt_cache() -> int:
|
||||
"""
|
||||
Clear all voice prompt caches (memory and disk).
|
||||
|
||||
Returns:
|
||||
Number of cache files deleted
|
||||
"""
|
||||
# Clear memory cache
|
||||
_memory_cache.clear()
|
||||
|
||||
# Clear disk cache
|
||||
cache_dir = _get_cache_dir()
|
||||
deleted_count = 0
|
||||
|
||||
if cache_dir.exists():
|
||||
# Delete prompt cache files
|
||||
for cache_file in cache_dir.glob("*.prompt"):
|
||||
try:
|
||||
cache_file.unlink()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete cache file %s: %s", cache_file, e)
|
||||
|
||||
# Delete combined audio files
|
||||
for audio_file in cache_dir.glob("combined_*.wav"):
|
||||
try:
|
||||
audio_file.unlink()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete combined audio file %s: %s", audio_file, e)
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
def clear_profile_cache(profile_id: str) -> int:
|
||||
"""
|
||||
Clear cache files for a specific profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
|
||||
Returns:
|
||||
Number of cache files deleted
|
||||
"""
|
||||
cache_dir = _get_cache_dir()
|
||||
deleted_count = 0
|
||||
|
||||
if cache_dir.exists():
|
||||
# Delete combined audio files for this profile
|
||||
pattern = f"combined_{profile_id}_*.wav"
|
||||
for audio_file in cache_dir.glob(pattern):
|
||||
try:
|
||||
audio_file.unlink()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete combined audio file %s: %s", audio_file, e)
|
||||
|
||||
return deleted_count
|
||||
299
backend/utils/chunked_tts.py
Normal file
299
backend/utils/chunked_tts.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Chunked TTS generation utilities.
|
||||
|
||||
Splits long text into sentence-boundary chunks, generates audio per-chunk
|
||||
via any TTSBackend, and concatenates with crossfade. All logic is
|
||||
engine-agnostic — it wraps the standard ``TTSBackend.generate()`` interface.
|
||||
|
||||
Short text (≤ max_chunk_chars) uses the single-shot fast path with zero
|
||||
overhead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger("voicebox.chunked-tts")
|
||||
|
||||
# Default chunk size in characters. Can be overridden per-request via
|
||||
# the ``max_chunk_chars`` field on GenerationRequest.
|
||||
DEFAULT_MAX_CHUNK_CHARS = 800
|
||||
|
||||
# Common abbreviations that should NOT be treated as sentence endings.
|
||||
# Lowercase for case-insensitive matching.
|
||||
_ABBREVIATIONS = frozenset(
|
||||
{
|
||||
"mr",
|
||||
"mrs",
|
||||
"ms",
|
||||
"dr",
|
||||
"prof",
|
||||
"sr",
|
||||
"jr",
|
||||
"st",
|
||||
"ave",
|
||||
"blvd",
|
||||
"inc",
|
||||
"ltd",
|
||||
"corp",
|
||||
"dept",
|
||||
"est",
|
||||
"approx",
|
||||
"vs",
|
||||
"etc",
|
||||
"e.g",
|
||||
"i.e",
|
||||
"a.m",
|
||||
"p.m",
|
||||
"u.s",
|
||||
"u.s.a",
|
||||
"u.k",
|
||||
}
|
||||
)
|
||||
|
||||
# Paralinguistic tags used by Chatterbox Turbo. The splitter must never
|
||||
# cut inside one of these.
|
||||
_PARA_TAG_RE = re.compile(r"\[[^\]]*\]")
|
||||
|
||||
|
||||
def split_text_into_chunks(text: str, max_chars: int = DEFAULT_MAX_CHUNK_CHARS) -> List[str]:
|
||||
"""Split *text* at natural boundaries into chunks of at most *max_chars*.
|
||||
|
||||
Priority: sentence-end (``.!?`` not preceded by an abbreviation and not
|
||||
inside brackets) → clause boundary (``;:,—``) → whitespace → hard cut.
|
||||
|
||||
Paralinguistic tags like ``[laugh]`` are treated as atomic and will not
|
||||
be split across chunks.
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
if len(text) <= max_chars:
|
||||
return [text]
|
||||
|
||||
chunks: List[str] = []
|
||||
remaining = text
|
||||
|
||||
while remaining:
|
||||
remaining = remaining.lstrip()
|
||||
if not remaining:
|
||||
break
|
||||
if len(remaining) <= max_chars:
|
||||
chunks.append(remaining)
|
||||
break
|
||||
|
||||
segment = remaining[:max_chars]
|
||||
|
||||
# Try to split at the last real sentence ending
|
||||
split_pos = _find_last_sentence_end(segment)
|
||||
if split_pos == -1:
|
||||
split_pos = _find_last_clause_boundary(segment)
|
||||
if split_pos == -1:
|
||||
split_pos = segment.rfind(" ")
|
||||
if split_pos == -1:
|
||||
# Absolute fallback: hard cut but avoid splitting inside a tag
|
||||
split_pos = _safe_hard_cut(segment, max_chars)
|
||||
|
||||
chunk = remaining[: split_pos + 1].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
remaining = remaining[split_pos + 1 :]
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _find_last_sentence_end(text: str) -> int:
|
||||
"""Return the index of the last sentence-ending punctuation in *text*.
|
||||
|
||||
Skips periods that follow common abbreviations (``Dr.``, ``Mr.``, etc.)
|
||||
and periods inside bracket tags (``[laugh]``). Also handles CJK
|
||||
sentence-ending punctuation (``。!?``).
|
||||
"""
|
||||
best = -1
|
||||
# ASCII sentence ends
|
||||
for m in re.finditer(r"[.!?](?:\s|$)", text):
|
||||
pos = m.start()
|
||||
char = text[pos]
|
||||
# Skip periods after abbreviations
|
||||
if char == ".":
|
||||
# Walk backwards to find the preceding word
|
||||
word_start = pos - 1
|
||||
while word_start >= 0 and text[word_start].isalpha():
|
||||
word_start -= 1
|
||||
word = text[word_start + 1 : pos].lower()
|
||||
if word in _ABBREVIATIONS:
|
||||
continue
|
||||
# Skip decimal numbers (digit immediately before the period)
|
||||
if word_start >= 0 and text[word_start].isdigit():
|
||||
continue
|
||||
# Skip if we're inside a bracket tag
|
||||
if _inside_bracket_tag(text, pos):
|
||||
continue
|
||||
best = pos
|
||||
# CJK sentence-ending punctuation
|
||||
for m in re.finditer(r"[\u3002\uff01\uff1f]", text):
|
||||
if m.start() > best:
|
||||
best = m.start()
|
||||
return best
|
||||
|
||||
|
||||
def _find_last_clause_boundary(text: str) -> int:
|
||||
"""Return the index of the last clause-boundary punctuation."""
|
||||
best = -1
|
||||
for m in re.finditer(r"[;:,\u2014](?:\s|$)", text):
|
||||
pos = m.start()
|
||||
# Skip if inside a bracket tag
|
||||
if _inside_bracket_tag(text, pos):
|
||||
continue
|
||||
best = pos
|
||||
return best
|
||||
|
||||
|
||||
def _inside_bracket_tag(text: str, pos: int) -> bool:
|
||||
"""Return True if *pos* falls inside a ``[...]`` tag."""
|
||||
for m in _PARA_TAG_RE.finditer(text):
|
||||
if m.start() < pos < m.end():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _safe_hard_cut(segment: str, max_chars: int) -> int:
|
||||
"""Find a hard-cut position that doesn't split a ``[tag]``."""
|
||||
cut = max_chars - 1
|
||||
# Check if the cut falls inside a bracket tag; if so, move before it
|
||||
for m in _PARA_TAG_RE.finditer(segment):
|
||||
if m.start() < cut < m.end():
|
||||
return m.start() - 1 if m.start() > 0 else cut
|
||||
return cut
|
||||
|
||||
|
||||
def concatenate_audio_chunks(
|
||||
chunks: List[np.ndarray],
|
||||
sample_rate: int,
|
||||
crossfade_ms: int = 50,
|
||||
) -> np.ndarray:
|
||||
"""Concatenate audio arrays with a short crossfade to eliminate clicks.
|
||||
|
||||
Each chunk is expected to be a 1-D float32 ndarray at *sample_rate* Hz.
|
||||
"""
|
||||
if not chunks:
|
||||
return np.array([], dtype=np.float32)
|
||||
if len(chunks) == 1:
|
||||
return chunks[0]
|
||||
|
||||
crossfade_samples = int(sample_rate * crossfade_ms / 1000)
|
||||
result = np.array(chunks[0], dtype=np.float32, copy=True)
|
||||
|
||||
for chunk in chunks[1:]:
|
||||
if len(chunk) == 0:
|
||||
continue
|
||||
overlap = min(crossfade_samples, len(result), len(chunk))
|
||||
if overlap > 0:
|
||||
fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32)
|
||||
fade_in = np.linspace(0.0, 1.0, overlap, dtype=np.float32)
|
||||
result[-overlap:] = result[-overlap:] * fade_out + chunk[:overlap] * fade_in
|
||||
result = np.concatenate([result, chunk[overlap:]])
|
||||
else:
|
||||
result = np.concatenate([result, chunk])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def generate_chunked(
|
||||
backend,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: int | None = None,
|
||||
instruct: str | None = None,
|
||||
max_chunk_chars: int = DEFAULT_MAX_CHUNK_CHARS,
|
||||
crossfade_ms: int = 50,
|
||||
trim_fn=None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""Generate audio with automatic chunking for long text.
|
||||
|
||||
For text shorter than *max_chunk_chars* this is a thin wrapper around
|
||||
``backend.generate()`` with zero overhead.
|
||||
|
||||
For longer text the input is split at natural sentence boundaries,
|
||||
each chunk is generated independently, optionally trimmed (useful for
|
||||
Chatterbox engines that hallucinate trailing noise), and the results
|
||||
are concatenated with a crossfade (or hard cut if *crossfade_ms* is 0).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend : TTSBackend
|
||||
Any backend implementing the ``generate()`` protocol.
|
||||
text : str
|
||||
Input text (may be arbitrarily long).
|
||||
voice_prompt, language, seed, instruct
|
||||
Forwarded to ``backend.generate()`` verbatim.
|
||||
max_chunk_chars : int
|
||||
Maximum characters per chunk (default 800).
|
||||
crossfade_ms : int
|
||||
Crossfade duration in milliseconds between chunks. 0 for a hard
|
||||
cut with no overlap (default 50).
|
||||
trim_fn : callable | None
|
||||
Optional ``(audio, sample_rate) -> audio`` post-processing
|
||||
function applied to each chunk before concatenation (e.g.
|
||||
``trim_tts_output`` for Chatterbox engines).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(audio, sample_rate) : Tuple[np.ndarray, int]
|
||||
"""
|
||||
chunks = split_text_into_chunks(text, max_chunk_chars)
|
||||
|
||||
if len(chunks) <= 1:
|
||||
# Short text — single-shot fast path
|
||||
audio, sample_rate = await backend.generate(
|
||||
text,
|
||||
voice_prompt,
|
||||
language,
|
||||
seed,
|
||||
instruct,
|
||||
)
|
||||
if trim_fn is not None:
|
||||
audio = trim_fn(audio, sample_rate)
|
||||
return audio, sample_rate
|
||||
|
||||
# Long text — chunked generation
|
||||
logger.info(
|
||||
"Splitting %d chars into %d chunks (max %d chars each)",
|
||||
len(text),
|
||||
len(chunks),
|
||||
max_chunk_chars,
|
||||
)
|
||||
audio_chunks: List[np.ndarray] = []
|
||||
sample_rate: int | None = None
|
||||
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
logger.info(
|
||||
"Generating chunk %d/%d (%d chars)",
|
||||
i + 1,
|
||||
len(chunks),
|
||||
len(chunk_text),
|
||||
)
|
||||
# Vary the seed per chunk to avoid correlated RNG artefacts,
|
||||
# but keep it deterministic so the same (text, seed) pair
|
||||
# always produces the same output.
|
||||
chunk_seed = (seed + i) if seed is not None else None
|
||||
|
||||
chunk_audio, chunk_sr = await backend.generate(
|
||||
chunk_text,
|
||||
voice_prompt,
|
||||
language,
|
||||
chunk_seed,
|
||||
instruct,
|
||||
)
|
||||
if trim_fn is not None:
|
||||
chunk_audio = trim_fn(chunk_audio, chunk_sr)
|
||||
|
||||
audio_chunks.append(np.asarray(chunk_audio, dtype=np.float32))
|
||||
if sample_rate is None:
|
||||
sample_rate = chunk_sr
|
||||
|
||||
audio = concatenate_audio_chunks(audio_chunks, sample_rate, crossfade_ms=crossfade_ms)
|
||||
return audio, sample_rate
|
||||
95
backend/utils/dac_shim.py
Normal file
95
backend/utils/dac_shim.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Minimal shim for descript-audio-codec (DAC).
|
||||
|
||||
TADA only imports Snake1d from dac.nn.layers and dac.model.dac.
|
||||
The real DAC package pulls in descript-audiotools which depends on
|
||||
onnx, tensorboard, protobuf, matplotlib, pystoi, etc. — none of
|
||||
which are needed for TADA's runtime use of Snake1d.
|
||||
|
||||
This shim provides the exact Snake1d implementation (MIT-licensed,
|
||||
from https://github.com/descriptinc/descript-audio-codec) so we can
|
||||
avoid the entire audiotools dependency chain.
|
||||
|
||||
If the real DAC package is installed, this module is never used —
|
||||
Python's import system will find the site-packages version first.
|
||||
Install this shim only when descript-audio-codec is NOT installed.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# ── Snake activation (from dac/nn/layers.py) ────────────────────────
|
||||
|
||||
# NOTE: The original DAC code uses @torch.jit.script here for a 1.4x
|
||||
# speedup. We omit it because TorchScript calls inspect.getsource()
|
||||
# which fails inside a PyInstaller frozen binary (no .py source files).
|
||||
def snake(x: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
|
||||
shape = x.shape
|
||||
x = x.reshape(shape[0], shape[1], -1)
|
||||
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return snake(x, self.alpha)
|
||||
|
||||
|
||||
# ── Register as dac.nn.layers and dac.model.dac ─────────────────────
|
||||
|
||||
def install_dac_shim() -> None:
|
||||
"""Register fake dac package modules in sys.modules.
|
||||
|
||||
Only installs the shim if 'dac' is not already importable
|
||||
(i.e. the real descript-audio-codec is not installed).
|
||||
"""
|
||||
try:
|
||||
import dac # noqa: F401 — real package exists, do nothing
|
||||
return
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Create the module tree: dac -> dac.nn -> dac.nn.layers
|
||||
# -> dac.model -> dac.model.dac
|
||||
dac_pkg = types.ModuleType("dac")
|
||||
dac_pkg.__path__ = [] # make it a package
|
||||
dac_pkg.__package__ = "dac"
|
||||
|
||||
dac_nn = types.ModuleType("dac.nn")
|
||||
dac_nn.__path__ = []
|
||||
dac_nn.__package__ = "dac.nn"
|
||||
|
||||
dac_nn_layers = types.ModuleType("dac.nn.layers")
|
||||
dac_nn_layers.__package__ = "dac.nn"
|
||||
dac_nn_layers.Snake1d = Snake1d
|
||||
dac_nn_layers.snake = snake
|
||||
|
||||
dac_model = types.ModuleType("dac.model")
|
||||
dac_model.__path__ = []
|
||||
dac_model.__package__ = "dac.model"
|
||||
|
||||
dac_model_dac = types.ModuleType("dac.model.dac")
|
||||
dac_model_dac.__package__ = "dac.model"
|
||||
dac_model_dac.Snake1d = Snake1d
|
||||
|
||||
# Wire up submodules
|
||||
dac_pkg.nn = dac_nn
|
||||
dac_pkg.model = dac_model
|
||||
dac_nn.layers = dac_nn_layers
|
||||
dac_model.dac = dac_model_dac
|
||||
|
||||
# Register in sys.modules
|
||||
sys.modules["dac"] = dac_pkg
|
||||
sys.modules["dac.nn"] = dac_nn
|
||||
sys.modules["dac.nn.layers"] = dac_nn_layers
|
||||
sys.modules["dac.model"] = dac_model
|
||||
sys.modules["dac.model.dac"] = dac_model_dac
|
||||
373
backend/utils/effects.py
Normal file
373
backend/utils/effects.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""
|
||||
Audio post-processing effects engine.
|
||||
|
||||
Uses Spotify's pedalboard library to apply professional-grade DSP effects
|
||||
to generated audio. Effects are described as a JSON-serializable chain
|
||||
(list of effect dicts) so they can be stored in the database and sent
|
||||
over the API.
|
||||
|
||||
Supported effect types:
|
||||
- chorus (flanger-style with short delays)
|
||||
- reverb (room reverb)
|
||||
- delay (echo / delay line)
|
||||
- compressor (dynamic range compression)
|
||||
- gain (volume adjustment in dB)
|
||||
- highpass (high-pass filter)
|
||||
- lowpass (low-pass filter)
|
||||
- pitch_shift (semitone pitch shifting)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pedalboard import (
|
||||
Pedalboard,
|
||||
Chorus,
|
||||
Reverb,
|
||||
Compressor,
|
||||
Gain,
|
||||
HighpassFilter,
|
||||
LowpassFilter,
|
||||
Delay,
|
||||
PitchShift,
|
||||
)
|
||||
|
||||
|
||||
# Each param definition: (default, min, max, description)
|
||||
EFFECT_REGISTRY: Dict[str, Dict[str, Any]] = {
|
||||
"chorus": {
|
||||
"cls": Chorus,
|
||||
"label": "Chorus / Flanger",
|
||||
"description": "Modulated delay for flanging or chorus effects. Short centre_delay_ms (<10) gives flanger; longer gives chorus.",
|
||||
"params": {
|
||||
"rate_hz": {"default": 1.0, "min": 0.01, "max": 20.0, "step": 0.01, "description": "LFO speed (Hz)"},
|
||||
"depth": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Modulation depth"},
|
||||
"feedback": {"default": 0.0, "min": 0.0, "max": 0.95, "step": 0.01, "description": "Feedback amount"},
|
||||
"centre_delay_ms": {
|
||||
"default": 7.0,
|
||||
"min": 0.5,
|
||||
"max": 50.0,
|
||||
"step": 0.1,
|
||||
"description": "Centre delay (ms)",
|
||||
},
|
||||
"mix": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Wet/dry mix"},
|
||||
},
|
||||
},
|
||||
"reverb": {
|
||||
"cls": Reverb,
|
||||
"label": "Reverb",
|
||||
"description": "Room reverb effect.",
|
||||
"params": {
|
||||
"room_size": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Room size"},
|
||||
"damping": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "High frequency damping"},
|
||||
"wet_level": {"default": 0.33, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Wet level"},
|
||||
"dry_level": {"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Dry level"},
|
||||
"width": {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Stereo width"},
|
||||
},
|
||||
},
|
||||
"delay": {
|
||||
"cls": Delay,
|
||||
"label": "Delay",
|
||||
"description": "Echo / delay line.",
|
||||
"params": {
|
||||
"delay_seconds": {
|
||||
"default": 0.3,
|
||||
"min": 0.01,
|
||||
"max": 2.0,
|
||||
"step": 0.01,
|
||||
"description": "Delay time (seconds)",
|
||||
},
|
||||
"feedback": {"default": 0.3, "min": 0.0, "max": 0.95, "step": 0.01, "description": "Feedback amount"},
|
||||
"mix": {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Wet/dry mix"},
|
||||
},
|
||||
},
|
||||
"compressor": {
|
||||
"cls": Compressor,
|
||||
"label": "Compressor",
|
||||
"description": "Dynamic range compression for consistent loudness.",
|
||||
"params": {
|
||||
"threshold_db": {"default": -20.0, "min": -60.0, "max": 0.0, "step": 0.5, "description": "Threshold (dB)"},
|
||||
"ratio": {"default": 4.0, "min": 1.0, "max": 20.0, "step": 0.1, "description": "Compression ratio"},
|
||||
"attack_ms": {"default": 10.0, "min": 0.1, "max": 100.0, "step": 0.1, "description": "Attack time (ms)"},
|
||||
"release_ms": {
|
||||
"default": 100.0,
|
||||
"min": 10.0,
|
||||
"max": 1000.0,
|
||||
"step": 1.0,
|
||||
"description": "Release time (ms)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"gain": {
|
||||
"cls": Gain,
|
||||
"label": "Gain",
|
||||
"description": "Volume adjustment in decibels.",
|
||||
"params": {
|
||||
"gain_db": {"default": 0.0, "min": -40.0, "max": 40.0, "step": 0.5, "description": "Gain (dB)"},
|
||||
},
|
||||
},
|
||||
"highpass": {
|
||||
"cls": HighpassFilter,
|
||||
"label": "High-Pass Filter",
|
||||
"description": "Removes frequencies below the cutoff.",
|
||||
"params": {
|
||||
"cutoff_frequency_hz": {
|
||||
"default": 80.0,
|
||||
"min": 20.0,
|
||||
"max": 8000.0,
|
||||
"step": 1.0,
|
||||
"description": "Cutoff frequency (Hz)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"lowpass": {
|
||||
"cls": LowpassFilter,
|
||||
"label": "Low-Pass Filter",
|
||||
"description": "Removes frequencies above the cutoff.",
|
||||
"params": {
|
||||
"cutoff_frequency_hz": {
|
||||
"default": 8000.0,
|
||||
"min": 200.0,
|
||||
"max": 20000.0,
|
||||
"step": 1.0,
|
||||
"description": "Cutoff frequency (Hz)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"pitch_shift": {
|
||||
"cls": PitchShift,
|
||||
"label": "Pitch Shift",
|
||||
"description": "Shift pitch up or down by semitones.",
|
||||
"params": {
|
||||
"semitones": {"default": 0.0, "min": -12.0, "max": 12.0, "step": 0.5, "description": "Semitones to shift"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
BUILTIN_PRESETS: Dict[str, Dict[str, Any]] = {
|
||||
"robotic": {
|
||||
"name": "Robotic",
|
||||
"sort_order": 0,
|
||||
"description": "Metallic robotic voice (flanger with slow LFO and high feedback)",
|
||||
"effects_chain": [
|
||||
{
|
||||
"type": "chorus",
|
||||
"enabled": True,
|
||||
"params": {
|
||||
"rate_hz": 0.2,
|
||||
"depth": 1.0,
|
||||
"feedback": 0.35,
|
||||
"centre_delay_ms": 7.0,
|
||||
"mix": 0.5,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
"radio": {
|
||||
"name": "Radio",
|
||||
"sort_order": 1,
|
||||
"description": "Thin AM-radio voice with band-pass filtering and light compression",
|
||||
"effects_chain": [
|
||||
{
|
||||
"type": "highpass",
|
||||
"enabled": True,
|
||||
"params": {"cutoff_frequency_hz": 300.0},
|
||||
},
|
||||
{
|
||||
"type": "lowpass",
|
||||
"enabled": True,
|
||||
"params": {"cutoff_frequency_hz": 3500.0},
|
||||
},
|
||||
{
|
||||
"type": "compressor",
|
||||
"enabled": True,
|
||||
"params": {
|
||||
"threshold_db": -15.0,
|
||||
"ratio": 6.0,
|
||||
"attack_ms": 5.0,
|
||||
"release_ms": 50.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "gain",
|
||||
"enabled": True,
|
||||
"params": {"gain_db": 6.0},
|
||||
},
|
||||
],
|
||||
},
|
||||
"echo_chamber": {
|
||||
"name": "Echo Chamber",
|
||||
"sort_order": 2,
|
||||
"description": "Spacious reverb with trailing echo",
|
||||
"effects_chain": [
|
||||
{
|
||||
"type": "reverb",
|
||||
"enabled": True,
|
||||
"params": {
|
||||
"room_size": 0.85,
|
||||
"damping": 0.3,
|
||||
"wet_level": 0.45,
|
||||
"dry_level": 0.55,
|
||||
"width": 1.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "delay",
|
||||
"enabled": True,
|
||||
"params": {
|
||||
"delay_seconds": 0.25,
|
||||
"feedback": 0.3,
|
||||
"mix": 0.2,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
"deep_voice": {
|
||||
"name": "Deep Voice",
|
||||
"sort_order": 99,
|
||||
"description": "Lower pitch with added warmth",
|
||||
"effects_chain": [
|
||||
{
|
||||
"type": "pitch_shift",
|
||||
"enabled": True,
|
||||
"params": {"semitones": -3.0},
|
||||
},
|
||||
{
|
||||
"type": "lowpass",
|
||||
"enabled": True,
|
||||
"params": {"cutoff_frequency_hz": 6000.0},
|
||||
},
|
||||
{
|
||||
"type": "compressor",
|
||||
"enabled": True,
|
||||
"params": {
|
||||
"threshold_db": -18.0,
|
||||
"ratio": 3.0,
|
||||
"attack_ms": 10.0,
|
||||
"release_ms": 150.0,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_available_effects() -> List[Dict[str, Any]]:
|
||||
"""Return the list of available effect types with their parameter definitions.
|
||||
|
||||
Used by the frontend to build the effects chain editor UI.
|
||||
"""
|
||||
result = []
|
||||
for effect_type, info in EFFECT_REGISTRY.items():
|
||||
result.append(
|
||||
{
|
||||
"type": effect_type,
|
||||
"label": info["label"],
|
||||
"description": info["description"],
|
||||
"params": {name: {k: v for k, v in pdef.items()} for name, pdef in info["params"].items()},
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_builtin_presets() -> Dict[str, Dict[str, Any]]:
|
||||
"""Return all built-in effect presets."""
|
||||
return BUILTIN_PRESETS
|
||||
|
||||
|
||||
def validate_effects_chain(effects_chain: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Validate an effects chain configuration.
|
||||
|
||||
Returns None if valid, or an error message string.
|
||||
"""
|
||||
if not isinstance(effects_chain, list):
|
||||
return "effects_chain must be a list"
|
||||
|
||||
for i, effect in enumerate(effects_chain):
|
||||
if not isinstance(effect, dict):
|
||||
return f"Effect at index {i} must be a dict"
|
||||
|
||||
effect_type = effect.get("type")
|
||||
if effect_type not in EFFECT_REGISTRY:
|
||||
return f"Unknown effect type '{effect_type}' at index {i}. Available: {list(EFFECT_REGISTRY.keys())}"
|
||||
|
||||
params = effect.get("params", {})
|
||||
if not isinstance(params, dict):
|
||||
return f"Effect '{effect_type}' at index {i}: params must be a dict"
|
||||
|
||||
registry = EFFECT_REGISTRY[effect_type]
|
||||
for param_name, value in params.items():
|
||||
if param_name not in registry["params"]:
|
||||
return f"Effect '{effect_type}' at index {i}: unknown param '{param_name}'"
|
||||
|
||||
pdef = registry["params"][param_name]
|
||||
if not isinstance(value, (int, float)):
|
||||
return f"Effect '{effect_type}' at index {i}: param '{param_name}' must be a number"
|
||||
if value < pdef["min"] or value > pdef["max"]:
|
||||
return (
|
||||
f"Effect '{effect_type}' at index {i}: param '{param_name}' "
|
||||
f"must be between {pdef['min']} and {pdef['max']} (got {value})"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def build_pedalboard(effects_chain: List[Dict[str, Any]]) -> Pedalboard:
|
||||
"""Build a Pedalboard instance from an effects chain config.
|
||||
|
||||
Skips effects where ``enabled`` is ``False``.
|
||||
"""
|
||||
plugins = []
|
||||
for effect in effects_chain:
|
||||
if not effect.get("enabled", True):
|
||||
continue
|
||||
|
||||
effect_type = effect["type"]
|
||||
registry = EFFECT_REGISTRY[effect_type]
|
||||
cls = registry["cls"]
|
||||
|
||||
# Merge defaults with provided params
|
||||
params = {}
|
||||
for pname, pdef in registry["params"].items():
|
||||
params[pname] = effect.get("params", {}).get(pname, pdef["default"])
|
||||
|
||||
plugins.append(cls(**params))
|
||||
|
||||
return Pedalboard(plugins)
|
||||
|
||||
|
||||
def apply_effects(
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
effects_chain: List[Dict[str, Any]],
|
||||
) -> np.ndarray:
|
||||
"""Apply an effects chain to audio data.
|
||||
|
||||
Args:
|
||||
audio: Input audio array (1-D mono float32).
|
||||
sample_rate: Sample rate in Hz.
|
||||
effects_chain: List of effect configuration dicts.
|
||||
|
||||
Returns:
|
||||
Processed audio array.
|
||||
"""
|
||||
if not effects_chain:
|
||||
return audio
|
||||
|
||||
board = build_pedalboard(effects_chain)
|
||||
|
||||
# pedalboard expects shape (channels, samples)
|
||||
if audio.ndim == 1:
|
||||
audio_2d = audio[np.newaxis, :]
|
||||
else:
|
||||
audio_2d = audio
|
||||
|
||||
processed = board(audio_2d.astype(np.float32), sample_rate)
|
||||
|
||||
# Return same dimensionality as input
|
||||
if audio.ndim == 1:
|
||||
return processed[0]
|
||||
return processed
|
||||
270
backend/utils/hf_offline_patch.py
Normal file
270
backend/utils/hf_offline_patch.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Monkey-patch huggingface_hub to force offline mode with cached models.
|
||||
|
||||
Prevents mlx_audio / transformers from making network requests when models
|
||||
are already downloaded. Must be imported BEFORE mlx_audio.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# huggingface_hub reads ``HF_HUB_OFFLINE`` once at import time into
|
||||
# ``huggingface_hub.constants.HF_HUB_OFFLINE``; transformers mirrors that into
|
||||
# ``transformers.utils.hub._is_offline_mode`` at *its* import time. Toggling
|
||||
# ``os.environ`` after either module is imported does not flip those cached
|
||||
# bools, and the hot paths (``_http._default_backend_factory``,
|
||||
# ``transformers.utils.hub.is_offline_mode``) read the bools — not the env.
|
||||
# We mutate the cached constants directly, guarded by a refcount so
|
||||
# concurrent inference threads share a single offline window safely.
|
||||
|
||||
_offline_lock = threading.RLock()
|
||||
_offline_refcount = 0
|
||||
_saved_env: Optional[str] = None
|
||||
_saved_hf_const: Optional[bool] = None
|
||||
_saved_transformers_const: Optional[bool] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def force_offline_if_cached(is_cached: bool, model_label: str = ""):
|
||||
"""Force offline mode for the duration of a cached-model operation.
|
||||
|
||||
Flips ``HF_HUB_OFFLINE`` in the process env **and** in the cached bools
|
||||
inside ``huggingface_hub.constants`` / ``transformers.utils.hub`` so HTTP
|
||||
adapters and offline-mode checks actually see the change. Uses a refcount
|
||||
so multiple concurrent inference threads share a single offline window
|
||||
and the last one to exit restores state.
|
||||
|
||||
If *is_cached* is ``False`` the block runs normally (network allowed).
|
||||
|
||||
Args:
|
||||
is_cached: Whether the model weights are already on disk.
|
||||
model_label: Human-readable name used in log messages.
|
||||
"""
|
||||
if not is_cached:
|
||||
yield
|
||||
return
|
||||
|
||||
global _offline_refcount, _saved_env, _saved_hf_const, _saved_transformers_const
|
||||
|
||||
with _offline_lock:
|
||||
if _offline_refcount == 0:
|
||||
# Snapshot prior state, apply new state, roll back on *any*
|
||||
# failure. Catching only ImportError here would let a partially
|
||||
# broken install (RuntimeError, AttributeError from a half-init
|
||||
# module, etc.) leave the cached HF constants mutated without
|
||||
# bumping the refcount — a persistent offline leak that outlives
|
||||
# the process and is miserable to debug.
|
||||
prev_env = os.environ.get("HF_HUB_OFFLINE")
|
||||
prev_hf: Optional[bool] = None
|
||||
prev_tf: Optional[bool] = None
|
||||
try:
|
||||
try:
|
||||
import huggingface_hub.constants as hf_const
|
||||
|
||||
prev_hf = hf_const.HF_HUB_OFFLINE
|
||||
hf_const.HF_HUB_OFFLINE = True
|
||||
except ImportError:
|
||||
prev_hf = None
|
||||
|
||||
try:
|
||||
import transformers.utils.hub as tf_hub
|
||||
|
||||
prev_tf = tf_hub._is_offline_mode
|
||||
tf_hub._is_offline_mode = True
|
||||
except ImportError:
|
||||
prev_tf = None
|
||||
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
except BaseException:
|
||||
# Roll back whatever we already changed, then re-raise so
|
||||
# the caller sees the real failure.
|
||||
if prev_hf is not None:
|
||||
try:
|
||||
import huggingface_hub.constants as hf_const
|
||||
|
||||
hf_const.HF_HUB_OFFLINE = prev_hf
|
||||
except ImportError:
|
||||
pass
|
||||
if prev_tf is not None:
|
||||
try:
|
||||
import transformers.utils.hub as tf_hub
|
||||
|
||||
tf_hub._is_offline_mode = prev_tf
|
||||
except ImportError:
|
||||
pass
|
||||
if prev_env is not None:
|
||||
os.environ["HF_HUB_OFFLINE"] = prev_env
|
||||
else:
|
||||
os.environ.pop("HF_HUB_OFFLINE", None)
|
||||
raise
|
||||
|
||||
_saved_env = prev_env
|
||||
_saved_hf_const = prev_hf
|
||||
_saved_transformers_const = prev_tf
|
||||
logger.info(
|
||||
"[offline-guard] %s is cached — forcing offline mode",
|
||||
model_label or "model",
|
||||
)
|
||||
_offline_refcount += 1
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
with _offline_lock:
|
||||
_offline_refcount -= 1
|
||||
if _offline_refcount == 0:
|
||||
if _saved_env is not None:
|
||||
os.environ["HF_HUB_OFFLINE"] = _saved_env
|
||||
else:
|
||||
os.environ.pop("HF_HUB_OFFLINE", None)
|
||||
if _saved_hf_const is not None:
|
||||
try:
|
||||
import huggingface_hub.constants as hf_const
|
||||
|
||||
hf_const.HF_HUB_OFFLINE = _saved_hf_const
|
||||
except ImportError:
|
||||
pass
|
||||
if _saved_transformers_const is not None:
|
||||
try:
|
||||
import transformers.utils.hub as tf_hub
|
||||
|
||||
tf_hub._is_offline_mode = _saved_transformers_const
|
||||
except ImportError:
|
||||
pass
|
||||
_saved_env = None
|
||||
_saved_hf_const = None
|
||||
_saved_transformers_const = None
|
||||
|
||||
|
||||
_mistral_regex_patched = False
|
||||
|
||||
|
||||
def patch_transformers_mistral_regex():
|
||||
"""Make transformers' tokenizer load robust to HuggingFace metadata failures.
|
||||
|
||||
transformers 4.57.x added ``PreTrainedTokenizerBase._patch_mistral_regex``
|
||||
which unconditionally calls ``huggingface_hub.model_info(repo_id)`` during
|
||||
every non-local tokenizer load to check whether the model is a Mistral
|
||||
variant. That call raises on ``HF_HUB_OFFLINE=1`` and on plain network
|
||||
failures, killing unrelated loads (Qwen TTS, TADA, etc.).
|
||||
|
||||
Voicebox never loads Mistral models, so the rewrite the function would
|
||||
apply is a no-op for us anyway. Wrap the method so any exception from the
|
||||
metadata lookup returns the tokenizer unchanged — matching the success-path
|
||||
behavior for non-Mistral repos (transformers 4.57.3,
|
||||
``tokenization_utils_base.py:2503``).
|
||||
"""
|
||||
global _mistral_regex_patched
|
||||
if _mistral_regex_patched:
|
||||
return
|
||||
|
||||
try:
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
except ImportError:
|
||||
logger.debug("transformers not available, skipping mistral-regex patch")
|
||||
return
|
||||
|
||||
original = getattr(PreTrainedTokenizerBase, "_patch_mistral_regex", None)
|
||||
if original is None:
|
||||
logger.debug(
|
||||
"transformers has no _patch_mistral_regex attribute, skipping patch",
|
||||
)
|
||||
return
|
||||
|
||||
def safe_patch_mistral_regex(cls, tokenizer, pretrained_model_name_or_path, *args, **kwargs):
|
||||
try:
|
||||
return original(tokenizer, pretrained_model_name_or_path, *args, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[mistral-regex-patch] suppressed %s for %r, returning tokenizer as-is",
|
||||
type(exc).__name__,
|
||||
pretrained_model_name_or_path,
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
PreTrainedTokenizerBase._patch_mistral_regex = classmethod(safe_patch_mistral_regex)
|
||||
_mistral_regex_patched = True
|
||||
logger.debug("installed _patch_mistral_regex wrapper")
|
||||
|
||||
|
||||
def patch_huggingface_hub_offline():
|
||||
"""Monkey-patch huggingface_hub to force offline mode."""
|
||||
try:
|
||||
import huggingface_hub # noqa: F401 -- need the package loaded
|
||||
from huggingface_hub import constants as hf_constants
|
||||
from huggingface_hub.file_download import _try_to_load_from_cache
|
||||
|
||||
original_try_load = _try_to_load_from_cache
|
||||
|
||||
def _patched_try_to_load_from_cache(
|
||||
repo_id: str,
|
||||
filename: str,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
revision: Optional[str] = None,
|
||||
repo_type: Optional[str] = None,
|
||||
):
|
||||
result = original_try_load(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
cache_path = Path(hf_constants.HF_HUB_CACHE) / f"models--{repo_id.replace('/', '--')}"
|
||||
logger.debug("file not cached: %s/%s (expected at %s)", repo_id, filename, cache_path)
|
||||
else:
|
||||
logger.debug("cache hit: %s/%s", repo_id, filename)
|
||||
|
||||
return result
|
||||
|
||||
import huggingface_hub.file_download as fd
|
||||
|
||||
fd._try_to_load_from_cache = _patched_try_to_load_from_cache
|
||||
logger.debug("huggingface_hub patched for offline mode")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("huggingface_hub not available, skipping offline patch")
|
||||
except Exception:
|
||||
logger.exception("failed to patch huggingface_hub for offline mode")
|
||||
|
||||
|
||||
def ensure_original_qwen_config_cached():
|
||||
"""Symlink the original Qwen repo cache to the MLX community version.
|
||||
|
||||
mlx_audio may try to fetch config from the original Qwen repo. If only
|
||||
the MLX community variant is cached, create a symlink so the cache lookup
|
||||
succeeds without a network request.
|
||||
"""
|
||||
try:
|
||||
from huggingface_hub import constants as hf_constants
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
original_repo = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
mlx_repo = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16"
|
||||
|
||||
cache_dir = Path(hf_constants.HF_HUB_CACHE)
|
||||
original_path = cache_dir / f"models--{original_repo.replace('/', '--')}"
|
||||
mlx_path = cache_dir / f"models--{mlx_repo.replace('/', '--')}"
|
||||
|
||||
if not original_path.exists() and mlx_path.exists():
|
||||
try:
|
||||
original_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
original_path.symlink_to(mlx_path, target_is_directory=True)
|
||||
logger.info("created cache symlink: %s -> %s", original_repo, mlx_repo)
|
||||
except Exception:
|
||||
logger.warning("could not create cache symlink for %s", original_repo, exc_info=True)
|
||||
|
||||
|
||||
if os.environ.get("VOICEBOX_OFFLINE_PATCH", "1") != "0":
|
||||
patch_huggingface_hub_offline()
|
||||
patch_transformers_mistral_regex()
|
||||
ensure_original_qwen_config_cached()
|
||||
383
backend/utils/hf_progress.py
Normal file
383
backend/utils/hf_progress.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""
|
||||
HuggingFace Hub download progress tracking.
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable
|
||||
from contextlib import contextmanager
|
||||
import logging
|
||||
import threading
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HFProgressTracker:
|
||||
"""Tracks HuggingFace Hub download progress by intercepting tqdm."""
|
||||
|
||||
def __init__(self, progress_callback: Optional[Callable] = None, filter_non_downloads: bool = False):
|
||||
self.progress_callback = progress_callback
|
||||
self.filter_non_downloads = filter_non_downloads # Only filter if True
|
||||
self._original_tqdm_class = None
|
||||
self._lock = threading.Lock()
|
||||
self._total_downloaded = 0
|
||||
self._total_size = 0
|
||||
self._file_sizes = {} # Track sizes of individual files
|
||||
self._file_downloaded = {} # Track downloaded bytes per file
|
||||
self._current_filename = ""
|
||||
self._active_tqdms = {} # Track active tqdm instances
|
||||
self._hf_tqdm_original_update = None # For monkey-patching hf's tqdm
|
||||
|
||||
def _create_tracked_tqdm_class(self):
|
||||
"""Create a tqdm subclass that tracks progress."""
|
||||
tracker = self
|
||||
original_tqdm = self._original_tqdm_class
|
||||
|
||||
class TrackedTqdm(original_tqdm):
|
||||
"""A tqdm subclass that reports progress to our tracker."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Extract filename from desc before passing to parent
|
||||
desc = kwargs.get("desc", "")
|
||||
if not desc and args:
|
||||
first_arg = args[0]
|
||||
if isinstance(first_arg, str):
|
||||
desc = first_arg
|
||||
|
||||
filename = ""
|
||||
if desc:
|
||||
# Try to extract filename from description
|
||||
# HuggingFace Hub uses format like "model.safetensors: 0%|..."
|
||||
if ":" in desc:
|
||||
filename = desc.split(":")[0].strip()
|
||||
else:
|
||||
filename = desc.strip()
|
||||
|
||||
# Filter out non-standard kwargs that huggingface_hub might pass
|
||||
# These are custom kwargs that tqdm doesn't understand
|
||||
filtered_kwargs = {}
|
||||
# Known tqdm kwargs - pass these through
|
||||
tqdm_kwargs = {
|
||||
"iterable",
|
||||
"desc",
|
||||
"total",
|
||||
"leave",
|
||||
"file",
|
||||
"ncols",
|
||||
"mininterval",
|
||||
"maxinterval",
|
||||
"miniters",
|
||||
"ascii",
|
||||
"disable",
|
||||
"unit",
|
||||
"unit_scale",
|
||||
"dynamic_ncols",
|
||||
"smoothing",
|
||||
"bar_format",
|
||||
"initial",
|
||||
"position",
|
||||
"postfix",
|
||||
"unit_divisor",
|
||||
"write_bytes",
|
||||
"lock_args",
|
||||
"nrows",
|
||||
"colour",
|
||||
"color",
|
||||
"delay",
|
||||
"gui",
|
||||
"disable_default",
|
||||
"pos",
|
||||
}
|
||||
for key, value in kwargs.items():
|
||||
if key in tqdm_kwargs:
|
||||
filtered_kwargs[key] = value
|
||||
|
||||
# Force-enable the progress bar — we're tracking progress ourselves,
|
||||
# we don't need tqdm to render to a terminal, but we DO need
|
||||
# self.n to be updated when update() is called.
|
||||
filtered_kwargs["disable"] = False
|
||||
|
||||
# Try to initialize with filtered kwargs, fall back to all kwargs if that fails
|
||||
try:
|
||||
super().__init__(*args, **filtered_kwargs)
|
||||
except TypeError:
|
||||
# If filtering failed, try with all kwargs (maybe tqdm version accepts them)
|
||||
kwargs["disable"] = False
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._tracker_filename = filename or "unknown"
|
||||
|
||||
with tracker._lock:
|
||||
if filename:
|
||||
tracker._current_filename = filename
|
||||
tracker._active_tqdms[id(self)] = {
|
||||
"filename": self._tracker_filename,
|
||||
}
|
||||
|
||||
def update(self, n=1):
|
||||
result = super().update(n)
|
||||
|
||||
# Report progress
|
||||
with tracker._lock:
|
||||
if id(self) in tracker._active_tqdms:
|
||||
filename = tracker._active_tqdms[id(self)]["filename"]
|
||||
current = getattr(self, "n", 0)
|
||||
total = getattr(self, "total", 0)
|
||||
|
||||
if total and total > 0:
|
||||
# Always filter out non-byte progress bars (e.g., "Fetching 12 files")
|
||||
# These cause crazy percentages because they're counting files, not bytes
|
||||
if self._is_non_byte_progress(filename):
|
||||
return result
|
||||
|
||||
# When model is cached, also filter out generation-related progress
|
||||
if tracker.filter_non_downloads:
|
||||
if not self._is_download_progress(filename):
|
||||
return result
|
||||
|
||||
# Update per-file tracking
|
||||
tracker._file_sizes[filename] = total
|
||||
tracker._file_downloaded[filename] = current
|
||||
|
||||
# Calculate totals across all files
|
||||
tracker._total_size = sum(tracker._file_sizes.values())
|
||||
tracker._total_downloaded = sum(tracker._file_downloaded.values())
|
||||
|
||||
# Only report progress once we have a meaningful total (at least 1MB)
|
||||
# This avoids the "100% at 0MB" issue when small config
|
||||
# files are counted before the real model files
|
||||
MIN_TOTAL_BYTES = 1_000_000 # 1MB
|
||||
if tracker._total_size < MIN_TOTAL_BYTES:
|
||||
return result
|
||||
|
||||
# Call progress callback
|
||||
if tracker.progress_callback:
|
||||
tracker.progress_callback(tracker._total_downloaded, tracker._total_size, filename)
|
||||
|
||||
return result
|
||||
|
||||
def _is_non_byte_progress(self, filename: str) -> bool:
|
||||
"""Check if this progress bar should be SKIPPED (returns True to skip).
|
||||
|
||||
We want to track byte-based progress bars. This method identifies
|
||||
progress bars that count files/items instead of bytes, which would
|
||||
cause crazy percentages if mixed with our byte counting.
|
||||
|
||||
Returns:
|
||||
True = SKIP this bar (it's not byte-based)
|
||||
False = TRACK this bar (it counts bytes)
|
||||
"""
|
||||
if not filename:
|
||||
return False
|
||||
|
||||
filename_lower = filename.lower()
|
||||
|
||||
# Skip "Fetching X files" - it counts files (total=12), not bytes
|
||||
# Don't skip "Downloading (incomplete total...)" - that IS byte-based
|
||||
skip_patterns = [
|
||||
"fetching", # "Fetching 12 files" has total=12 files, not bytes
|
||||
]
|
||||
return any(pattern in filename_lower for pattern in skip_patterns)
|
||||
|
||||
def _is_download_progress(self, filename: str) -> bool:
|
||||
"""Check if this is a real file download progress bar vs internal processing."""
|
||||
if not filename or filename == "unknown":
|
||||
return False
|
||||
|
||||
# Real downloads have file extensions
|
||||
download_extensions = [
|
||||
".safetensors",
|
||||
".bin",
|
||||
".pt",
|
||||
".pth", # Model weights
|
||||
".json",
|
||||
".txt",
|
||||
".py", # Config files
|
||||
".msgpack",
|
||||
".h5", # Other formats
|
||||
]
|
||||
|
||||
filename_lower = filename.lower()
|
||||
has_extension = any(filename_lower.endswith(ext) for ext in download_extensions)
|
||||
|
||||
# Skip generation-related progress indicators
|
||||
skip_patterns = ["segment", "processing", "generating", "loading"]
|
||||
has_skip_pattern = any(pattern in filename_lower for pattern in skip_patterns)
|
||||
|
||||
return has_extension and not has_skip_pattern
|
||||
|
||||
def close(self):
|
||||
with tracker._lock:
|
||||
if id(self) in tracker._active_tqdms:
|
||||
del tracker._active_tqdms[id(self)]
|
||||
return super().close()
|
||||
|
||||
return TrackedTqdm
|
||||
|
||||
@contextmanager
|
||||
def patch_download(self):
|
||||
"""Context manager to patch tqdm for progress tracking."""
|
||||
try:
|
||||
import tqdm as tqdm_module
|
||||
|
||||
# Store original tqdm class
|
||||
self._original_tqdm_class = tqdm_module.tqdm
|
||||
|
||||
# Reset totals
|
||||
with self._lock:
|
||||
self._total_downloaded = 0
|
||||
self._total_size = 0
|
||||
self._file_sizes = {}
|
||||
self._file_downloaded = {}
|
||||
self._current_filename = ""
|
||||
self._active_tqdms = {}
|
||||
|
||||
# Create our tracked tqdm class
|
||||
tracked_tqdm = self._create_tracked_tqdm_class()
|
||||
|
||||
# Patch tqdm.tqdm
|
||||
tqdm_module.tqdm = tracked_tqdm
|
||||
|
||||
# Also patch tqdm.auto.tqdm if it exists (used by huggingface_hub)
|
||||
self._original_tqdm_auto = None
|
||||
if hasattr(tqdm_module, "auto") and hasattr(tqdm_module.auto, "tqdm"):
|
||||
self._original_tqdm_auto = tqdm_module.auto.tqdm
|
||||
tqdm_module.auto.tqdm = tracked_tqdm
|
||||
|
||||
# Patch in sys.modules to catch already-imported references
|
||||
# huggingface_hub uses: from tqdm.auto import tqdm as base_tqdm
|
||||
# So we need to patch both 'tqdm' and 'base_tqdm' attributes
|
||||
self._patched_modules = {}
|
||||
tqdm_attr_names = ["tqdm", "base_tqdm", "old_tqdm"] # Various names used
|
||||
|
||||
patched_count = 0
|
||||
for module_name in list(sys.modules.keys()):
|
||||
if "huggingface" in module_name or module_name.startswith("tqdm"):
|
||||
try:
|
||||
module = sys.modules[module_name]
|
||||
for attr_name in tqdm_attr_names:
|
||||
if hasattr(module, attr_name):
|
||||
attr = getattr(module, attr_name)
|
||||
# Only patch if it's a tqdm class (not already patched)
|
||||
is_tqdm_class = (
|
||||
attr is self._original_tqdm_class
|
||||
or (self._original_tqdm_auto and attr is self._original_tqdm_auto)
|
||||
or (
|
||||
hasattr(attr, "__name__")
|
||||
and attr.__name__ == "tqdm"
|
||||
and hasattr(attr, "update")
|
||||
) # tqdm classes have update method
|
||||
)
|
||||
if is_tqdm_class:
|
||||
key = f"{module_name}.{attr_name}"
|
||||
self._patched_modules[key] = (module, attr_name, attr)
|
||||
setattr(module, attr_name, tracked_tqdm)
|
||||
patched_count += 1
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
# ALSO monkey-patch the update method on huggingface_hub's tqdm class
|
||||
# This is needed because the class was already defined at import time
|
||||
self._hf_tqdm_original_update = None
|
||||
try:
|
||||
from huggingface_hub.utils import tqdm as hf_tqdm_module
|
||||
|
||||
if hasattr(hf_tqdm_module, "tqdm"):
|
||||
hf_tqdm_class = hf_tqdm_module.tqdm
|
||||
self._hf_tqdm_original_update = hf_tqdm_class.update
|
||||
|
||||
# Create a wrapper that calls our tracking
|
||||
tracker = self # Reference to HFProgressTracker instance
|
||||
|
||||
def patched_update(tqdm_self, n=1):
|
||||
result = tracker._hf_tqdm_original_update(tqdm_self, n)
|
||||
|
||||
# Track this progress
|
||||
with tracker._lock:
|
||||
desc = getattr(tqdm_self, "desc", "") or ""
|
||||
current = getattr(tqdm_self, "n", 0)
|
||||
total = getattr(tqdm_self, "total", 0) or 0
|
||||
|
||||
# Skip non-byte progress bars
|
||||
if "fetching" in desc.lower():
|
||||
return result
|
||||
|
||||
# Skip until we have a meaningful total (at least 1MB)
|
||||
# This avoids the "100% at 0MB" issue when small config
|
||||
# files are counted before the real model files
|
||||
MIN_TOTAL_BYTES = 1_000_000 # 1MB
|
||||
if total >= MIN_TOTAL_BYTES:
|
||||
tracker._total_downloaded = current
|
||||
tracker._total_size = total
|
||||
|
||||
if tracker.progress_callback:
|
||||
tracker.progress_callback(current, total, desc)
|
||||
|
||||
return result
|
||||
|
||||
hf_tqdm_class.update = patched_update
|
||||
patched_count += 1
|
||||
logger.debug("Monkey-patched huggingface_hub.utils.tqdm.tqdm.update")
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.warning("Could not monkey-patch hf_tqdm: %s", e)
|
||||
|
||||
logger.debug("Patched %d tqdm references", patched_count)
|
||||
|
||||
yield
|
||||
|
||||
except ImportError:
|
||||
# If tqdm not available, just yield without patching
|
||||
yield
|
||||
finally:
|
||||
# Restore original tqdm
|
||||
if self._original_tqdm_class:
|
||||
try:
|
||||
import tqdm as tqdm_module
|
||||
|
||||
tqdm_module.tqdm = self._original_tqdm_class
|
||||
|
||||
if self._original_tqdm_auto:
|
||||
tqdm_module.auto.tqdm = self._original_tqdm_auto
|
||||
|
||||
# Restore patched modules
|
||||
for key, (module, attr_name, original) in self._patched_modules.items():
|
||||
try:
|
||||
if module and original:
|
||||
setattr(module, attr_name, original)
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
self._patched_modules = {}
|
||||
|
||||
# Restore hf_tqdm's original update method
|
||||
if self._hf_tqdm_original_update:
|
||||
try:
|
||||
from huggingface_hub.utils import tqdm as hf_tqdm_module
|
||||
|
||||
if hasattr(hf_tqdm_module, "tqdm"):
|
||||
hf_tqdm_module.tqdm.update = self._hf_tqdm_original_update
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
self._hf_tqdm_original_update = None
|
||||
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
def create_hf_progress_callback(model_name: str, progress_manager):
|
||||
"""Create a progress callback for HuggingFace downloads."""
|
||||
|
||||
def callback(downloaded: int, total: int, filename: str = ""):
|
||||
"""Progress callback.
|
||||
|
||||
Note: We send updates even when total=0 (unknown) to provide feedback
|
||||
during the "incomplete total" phase of huggingface_hub downloads.
|
||||
The frontend handles total=0 gracefully.
|
||||
"""
|
||||
progress_manager.update_progress(
|
||||
model_name=model_name,
|
||||
current=downloaded,
|
||||
total=total,
|
||||
filename=filename or "",
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
return callback
|
||||
114
backend/utils/images.py
Normal file
114
backend/utils/images.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Image processing utilities for avatar uploads."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
|
||||
# JPEG can be reported as 'JPEG' or 'MPO' (for multi-picture format from some cameras)
|
||||
ALLOWED_FORMATS = {'PNG', 'JPEG', 'WEBP', 'MPO', 'JPG'}
|
||||
MAX_SIZE = 512
|
||||
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
|
||||
|
||||
|
||||
def validate_image(file_path: str) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate image format and file size.
|
||||
|
||||
Args:
|
||||
file_path: Path to image file
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
# Check file size
|
||||
if path.stat().st_size > MAX_FILE_SIZE:
|
||||
return False, f"File size exceeds maximum of {MAX_FILE_SIZE // (1024 * 1024)}MB"
|
||||
|
||||
try:
|
||||
with Image.open(file_path) as img:
|
||||
# Verify the image can be loaded
|
||||
img.load()
|
||||
|
||||
# Check format (normalize JPEG variants)
|
||||
img_format = img.format
|
||||
if img_format in ('MPO', 'JPG'):
|
||||
img_format = 'JPEG'
|
||||
|
||||
if img_format not in {'PNG', 'JPEG', 'WEBP'}:
|
||||
return False, f"Invalid format '{img_format}'. Allowed formats: PNG, JPEG, WEBP"
|
||||
|
||||
return True, None
|
||||
except Exception as e:
|
||||
return False, f"Invalid image file: {str(e)}"
|
||||
|
||||
|
||||
def process_avatar(input_path: str, output_path: str, max_size: int = MAX_SIZE) -> None:
|
||||
"""
|
||||
Process avatar image: resize and optimize.
|
||||
|
||||
Resizes image to fit within max_size x max_size while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
input_path: Path to input image
|
||||
output_path: Path to save processed image
|
||||
max_size: Maximum width or height in pixels
|
||||
"""
|
||||
with Image.open(input_path) as img:
|
||||
# Handle EXIF orientation for JPEG images
|
||||
try:
|
||||
from PIL import ExifTags
|
||||
for orientation in ExifTags.TAGS.keys():
|
||||
if ExifTags.TAGS[orientation] == 'Orientation':
|
||||
break
|
||||
exif = img._getexif()
|
||||
if exif is not None:
|
||||
orientation_value = exif.get(orientation)
|
||||
if orientation_value == 3:
|
||||
img = img.rotate(180, expand=True)
|
||||
elif orientation_value == 6:
|
||||
img = img.rotate(270, expand=True)
|
||||
elif orientation_value == 8:
|
||||
img = img.rotate(90, expand=True)
|
||||
except (AttributeError, KeyError, IndexError, TypeError):
|
||||
# No EXIF data or orientation tag
|
||||
pass
|
||||
|
||||
# Convert to RGB if necessary (handles RGBA, P, CMYK, etc.)
|
||||
if img.mode not in ('RGB', 'L'):
|
||||
if img.mode == 'RGBA':
|
||||
# Create white background for RGBA images
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
background.paste(img, mask=img.split()[3]) # Use alpha channel as mask
|
||||
img = background
|
||||
elif img.mode == 'CMYK':
|
||||
# Convert CMYK to RGB
|
||||
img = img.convert('RGB')
|
||||
elif img.mode == 'P':
|
||||
# Convert palette mode to RGB
|
||||
img = img.convert('RGB')
|
||||
else:
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Calculate new size maintaining aspect ratio
|
||||
img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
|
||||
|
||||
# Determine output format from extension
|
||||
output_ext = Path(output_path).suffix.lower()
|
||||
|
||||
format_map = {
|
||||
'.png': 'PNG',
|
||||
'.jpeg': 'JPEG',
|
||||
'.jpg': 'JPEG',
|
||||
'.webp': 'WEBP'
|
||||
}
|
||||
|
||||
output_format = format_map.get(output_ext, 'PNG')
|
||||
|
||||
# Save with optimization
|
||||
save_kwargs = {'optimize': True}
|
||||
if output_format == 'JPEG':
|
||||
save_kwargs['quality'] = 90
|
||||
|
||||
img.save(output_path, format=output_format, **save_kwargs)
|
||||
35
backend/utils/platform_detect.py
Normal file
35
backend/utils/platform_detect.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Platform detection for backend selection.
|
||||
"""
|
||||
|
||||
import platform
|
||||
from typing import Literal
|
||||
|
||||
|
||||
def is_apple_silicon() -> bool:
|
||||
"""
|
||||
Check if running on Apple Silicon (arm64 macOS).
|
||||
|
||||
Returns:
|
||||
True if on Apple Silicon, False otherwise
|
||||
"""
|
||||
return platform.system() == "Darwin" and platform.machine() == "arm64"
|
||||
|
||||
|
||||
def get_backend_type() -> Literal["mlx", "pytorch"]:
|
||||
"""
|
||||
Detect the best backend for the current platform.
|
||||
|
||||
Returns:
|
||||
"mlx" on Apple Silicon (if MLX is available and functional), "pytorch" otherwise
|
||||
"""
|
||||
if is_apple_silicon():
|
||||
try:
|
||||
import mlx.core # noqa: F401 — triggers native lib loading
|
||||
return "mlx"
|
||||
except (ImportError, OSError, RuntimeError):
|
||||
# MLX not installed, or native libraries failed to load inside a
|
||||
# PyInstaller bundle (OSError on missing .dylib / .metallib).
|
||||
# Fall through to PyTorch.
|
||||
return "pytorch"
|
||||
return "pytorch"
|
||||
315
backend/utils/progress.py
Normal file
315
backend/utils/progress.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
Progress tracking for model downloads using Server-Sent Events.
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable, Dict, List
|
||||
from fastapi.responses import StreamingResponse
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ProgressManager:
|
||||
"""Manages download progress for multiple models.
|
||||
|
||||
Thread-safe: can be called from background threads (e.g., via asyncio.to_thread).
|
||||
"""
|
||||
|
||||
# Throttle settings to prevent overwhelming SSE clients
|
||||
THROTTLE_INTERVAL_SECONDS = 0.5 # Minimum time between updates
|
||||
THROTTLE_PROGRESS_DELTA = 1.0 # Minimum progress change (%) to force update
|
||||
|
||||
def __init__(self):
|
||||
self._progress: Dict[str, Dict] = {}
|
||||
self._listeners: Dict[str, list] = {}
|
||||
self._lock = threading.Lock() # Thread-safe lock for progress dict
|
||||
self._main_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._last_notify_time: Dict[str, float] = {} # Last notification time per model
|
||||
self._last_notify_progress: Dict[str, float] = {} # Last notified progress per model
|
||||
|
||||
def _set_main_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
"""Set the main event loop for thread-safe operations."""
|
||||
self._main_loop = loop
|
||||
|
||||
def _notify_listeners_threadsafe(self, model_name: str, progress_data: Dict):
|
||||
"""Notify listeners in a thread-safe manner."""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if model_name not in self._listeners:
|
||||
return
|
||||
|
||||
for queue in self._listeners[model_name]:
|
||||
try:
|
||||
# Check if we're in the main event loop thread
|
||||
try:
|
||||
running_loop = asyncio.get_running_loop()
|
||||
# We're in an async context, can use put_nowait directly
|
||||
queue.put_nowait(progress_data.copy())
|
||||
except RuntimeError:
|
||||
# Not in async context (running in background thread)
|
||||
# Use call_soon_threadsafe to safely put on queue
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
self._main_loop.call_soon_threadsafe(
|
||||
lambda q=queue, d=progress_data.copy(): q.put_nowait(d) if not q.full() else None
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No main loop available for {model_name}, skipping notification")
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"Queue full for {model_name}, dropping update")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error notifying listener for {model_name}: {e}")
|
||||
|
||||
def update_progress(
|
||||
self,
|
||||
model_name: str,
|
||||
current: int,
|
||||
total: int,
|
||||
filename: Optional[str] = None,
|
||||
status: str = "downloading",
|
||||
):
|
||||
"""
|
||||
Update progress for a model download.
|
||||
|
||||
Thread-safe: can be called from background threads.
|
||||
|
||||
Progress updates are throttled to prevent overwhelming SSE clients.
|
||||
Updates are sent at most every THROTTLE_INTERVAL_SECONDS, or when
|
||||
progress changes by at least THROTTLE_PROGRESS_DELTA percent.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (e.g., "qwen-tts-1.7B", "whisper-base")
|
||||
current: Current bytes downloaded
|
||||
total: Total bytes to download
|
||||
filename: Current file being downloaded
|
||||
status: Status string (downloading, extracting, complete, error)
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Calculate progress percentage, clamped to 0-100 range
|
||||
# This prevents crazy percentages from edge cases like:
|
||||
# - current > total temporarily during aggregation
|
||||
# - mixing file-count progress with byte-count progress
|
||||
if total > 0:
|
||||
progress_pct = min(100.0, max(0.0, (current / total * 100)))
|
||||
else:
|
||||
progress_pct = 0
|
||||
|
||||
progress_data = {
|
||||
"model_name": model_name,
|
||||
"current": current,
|
||||
"total": total,
|
||||
"progress": progress_pct,
|
||||
"filename": filename,
|
||||
"status": status,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# Thread-safe update of progress dict (always update internal state)
|
||||
with self._lock:
|
||||
self._progress[model_name] = progress_data
|
||||
|
||||
# Check if we should notify listeners (throttling)
|
||||
current_time = time.time()
|
||||
last_time = self._last_notify_time.get(model_name, 0)
|
||||
last_progress = self._last_notify_progress.get(model_name, -100)
|
||||
|
||||
time_delta = current_time - last_time
|
||||
progress_delta = abs(progress_pct - last_progress)
|
||||
|
||||
# Always notify for complete/error status, or if throttle conditions are met
|
||||
should_notify = (
|
||||
status in ("complete", "error") or
|
||||
time_delta >= self.THROTTLE_INTERVAL_SECONDS or
|
||||
progress_delta >= self.THROTTLE_PROGRESS_DELTA
|
||||
)
|
||||
|
||||
if not should_notify:
|
||||
return # Skip this update (throttled)
|
||||
|
||||
# Update throttle tracking
|
||||
self._last_notify_time[model_name] = current_time
|
||||
self._last_notify_progress[model_name] = progress_pct
|
||||
|
||||
# Notify all listeners (thread-safe)
|
||||
listener_count = len(self._listeners.get(model_name, []))
|
||||
|
||||
if listener_count > 0:
|
||||
logger.debug(f"Notifying {listener_count} listeners for {model_name}: {progress_pct:.1f}% ({filename})")
|
||||
self._notify_listeners_threadsafe(model_name, progress_data)
|
||||
else:
|
||||
logger.debug(f"No listeners for {model_name}, progress update stored: {progress_pct:.1f}%")
|
||||
|
||||
def get_progress(self, model_name: str) -> Optional[Dict]:
|
||||
"""Get current progress for a model. Thread-safe."""
|
||||
with self._lock:
|
||||
progress = self._progress.get(model_name)
|
||||
return progress.copy() if progress else None
|
||||
|
||||
def get_all_active(self) -> List[Dict]:
|
||||
"""Get all active downloads (status is 'downloading' or 'extracting'). Thread-safe."""
|
||||
active = []
|
||||
with self._lock:
|
||||
for model_name, progress in self._progress.items():
|
||||
status = progress.get("status", "")
|
||||
if status in ("downloading", "extracting"):
|
||||
active.append(progress.copy())
|
||||
return active
|
||||
|
||||
def create_progress_callback(self, model_name: str, filename: Optional[str] = None):
|
||||
"""
|
||||
Create a progress callback function for HuggingFace downloads.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
filename: Optional filename filter
|
||||
|
||||
Returns:
|
||||
Callback function
|
||||
"""
|
||||
def callback(progress: Dict):
|
||||
"""HuggingFace Hub progress callback."""
|
||||
if "total" in progress and "current" in progress:
|
||||
current = progress.get("current", 0)
|
||||
total = progress.get("total", 0)
|
||||
file_name = progress.get("filename", filename)
|
||||
|
||||
self.update_progress(
|
||||
model_name=model_name,
|
||||
current=current,
|
||||
total=total,
|
||||
filename=file_name,
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
async def subscribe(self, model_name: str):
|
||||
"""
|
||||
Subscribe to progress updates for a model.
|
||||
|
||||
Yields progress updates as Server-Sent Events.
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Store the main event loop for thread-safe operations
|
||||
try:
|
||||
self._main_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
queue = asyncio.Queue(maxsize=10)
|
||||
|
||||
# Add to listeners
|
||||
if model_name not in self._listeners:
|
||||
self._listeners[model_name] = []
|
||||
self._listeners[model_name].append(queue)
|
||||
|
||||
logger.info(f"SSE client subscribed to {model_name}, total listeners: {len(self._listeners[model_name])}")
|
||||
|
||||
try:
|
||||
# Send initial progress if available and still in progress (thread-safe read)
|
||||
with self._lock:
|
||||
initial_progress = self._progress.get(model_name)
|
||||
if initial_progress:
|
||||
initial_progress = initial_progress.copy()
|
||||
|
||||
if initial_progress:
|
||||
status = initial_progress.get('status')
|
||||
# Only send initial progress if download is actually in progress
|
||||
# Don't send old 'complete' or 'error' status from previous downloads
|
||||
if status in ('downloading', 'extracting'):
|
||||
logger.info(f"Sending initial progress for {model_name}: {status}")
|
||||
yield f"data: {json.dumps(initial_progress)}\n\n"
|
||||
else:
|
||||
logger.info(f"Skipping initial progress for {model_name} (status: {status})")
|
||||
else:
|
||||
logger.info(f"No initial progress available for {model_name}")
|
||||
|
||||
# Stream updates
|
||||
while True:
|
||||
try:
|
||||
# Wait for update with timeout
|
||||
progress = await asyncio.wait_for(queue.get(), timeout=1.0)
|
||||
logger.debug(f"Sending progress update for {model_name}: {progress.get('status')} - {progress.get('progress', 0):.1f}%")
|
||||
yield f"data: {json.dumps(progress)}\n\n"
|
||||
|
||||
# Stop if complete or error
|
||||
if progress.get("status") in ("complete", "error"):
|
||||
logger.info(f"Download {progress.get('status')} for {model_name}, closing SSE connection")
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat
|
||||
yield ": heartbeat\n\n"
|
||||
continue
|
||||
except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError):
|
||||
logger.debug(f"SSE client disconnected from {model_name}")
|
||||
finally:
|
||||
# Remove from listeners
|
||||
if model_name in self._listeners:
|
||||
self._listeners[model_name].remove(queue)
|
||||
if not self._listeners[model_name]:
|
||||
del self._listeners[model_name]
|
||||
logger.info(f"SSE client unsubscribed from {model_name}, remaining listeners: {len(self._listeners.get(model_name, []))}")
|
||||
|
||||
def mark_complete(self, model_name: str):
|
||||
"""Mark a model download as complete. Thread-safe."""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with self._lock:
|
||||
if model_name in self._progress:
|
||||
self._progress[model_name]["status"] = "complete"
|
||||
self._progress[model_name]["progress"] = 100.0
|
||||
progress_data = self._progress[model_name].copy()
|
||||
else:
|
||||
logger.warning(f"Cannot mark {model_name} as complete: not found in progress")
|
||||
return
|
||||
|
||||
logger.info(f"Marked {model_name} as complete")
|
||||
# Notify listeners (thread-safe)
|
||||
self._notify_listeners_threadsafe(model_name, progress_data)
|
||||
|
||||
def mark_error(self, model_name: str, error: str):
|
||||
"""Mark a model download as failed. Thread-safe."""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with self._lock:
|
||||
if model_name in self._progress:
|
||||
self._progress[model_name]["status"] = "error"
|
||||
self._progress[model_name]["error"] = error
|
||||
progress_data = self._progress[model_name].copy()
|
||||
else:
|
||||
# Create new progress entry for error
|
||||
progress_data = {
|
||||
"model_name": model_name,
|
||||
"current": 0,
|
||||
"total": 0,
|
||||
"progress": 0,
|
||||
"filename": None,
|
||||
"status": "error",
|
||||
"error": error,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
self._progress[model_name] = progress_data
|
||||
|
||||
logger.error(f"Marked {model_name} as error: {error}")
|
||||
# Notify listeners (thread-safe)
|
||||
self._notify_listeners_threadsafe(model_name, progress_data)
|
||||
|
||||
|
||||
# Global progress manager instance
|
||||
_progress_manager: Optional[ProgressManager] = None
|
||||
|
||||
|
||||
def get_progress_manager() -> ProgressManager:
|
||||
"""Get or create the global progress manager."""
|
||||
global _progress_manager
|
||||
if _progress_manager is None:
|
||||
_progress_manager = ProgressManager()
|
||||
return _progress_manager
|
||||
102
backend/utils/tasks.py
Normal file
102
backend/utils/tasks.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Task tracking for active downloads and generations.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, List
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadTask:
|
||||
"""Represents an active download task."""
|
||||
model_name: str
|
||||
status: str = "downloading" # downloading, extracting, complete, error
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationTask:
|
||||
"""Represents an active generation task."""
|
||||
task_id: str
|
||||
profile_id: str
|
||||
text_preview: str # First 50 chars of text
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""Manages active downloads and generations."""
|
||||
|
||||
def __init__(self):
|
||||
self._active_downloads: Dict[str, DownloadTask] = {}
|
||||
self._active_generations: Dict[str, GenerationTask] = {}
|
||||
|
||||
def start_download(self, model_name: str) -> None:
|
||||
"""Mark a download as started."""
|
||||
self._active_downloads[model_name] = DownloadTask(
|
||||
model_name=model_name,
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
def complete_download(self, model_name: str) -> None:
|
||||
"""Mark a download as complete."""
|
||||
if model_name in self._active_downloads:
|
||||
del self._active_downloads[model_name]
|
||||
|
||||
def error_download(self, model_name: str, error: str) -> None:
|
||||
"""Mark a download as failed."""
|
||||
if model_name in self._active_downloads:
|
||||
self._active_downloads[model_name].status = "error"
|
||||
self._active_downloads[model_name].error = error
|
||||
|
||||
def start_generation(self, task_id: str, profile_id: str, text: str) -> None:
|
||||
"""Mark a generation as started."""
|
||||
text_preview = text[:50] + "..." if len(text) > 50 else text
|
||||
self._active_generations[task_id] = GenerationTask(
|
||||
task_id=task_id,
|
||||
profile_id=profile_id,
|
||||
text_preview=text_preview,
|
||||
)
|
||||
|
||||
def complete_generation(self, task_id: str) -> None:
|
||||
"""Mark a generation as complete."""
|
||||
if task_id in self._active_generations:
|
||||
del self._active_generations[task_id]
|
||||
|
||||
def get_active_downloads(self) -> List[DownloadTask]:
|
||||
"""Get all active downloads."""
|
||||
return list(self._active_downloads.values())
|
||||
|
||||
def get_active_generations(self) -> List[GenerationTask]:
|
||||
"""Get all active generations."""
|
||||
return list(self._active_generations.values())
|
||||
|
||||
def cancel_download(self, model_name: str) -> bool:
|
||||
"""Cancel/dismiss a download task (removes it from active list)."""
|
||||
return self._active_downloads.pop(model_name, None) is not None
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all download and generation tasks."""
|
||||
self._active_downloads.clear()
|
||||
self._active_generations.clear()
|
||||
|
||||
def is_download_active(self, model_name: str) -> bool:
|
||||
"""Check if a download is active."""
|
||||
return model_name in self._active_downloads
|
||||
|
||||
def is_generation_active(self, task_id: str) -> bool:
|
||||
"""Check if a generation is active."""
|
||||
return task_id in self._active_generations
|
||||
|
||||
|
||||
# Global task manager instance
|
||||
_task_manager: Optional[TaskManager] = None
|
||||
|
||||
|
||||
def get_task_manager() -> TaskManager:
|
||||
"""Get or create the global task manager."""
|
||||
global _task_manager
|
||||
if _task_manager is None:
|
||||
_task_manager = TaskManager()
|
||||
return _task_manager
|
||||
Reference in New Issue
Block a user