Initial commit
This commit is contained in:
315
backend/utils/progress.py
Normal file
315
backend/utils/progress.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
Progress tracking for model downloads using Server-Sent Events.
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable, Dict, List
|
||||
from fastapi.responses import StreamingResponse
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ProgressManager:
|
||||
"""Manages download progress for multiple models.
|
||||
|
||||
Thread-safe: can be called from background threads (e.g., via asyncio.to_thread).
|
||||
"""
|
||||
|
||||
# Throttle settings to prevent overwhelming SSE clients
|
||||
THROTTLE_INTERVAL_SECONDS = 0.5 # Minimum time between updates
|
||||
THROTTLE_PROGRESS_DELTA = 1.0 # Minimum progress change (%) to force update
|
||||
|
||||
def __init__(self):
|
||||
self._progress: Dict[str, Dict] = {}
|
||||
self._listeners: Dict[str, list] = {}
|
||||
self._lock = threading.Lock() # Thread-safe lock for progress dict
|
||||
self._main_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._last_notify_time: Dict[str, float] = {} # Last notification time per model
|
||||
self._last_notify_progress: Dict[str, float] = {} # Last notified progress per model
|
||||
|
||||
def _set_main_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
"""Set the main event loop for thread-safe operations."""
|
||||
self._main_loop = loop
|
||||
|
||||
def _notify_listeners_threadsafe(self, model_name: str, progress_data: Dict):
|
||||
"""Notify listeners in a thread-safe manner."""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if model_name not in self._listeners:
|
||||
return
|
||||
|
||||
for queue in self._listeners[model_name]:
|
||||
try:
|
||||
# Check if we're in the main event loop thread
|
||||
try:
|
||||
running_loop = asyncio.get_running_loop()
|
||||
# We're in an async context, can use put_nowait directly
|
||||
queue.put_nowait(progress_data.copy())
|
||||
except RuntimeError:
|
||||
# Not in async context (running in background thread)
|
||||
# Use call_soon_threadsafe to safely put on queue
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
self._main_loop.call_soon_threadsafe(
|
||||
lambda q=queue, d=progress_data.copy(): q.put_nowait(d) if not q.full() else None
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No main loop available for {model_name}, skipping notification")
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"Queue full for {model_name}, dropping update")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error notifying listener for {model_name}: {e}")
|
||||
|
||||
def update_progress(
|
||||
self,
|
||||
model_name: str,
|
||||
current: int,
|
||||
total: int,
|
||||
filename: Optional[str] = None,
|
||||
status: str = "downloading",
|
||||
):
|
||||
"""
|
||||
Update progress for a model download.
|
||||
|
||||
Thread-safe: can be called from background threads.
|
||||
|
||||
Progress updates are throttled to prevent overwhelming SSE clients.
|
||||
Updates are sent at most every THROTTLE_INTERVAL_SECONDS, or when
|
||||
progress changes by at least THROTTLE_PROGRESS_DELTA percent.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (e.g., "qwen-tts-1.7B", "whisper-base")
|
||||
current: Current bytes downloaded
|
||||
total: Total bytes to download
|
||||
filename: Current file being downloaded
|
||||
status: Status string (downloading, extracting, complete, error)
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Calculate progress percentage, clamped to 0-100 range
|
||||
# This prevents crazy percentages from edge cases like:
|
||||
# - current > total temporarily during aggregation
|
||||
# - mixing file-count progress with byte-count progress
|
||||
if total > 0:
|
||||
progress_pct = min(100.0, max(0.0, (current / total * 100)))
|
||||
else:
|
||||
progress_pct = 0
|
||||
|
||||
progress_data = {
|
||||
"model_name": model_name,
|
||||
"current": current,
|
||||
"total": total,
|
||||
"progress": progress_pct,
|
||||
"filename": filename,
|
||||
"status": status,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# Thread-safe update of progress dict (always update internal state)
|
||||
with self._lock:
|
||||
self._progress[model_name] = progress_data
|
||||
|
||||
# Check if we should notify listeners (throttling)
|
||||
current_time = time.time()
|
||||
last_time = self._last_notify_time.get(model_name, 0)
|
||||
last_progress = self._last_notify_progress.get(model_name, -100)
|
||||
|
||||
time_delta = current_time - last_time
|
||||
progress_delta = abs(progress_pct - last_progress)
|
||||
|
||||
# Always notify for complete/error status, or if throttle conditions are met
|
||||
should_notify = (
|
||||
status in ("complete", "error") or
|
||||
time_delta >= self.THROTTLE_INTERVAL_SECONDS or
|
||||
progress_delta >= self.THROTTLE_PROGRESS_DELTA
|
||||
)
|
||||
|
||||
if not should_notify:
|
||||
return # Skip this update (throttled)
|
||||
|
||||
# Update throttle tracking
|
||||
self._last_notify_time[model_name] = current_time
|
||||
self._last_notify_progress[model_name] = progress_pct
|
||||
|
||||
# Notify all listeners (thread-safe)
|
||||
listener_count = len(self._listeners.get(model_name, []))
|
||||
|
||||
if listener_count > 0:
|
||||
logger.debug(f"Notifying {listener_count} listeners for {model_name}: {progress_pct:.1f}% ({filename})")
|
||||
self._notify_listeners_threadsafe(model_name, progress_data)
|
||||
else:
|
||||
logger.debug(f"No listeners for {model_name}, progress update stored: {progress_pct:.1f}%")
|
||||
|
||||
def get_progress(self, model_name: str) -> Optional[Dict]:
|
||||
"""Get current progress for a model. Thread-safe."""
|
||||
with self._lock:
|
||||
progress = self._progress.get(model_name)
|
||||
return progress.copy() if progress else None
|
||||
|
||||
def get_all_active(self) -> List[Dict]:
|
||||
"""Get all active downloads (status is 'downloading' or 'extracting'). Thread-safe."""
|
||||
active = []
|
||||
with self._lock:
|
||||
for model_name, progress in self._progress.items():
|
||||
status = progress.get("status", "")
|
||||
if status in ("downloading", "extracting"):
|
||||
active.append(progress.copy())
|
||||
return active
|
||||
|
||||
def create_progress_callback(self, model_name: str, filename: Optional[str] = None):
|
||||
"""
|
||||
Create a progress callback function for HuggingFace downloads.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
filename: Optional filename filter
|
||||
|
||||
Returns:
|
||||
Callback function
|
||||
"""
|
||||
def callback(progress: Dict):
|
||||
"""HuggingFace Hub progress callback."""
|
||||
if "total" in progress and "current" in progress:
|
||||
current = progress.get("current", 0)
|
||||
total = progress.get("total", 0)
|
||||
file_name = progress.get("filename", filename)
|
||||
|
||||
self.update_progress(
|
||||
model_name=model_name,
|
||||
current=current,
|
||||
total=total,
|
||||
filename=file_name,
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
async def subscribe(self, model_name: str):
|
||||
"""
|
||||
Subscribe to progress updates for a model.
|
||||
|
||||
Yields progress updates as Server-Sent Events.
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Store the main event loop for thread-safe operations
|
||||
try:
|
||||
self._main_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
queue = asyncio.Queue(maxsize=10)
|
||||
|
||||
# Add to listeners
|
||||
if model_name not in self._listeners:
|
||||
self._listeners[model_name] = []
|
||||
self._listeners[model_name].append(queue)
|
||||
|
||||
logger.info(f"SSE client subscribed to {model_name}, total listeners: {len(self._listeners[model_name])}")
|
||||
|
||||
try:
|
||||
# Send initial progress if available and still in progress (thread-safe read)
|
||||
with self._lock:
|
||||
initial_progress = self._progress.get(model_name)
|
||||
if initial_progress:
|
||||
initial_progress = initial_progress.copy()
|
||||
|
||||
if initial_progress:
|
||||
status = initial_progress.get('status')
|
||||
# Only send initial progress if download is actually in progress
|
||||
# Don't send old 'complete' or 'error' status from previous downloads
|
||||
if status in ('downloading', 'extracting'):
|
||||
logger.info(f"Sending initial progress for {model_name}: {status}")
|
||||
yield f"data: {json.dumps(initial_progress)}\n\n"
|
||||
else:
|
||||
logger.info(f"Skipping initial progress for {model_name} (status: {status})")
|
||||
else:
|
||||
logger.info(f"No initial progress available for {model_name}")
|
||||
|
||||
# Stream updates
|
||||
while True:
|
||||
try:
|
||||
# Wait for update with timeout
|
||||
progress = await asyncio.wait_for(queue.get(), timeout=1.0)
|
||||
logger.debug(f"Sending progress update for {model_name}: {progress.get('status')} - {progress.get('progress', 0):.1f}%")
|
||||
yield f"data: {json.dumps(progress)}\n\n"
|
||||
|
||||
# Stop if complete or error
|
||||
if progress.get("status") in ("complete", "error"):
|
||||
logger.info(f"Download {progress.get('status')} for {model_name}, closing SSE connection")
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat
|
||||
yield ": heartbeat\n\n"
|
||||
continue
|
||||
except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError):
|
||||
logger.debug(f"SSE client disconnected from {model_name}")
|
||||
finally:
|
||||
# Remove from listeners
|
||||
if model_name in self._listeners:
|
||||
self._listeners[model_name].remove(queue)
|
||||
if not self._listeners[model_name]:
|
||||
del self._listeners[model_name]
|
||||
logger.info(f"SSE client unsubscribed from {model_name}, remaining listeners: {len(self._listeners.get(model_name, []))}")
|
||||
|
||||
def mark_complete(self, model_name: str):
|
||||
"""Mark a model download as complete. Thread-safe."""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with self._lock:
|
||||
if model_name in self._progress:
|
||||
self._progress[model_name]["status"] = "complete"
|
||||
self._progress[model_name]["progress"] = 100.0
|
||||
progress_data = self._progress[model_name].copy()
|
||||
else:
|
||||
logger.warning(f"Cannot mark {model_name} as complete: not found in progress")
|
||||
return
|
||||
|
||||
logger.info(f"Marked {model_name} as complete")
|
||||
# Notify listeners (thread-safe)
|
||||
self._notify_listeners_threadsafe(model_name, progress_data)
|
||||
|
||||
def mark_error(self, model_name: str, error: str):
|
||||
"""Mark a model download as failed. Thread-safe."""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with self._lock:
|
||||
if model_name in self._progress:
|
||||
self._progress[model_name]["status"] = "error"
|
||||
self._progress[model_name]["error"] = error
|
||||
progress_data = self._progress[model_name].copy()
|
||||
else:
|
||||
# Create new progress entry for error
|
||||
progress_data = {
|
||||
"model_name": model_name,
|
||||
"current": 0,
|
||||
"total": 0,
|
||||
"progress": 0,
|
||||
"filename": None,
|
||||
"status": "error",
|
||||
"error": error,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
self._progress[model_name] = progress_data
|
||||
|
||||
logger.error(f"Marked {model_name} as error: {error}")
|
||||
# Notify listeners (thread-safe)
|
||||
self._notify_listeners_threadsafe(model_name, progress_data)
|
||||
|
||||
|
||||
# Global progress manager instance
|
||||
_progress_manager: Optional[ProgressManager] = None
|
||||
|
||||
|
||||
def get_progress_manager() -> ProgressManager:
|
||||
"""Get or create the global progress manager."""
|
||||
global _progress_manager
|
||||
if _progress_manager is None:
|
||||
_progress_manager = ProgressManager()
|
||||
return _progress_manager
|
||||
Reference in New Issue
Block a user