476 lines
17 KiB
Python
476 lines
17 KiB
Python
"""Model management endpoints."""
|
|
|
|
import asyncio
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from sqlalchemy.orm import Session
|
|
|
|
from .. import models
|
|
from ..utils.platform_detect import get_backend_type
|
|
from ..services.task_queue import create_background_task
|
|
from ..utils.progress import get_progress_manager
|
|
from ..utils.tasks import get_task_manager
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def _get_dir_size(path: Path) -> int:
|
|
"""Get total size of a directory in bytes."""
|
|
total = 0
|
|
for f in path.rglob("*"):
|
|
if f.is_file():
|
|
total += f.stat().st_size
|
|
return total
|
|
|
|
|
|
def _copy_with_progress(src: Path, dst: Path, progress_manager, copied_so_far: int, total_bytes: int) -> int:
|
|
"""Copy a directory tree with byte-level progress tracking."""
|
|
dst.mkdir(parents=True, exist_ok=True)
|
|
for item in src.iterdir():
|
|
dest_item = dst / item.name
|
|
if item.is_dir():
|
|
copied_so_far = _copy_with_progress(item, dest_item, progress_manager, copied_so_far, total_bytes)
|
|
else:
|
|
size = item.stat().st_size
|
|
shutil.copy2(str(item), str(dest_item))
|
|
copied_so_far += size
|
|
progress_manager.update_progress(
|
|
"migration",
|
|
copied_so_far,
|
|
total_bytes,
|
|
filename=item.name,
|
|
status="downloading",
|
|
)
|
|
return copied_so_far
|
|
|
|
|
|
@router.post("/models/load")
|
|
async def load_model(model_size: str = "1.7B"):
|
|
"""Manually load TTS model."""
|
|
from ..services import tts
|
|
|
|
try:
|
|
tts_model = tts.get_tts_model()
|
|
await tts_model.load_model_async(model_size)
|
|
return {"message": f"Model {model_size} loaded successfully"}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/models/unload")
|
|
async def unload_model():
|
|
"""Unload the default Qwen TTS model to free memory."""
|
|
from ..services import tts
|
|
|
|
try:
|
|
tts.unload_tts_model()
|
|
return {"message": "Model unloaded successfully"}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/models/{model_name}/unload")
|
|
async def unload_model_by_name(model_name: str):
|
|
"""Unload a specific model from memory without deleting it from disk."""
|
|
from ..backends import get_model_config, unload_model_by_config
|
|
|
|
config = get_model_config(model_name)
|
|
if not config:
|
|
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
|
|
|
|
try:
|
|
was_loaded = unload_model_by_config(config)
|
|
if not was_loaded:
|
|
return {"message": f"Model {model_name} is not loaded"}
|
|
return {"message": f"Model {model_name} unloaded successfully"}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
|
|
|
|
@router.get("/models/progress/{model_name}")
|
|
async def get_model_progress(model_name: str):
|
|
"""Get model download progress via Server-Sent Events."""
|
|
progress_manager = get_progress_manager()
|
|
|
|
async def event_generator():
|
|
async for event in progress_manager.subscribe(model_name):
|
|
yield event
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
@router.get("/models/cache-dir")
|
|
async def get_models_cache_dir():
|
|
"""Get the path to the HuggingFace model cache directory."""
|
|
from huggingface_hub import constants as hf_constants
|
|
|
|
return {"path": str(Path(hf_constants.HF_HUB_CACHE))}
|
|
|
|
|
|
@router.post("/models/migrate")
|
|
async def migrate_models(request: models.ModelMigrateRequest):
|
|
"""Move all downloaded models to a new directory with byte-level progress via SSE."""
|
|
from huggingface_hub import constants as hf_constants
|
|
|
|
source = Path(hf_constants.HF_HUB_CACHE)
|
|
destination = Path(request.destination)
|
|
|
|
if not source.exists():
|
|
raise HTTPException(status_code=404, detail="Current model cache directory not found")
|
|
|
|
if source.resolve() == destination.resolve():
|
|
raise HTTPException(status_code=400, detail="Source and destination are the same directory")
|
|
|
|
if destination.resolve().is_relative_to(source.resolve()):
|
|
raise HTTPException(status_code=400, detail="Destination cannot be inside the current cache directory")
|
|
|
|
progress_manager = get_progress_manager()
|
|
model_dirs = [d for d in source.iterdir() if d.name.startswith("models--") and d.is_dir()]
|
|
if not model_dirs:
|
|
progress_manager.update_progress("migration", 1, 1, status="complete")
|
|
progress_manager.mark_complete("migration")
|
|
return {"moved": 0, "errors": [], "source": str(source), "destination": str(destination)}
|
|
|
|
destination.mkdir(parents=True, exist_ok=True)
|
|
|
|
same_fs = False
|
|
try:
|
|
same_fs = source.stat().st_dev == destination.stat().st_dev
|
|
except OSError:
|
|
pass
|
|
|
|
async def migrate_background():
|
|
moved = 0
|
|
errors = []
|
|
try:
|
|
if same_fs:
|
|
total = len(model_dirs)
|
|
for i, item in enumerate(model_dirs):
|
|
dest_item = destination / item.name
|
|
try:
|
|
if dest_item.exists():
|
|
shutil.rmtree(dest_item)
|
|
shutil.move(str(item), str(dest_item))
|
|
moved += 1
|
|
progress_manager.update_progress(
|
|
"migration",
|
|
i + 1,
|
|
total,
|
|
filename=item.name,
|
|
status="downloading",
|
|
)
|
|
except Exception as e:
|
|
errors.append(f"{item.name}: {str(e)}")
|
|
else:
|
|
total_bytes = sum(_get_dir_size(d) for d in model_dirs)
|
|
progress_manager.update_progress(
|
|
"migration", 0, total_bytes, filename="Calculating...", status="downloading"
|
|
)
|
|
|
|
copied = 0
|
|
for item in model_dirs:
|
|
dest_item = destination / item.name
|
|
try:
|
|
if dest_item.exists():
|
|
shutil.rmtree(dest_item)
|
|
copied = await asyncio.to_thread(
|
|
_copy_with_progress, item, dest_item, progress_manager, copied, total_bytes
|
|
)
|
|
await asyncio.to_thread(shutil.rmtree, str(item))
|
|
moved += 1
|
|
except Exception as e:
|
|
errors.append(f"{item.name}: {str(e)}")
|
|
|
|
progress_manager.update_progress("migration", 1, 1, status="complete")
|
|
progress_manager.mark_complete("migration")
|
|
except Exception as e:
|
|
progress_manager.update_progress("migration", 0, 0, status="error")
|
|
progress_manager.mark_error("migration", str(e))
|
|
|
|
create_background_task(migrate_background())
|
|
|
|
return {"source": str(source), "destination": str(destination)}
|
|
|
|
|
|
@router.get("/models/migrate/progress")
|
|
async def get_migration_progress():
|
|
"""Get model migration progress via Server-Sent Events."""
|
|
progress_manager = get_progress_manager()
|
|
|
|
async def event_generator():
|
|
async for event in progress_manager.subscribe("migration"):
|
|
yield event
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
@router.get("/models/status", response_model=models.ModelStatusListResponse)
|
|
async def get_model_status():
|
|
"""Get status of all available models."""
|
|
from huggingface_hub import constants as hf_constants
|
|
|
|
backend_type = get_backend_type()
|
|
task_manager = get_task_manager()
|
|
|
|
active_download_names = {task.model_name for task in task_manager.get_active_downloads()}
|
|
|
|
try:
|
|
from huggingface_hub import scan_cache_dir
|
|
|
|
use_scan_cache = True
|
|
except ImportError:
|
|
use_scan_cache = False
|
|
|
|
from ..backends import get_all_model_configs, check_model_loaded
|
|
|
|
registry_configs = get_all_model_configs()
|
|
model_configs = [
|
|
{
|
|
"model_name": cfg.model_name,
|
|
"display_name": cfg.display_name,
|
|
"hf_repo_id": cfg.hf_repo_id,
|
|
"model_size": cfg.model_size,
|
|
"check_loaded": lambda c=cfg: check_model_loaded(c),
|
|
}
|
|
for cfg in registry_configs
|
|
]
|
|
|
|
model_to_repo = {cfg["model_name"]: cfg["hf_repo_id"] for cfg in model_configs}
|
|
active_download_repos = {model_to_repo.get(name) for name in active_download_names if name in model_to_repo}
|
|
|
|
cache_info = None
|
|
if use_scan_cache:
|
|
try:
|
|
cache_info = scan_cache_dir()
|
|
except Exception:
|
|
pass
|
|
|
|
statuses = []
|
|
|
|
for config in model_configs:
|
|
try:
|
|
downloaded = False
|
|
size_mb = None
|
|
loaded = False
|
|
|
|
if cache_info:
|
|
repo_id = config["hf_repo_id"]
|
|
for repo in cache_info.repos:
|
|
if repo.repo_id == repo_id:
|
|
has_model_weights = False
|
|
for rev in repo.revisions:
|
|
for f in rev.files:
|
|
fname = f.file_name.lower()
|
|
if fname.endswith((".safetensors", ".bin", ".pt", ".pth", ".npz")):
|
|
has_model_weights = True
|
|
break
|
|
if has_model_weights:
|
|
break
|
|
|
|
has_incomplete = False
|
|
try:
|
|
cache_dir = hf_constants.HF_HUB_CACHE
|
|
blobs_dir = Path(cache_dir) / ("models--" + repo_id.replace("/", "--")) / "blobs"
|
|
if blobs_dir.exists():
|
|
has_incomplete = any(blobs_dir.glob("*.incomplete"))
|
|
except Exception:
|
|
pass
|
|
|
|
if has_model_weights and not has_incomplete:
|
|
downloaded = True
|
|
try:
|
|
total_size = sum(revision.size_on_disk for revision in repo.revisions)
|
|
size_mb = total_size / (1024 * 1024)
|
|
except Exception:
|
|
pass
|
|
break
|
|
|
|
if not downloaded:
|
|
try:
|
|
cache_dir = hf_constants.HF_HUB_CACHE
|
|
repo_cache = Path(cache_dir) / ("models--" + config["hf_repo_id"].replace("/", "--"))
|
|
|
|
if repo_cache.exists():
|
|
blobs_dir = repo_cache / "blobs"
|
|
has_incomplete = blobs_dir.exists() and any(blobs_dir.glob("*.incomplete"))
|
|
|
|
if not has_incomplete:
|
|
snapshots_dir = repo_cache / "snapshots"
|
|
has_model_files = False
|
|
if snapshots_dir.exists():
|
|
has_model_files = (
|
|
any(snapshots_dir.rglob("*.bin"))
|
|
or any(snapshots_dir.rglob("*.safetensors"))
|
|
or any(snapshots_dir.rglob("*.pt"))
|
|
or any(snapshots_dir.rglob("*.pth"))
|
|
or any(snapshots_dir.rglob("*.npz"))
|
|
)
|
|
|
|
if has_model_files:
|
|
downloaded = True
|
|
try:
|
|
total_size = sum(
|
|
f.stat().st_size
|
|
for f in repo_cache.rglob("*")
|
|
if f.is_file() and not f.name.endswith(".incomplete")
|
|
)
|
|
size_mb = total_size / (1024 * 1024)
|
|
except Exception:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
loaded = config["check_loaded"]()
|
|
except Exception:
|
|
loaded = False
|
|
|
|
is_downloading = config["hf_repo_id"] in active_download_repos
|
|
|
|
if is_downloading:
|
|
downloaded = False
|
|
size_mb = None
|
|
|
|
statuses.append(
|
|
models.ModelStatus(
|
|
model_name=config["model_name"],
|
|
display_name=config["display_name"],
|
|
hf_repo_id=config["hf_repo_id"],
|
|
downloaded=downloaded,
|
|
downloading=is_downloading,
|
|
size_mb=size_mb,
|
|
loaded=loaded,
|
|
)
|
|
)
|
|
except Exception:
|
|
try:
|
|
loaded = config["check_loaded"]()
|
|
except Exception:
|
|
loaded = False
|
|
|
|
is_downloading = config["hf_repo_id"] in active_download_repos
|
|
|
|
statuses.append(
|
|
models.ModelStatus(
|
|
model_name=config["model_name"],
|
|
display_name=config["display_name"],
|
|
hf_repo_id=config["hf_repo_id"],
|
|
downloaded=False,
|
|
downloading=is_downloading,
|
|
size_mb=None,
|
|
loaded=loaded,
|
|
)
|
|
)
|
|
|
|
return models.ModelStatusListResponse(models=statuses)
|
|
|
|
|
|
@router.post("/models/download")
|
|
async def trigger_model_download(request: models.ModelDownloadRequest):
|
|
"""Trigger download of a specific model."""
|
|
from ..backends import get_model_config, get_model_load_func
|
|
|
|
task_manager = get_task_manager()
|
|
progress_manager = get_progress_manager()
|
|
|
|
config = get_model_config(request.model_name)
|
|
if not config:
|
|
raise HTTPException(status_code=400, detail=f"Unknown model: {request.model_name}")
|
|
|
|
load_func = get_model_load_func(config)
|
|
|
|
async def download_in_background():
|
|
try:
|
|
result = load_func()
|
|
if asyncio.iscoroutine(result):
|
|
await result
|
|
task_manager.complete_download(request.model_name)
|
|
except Exception as e:
|
|
task_manager.error_download(request.model_name, str(e))
|
|
|
|
task_manager.start_download(request.model_name)
|
|
|
|
progress_manager.update_progress(
|
|
model_name=request.model_name,
|
|
current=0,
|
|
total=0,
|
|
filename="Connecting to HuggingFace...",
|
|
status="downloading",
|
|
)
|
|
|
|
create_background_task(download_in_background())
|
|
|
|
return {"message": f"Model {request.model_name} download started"}
|
|
|
|
|
|
@router.post("/models/download/cancel")
|
|
async def cancel_model_download(request: models.ModelDownloadRequest):
|
|
"""Cancel or dismiss an errored/stale download task."""
|
|
task_manager = get_task_manager()
|
|
progress_manager = get_progress_manager()
|
|
|
|
removed = task_manager.cancel_download(request.model_name)
|
|
|
|
progress_removed = False
|
|
with progress_manager._lock:
|
|
if request.model_name in progress_manager._progress:
|
|
del progress_manager._progress[request.model_name]
|
|
progress_removed = True
|
|
|
|
if removed or progress_removed:
|
|
return {"message": f"Download task for {request.model_name} cancelled"}
|
|
return {"message": f"No active task found for {request.model_name}"}
|
|
|
|
|
|
@router.delete("/models/{model_name}")
|
|
async def delete_model(model_name: str):
|
|
"""Delete a downloaded model from the HuggingFace cache."""
|
|
from huggingface_hub import constants as hf_constants
|
|
from ..backends import get_model_config, unload_model_by_config
|
|
|
|
config = get_model_config(model_name)
|
|
if not config:
|
|
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
|
|
|
|
hf_repo_id = config.hf_repo_id
|
|
|
|
try:
|
|
unload_model_by_config(config)
|
|
|
|
cache_dir = hf_constants.HF_HUB_CACHE
|
|
repo_cache_dir = Path(cache_dir) / ("models--" + hf_repo_id.replace("/", "--"))
|
|
|
|
if not repo_cache_dir.exists():
|
|
raise HTTPException(status_code=404, detail=f"Model {model_name} not found in cache")
|
|
|
|
try:
|
|
shutil.rmtree(repo_cache_dir)
|
|
except OSError as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to delete model cache directory: {str(e)}")
|
|
|
|
return {"message": f"Model {model_name} deleted successfully"}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
|