Initial commit

This commit is contained in:
2026-04-24 19:18:15 +08:00
commit fbcbe08696
555 changed files with 96692 additions and 0 deletions

View File

@@ -0,0 +1,220 @@
# End-to-End Model Generation Test — Design
## Goal
A single script, runnable on macOS and Windows, that exercises every TTS model against the **frozen PyInstaller binary** (not the dev server), captures per-model pass/fail and error messages, and exits non-zero if any model fails. Generation is strictly sequential — one model loaded at a time.
## Test matrix (10 runs)
Derived from `backend/backends/__init__.py:185-316`. Each row maps to one `POST /generate` call.
| # | engine | model_size | profile kind | notes |
|---|-----------------------|------------|--------------|-------|
| 1 | `qwen` | `1.7B` | cloned | reference audio required |
| 2 | `qwen` | `0.6B` | cloned | |
| 3 | `qwen_custom_voice` | `1.7B` | preset | `preset_voice_id="Ryan"` |
| 4 | `qwen_custom_voice` | `0.6B` | preset | `preset_voice_id="Ryan"` |
| 5 | `luxtts` | — | cloned | English only |
| 6 | `chatterbox` | — | cloned | |
| 7 | `chatterbox_turbo` | — | cloned | English only |
| 8 | `tada` | `1B` | cloned | tada-1b, English only |
| 9 | `tada` | `3B` | cloned | tada-3b-ml, multilingual |
| 10| `kokoro` | — | preset | `preset_voice_id="af_heart"` |
Cloned engines (1, 2, 5, 6, 7, 8, 9) share **one** profile created once with the reference WAV. Preset profiles are created separately, one for kokoro and one for qwen_custom_voice.
Language for every run: `en` (covers every engine's supported set).
## End-to-end flow
```
1. Resolve paths → find binary, build if missing
2. Launch binary → spawn with --port --data-dir --parent-pid
3. Wait for /health → poll until status=="healthy" or 120s timeout
4. Create profiles → 1 cloned + 2 preset, via /profiles (+ /samples)
5. For each (engine, model_size) in matrix:
a. Check cache → GET /models/status → cached? short timeout : long
b. POST /generate → get generation_id
c. Stream /status → consume SSE until completed/failed/timeout
d. Record result → {engine, model_size, status, duration, error, elapsed}
6. Write results → JSON + Markdown table to ./results/
7. Shutdown binary → SIGTERM, fall back to kill, verify port freed
8. Exit code → 0 if all passed, 1 otherwise
```
## Binary resolution
Search order — **first hit wins**:
| Platform | Path | Build type |
|----------|------|------------|
| macOS | `backend/dist/voicebox-server-cuda/voicebox-server-cuda` | onedir (CUDA, rarely on Mac) |
| macOS | `backend/dist/voicebox-server` | onefile (CPU) |
| Windows | `backend\dist\voicebox-server-cuda\voicebox-server-cuda.exe` | onedir (CUDA) |
| Windows | `backend\dist\voicebox-server.exe` | onefile (CPU) |
If none exist, run `python backend/build_binary.py` and wait for it to finish (can take 5-20 min). Fail with a clear error if the build itself fails. `--skip-build` flag forces "error out if no binary" instead of building.
## Spawn command
Mirrors Tauri's launch in `tauri/src-tauri/src/main.rs:369-388`:
```
<binary> --host 127.0.0.1 --port <free-port> --data-dir <tempdir> --parent-pid <test-pid>
```
- **Port**: bind to `0` first in Python to grab a free port, then pass that number.
- **Data dir**: `tempfile.mkdtemp(prefix="voicebox-e2e-")`. Deleted after the run unless `--keep-data-dir`. Profiles and generated WAVs land here.
- **Parent PID**: current Python PID — ensures the backend dies if the test crashes (watchdog in `server.py:102-224`).
- **stdout/stderr**: tee to both a log file in `./results/server-<timestamp>.log` and a rolling in-memory buffer. On model failure, last 100 lines of the buffer are attached to that model's error record.
## Profile setup
One cloned profile shared across all cloning engines:
```http
POST /profiles
{
"name": "e2e-cloned",
"voice_type": "cloned",
"language": "en"
}
```
Then:
```http
POST /profiles/{id}/samples (multipart)
file: <reference WAV>
reference_text: <exact transcription>
```
Two preset profiles:
```http
POST /profiles
{ "name": "e2e-kokoro", "voice_type": "preset", "language": "en",
"preset_engine": "kokoro", "preset_voice_id": "af_heart" }
POST /profiles
{ "name": "e2e-qwen-cv", "voice_type": "preset", "language": "en",
"preset_engine": "qwen_custom_voice", "preset_voice_id": "Ryan" }
```
## Generation request (per matrix row)
```http
POST /generate
{
"profile_id": "<appropriate profile>",
"text": "The quick brown fox jumps over the lazy dog.",
"language": "en",
"engine": "<engine>",
"model_size": "<size or omitted>",
"seed": 42,
"normalize": true
}
```
Response `id` feeds into the SSE status loop (`GET /generate/{id}/status`, `routes/generations.py:190-227`). Loop reads lines until a payload with `status in ("completed", "failed")` arrives, then breaks.
## Timeout strategy (split)
Check `GET /models/status` for the target model **before** generation:
| Cached? | Per-model timeout | Rationale |
|---------|-------------------|-----------|
| Yes | **3 minutes** | Inference only; generous for CPU builds |
| No | **20 minutes** | First-run HF download up to 8 GB (tada-3b-ml) |
On timeout: cancel the SSE stream, mark the row `timeout`, and continue to the next row. Don't abort the whole run on one timeout.
## Result format
`./results/e2e-<platform>-<arch>-<timestamp>.json`:
```json
{
"platform": "darwin-arm64",
"binary": "/abs/path/voicebox-server",
"binary_size_mb": 612,
"started_at": "2026-04-16T12:34:56Z",
"finished_at": "...",
"results": [
{
"engine": "qwen",
"model_size": "1.7B",
"status": "passed|failed|timeout",
"generation_id": "...",
"was_cached": true,
"elapsed_seconds": 12.4,
"audio_duration": 3.1,
"audio_path": "/tmp/.../gen.wav",
"error": null,
"server_log_tail": null
}
]
}
```
Companion `./results/e2e-<...>.md`:
```
# Voicebox E2E — darwin-arm64 — 2026-04-16 12:34
| Engine | Size | Status | Elapsed | Error |
|---------------------|------|--------|---------|-------|
| qwen | 1.7B | PASS | 12.4s | |
| qwen | 0.6B | FAIL | 4.1s | CUDA OOM: ... |
...
```
## CLI flags
```
python -m backend.tests.test_all_models_e2e [flags]
--binary PATH Use this binary instead of auto-detecting
--skip-build Error if no binary found (no auto-build)
--reference-wav PATH Reference audio (default: backend/tests/fixtures/reference_voice.wav)
--reference-text STR Transcription (default: read from fixtures/reference_voice.txt)
--only ENGINE[,...] Run only these engines (e.g. kokoro,qwen)
--skip ENGINE[,...] Skip these engines
--keep-data-dir Don't delete tempdir after run
--timeout-cached SEC Override 180
--timeout-download SEC Override 1200
--port N Override auto-picked port
--output-dir PATH Default: backend/tests/results/
```
## File layout
```
backend/tests/
├── E2E_MODEL_TEST_DESIGN.md (this file)
├── test_all_models_e2e.py (main script, ~400-500 LoC)
├── fixtures/
│ ├── reference_voice.wav (user-provided, ~5-15s clean speech)
│ └── reference_voice.txt (exact transcription)
└── results/ (gitignored)
├── e2e-darwin-arm64-<ts>.json
├── e2e-darwin-arm64-<ts>.md
└── server-<ts>.log
```
The script uses only stdlib + `httpx` (or `requests`) + `sseclient-py` — all already in `backend/requirements.txt`. No pytest to keep it invocable as a single command on fresh checkouts.
## Safety & cleanup
- Always kill the spawned binary in a `try/finally`. On Windows, `taskkill /F /T` the whole tree (Tauri does the same).
- Verify the port is free on shutdown (Tauri port-reuse check in `main.rs:114-186` could otherwise pick up a ghost).
- Don't touch the user's HF cache by default — let the server use `HF_HUB_CACHE` / `VOICEBOX_MODELS_DIR`. Passing `--isolated-cache` would point both env vars at the tempdir for a true cold-start run (opt-in only; would re-download every time).
## Non-goals
- Not validating audio quality (no WER, no waveform comparison). Pass = "endpoint returned `completed` and produced a non-empty WAV".
- Not testing STT (Whisper), effects chains, channels, or streaming endpoints.
- Not running on CI today — human-invoked on dev machines. CI integration is a follow-up once the script is stable.
- No model unload between runs — models stay loaded; server manages its own eviction.
- No version-drift check on the binary.
- No `instruct` parameter exercised on qwen_custom_voice runs.

58
backend/tests/README.md Normal file
View File

@@ -0,0 +1,58 @@
# Backend Tests
Manual test scripts for debugging and validating backend functionality.
## Test Files
### `test_generation_progress.py`
Tests TTS generation with SSE progress monitoring to identify UX issues where users see download progress even when the model is already cached.
**Usage:**
```bash
cd backend
python tests/test_generation_progress.py
```
**Prerequisites:**
- Server must be running (`python main.py`)
- At least one voice profile must exist
### `test_real_download.py`
Tests real model download with SSE progress monitoring.
**Usage:**
```bash
cd backend
# Delete cache first to force fresh download
rm -rf ~/.cache/huggingface/hub/models--openai--whisper-base
python tests/test_real_download.py
```
**Prerequisites:**
- Server must be running (`python main.py`)
### `test_progress.py`
Unit tests for ProgressManager and HFProgressTracker functionality.
**Usage:**
```bash
cd backend
python tests/test_progress.py
```
### `test_check_progress_state.py`
Debugging script to inspect the internal state of ProgressManager and TaskManager.
**Usage:**
```bash
cd backend
python tests/test_check_progress_state.py
```
## Notes
These are manual test scripts, not automated unit tests. They're designed for:
- Debugging progress tracking issues
- Validating SSE event streams
- Monitoring real-time download behavior
- Inspecting internal state during development

View File

@@ -0,0 +1,6 @@
"""
Test suite for Voicebox backend.
This directory contains manual test scripts for debugging and validating
progress tracking, model downloads, and generation functionality.
"""

16
backend/tests/fixtures/README.md vendored Normal file
View File

@@ -0,0 +1,16 @@
# E2E Test Fixtures
Place two files here before running `test_all_models_e2e.py`:
- `reference_voice.wav` — a clean speech sample, mono, 1624 kHz, ~515 seconds.
- `reference_voice.txt` — the **exact** transcription of the WAV (single line, no trailing newline required).
These are used to create a cloned voice profile for every cloning-capable engine (qwen, luxtts, chatterbox, chatterbox_turbo, tada). Keep them out of version control if they contain personal audio — this directory is not gitignored by default, so add them to `.gitignore` locally if needed.
You can point the test at different files with:
```
python backend/tests/test_all_models_e2e.py \
--reference-wav /path/to/your.wav \
--reference-text "exact transcription here"
```

View File

@@ -0,0 +1,630 @@
"""
End-to-end model generation test.
Exercises every TTS model against the frozen PyInstaller binary, captures
per-model pass/fail, and writes a JSON + Markdown report.
Usage:
python backend/tests/test_all_models_e2e.py [flags]
See E2E_MODEL_TEST_DESIGN.md for the full design.
"""
from __future__ import annotations
import argparse
import json
import os
import platform
import shutil
import signal
import socket
import subprocess
import sys
import tempfile
import threading
import time
from collections import deque
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
import httpx
REPO_ROOT = Path(__file__).resolve().parents[2]
BACKEND_DIR = REPO_ROOT / "backend"
DIST_DIR = BACKEND_DIR / "dist"
FIXTURES_DIR = Path(__file__).resolve().parent / "fixtures"
RESULTS_DIR = Path(__file__).resolve().parent / "results"
# ── Test matrix ──────────────────────────────────────────────────────
@dataclass(frozen=True)
class MatrixRow:
label: str # human-readable (appears in report)
engine: str # /generate engine
model_size: Optional[str] # /generate model_size (None = omit)
profile_kind: str # "cloned" | "preset_kokoro" | "preset_qwen_cv"
model_name: str # /models/status key for cache lookup
MATRIX: list[MatrixRow] = [
MatrixRow("qwen 1.7B", "qwen", "1.7B", "cloned", "qwen-tts-1.7B"),
MatrixRow("qwen 0.6B", "qwen", "0.6B", "cloned", "qwen-tts-0.6B"),
MatrixRow("qwen_custom_voice 1.7B", "qwen_custom_voice", "1.7B", "preset_qwen_cv", "qwen-custom-voice-1.7B"),
MatrixRow("qwen_custom_voice 0.6B", "qwen_custom_voice", "0.6B", "preset_qwen_cv", "qwen-custom-voice-0.6B"),
MatrixRow("luxtts", "luxtts", None, "cloned", "luxtts"),
MatrixRow("chatterbox", "chatterbox", None, "cloned", "chatterbox-tts"),
MatrixRow("chatterbox_turbo", "chatterbox_turbo", None, "cloned", "chatterbox-turbo"),
MatrixRow("tada 1B", "tada", "1B", "cloned", "tada-1b"),
MatrixRow("tada 3B", "tada", "3B", "cloned", "tada-3b-ml"),
MatrixRow("kokoro", "kokoro", None, "preset_kokoro", "kokoro"),
]
TEXT = "The quick brown fox jumps over the lazy dog."
DEFAULT_TIMEOUT_CACHED = 180
DEFAULT_TIMEOUT_DOWNLOAD = 1200
HEALTH_TIMEOUT = 120
# ── Result record ────────────────────────────────────────────────────
@dataclass
class ModelResult:
label: str
engine: str
model_size: Optional[str]
status: str # "passed" | "failed" | "timeout"
was_cached: Optional[bool] = None
generation_id: Optional[str] = None
elapsed_seconds: float = 0.0
audio_duration: Optional[float] = None
audio_path: Optional[str] = None
audio_bytes: Optional[int] = None
error: Optional[str] = None
http_status: Optional[int] = None
server_log_tail: Optional[list[str]] = None
# ── Binary resolution ────────────────────────────────────────────────
def find_binary() -> Optional[Path]:
"""Return the first existing binary in priority order, or None."""
is_win = platform.system() == "Windows"
exe = ".exe" if is_win else ""
candidates = [
DIST_DIR / "voicebox-server-cuda" / f"voicebox-server-cuda{exe}",
DIST_DIR / f"voicebox-server{exe}",
]
for c in candidates:
if c.exists() and c.is_file():
return c
return None
def build_binary() -> Path:
"""Invoke build_binary.py and return the resulting binary path."""
print("[build] No frozen binary found — invoking build_binary.py (this may take 5-20 minutes)...", flush=True)
script = BACKEND_DIR / "build_binary.py"
result = subprocess.run(
[sys.executable, str(script)],
cwd=str(BACKEND_DIR),
)
if result.returncode != 0:
raise RuntimeError(f"build_binary.py exited with code {result.returncode}")
found = find_binary()
if found is None:
raise RuntimeError("build_binary.py finished but no binary was found in backend/dist/")
return found
# ── Server spawn + log capture ───────────────────────────────────────
class ServerProcess:
def __init__(self, binary: Path, port: int, data_dir: Path, log_path: Path):
self.binary = binary
self.port = port
self.data_dir = data_dir
self.log_path = log_path
self.proc: Optional[subprocess.Popen] = None
self._log_buffer: deque[str] = deque(maxlen=500)
self._reader_thread: Optional[threading.Thread] = None
def start(self) -> None:
args = [
str(self.binary),
"--host", "127.0.0.1",
"--port", str(self.port),
"--data-dir", str(self.data_dir),
"--parent-pid", str(os.getpid()),
]
print(f"[spawn] {' '.join(args)}", flush=True)
self._log_fh = open(self.log_path, "w", encoding="utf-8", errors="replace")
# Combine stderr into stdout so we get a single ordered stream.
self.proc = subprocess.Popen(
args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
bufsize=1,
text=True,
errors="replace",
)
self._reader_thread = threading.Thread(target=self._pump_logs, daemon=True)
self._reader_thread.start()
def _pump_logs(self) -> None:
assert self.proc is not None and self.proc.stdout is not None
for line in self.proc.stdout:
self._log_buffer.append(line.rstrip("\n"))
self._log_fh.write(line)
self._log_fh.flush()
def log_tail(self, n: int = 100) -> list[str]:
tail = list(self._log_buffer)[-n:]
return tail
def is_alive(self) -> bool:
return self.proc is not None and self.proc.poll() is None
def stop(self) -> None:
if self.proc is None:
return
if self.proc.poll() is not None:
return
try:
if platform.system() == "Windows":
subprocess.run(
["taskkill", "/F", "/T", "/PID", str(self.proc.pid)],
capture_output=True,
)
else:
self.proc.send_signal(signal.SIGTERM)
except Exception as e:
print(f"[shutdown] signal failed: {e}", flush=True)
try:
self.proc.wait(timeout=10)
except subprocess.TimeoutExpired:
print("[shutdown] server didn't exit cleanly, killing", flush=True)
self.proc.kill()
try:
self.proc.wait(timeout=5)
except subprocess.TimeoutExpired:
pass
if self._reader_thread is not None:
self._reader_thread.join(timeout=2)
try:
self._log_fh.close()
except Exception:
pass
def pick_free_port() -> int:
s = socket.socket()
s.bind(("127.0.0.1", 0))
port = s.getsockname()[1]
s.close()
return port
# ── HTTP helpers ─────────────────────────────────────────────────────
def wait_for_health(base_url: str, server: ServerProcess, timeout: int) -> None:
deadline = time.time() + timeout
with httpx.Client(timeout=5.0) as client:
while time.time() < deadline:
if not server.is_alive():
raise RuntimeError("Server process exited before becoming healthy")
try:
r = client.get(f"{base_url}/health")
if r.status_code == 200 and r.json().get("status") == "healthy":
return
except httpx.HTTPError:
pass
time.sleep(1.0)
raise TimeoutError(f"Server did not become healthy within {timeout}s")
def get_model_cached(client: httpx.Client, base_url: str, model_name: str) -> Optional[bool]:
try:
r = client.get(f"{base_url}/models/status", timeout=30.0)
r.raise_for_status()
for m in r.json().get("models", []):
if m.get("model_name") == model_name:
return bool(m.get("downloaded"))
except httpx.HTTPError:
return None
return None
def create_cloned_profile(client: httpx.Client, base_url: str, wav_path: Path, reference_text: str) -> str:
r = client.post(f"{base_url}/profiles", json={
"name": "e2e-cloned",
"voice_type": "cloned",
"language": "en",
})
r.raise_for_status()
profile_id = r.json()["id"]
with open(wav_path, "rb") as f:
r = client.post(
f"{base_url}/profiles/{profile_id}/samples",
files={"file": (wav_path.name, f, "audio/wav")},
data={"reference_text": reference_text},
timeout=120.0,
)
r.raise_for_status()
return profile_id
def create_preset_profile(client: httpx.Client, base_url: str, name: str, engine: str, voice_id: str) -> str:
r = client.post(f"{base_url}/profiles", json={
"name": name,
"voice_type": "preset",
"language": "en",
"preset_engine": engine,
"preset_voice_id": voice_id,
})
r.raise_for_status()
return r.json()["id"]
def run_one_generation(
client: httpx.Client,
base_url: str,
row: MatrixRow,
profile_id: str,
timeout_s: int,
) -> tuple[str, dict]:
"""Start a generation and stream its status until done/failed/timeout.
Returns (status, payload) where status is "completed" | "failed" | "timeout".
"""
body = {
"profile_id": profile_id,
"text": TEXT,
"language": "en",
"engine": row.engine,
"seed": 42,
"normalize": True,
}
if row.model_size is not None:
body["model_size"] = row.model_size
r = client.post(f"{base_url}/generate", json=body, timeout=30.0)
r.raise_for_status()
gen = r.json()
gen_id = gen["id"]
deadline = time.time() + timeout_s
last_payload: dict = gen
status_url = f"{base_url}/generate/{gen_id}/status"
while time.time() < deadline:
remaining = max(1.0, deadline - time.time())
try:
with client.stream("GET", status_url, timeout=httpx.Timeout(remaining + 5, read=remaining + 5)) as resp:
resp.raise_for_status()
for line in resp.iter_lines():
if not line or not line.startswith("data: "):
continue
try:
payload = json.loads(line[6:])
except json.JSONDecodeError:
continue
last_payload = payload
status = payload.get("status")
if status == "not_found":
return "failed", {"error": "generation not found", **payload}
if status in ("completed", "failed"):
return status, payload
if time.time() >= deadline:
break
except httpx.HTTPError:
time.sleep(1.0)
continue
return "timeout", last_payload
def fetch_audio_info(
client: httpx.Client, base_url: str, generation_id: str, data_dir: Path
) -> tuple[Optional[str], Optional[int]]:
"""Return (audio_path, audio_bytes) for a completed generation.
Server stores audio_path relative to data_dir; resolve it to get a size.
"""
try:
r = client.get(f"{base_url}/history/{generation_id}", timeout=10.0)
if r.status_code != 200:
return None, None
data = r.json()
audio_path = data.get("audio_path")
if not audio_path:
return None, None
p = Path(audio_path)
if not p.is_absolute():
p = data_dir / p
if p.exists():
return str(p), p.stat().st_size
return audio_path, None
except httpx.HTTPError:
return None, None
# ── Report writers ───────────────────────────────────────────────────
def write_reports(
output_dir: Path,
binary: Path,
started_at: datetime,
finished_at: datetime,
results: list[ModelResult],
) -> tuple[Path, Path]:
output_dir.mkdir(parents=True, exist_ok=True)
plat = f"{platform.system().lower()}-{platform.machine().lower()}"
ts = started_at.strftime("%Y%m%d-%H%M%S")
json_path = output_dir / f"e2e-{plat}-{ts}.json"
md_path = output_dir / f"e2e-{plat}-{ts}.md"
doc = {
"platform": plat,
"binary": str(binary),
"binary_size_mb": round(binary.stat().st_size / (1024 * 1024), 1) if binary.exists() else None,
"started_at": started_at.isoformat(),
"finished_at": finished_at.isoformat(),
"elapsed_seconds": (finished_at - started_at).total_seconds(),
"results": [asdict(r) for r in results],
}
json_path.write_text(json.dumps(doc, indent=2))
lines = [
f"# Voicebox E2E — {plat}{started_at.strftime('%Y-%m-%d %H:%M UTC')}",
"",
f"Binary: `{binary}` ",
f"Elapsed: {doc['elapsed_seconds']:.1f}s",
"",
"| Model | Status | Cached | Elapsed | Audio | Error |",
"|-------|--------|--------|---------|-------|-------|",
]
for r in results:
status_icon = {"passed": "PASS", "failed": "FAIL", "timeout": "TIMEOUT"}.get(r.status, r.status.upper())
cached = "yes" if r.was_cached else ("no" if r.was_cached is False else "?")
audio_col = f"{r.audio_duration:.2f}s" if r.audio_duration else ("" if r.status != "passed" else "?")
error_col = (r.error or "").replace("\n", " ")[:120]
lines.append(f"| {r.label} | {status_icon} | {cached} | {r.elapsed_seconds:.1f}s | {audio_col} | {error_col} |")
failed_rows = [r for r in results if r.status != "passed"]
if failed_rows:
lines.append("")
lines.append("## Failures")
for r in failed_rows:
lines.append("")
lines.append(f"### {r.label}{r.status}")
if r.error:
lines.append("")
lines.append("```")
lines.append(r.error)
lines.append("```")
if r.server_log_tail:
lines.append("")
lines.append("<details><summary>server log (last lines)</summary>")
lines.append("")
lines.append("```")
lines.extend(r.server_log_tail)
lines.append("```")
lines.append("</details>")
md_path.write_text("\n".join(lines) + "\n")
return json_path, md_path
# ── Main ─────────────────────────────────────────────────────────────
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Voicebox E2E model generation test")
p.add_argument("--binary", type=Path, help="Path to voicebox-server binary (overrides auto-detect)")
p.add_argument("--skip-build", action="store_true", help="Error if binary missing instead of building")
p.add_argument(
"--reference-wav",
type=Path,
default=FIXTURES_DIR / "reference_voice.wav",
help="Reference audio for cloning engines",
)
p.add_argument(
"--reference-text",
help="Transcription of reference-wav (default: read from fixtures/reference_voice.txt)",
)
p.add_argument("--only", help="Comma-separated engines to run (e.g. kokoro,qwen)")
p.add_argument("--skip", help="Comma-separated engines to skip")
p.add_argument("--keep-data-dir", action="store_true", help="Don't delete tempdir after run")
p.add_argument("--timeout-cached", type=int, default=DEFAULT_TIMEOUT_CACHED)
p.add_argument("--timeout-download", type=int, default=DEFAULT_TIMEOUT_DOWNLOAD)
p.add_argument("--port", type=int, help="Override auto-picked port")
p.add_argument("--output-dir", type=Path, default=RESULTS_DIR)
return p.parse_args()
def filter_matrix(args: argparse.Namespace) -> list[MatrixRow]:
only = set(x.strip() for x in args.only.split(",")) if args.only else None
skip = set(x.strip() for x in args.skip.split(",")) if args.skip else set()
rows = []
for r in MATRIX:
if only is not None and r.engine not in only:
continue
if r.engine in skip:
continue
rows.append(r)
return rows
def resolve_reference(args: argparse.Namespace) -> tuple[Path, str]:
wav = args.reference_wav
if not wav.exists():
raise FileNotFoundError(
f"Reference WAV not found: {wav}\n"
f"Place a sample at {FIXTURES_DIR / 'reference_voice.wav'} or pass --reference-wav.\n"
f"See backend/tests/fixtures/README.md."
)
if args.reference_text:
text = args.reference_text
else:
txt_path = wav.with_suffix(".txt")
if not txt_path.exists():
raise FileNotFoundError(
f"Reference transcription not found: {txt_path}\n"
f"Create it next to the WAV, or pass --reference-text."
)
text = txt_path.read_text().strip()
if not text:
raise ValueError("Reference transcription is empty")
return wav, text
def main() -> int:
args = parse_args()
rows = filter_matrix(args)
if not rows:
print("No rows selected after --only/--skip filtering", file=sys.stderr)
return 2
# Binary
binary = args.binary or find_binary()
if binary is None:
if args.skip_build:
print("No frozen binary found and --skip-build set. Run: python backend/build_binary.py", file=sys.stderr)
return 2
binary = build_binary()
if not binary.exists():
print(f"Binary path does not exist: {binary}", file=sys.stderr)
return 2
print(f"[binary] {binary}", flush=True)
# Reference audio (only required if any cloning row is in the matrix)
needs_reference = any(r.profile_kind == "cloned" for r in rows)
ref_wav: Optional[Path] = None
ref_text: Optional[str] = None
if needs_reference:
try:
ref_wav, ref_text = resolve_reference(args)
except (FileNotFoundError, ValueError) as e:
print(f"[fixture] {e}", file=sys.stderr)
return 2
print(f"[fixture] reference WAV: {ref_wav}", flush=True)
print(f"[fixture] reference text: {ref_text!r}", flush=True)
# Tempdir + log path
data_dir = Path(tempfile.mkdtemp(prefix="voicebox-e2e-"))
args.output_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
log_path = args.output_dir / f"server-{ts}.log"
port = args.port or pick_free_port()
base_url = f"http://127.0.0.1:{port}"
server = ServerProcess(binary=binary, port=port, data_dir=data_dir, log_path=log_path)
started_at = datetime.now(timezone.utc)
results: list[ModelResult] = []
try:
server.start()
print(f"[health] waiting for {base_url}/health ...", flush=True)
wait_for_health(base_url, server, HEALTH_TIMEOUT)
print("[health] ready", flush=True)
with httpx.Client(timeout=30.0) as client:
# Profile setup (only create what's needed)
cloned_profile_id: Optional[str] = None
kokoro_profile_id: Optional[str] = None
qwen_cv_profile_id: Optional[str] = None
needed_kinds = {r.profile_kind for r in rows}
if "cloned" in needed_kinds:
assert ref_wav is not None and ref_text is not None
print("[profile] creating cloned profile...", flush=True)
cloned_profile_id = create_cloned_profile(client, base_url, ref_wav, ref_text)
if "preset_kokoro" in needed_kinds:
print("[profile] creating kokoro preset...", flush=True)
kokoro_profile_id = create_preset_profile(client, base_url, "e2e-kokoro", "kokoro", "af_heart")
if "preset_qwen_cv" in needed_kinds:
print("[profile] creating qwen_custom_voice preset...", flush=True)
qwen_cv_profile_id = create_preset_profile(client, base_url, "e2e-qwen-cv", "qwen_custom_voice", "Ryan")
profile_lookup = {
"cloned": cloned_profile_id,
"preset_kokoro": kokoro_profile_id,
"preset_qwen_cv": qwen_cv_profile_id,
}
# Matrix loop
for row in rows:
print(f"\n[run] {row.label} (engine={row.engine}, size={row.model_size})", flush=True)
profile_id = profile_lookup[row.profile_kind]
assert profile_id is not None
was_cached = get_model_cached(client, base_url, row.model_name)
timeout_s = args.timeout_cached if was_cached else args.timeout_download
print(f"[run] cached={was_cached} timeout={timeout_s}s", flush=True)
t0 = time.time()
result = ModelResult(
label=row.label,
engine=row.engine,
model_size=row.model_size,
status="failed",
was_cached=was_cached,
)
try:
status, payload = run_one_generation(client, base_url, row, profile_id, timeout_s)
result.status = "passed" if status == "completed" else status
result.generation_id = payload.get("id")
result.audio_duration = payload.get("duration")
result.error = payload.get("error")
if status == "completed" and result.generation_id:
audio_path, audio_bytes = fetch_audio_info(
client, base_url, result.generation_id, data_dir
)
result.audio_path = audio_path
result.audio_bytes = audio_bytes
if audio_bytes is not None and audio_bytes == 0:
result.status = "failed"
result.error = (result.error or "") + " (audio file is empty)"
except httpx.HTTPStatusError as e:
result.status = "failed"
result.http_status = e.response.status_code
try:
detail = e.response.json().get("detail")
except Exception:
detail = e.response.text
result.error = f"HTTP {e.response.status_code}: {detail}"
except Exception as e:
result.status = "failed"
result.error = f"{type(e).__name__}: {e}"
result.elapsed_seconds = round(time.time() - t0, 2)
if result.status != "passed":
result.server_log_tail = server.log_tail(100)
print(f"[run] {row.label}{result.status} in {result.elapsed_seconds}s"
+ (f" ({result.error})" if result.error else ""), flush=True)
results.append(result)
finally:
finished_at = datetime.now(timezone.utc)
server.stop()
if not args.keep_data_dir:
shutil.rmtree(data_dir, ignore_errors=True)
else:
print(f"[cleanup] keeping data dir: {data_dir}", flush=True)
json_path, md_path = write_reports(args.output_dir, binary, started_at, finished_at, results)
print(f"\n[report] {json_path}")
print(f"[report] {md_path}")
print(f"[report] server log: {log_path}")
passed = sum(1 for r in results if r.status == "passed")
failed = len(results) - passed
print(f"\n== {passed} passed, {failed} failed ==")
return 0 if failed == 0 else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,112 @@
"""
Unit tests for reference-audio preprocessing.
Covers :func:`backend.utils.audio.preprocess_reference_audio` and
:func:`backend.utils.audio.validate_and_load_reference_audio`.
"""
import sys
from pathlib import Path
import numpy as np
import pytest
import soundfile as sf
sys.path.insert(0, str(Path(__file__).parent.parent))
from utils.audio import ( # noqa: E402
preprocess_reference_audio,
validate_and_load_reference_audio,
)
SR = 24000
def _tone(duration_s: float, amp: float = 0.3, freq: float = 220.0) -> np.ndarray:
n = int(duration_s * SR)
t = np.arange(n, dtype=np.float32) / SR
return (amp * np.sin(2 * np.pi * freq * t)).astype(np.float32)
def test_peak_cap_scales_hot_input():
audio = _tone(3.0, amp=0.99)
out = preprocess_reference_audio(audio, SR)
assert np.abs(out).max() <= 0.951
def test_peak_cap_leaves_moderate_input_untouched():
audio = _tone(3.0, amp=0.5)
out = preprocess_reference_audio(audio, SR)
assert np.isclose(np.abs(out).max(), 0.5, atol=1e-3)
def test_dc_offset_removed():
audio = _tone(3.0, amp=0.3) + 0.1
out = preprocess_reference_audio(audio, SR)
assert abs(float(np.mean(out))) < 1e-3
def test_silence_is_trimmed_with_padding_kept():
silence = np.zeros(int(SR * 1.0), dtype=np.float32)
speech = _tone(3.0, amp=0.3)
audio = np.concatenate([silence, speech, silence])
out = preprocess_reference_audio(audio, SR)
# Most of the 2s of leading/trailing silence should be gone, but the
# 3s of speech plus ~200ms of padding should remain.
assert len(audio) - len(out) >= SR, "expected >=1s of silence trimmed"
assert len(out) >= int(3.0 * SR), "speech body should be preserved"
def test_clean_audio_is_not_padded_past_original_length():
# Well-recorded audio with no edge silence shouldn't get longer after
# preprocessing — otherwise a 29.9 s upload could be pushed past the
# 30 s max_duration ceiling downstream.
audio = _tone(3.0, amp=0.3)
out = preprocess_reference_audio(audio, SR)
assert len(out) <= len(audio)
def test_empty_input_returns_empty():
out = preprocess_reference_audio(np.zeros(0, dtype=np.float32), SR)
assert out.size == 0
def test_validate_accepts_previously_rejected_hot_file(tmp_path):
audio = _tone(3.0, amp=0.995)
path = tmp_path / "hot.wav"
sf.write(str(path), audio, SR)
ok, err, out_audio, out_sr = validate_and_load_reference_audio(str(path))
assert ok, f"expected pass, got error: {err}"
assert out_audio is not None
assert out_sr == SR
assert np.abs(out_audio).max() <= 0.951
def test_validate_still_rejects_silent_input(tmp_path):
audio = np.zeros(int(SR * 3.0), dtype=np.float32)
path = tmp_path / "silent.wav"
sf.write(str(path), audio, SR)
ok, err, _, _ = validate_and_load_reference_audio(str(path))
assert not ok
assert err is not None
assert "too short" in err.lower() or "quiet" in err.lower()
def test_validate_rejects_too_short(tmp_path):
audio = _tone(0.5, amp=0.3)
path = tmp_path / "short.wav"
sf.write(str(path), audio, SR)
ok, err, _, _ = validate_and_load_reference_audio(str(path))
assert not ok
assert "too short" in (err or "").lower()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

162
backend/tests/test_cors.py Normal file
View File

@@ -0,0 +1,162 @@
"""
Tests for CORS origin restrictions.
Validates that the CORS middleware only allows known local origins
and respects the VOICEBOX_CORS_ORIGINS environment variable.
Uses a minimal FastAPI app that mirrors the exact CORS configuration
from backend/main.py, so tests run without heavy ML dependencies.
Usage:
pip install httpx pytest fastapi starlette
python -m pytest backend/tests/test_cors.py -v
"""
import os
import pytest
from unittest.mock import patch
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.testclient import TestClient
def _build_app(env_origins: str = "") -> FastAPI:
"""
Build a minimal FastAPI app with the same CORS logic as backend/main.py.
This mirrors the exact code in main.py so the test validates the real
configuration without needing torch/numpy/transformers installed.
"""
app = FastAPI()
_default_origins = [
"http://localhost:5173",
"http://127.0.0.1:5173",
"http://localhost:17493",
"http://127.0.0.1:17493",
"tauri://localhost",
"https://tauri.localhost",
]
_cors_origins = _default_origins + [o.strip() for o in env_origins.split(",") if o.strip()]
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
async def health():
return {"status": "ok"}
return app
@pytest.fixture()
def client():
return TestClient(_build_app())
@pytest.fixture()
def client_with_custom_origins():
return TestClient(_build_app("https://custom.example.com,https://other.example.com"))
def _get_with_origin(client: TestClient, origin: str) -> dict:
"""Send a GET with Origin header, return response headers."""
response = client.get("/health", headers={"Origin": origin})
return dict(response.headers)
def _preflight(client: TestClient, origin: str) -> dict:
"""Send CORS preflight OPTIONS request, return response headers."""
response = client.options(
"/health",
headers={
"Origin": origin,
"Access-Control-Request-Method": "GET",
},
)
return dict(response.headers)
class TestCORSDefaultOrigins:
"""CORS should allow known local origins and block everything else."""
@pytest.mark.parametrize("origin", [
"http://localhost:5173",
"http://127.0.0.1:5173",
"http://localhost:17493",
"http://127.0.0.1:17493",
"tauri://localhost",
"https://tauri.localhost",
])
def test_allowed_origins(self, client, origin):
headers = _get_with_origin(client, origin)
assert headers.get("access-control-allow-origin") == origin
@pytest.mark.parametrize("origin", [
"http://evil.com",
"http://localhost:9999",
"https://attacker.example.com",
"null",
])
def test_blocked_origins(self, client, origin):
headers = _get_with_origin(client, origin)
assert "access-control-allow-origin" not in headers
def test_preflight_allowed(self, client):
headers = _preflight(client, "http://localhost:5173")
assert headers.get("access-control-allow-origin") == "http://localhost:5173"
def test_preflight_blocked(self, client):
headers = _preflight(client, "http://evil.com")
assert "access-control-allow-origin" not in headers
def test_credentials_header_present(self, client):
headers = _get_with_origin(client, "http://localhost:5173")
assert headers.get("access-control-allow-credentials") == "true"
class TestCORSCustomOrigins:
"""VOICEBOX_CORS_ORIGINS env var should extend the allowlist."""
def test_custom_origin_allowed(self, client_with_custom_origins):
headers = _get_with_origin(client_with_custom_origins, "https://custom.example.com")
assert headers.get("access-control-allow-origin") == "https://custom.example.com"
def test_other_custom_origin_allowed(self, client_with_custom_origins):
headers = _get_with_origin(client_with_custom_origins, "https://other.example.com")
assert headers.get("access-control-allow-origin") == "https://other.example.com"
def test_default_origins_still_work(self, client_with_custom_origins):
headers = _get_with_origin(client_with_custom_origins, "http://localhost:5173")
assert headers.get("access-control-allow-origin") == "http://localhost:5173"
def test_unlisted_origin_still_blocked(self, client_with_custom_origins):
headers = _get_with_origin(client_with_custom_origins, "http://evil.com")
assert "access-control-allow-origin" not in headers
class TestCORSEnvVarParsing:
"""Edge cases for VOICEBOX_CORS_ORIGINS parsing."""
def test_empty_env_var(self):
app = _build_app("")
client = TestClient(app)
headers = _get_with_origin(client, "http://evil.com")
assert "access-control-allow-origin" not in headers
def test_whitespace_trimmed(self):
app = _build_app(" https://spaced.example.com ")
client = TestClient(app)
headers = _get_with_origin(client, "https://spaced.example.com")
assert headers.get("access-control-allow-origin") == "https://spaced.example.com"
def test_trailing_comma_ignored(self):
app = _build_app("https://one.example.com,")
client = TestClient(app)
headers = _get_with_origin(client, "https://one.example.com")
assert headers.get("access-control-allow-origin") == "https://one.example.com"

View File

@@ -0,0 +1,305 @@
"""
Test TTS generation with SSE progress monitoring.
This test captures the exact SSE events triggered during generation
to identify UX issues where users see download progress even when
the model is already cached.
"""
import asyncio
import json
import httpx
from typing import List, Dict, Optional
from datetime import datetime
async def monitor_sse_stream(model_name: str, timeout: int = 120):
"""Monitor SSE stream for a model during generation."""
events: List[Dict] = []
url = f"http://localhost:8000/models/progress/{model_name}"
print(f"[{_timestamp()}] Connecting to SSE endpoint: {url}")
try:
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("GET", url) as response:
print(f"[{_timestamp()}] SSE connected, status: {response.status_code}")
if response.status_code != 200:
print(f"[{_timestamp()}] Error: SSE endpoint returned {response.status_code}")
return events
async for line in response.aiter_lines():
if not line:
continue
timestamp = _timestamp()
if line.startswith("data: "):
try:
data = json.loads(line[6:])
print(
f"[{timestamp}] → SSE Event: {data['status']:12} {data.get('progress', 0):6.1f}% {data.get('filename', '')}"
)
events.append({**data, "_timestamp": timestamp})
# Stop if complete or error
if data.get("status") in ("complete", "error"):
print(f"[{timestamp}] → Model {data['status']}!")
break
except json.JSONDecodeError as e:
print(f"[{timestamp}] Error parsing JSON: {e}")
print(f" Line was: {line}")
elif line.startswith(": heartbeat"):
print(f"[{timestamp}] ♥ heartbeat")
except asyncio.TimeoutError:
print(f"[{_timestamp()}] SSE monitoring timed out")
except Exception as e:
print(f"[{_timestamp()}] SSE error: {e}")
return events
async def trigger_generation(profile_id: str, text: str, model_size: str = "1.7B"):
"""Trigger TTS generation via the API."""
url = "http://localhost:8000/generate"
print(f"\n[{_timestamp()}] Triggering generation...")
print(f" Profile: {profile_id}")
print(f" Text: {text[:50]}...")
print(f" Model: {model_size}")
try:
async with httpx.AsyncClient(timeout=120) as client:
response = await client.post(
url,
json={
"profile_id": profile_id,
"text": text,
"language": "en",
"model_size": model_size,
},
)
print(f"[{_timestamp()}] Response: {response.status_code}")
if response.status_code == 200:
result = response.json()
print(f"[{_timestamp()}] ✓ Generation successful!")
print(f" Generation ID: {result.get('id')}")
print(f" Duration: {result.get('duration', 0):.2f}s")
return True, result
elif response.status_code == 202:
# Model is being downloaded
result = response.json()
print(f"[{_timestamp()}] → Model download in progress")
print(f" Detail: {result}")
return False, result
else:
print(f"[{_timestamp()}] ✗ Error: {response.text}")
return False, None
except Exception as e:
print(f"[{_timestamp()}] ✗ Exception: {e}")
return False, None
async def get_first_profile():
"""Get the first available voice profile."""
url = "http://localhost:8000/profiles"
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(url)
if response.status_code == 200:
profiles = response.json()
if profiles:
return profiles[0]["id"]
except Exception as e:
print(f"Error getting profiles: {e}")
return None
async def check_server():
"""Check if the server is running."""
try:
async with httpx.AsyncClient(timeout=5) as client:
response = await client.get("http://localhost:8000/health")
return response.status_code == 200
except Exception as e:
print(f"Server not running: {e}")
return False
def _timestamp():
"""Get current timestamp for logging."""
return datetime.now().strftime("%H:%M:%S.%f")[:-3]
async def test_generation_with_cached_model():
"""
Test Case 1: Generation when model is already cached.
This should NOT show any download progress events.
If it does, that's the UX bug we're trying to fix.
"""
print("\n" + "=" * 80)
print("TEST CASE 1: Generation with Cached Model")
print("=" * 80)
print("Expected: No download progress events (or minimal/instant completion)")
print("Actual UX Issue: Users see 'started' and 'finished' events even for cached models")
print("=" * 80)
model_size = "1.7B"
model_name = f"qwen-tts-{model_size}"
# Get a profile
profile_id = await get_first_profile()
if not profile_id:
print("✗ No voice profiles found. Please create a profile first.")
return False
print(f"\nUsing profile: {profile_id}")
# Start SSE monitor BEFORE triggering generation
monitor_task = asyncio.create_task(monitor_sse_stream(model_name, timeout=30))
# Wait for SSE to connect
await asyncio.sleep(1)
# Trigger generation
test_text = "Hello, this is a test of the voice generation system."
success, result = await trigger_generation(profile_id, test_text, model_size)
if not success and result and result.get("downloading"):
print("\n⚠ Model is being downloaded. Waiting for download to complete...")
# Wait for SSE monitor to capture download events
events = await monitor_task
return events
# Wait a bit more to catch any progress events
await asyncio.sleep(3)
# Cancel SSE monitor
monitor_task.cancel()
try:
events = await monitor_task
except asyncio.CancelledError:
events = []
return events
async def test_generation_with_fresh_download():
"""
Test Case 2: Generation when model needs to be downloaded.
This SHOULD show download progress events.
"""
print("\n" + "=" * 80)
print("TEST CASE 2: Generation with Model Download")
print("=" * 80)
print("Expected: Download progress events from 0% to 100%")
print("=" * 80)
# Use a different model size to force download
model_size = "0.6B" # Smaller model for faster testing
model_name = f"qwen-tts-{model_size}"
# Get a profile
profile_id = await get_first_profile()
if not profile_id:
print("✗ No voice profiles found. Please create a profile first.")
return False
print(f"\nUsing profile: {profile_id}")
print("Note: This will download the model if not cached")
# Start SSE monitor BEFORE triggering generation
monitor_task = asyncio.create_task(monitor_sse_stream(model_name, timeout=300))
# Wait for SSE to connect
await asyncio.sleep(1)
# Trigger generation
test_text = "This should trigger a model download if the model is not cached."
success, result = await trigger_generation(profile_id, test_text, model_size)
if not success and result and result.get("downloading"):
print("\n→ Model download initiated. Monitoring progress...")
# Wait for download to complete
events = await monitor_task
# Try generation again
print(f"\n[{_timestamp()}] Retrying generation after download...")
await asyncio.sleep(2)
success, result = await trigger_generation(profile_id, test_text, model_size)
if success:
print("✓ Generation successful after download")
return events
# If model was already cached
await asyncio.sleep(3)
monitor_task.cancel()
try:
events = await monitor_task
except asyncio.CancelledError:
events = []
return events
async def main():
print("=" * 80)
print("TTS Generation Progress Test")
print("=" * 80)
print("Purpose: Capture exact SSE events during generation to identify UX issues")
print("=" * 80)
# Check if server is running
print(f"\n[{_timestamp()}] Checking if server is running...")
if not await check_server():
print("✗ Server is not running on http://localhost:8000")
print("\nPlease start the server first:")
print(" cd backend && python main.py")
return False
print("✓ Server is running")
# Test Case 1: Cached model
print("\n" + "🧪 " * 20)
events_cached = await test_generation_with_cached_model()
# Results for Test Case 1
print("\n" + "=" * 80)
print("TEST CASE 1 RESULTS: Generation with Cached Model")
print("=" * 80)
if not events_cached:
print("✓ GOOD: No SSE progress events received")
print(" This is the expected behavior for a cached model.")
else:
print(f"⚠ ISSUE FOUND: Received {len(events_cached)} SSE events:")
print("\nEvent Timeline:")
for i, event in enumerate(events_cached, 1):
timestamp = event.pop("_timestamp", "??:??:??.???")
print(f" {i}. [{timestamp}] {event}")
print("\n⚠ This explains the UX issue!")
print(" Users see progress events even when the model is already cached,")
print(" making them think the model is downloading again.")
print("\n" + "=" * 80)
print("Test Complete!")
print("=" * 80)
return True
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,118 @@
"""
Unit tests for the ``force_offline_if_cached`` helper.
Verifies that the helper mutates the cached module constants in
``huggingface_hub.constants`` and ``transformers.utils.hub`` — not just
``os.environ`` — and that concurrent users are refcount-coordinated so
one thread's exit can't strip another thread's offline protection.
NOTE: These tests mutate process-global state in ``huggingface_hub.constants``
and ``transformers.utils.hub``. They are not safe under cross-process
parallelism (e.g. ``pytest-xdist`` with ``--dist=loadfile``/``loadscope``);
run this file serially.
"""
import os
import sys
import threading
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
from utils.hf_offline_patch import force_offline_if_cached # noqa: E402
def _hf_const():
import huggingface_hub.constants as hf_const
return hf_const
def _tf_hub():
import transformers.utils.hub as tf_hub
return tf_hub
def test_mutates_cached_huggingface_hub_constant():
original = _hf_const().HF_HUB_OFFLINE
with force_offline_if_cached(True, "t"):
assert _hf_const().HF_HUB_OFFLINE is True
assert original == _hf_const().HF_HUB_OFFLINE
def test_mutates_cached_transformers_constant():
original = _tf_hub()._is_offline_mode
with force_offline_if_cached(True, "t"):
assert _tf_hub()._is_offline_mode is True
assert original == _tf_hub()._is_offline_mode
def test_sets_env_variable():
original = os.environ.get("HF_HUB_OFFLINE")
with force_offline_if_cached(True, "t"):
assert "1" == os.environ.get("HF_HUB_OFFLINE")
assert original == os.environ.get("HF_HUB_OFFLINE")
def test_noop_when_not_cached():
before = _hf_const().HF_HUB_OFFLINE
with force_offline_if_cached(False, "t"):
assert before == _hf_const().HF_HUB_OFFLINE
def test_nested_contexts_respect_refcount():
original = _hf_const().HF_HUB_OFFLINE
with force_offline_if_cached(True, "outer"):
assert _hf_const().HF_HUB_OFFLINE is True
with force_offline_if_cached(True, "inner"):
assert _hf_const().HF_HUB_OFFLINE is True
# inner exit must not restore while outer is still active
assert _hf_const().HF_HUB_OFFLINE is True
assert original == _hf_const().HF_HUB_OFFLINE
def test_concurrent_threads_share_offline_window():
"""A slow thread must keep seeing offline mode even if a peer exits first."""
original = _hf_const().HF_HUB_OFFLINE
observations: list[bool] = []
errors: list[Exception] = []
barrier = threading.Barrier(2)
fast_exited = threading.Event()
def slow():
try:
with force_offline_if_cached(True, "slow"):
barrier.wait(timeout=5)
assert fast_exited.wait(timeout=5), "fast thread did not exit"
observations.append(_hf_const().HF_HUB_OFFLINE)
except Exception as exc: # noqa: BLE001
errors.append(exc)
def fast():
try:
with force_offline_if_cached(True, "fast"):
barrier.wait(timeout=5)
except Exception as exc: # noqa: BLE001
errors.append(exc)
finally:
fast_exited.set()
t_slow = threading.Thread(target=slow)
t_fast = threading.Thread(target=fast)
t_slow.start()
t_fast.start()
t_slow.join(timeout=5)
t_fast.join(timeout=5)
assert not t_slow.is_alive(), "slow thread did not finish"
assert not t_fast.is_alive(), "fast thread did not finish"
assert not errors, errors
assert observations == [True], "slow thread lost offline protection"
assert original == _hf_const().HF_HUB_OFFLINE
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,113 @@
"""
Unit tests for ``patch_transformers_mistral_regex``.
Verifies that our wrapper around
``transformers.PreTrainedTokenizerBase._patch_mistral_regex`` catches
exceptions from the unconditional ``huggingface_hub.model_info()`` lookup
and returns the tokenizer unchanged — matching the success-path behavior
for non-Mistral repos (transformers 4.57.3, ``tokenization_utils_base.py:2503``).
NOTE: These tests mutate ``transformers.PreTrainedTokenizerBase`` globally;
run serially, not under ``pytest-xdist`` with per-worker process isolation.
"""
import sys
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
from huggingface_hub.errors import OfflineModeIsEnabled # noqa: E402
from transformers.tokenization_utils_base import PreTrainedTokenizerBase # noqa: E402
import utils.hf_offline_patch as hf_offline_patch # noqa: E402
@pytest.fixture(autouse=True)
def restore_mistral_regex():
"""Snapshot the current ``_patch_mistral_regex`` and restore after each test."""
saved = PreTrainedTokenizerBase.__dict__.get("_patch_mistral_regex")
saved_flag = hf_offline_patch._mistral_regex_patched
try:
yield
finally:
if saved is not None:
PreTrainedTokenizerBase._patch_mistral_regex = saved
hf_offline_patch._mistral_regex_patched = saved_flag
def _apply_patch():
hf_offline_patch._mistral_regex_patched = False
hf_offline_patch.patch_transformers_mistral_regex()
def test_suppresses_offline_mode_is_enabled(monkeypatch):
_apply_patch()
import huggingface_hub
def raise_offline(*_args, **_kwargs):
raise OfflineModeIsEnabled("offline")
monkeypatch.setattr(huggingface_hub, "model_info", raise_offline)
sentinel = object()
result = PreTrainedTokenizerBase._patch_mistral_regex(
sentinel, "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
)
assert result is sentinel
def test_suppresses_connection_errors(monkeypatch):
_apply_patch()
import huggingface_hub
def raise_connection(*_args, **_kwargs):
raise ConnectionError("network unreachable")
monkeypatch.setattr(huggingface_hub, "model_info", raise_connection)
sentinel = object()
result = PreTrainedTokenizerBase._patch_mistral_regex(
sentinel, "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
)
assert result is sentinel
def test_passthrough_on_success(monkeypatch):
"""When model_info returns non-Mistral tags the original falls through and returns the tokenizer unchanged."""
_apply_patch()
import huggingface_hub
class FakeInfo:
tags = ["model-type:qwen", "language:en"]
monkeypatch.setattr(huggingface_hub, "model_info", lambda *_a, **_kw: FakeInfo())
sentinel = object()
result = PreTrainedTokenizerBase._patch_mistral_regex(
sentinel, "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
)
assert result is sentinel
def test_idempotent():
_apply_patch()
first = PreTrainedTokenizerBase._patch_mistral_regex
hf_offline_patch.patch_transformers_mistral_regex()
second = PreTrainedTokenizerBase._patch_mistral_regex
assert first.__func__ is second.__func__
def test_missing_method_is_noop(monkeypatch):
monkeypatch.delattr(PreTrainedTokenizerBase, "_patch_mistral_regex", raising=False)
hf_offline_patch._mistral_regex_patched = False
hf_offline_patch.patch_transformers_mistral_regex()
assert hf_offline_patch._mistral_regex_patched is False
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,217 @@
"""
Tests for profile duplicate name validation.
This test suite verifies that the application correctly handles
duplicate profile names and provides user-friendly error messages.
"""
import pytest
import tempfile
import shutil
from pathlib import Path
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Add parent directory to path to import backend modules
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from database import Base, VoiceProfile as DBVoiceProfile
from models import VoiceProfileCreate
from profiles import create_profile, update_profile
@pytest.fixture
def test_db():
"""Create a temporary test database."""
# Create temporary directory for test database
temp_dir = tempfile.mkdtemp()
db_path = Path(temp_dir) / "test.db"
# Create engine and session
engine = create_engine(f"sqlite:///{db_path}")
Base.metadata.create_all(bind=engine)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()
yield db
# Cleanup
db.close()
shutil.rmtree(temp_dir)
@pytest.fixture
def mock_profiles_dir(monkeypatch, tmp_path):
"""Mock the profiles directory to use a temporary path."""
from backend import config
monkeypatch.setattr(config, 'get_profiles_dir', lambda: tmp_path)
return tmp_path
@pytest.mark.asyncio
async def test_create_profile_duplicate_name_raises_error(test_db, mock_profiles_dir):
"""Test that creating a profile with a duplicate name raises a ValueError."""
# Create first profile
profile_data_1 = VoiceProfileCreate(
name="Test Profile",
description="First profile",
language="en"
)
profile_1 = await create_profile(profile_data_1, test_db)
assert profile_1.name == "Test Profile"
# Try to create second profile with same name
profile_data_2 = VoiceProfileCreate(
name="Test Profile",
description="Second profile",
language="en"
)
with pytest.raises(ValueError) as exc_info:
await create_profile(profile_data_2, test_db)
# Verify error message is user-friendly
assert "already exists" in str(exc_info.value)
assert "Test Profile" in str(exc_info.value)
assert "choose a different name" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_create_profile_different_names_succeeds(test_db, mock_profiles_dir):
"""Test that creating profiles with different names succeeds."""
# Create first profile
profile_data_1 = VoiceProfileCreate(
name="Profile One",
description="First profile",
language="en"
)
profile_1 = await create_profile(profile_data_1, test_db)
assert profile_1.name == "Profile One"
# Create second profile with different name
profile_data_2 = VoiceProfileCreate(
name="Profile Two",
description="Second profile",
language="en"
)
profile_2 = await create_profile(profile_data_2, test_db)
assert profile_2.name == "Profile Two"
# Verify both profiles exist
assert profile_1.id != profile_2.id
@pytest.mark.asyncio
async def test_update_profile_to_duplicate_name_raises_error(test_db, mock_profiles_dir):
"""Test that updating a profile to a duplicate name raises a ValueError."""
# Create two profiles with different names
profile_data_1 = VoiceProfileCreate(
name="Profile A",
description="First profile",
language="en"
)
profile_1 = await create_profile(profile_data_1, test_db)
profile_data_2 = VoiceProfileCreate(
name="Profile B",
description="Second profile",
language="en"
)
profile_2 = await create_profile(profile_data_2, test_db)
# Try to update profile_2 to use profile_1's name
update_data = VoiceProfileCreate(
name="Profile A", # Duplicate name
description="Updated description",
language="en"
)
with pytest.raises(ValueError) as exc_info:
await update_profile(profile_2.id, update_data, test_db)
# Verify error message is user-friendly
assert "already exists" in str(exc_info.value)
assert "Profile A" in str(exc_info.value)
@pytest.mark.asyncio
async def test_update_profile_keep_same_name_succeeds(test_db, mock_profiles_dir):
"""Test that updating a profile while keeping the same name succeeds."""
# Create profile
profile_data = VoiceProfileCreate(
name="My Profile",
description="Original description",
language="en"
)
profile = await create_profile(profile_data, test_db)
# Update profile with same name but different description
update_data = VoiceProfileCreate(
name="My Profile", # Same name
description="Updated description",
language="en"
)
updated_profile = await update_profile(profile.id, update_data, test_db)
# Verify update succeeded
assert updated_profile is not None
assert updated_profile.id == profile.id
assert updated_profile.name == "My Profile"
assert updated_profile.description == "Updated description"
@pytest.mark.asyncio
async def test_update_profile_to_new_unique_name_succeeds(test_db, mock_profiles_dir):
"""Test that updating a profile to a new unique name succeeds."""
# Create profile
profile_data = VoiceProfileCreate(
name="Original Name",
description="Profile description",
language="en"
)
profile = await create_profile(profile_data, test_db)
# Update profile with new unique name
update_data = VoiceProfileCreate(
name="New Unique Name",
description="Updated description",
language="en"
)
updated_profile = await update_profile(profile.id, update_data, test_db)
# Verify update succeeded
assert updated_profile is not None
assert updated_profile.id == profile.id
assert updated_profile.name == "New Unique Name"
@pytest.mark.asyncio
async def test_case_sensitive_names_allowed(test_db, mock_profiles_dir):
"""Test that profile names are case-sensitive (e.g., 'Test' and 'test' are different)."""
# Create profile with lowercase name
profile_data_1 = VoiceProfileCreate(
name="test profile",
description="Lowercase",
language="en"
)
profile_1 = await create_profile(profile_data_1, test_db)
# Create profile with different case
profile_data_2 = VoiceProfileCreate(
name="Test Profile",
description="Title case",
language="en"
)
profile_2 = await create_profile(profile_data_2, test_db)
# Both should succeed since SQLite unique constraint is case-sensitive by default
assert profile_1.name == "test profile"
assert profile_2.name == "Test Profile"
assert profile_1.id != profile_2.id

View File

@@ -0,0 +1,313 @@
"""
Test script to debug model download progress tracking.
"""
import asyncio
import json
import time
from typing import List, Dict
import logging
# Set up logging to see what's happening
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
from utils.progress import ProgressManager, get_progress_manager
from utils.hf_progress import HFProgressTracker, create_hf_progress_callback
def test_progress_manager_basic():
"""Test 1: Basic ProgressManager functionality."""
print("\n" + "=" * 60)
print("Test 1: ProgressManager Basic Operations")
print("=" * 60)
pm = ProgressManager()
# Test update_progress
pm.update_progress(
model_name="test-model",
current=50,
total=100,
filename="test.bin",
status="downloading"
)
# Test get_progress
progress = pm.get_progress("test-model")
print(f"✓ Progress stored: {progress}")
assert progress is not None
assert progress["progress"] == 50.0
assert progress["filename"] == "test.bin"
assert progress["status"] == "downloading"
# Test mark_complete
pm.mark_complete("test-model")
progress = pm.get_progress("test-model")
print(f"✓ Marked complete: {progress}")
assert progress["status"] == "complete"
assert progress["progress"] == 100.0
print("✓ Test 1 PASSED\n")
return True
async def test_progress_manager_sse():
"""Test 2: ProgressManager SSE streaming."""
print("\n" + "=" * 60)
print("Test 2: ProgressManager SSE Streaming")
print("=" * 60)
pm = ProgressManager()
collected_events: List[Dict] = []
# Simulate SSE client
async def sse_client():
"""Simulates a frontend SSE connection."""
print(" SSE client: Subscribing to test-model-sse...")
async for event in pm.subscribe("test-model-sse"):
# Parse SSE event
if event.startswith("data: "):
data = json.loads(event[6:])
print(f" SSE client: Received event: {data['status']} - {data.get('progress', 0):.1f}%")
collected_events.append(data)
# Stop when complete
if data.get("status") in ("complete", "error"):
break
elif event.startswith(": heartbeat"):
print(" SSE client: Received heartbeat")
# Simulate download progress updates (from backend thread)
async def simulate_download():
"""Simulates backend sending progress updates."""
print(" Backend: Starting simulated download...")
await asyncio.sleep(0.2) # Let SSE client subscribe first
# Send progress updates
for i in range(0, 101, 20):
print(f" Backend: Updating progress to {i}%")
pm.update_progress(
model_name="test-model-sse",
current=i,
total=100,
filename=f"file_{i}.bin",
status="downloading" if i < 100 else "downloading"
)
await asyncio.sleep(0.1)
# Mark complete
print(" Backend: Marking download complete")
pm.mark_complete("test-model-sse")
# Run SSE client and download simulation concurrently
await asyncio.gather(
sse_client(),
simulate_download()
)
# Verify we got events
print(f"\n Collected {len(collected_events)} events")
assert len(collected_events) > 0, "Should have received at least one event"
assert collected_events[-1]["status"] == "complete", "Last event should be 'complete'"
print("✓ Test 2 PASSED\n")
return True
def test_hf_progress_tracker():
"""Test 3: HFProgressTracker tqdm patching."""
print("\n" + "=" * 60)
print("Test 3: HFProgressTracker tqdm Patching")
print("=" * 60)
captured_progress: List[tuple] = []
def progress_callback(downloaded: int, total: int, filename: str):
"""Capture progress updates."""
captured_progress.append((downloaded, total, filename))
print(f" Progress callback: {downloaded}/{total} bytes ({filename})")
tracker = HFProgressTracker(progress_callback)
# Simulate a download with tqdm
with tracker.patch_download():
try:
from tqdm import tqdm
# Simulate downloading a file
print(" Simulating download with tqdm...")
total_size = 1000
with tqdm(total=total_size, desc="model.bin", unit="B", unit_scale=True) as pbar:
for chunk in range(0, total_size, 100):
pbar.update(100)
time.sleep(0.01)
print(f" Captured {len(captured_progress)} progress updates")
assert len(captured_progress) > 0, "Should have captured progress updates"
# Verify progress increases
last_downloaded = 0
for downloaded, total, filename in captured_progress:
assert downloaded >= last_downloaded, "Downloaded bytes should increase"
assert total == total_size, "Total should be consistent"
last_downloaded = downloaded
print("✓ Test 3 PASSED\n")
return True
except ImportError:
print("✗ tqdm not available, skipping test\n")
return None
async def test_full_integration():
"""Test 4: Full integration test."""
print("\n" + "=" * 60)
print("Test 4: Full Integration (ProgressManager + HFProgressTracker)")
print("=" * 60)
pm = get_progress_manager()
collected_events: List[Dict] = []
# SSE client
async def sse_client():
print(" SSE client: Subscribing...")
async for event in pm.subscribe("integration-test"):
if event.startswith("data: "):
data = json.loads(event[6:])
print(f" SSE client: {data['status']} - {data.get('progress', 0):.1f}% - {data.get('filename', '')}")
collected_events.append(data)
if data.get("status") in ("complete", "error"):
break
# Simulate backend download with HFProgressTracker
async def simulate_real_download():
await asyncio.sleep(0.2) # Let SSE subscribe
print(" Backend: Starting download with HFProgressTracker...")
# Set up tracking (like the real backend does)
progress_callback = create_hf_progress_callback("integration-test", pm)
tracker = HFProgressTracker(progress_callback)
# Initialize progress
pm.update_progress(
model_name="integration-test",
current=0,
total=1,
filename="",
status="downloading"
)
# Simulate download with tqdm patching
with tracker.patch_download():
try:
from tqdm import tqdm
# Simulate multi-file download (like HuggingFace does)
files = [
("model.safetensors", 5000),
("config.json", 1000),
("tokenizer.json", 500),
]
for filename, size in files:
print(f" Backend: Downloading {filename}...")
with tqdm(total=size, desc=filename, unit="B") as pbar:
for chunk in range(0, size, 500):
chunk_size = min(500, size - chunk)
pbar.update(chunk_size)
await asyncio.sleep(0.05)
# Mark complete
print(" Backend: Download complete")
pm.mark_complete("integration-test")
except ImportError:
print(" ✗ tqdm not available")
pm.mark_error("integration-test", "tqdm not available")
# Run both
await asyncio.gather(
sse_client(),
simulate_real_download()
)
# Verify
print(f"\n Collected {len(collected_events)} events")
if len(collected_events) > 0:
print(f" First event: {collected_events[0]}")
print(f" Last event: {collected_events[-1]}")
assert collected_events[-1]["status"] == "complete", "Should end with 'complete'"
print("✓ Test 4 PASSED\n")
return True
else:
print("✗ Test 4 FAILED - No events received\n")
return False
async def main():
"""Run all tests."""
print("\n" + "=" * 60)
print("Voicebox Progress Tracking Test Suite")
print("=" * 60)
results = []
# Test 1: Basic operations
try:
results.append(("Basic Operations", test_progress_manager_basic()))
except Exception as e:
print(f"✗ Test 1 FAILED: {e}\n")
results.append(("Basic Operations", False))
# Test 2: SSE streaming
try:
results.append(("SSE Streaming", await test_progress_manager_sse()))
except Exception as e:
print(f"✗ Test 2 FAILED: {e}\n")
results.append(("SSE Streaming", False))
# Test 3: tqdm patching
try:
results.append(("tqdm Patching", test_hf_progress_tracker()))
except Exception as e:
print(f"✗ Test 3 FAILED: {e}\n")
results.append(("tqdm Patching", False))
# Test 4: Full integration
try:
results.append(("Full Integration", await test_full_integration()))
except Exception as e:
print(f"✗ Test 4 FAILED: {e}\n")
results.append(("Full Integration", False))
# Summary
print("\n" + "=" * 60)
print("Test Results Summary")
print("=" * 60)
for name, result in results:
status = "✓ PASS" if result else ("⊘ SKIP" if result is None else "✗ FAIL")
print(f" {status:8} {name}")
passed = sum(1 for _, r in results if r is True)
failed = sum(1 for _, r in results if r is False)
skipped = sum(1 for _, r in results if r is None)
print()
print(f" Total: {len(results)} tests")
print(f" Passed: {passed}")
print(f" Failed: {failed}")
print(f" Skipped: {skipped}")
print("=" * 60 + "\n")
return failed == 0
if __name__ == "__main__":
success = asyncio.run(main())
exit(0 if success else 1)

View File

@@ -0,0 +1,317 @@
"""
Test Qwen TTS model download with SSE progress monitoring.
This specifically tests the MLX TTS backend download progress tracking,
which requires tqdm to be patched BEFORE mlx_audio is imported.
Usage:
cd backend && python -m tests.test_qwen_download
Prerequisites:
- Server must be running: cd backend && python main.py
- Delete model first for fresh download test:
curl -X DELETE http://localhost:8000/models/qwen-tts-0.6B
"""
import asyncio
import json
import httpx
import time
from typing import List, Dict, Optional
async def monitor_sse_stream(model_name: str, timeout: int = 600) -> List[Dict]:
"""
Monitor SSE stream for a model download.
Args:
model_name: Name of the model to monitor
timeout: Maximum time to wait for download (seconds)
Returns:
List of SSE events received
"""
events: List[Dict] = []
url = f"http://localhost:8000/models/progress/{model_name}"
last_progress = -1
print(f"\n📡 Connecting to SSE endpoint: {url}")
try:
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("GET", url) as response:
print(f" SSE connected, status: {response.status_code}")
if response.status_code != 200:
print(f" ❌ Error: SSE endpoint returned {response.status_code}")
return events
async for line in response.aiter_lines():
if not line:
continue
if line.startswith("data: "):
try:
data = json.loads(line[6:])
events.append(data)
# Print progress (only when it changes significantly)
progress = data.get('progress', 0)
status = data.get('status', 'unknown')
filename = data.get('filename', '')
current = data.get('current', 0)
total = data.get('total', 0)
# Print every 5% change or status change
if abs(progress - last_progress) >= 5 or status in ('complete', 'error'):
current_mb = current / (1024 * 1024)
total_mb = total / (1024 * 1024)
print(f" 📊 {status:12} {progress:6.1f}% ({current_mb:.1f}MB / {total_mb:.1f}MB) {filename[:50]}")
last_progress = progress
# Stop if complete or error
if status in ("complete", "error"):
if status == "complete":
print(f" ✅ Download complete!")
else:
print(f" ❌ Download error: {data.get('error', 'unknown')}")
break
except json.JSONDecodeError as e:
print(f" ⚠️ Error parsing JSON: {e}")
elif line.startswith(": heartbeat"):
# Heartbeat every 1 second, don't spam
pass
except asyncio.CancelledError:
print(" ⏹️ SSE monitor cancelled")
except Exception as e:
print(f" ❌ SSE error: {e}")
return events
async def trigger_download(model_name: str) -> bool:
"""Trigger a model download via the API."""
url = "http://localhost:8000/models/download"
print(f"\n🚀 Triggering download for: {model_name}")
try:
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(url, json={"model_name": model_name})
result = response.json()
print(f" Response: {response.status_code} - {result}")
return response.status_code == 200
except Exception as e:
print(f" ❌ Error triggering download: {e}")
return False
async def delete_model(model_name: str) -> bool:
"""Delete a model from cache."""
url = f"http://localhost:8000/models/{model_name}"
print(f"\n🗑️ Deleting model: {model_name}")
try:
async with httpx.AsyncClient(timeout=30) as client:
response = await client.delete(url)
if response.status_code == 200:
print(f" ✅ Model deleted")
return True
elif response.status_code == 404:
print(f" Model not found (already deleted)")
return True
else:
print(f" ⚠️ Delete response: {response.status_code} - {response.text}")
return False
except Exception as e:
print(f" ❌ Error deleting model: {e}")
return False
async def check_model_status(model_name: str) -> Optional[Dict]:
"""Check the status of a model."""
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get("http://localhost:8000/models/status")
if response.status_code == 200:
data = response.json()
for model in data.get("models", []):
if model["model_name"] == model_name:
return model
except Exception as e:
print(f" ⚠️ Error checking model status: {e}")
return None
async def check_server() -> bool:
"""Check if the server is running."""
try:
async with httpx.AsyncClient(timeout=5) as client:
response = await client.get("http://localhost:8000/health")
return response.status_code == 200
except Exception:
return False
async def main():
print("=" * 70)
print("🧪 Qwen TTS Model Download Progress Test")
print("=" * 70)
print("\nThis test verifies that MLX TTS download progress tracking works.")
print("It specifically tests the tqdm patching for mlx_audio.tts imports.")
# Check if server is running
print("\n📡 Checking if server is running...")
if not await check_server():
print(" ❌ Server is not running on http://localhost:8000")
print("\n Please start the server first:")
print(" cd backend && python main.py")
return False
print(" ✅ Server is running")
# Test model
model_name = "qwen-tts-0.6B"
# Check current status
print(f"\n📊 Checking status of {model_name}...")
status = await check_model_status(model_name)
if status:
print(f" Downloaded: {status.get('downloaded', False)}")
print(f" Downloading: {status.get('downloading', False)}")
print(f" Loaded: {status.get('loaded', False)}")
if status.get('size_mb'):
print(f" Size: {status['size_mb']:.1f} MB")
else:
print(" ⚠️ Could not get model status")
# Ask if user wants to delete first
print("\n" + "-" * 70)
if status and status.get('downloaded'):
print("⚠️ Model is already downloaded. Delete it for a fresh download test?")
print(" [y] Yes, delete and download fresh")
print(" [n] No, just test SSE connection")
print(" [q] Quit")
choice = input("\nChoice [y/n/q]: ").strip().lower()
if choice == 'q':
print("Exiting...")
return True
if choice == 'y':
if not await delete_model(model_name):
print("Failed to delete model. Continue anyway? [y/n]")
if input().strip().lower() != 'y':
return False
else:
print("Model not downloaded. Will perform fresh download test.")
input("Press Enter to continue...")
# Run the test
print("\n" + "=" * 70)
print("🏃 Starting Download Test")
print("=" * 70)
async def run_test():
# Start SSE monitor in background FIRST
monitor_task = asyncio.create_task(monitor_sse_stream(model_name, timeout=600))
# Wait for SSE to connect
await asyncio.sleep(1)
# Trigger download
success = await trigger_download(model_name)
if not success:
print(" ❌ Failed to trigger download")
monitor_task.cancel()
try:
await monitor_task
except asyncio.CancelledError:
pass
return []
# Wait for SSE monitor to complete
print("\n⏳ Waiting for download to complete (this may take several minutes)...")
events = await monitor_task
return events
start_time = time.time()
events = await run_test()
elapsed = time.time() - start_time
# Results
print("\n" + "=" * 70)
print("📋 Test Results")
print("=" * 70)
print(f"\n⏱️ Elapsed time: {elapsed:.1f} seconds")
print(f"📨 Total SSE events received: {len(events)}")
if not events:
print("\n❌ FAILED - No SSE events received!")
print("\nPossible causes:")
print(" 1. SSE endpoint not working")
print(" 2. tqdm not patched before mlx_audio import")
print(" 3. Progress callbacks not firing")
print(" 4. Model already fully downloaded")
print("\nDebug steps:")
print(" 1. Check server logs for [DEBUG] messages")
print(" 2. Look for 'tqdm patched' before 'mlx_audio.tts import'")
print(f" 3. Delete model: curl -X DELETE http://localhost:8000/models/{model_name}")
return False
# Analyze events
first_event = events[0]
last_event = events[-1]
print(f"\n📊 First event:")
print(f" Status: {first_event.get('status')}")
print(f" Progress: {first_event.get('progress', 0):.1f}%")
print(f"\n📊 Last event:")
print(f" Status: {last_event.get('status')}")
print(f" Progress: {last_event.get('progress', 0):.1f}%")
# Check for expected behaviors
has_progress_updates = len(events) > 2
has_increasing_progress = False
has_complete = any(e.get('status') == 'complete' for e in events)
has_100_percent = any(e.get('progress', 0) >= 100 for e in events)
# Check if progress increased over time
if len(events) >= 2:
progress_values = [e.get('progress', 0) for e in events]
has_increasing_progress = progress_values[-1] > progress_values[0]
print("\n📋 Checks:")
print(f" {'' if has_progress_updates else ''} Multiple progress updates received ({len(events)} events)")
print(f" {'' if has_increasing_progress else ''} Progress increased over time")
print(f" {'' if has_100_percent else ''} Reached 100% progress")
print(f" {'' if has_complete else ''} Received 'complete' status")
# Overall result
success = has_progress_updates and has_complete
if success:
print("\n" + "=" * 70)
print("✅ TEST PASSED - Qwen TTS download progress tracking works!")
print("=" * 70)
else:
print("\n" + "=" * 70)
print("❌ TEST FAILED - Progress tracking has issues")
print("=" * 70)
print("\nCheck the server logs for debug output.")
return success
if __name__ == "__main__":
result = asyncio.run(main())
exit(0 if result else 1)

View File

@@ -0,0 +1,54 @@
import asyncio
import pytest
from backend.services import task_queue
@pytest.mark.asyncio
async def test_cancel_queued_generation_skips_execution():
task_queue.init_queue(force=True)
running_started = asyncio.Event()
release_running = asyncio.Event()
queued_ran = asyncio.Event()
async def running_job():
running_started.set()
await release_running.wait()
async def queued_job():
queued_ran.set()
task_queue.enqueue_generation("gen-running", running_job())
await asyncio.wait_for(running_started.wait(), timeout=1)
task_queue.enqueue_generation("gen-queued", queued_job())
assert task_queue.cancel_generation("gen-queued") == "queued"
release_running.set()
await asyncio.sleep(0.1)
assert not queued_ran.is_set()
@pytest.mark.asyncio
async def test_cancel_running_generation_cancels_task():
task_queue.init_queue(force=True)
running_started = asyncio.Event()
running_cancelled = asyncio.Event()
async def running_job():
running_started.set()
try:
await asyncio.Event().wait()
except asyncio.CancelledError:
running_cancelled.set()
raise
task_queue.enqueue_generation("gen-running", running_job())
await asyncio.wait_for(running_started.wait(), timeout=1)
assert task_queue.cancel_generation("gen-running") == "running"
await asyncio.wait_for(running_cancelled.wait(), timeout=1)

View File

@@ -0,0 +1,178 @@
"""
Test real model download with SSE progress monitoring.
"""
import asyncio
import json
import httpx
import time
from typing import List, Dict
async def monitor_sse_stream(model_name: str, timeout: int = 300):
"""Monitor SSE stream for a model download."""
events: List[Dict] = []
url = f"http://localhost:8000/models/progress/{model_name}"
print(f"Connecting to SSE endpoint: {url}")
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("GET", url) as response:
print(f"SSE connected, status: {response.status_code}")
if response.status_code != 200:
print(f"Error: SSE endpoint returned {response.status_code}")
return events
async for line in response.aiter_lines():
if not line:
continue
print(f" Raw SSE: {line[:100]}...") # Print first 100 chars
if line.startswith("data: "):
try:
data = json.loads(line[6:])
print(f"{data['status']:12} {data.get('progress', 0):6.1f}% {data.get('filename', '')}")
events.append(data)
# Stop if complete or error
if data.get("status") in ("complete", "error"):
print(f" Download {data['status']}!")
break
except json.JSONDecodeError as e:
print(f" Error parsing JSON: {e}")
print(f" Line was: {line}")
elif line.startswith(": heartbeat"):
print(" ♥ heartbeat")
return events
async def trigger_download(model_name: str):
"""Trigger a model download via the API."""
url = "http://localhost:8000/models/download"
print(f"\nTriggering download for: {model_name}")
async with httpx.AsyncClient(timeout=300) as client:
response = await client.post(url, json={"model_name": model_name})
print(f"Response: {response.status_code} - {response.json()}")
return response.status_code == 200
async def check_server():
"""Check if the server is running."""
try:
async with httpx.AsyncClient(timeout=5) as client:
response = await client.get("http://localhost:8000/health")
return response.status_code == 200
except Exception as e:
print(f"Server not running: {e}")
return False
async def main():
print("=" * 60)
print("Real Model Download Progress Test")
print("=" * 60)
# Check if server is running
print("\nChecking if server is running...")
if not await check_server():
print("✗ Server is not running on http://localhost:8000")
print("\nPlease start the server first:")
print(" cd backend && python main.py")
return False
print("✓ Server is running")
# Choose a small model for testing
model_name = "whisper-base" # ~150MB, faster to download
print(f"\nUsing model: {model_name}")
# Option to delete model first if it exists
print("\nDo you want to delete the model first to force a fresh download? (y/n)")
# For automated testing, skip deletion prompt
# delete_first = input().strip().lower() == 'y'
delete_first = False
if delete_first:
print(f"Deleting {model_name}...")
async with httpx.AsyncClient(timeout=30) as client:
response = await client.delete(f"http://localhost:8000/models/{model_name}")
print(f"Delete response: {response.status_code}")
print("\n" + "=" * 60)
print("Starting Test")
print("=" * 60)
# Start monitoring SSE stream BEFORE triggering download
async def run_test():
# Start SSE monitor in background
monitor_task = asyncio.create_task(monitor_sse_stream(model_name))
# Wait a bit to ensure SSE is connected
await asyncio.sleep(1)
# Trigger download
success = await trigger_download(model_name)
if not success:
print("✗ Failed to trigger download")
monitor_task.cancel()
return False
# Wait for SSE monitor to complete
events = await monitor_task
return events
events = await run_test()
# Results
print("\n" + "=" * 60)
print("Test Results")
print("=" * 60)
if not events:
print("✗ FAILED - No SSE events received!")
print("\nPossible causes:")
print(" 1. SSE endpoint not working")
print(" 2. Progress updates not being sent")
print(" 3. Model already downloaded (no progress to report)")
print("\nTry deleting the model first to force a fresh download:")
print(f" curl -X DELETE http://localhost:8000/models/{model_name}")
return False
print(f"✓ Received {len(events)} SSE events")
print(f"\nFirst event: {events[0]}")
print(f"Last event: {events[-1]}")
# Check if we got meaningful progress
has_progress = any(e.get('progress', 0) > 0 for e in events)
has_complete = any(e.get('status') == 'complete' for e in events)
if has_progress:
print("✓ Progress updates received")
else:
print("✗ No progress updates (might be already downloaded)")
if has_complete:
print("✓ Download completed successfully")
else:
print("✗ Download did not complete")
success = has_progress and has_complete
if success:
print("\n✓ TEST PASSED - Progress tracking works!")
else:
print("\n⊘ TEST INCONCLUSIVE - Try with a fresh download")
return success
if __name__ == "__main__":
asyncio.run(main())