Initial commit
This commit is contained in:
475
backend/routes/models.py
Normal file
475
backend/routes/models.py
Normal file
@@ -0,0 +1,475 @@
|
||||
"""Model management endpoints."""
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import models
|
||||
from ..utils.platform_detect import get_backend_type
|
||||
from ..services.task_queue import create_background_task
|
||||
from ..utils.progress import get_progress_manager
|
||||
from ..utils.tasks import get_task_manager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _get_dir_size(path: Path) -> int:
|
||||
"""Get total size of a directory in bytes."""
|
||||
total = 0
|
||||
for f in path.rglob("*"):
|
||||
if f.is_file():
|
||||
total += f.stat().st_size
|
||||
return total
|
||||
|
||||
|
||||
def _copy_with_progress(src: Path, dst: Path, progress_manager, copied_so_far: int, total_bytes: int) -> int:
|
||||
"""Copy a directory tree with byte-level progress tracking."""
|
||||
dst.mkdir(parents=True, exist_ok=True)
|
||||
for item in src.iterdir():
|
||||
dest_item = dst / item.name
|
||||
if item.is_dir():
|
||||
copied_so_far = _copy_with_progress(item, dest_item, progress_manager, copied_so_far, total_bytes)
|
||||
else:
|
||||
size = item.stat().st_size
|
||||
shutil.copy2(str(item), str(dest_item))
|
||||
copied_so_far += size
|
||||
progress_manager.update_progress(
|
||||
"migration",
|
||||
copied_so_far,
|
||||
total_bytes,
|
||||
filename=item.name,
|
||||
status="downloading",
|
||||
)
|
||||
return copied_so_far
|
||||
|
||||
|
||||
@router.post("/models/load")
|
||||
async def load_model(model_size: str = "1.7B"):
|
||||
"""Manually load TTS model."""
|
||||
from ..services import tts
|
||||
|
||||
try:
|
||||
tts_model = tts.get_tts_model()
|
||||
await tts_model.load_model_async(model_size)
|
||||
return {"message": f"Model {model_size} loaded successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/models/unload")
|
||||
async def unload_model():
|
||||
"""Unload the default Qwen TTS model to free memory."""
|
||||
from ..services import tts
|
||||
|
||||
try:
|
||||
tts.unload_tts_model()
|
||||
return {"message": "Model unloaded successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/models/{model_name}/unload")
|
||||
async def unload_model_by_name(model_name: str):
|
||||
"""Unload a specific model from memory without deleting it from disk."""
|
||||
from ..backends import get_model_config, unload_model_by_config
|
||||
|
||||
config = get_model_config(model_name)
|
||||
if not config:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
|
||||
|
||||
try:
|
||||
was_loaded = unload_model_by_config(config)
|
||||
if not was_loaded:
|
||||
return {"message": f"Model {model_name} is not loaded"}
|
||||
return {"message": f"Model {model_name} unloaded successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/models/progress/{model_name}")
|
||||
async def get_model_progress(model_name: str):
|
||||
"""Get model download progress via Server-Sent Events."""
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
async def event_generator():
|
||||
async for event in progress_manager.subscribe(model_name):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models/cache-dir")
|
||||
async def get_models_cache_dir():
|
||||
"""Get the path to the HuggingFace model cache directory."""
|
||||
from huggingface_hub import constants as hf_constants
|
||||
|
||||
return {"path": str(Path(hf_constants.HF_HUB_CACHE))}
|
||||
|
||||
|
||||
@router.post("/models/migrate")
|
||||
async def migrate_models(request: models.ModelMigrateRequest):
|
||||
"""Move all downloaded models to a new directory with byte-level progress via SSE."""
|
||||
from huggingface_hub import constants as hf_constants
|
||||
|
||||
source = Path(hf_constants.HF_HUB_CACHE)
|
||||
destination = Path(request.destination)
|
||||
|
||||
if not source.exists():
|
||||
raise HTTPException(status_code=404, detail="Current model cache directory not found")
|
||||
|
||||
if source.resolve() == destination.resolve():
|
||||
raise HTTPException(status_code=400, detail="Source and destination are the same directory")
|
||||
|
||||
if destination.resolve().is_relative_to(source.resolve()):
|
||||
raise HTTPException(status_code=400, detail="Destination cannot be inside the current cache directory")
|
||||
|
||||
progress_manager = get_progress_manager()
|
||||
model_dirs = [d for d in source.iterdir() if d.name.startswith("models--") and d.is_dir()]
|
||||
if not model_dirs:
|
||||
progress_manager.update_progress("migration", 1, 1, status="complete")
|
||||
progress_manager.mark_complete("migration")
|
||||
return {"moved": 0, "errors": [], "source": str(source), "destination": str(destination)}
|
||||
|
||||
destination.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
same_fs = False
|
||||
try:
|
||||
same_fs = source.stat().st_dev == destination.stat().st_dev
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def migrate_background():
|
||||
moved = 0
|
||||
errors = []
|
||||
try:
|
||||
if same_fs:
|
||||
total = len(model_dirs)
|
||||
for i, item in enumerate(model_dirs):
|
||||
dest_item = destination / item.name
|
||||
try:
|
||||
if dest_item.exists():
|
||||
shutil.rmtree(dest_item)
|
||||
shutil.move(str(item), str(dest_item))
|
||||
moved += 1
|
||||
progress_manager.update_progress(
|
||||
"migration",
|
||||
i + 1,
|
||||
total,
|
||||
filename=item.name,
|
||||
status="downloading",
|
||||
)
|
||||
except Exception as e:
|
||||
errors.append(f"{item.name}: {str(e)}")
|
||||
else:
|
||||
total_bytes = sum(_get_dir_size(d) for d in model_dirs)
|
||||
progress_manager.update_progress(
|
||||
"migration", 0, total_bytes, filename="Calculating...", status="downloading"
|
||||
)
|
||||
|
||||
copied = 0
|
||||
for item in model_dirs:
|
||||
dest_item = destination / item.name
|
||||
try:
|
||||
if dest_item.exists():
|
||||
shutil.rmtree(dest_item)
|
||||
copied = await asyncio.to_thread(
|
||||
_copy_with_progress, item, dest_item, progress_manager, copied, total_bytes
|
||||
)
|
||||
await asyncio.to_thread(shutil.rmtree, str(item))
|
||||
moved += 1
|
||||
except Exception as e:
|
||||
errors.append(f"{item.name}: {str(e)}")
|
||||
|
||||
progress_manager.update_progress("migration", 1, 1, status="complete")
|
||||
progress_manager.mark_complete("migration")
|
||||
except Exception as e:
|
||||
progress_manager.update_progress("migration", 0, 0, status="error")
|
||||
progress_manager.mark_error("migration", str(e))
|
||||
|
||||
create_background_task(migrate_background())
|
||||
|
||||
return {"source": str(source), "destination": str(destination)}
|
||||
|
||||
|
||||
@router.get("/models/migrate/progress")
|
||||
async def get_migration_progress():
|
||||
"""Get model migration progress via Server-Sent Events."""
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
async def event_generator():
|
||||
async for event in progress_manager.subscribe("migration"):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models/status", response_model=models.ModelStatusListResponse)
|
||||
async def get_model_status():
|
||||
"""Get status of all available models."""
|
||||
from huggingface_hub import constants as hf_constants
|
||||
|
||||
backend_type = get_backend_type()
|
||||
task_manager = get_task_manager()
|
||||
|
||||
active_download_names = {task.model_name for task in task_manager.get_active_downloads()}
|
||||
|
||||
try:
|
||||
from huggingface_hub import scan_cache_dir
|
||||
|
||||
use_scan_cache = True
|
||||
except ImportError:
|
||||
use_scan_cache = False
|
||||
|
||||
from ..backends import get_all_model_configs, check_model_loaded
|
||||
|
||||
registry_configs = get_all_model_configs()
|
||||
model_configs = [
|
||||
{
|
||||
"model_name": cfg.model_name,
|
||||
"display_name": cfg.display_name,
|
||||
"hf_repo_id": cfg.hf_repo_id,
|
||||
"model_size": cfg.model_size,
|
||||
"check_loaded": lambda c=cfg: check_model_loaded(c),
|
||||
}
|
||||
for cfg in registry_configs
|
||||
]
|
||||
|
||||
model_to_repo = {cfg["model_name"]: cfg["hf_repo_id"] for cfg in model_configs}
|
||||
active_download_repos = {model_to_repo.get(name) for name in active_download_names if name in model_to_repo}
|
||||
|
||||
cache_info = None
|
||||
if use_scan_cache:
|
||||
try:
|
||||
cache_info = scan_cache_dir()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
statuses = []
|
||||
|
||||
for config in model_configs:
|
||||
try:
|
||||
downloaded = False
|
||||
size_mb = None
|
||||
loaded = False
|
||||
|
||||
if cache_info:
|
||||
repo_id = config["hf_repo_id"]
|
||||
for repo in cache_info.repos:
|
||||
if repo.repo_id == repo_id:
|
||||
has_model_weights = False
|
||||
for rev in repo.revisions:
|
||||
for f in rev.files:
|
||||
fname = f.file_name.lower()
|
||||
if fname.endswith((".safetensors", ".bin", ".pt", ".pth", ".npz")):
|
||||
has_model_weights = True
|
||||
break
|
||||
if has_model_weights:
|
||||
break
|
||||
|
||||
has_incomplete = False
|
||||
try:
|
||||
cache_dir = hf_constants.HF_HUB_CACHE
|
||||
blobs_dir = Path(cache_dir) / ("models--" + repo_id.replace("/", "--")) / "blobs"
|
||||
if blobs_dir.exists():
|
||||
has_incomplete = any(blobs_dir.glob("*.incomplete"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if has_model_weights and not has_incomplete:
|
||||
downloaded = True
|
||||
try:
|
||||
total_size = sum(revision.size_on_disk for revision in repo.revisions)
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
if not downloaded:
|
||||
try:
|
||||
cache_dir = hf_constants.HF_HUB_CACHE
|
||||
repo_cache = Path(cache_dir) / ("models--" + config["hf_repo_id"].replace("/", "--"))
|
||||
|
||||
if repo_cache.exists():
|
||||
blobs_dir = repo_cache / "blobs"
|
||||
has_incomplete = blobs_dir.exists() and any(blobs_dir.glob("*.incomplete"))
|
||||
|
||||
if not has_incomplete:
|
||||
snapshots_dir = repo_cache / "snapshots"
|
||||
has_model_files = False
|
||||
if snapshots_dir.exists():
|
||||
has_model_files = (
|
||||
any(snapshots_dir.rglob("*.bin"))
|
||||
or any(snapshots_dir.rglob("*.safetensors"))
|
||||
or any(snapshots_dir.rglob("*.pt"))
|
||||
or any(snapshots_dir.rglob("*.pth"))
|
||||
or any(snapshots_dir.rglob("*.npz"))
|
||||
)
|
||||
|
||||
if has_model_files:
|
||||
downloaded = True
|
||||
try:
|
||||
total_size = sum(
|
||||
f.stat().st_size
|
||||
for f in repo_cache.rglob("*")
|
||||
if f.is_file() and not f.name.endswith(".incomplete")
|
||||
)
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
loaded = config["check_loaded"]()
|
||||
except Exception:
|
||||
loaded = False
|
||||
|
||||
is_downloading = config["hf_repo_id"] in active_download_repos
|
||||
|
||||
if is_downloading:
|
||||
downloaded = False
|
||||
size_mb = None
|
||||
|
||||
statuses.append(
|
||||
models.ModelStatus(
|
||||
model_name=config["model_name"],
|
||||
display_name=config["display_name"],
|
||||
hf_repo_id=config["hf_repo_id"],
|
||||
downloaded=downloaded,
|
||||
downloading=is_downloading,
|
||||
size_mb=size_mb,
|
||||
loaded=loaded,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
loaded = config["check_loaded"]()
|
||||
except Exception:
|
||||
loaded = False
|
||||
|
||||
is_downloading = config["hf_repo_id"] in active_download_repos
|
||||
|
||||
statuses.append(
|
||||
models.ModelStatus(
|
||||
model_name=config["model_name"],
|
||||
display_name=config["display_name"],
|
||||
hf_repo_id=config["hf_repo_id"],
|
||||
downloaded=False,
|
||||
downloading=is_downloading,
|
||||
size_mb=None,
|
||||
loaded=loaded,
|
||||
)
|
||||
)
|
||||
|
||||
return models.ModelStatusListResponse(models=statuses)
|
||||
|
||||
|
||||
@router.post("/models/download")
|
||||
async def trigger_model_download(request: models.ModelDownloadRequest):
|
||||
"""Trigger download of a specific model."""
|
||||
from ..backends import get_model_config, get_model_load_func
|
||||
|
||||
task_manager = get_task_manager()
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
config = get_model_config(request.model_name)
|
||||
if not config:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown model: {request.model_name}")
|
||||
|
||||
load_func = get_model_load_func(config)
|
||||
|
||||
async def download_in_background():
|
||||
try:
|
||||
result = load_func()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
task_manager.complete_download(request.model_name)
|
||||
except Exception as e:
|
||||
task_manager.error_download(request.model_name, str(e))
|
||||
|
||||
task_manager.start_download(request.model_name)
|
||||
|
||||
progress_manager.update_progress(
|
||||
model_name=request.model_name,
|
||||
current=0,
|
||||
total=0,
|
||||
filename="Connecting to HuggingFace...",
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
create_background_task(download_in_background())
|
||||
|
||||
return {"message": f"Model {request.model_name} download started"}
|
||||
|
||||
|
||||
@router.post("/models/download/cancel")
|
||||
async def cancel_model_download(request: models.ModelDownloadRequest):
|
||||
"""Cancel or dismiss an errored/stale download task."""
|
||||
task_manager = get_task_manager()
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
removed = task_manager.cancel_download(request.model_name)
|
||||
|
||||
progress_removed = False
|
||||
with progress_manager._lock:
|
||||
if request.model_name in progress_manager._progress:
|
||||
del progress_manager._progress[request.model_name]
|
||||
progress_removed = True
|
||||
|
||||
if removed or progress_removed:
|
||||
return {"message": f"Download task for {request.model_name} cancelled"}
|
||||
return {"message": f"No active task found for {request.model_name}"}
|
||||
|
||||
|
||||
@router.delete("/models/{model_name}")
|
||||
async def delete_model(model_name: str):
|
||||
"""Delete a downloaded model from the HuggingFace cache."""
|
||||
from huggingface_hub import constants as hf_constants
|
||||
from ..backends import get_model_config, unload_model_by_config
|
||||
|
||||
config = get_model_config(model_name)
|
||||
if not config:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
|
||||
|
||||
hf_repo_id = config.hf_repo_id
|
||||
|
||||
try:
|
||||
unload_model_by_config(config)
|
||||
|
||||
cache_dir = hf_constants.HF_HUB_CACHE
|
||||
repo_cache_dir = Path(cache_dir) / ("models--" + hf_repo_id.replace("/", "--"))
|
||||
|
||||
if not repo_cache_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found in cache")
|
||||
|
||||
try:
|
||||
shutil.rmtree(repo_cache_dir)
|
||||
except OSError as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete model cache directory: {str(e)}")
|
||||
|
||||
return {"message": f"Model {model_name} deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
|
||||
Reference in New Issue
Block a user