103 lines
3.4 KiB
Python
103 lines
3.4 KiB
Python
"""
|
|
Task tracking for active downloads and generations.
|
|
"""
|
|
|
|
from typing import Optional, Dict, List
|
|
from datetime import datetime
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
@dataclass
|
|
class DownloadTask:
|
|
"""Represents an active download task."""
|
|
model_name: str
|
|
status: str = "downloading" # downloading, extracting, complete, error
|
|
started_at: datetime = field(default_factory=datetime.utcnow)
|
|
error: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class GenerationTask:
|
|
"""Represents an active generation task."""
|
|
task_id: str
|
|
profile_id: str
|
|
text_preview: str # First 50 chars of text
|
|
started_at: datetime = field(default_factory=datetime.utcnow)
|
|
|
|
|
|
class TaskManager:
|
|
"""Manages active downloads and generations."""
|
|
|
|
def __init__(self):
|
|
self._active_downloads: Dict[str, DownloadTask] = {}
|
|
self._active_generations: Dict[str, GenerationTask] = {}
|
|
|
|
def start_download(self, model_name: str) -> None:
|
|
"""Mark a download as started."""
|
|
self._active_downloads[model_name] = DownloadTask(
|
|
model_name=model_name,
|
|
status="downloading",
|
|
)
|
|
|
|
def complete_download(self, model_name: str) -> None:
|
|
"""Mark a download as complete."""
|
|
if model_name in self._active_downloads:
|
|
del self._active_downloads[model_name]
|
|
|
|
def error_download(self, model_name: str, error: str) -> None:
|
|
"""Mark a download as failed."""
|
|
if model_name in self._active_downloads:
|
|
self._active_downloads[model_name].status = "error"
|
|
self._active_downloads[model_name].error = error
|
|
|
|
def start_generation(self, task_id: str, profile_id: str, text: str) -> None:
|
|
"""Mark a generation as started."""
|
|
text_preview = text[:50] + "..." if len(text) > 50 else text
|
|
self._active_generations[task_id] = GenerationTask(
|
|
task_id=task_id,
|
|
profile_id=profile_id,
|
|
text_preview=text_preview,
|
|
)
|
|
|
|
def complete_generation(self, task_id: str) -> None:
|
|
"""Mark a generation as complete."""
|
|
if task_id in self._active_generations:
|
|
del self._active_generations[task_id]
|
|
|
|
def get_active_downloads(self) -> List[DownloadTask]:
|
|
"""Get all active downloads."""
|
|
return list(self._active_downloads.values())
|
|
|
|
def get_active_generations(self) -> List[GenerationTask]:
|
|
"""Get all active generations."""
|
|
return list(self._active_generations.values())
|
|
|
|
def cancel_download(self, model_name: str) -> bool:
|
|
"""Cancel/dismiss a download task (removes it from active list)."""
|
|
return self._active_downloads.pop(model_name, None) is not None
|
|
|
|
def clear_all(self) -> None:
|
|
"""Clear all download and generation tasks."""
|
|
self._active_downloads.clear()
|
|
self._active_generations.clear()
|
|
|
|
def is_download_active(self, model_name: str) -> bool:
|
|
"""Check if a download is active."""
|
|
return model_name in self._active_downloads
|
|
|
|
def is_generation_active(self, task_id: str) -> bool:
|
|
"""Check if a generation is active."""
|
|
return task_id in self._active_generations
|
|
|
|
|
|
# Global task manager instance
|
|
_task_manager: Optional[TaskManager] = None
|
|
|
|
|
|
def get_task_manager() -> TaskManager:
|
|
"""Get or create the global task manager."""
|
|
global _task_manager
|
|
if _task_manager is None:
|
|
_task_manager = TaskManager()
|
|
return _task_manager
|