316 lines
13 KiB
Python
316 lines
13 KiB
Python
"""
|
|
Progress tracking for model downloads using Server-Sent Events.
|
|
"""
|
|
|
|
from typing import Optional, Callable, Dict, List
|
|
from fastapi.responses import StreamingResponse
|
|
import asyncio
|
|
import json
|
|
import threading
|
|
from datetime import datetime
|
|
|
|
|
|
class ProgressManager:
|
|
"""Manages download progress for multiple models.
|
|
|
|
Thread-safe: can be called from background threads (e.g., via asyncio.to_thread).
|
|
"""
|
|
|
|
# Throttle settings to prevent overwhelming SSE clients
|
|
THROTTLE_INTERVAL_SECONDS = 0.5 # Minimum time between updates
|
|
THROTTLE_PROGRESS_DELTA = 1.0 # Minimum progress change (%) to force update
|
|
|
|
def __init__(self):
|
|
self._progress: Dict[str, Dict] = {}
|
|
self._listeners: Dict[str, list] = {}
|
|
self._lock = threading.Lock() # Thread-safe lock for progress dict
|
|
self._main_loop: Optional[asyncio.AbstractEventLoop] = None
|
|
self._last_notify_time: Dict[str, float] = {} # Last notification time per model
|
|
self._last_notify_progress: Dict[str, float] = {} # Last notified progress per model
|
|
|
|
def _set_main_loop(self, loop: asyncio.AbstractEventLoop):
|
|
"""Set the main event loop for thread-safe operations."""
|
|
self._main_loop = loop
|
|
|
|
def _notify_listeners_threadsafe(self, model_name: str, progress_data: Dict):
|
|
"""Notify listeners in a thread-safe manner."""
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if model_name not in self._listeners:
|
|
return
|
|
|
|
for queue in self._listeners[model_name]:
|
|
try:
|
|
# Check if we're in the main event loop thread
|
|
try:
|
|
running_loop = asyncio.get_running_loop()
|
|
# We're in an async context, can use put_nowait directly
|
|
queue.put_nowait(progress_data.copy())
|
|
except RuntimeError:
|
|
# Not in async context (running in background thread)
|
|
# Use call_soon_threadsafe to safely put on queue
|
|
if self._main_loop and self._main_loop.is_running():
|
|
self._main_loop.call_soon_threadsafe(
|
|
lambda q=queue, d=progress_data.copy(): q.put_nowait(d) if not q.full() else None
|
|
)
|
|
else:
|
|
logger.debug(f"No main loop available for {model_name}, skipping notification")
|
|
except asyncio.QueueFull:
|
|
logger.warning(f"Queue full for {model_name}, dropping update")
|
|
except Exception as e:
|
|
logger.warning(f"Error notifying listener for {model_name}: {e}")
|
|
|
|
def update_progress(
|
|
self,
|
|
model_name: str,
|
|
current: int,
|
|
total: int,
|
|
filename: Optional[str] = None,
|
|
status: str = "downloading",
|
|
):
|
|
"""
|
|
Update progress for a model download.
|
|
|
|
Thread-safe: can be called from background threads.
|
|
|
|
Progress updates are throttled to prevent overwhelming SSE clients.
|
|
Updates are sent at most every THROTTLE_INTERVAL_SECONDS, or when
|
|
progress changes by at least THROTTLE_PROGRESS_DELTA percent.
|
|
|
|
Args:
|
|
model_name: Name of the model (e.g., "qwen-tts-1.7B", "whisper-base")
|
|
current: Current bytes downloaded
|
|
total: Total bytes to download
|
|
filename: Current file being downloaded
|
|
status: Status string (downloading, extracting, complete, error)
|
|
"""
|
|
import logging
|
|
import time
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Calculate progress percentage, clamped to 0-100 range
|
|
# This prevents crazy percentages from edge cases like:
|
|
# - current > total temporarily during aggregation
|
|
# - mixing file-count progress with byte-count progress
|
|
if total > 0:
|
|
progress_pct = min(100.0, max(0.0, (current / total * 100)))
|
|
else:
|
|
progress_pct = 0
|
|
|
|
progress_data = {
|
|
"model_name": model_name,
|
|
"current": current,
|
|
"total": total,
|
|
"progress": progress_pct,
|
|
"filename": filename,
|
|
"status": status,
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
# Thread-safe update of progress dict (always update internal state)
|
|
with self._lock:
|
|
self._progress[model_name] = progress_data
|
|
|
|
# Check if we should notify listeners (throttling)
|
|
current_time = time.time()
|
|
last_time = self._last_notify_time.get(model_name, 0)
|
|
last_progress = self._last_notify_progress.get(model_name, -100)
|
|
|
|
time_delta = current_time - last_time
|
|
progress_delta = abs(progress_pct - last_progress)
|
|
|
|
# Always notify for complete/error status, or if throttle conditions are met
|
|
should_notify = (
|
|
status in ("complete", "error") or
|
|
time_delta >= self.THROTTLE_INTERVAL_SECONDS or
|
|
progress_delta >= self.THROTTLE_PROGRESS_DELTA
|
|
)
|
|
|
|
if not should_notify:
|
|
return # Skip this update (throttled)
|
|
|
|
# Update throttle tracking
|
|
self._last_notify_time[model_name] = current_time
|
|
self._last_notify_progress[model_name] = progress_pct
|
|
|
|
# Notify all listeners (thread-safe)
|
|
listener_count = len(self._listeners.get(model_name, []))
|
|
|
|
if listener_count > 0:
|
|
logger.debug(f"Notifying {listener_count} listeners for {model_name}: {progress_pct:.1f}% ({filename})")
|
|
self._notify_listeners_threadsafe(model_name, progress_data)
|
|
else:
|
|
logger.debug(f"No listeners for {model_name}, progress update stored: {progress_pct:.1f}%")
|
|
|
|
def get_progress(self, model_name: str) -> Optional[Dict]:
|
|
"""Get current progress for a model. Thread-safe."""
|
|
with self._lock:
|
|
progress = self._progress.get(model_name)
|
|
return progress.copy() if progress else None
|
|
|
|
def get_all_active(self) -> List[Dict]:
|
|
"""Get all active downloads (status is 'downloading' or 'extracting'). Thread-safe."""
|
|
active = []
|
|
with self._lock:
|
|
for model_name, progress in self._progress.items():
|
|
status = progress.get("status", "")
|
|
if status in ("downloading", "extracting"):
|
|
active.append(progress.copy())
|
|
return active
|
|
|
|
def create_progress_callback(self, model_name: str, filename: Optional[str] = None):
|
|
"""
|
|
Create a progress callback function for HuggingFace downloads.
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
filename: Optional filename filter
|
|
|
|
Returns:
|
|
Callback function
|
|
"""
|
|
def callback(progress: Dict):
|
|
"""HuggingFace Hub progress callback."""
|
|
if "total" in progress and "current" in progress:
|
|
current = progress.get("current", 0)
|
|
total = progress.get("total", 0)
|
|
file_name = progress.get("filename", filename)
|
|
|
|
self.update_progress(
|
|
model_name=model_name,
|
|
current=current,
|
|
total=total,
|
|
filename=file_name,
|
|
status="downloading",
|
|
)
|
|
|
|
return callback
|
|
|
|
async def subscribe(self, model_name: str):
|
|
"""
|
|
Subscribe to progress updates for a model.
|
|
|
|
Yields progress updates as Server-Sent Events.
|
|
"""
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Store the main event loop for thread-safe operations
|
|
try:
|
|
self._main_loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
pass
|
|
|
|
queue = asyncio.Queue(maxsize=10)
|
|
|
|
# Add to listeners
|
|
if model_name not in self._listeners:
|
|
self._listeners[model_name] = []
|
|
self._listeners[model_name].append(queue)
|
|
|
|
logger.info(f"SSE client subscribed to {model_name}, total listeners: {len(self._listeners[model_name])}")
|
|
|
|
try:
|
|
# Send initial progress if available and still in progress (thread-safe read)
|
|
with self._lock:
|
|
initial_progress = self._progress.get(model_name)
|
|
if initial_progress:
|
|
initial_progress = initial_progress.copy()
|
|
|
|
if initial_progress:
|
|
status = initial_progress.get('status')
|
|
# Only send initial progress if download is actually in progress
|
|
# Don't send old 'complete' or 'error' status from previous downloads
|
|
if status in ('downloading', 'extracting'):
|
|
logger.info(f"Sending initial progress for {model_name}: {status}")
|
|
yield f"data: {json.dumps(initial_progress)}\n\n"
|
|
else:
|
|
logger.info(f"Skipping initial progress for {model_name} (status: {status})")
|
|
else:
|
|
logger.info(f"No initial progress available for {model_name}")
|
|
|
|
# Stream updates
|
|
while True:
|
|
try:
|
|
# Wait for update with timeout
|
|
progress = await asyncio.wait_for(queue.get(), timeout=1.0)
|
|
logger.debug(f"Sending progress update for {model_name}: {progress.get('status')} - {progress.get('progress', 0):.1f}%")
|
|
yield f"data: {json.dumps(progress)}\n\n"
|
|
|
|
# Stop if complete or error
|
|
if progress.get("status") in ("complete", "error"):
|
|
logger.info(f"Download {progress.get('status')} for {model_name}, closing SSE connection")
|
|
break
|
|
except asyncio.TimeoutError:
|
|
# Send heartbeat
|
|
yield ": heartbeat\n\n"
|
|
continue
|
|
except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError):
|
|
logger.debug(f"SSE client disconnected from {model_name}")
|
|
finally:
|
|
# Remove from listeners
|
|
if model_name in self._listeners:
|
|
self._listeners[model_name].remove(queue)
|
|
if not self._listeners[model_name]:
|
|
del self._listeners[model_name]
|
|
logger.info(f"SSE client unsubscribed from {model_name}, remaining listeners: {len(self._listeners.get(model_name, []))}")
|
|
|
|
def mark_complete(self, model_name: str):
|
|
"""Mark a model download as complete. Thread-safe."""
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
with self._lock:
|
|
if model_name in self._progress:
|
|
self._progress[model_name]["status"] = "complete"
|
|
self._progress[model_name]["progress"] = 100.0
|
|
progress_data = self._progress[model_name].copy()
|
|
else:
|
|
logger.warning(f"Cannot mark {model_name} as complete: not found in progress")
|
|
return
|
|
|
|
logger.info(f"Marked {model_name} as complete")
|
|
# Notify listeners (thread-safe)
|
|
self._notify_listeners_threadsafe(model_name, progress_data)
|
|
|
|
def mark_error(self, model_name: str, error: str):
|
|
"""Mark a model download as failed. Thread-safe."""
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
with self._lock:
|
|
if model_name in self._progress:
|
|
self._progress[model_name]["status"] = "error"
|
|
self._progress[model_name]["error"] = error
|
|
progress_data = self._progress[model_name].copy()
|
|
else:
|
|
# Create new progress entry for error
|
|
progress_data = {
|
|
"model_name": model_name,
|
|
"current": 0,
|
|
"total": 0,
|
|
"progress": 0,
|
|
"filename": None,
|
|
"status": "error",
|
|
"error": error,
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
self._progress[model_name] = progress_data
|
|
|
|
logger.error(f"Marked {model_name} as error: {error}")
|
|
# Notify listeners (thread-safe)
|
|
self._notify_listeners_threadsafe(model_name, progress_data)
|
|
|
|
|
|
# Global progress manager instance
|
|
_progress_manager: Optional[ProgressManager] = None
|
|
|
|
|
|
def get_progress_manager() -> ProgressManager:
|
|
"""Get or create the global progress manager."""
|
|
global _progress_manager
|
|
if _progress_manager is None:
|
|
_progress_manager = ProgressManager()
|
|
return _progress_manager
|