306 lines
10 KiB
Python
306 lines
10 KiB
Python
"""
|
|
Test TTS generation with SSE progress monitoring.
|
|
This test captures the exact SSE events triggered during generation
|
|
to identify UX issues where users see download progress even when
|
|
the model is already cached.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import httpx
|
|
from typing import List, Dict, Optional
|
|
from datetime import datetime
|
|
|
|
|
|
async def monitor_sse_stream(model_name: str, timeout: int = 120):
|
|
"""Monitor SSE stream for a model during generation."""
|
|
events: List[Dict] = []
|
|
url = f"http://localhost:8000/models/progress/{model_name}"
|
|
|
|
print(f"[{_timestamp()}] Connecting to SSE endpoint: {url}")
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
async with client.stream("GET", url) as response:
|
|
print(f"[{_timestamp()}] SSE connected, status: {response.status_code}")
|
|
|
|
if response.status_code != 200:
|
|
print(f"[{_timestamp()}] Error: SSE endpoint returned {response.status_code}")
|
|
return events
|
|
|
|
async for line in response.aiter_lines():
|
|
if not line:
|
|
continue
|
|
|
|
timestamp = _timestamp()
|
|
|
|
if line.startswith("data: "):
|
|
try:
|
|
data = json.loads(line[6:])
|
|
print(
|
|
f"[{timestamp}] → SSE Event: {data['status']:12} {data.get('progress', 0):6.1f}% {data.get('filename', '')}"
|
|
)
|
|
events.append({**data, "_timestamp": timestamp})
|
|
|
|
# Stop if complete or error
|
|
if data.get("status") in ("complete", "error"):
|
|
print(f"[{timestamp}] → Model {data['status']}!")
|
|
break
|
|
|
|
except json.JSONDecodeError as e:
|
|
print(f"[{timestamp}] Error parsing JSON: {e}")
|
|
print(f" Line was: {line}")
|
|
|
|
elif line.startswith(": heartbeat"):
|
|
print(f"[{timestamp}] ♥ heartbeat")
|
|
|
|
except asyncio.TimeoutError:
|
|
print(f"[{_timestamp()}] SSE monitoring timed out")
|
|
except Exception as e:
|
|
print(f"[{_timestamp()}] SSE error: {e}")
|
|
|
|
return events
|
|
|
|
|
|
async def trigger_generation(profile_id: str, text: str, model_size: str = "1.7B"):
|
|
"""Trigger TTS generation via the API."""
|
|
url = "http://localhost:8000/generate"
|
|
|
|
print(f"\n[{_timestamp()}] Triggering generation...")
|
|
print(f" Profile: {profile_id}")
|
|
print(f" Text: {text[:50]}...")
|
|
print(f" Model: {model_size}")
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=120) as client:
|
|
response = await client.post(
|
|
url,
|
|
json={
|
|
"profile_id": profile_id,
|
|
"text": text,
|
|
"language": "en",
|
|
"model_size": model_size,
|
|
},
|
|
)
|
|
|
|
print(f"[{_timestamp()}] Response: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
print(f"[{_timestamp()}] ✓ Generation successful!")
|
|
print(f" Generation ID: {result.get('id')}")
|
|
print(f" Duration: {result.get('duration', 0):.2f}s")
|
|
return True, result
|
|
elif response.status_code == 202:
|
|
# Model is being downloaded
|
|
result = response.json()
|
|
print(f"[{_timestamp()}] → Model download in progress")
|
|
print(f" Detail: {result}")
|
|
return False, result
|
|
else:
|
|
print(f"[{_timestamp()}] ✗ Error: {response.text}")
|
|
return False, None
|
|
|
|
except Exception as e:
|
|
print(f"[{_timestamp()}] ✗ Exception: {e}")
|
|
return False, None
|
|
|
|
|
|
async def get_first_profile():
|
|
"""Get the first available voice profile."""
|
|
url = "http://localhost:8000/profiles"
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
response = await client.get(url)
|
|
if response.status_code == 200:
|
|
profiles = response.json()
|
|
if profiles:
|
|
return profiles[0]["id"]
|
|
except Exception as e:
|
|
print(f"Error getting profiles: {e}")
|
|
|
|
return None
|
|
|
|
|
|
async def check_server():
|
|
"""Check if the server is running."""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5) as client:
|
|
response = await client.get("http://localhost:8000/health")
|
|
return response.status_code == 200
|
|
except Exception as e:
|
|
print(f"Server not running: {e}")
|
|
return False
|
|
|
|
|
|
def _timestamp():
|
|
"""Get current timestamp for logging."""
|
|
return datetime.now().strftime("%H:%M:%S.%f")[:-3]
|
|
|
|
|
|
async def test_generation_with_cached_model():
|
|
"""
|
|
Test Case 1: Generation when model is already cached.
|
|
|
|
This should NOT show any download progress events.
|
|
If it does, that's the UX bug we're trying to fix.
|
|
"""
|
|
print("\n" + "=" * 80)
|
|
print("TEST CASE 1: Generation with Cached Model")
|
|
print("=" * 80)
|
|
print("Expected: No download progress events (or minimal/instant completion)")
|
|
print("Actual UX Issue: Users see 'started' and 'finished' events even for cached models")
|
|
print("=" * 80)
|
|
|
|
model_size = "1.7B"
|
|
model_name = f"qwen-tts-{model_size}"
|
|
|
|
# Get a profile
|
|
profile_id = await get_first_profile()
|
|
if not profile_id:
|
|
print("✗ No voice profiles found. Please create a profile first.")
|
|
return False
|
|
|
|
print(f"\nUsing profile: {profile_id}")
|
|
|
|
# Start SSE monitor BEFORE triggering generation
|
|
monitor_task = asyncio.create_task(monitor_sse_stream(model_name, timeout=30))
|
|
|
|
# Wait for SSE to connect
|
|
await asyncio.sleep(1)
|
|
|
|
# Trigger generation
|
|
test_text = "Hello, this is a test of the voice generation system."
|
|
success, result = await trigger_generation(profile_id, test_text, model_size)
|
|
|
|
if not success and result and result.get("downloading"):
|
|
print("\n⚠ Model is being downloaded. Waiting for download to complete...")
|
|
# Wait for SSE monitor to capture download events
|
|
events = await monitor_task
|
|
return events
|
|
|
|
# Wait a bit more to catch any progress events
|
|
await asyncio.sleep(3)
|
|
|
|
# Cancel SSE monitor
|
|
monitor_task.cancel()
|
|
try:
|
|
events = await monitor_task
|
|
except asyncio.CancelledError:
|
|
events = []
|
|
|
|
return events
|
|
|
|
|
|
async def test_generation_with_fresh_download():
|
|
"""
|
|
Test Case 2: Generation when model needs to be downloaded.
|
|
|
|
This SHOULD show download progress events.
|
|
"""
|
|
print("\n" + "=" * 80)
|
|
print("TEST CASE 2: Generation with Model Download")
|
|
print("=" * 80)
|
|
print("Expected: Download progress events from 0% to 100%")
|
|
print("=" * 80)
|
|
|
|
# Use a different model size to force download
|
|
model_size = "0.6B" # Smaller model for faster testing
|
|
model_name = f"qwen-tts-{model_size}"
|
|
|
|
# Get a profile
|
|
profile_id = await get_first_profile()
|
|
if not profile_id:
|
|
print("✗ No voice profiles found. Please create a profile first.")
|
|
return False
|
|
|
|
print(f"\nUsing profile: {profile_id}")
|
|
print("Note: This will download the model if not cached")
|
|
|
|
# Start SSE monitor BEFORE triggering generation
|
|
monitor_task = asyncio.create_task(monitor_sse_stream(model_name, timeout=300))
|
|
|
|
# Wait for SSE to connect
|
|
await asyncio.sleep(1)
|
|
|
|
# Trigger generation
|
|
test_text = "This should trigger a model download if the model is not cached."
|
|
success, result = await trigger_generation(profile_id, test_text, model_size)
|
|
|
|
if not success and result and result.get("downloading"):
|
|
print("\n→ Model download initiated. Monitoring progress...")
|
|
# Wait for download to complete
|
|
events = await monitor_task
|
|
|
|
# Try generation again
|
|
print(f"\n[{_timestamp()}] Retrying generation after download...")
|
|
await asyncio.sleep(2)
|
|
success, result = await trigger_generation(profile_id, test_text, model_size)
|
|
|
|
if success:
|
|
print("✓ Generation successful after download")
|
|
|
|
return events
|
|
|
|
# If model was already cached
|
|
await asyncio.sleep(3)
|
|
monitor_task.cancel()
|
|
try:
|
|
events = await monitor_task
|
|
except asyncio.CancelledError:
|
|
events = []
|
|
|
|
return events
|
|
|
|
|
|
async def main():
|
|
print("=" * 80)
|
|
print("TTS Generation Progress Test")
|
|
print("=" * 80)
|
|
print("Purpose: Capture exact SSE events during generation to identify UX issues")
|
|
print("=" * 80)
|
|
|
|
# Check if server is running
|
|
print(f"\n[{_timestamp()}] Checking if server is running...")
|
|
if not await check_server():
|
|
print("✗ Server is not running on http://localhost:8000")
|
|
print("\nPlease start the server first:")
|
|
print(" cd backend && python main.py")
|
|
return False
|
|
|
|
print("✓ Server is running")
|
|
|
|
# Test Case 1: Cached model
|
|
print("\n" + "🧪 " * 20)
|
|
events_cached = await test_generation_with_cached_model()
|
|
|
|
# Results for Test Case 1
|
|
print("\n" + "=" * 80)
|
|
print("TEST CASE 1 RESULTS: Generation with Cached Model")
|
|
print("=" * 80)
|
|
|
|
if not events_cached:
|
|
print("✓ GOOD: No SSE progress events received")
|
|
print(" This is the expected behavior for a cached model.")
|
|
else:
|
|
print(f"⚠ ISSUE FOUND: Received {len(events_cached)} SSE events:")
|
|
print("\nEvent Timeline:")
|
|
for i, event in enumerate(events_cached, 1):
|
|
timestamp = event.pop("_timestamp", "??:??:??.???")
|
|
print(f" {i}. [{timestamp}] {event}")
|
|
|
|
print("\n⚠ This explains the UX issue!")
|
|
print(" Users see progress events even when the model is already cached,")
|
|
print(" making them think the model is downloading again.")
|
|
|
|
print("\n" + "=" * 80)
|
|
print("Test Complete!")
|
|
print("=" * 80)
|
|
|
|
return True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|