314 lines
10 KiB
Python
314 lines
10 KiB
Python
"""
|
|
Test script to debug model download progress tracking.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import time
|
|
from typing import List, Dict
|
|
import logging
|
|
|
|
# Set up logging to see what's happening
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
|
|
from utils.progress import ProgressManager, get_progress_manager
|
|
from utils.hf_progress import HFProgressTracker, create_hf_progress_callback
|
|
|
|
|
|
def test_progress_manager_basic():
|
|
"""Test 1: Basic ProgressManager functionality."""
|
|
print("\n" + "=" * 60)
|
|
print("Test 1: ProgressManager Basic Operations")
|
|
print("=" * 60)
|
|
|
|
pm = ProgressManager()
|
|
|
|
# Test update_progress
|
|
pm.update_progress(
|
|
model_name="test-model",
|
|
current=50,
|
|
total=100,
|
|
filename="test.bin",
|
|
status="downloading"
|
|
)
|
|
|
|
# Test get_progress
|
|
progress = pm.get_progress("test-model")
|
|
print(f"✓ Progress stored: {progress}")
|
|
assert progress is not None
|
|
assert progress["progress"] == 50.0
|
|
assert progress["filename"] == "test.bin"
|
|
assert progress["status"] == "downloading"
|
|
|
|
# Test mark_complete
|
|
pm.mark_complete("test-model")
|
|
progress = pm.get_progress("test-model")
|
|
print(f"✓ Marked complete: {progress}")
|
|
assert progress["status"] == "complete"
|
|
assert progress["progress"] == 100.0
|
|
|
|
print("✓ Test 1 PASSED\n")
|
|
return True
|
|
|
|
|
|
async def test_progress_manager_sse():
|
|
"""Test 2: ProgressManager SSE streaming."""
|
|
print("\n" + "=" * 60)
|
|
print("Test 2: ProgressManager SSE Streaming")
|
|
print("=" * 60)
|
|
|
|
pm = ProgressManager()
|
|
collected_events: List[Dict] = []
|
|
|
|
# Simulate SSE client
|
|
async def sse_client():
|
|
"""Simulates a frontend SSE connection."""
|
|
print(" SSE client: Subscribing to test-model-sse...")
|
|
async for event in pm.subscribe("test-model-sse"):
|
|
# Parse SSE event
|
|
if event.startswith("data: "):
|
|
data = json.loads(event[6:])
|
|
print(f" SSE client: Received event: {data['status']} - {data.get('progress', 0):.1f}%")
|
|
collected_events.append(data)
|
|
|
|
# Stop when complete
|
|
if data.get("status") in ("complete", "error"):
|
|
break
|
|
elif event.startswith(": heartbeat"):
|
|
print(" SSE client: Received heartbeat")
|
|
|
|
# Simulate download progress updates (from backend thread)
|
|
async def simulate_download():
|
|
"""Simulates backend sending progress updates."""
|
|
print(" Backend: Starting simulated download...")
|
|
await asyncio.sleep(0.2) # Let SSE client subscribe first
|
|
|
|
# Send progress updates
|
|
for i in range(0, 101, 20):
|
|
print(f" Backend: Updating progress to {i}%")
|
|
pm.update_progress(
|
|
model_name="test-model-sse",
|
|
current=i,
|
|
total=100,
|
|
filename=f"file_{i}.bin",
|
|
status="downloading" if i < 100 else "downloading"
|
|
)
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Mark complete
|
|
print(" Backend: Marking download complete")
|
|
pm.mark_complete("test-model-sse")
|
|
|
|
# Run SSE client and download simulation concurrently
|
|
await asyncio.gather(
|
|
sse_client(),
|
|
simulate_download()
|
|
)
|
|
|
|
# Verify we got events
|
|
print(f"\n Collected {len(collected_events)} events")
|
|
assert len(collected_events) > 0, "Should have received at least one event"
|
|
assert collected_events[-1]["status"] == "complete", "Last event should be 'complete'"
|
|
|
|
print("✓ Test 2 PASSED\n")
|
|
return True
|
|
|
|
|
|
def test_hf_progress_tracker():
|
|
"""Test 3: HFProgressTracker tqdm patching."""
|
|
print("\n" + "=" * 60)
|
|
print("Test 3: HFProgressTracker tqdm Patching")
|
|
print("=" * 60)
|
|
|
|
captured_progress: List[tuple] = []
|
|
|
|
def progress_callback(downloaded: int, total: int, filename: str):
|
|
"""Capture progress updates."""
|
|
captured_progress.append((downloaded, total, filename))
|
|
print(f" Progress callback: {downloaded}/{total} bytes ({filename})")
|
|
|
|
tracker = HFProgressTracker(progress_callback)
|
|
|
|
# Simulate a download with tqdm
|
|
with tracker.patch_download():
|
|
try:
|
|
from tqdm import tqdm
|
|
|
|
# Simulate downloading a file
|
|
print(" Simulating download with tqdm...")
|
|
total_size = 1000
|
|
with tqdm(total=total_size, desc="model.bin", unit="B", unit_scale=True) as pbar:
|
|
for chunk in range(0, total_size, 100):
|
|
pbar.update(100)
|
|
time.sleep(0.01)
|
|
|
|
print(f" Captured {len(captured_progress)} progress updates")
|
|
assert len(captured_progress) > 0, "Should have captured progress updates"
|
|
|
|
# Verify progress increases
|
|
last_downloaded = 0
|
|
for downloaded, total, filename in captured_progress:
|
|
assert downloaded >= last_downloaded, "Downloaded bytes should increase"
|
|
assert total == total_size, "Total should be consistent"
|
|
last_downloaded = downloaded
|
|
|
|
print("✓ Test 3 PASSED\n")
|
|
return True
|
|
|
|
except ImportError:
|
|
print("✗ tqdm not available, skipping test\n")
|
|
return None
|
|
|
|
|
|
async def test_full_integration():
|
|
"""Test 4: Full integration test."""
|
|
print("\n" + "=" * 60)
|
|
print("Test 4: Full Integration (ProgressManager + HFProgressTracker)")
|
|
print("=" * 60)
|
|
|
|
pm = get_progress_manager()
|
|
collected_events: List[Dict] = []
|
|
|
|
# SSE client
|
|
async def sse_client():
|
|
print(" SSE client: Subscribing...")
|
|
async for event in pm.subscribe("integration-test"):
|
|
if event.startswith("data: "):
|
|
data = json.loads(event[6:])
|
|
print(f" SSE client: {data['status']} - {data.get('progress', 0):.1f}% - {data.get('filename', '')}")
|
|
collected_events.append(data)
|
|
if data.get("status") in ("complete", "error"):
|
|
break
|
|
|
|
# Simulate backend download with HFProgressTracker
|
|
async def simulate_real_download():
|
|
await asyncio.sleep(0.2) # Let SSE subscribe
|
|
|
|
print(" Backend: Starting download with HFProgressTracker...")
|
|
|
|
# Set up tracking (like the real backend does)
|
|
progress_callback = create_hf_progress_callback("integration-test", pm)
|
|
tracker = HFProgressTracker(progress_callback)
|
|
|
|
# Initialize progress
|
|
pm.update_progress(
|
|
model_name="integration-test",
|
|
current=0,
|
|
total=1,
|
|
filename="",
|
|
status="downloading"
|
|
)
|
|
|
|
# Simulate download with tqdm patching
|
|
with tracker.patch_download():
|
|
try:
|
|
from tqdm import tqdm
|
|
|
|
# Simulate multi-file download (like HuggingFace does)
|
|
files = [
|
|
("model.safetensors", 5000),
|
|
("config.json", 1000),
|
|
("tokenizer.json", 500),
|
|
]
|
|
|
|
for filename, size in files:
|
|
print(f" Backend: Downloading {filename}...")
|
|
with tqdm(total=size, desc=filename, unit="B") as pbar:
|
|
for chunk in range(0, size, 500):
|
|
chunk_size = min(500, size - chunk)
|
|
pbar.update(chunk_size)
|
|
await asyncio.sleep(0.05)
|
|
|
|
# Mark complete
|
|
print(" Backend: Download complete")
|
|
pm.mark_complete("integration-test")
|
|
|
|
except ImportError:
|
|
print(" ✗ tqdm not available")
|
|
pm.mark_error("integration-test", "tqdm not available")
|
|
|
|
# Run both
|
|
await asyncio.gather(
|
|
sse_client(),
|
|
simulate_real_download()
|
|
)
|
|
|
|
# Verify
|
|
print(f"\n Collected {len(collected_events)} events")
|
|
if len(collected_events) > 0:
|
|
print(f" First event: {collected_events[0]}")
|
|
print(f" Last event: {collected_events[-1]}")
|
|
assert collected_events[-1]["status"] == "complete", "Should end with 'complete'"
|
|
print("✓ Test 4 PASSED\n")
|
|
return True
|
|
else:
|
|
print("✗ Test 4 FAILED - No events received\n")
|
|
return False
|
|
|
|
|
|
async def main():
|
|
"""Run all tests."""
|
|
print("\n" + "=" * 60)
|
|
print("Voicebox Progress Tracking Test Suite")
|
|
print("=" * 60)
|
|
|
|
results = []
|
|
|
|
# Test 1: Basic operations
|
|
try:
|
|
results.append(("Basic Operations", test_progress_manager_basic()))
|
|
except Exception as e:
|
|
print(f"✗ Test 1 FAILED: {e}\n")
|
|
results.append(("Basic Operations", False))
|
|
|
|
# Test 2: SSE streaming
|
|
try:
|
|
results.append(("SSE Streaming", await test_progress_manager_sse()))
|
|
except Exception as e:
|
|
print(f"✗ Test 2 FAILED: {e}\n")
|
|
results.append(("SSE Streaming", False))
|
|
|
|
# Test 3: tqdm patching
|
|
try:
|
|
results.append(("tqdm Patching", test_hf_progress_tracker()))
|
|
except Exception as e:
|
|
print(f"✗ Test 3 FAILED: {e}\n")
|
|
results.append(("tqdm Patching", False))
|
|
|
|
# Test 4: Full integration
|
|
try:
|
|
results.append(("Full Integration", await test_full_integration()))
|
|
except Exception as e:
|
|
print(f"✗ Test 4 FAILED: {e}\n")
|
|
results.append(("Full Integration", False))
|
|
|
|
# Summary
|
|
print("\n" + "=" * 60)
|
|
print("Test Results Summary")
|
|
print("=" * 60)
|
|
|
|
for name, result in results:
|
|
status = "✓ PASS" if result else ("⊘ SKIP" if result is None else "✗ FAIL")
|
|
print(f" {status:8} {name}")
|
|
|
|
passed = sum(1 for _, r in results if r is True)
|
|
failed = sum(1 for _, r in results if r is False)
|
|
skipped = sum(1 for _, r in results if r is None)
|
|
|
|
print()
|
|
print(f" Total: {len(results)} tests")
|
|
print(f" Passed: {passed}")
|
|
print(f" Failed: {failed}")
|
|
print(f" Skipped: {skipped}")
|
|
print("=" * 60 + "\n")
|
|
|
|
return failed == 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
success = asyncio.run(main())
|
|
exit(0 if success else 1)
|