Initial commit
This commit is contained in:
220
backend/tests/E2E_MODEL_TEST_DESIGN.md
Normal file
220
backend/tests/E2E_MODEL_TEST_DESIGN.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# End-to-End Model Generation Test — Design
|
||||
|
||||
## Goal
|
||||
|
||||
A single script, runnable on macOS and Windows, that exercises every TTS model against the **frozen PyInstaller binary** (not the dev server), captures per-model pass/fail and error messages, and exits non-zero if any model fails. Generation is strictly sequential — one model loaded at a time.
|
||||
|
||||
## Test matrix (10 runs)
|
||||
|
||||
Derived from `backend/backends/__init__.py:185-316`. Each row maps to one `POST /generate` call.
|
||||
|
||||
| # | engine | model_size | profile kind | notes |
|
||||
|---|-----------------------|------------|--------------|-------|
|
||||
| 1 | `qwen` | `1.7B` | cloned | reference audio required |
|
||||
| 2 | `qwen` | `0.6B` | cloned | |
|
||||
| 3 | `qwen_custom_voice` | `1.7B` | preset | `preset_voice_id="Ryan"` |
|
||||
| 4 | `qwen_custom_voice` | `0.6B` | preset | `preset_voice_id="Ryan"` |
|
||||
| 5 | `luxtts` | — | cloned | English only |
|
||||
| 6 | `chatterbox` | — | cloned | |
|
||||
| 7 | `chatterbox_turbo` | — | cloned | English only |
|
||||
| 8 | `tada` | `1B` | cloned | tada-1b, English only |
|
||||
| 9 | `tada` | `3B` | cloned | tada-3b-ml, multilingual |
|
||||
| 10| `kokoro` | — | preset | `preset_voice_id="af_heart"` |
|
||||
|
||||
Cloned engines (1, 2, 5, 6, 7, 8, 9) share **one** profile created once with the reference WAV. Preset profiles are created separately, one for kokoro and one for qwen_custom_voice.
|
||||
|
||||
Language for every run: `en` (covers every engine's supported set).
|
||||
|
||||
## End-to-end flow
|
||||
|
||||
```
|
||||
1. Resolve paths → find binary, build if missing
|
||||
2. Launch binary → spawn with --port --data-dir --parent-pid
|
||||
3. Wait for /health → poll until status=="healthy" or 120s timeout
|
||||
4. Create profiles → 1 cloned + 2 preset, via /profiles (+ /samples)
|
||||
5. For each (engine, model_size) in matrix:
|
||||
a. Check cache → GET /models/status → cached? short timeout : long
|
||||
b. POST /generate → get generation_id
|
||||
c. Stream /status → consume SSE until completed/failed/timeout
|
||||
d. Record result → {engine, model_size, status, duration, error, elapsed}
|
||||
6. Write results → JSON + Markdown table to ./results/
|
||||
7. Shutdown binary → SIGTERM, fall back to kill, verify port freed
|
||||
8. Exit code → 0 if all passed, 1 otherwise
|
||||
```
|
||||
|
||||
## Binary resolution
|
||||
|
||||
Search order — **first hit wins**:
|
||||
|
||||
| Platform | Path | Build type |
|
||||
|----------|------|------------|
|
||||
| macOS | `backend/dist/voicebox-server-cuda/voicebox-server-cuda` | onedir (CUDA, rarely on Mac) |
|
||||
| macOS | `backend/dist/voicebox-server` | onefile (CPU) |
|
||||
| Windows | `backend\dist\voicebox-server-cuda\voicebox-server-cuda.exe` | onedir (CUDA) |
|
||||
| Windows | `backend\dist\voicebox-server.exe` | onefile (CPU) |
|
||||
|
||||
If none exist, run `python backend/build_binary.py` and wait for it to finish (can take 5-20 min). Fail with a clear error if the build itself fails. `--skip-build` flag forces "error out if no binary" instead of building.
|
||||
|
||||
## Spawn command
|
||||
|
||||
Mirrors Tauri's launch in `tauri/src-tauri/src/main.rs:369-388`:
|
||||
|
||||
```
|
||||
<binary> --host 127.0.0.1 --port <free-port> --data-dir <tempdir> --parent-pid <test-pid>
|
||||
```
|
||||
|
||||
- **Port**: bind to `0` first in Python to grab a free port, then pass that number.
|
||||
- **Data dir**: `tempfile.mkdtemp(prefix="voicebox-e2e-")`. Deleted after the run unless `--keep-data-dir`. Profiles and generated WAVs land here.
|
||||
- **Parent PID**: current Python PID — ensures the backend dies if the test crashes (watchdog in `server.py:102-224`).
|
||||
- **stdout/stderr**: tee to both a log file in `./results/server-<timestamp>.log` and a rolling in-memory buffer. On model failure, last 100 lines of the buffer are attached to that model's error record.
|
||||
|
||||
## Profile setup
|
||||
|
||||
One cloned profile shared across all cloning engines:
|
||||
|
||||
```http
|
||||
POST /profiles
|
||||
{
|
||||
"name": "e2e-cloned",
|
||||
"voice_type": "cloned",
|
||||
"language": "en"
|
||||
}
|
||||
```
|
||||
|
||||
Then:
|
||||
|
||||
```http
|
||||
POST /profiles/{id}/samples (multipart)
|
||||
file: <reference WAV>
|
||||
reference_text: <exact transcription>
|
||||
```
|
||||
|
||||
Two preset profiles:
|
||||
|
||||
```http
|
||||
POST /profiles
|
||||
{ "name": "e2e-kokoro", "voice_type": "preset", "language": "en",
|
||||
"preset_engine": "kokoro", "preset_voice_id": "af_heart" }
|
||||
|
||||
POST /profiles
|
||||
{ "name": "e2e-qwen-cv", "voice_type": "preset", "language": "en",
|
||||
"preset_engine": "qwen_custom_voice", "preset_voice_id": "Ryan" }
|
||||
```
|
||||
|
||||
## Generation request (per matrix row)
|
||||
|
||||
```http
|
||||
POST /generate
|
||||
{
|
||||
"profile_id": "<appropriate profile>",
|
||||
"text": "The quick brown fox jumps over the lazy dog.",
|
||||
"language": "en",
|
||||
"engine": "<engine>",
|
||||
"model_size": "<size or omitted>",
|
||||
"seed": 42,
|
||||
"normalize": true
|
||||
}
|
||||
```
|
||||
|
||||
Response `id` feeds into the SSE status loop (`GET /generate/{id}/status`, `routes/generations.py:190-227`). Loop reads lines until a payload with `status in ("completed", "failed")` arrives, then breaks.
|
||||
|
||||
## Timeout strategy (split)
|
||||
|
||||
Check `GET /models/status` for the target model **before** generation:
|
||||
|
||||
| Cached? | Per-model timeout | Rationale |
|
||||
|---------|-------------------|-----------|
|
||||
| Yes | **3 minutes** | Inference only; generous for CPU builds |
|
||||
| No | **20 minutes** | First-run HF download up to 8 GB (tada-3b-ml) |
|
||||
|
||||
On timeout: cancel the SSE stream, mark the row `timeout`, and continue to the next row. Don't abort the whole run on one timeout.
|
||||
|
||||
## Result format
|
||||
|
||||
`./results/e2e-<platform>-<arch>-<timestamp>.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"platform": "darwin-arm64",
|
||||
"binary": "/abs/path/voicebox-server",
|
||||
"binary_size_mb": 612,
|
||||
"started_at": "2026-04-16T12:34:56Z",
|
||||
"finished_at": "...",
|
||||
"results": [
|
||||
{
|
||||
"engine": "qwen",
|
||||
"model_size": "1.7B",
|
||||
"status": "passed|failed|timeout",
|
||||
"generation_id": "...",
|
||||
"was_cached": true,
|
||||
"elapsed_seconds": 12.4,
|
||||
"audio_duration": 3.1,
|
||||
"audio_path": "/tmp/.../gen.wav",
|
||||
"error": null,
|
||||
"server_log_tail": null
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Companion `./results/e2e-<...>.md`:
|
||||
|
||||
```
|
||||
# Voicebox E2E — darwin-arm64 — 2026-04-16 12:34
|
||||
|
||||
| Engine | Size | Status | Elapsed | Error |
|
||||
|---------------------|------|--------|---------|-------|
|
||||
| qwen | 1.7B | PASS | 12.4s | |
|
||||
| qwen | 0.6B | FAIL | 4.1s | CUDA OOM: ... |
|
||||
...
|
||||
```
|
||||
|
||||
## CLI flags
|
||||
|
||||
```
|
||||
python -m backend.tests.test_all_models_e2e [flags]
|
||||
|
||||
--binary PATH Use this binary instead of auto-detecting
|
||||
--skip-build Error if no binary found (no auto-build)
|
||||
--reference-wav PATH Reference audio (default: backend/tests/fixtures/reference_voice.wav)
|
||||
--reference-text STR Transcription (default: read from fixtures/reference_voice.txt)
|
||||
--only ENGINE[,...] Run only these engines (e.g. kokoro,qwen)
|
||||
--skip ENGINE[,...] Skip these engines
|
||||
--keep-data-dir Don't delete tempdir after run
|
||||
--timeout-cached SEC Override 180
|
||||
--timeout-download SEC Override 1200
|
||||
--port N Override auto-picked port
|
||||
--output-dir PATH Default: backend/tests/results/
|
||||
```
|
||||
|
||||
## File layout
|
||||
|
||||
```
|
||||
backend/tests/
|
||||
├── E2E_MODEL_TEST_DESIGN.md (this file)
|
||||
├── test_all_models_e2e.py (main script, ~400-500 LoC)
|
||||
├── fixtures/
|
||||
│ ├── reference_voice.wav (user-provided, ~5-15s clean speech)
|
||||
│ └── reference_voice.txt (exact transcription)
|
||||
└── results/ (gitignored)
|
||||
├── e2e-darwin-arm64-<ts>.json
|
||||
├── e2e-darwin-arm64-<ts>.md
|
||||
└── server-<ts>.log
|
||||
```
|
||||
|
||||
The script uses only stdlib + `httpx` (or `requests`) + `sseclient-py` — all already in `backend/requirements.txt`. No pytest to keep it invocable as a single command on fresh checkouts.
|
||||
|
||||
## Safety & cleanup
|
||||
|
||||
- Always kill the spawned binary in a `try/finally`. On Windows, `taskkill /F /T` the whole tree (Tauri does the same).
|
||||
- Verify the port is free on shutdown (Tauri port-reuse check in `main.rs:114-186` could otherwise pick up a ghost).
|
||||
- Don't touch the user's HF cache by default — let the server use `HF_HUB_CACHE` / `VOICEBOX_MODELS_DIR`. Passing `--isolated-cache` would point both env vars at the tempdir for a true cold-start run (opt-in only; would re-download every time).
|
||||
|
||||
## Non-goals
|
||||
|
||||
- Not validating audio quality (no WER, no waveform comparison). Pass = "endpoint returned `completed` and produced a non-empty WAV".
|
||||
- Not testing STT (Whisper), effects chains, channels, or streaming endpoints.
|
||||
- Not running on CI today — human-invoked on dev machines. CI integration is a follow-up once the script is stable.
|
||||
- No model unload between runs — models stay loaded; server manages its own eviction.
|
||||
- No version-drift check on the binary.
|
||||
- No `instruct` parameter exercised on qwen_custom_voice runs.
|
||||
58
backend/tests/README.md
Normal file
58
backend/tests/README.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# Backend Tests
|
||||
|
||||
Manual test scripts for debugging and validating backend functionality.
|
||||
|
||||
## Test Files
|
||||
|
||||
### `test_generation_progress.py`
|
||||
Tests TTS generation with SSE progress monitoring to identify UX issues where users see download progress even when the model is already cached.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
cd backend
|
||||
python tests/test_generation_progress.py
|
||||
```
|
||||
|
||||
**Prerequisites:**
|
||||
- Server must be running (`python main.py`)
|
||||
- At least one voice profile must exist
|
||||
|
||||
### `test_real_download.py`
|
||||
Tests real model download with SSE progress monitoring.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
cd backend
|
||||
# Delete cache first to force fresh download
|
||||
rm -rf ~/.cache/huggingface/hub/models--openai--whisper-base
|
||||
python tests/test_real_download.py
|
||||
```
|
||||
|
||||
**Prerequisites:**
|
||||
- Server must be running (`python main.py`)
|
||||
|
||||
### `test_progress.py`
|
||||
Unit tests for ProgressManager and HFProgressTracker functionality.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
cd backend
|
||||
python tests/test_progress.py
|
||||
```
|
||||
|
||||
### `test_check_progress_state.py`
|
||||
Debugging script to inspect the internal state of ProgressManager and TaskManager.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
cd backend
|
||||
python tests/test_check_progress_state.py
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
These are manual test scripts, not automated unit tests. They're designed for:
|
||||
- Debugging progress tracking issues
|
||||
- Validating SSE event streams
|
||||
- Monitoring real-time download behavior
|
||||
- Inspecting internal state during development
|
||||
6
backend/tests/__init__.py
Normal file
6
backend/tests/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Test suite for Voicebox backend.
|
||||
|
||||
This directory contains manual test scripts for debugging and validating
|
||||
progress tracking, model downloads, and generation functionality.
|
||||
"""
|
||||
16
backend/tests/fixtures/README.md
vendored
Normal file
16
backend/tests/fixtures/README.md
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
# E2E Test Fixtures
|
||||
|
||||
Place two files here before running `test_all_models_e2e.py`:
|
||||
|
||||
- `reference_voice.wav` — a clean speech sample, mono, 16–24 kHz, ~5–15 seconds.
|
||||
- `reference_voice.txt` — the **exact** transcription of the WAV (single line, no trailing newline required).
|
||||
|
||||
These are used to create a cloned voice profile for every cloning-capable engine (qwen, luxtts, chatterbox, chatterbox_turbo, tada). Keep them out of version control if they contain personal audio — this directory is not gitignored by default, so add them to `.gitignore` locally if needed.
|
||||
|
||||
You can point the test at different files with:
|
||||
|
||||
```
|
||||
python backend/tests/test_all_models_e2e.py \
|
||||
--reference-wav /path/to/your.wav \
|
||||
--reference-text "exact transcription here"
|
||||
```
|
||||
630
backend/tests/test_all_models_e2e.py
Normal file
630
backend/tests/test_all_models_e2e.py
Normal file
@@ -0,0 +1,630 @@
|
||||
"""
|
||||
End-to-end model generation test.
|
||||
|
||||
Exercises every TTS model against the frozen PyInstaller binary, captures
|
||||
per-model pass/fail, and writes a JSON + Markdown report.
|
||||
|
||||
Usage:
|
||||
python backend/tests/test_all_models_e2e.py [flags]
|
||||
|
||||
See E2E_MODEL_TEST_DESIGN.md for the full design.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
BACKEND_DIR = REPO_ROOT / "backend"
|
||||
DIST_DIR = BACKEND_DIR / "dist"
|
||||
FIXTURES_DIR = Path(__file__).resolve().parent / "fixtures"
|
||||
RESULTS_DIR = Path(__file__).resolve().parent / "results"
|
||||
|
||||
|
||||
# ── Test matrix ──────────────────────────────────────────────────────
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MatrixRow:
|
||||
label: str # human-readable (appears in report)
|
||||
engine: str # /generate engine
|
||||
model_size: Optional[str] # /generate model_size (None = omit)
|
||||
profile_kind: str # "cloned" | "preset_kokoro" | "preset_qwen_cv"
|
||||
model_name: str # /models/status key for cache lookup
|
||||
|
||||
|
||||
MATRIX: list[MatrixRow] = [
|
||||
MatrixRow("qwen 1.7B", "qwen", "1.7B", "cloned", "qwen-tts-1.7B"),
|
||||
MatrixRow("qwen 0.6B", "qwen", "0.6B", "cloned", "qwen-tts-0.6B"),
|
||||
MatrixRow("qwen_custom_voice 1.7B", "qwen_custom_voice", "1.7B", "preset_qwen_cv", "qwen-custom-voice-1.7B"),
|
||||
MatrixRow("qwen_custom_voice 0.6B", "qwen_custom_voice", "0.6B", "preset_qwen_cv", "qwen-custom-voice-0.6B"),
|
||||
MatrixRow("luxtts", "luxtts", None, "cloned", "luxtts"),
|
||||
MatrixRow("chatterbox", "chatterbox", None, "cloned", "chatterbox-tts"),
|
||||
MatrixRow("chatterbox_turbo", "chatterbox_turbo", None, "cloned", "chatterbox-turbo"),
|
||||
MatrixRow("tada 1B", "tada", "1B", "cloned", "tada-1b"),
|
||||
MatrixRow("tada 3B", "tada", "3B", "cloned", "tada-3b-ml"),
|
||||
MatrixRow("kokoro", "kokoro", None, "preset_kokoro", "kokoro"),
|
||||
]
|
||||
|
||||
TEXT = "The quick brown fox jumps over the lazy dog."
|
||||
DEFAULT_TIMEOUT_CACHED = 180
|
||||
DEFAULT_TIMEOUT_DOWNLOAD = 1200
|
||||
HEALTH_TIMEOUT = 120
|
||||
|
||||
|
||||
# ── Result record ────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class ModelResult:
|
||||
label: str
|
||||
engine: str
|
||||
model_size: Optional[str]
|
||||
status: str # "passed" | "failed" | "timeout"
|
||||
was_cached: Optional[bool] = None
|
||||
generation_id: Optional[str] = None
|
||||
elapsed_seconds: float = 0.0
|
||||
audio_duration: Optional[float] = None
|
||||
audio_path: Optional[str] = None
|
||||
audio_bytes: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
http_status: Optional[int] = None
|
||||
server_log_tail: Optional[list[str]] = None
|
||||
|
||||
|
||||
# ── Binary resolution ────────────────────────────────────────────────
|
||||
|
||||
def find_binary() -> Optional[Path]:
|
||||
"""Return the first existing binary in priority order, or None."""
|
||||
is_win = platform.system() == "Windows"
|
||||
exe = ".exe" if is_win else ""
|
||||
candidates = [
|
||||
DIST_DIR / "voicebox-server-cuda" / f"voicebox-server-cuda{exe}",
|
||||
DIST_DIR / f"voicebox-server{exe}",
|
||||
]
|
||||
for c in candidates:
|
||||
if c.exists() and c.is_file():
|
||||
return c
|
||||
return None
|
||||
|
||||
|
||||
def build_binary() -> Path:
|
||||
"""Invoke build_binary.py and return the resulting binary path."""
|
||||
print("[build] No frozen binary found — invoking build_binary.py (this may take 5-20 minutes)...", flush=True)
|
||||
script = BACKEND_DIR / "build_binary.py"
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(script)],
|
||||
cwd=str(BACKEND_DIR),
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"build_binary.py exited with code {result.returncode}")
|
||||
found = find_binary()
|
||||
if found is None:
|
||||
raise RuntimeError("build_binary.py finished but no binary was found in backend/dist/")
|
||||
return found
|
||||
|
||||
|
||||
# ── Server spawn + log capture ───────────────────────────────────────
|
||||
|
||||
class ServerProcess:
|
||||
def __init__(self, binary: Path, port: int, data_dir: Path, log_path: Path):
|
||||
self.binary = binary
|
||||
self.port = port
|
||||
self.data_dir = data_dir
|
||||
self.log_path = log_path
|
||||
self.proc: Optional[subprocess.Popen] = None
|
||||
self._log_buffer: deque[str] = deque(maxlen=500)
|
||||
self._reader_thread: Optional[threading.Thread] = None
|
||||
|
||||
def start(self) -> None:
|
||||
args = [
|
||||
str(self.binary),
|
||||
"--host", "127.0.0.1",
|
||||
"--port", str(self.port),
|
||||
"--data-dir", str(self.data_dir),
|
||||
"--parent-pid", str(os.getpid()),
|
||||
]
|
||||
print(f"[spawn] {' '.join(args)}", flush=True)
|
||||
self._log_fh = open(self.log_path, "w", encoding="utf-8", errors="replace")
|
||||
# Combine stderr into stdout so we get a single ordered stream.
|
||||
self.proc = subprocess.Popen(
|
||||
args,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
bufsize=1,
|
||||
text=True,
|
||||
errors="replace",
|
||||
)
|
||||
self._reader_thread = threading.Thread(target=self._pump_logs, daemon=True)
|
||||
self._reader_thread.start()
|
||||
|
||||
def _pump_logs(self) -> None:
|
||||
assert self.proc is not None and self.proc.stdout is not None
|
||||
for line in self.proc.stdout:
|
||||
self._log_buffer.append(line.rstrip("\n"))
|
||||
self._log_fh.write(line)
|
||||
self._log_fh.flush()
|
||||
|
||||
def log_tail(self, n: int = 100) -> list[str]:
|
||||
tail = list(self._log_buffer)[-n:]
|
||||
return tail
|
||||
|
||||
def is_alive(self) -> bool:
|
||||
return self.proc is not None and self.proc.poll() is None
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.proc is None:
|
||||
return
|
||||
if self.proc.poll() is not None:
|
||||
return
|
||||
try:
|
||||
if platform.system() == "Windows":
|
||||
subprocess.run(
|
||||
["taskkill", "/F", "/T", "/PID", str(self.proc.pid)],
|
||||
capture_output=True,
|
||||
)
|
||||
else:
|
||||
self.proc.send_signal(signal.SIGTERM)
|
||||
except Exception as e:
|
||||
print(f"[shutdown] signal failed: {e}", flush=True)
|
||||
try:
|
||||
self.proc.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
print("[shutdown] server didn't exit cleanly, killing", flush=True)
|
||||
self.proc.kill()
|
||||
try:
|
||||
self.proc.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
pass
|
||||
if self._reader_thread is not None:
|
||||
self._reader_thread.join(timeout=2)
|
||||
try:
|
||||
self._log_fh.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def pick_free_port() -> int:
|
||||
s = socket.socket()
|
||||
s.bind(("127.0.0.1", 0))
|
||||
port = s.getsockname()[1]
|
||||
s.close()
|
||||
return port
|
||||
|
||||
|
||||
# ── HTTP helpers ─────────────────────────────────────────────────────
|
||||
|
||||
def wait_for_health(base_url: str, server: ServerProcess, timeout: int) -> None:
|
||||
deadline = time.time() + timeout
|
||||
with httpx.Client(timeout=5.0) as client:
|
||||
while time.time() < deadline:
|
||||
if not server.is_alive():
|
||||
raise RuntimeError("Server process exited before becoming healthy")
|
||||
try:
|
||||
r = client.get(f"{base_url}/health")
|
||||
if r.status_code == 200 and r.json().get("status") == "healthy":
|
||||
return
|
||||
except httpx.HTTPError:
|
||||
pass
|
||||
time.sleep(1.0)
|
||||
raise TimeoutError(f"Server did not become healthy within {timeout}s")
|
||||
|
||||
|
||||
def get_model_cached(client: httpx.Client, base_url: str, model_name: str) -> Optional[bool]:
|
||||
try:
|
||||
r = client.get(f"{base_url}/models/status", timeout=30.0)
|
||||
r.raise_for_status()
|
||||
for m in r.json().get("models", []):
|
||||
if m.get("model_name") == model_name:
|
||||
return bool(m.get("downloaded"))
|
||||
except httpx.HTTPError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def create_cloned_profile(client: httpx.Client, base_url: str, wav_path: Path, reference_text: str) -> str:
|
||||
r = client.post(f"{base_url}/profiles", json={
|
||||
"name": "e2e-cloned",
|
||||
"voice_type": "cloned",
|
||||
"language": "en",
|
||||
})
|
||||
r.raise_for_status()
|
||||
profile_id = r.json()["id"]
|
||||
|
||||
with open(wav_path, "rb") as f:
|
||||
r = client.post(
|
||||
f"{base_url}/profiles/{profile_id}/samples",
|
||||
files={"file": (wav_path.name, f, "audio/wav")},
|
||||
data={"reference_text": reference_text},
|
||||
timeout=120.0,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return profile_id
|
||||
|
||||
|
||||
def create_preset_profile(client: httpx.Client, base_url: str, name: str, engine: str, voice_id: str) -> str:
|
||||
r = client.post(f"{base_url}/profiles", json={
|
||||
"name": name,
|
||||
"voice_type": "preset",
|
||||
"language": "en",
|
||||
"preset_engine": engine,
|
||||
"preset_voice_id": voice_id,
|
||||
})
|
||||
r.raise_for_status()
|
||||
return r.json()["id"]
|
||||
|
||||
|
||||
def run_one_generation(
|
||||
client: httpx.Client,
|
||||
base_url: str,
|
||||
row: MatrixRow,
|
||||
profile_id: str,
|
||||
timeout_s: int,
|
||||
) -> tuple[str, dict]:
|
||||
"""Start a generation and stream its status until done/failed/timeout.
|
||||
|
||||
Returns (status, payload) where status is "completed" | "failed" | "timeout".
|
||||
"""
|
||||
body = {
|
||||
"profile_id": profile_id,
|
||||
"text": TEXT,
|
||||
"language": "en",
|
||||
"engine": row.engine,
|
||||
"seed": 42,
|
||||
"normalize": True,
|
||||
}
|
||||
if row.model_size is not None:
|
||||
body["model_size"] = row.model_size
|
||||
|
||||
r = client.post(f"{base_url}/generate", json=body, timeout=30.0)
|
||||
r.raise_for_status()
|
||||
gen = r.json()
|
||||
gen_id = gen["id"]
|
||||
|
||||
deadline = time.time() + timeout_s
|
||||
last_payload: dict = gen
|
||||
status_url = f"{base_url}/generate/{gen_id}/status"
|
||||
|
||||
while time.time() < deadline:
|
||||
remaining = max(1.0, deadline - time.time())
|
||||
try:
|
||||
with client.stream("GET", status_url, timeout=httpx.Timeout(remaining + 5, read=remaining + 5)) as resp:
|
||||
resp.raise_for_status()
|
||||
for line in resp.iter_lines():
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
try:
|
||||
payload = json.loads(line[6:])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
last_payload = payload
|
||||
status = payload.get("status")
|
||||
if status == "not_found":
|
||||
return "failed", {"error": "generation not found", **payload}
|
||||
if status in ("completed", "failed"):
|
||||
return status, payload
|
||||
if time.time() >= deadline:
|
||||
break
|
||||
except httpx.HTTPError:
|
||||
time.sleep(1.0)
|
||||
continue
|
||||
|
||||
return "timeout", last_payload
|
||||
|
||||
|
||||
def fetch_audio_info(
|
||||
client: httpx.Client, base_url: str, generation_id: str, data_dir: Path
|
||||
) -> tuple[Optional[str], Optional[int]]:
|
||||
"""Return (audio_path, audio_bytes) for a completed generation.
|
||||
|
||||
Server stores audio_path relative to data_dir; resolve it to get a size.
|
||||
"""
|
||||
try:
|
||||
r = client.get(f"{base_url}/history/{generation_id}", timeout=10.0)
|
||||
if r.status_code != 200:
|
||||
return None, None
|
||||
data = r.json()
|
||||
audio_path = data.get("audio_path")
|
||||
if not audio_path:
|
||||
return None, None
|
||||
p = Path(audio_path)
|
||||
if not p.is_absolute():
|
||||
p = data_dir / p
|
||||
if p.exists():
|
||||
return str(p), p.stat().st_size
|
||||
return audio_path, None
|
||||
except httpx.HTTPError:
|
||||
return None, None
|
||||
|
||||
|
||||
# ── Report writers ───────────────────────────────────────────────────
|
||||
|
||||
def write_reports(
|
||||
output_dir: Path,
|
||||
binary: Path,
|
||||
started_at: datetime,
|
||||
finished_at: datetime,
|
||||
results: list[ModelResult],
|
||||
) -> tuple[Path, Path]:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
plat = f"{platform.system().lower()}-{platform.machine().lower()}"
|
||||
ts = started_at.strftime("%Y%m%d-%H%M%S")
|
||||
json_path = output_dir / f"e2e-{plat}-{ts}.json"
|
||||
md_path = output_dir / f"e2e-{plat}-{ts}.md"
|
||||
|
||||
doc = {
|
||||
"platform": plat,
|
||||
"binary": str(binary),
|
||||
"binary_size_mb": round(binary.stat().st_size / (1024 * 1024), 1) if binary.exists() else None,
|
||||
"started_at": started_at.isoformat(),
|
||||
"finished_at": finished_at.isoformat(),
|
||||
"elapsed_seconds": (finished_at - started_at).total_seconds(),
|
||||
"results": [asdict(r) for r in results],
|
||||
}
|
||||
json_path.write_text(json.dumps(doc, indent=2))
|
||||
|
||||
lines = [
|
||||
f"# Voicebox E2E — {plat} — {started_at.strftime('%Y-%m-%d %H:%M UTC')}",
|
||||
"",
|
||||
f"Binary: `{binary}` ",
|
||||
f"Elapsed: {doc['elapsed_seconds']:.1f}s",
|
||||
"",
|
||||
"| Model | Status | Cached | Elapsed | Audio | Error |",
|
||||
"|-------|--------|--------|---------|-------|-------|",
|
||||
]
|
||||
for r in results:
|
||||
status_icon = {"passed": "PASS", "failed": "FAIL", "timeout": "TIMEOUT"}.get(r.status, r.status.upper())
|
||||
cached = "yes" if r.was_cached else ("no" if r.was_cached is False else "?")
|
||||
audio_col = f"{r.audio_duration:.2f}s" if r.audio_duration else ("—" if r.status != "passed" else "?")
|
||||
error_col = (r.error or "").replace("\n", " ")[:120]
|
||||
lines.append(f"| {r.label} | {status_icon} | {cached} | {r.elapsed_seconds:.1f}s | {audio_col} | {error_col} |")
|
||||
|
||||
failed_rows = [r for r in results if r.status != "passed"]
|
||||
if failed_rows:
|
||||
lines.append("")
|
||||
lines.append("## Failures")
|
||||
for r in failed_rows:
|
||||
lines.append("")
|
||||
lines.append(f"### {r.label} — {r.status}")
|
||||
if r.error:
|
||||
lines.append("")
|
||||
lines.append("```")
|
||||
lines.append(r.error)
|
||||
lines.append("```")
|
||||
if r.server_log_tail:
|
||||
lines.append("")
|
||||
lines.append("<details><summary>server log (last lines)</summary>")
|
||||
lines.append("")
|
||||
lines.append("```")
|
||||
lines.extend(r.server_log_tail)
|
||||
lines.append("```")
|
||||
lines.append("</details>")
|
||||
|
||||
md_path.write_text("\n".join(lines) + "\n")
|
||||
return json_path, md_path
|
||||
|
||||
|
||||
# ── Main ─────────────────────────────────────────────────────────────
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="Voicebox E2E model generation test")
|
||||
p.add_argument("--binary", type=Path, help="Path to voicebox-server binary (overrides auto-detect)")
|
||||
p.add_argument("--skip-build", action="store_true", help="Error if binary missing instead of building")
|
||||
p.add_argument(
|
||||
"--reference-wav",
|
||||
type=Path,
|
||||
default=FIXTURES_DIR / "reference_voice.wav",
|
||||
help="Reference audio for cloning engines",
|
||||
)
|
||||
p.add_argument(
|
||||
"--reference-text",
|
||||
help="Transcription of reference-wav (default: read from fixtures/reference_voice.txt)",
|
||||
)
|
||||
p.add_argument("--only", help="Comma-separated engines to run (e.g. kokoro,qwen)")
|
||||
p.add_argument("--skip", help="Comma-separated engines to skip")
|
||||
p.add_argument("--keep-data-dir", action="store_true", help="Don't delete tempdir after run")
|
||||
p.add_argument("--timeout-cached", type=int, default=DEFAULT_TIMEOUT_CACHED)
|
||||
p.add_argument("--timeout-download", type=int, default=DEFAULT_TIMEOUT_DOWNLOAD)
|
||||
p.add_argument("--port", type=int, help="Override auto-picked port")
|
||||
p.add_argument("--output-dir", type=Path, default=RESULTS_DIR)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def filter_matrix(args: argparse.Namespace) -> list[MatrixRow]:
|
||||
only = set(x.strip() for x in args.only.split(",")) if args.only else None
|
||||
skip = set(x.strip() for x in args.skip.split(",")) if args.skip else set()
|
||||
rows = []
|
||||
for r in MATRIX:
|
||||
if only is not None and r.engine not in only:
|
||||
continue
|
||||
if r.engine in skip:
|
||||
continue
|
||||
rows.append(r)
|
||||
return rows
|
||||
|
||||
|
||||
def resolve_reference(args: argparse.Namespace) -> tuple[Path, str]:
|
||||
wav = args.reference_wav
|
||||
if not wav.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Reference WAV not found: {wav}\n"
|
||||
f"Place a sample at {FIXTURES_DIR / 'reference_voice.wav'} or pass --reference-wav.\n"
|
||||
f"See backend/tests/fixtures/README.md."
|
||||
)
|
||||
if args.reference_text:
|
||||
text = args.reference_text
|
||||
else:
|
||||
txt_path = wav.with_suffix(".txt")
|
||||
if not txt_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Reference transcription not found: {txt_path}\n"
|
||||
f"Create it next to the WAV, or pass --reference-text."
|
||||
)
|
||||
text = txt_path.read_text().strip()
|
||||
if not text:
|
||||
raise ValueError("Reference transcription is empty")
|
||||
return wav, text
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
rows = filter_matrix(args)
|
||||
if not rows:
|
||||
print("No rows selected after --only/--skip filtering", file=sys.stderr)
|
||||
return 2
|
||||
|
||||
# Binary
|
||||
binary = args.binary or find_binary()
|
||||
if binary is None:
|
||||
if args.skip_build:
|
||||
print("No frozen binary found and --skip-build set. Run: python backend/build_binary.py", file=sys.stderr)
|
||||
return 2
|
||||
binary = build_binary()
|
||||
if not binary.exists():
|
||||
print(f"Binary path does not exist: {binary}", file=sys.stderr)
|
||||
return 2
|
||||
print(f"[binary] {binary}", flush=True)
|
||||
|
||||
# Reference audio (only required if any cloning row is in the matrix)
|
||||
needs_reference = any(r.profile_kind == "cloned" for r in rows)
|
||||
ref_wav: Optional[Path] = None
|
||||
ref_text: Optional[str] = None
|
||||
if needs_reference:
|
||||
try:
|
||||
ref_wav, ref_text = resolve_reference(args)
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
print(f"[fixture] {e}", file=sys.stderr)
|
||||
return 2
|
||||
print(f"[fixture] reference WAV: {ref_wav}", flush=True)
|
||||
print(f"[fixture] reference text: {ref_text!r}", flush=True)
|
||||
|
||||
# Tempdir + log path
|
||||
data_dir = Path(tempfile.mkdtemp(prefix="voicebox-e2e-"))
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
log_path = args.output_dir / f"server-{ts}.log"
|
||||
|
||||
port = args.port or pick_free_port()
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
|
||||
server = ServerProcess(binary=binary, port=port, data_dir=data_dir, log_path=log_path)
|
||||
started_at = datetime.now(timezone.utc)
|
||||
results: list[ModelResult] = []
|
||||
|
||||
try:
|
||||
server.start()
|
||||
print(f"[health] waiting for {base_url}/health ...", flush=True)
|
||||
wait_for_health(base_url, server, HEALTH_TIMEOUT)
|
||||
print("[health] ready", flush=True)
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
# Profile setup (only create what's needed)
|
||||
cloned_profile_id: Optional[str] = None
|
||||
kokoro_profile_id: Optional[str] = None
|
||||
qwen_cv_profile_id: Optional[str] = None
|
||||
needed_kinds = {r.profile_kind for r in rows}
|
||||
if "cloned" in needed_kinds:
|
||||
assert ref_wav is not None and ref_text is not None
|
||||
print("[profile] creating cloned profile...", flush=True)
|
||||
cloned_profile_id = create_cloned_profile(client, base_url, ref_wav, ref_text)
|
||||
if "preset_kokoro" in needed_kinds:
|
||||
print("[profile] creating kokoro preset...", flush=True)
|
||||
kokoro_profile_id = create_preset_profile(client, base_url, "e2e-kokoro", "kokoro", "af_heart")
|
||||
if "preset_qwen_cv" in needed_kinds:
|
||||
print("[profile] creating qwen_custom_voice preset...", flush=True)
|
||||
qwen_cv_profile_id = create_preset_profile(client, base_url, "e2e-qwen-cv", "qwen_custom_voice", "Ryan")
|
||||
|
||||
profile_lookup = {
|
||||
"cloned": cloned_profile_id,
|
||||
"preset_kokoro": kokoro_profile_id,
|
||||
"preset_qwen_cv": qwen_cv_profile_id,
|
||||
}
|
||||
|
||||
# Matrix loop
|
||||
for row in rows:
|
||||
print(f"\n[run] {row.label} (engine={row.engine}, size={row.model_size})", flush=True)
|
||||
profile_id = profile_lookup[row.profile_kind]
|
||||
assert profile_id is not None
|
||||
was_cached = get_model_cached(client, base_url, row.model_name)
|
||||
timeout_s = args.timeout_cached if was_cached else args.timeout_download
|
||||
print(f"[run] cached={was_cached} timeout={timeout_s}s", flush=True)
|
||||
|
||||
t0 = time.time()
|
||||
result = ModelResult(
|
||||
label=row.label,
|
||||
engine=row.engine,
|
||||
model_size=row.model_size,
|
||||
status="failed",
|
||||
was_cached=was_cached,
|
||||
)
|
||||
try:
|
||||
status, payload = run_one_generation(client, base_url, row, profile_id, timeout_s)
|
||||
result.status = "passed" if status == "completed" else status
|
||||
result.generation_id = payload.get("id")
|
||||
result.audio_duration = payload.get("duration")
|
||||
result.error = payload.get("error")
|
||||
if status == "completed" and result.generation_id:
|
||||
audio_path, audio_bytes = fetch_audio_info(
|
||||
client, base_url, result.generation_id, data_dir
|
||||
)
|
||||
result.audio_path = audio_path
|
||||
result.audio_bytes = audio_bytes
|
||||
if audio_bytes is not None and audio_bytes == 0:
|
||||
result.status = "failed"
|
||||
result.error = (result.error or "") + " (audio file is empty)"
|
||||
except httpx.HTTPStatusError as e:
|
||||
result.status = "failed"
|
||||
result.http_status = e.response.status_code
|
||||
try:
|
||||
detail = e.response.json().get("detail")
|
||||
except Exception:
|
||||
detail = e.response.text
|
||||
result.error = f"HTTP {e.response.status_code}: {detail}"
|
||||
except Exception as e:
|
||||
result.status = "failed"
|
||||
result.error = f"{type(e).__name__}: {e}"
|
||||
|
||||
result.elapsed_seconds = round(time.time() - t0, 2)
|
||||
if result.status != "passed":
|
||||
result.server_log_tail = server.log_tail(100)
|
||||
print(f"[run] {row.label} → {result.status} in {result.elapsed_seconds}s"
|
||||
+ (f" ({result.error})" if result.error else ""), flush=True)
|
||||
results.append(result)
|
||||
finally:
|
||||
finished_at = datetime.now(timezone.utc)
|
||||
server.stop()
|
||||
if not args.keep_data_dir:
|
||||
shutil.rmtree(data_dir, ignore_errors=True)
|
||||
else:
|
||||
print(f"[cleanup] keeping data dir: {data_dir}", flush=True)
|
||||
|
||||
json_path, md_path = write_reports(args.output_dir, binary, started_at, finished_at, results)
|
||||
print(f"\n[report] {json_path}")
|
||||
print(f"[report] {md_path}")
|
||||
print(f"[report] server log: {log_path}")
|
||||
|
||||
passed = sum(1 for r in results if r.status == "passed")
|
||||
failed = len(results) - passed
|
||||
print(f"\n== {passed} passed, {failed} failed ==")
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
112
backend/tests/test_audio_preprocess.py
Normal file
112
backend/tests/test_audio_preprocess.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Unit tests for reference-audio preprocessing.
|
||||
|
||||
Covers :func:`backend.utils.audio.preprocess_reference_audio` and
|
||||
:func:`backend.utils.audio.validate_and_load_reference_audio`.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from utils.audio import ( # noqa: E402
|
||||
preprocess_reference_audio,
|
||||
validate_and_load_reference_audio,
|
||||
)
|
||||
|
||||
|
||||
SR = 24000
|
||||
|
||||
|
||||
def _tone(duration_s: float, amp: float = 0.3, freq: float = 220.0) -> np.ndarray:
|
||||
n = int(duration_s * SR)
|
||||
t = np.arange(n, dtype=np.float32) / SR
|
||||
return (amp * np.sin(2 * np.pi * freq * t)).astype(np.float32)
|
||||
|
||||
|
||||
def test_peak_cap_scales_hot_input():
|
||||
audio = _tone(3.0, amp=0.99)
|
||||
out = preprocess_reference_audio(audio, SR)
|
||||
assert np.abs(out).max() <= 0.951
|
||||
|
||||
|
||||
def test_peak_cap_leaves_moderate_input_untouched():
|
||||
audio = _tone(3.0, amp=0.5)
|
||||
out = preprocess_reference_audio(audio, SR)
|
||||
assert np.isclose(np.abs(out).max(), 0.5, atol=1e-3)
|
||||
|
||||
|
||||
def test_dc_offset_removed():
|
||||
audio = _tone(3.0, amp=0.3) + 0.1
|
||||
out = preprocess_reference_audio(audio, SR)
|
||||
assert abs(float(np.mean(out))) < 1e-3
|
||||
|
||||
|
||||
def test_silence_is_trimmed_with_padding_kept():
|
||||
silence = np.zeros(int(SR * 1.0), dtype=np.float32)
|
||||
speech = _tone(3.0, amp=0.3)
|
||||
audio = np.concatenate([silence, speech, silence])
|
||||
out = preprocess_reference_audio(audio, SR)
|
||||
# Most of the 2s of leading/trailing silence should be gone, but the
|
||||
# 3s of speech plus ~200ms of padding should remain.
|
||||
assert len(audio) - len(out) >= SR, "expected >=1s of silence trimmed"
|
||||
assert len(out) >= int(3.0 * SR), "speech body should be preserved"
|
||||
|
||||
|
||||
def test_clean_audio_is_not_padded_past_original_length():
|
||||
# Well-recorded audio with no edge silence shouldn't get longer after
|
||||
# preprocessing — otherwise a 29.9 s upload could be pushed past the
|
||||
# 30 s max_duration ceiling downstream.
|
||||
audio = _tone(3.0, amp=0.3)
|
||||
out = preprocess_reference_audio(audio, SR)
|
||||
assert len(out) <= len(audio)
|
||||
|
||||
|
||||
def test_empty_input_returns_empty():
|
||||
out = preprocess_reference_audio(np.zeros(0, dtype=np.float32), SR)
|
||||
assert out.size == 0
|
||||
|
||||
|
||||
def test_validate_accepts_previously_rejected_hot_file(tmp_path):
|
||||
audio = _tone(3.0, amp=0.995)
|
||||
path = tmp_path / "hot.wav"
|
||||
sf.write(str(path), audio, SR)
|
||||
|
||||
ok, err, out_audio, out_sr = validate_and_load_reference_audio(str(path))
|
||||
|
||||
assert ok, f"expected pass, got error: {err}"
|
||||
assert out_audio is not None
|
||||
assert out_sr == SR
|
||||
assert np.abs(out_audio).max() <= 0.951
|
||||
|
||||
|
||||
def test_validate_still_rejects_silent_input(tmp_path):
|
||||
audio = np.zeros(int(SR * 3.0), dtype=np.float32)
|
||||
path = tmp_path / "silent.wav"
|
||||
sf.write(str(path), audio, SR)
|
||||
|
||||
ok, err, _, _ = validate_and_load_reference_audio(str(path))
|
||||
|
||||
assert not ok
|
||||
assert err is not None
|
||||
assert "too short" in err.lower() or "quiet" in err.lower()
|
||||
|
||||
|
||||
def test_validate_rejects_too_short(tmp_path):
|
||||
audio = _tone(0.5, amp=0.3)
|
||||
path = tmp_path / "short.wav"
|
||||
sf.write(str(path), audio, SR)
|
||||
|
||||
ok, err, _, _ = validate_and_load_reference_audio(str(path))
|
||||
|
||||
assert not ok
|
||||
assert "too short" in (err or "").lower()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
162
backend/tests/test_cors.py
Normal file
162
backend/tests/test_cors.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Tests for CORS origin restrictions.
|
||||
|
||||
Validates that the CORS middleware only allows known local origins
|
||||
and respects the VOICEBOX_CORS_ORIGINS environment variable.
|
||||
|
||||
Uses a minimal FastAPI app that mirrors the exact CORS configuration
|
||||
from backend/main.py, so tests run without heavy ML dependencies.
|
||||
|
||||
Usage:
|
||||
pip install httpx pytest fastapi starlette
|
||||
python -m pytest backend/tests/test_cors.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
|
||||
def _build_app(env_origins: str = "") -> FastAPI:
|
||||
"""
|
||||
Build a minimal FastAPI app with the same CORS logic as backend/main.py.
|
||||
|
||||
This mirrors the exact code in main.py so the test validates the real
|
||||
configuration without needing torch/numpy/transformers installed.
|
||||
"""
|
||||
app = FastAPI()
|
||||
|
||||
_default_origins = [
|
||||
"http://localhost:5173",
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:17493",
|
||||
"http://127.0.0.1:17493",
|
||||
"tauri://localhost",
|
||||
"https://tauri.localhost",
|
||||
]
|
||||
_cors_origins = _default_origins + [o.strip() for o in env_origins.split(",") if o.strip()]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=_cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client():
|
||||
return TestClient(_build_app())
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client_with_custom_origins():
|
||||
return TestClient(_build_app("https://custom.example.com,https://other.example.com"))
|
||||
|
||||
|
||||
def _get_with_origin(client: TestClient, origin: str) -> dict:
|
||||
"""Send a GET with Origin header, return response headers."""
|
||||
response = client.get("/health", headers={"Origin": origin})
|
||||
return dict(response.headers)
|
||||
|
||||
|
||||
def _preflight(client: TestClient, origin: str) -> dict:
|
||||
"""Send CORS preflight OPTIONS request, return response headers."""
|
||||
response = client.options(
|
||||
"/health",
|
||||
headers={
|
||||
"Origin": origin,
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
return dict(response.headers)
|
||||
|
||||
|
||||
class TestCORSDefaultOrigins:
|
||||
"""CORS should allow known local origins and block everything else."""
|
||||
|
||||
@pytest.mark.parametrize("origin", [
|
||||
"http://localhost:5173",
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:17493",
|
||||
"http://127.0.0.1:17493",
|
||||
"tauri://localhost",
|
||||
"https://tauri.localhost",
|
||||
])
|
||||
def test_allowed_origins(self, client, origin):
|
||||
headers = _get_with_origin(client, origin)
|
||||
assert headers.get("access-control-allow-origin") == origin
|
||||
|
||||
@pytest.mark.parametrize("origin", [
|
||||
"http://evil.com",
|
||||
"http://localhost:9999",
|
||||
"https://attacker.example.com",
|
||||
"null",
|
||||
])
|
||||
def test_blocked_origins(self, client, origin):
|
||||
headers = _get_with_origin(client, origin)
|
||||
assert "access-control-allow-origin" not in headers
|
||||
|
||||
def test_preflight_allowed(self, client):
|
||||
headers = _preflight(client, "http://localhost:5173")
|
||||
assert headers.get("access-control-allow-origin") == "http://localhost:5173"
|
||||
|
||||
def test_preflight_blocked(self, client):
|
||||
headers = _preflight(client, "http://evil.com")
|
||||
assert "access-control-allow-origin" not in headers
|
||||
|
||||
def test_credentials_header_present(self, client):
|
||||
headers = _get_with_origin(client, "http://localhost:5173")
|
||||
assert headers.get("access-control-allow-credentials") == "true"
|
||||
|
||||
|
||||
class TestCORSCustomOrigins:
|
||||
"""VOICEBOX_CORS_ORIGINS env var should extend the allowlist."""
|
||||
|
||||
def test_custom_origin_allowed(self, client_with_custom_origins):
|
||||
headers = _get_with_origin(client_with_custom_origins, "https://custom.example.com")
|
||||
assert headers.get("access-control-allow-origin") == "https://custom.example.com"
|
||||
|
||||
def test_other_custom_origin_allowed(self, client_with_custom_origins):
|
||||
headers = _get_with_origin(client_with_custom_origins, "https://other.example.com")
|
||||
assert headers.get("access-control-allow-origin") == "https://other.example.com"
|
||||
|
||||
def test_default_origins_still_work(self, client_with_custom_origins):
|
||||
headers = _get_with_origin(client_with_custom_origins, "http://localhost:5173")
|
||||
assert headers.get("access-control-allow-origin") == "http://localhost:5173"
|
||||
|
||||
def test_unlisted_origin_still_blocked(self, client_with_custom_origins):
|
||||
headers = _get_with_origin(client_with_custom_origins, "http://evil.com")
|
||||
assert "access-control-allow-origin" not in headers
|
||||
|
||||
|
||||
class TestCORSEnvVarParsing:
|
||||
"""Edge cases for VOICEBOX_CORS_ORIGINS parsing."""
|
||||
|
||||
def test_empty_env_var(self):
|
||||
app = _build_app("")
|
||||
client = TestClient(app)
|
||||
headers = _get_with_origin(client, "http://evil.com")
|
||||
assert "access-control-allow-origin" not in headers
|
||||
|
||||
def test_whitespace_trimmed(self):
|
||||
app = _build_app(" https://spaced.example.com ")
|
||||
client = TestClient(app)
|
||||
headers = _get_with_origin(client, "https://spaced.example.com")
|
||||
assert headers.get("access-control-allow-origin") == "https://spaced.example.com"
|
||||
|
||||
def test_trailing_comma_ignored(self):
|
||||
app = _build_app("https://one.example.com,")
|
||||
client = TestClient(app)
|
||||
headers = _get_with_origin(client, "https://one.example.com")
|
||||
assert headers.get("access-control-allow-origin") == "https://one.example.com"
|
||||
305
backend/tests/test_generation_download.py
Normal file
305
backend/tests/test_generation_download.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
Test TTS generation with SSE progress monitoring.
|
||||
This test captures the exact SSE events triggered during generation
|
||||
to identify UX issues where users see download progress even when
|
||||
the model is already cached.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import httpx
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
async def monitor_sse_stream(model_name: str, timeout: int = 120):
|
||||
"""Monitor SSE stream for a model during generation."""
|
||||
events: List[Dict] = []
|
||||
url = f"http://localhost:8000/models/progress/{model_name}"
|
||||
|
||||
print(f"[{_timestamp()}] Connecting to SSE endpoint: {url}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("GET", url) as response:
|
||||
print(f"[{_timestamp()}] SSE connected, status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"[{_timestamp()}] Error: SSE endpoint returned {response.status_code}")
|
||||
return events
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
timestamp = _timestamp()
|
||||
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
data = json.loads(line[6:])
|
||||
print(
|
||||
f"[{timestamp}] → SSE Event: {data['status']:12} {data.get('progress', 0):6.1f}% {data.get('filename', '')}"
|
||||
)
|
||||
events.append({**data, "_timestamp": timestamp})
|
||||
|
||||
# Stop if complete or error
|
||||
if data.get("status") in ("complete", "error"):
|
||||
print(f"[{timestamp}] → Model {data['status']}!")
|
||||
break
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[{timestamp}] Error parsing JSON: {e}")
|
||||
print(f" Line was: {line}")
|
||||
|
||||
elif line.startswith(": heartbeat"):
|
||||
print(f"[{timestamp}] ♥ heartbeat")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print(f"[{_timestamp()}] SSE monitoring timed out")
|
||||
except Exception as e:
|
||||
print(f"[{_timestamp()}] SSE error: {e}")
|
||||
|
||||
return events
|
||||
|
||||
|
||||
async def trigger_generation(profile_id: str, text: str, model_size: str = "1.7B"):
|
||||
"""Trigger TTS generation via the API."""
|
||||
url = "http://localhost:8000/generate"
|
||||
|
||||
print(f"\n[{_timestamp()}] Triggering generation...")
|
||||
print(f" Profile: {profile_id}")
|
||||
print(f" Text: {text[:50]}...")
|
||||
print(f" Model: {model_size}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
response = await client.post(
|
||||
url,
|
||||
json={
|
||||
"profile_id": profile_id,
|
||||
"text": text,
|
||||
"language": "en",
|
||||
"model_size": model_size,
|
||||
},
|
||||
)
|
||||
|
||||
print(f"[{_timestamp()}] Response: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print(f"[{_timestamp()}] ✓ Generation successful!")
|
||||
print(f" Generation ID: {result.get('id')}")
|
||||
print(f" Duration: {result.get('duration', 0):.2f}s")
|
||||
return True, result
|
||||
elif response.status_code == 202:
|
||||
# Model is being downloaded
|
||||
result = response.json()
|
||||
print(f"[{_timestamp()}] → Model download in progress")
|
||||
print(f" Detail: {result}")
|
||||
return False, result
|
||||
else:
|
||||
print(f"[{_timestamp()}] ✗ Error: {response.text}")
|
||||
return False, None
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{_timestamp()}] ✗ Exception: {e}")
|
||||
return False, None
|
||||
|
||||
|
||||
async def get_first_profile():
|
||||
"""Get the first available voice profile."""
|
||||
url = "http://localhost:8000/profiles"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get(url)
|
||||
if response.status_code == 200:
|
||||
profiles = response.json()
|
||||
if profiles:
|
||||
return profiles[0]["id"]
|
||||
except Exception as e:
|
||||
print(f"Error getting profiles: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def check_server():
|
||||
"""Check if the server is running."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
response = await client.get("http://localhost:8000/health")
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
print(f"Server not running: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _timestamp():
|
||||
"""Get current timestamp for logging."""
|
||||
return datetime.now().strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
|
||||
async def test_generation_with_cached_model():
|
||||
"""
|
||||
Test Case 1: Generation when model is already cached.
|
||||
|
||||
This should NOT show any download progress events.
|
||||
If it does, that's the UX bug we're trying to fix.
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST CASE 1: Generation with Cached Model")
|
||||
print("=" * 80)
|
||||
print("Expected: No download progress events (or minimal/instant completion)")
|
||||
print("Actual UX Issue: Users see 'started' and 'finished' events even for cached models")
|
||||
print("=" * 80)
|
||||
|
||||
model_size = "1.7B"
|
||||
model_name = f"qwen-tts-{model_size}"
|
||||
|
||||
# Get a profile
|
||||
profile_id = await get_first_profile()
|
||||
if not profile_id:
|
||||
print("✗ No voice profiles found. Please create a profile first.")
|
||||
return False
|
||||
|
||||
print(f"\nUsing profile: {profile_id}")
|
||||
|
||||
# Start SSE monitor BEFORE triggering generation
|
||||
monitor_task = asyncio.create_task(monitor_sse_stream(model_name, timeout=30))
|
||||
|
||||
# Wait for SSE to connect
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Trigger generation
|
||||
test_text = "Hello, this is a test of the voice generation system."
|
||||
success, result = await trigger_generation(profile_id, test_text, model_size)
|
||||
|
||||
if not success and result and result.get("downloading"):
|
||||
print("\n⚠ Model is being downloaded. Waiting for download to complete...")
|
||||
# Wait for SSE monitor to capture download events
|
||||
events = await monitor_task
|
||||
return events
|
||||
|
||||
# Wait a bit more to catch any progress events
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# Cancel SSE monitor
|
||||
monitor_task.cancel()
|
||||
try:
|
||||
events = await monitor_task
|
||||
except asyncio.CancelledError:
|
||||
events = []
|
||||
|
||||
return events
|
||||
|
||||
|
||||
async def test_generation_with_fresh_download():
|
||||
"""
|
||||
Test Case 2: Generation when model needs to be downloaded.
|
||||
|
||||
This SHOULD show download progress events.
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST CASE 2: Generation with Model Download")
|
||||
print("=" * 80)
|
||||
print("Expected: Download progress events from 0% to 100%")
|
||||
print("=" * 80)
|
||||
|
||||
# Use a different model size to force download
|
||||
model_size = "0.6B" # Smaller model for faster testing
|
||||
model_name = f"qwen-tts-{model_size}"
|
||||
|
||||
# Get a profile
|
||||
profile_id = await get_first_profile()
|
||||
if not profile_id:
|
||||
print("✗ No voice profiles found. Please create a profile first.")
|
||||
return False
|
||||
|
||||
print(f"\nUsing profile: {profile_id}")
|
||||
print("Note: This will download the model if not cached")
|
||||
|
||||
# Start SSE monitor BEFORE triggering generation
|
||||
monitor_task = asyncio.create_task(monitor_sse_stream(model_name, timeout=300))
|
||||
|
||||
# Wait for SSE to connect
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Trigger generation
|
||||
test_text = "This should trigger a model download if the model is not cached."
|
||||
success, result = await trigger_generation(profile_id, test_text, model_size)
|
||||
|
||||
if not success and result and result.get("downloading"):
|
||||
print("\n→ Model download initiated. Monitoring progress...")
|
||||
# Wait for download to complete
|
||||
events = await monitor_task
|
||||
|
||||
# Try generation again
|
||||
print(f"\n[{_timestamp()}] Retrying generation after download...")
|
||||
await asyncio.sleep(2)
|
||||
success, result = await trigger_generation(profile_id, test_text, model_size)
|
||||
|
||||
if success:
|
||||
print("✓ Generation successful after download")
|
||||
|
||||
return events
|
||||
|
||||
# If model was already cached
|
||||
await asyncio.sleep(3)
|
||||
monitor_task.cancel()
|
||||
try:
|
||||
events = await monitor_task
|
||||
except asyncio.CancelledError:
|
||||
events = []
|
||||
|
||||
return events
|
||||
|
||||
|
||||
async def main():
|
||||
print("=" * 80)
|
||||
print("TTS Generation Progress Test")
|
||||
print("=" * 80)
|
||||
print("Purpose: Capture exact SSE events during generation to identify UX issues")
|
||||
print("=" * 80)
|
||||
|
||||
# Check if server is running
|
||||
print(f"\n[{_timestamp()}] Checking if server is running...")
|
||||
if not await check_server():
|
||||
print("✗ Server is not running on http://localhost:8000")
|
||||
print("\nPlease start the server first:")
|
||||
print(" cd backend && python main.py")
|
||||
return False
|
||||
|
||||
print("✓ Server is running")
|
||||
|
||||
# Test Case 1: Cached model
|
||||
print("\n" + "🧪 " * 20)
|
||||
events_cached = await test_generation_with_cached_model()
|
||||
|
||||
# Results for Test Case 1
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST CASE 1 RESULTS: Generation with Cached Model")
|
||||
print("=" * 80)
|
||||
|
||||
if not events_cached:
|
||||
print("✓ GOOD: No SSE progress events received")
|
||||
print(" This is the expected behavior for a cached model.")
|
||||
else:
|
||||
print(f"⚠ ISSUE FOUND: Received {len(events_cached)} SSE events:")
|
||||
print("\nEvent Timeline:")
|
||||
for i, event in enumerate(events_cached, 1):
|
||||
timestamp = event.pop("_timestamp", "??:??:??.???")
|
||||
print(f" {i}. [{timestamp}] {event}")
|
||||
|
||||
print("\n⚠ This explains the UX issue!")
|
||||
print(" Users see progress events even when the model is already cached,")
|
||||
print(" making them think the model is downloading again.")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Test Complete!")
|
||||
print("=" * 80)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
118
backend/tests/test_offline_guard.py
Normal file
118
backend/tests/test_offline_guard.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Unit tests for the ``force_offline_if_cached`` helper.
|
||||
|
||||
Verifies that the helper mutates the cached module constants in
|
||||
``huggingface_hub.constants`` and ``transformers.utils.hub`` — not just
|
||||
``os.environ`` — and that concurrent users are refcount-coordinated so
|
||||
one thread's exit can't strip another thread's offline protection.
|
||||
|
||||
NOTE: These tests mutate process-global state in ``huggingface_hub.constants``
|
||||
and ``transformers.utils.hub``. They are not safe under cross-process
|
||||
parallelism (e.g. ``pytest-xdist`` with ``--dist=loadfile``/``loadscope``);
|
||||
run this file serially.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from utils.hf_offline_patch import force_offline_if_cached # noqa: E402
|
||||
|
||||
|
||||
def _hf_const():
|
||||
import huggingface_hub.constants as hf_const
|
||||
|
||||
return hf_const
|
||||
|
||||
|
||||
def _tf_hub():
|
||||
import transformers.utils.hub as tf_hub
|
||||
|
||||
return tf_hub
|
||||
|
||||
|
||||
def test_mutates_cached_huggingface_hub_constant():
|
||||
original = _hf_const().HF_HUB_OFFLINE
|
||||
with force_offline_if_cached(True, "t"):
|
||||
assert _hf_const().HF_HUB_OFFLINE is True
|
||||
assert original == _hf_const().HF_HUB_OFFLINE
|
||||
|
||||
|
||||
def test_mutates_cached_transformers_constant():
|
||||
original = _tf_hub()._is_offline_mode
|
||||
with force_offline_if_cached(True, "t"):
|
||||
assert _tf_hub()._is_offline_mode is True
|
||||
assert original == _tf_hub()._is_offline_mode
|
||||
|
||||
|
||||
def test_sets_env_variable():
|
||||
original = os.environ.get("HF_HUB_OFFLINE")
|
||||
with force_offline_if_cached(True, "t"):
|
||||
assert "1" == os.environ.get("HF_HUB_OFFLINE")
|
||||
assert original == os.environ.get("HF_HUB_OFFLINE")
|
||||
|
||||
|
||||
def test_noop_when_not_cached():
|
||||
before = _hf_const().HF_HUB_OFFLINE
|
||||
with force_offline_if_cached(False, "t"):
|
||||
assert before == _hf_const().HF_HUB_OFFLINE
|
||||
|
||||
|
||||
def test_nested_contexts_respect_refcount():
|
||||
original = _hf_const().HF_HUB_OFFLINE
|
||||
with force_offline_if_cached(True, "outer"):
|
||||
assert _hf_const().HF_HUB_OFFLINE is True
|
||||
with force_offline_if_cached(True, "inner"):
|
||||
assert _hf_const().HF_HUB_OFFLINE is True
|
||||
# inner exit must not restore while outer is still active
|
||||
assert _hf_const().HF_HUB_OFFLINE is True
|
||||
assert original == _hf_const().HF_HUB_OFFLINE
|
||||
|
||||
|
||||
def test_concurrent_threads_share_offline_window():
|
||||
"""A slow thread must keep seeing offline mode even if a peer exits first."""
|
||||
original = _hf_const().HF_HUB_OFFLINE
|
||||
observations: list[bool] = []
|
||||
errors: list[Exception] = []
|
||||
barrier = threading.Barrier(2)
|
||||
fast_exited = threading.Event()
|
||||
|
||||
def slow():
|
||||
try:
|
||||
with force_offline_if_cached(True, "slow"):
|
||||
barrier.wait(timeout=5)
|
||||
assert fast_exited.wait(timeout=5), "fast thread did not exit"
|
||||
observations.append(_hf_const().HF_HUB_OFFLINE)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(exc)
|
||||
|
||||
def fast():
|
||||
try:
|
||||
with force_offline_if_cached(True, "fast"):
|
||||
barrier.wait(timeout=5)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
errors.append(exc)
|
||||
finally:
|
||||
fast_exited.set()
|
||||
|
||||
t_slow = threading.Thread(target=slow)
|
||||
t_fast = threading.Thread(target=fast)
|
||||
t_slow.start()
|
||||
t_fast.start()
|
||||
t_slow.join(timeout=5)
|
||||
t_fast.join(timeout=5)
|
||||
|
||||
assert not t_slow.is_alive(), "slow thread did not finish"
|
||||
assert not t_fast.is_alive(), "fast thread did not finish"
|
||||
assert not errors, errors
|
||||
assert observations == [True], "slow thread lost offline protection"
|
||||
assert original == _hf_const().HF_HUB_OFFLINE
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
113
backend/tests/test_offline_patch.py
Normal file
113
backend/tests/test_offline_patch.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
Unit tests for ``patch_transformers_mistral_regex``.
|
||||
|
||||
Verifies that our wrapper around
|
||||
``transformers.PreTrainedTokenizerBase._patch_mistral_regex`` catches
|
||||
exceptions from the unconditional ``huggingface_hub.model_info()`` lookup
|
||||
and returns the tokenizer unchanged — matching the success-path behavior
|
||||
for non-Mistral repos (transformers 4.57.3, ``tokenization_utils_base.py:2503``).
|
||||
|
||||
NOTE: These tests mutate ``transformers.PreTrainedTokenizerBase`` globally;
|
||||
run serially, not under ``pytest-xdist`` with per-worker process isolation.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from huggingface_hub.errors import OfflineModeIsEnabled # noqa: E402
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase # noqa: E402
|
||||
|
||||
import utils.hf_offline_patch as hf_offline_patch # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def restore_mistral_regex():
|
||||
"""Snapshot the current ``_patch_mistral_regex`` and restore after each test."""
|
||||
saved = PreTrainedTokenizerBase.__dict__.get("_patch_mistral_regex")
|
||||
saved_flag = hf_offline_patch._mistral_regex_patched
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if saved is not None:
|
||||
PreTrainedTokenizerBase._patch_mistral_regex = saved
|
||||
hf_offline_patch._mistral_regex_patched = saved_flag
|
||||
|
||||
|
||||
def _apply_patch():
|
||||
hf_offline_patch._mistral_regex_patched = False
|
||||
hf_offline_patch.patch_transformers_mistral_regex()
|
||||
|
||||
|
||||
def test_suppresses_offline_mode_is_enabled(monkeypatch):
|
||||
_apply_patch()
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
def raise_offline(*_args, **_kwargs):
|
||||
raise OfflineModeIsEnabled("offline")
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "model_info", raise_offline)
|
||||
|
||||
sentinel = object()
|
||||
result = PreTrainedTokenizerBase._patch_mistral_regex(
|
||||
sentinel, "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
)
|
||||
assert result is sentinel
|
||||
|
||||
|
||||
def test_suppresses_connection_errors(monkeypatch):
|
||||
_apply_patch()
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
def raise_connection(*_args, **_kwargs):
|
||||
raise ConnectionError("network unreachable")
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "model_info", raise_connection)
|
||||
|
||||
sentinel = object()
|
||||
result = PreTrainedTokenizerBase._patch_mistral_regex(
|
||||
sentinel, "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
)
|
||||
assert result is sentinel
|
||||
|
||||
|
||||
def test_passthrough_on_success(monkeypatch):
|
||||
"""When model_info returns non-Mistral tags the original falls through and returns the tokenizer unchanged."""
|
||||
_apply_patch()
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
class FakeInfo:
|
||||
tags = ["model-type:qwen", "language:en"]
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "model_info", lambda *_a, **_kw: FakeInfo())
|
||||
|
||||
sentinel = object()
|
||||
result = PreTrainedTokenizerBase._patch_mistral_regex(
|
||||
sentinel, "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
)
|
||||
assert result is sentinel
|
||||
|
||||
|
||||
def test_idempotent():
|
||||
_apply_patch()
|
||||
first = PreTrainedTokenizerBase._patch_mistral_regex
|
||||
hf_offline_patch.patch_transformers_mistral_regex()
|
||||
second = PreTrainedTokenizerBase._patch_mistral_regex
|
||||
assert first.__func__ is second.__func__
|
||||
|
||||
|
||||
def test_missing_method_is_noop(monkeypatch):
|
||||
monkeypatch.delattr(PreTrainedTokenizerBase, "_patch_mistral_regex", raising=False)
|
||||
hf_offline_patch._mistral_regex_patched = False
|
||||
hf_offline_patch.patch_transformers_mistral_regex()
|
||||
assert hf_offline_patch._mistral_regex_patched is False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
217
backend/tests/test_profile_duplicate_names.py
Normal file
217
backend/tests/test_profile_duplicate_names.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Tests for profile duplicate name validation.
|
||||
|
||||
This test suite verifies that the application correctly handles
|
||||
duplicate profile names and provides user-friendly error messages.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Add parent directory to path to import backend modules
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from database import Base, VoiceProfile as DBVoiceProfile
|
||||
from models import VoiceProfileCreate
|
||||
from profiles import create_profile, update_profile
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_db():
|
||||
"""Create a temporary test database."""
|
||||
# Create temporary directory for test database
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_path = Path(temp_dir) / "test.db"
|
||||
|
||||
# Create engine and session
|
||||
engine = create_engine(f"sqlite:///{db_path}")
|
||||
Base.metadata.create_all(bind=engine)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
db = SessionLocal()
|
||||
|
||||
yield db
|
||||
|
||||
# Cleanup
|
||||
db.close()
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_profiles_dir(monkeypatch, tmp_path):
|
||||
"""Mock the profiles directory to use a temporary path."""
|
||||
from backend import config
|
||||
monkeypatch.setattr(config, 'get_profiles_dir', lambda: tmp_path)
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_profile_duplicate_name_raises_error(test_db, mock_profiles_dir):
|
||||
"""Test that creating a profile with a duplicate name raises a ValueError."""
|
||||
# Create first profile
|
||||
profile_data_1 = VoiceProfileCreate(
|
||||
name="Test Profile",
|
||||
description="First profile",
|
||||
language="en"
|
||||
)
|
||||
|
||||
profile_1 = await create_profile(profile_data_1, test_db)
|
||||
assert profile_1.name == "Test Profile"
|
||||
|
||||
# Try to create second profile with same name
|
||||
profile_data_2 = VoiceProfileCreate(
|
||||
name="Test Profile",
|
||||
description="Second profile",
|
||||
language="en"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await create_profile(profile_data_2, test_db)
|
||||
|
||||
# Verify error message is user-friendly
|
||||
assert "already exists" in str(exc_info.value)
|
||||
assert "Test Profile" in str(exc_info.value)
|
||||
assert "choose a different name" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_profile_different_names_succeeds(test_db, mock_profiles_dir):
|
||||
"""Test that creating profiles with different names succeeds."""
|
||||
# Create first profile
|
||||
profile_data_1 = VoiceProfileCreate(
|
||||
name="Profile One",
|
||||
description="First profile",
|
||||
language="en"
|
||||
)
|
||||
|
||||
profile_1 = await create_profile(profile_data_1, test_db)
|
||||
assert profile_1.name == "Profile One"
|
||||
|
||||
# Create second profile with different name
|
||||
profile_data_2 = VoiceProfileCreate(
|
||||
name="Profile Two",
|
||||
description="Second profile",
|
||||
language="en"
|
||||
)
|
||||
|
||||
profile_2 = await create_profile(profile_data_2, test_db)
|
||||
assert profile_2.name == "Profile Two"
|
||||
|
||||
# Verify both profiles exist
|
||||
assert profile_1.id != profile_2.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile_to_duplicate_name_raises_error(test_db, mock_profiles_dir):
|
||||
"""Test that updating a profile to a duplicate name raises a ValueError."""
|
||||
# Create two profiles with different names
|
||||
profile_data_1 = VoiceProfileCreate(
|
||||
name="Profile A",
|
||||
description="First profile",
|
||||
language="en"
|
||||
)
|
||||
profile_1 = await create_profile(profile_data_1, test_db)
|
||||
|
||||
profile_data_2 = VoiceProfileCreate(
|
||||
name="Profile B",
|
||||
description="Second profile",
|
||||
language="en"
|
||||
)
|
||||
profile_2 = await create_profile(profile_data_2, test_db)
|
||||
|
||||
# Try to update profile_2 to use profile_1's name
|
||||
update_data = VoiceProfileCreate(
|
||||
name="Profile A", # Duplicate name
|
||||
description="Updated description",
|
||||
language="en"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await update_profile(profile_2.id, update_data, test_db)
|
||||
|
||||
# Verify error message is user-friendly
|
||||
assert "already exists" in str(exc_info.value)
|
||||
assert "Profile A" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile_keep_same_name_succeeds(test_db, mock_profiles_dir):
|
||||
"""Test that updating a profile while keeping the same name succeeds."""
|
||||
# Create profile
|
||||
profile_data = VoiceProfileCreate(
|
||||
name="My Profile",
|
||||
description="Original description",
|
||||
language="en"
|
||||
)
|
||||
profile = await create_profile(profile_data, test_db)
|
||||
|
||||
# Update profile with same name but different description
|
||||
update_data = VoiceProfileCreate(
|
||||
name="My Profile", # Same name
|
||||
description="Updated description",
|
||||
language="en"
|
||||
)
|
||||
|
||||
updated_profile = await update_profile(profile.id, update_data, test_db)
|
||||
|
||||
# Verify update succeeded
|
||||
assert updated_profile is not None
|
||||
assert updated_profile.id == profile.id
|
||||
assert updated_profile.name == "My Profile"
|
||||
assert updated_profile.description == "Updated description"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_profile_to_new_unique_name_succeeds(test_db, mock_profiles_dir):
|
||||
"""Test that updating a profile to a new unique name succeeds."""
|
||||
# Create profile
|
||||
profile_data = VoiceProfileCreate(
|
||||
name="Original Name",
|
||||
description="Profile description",
|
||||
language="en"
|
||||
)
|
||||
profile = await create_profile(profile_data, test_db)
|
||||
|
||||
# Update profile with new unique name
|
||||
update_data = VoiceProfileCreate(
|
||||
name="New Unique Name",
|
||||
description="Updated description",
|
||||
language="en"
|
||||
)
|
||||
|
||||
updated_profile = await update_profile(profile.id, update_data, test_db)
|
||||
|
||||
# Verify update succeeded
|
||||
assert updated_profile is not None
|
||||
assert updated_profile.id == profile.id
|
||||
assert updated_profile.name == "New Unique Name"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_case_sensitive_names_allowed(test_db, mock_profiles_dir):
|
||||
"""Test that profile names are case-sensitive (e.g., 'Test' and 'test' are different)."""
|
||||
# Create profile with lowercase name
|
||||
profile_data_1 = VoiceProfileCreate(
|
||||
name="test profile",
|
||||
description="Lowercase",
|
||||
language="en"
|
||||
)
|
||||
profile_1 = await create_profile(profile_data_1, test_db)
|
||||
|
||||
# Create profile with different case
|
||||
profile_data_2 = VoiceProfileCreate(
|
||||
name="Test Profile",
|
||||
description="Title case",
|
||||
language="en"
|
||||
)
|
||||
profile_2 = await create_profile(profile_data_2, test_db)
|
||||
|
||||
# Both should succeed since SQLite unique constraint is case-sensitive by default
|
||||
assert profile_1.name == "test profile"
|
||||
assert profile_2.name == "Test Profile"
|
||||
assert profile_1.id != profile_2.id
|
||||
313
backend/tests/test_progress.py
Normal file
313
backend/tests/test_progress.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""
|
||||
Test script to debug model download progress tracking.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict
|
||||
import logging
|
||||
|
||||
# Set up logging to see what's happening
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
from utils.progress import ProgressManager, get_progress_manager
|
||||
from utils.hf_progress import HFProgressTracker, create_hf_progress_callback
|
||||
|
||||
|
||||
def test_progress_manager_basic():
|
||||
"""Test 1: Basic ProgressManager functionality."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 1: ProgressManager Basic Operations")
|
||||
print("=" * 60)
|
||||
|
||||
pm = ProgressManager()
|
||||
|
||||
# Test update_progress
|
||||
pm.update_progress(
|
||||
model_name="test-model",
|
||||
current=50,
|
||||
total=100,
|
||||
filename="test.bin",
|
||||
status="downloading"
|
||||
)
|
||||
|
||||
# Test get_progress
|
||||
progress = pm.get_progress("test-model")
|
||||
print(f"✓ Progress stored: {progress}")
|
||||
assert progress is not None
|
||||
assert progress["progress"] == 50.0
|
||||
assert progress["filename"] == "test.bin"
|
||||
assert progress["status"] == "downloading"
|
||||
|
||||
# Test mark_complete
|
||||
pm.mark_complete("test-model")
|
||||
progress = pm.get_progress("test-model")
|
||||
print(f"✓ Marked complete: {progress}")
|
||||
assert progress["status"] == "complete"
|
||||
assert progress["progress"] == 100.0
|
||||
|
||||
print("✓ Test 1 PASSED\n")
|
||||
return True
|
||||
|
||||
|
||||
async def test_progress_manager_sse():
|
||||
"""Test 2: ProgressManager SSE streaming."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 2: ProgressManager SSE Streaming")
|
||||
print("=" * 60)
|
||||
|
||||
pm = ProgressManager()
|
||||
collected_events: List[Dict] = []
|
||||
|
||||
# Simulate SSE client
|
||||
async def sse_client():
|
||||
"""Simulates a frontend SSE connection."""
|
||||
print(" SSE client: Subscribing to test-model-sse...")
|
||||
async for event in pm.subscribe("test-model-sse"):
|
||||
# Parse SSE event
|
||||
if event.startswith("data: "):
|
||||
data = json.loads(event[6:])
|
||||
print(f" SSE client: Received event: {data['status']} - {data.get('progress', 0):.1f}%")
|
||||
collected_events.append(data)
|
||||
|
||||
# Stop when complete
|
||||
if data.get("status") in ("complete", "error"):
|
||||
break
|
||||
elif event.startswith(": heartbeat"):
|
||||
print(" SSE client: Received heartbeat")
|
||||
|
||||
# Simulate download progress updates (from backend thread)
|
||||
async def simulate_download():
|
||||
"""Simulates backend sending progress updates."""
|
||||
print(" Backend: Starting simulated download...")
|
||||
await asyncio.sleep(0.2) # Let SSE client subscribe first
|
||||
|
||||
# Send progress updates
|
||||
for i in range(0, 101, 20):
|
||||
print(f" Backend: Updating progress to {i}%")
|
||||
pm.update_progress(
|
||||
model_name="test-model-sse",
|
||||
current=i,
|
||||
total=100,
|
||||
filename=f"file_{i}.bin",
|
||||
status="downloading" if i < 100 else "downloading"
|
||||
)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Mark complete
|
||||
print(" Backend: Marking download complete")
|
||||
pm.mark_complete("test-model-sse")
|
||||
|
||||
# Run SSE client and download simulation concurrently
|
||||
await asyncio.gather(
|
||||
sse_client(),
|
||||
simulate_download()
|
||||
)
|
||||
|
||||
# Verify we got events
|
||||
print(f"\n Collected {len(collected_events)} events")
|
||||
assert len(collected_events) > 0, "Should have received at least one event"
|
||||
assert collected_events[-1]["status"] == "complete", "Last event should be 'complete'"
|
||||
|
||||
print("✓ Test 2 PASSED\n")
|
||||
return True
|
||||
|
||||
|
||||
def test_hf_progress_tracker():
|
||||
"""Test 3: HFProgressTracker tqdm patching."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 3: HFProgressTracker tqdm Patching")
|
||||
print("=" * 60)
|
||||
|
||||
captured_progress: List[tuple] = []
|
||||
|
||||
def progress_callback(downloaded: int, total: int, filename: str):
|
||||
"""Capture progress updates."""
|
||||
captured_progress.append((downloaded, total, filename))
|
||||
print(f" Progress callback: {downloaded}/{total} bytes ({filename})")
|
||||
|
||||
tracker = HFProgressTracker(progress_callback)
|
||||
|
||||
# Simulate a download with tqdm
|
||||
with tracker.patch_download():
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
# Simulate downloading a file
|
||||
print(" Simulating download with tqdm...")
|
||||
total_size = 1000
|
||||
with tqdm(total=total_size, desc="model.bin", unit="B", unit_scale=True) as pbar:
|
||||
for chunk in range(0, total_size, 100):
|
||||
pbar.update(100)
|
||||
time.sleep(0.01)
|
||||
|
||||
print(f" Captured {len(captured_progress)} progress updates")
|
||||
assert len(captured_progress) > 0, "Should have captured progress updates"
|
||||
|
||||
# Verify progress increases
|
||||
last_downloaded = 0
|
||||
for downloaded, total, filename in captured_progress:
|
||||
assert downloaded >= last_downloaded, "Downloaded bytes should increase"
|
||||
assert total == total_size, "Total should be consistent"
|
||||
last_downloaded = downloaded
|
||||
|
||||
print("✓ Test 3 PASSED\n")
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
print("✗ tqdm not available, skipping test\n")
|
||||
return None
|
||||
|
||||
|
||||
async def test_full_integration():
|
||||
"""Test 4: Full integration test."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 4: Full Integration (ProgressManager + HFProgressTracker)")
|
||||
print("=" * 60)
|
||||
|
||||
pm = get_progress_manager()
|
||||
collected_events: List[Dict] = []
|
||||
|
||||
# SSE client
|
||||
async def sse_client():
|
||||
print(" SSE client: Subscribing...")
|
||||
async for event in pm.subscribe("integration-test"):
|
||||
if event.startswith("data: "):
|
||||
data = json.loads(event[6:])
|
||||
print(f" SSE client: {data['status']} - {data.get('progress', 0):.1f}% - {data.get('filename', '')}")
|
||||
collected_events.append(data)
|
||||
if data.get("status") in ("complete", "error"):
|
||||
break
|
||||
|
||||
# Simulate backend download with HFProgressTracker
|
||||
async def simulate_real_download():
|
||||
await asyncio.sleep(0.2) # Let SSE subscribe
|
||||
|
||||
print(" Backend: Starting download with HFProgressTracker...")
|
||||
|
||||
# Set up tracking (like the real backend does)
|
||||
progress_callback = create_hf_progress_callback("integration-test", pm)
|
||||
tracker = HFProgressTracker(progress_callback)
|
||||
|
||||
# Initialize progress
|
||||
pm.update_progress(
|
||||
model_name="integration-test",
|
||||
current=0,
|
||||
total=1,
|
||||
filename="",
|
||||
status="downloading"
|
||||
)
|
||||
|
||||
# Simulate download with tqdm patching
|
||||
with tracker.patch_download():
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
|
||||
# Simulate multi-file download (like HuggingFace does)
|
||||
files = [
|
||||
("model.safetensors", 5000),
|
||||
("config.json", 1000),
|
||||
("tokenizer.json", 500),
|
||||
]
|
||||
|
||||
for filename, size in files:
|
||||
print(f" Backend: Downloading {filename}...")
|
||||
with tqdm(total=size, desc=filename, unit="B") as pbar:
|
||||
for chunk in range(0, size, 500):
|
||||
chunk_size = min(500, size - chunk)
|
||||
pbar.update(chunk_size)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Mark complete
|
||||
print(" Backend: Download complete")
|
||||
pm.mark_complete("integration-test")
|
||||
|
||||
except ImportError:
|
||||
print(" ✗ tqdm not available")
|
||||
pm.mark_error("integration-test", "tqdm not available")
|
||||
|
||||
# Run both
|
||||
await asyncio.gather(
|
||||
sse_client(),
|
||||
simulate_real_download()
|
||||
)
|
||||
|
||||
# Verify
|
||||
print(f"\n Collected {len(collected_events)} events")
|
||||
if len(collected_events) > 0:
|
||||
print(f" First event: {collected_events[0]}")
|
||||
print(f" Last event: {collected_events[-1]}")
|
||||
assert collected_events[-1]["status"] == "complete", "Should end with 'complete'"
|
||||
print("✓ Test 4 PASSED\n")
|
||||
return True
|
||||
else:
|
||||
print("✗ Test 4 FAILED - No events received\n")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Voicebox Progress Tracking Test Suite")
|
||||
print("=" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
# Test 1: Basic operations
|
||||
try:
|
||||
results.append(("Basic Operations", test_progress_manager_basic()))
|
||||
except Exception as e:
|
||||
print(f"✗ Test 1 FAILED: {e}\n")
|
||||
results.append(("Basic Operations", False))
|
||||
|
||||
# Test 2: SSE streaming
|
||||
try:
|
||||
results.append(("SSE Streaming", await test_progress_manager_sse()))
|
||||
except Exception as e:
|
||||
print(f"✗ Test 2 FAILED: {e}\n")
|
||||
results.append(("SSE Streaming", False))
|
||||
|
||||
# Test 3: tqdm patching
|
||||
try:
|
||||
results.append(("tqdm Patching", test_hf_progress_tracker()))
|
||||
except Exception as e:
|
||||
print(f"✗ Test 3 FAILED: {e}\n")
|
||||
results.append(("tqdm Patching", False))
|
||||
|
||||
# Test 4: Full integration
|
||||
try:
|
||||
results.append(("Full Integration", await test_full_integration()))
|
||||
except Exception as e:
|
||||
print(f"✗ Test 4 FAILED: {e}\n")
|
||||
results.append(("Full Integration", False))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Test Results Summary")
|
||||
print("=" * 60)
|
||||
|
||||
for name, result in results:
|
||||
status = "✓ PASS" if result else ("⊘ SKIP" if result is None else "✗ FAIL")
|
||||
print(f" {status:8} {name}")
|
||||
|
||||
passed = sum(1 for _, r in results if r is True)
|
||||
failed = sum(1 for _, r in results if r is False)
|
||||
skipped = sum(1 for _, r in results if r is None)
|
||||
|
||||
print()
|
||||
print(f" Total: {len(results)} tests")
|
||||
print(f" Passed: {passed}")
|
||||
print(f" Failed: {failed}")
|
||||
print(f" Skipped: {skipped}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
return failed == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = asyncio.run(main())
|
||||
exit(0 if success else 1)
|
||||
317
backend/tests/test_qwen_download.py
Normal file
317
backend/tests/test_qwen_download.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Test Qwen TTS model download with SSE progress monitoring.
|
||||
|
||||
This specifically tests the MLX TTS backend download progress tracking,
|
||||
which requires tqdm to be patched BEFORE mlx_audio is imported.
|
||||
|
||||
Usage:
|
||||
cd backend && python -m tests.test_qwen_download
|
||||
|
||||
Prerequisites:
|
||||
- Server must be running: cd backend && python main.py
|
||||
- Delete model first for fresh download test:
|
||||
curl -X DELETE http://localhost:8000/models/qwen-tts-0.6B
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import httpx
|
||||
import time
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
|
||||
async def monitor_sse_stream(model_name: str, timeout: int = 600) -> List[Dict]:
|
||||
"""
|
||||
Monitor SSE stream for a model download.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to monitor
|
||||
timeout: Maximum time to wait for download (seconds)
|
||||
|
||||
Returns:
|
||||
List of SSE events received
|
||||
"""
|
||||
events: List[Dict] = []
|
||||
url = f"http://localhost:8000/models/progress/{model_name}"
|
||||
last_progress = -1
|
||||
|
||||
print(f"\n📡 Connecting to SSE endpoint: {url}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("GET", url) as response:
|
||||
print(f" SSE connected, status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f" ❌ Error: SSE endpoint returned {response.status_code}")
|
||||
return events
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
data = json.loads(line[6:])
|
||||
events.append(data)
|
||||
|
||||
# Print progress (only when it changes significantly)
|
||||
progress = data.get('progress', 0)
|
||||
status = data.get('status', 'unknown')
|
||||
filename = data.get('filename', '')
|
||||
current = data.get('current', 0)
|
||||
total = data.get('total', 0)
|
||||
|
||||
# Print every 5% change or status change
|
||||
if abs(progress - last_progress) >= 5 or status in ('complete', 'error'):
|
||||
current_mb = current / (1024 * 1024)
|
||||
total_mb = total / (1024 * 1024)
|
||||
print(f" 📊 {status:12} {progress:6.1f}% ({current_mb:.1f}MB / {total_mb:.1f}MB) {filename[:50]}")
|
||||
last_progress = progress
|
||||
|
||||
# Stop if complete or error
|
||||
if status in ("complete", "error"):
|
||||
if status == "complete":
|
||||
print(f" ✅ Download complete!")
|
||||
else:
|
||||
print(f" ❌ Download error: {data.get('error', 'unknown')}")
|
||||
break
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" ⚠️ Error parsing JSON: {e}")
|
||||
|
||||
elif line.startswith(": heartbeat"):
|
||||
# Heartbeat every 1 second, don't spam
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError:
|
||||
print(" ⏹️ SSE monitor cancelled")
|
||||
except Exception as e:
|
||||
print(f" ❌ SSE error: {e}")
|
||||
|
||||
return events
|
||||
|
||||
|
||||
async def trigger_download(model_name: str) -> bool:
|
||||
"""Trigger a model download via the API."""
|
||||
url = "http://localhost:8000/models/download"
|
||||
|
||||
print(f"\n🚀 Triggering download for: {model_name}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.post(url, json={"model_name": model_name})
|
||||
result = response.json()
|
||||
print(f" Response: {response.status_code} - {result}")
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
print(f" ❌ Error triggering download: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_model(model_name: str) -> bool:
|
||||
"""Delete a model from cache."""
|
||||
url = f"http://localhost:8000/models/{model_name}"
|
||||
|
||||
print(f"\n🗑️ Deleting model: {model_name}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.delete(url)
|
||||
if response.status_code == 200:
|
||||
print(f" ✅ Model deleted")
|
||||
return True
|
||||
elif response.status_code == 404:
|
||||
print(f" ℹ️ Model not found (already deleted)")
|
||||
return True
|
||||
else:
|
||||
print(f" ⚠️ Delete response: {response.status_code} - {response.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f" ❌ Error deleting model: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def check_model_status(model_name: str) -> Optional[Dict]:
|
||||
"""Check the status of a model."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get("http://localhost:8000/models/status")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
for model in data.get("models", []):
|
||||
if model["model_name"] == model_name:
|
||||
return model
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Error checking model status: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def check_server() -> bool:
|
||||
"""Check if the server is running."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
response = await client.get("http://localhost:8000/health")
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
print("=" * 70)
|
||||
print("🧪 Qwen TTS Model Download Progress Test")
|
||||
print("=" * 70)
|
||||
print("\nThis test verifies that MLX TTS download progress tracking works.")
|
||||
print("It specifically tests the tqdm patching for mlx_audio.tts imports.")
|
||||
|
||||
# Check if server is running
|
||||
print("\n📡 Checking if server is running...")
|
||||
if not await check_server():
|
||||
print(" ❌ Server is not running on http://localhost:8000")
|
||||
print("\n Please start the server first:")
|
||||
print(" cd backend && python main.py")
|
||||
return False
|
||||
|
||||
print(" ✅ Server is running")
|
||||
|
||||
# Test model
|
||||
model_name = "qwen-tts-0.6B"
|
||||
|
||||
# Check current status
|
||||
print(f"\n📊 Checking status of {model_name}...")
|
||||
status = await check_model_status(model_name)
|
||||
if status:
|
||||
print(f" Downloaded: {status.get('downloaded', False)}")
|
||||
print(f" Downloading: {status.get('downloading', False)}")
|
||||
print(f" Loaded: {status.get('loaded', False)}")
|
||||
if status.get('size_mb'):
|
||||
print(f" Size: {status['size_mb']:.1f} MB")
|
||||
else:
|
||||
print(" ⚠️ Could not get model status")
|
||||
|
||||
# Ask if user wants to delete first
|
||||
print("\n" + "-" * 70)
|
||||
if status and status.get('downloaded'):
|
||||
print("⚠️ Model is already downloaded. Delete it for a fresh download test?")
|
||||
print(" [y] Yes, delete and download fresh")
|
||||
print(" [n] No, just test SSE connection")
|
||||
print(" [q] Quit")
|
||||
|
||||
choice = input("\nChoice [y/n/q]: ").strip().lower()
|
||||
|
||||
if choice == 'q':
|
||||
print("Exiting...")
|
||||
return True
|
||||
|
||||
if choice == 'y':
|
||||
if not await delete_model(model_name):
|
||||
print("Failed to delete model. Continue anyway? [y/n]")
|
||||
if input().strip().lower() != 'y':
|
||||
return False
|
||||
else:
|
||||
print("Model not downloaded. Will perform fresh download test.")
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# Run the test
|
||||
print("\n" + "=" * 70)
|
||||
print("🏃 Starting Download Test")
|
||||
print("=" * 70)
|
||||
|
||||
async def run_test():
|
||||
# Start SSE monitor in background FIRST
|
||||
monitor_task = asyncio.create_task(monitor_sse_stream(model_name, timeout=600))
|
||||
|
||||
# Wait for SSE to connect
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Trigger download
|
||||
success = await trigger_download(model_name)
|
||||
|
||||
if not success:
|
||||
print(" ❌ Failed to trigger download")
|
||||
monitor_task.cancel()
|
||||
try:
|
||||
await monitor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
return []
|
||||
|
||||
# Wait for SSE monitor to complete
|
||||
print("\n⏳ Waiting for download to complete (this may take several minutes)...")
|
||||
events = await monitor_task
|
||||
|
||||
return events
|
||||
|
||||
start_time = time.time()
|
||||
events = await run_test()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Results
|
||||
print("\n" + "=" * 70)
|
||||
print("📋 Test Results")
|
||||
print("=" * 70)
|
||||
|
||||
print(f"\n⏱️ Elapsed time: {elapsed:.1f} seconds")
|
||||
print(f"📨 Total SSE events received: {len(events)}")
|
||||
|
||||
if not events:
|
||||
print("\n❌ FAILED - No SSE events received!")
|
||||
print("\nPossible causes:")
|
||||
print(" 1. SSE endpoint not working")
|
||||
print(" 2. tqdm not patched before mlx_audio import")
|
||||
print(" 3. Progress callbacks not firing")
|
||||
print(" 4. Model already fully downloaded")
|
||||
print("\nDebug steps:")
|
||||
print(" 1. Check server logs for [DEBUG] messages")
|
||||
print(" 2. Look for 'tqdm patched' before 'mlx_audio.tts import'")
|
||||
print(f" 3. Delete model: curl -X DELETE http://localhost:8000/models/{model_name}")
|
||||
return False
|
||||
|
||||
# Analyze events
|
||||
first_event = events[0]
|
||||
last_event = events[-1]
|
||||
|
||||
print(f"\n📊 First event:")
|
||||
print(f" Status: {first_event.get('status')}")
|
||||
print(f" Progress: {first_event.get('progress', 0):.1f}%")
|
||||
|
||||
print(f"\n📊 Last event:")
|
||||
print(f" Status: {last_event.get('status')}")
|
||||
print(f" Progress: {last_event.get('progress', 0):.1f}%")
|
||||
|
||||
# Check for expected behaviors
|
||||
has_progress_updates = len(events) > 2
|
||||
has_increasing_progress = False
|
||||
has_complete = any(e.get('status') == 'complete' for e in events)
|
||||
has_100_percent = any(e.get('progress', 0) >= 100 for e in events)
|
||||
|
||||
# Check if progress increased over time
|
||||
if len(events) >= 2:
|
||||
progress_values = [e.get('progress', 0) for e in events]
|
||||
has_increasing_progress = progress_values[-1] > progress_values[0]
|
||||
|
||||
print("\n📋 Checks:")
|
||||
print(f" {'✅' if has_progress_updates else '❌'} Multiple progress updates received ({len(events)} events)")
|
||||
print(f" {'✅' if has_increasing_progress else '❌'} Progress increased over time")
|
||||
print(f" {'✅' if has_100_percent else '❌'} Reached 100% progress")
|
||||
print(f" {'✅' if has_complete else '❌'} Received 'complete' status")
|
||||
|
||||
# Overall result
|
||||
success = has_progress_updates and has_complete
|
||||
|
||||
if success:
|
||||
print("\n" + "=" * 70)
|
||||
print("✅ TEST PASSED - Qwen TTS download progress tracking works!")
|
||||
print("=" * 70)
|
||||
else:
|
||||
print("\n" + "=" * 70)
|
||||
print("❌ TEST FAILED - Progress tracking has issues")
|
||||
print("=" * 70)
|
||||
print("\nCheck the server logs for debug output.")
|
||||
|
||||
return success
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = asyncio.run(main())
|
||||
exit(0 if result else 1)
|
||||
54
backend/tests/test_task_queue_cancellation.py
Normal file
54
backend/tests/test_task_queue_cancellation.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.services import task_queue
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_queued_generation_skips_execution():
|
||||
task_queue.init_queue(force=True)
|
||||
|
||||
running_started = asyncio.Event()
|
||||
release_running = asyncio.Event()
|
||||
queued_ran = asyncio.Event()
|
||||
|
||||
async def running_job():
|
||||
running_started.set()
|
||||
await release_running.wait()
|
||||
|
||||
async def queued_job():
|
||||
queued_ran.set()
|
||||
|
||||
task_queue.enqueue_generation("gen-running", running_job())
|
||||
await asyncio.wait_for(running_started.wait(), timeout=1)
|
||||
|
||||
task_queue.enqueue_generation("gen-queued", queued_job())
|
||||
assert task_queue.cancel_generation("gen-queued") == "queued"
|
||||
|
||||
release_running.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert not queued_ran.is_set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_running_generation_cancels_task():
|
||||
task_queue.init_queue(force=True)
|
||||
|
||||
running_started = asyncio.Event()
|
||||
running_cancelled = asyncio.Event()
|
||||
|
||||
async def running_job():
|
||||
running_started.set()
|
||||
try:
|
||||
await asyncio.Event().wait()
|
||||
except asyncio.CancelledError:
|
||||
running_cancelled.set()
|
||||
raise
|
||||
|
||||
task_queue.enqueue_generation("gen-running", running_job())
|
||||
await asyncio.wait_for(running_started.wait(), timeout=1)
|
||||
|
||||
assert task_queue.cancel_generation("gen-running") == "running"
|
||||
await asyncio.wait_for(running_cancelled.wait(), timeout=1)
|
||||
178
backend/tests/test_whisper_download.py
Normal file
178
backend/tests/test_whisper_download.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Test real model download with SSE progress monitoring.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import httpx
|
||||
import time
|
||||
from typing import List, Dict
|
||||
|
||||
async def monitor_sse_stream(model_name: str, timeout: int = 300):
|
||||
"""Monitor SSE stream for a model download."""
|
||||
events: List[Dict] = []
|
||||
url = f"http://localhost:8000/models/progress/{model_name}"
|
||||
|
||||
print(f"Connecting to SSE endpoint: {url}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("GET", url) as response:
|
||||
print(f"SSE connected, status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"Error: SSE endpoint returned {response.status_code}")
|
||||
return events
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
print(f" Raw SSE: {line[:100]}...") # Print first 100 chars
|
||||
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
data = json.loads(line[6:])
|
||||
print(f" → {data['status']:12} {data.get('progress', 0):6.1f}% {data.get('filename', '')}")
|
||||
events.append(data)
|
||||
|
||||
# Stop if complete or error
|
||||
if data.get("status") in ("complete", "error"):
|
||||
print(f" Download {data['status']}!")
|
||||
break
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" Error parsing JSON: {e}")
|
||||
print(f" Line was: {line}")
|
||||
|
||||
elif line.startswith(": heartbeat"):
|
||||
print(" ♥ heartbeat")
|
||||
|
||||
return events
|
||||
|
||||
|
||||
async def trigger_download(model_name: str):
|
||||
"""Trigger a model download via the API."""
|
||||
url = "http://localhost:8000/models/download"
|
||||
|
||||
print(f"\nTriggering download for: {model_name}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=300) as client:
|
||||
response = await client.post(url, json={"model_name": model_name})
|
||||
print(f"Response: {response.status_code} - {response.json()}")
|
||||
return response.status_code == 200
|
||||
|
||||
|
||||
async def check_server():
|
||||
"""Check if the server is running."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
response = await client.get("http://localhost:8000/health")
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
print(f"Server not running: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
print("=" * 60)
|
||||
print("Real Model Download Progress Test")
|
||||
print("=" * 60)
|
||||
|
||||
# Check if server is running
|
||||
print("\nChecking if server is running...")
|
||||
if not await check_server():
|
||||
print("✗ Server is not running on http://localhost:8000")
|
||||
print("\nPlease start the server first:")
|
||||
print(" cd backend && python main.py")
|
||||
return False
|
||||
|
||||
print("✓ Server is running")
|
||||
|
||||
# Choose a small model for testing
|
||||
model_name = "whisper-base" # ~150MB, faster to download
|
||||
print(f"\nUsing model: {model_name}")
|
||||
|
||||
# Option to delete model first if it exists
|
||||
print("\nDo you want to delete the model first to force a fresh download? (y/n)")
|
||||
# For automated testing, skip deletion prompt
|
||||
# delete_first = input().strip().lower() == 'y'
|
||||
delete_first = False
|
||||
|
||||
if delete_first:
|
||||
print(f"Deleting {model_name}...")
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.delete(f"http://localhost:8000/models/{model_name}")
|
||||
print(f"Delete response: {response.status_code}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting Test")
|
||||
print("=" * 60)
|
||||
|
||||
# Start monitoring SSE stream BEFORE triggering download
|
||||
async def run_test():
|
||||
# Start SSE monitor in background
|
||||
monitor_task = asyncio.create_task(monitor_sse_stream(model_name))
|
||||
|
||||
# Wait a bit to ensure SSE is connected
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Trigger download
|
||||
success = await trigger_download(model_name)
|
||||
|
||||
if not success:
|
||||
print("✗ Failed to trigger download")
|
||||
monitor_task.cancel()
|
||||
return False
|
||||
|
||||
# Wait for SSE monitor to complete
|
||||
events = await monitor_task
|
||||
|
||||
return events
|
||||
|
||||
events = await run_test()
|
||||
|
||||
# Results
|
||||
print("\n" + "=" * 60)
|
||||
print("Test Results")
|
||||
print("=" * 60)
|
||||
|
||||
if not events:
|
||||
print("✗ FAILED - No SSE events received!")
|
||||
print("\nPossible causes:")
|
||||
print(" 1. SSE endpoint not working")
|
||||
print(" 2. Progress updates not being sent")
|
||||
print(" 3. Model already downloaded (no progress to report)")
|
||||
print("\nTry deleting the model first to force a fresh download:")
|
||||
print(f" curl -X DELETE http://localhost:8000/models/{model_name}")
|
||||
return False
|
||||
|
||||
print(f"✓ Received {len(events)} SSE events")
|
||||
print(f"\nFirst event: {events[0]}")
|
||||
print(f"Last event: {events[-1]}")
|
||||
|
||||
# Check if we got meaningful progress
|
||||
has_progress = any(e.get('progress', 0) > 0 for e in events)
|
||||
has_complete = any(e.get('status') == 'complete' for e in events)
|
||||
|
||||
if has_progress:
|
||||
print("✓ Progress updates received")
|
||||
else:
|
||||
print("✗ No progress updates (might be already downloaded)")
|
||||
|
||||
if has_complete:
|
||||
print("✓ Download completed successfully")
|
||||
else:
|
||||
print("✗ Download did not complete")
|
||||
|
||||
success = has_progress and has_complete
|
||||
|
||||
if success:
|
||||
print("\n✓ TEST PASSED - Progress tracking works!")
|
||||
else:
|
||||
print("\n⊘ TEST INCONCLUSIVE - Try with a fresh download")
|
||||
|
||||
return success
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user