Initial commit
This commit is contained in:
32
backend/routes/__init__.py
Normal file
32
backend/routes/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Route registration for the voicebox API."""
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def register_routers(app: FastAPI) -> None:
|
||||
"""Include all domain routers on the application."""
|
||||
from .health import router as health_router
|
||||
from .profiles import router as profiles_router
|
||||
from .channels import router as channels_router
|
||||
from .generations import router as generations_router
|
||||
from .history import router as history_router
|
||||
from .transcription import router as transcription_router
|
||||
from .stories import router as stories_router
|
||||
from .effects import router as effects_router
|
||||
from .audio import router as audio_router
|
||||
from .models import router as models_router
|
||||
from .tasks import router as tasks_router
|
||||
from .cuda import router as cuda_router
|
||||
|
||||
app.include_router(health_router)
|
||||
app.include_router(profiles_router)
|
||||
app.include_router(channels_router)
|
||||
app.include_router(generations_router)
|
||||
app.include_router(history_router)
|
||||
app.include_router(transcription_router)
|
||||
app.include_router(stories_router)
|
||||
app.include_router(effects_router)
|
||||
app.include_router(audio_router)
|
||||
app.include_router(models_router)
|
||||
app.include_router(tasks_router)
|
||||
app.include_router(cuda_router)
|
||||
69
backend/routes/audio.py
Normal file
69
backend/routes/audio.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Audio file serving endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import config, models
|
||||
from ..services import history
|
||||
from ..database import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/audio/version/{version_id}")
|
||||
async def get_version_audio(version_id: str, db: Session = Depends(get_db)):
|
||||
"""Serve audio for a specific version."""
|
||||
from ..services import versions as versions_mod
|
||||
|
||||
version = versions_mod.get_version(version_id, db)
|
||||
if not version:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
audio_path = config.resolve_storage_path(version.audio_path)
|
||||
if audio_path is None or not audio_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
return FileResponse(
|
||||
audio_path,
|
||||
media_type="audio/wav",
|
||||
filename=f"generation_{version.generation_id}_{version.label}.wav",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/audio/{generation_id}")
|
||||
async def get_audio(generation_id: str, db: Session = Depends(get_db)):
|
||||
"""Serve generated audio file (serves the default version)."""
|
||||
generation = await history.get_generation(generation_id, db)
|
||||
if not generation:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is None or not audio_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
return FileResponse(
|
||||
audio_path,
|
||||
media_type="audio/wav",
|
||||
filename=f"generation_{generation_id}.wav",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/samples/{sample_id}")
|
||||
async def get_sample_audio(sample_id: str, db: Session = Depends(get_db)):
|
||||
"""Serve profile sample audio file."""
|
||||
from ..database import ProfileSample as DBProfileSample
|
||||
|
||||
sample = db.query(DBProfileSample).filter_by(id=sample_id).first()
|
||||
if not sample:
|
||||
raise HTTPException(status_code=404, detail="Sample not found")
|
||||
|
||||
audio_path = config.resolve_storage_path(sample.audio_path)
|
||||
if audio_path is None or not audio_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
return FileResponse(
|
||||
audio_path,
|
||||
media_type="audio/wav",
|
||||
filename=f"sample_{sample_id}.wav",
|
||||
)
|
||||
98
backend/routes/channels.py
Normal file
98
backend/routes/channels.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Audio channel endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import models
|
||||
from ..services import channels
|
||||
from ..database import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/channels", response_model=list[models.AudioChannelResponse])
|
||||
async def list_channels(db: Session = Depends(get_db)):
|
||||
"""List all audio channels."""
|
||||
return await channels.list_channels(db)
|
||||
|
||||
|
||||
@router.post("/channels", response_model=models.AudioChannelResponse)
|
||||
async def create_channel(
|
||||
data: models.AudioChannelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create a new audio channel."""
|
||||
try:
|
||||
return await channels.create_channel(data, db)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/channels/{channel_id}", response_model=models.AudioChannelResponse)
|
||||
async def get_channel(
|
||||
channel_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get an audio channel by ID."""
|
||||
channel = await channels.get_channel(channel_id, db)
|
||||
if not channel:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
return channel
|
||||
|
||||
|
||||
@router.put("/channels/{channel_id}", response_model=models.AudioChannelResponse)
|
||||
async def update_channel(
|
||||
channel_id: str,
|
||||
data: models.AudioChannelUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update an audio channel."""
|
||||
try:
|
||||
channel = await channels.update_channel(channel_id, data, db)
|
||||
if not channel:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
return channel
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/channels/{channel_id}")
|
||||
async def delete_channel(
|
||||
channel_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete an audio channel."""
|
||||
try:
|
||||
success = await channels.delete_channel(channel_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Channel not found")
|
||||
return {"message": "Channel deleted successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/channels/{channel_id}/voices")
|
||||
async def get_channel_voices(
|
||||
channel_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get list of profile IDs assigned to a channel."""
|
||||
try:
|
||||
profile_ids = await channels.get_channel_voices(channel_id, db)
|
||||
return {"profile_ids": profile_ids}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/channels/{channel_id}/voices")
|
||||
async def set_channel_voices(
|
||||
channel_id: str,
|
||||
data: models.ChannelVoiceAssignment,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Set which voices are assigned to a channel."""
|
||||
try:
|
||||
await channels.set_channel_voices(channel_id, data, db)
|
||||
return {"message": "Channel voices updated successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
82
backend/routes/cuda.py
Normal file
82
backend/routes/cuda.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""CUDA backend management endpoints."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from ..services.task_queue import create_background_task
|
||||
from ..utils.progress import get_progress_manager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/backend/cuda-status")
|
||||
async def get_cuda_status():
|
||||
"""Get CUDA backend download/availability status."""
|
||||
from ..services import cuda
|
||||
|
||||
return cuda.get_cuda_status()
|
||||
|
||||
|
||||
@router.post("/backend/download-cuda")
|
||||
async def download_cuda_backend():
|
||||
"""Download the CUDA backend binary."""
|
||||
from ..services import cuda
|
||||
|
||||
if cuda.get_cuda_binary_path() is not None:
|
||||
raise HTTPException(status_code=409, detail="CUDA backend already downloaded")
|
||||
|
||||
progress_manager = get_progress_manager()
|
||||
existing = progress_manager.get_progress(cuda.PROGRESS_KEY)
|
||||
if existing and existing.get("status") == "downloading":
|
||||
raise HTTPException(status_code=409, detail="CUDA backend download already in progress")
|
||||
|
||||
async def _download():
|
||||
try:
|
||||
await cuda.download_cuda_binary()
|
||||
except Exception as e:
|
||||
logger.error("CUDA download failed: %s", e)
|
||||
|
||||
create_background_task(_download())
|
||||
return {"message": "CUDA backend download started", "progress_key": "cuda-backend"}
|
||||
|
||||
|
||||
@router.delete("/backend/cuda")
|
||||
async def delete_cuda_backend():
|
||||
"""Delete the downloaded CUDA backend binary."""
|
||||
from ..services import cuda
|
||||
|
||||
if cuda.is_cuda_active():
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Cannot delete CUDA backend while it is active. Switch to CPU first.",
|
||||
)
|
||||
|
||||
deleted = await cuda.delete_cuda_binary()
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="No CUDA backend found to delete")
|
||||
|
||||
return {"message": "CUDA backend deleted"}
|
||||
|
||||
|
||||
@router.get("/backend/cuda-progress")
|
||||
async def get_cuda_download_progress():
|
||||
"""Get CUDA backend download progress via Server-Sent Events."""
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
async def event_generator():
|
||||
async for event in progress_manager.subscribe("cuda-backend"):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
262
backend/routes/effects.py
Normal file
262
backend/routes/effects.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""Effects presets and generation version endpoints."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import config, models
|
||||
from ..services import history
|
||||
from ..database import Generation as DBGeneration, get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/effects/preview/{generation_id}")
|
||||
async def preview_effects(
|
||||
generation_id: str,
|
||||
data: models.ApplyEffectsRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Apply effects to a generation's clean audio and stream back without saving."""
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
if (gen.status or "completed") != "completed":
|
||||
raise HTTPException(status_code=400, detail="Generation is not completed")
|
||||
|
||||
from ..services import versions as versions_mod
|
||||
from ..utils.effects import apply_effects, validate_effects_chain
|
||||
from ..utils.audio import load_audio
|
||||
|
||||
chain_dicts = [e.model_dump() for e in data.effects_chain]
|
||||
error = validate_effects_chain(chain_dicts)
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
|
||||
all_versions = versions_mod.list_versions(generation_id, db)
|
||||
clean_version = next((v for v in all_versions if v.effects_chain is None), None)
|
||||
source_path = clean_version.audio_path if clean_version else gen.audio_path
|
||||
resolved_source_path = config.resolve_storage_path(source_path)
|
||||
if resolved_source_path is None or not resolved_source_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Source audio file not found")
|
||||
|
||||
audio, sample_rate = await asyncio.to_thread(load_audio, str(resolved_source_path))
|
||||
processed = await asyncio.to_thread(apply_effects, audio, sample_rate, chain_dicts)
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
buf = io.BytesIO()
|
||||
await asyncio.to_thread(lambda: sf.write(buf, processed, sample_rate, format="WAV"))
|
||||
buf.seek(0)
|
||||
|
||||
return StreamingResponse(
|
||||
buf,
|
||||
media_type="audio/wav",
|
||||
headers={
|
||||
"Content-Disposition": f'inline; filename="preview_{generation_id}.wav"',
|
||||
"Cache-Control": "no-cache, no-store",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/effects/available", response_model=models.AvailableEffectsResponse)
|
||||
async def get_available_effects():
|
||||
"""List all available effect types with parameter definitions."""
|
||||
from ..utils.effects import get_available_effects as _get_effects
|
||||
|
||||
return models.AvailableEffectsResponse(effects=[models.AvailableEffect(**e) for e in _get_effects()])
|
||||
|
||||
|
||||
@router.get("/effects/presets", response_model=list[models.EffectPresetResponse])
|
||||
async def list_effect_presets(db: Session = Depends(get_db)):
|
||||
"""List all effect presets (built-in + user-created)."""
|
||||
from ..services import effects as effects_mod
|
||||
|
||||
return effects_mod.list_presets(db)
|
||||
|
||||
|
||||
@router.get("/effects/presets/{preset_id}", response_model=models.EffectPresetResponse)
|
||||
async def get_effect_preset(preset_id: str, db: Session = Depends(get_db)):
|
||||
"""Get a specific effect preset."""
|
||||
from ..services import effects as effects_mod
|
||||
|
||||
preset = effects_mod.get_preset(preset_id, db)
|
||||
if not preset:
|
||||
raise HTTPException(status_code=404, detail="Preset not found")
|
||||
return preset
|
||||
|
||||
|
||||
@router.post("/effects/presets", response_model=models.EffectPresetResponse)
|
||||
async def create_effect_preset(
|
||||
data: models.EffectPresetCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create a new effect preset."""
|
||||
from ..services import effects as effects_mod
|
||||
|
||||
try:
|
||||
return effects_mod.create_preset(data, db)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/effects/presets/{preset_id}", response_model=models.EffectPresetResponse)
|
||||
async def update_effect_preset(
|
||||
preset_id: str,
|
||||
data: models.EffectPresetUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update an effect preset."""
|
||||
from ..services import effects as effects_mod
|
||||
|
||||
try:
|
||||
result = effects_mod.update_preset(preset_id, data, db)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Preset not found")
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/effects/presets/{preset_id}")
|
||||
async def delete_effect_preset(preset_id: str, db: Session = Depends(get_db)):
|
||||
"""Delete a user effect preset."""
|
||||
from ..services import effects as effects_mod
|
||||
|
||||
try:
|
||||
if not effects_mod.delete_preset(preset_id, db):
|
||||
raise HTTPException(status_code=404, detail="Preset not found")
|
||||
return {"status": "deleted"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/generations/{generation_id}/versions",
|
||||
response_model=list[models.GenerationVersionResponse],
|
||||
)
|
||||
async def list_generation_versions(
|
||||
generation_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""List all versions for a generation."""
|
||||
gen = await history.get_generation(generation_id, db)
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
from ..services import versions as versions_mod
|
||||
|
||||
return versions_mod.list_versions(generation_id, db)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/generations/{generation_id}/versions/apply-effects",
|
||||
response_model=models.GenerationVersionResponse,
|
||||
)
|
||||
async def apply_effects_to_generation(
|
||||
generation_id: str,
|
||||
data: models.ApplyEffectsRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Apply an effects chain to an existing generation, creating a new version."""
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
if (gen.status or "completed") != "completed":
|
||||
raise HTTPException(status_code=400, detail="Generation is not completed")
|
||||
|
||||
from ..services import versions as versions_mod
|
||||
from ..utils.effects import apply_effects, validate_effects_chain
|
||||
from ..utils.audio import load_audio, save_audio
|
||||
|
||||
chain_dicts = [e.model_dump() for e in data.effects_chain]
|
||||
error = validate_effects_chain(chain_dicts)
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
|
||||
all_versions = versions_mod.list_versions(generation_id, db)
|
||||
source_version_id = data.source_version_id
|
||||
if source_version_id:
|
||||
source_version = next((v for v in all_versions if v.id == source_version_id), None)
|
||||
if not source_version:
|
||||
raise HTTPException(status_code=404, detail="Source version not found")
|
||||
source_path = source_version.audio_path
|
||||
else:
|
||||
clean_version = next((v for v in all_versions if v.effects_chain is None), None)
|
||||
if not clean_version:
|
||||
source_path = gen.audio_path
|
||||
else:
|
||||
source_path = clean_version.audio_path
|
||||
source_version_id = clean_version.id
|
||||
|
||||
resolved_source_path = config.resolve_storage_path(source_path)
|
||||
if resolved_source_path is None or not resolved_source_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Source audio file not found")
|
||||
|
||||
audio, sample_rate = await asyncio.to_thread(load_audio, str(resolved_source_path))
|
||||
processed_audio = await asyncio.to_thread(apply_effects, audio, sample_rate, chain_dicts)
|
||||
|
||||
version_id = str(uuid.uuid4())
|
||||
processed_path = config.get_generations_dir() / f"{generation_id}_{version_id[:8]}.wav"
|
||||
await asyncio.to_thread(save_audio, processed_audio, str(processed_path), sample_rate)
|
||||
|
||||
label = data.label or f"version-{len(all_versions) + 1}"
|
||||
|
||||
version = versions_mod.create_version(
|
||||
generation_id=generation_id,
|
||||
label=label,
|
||||
audio_path=config.to_storage_path(processed_path),
|
||||
db=db,
|
||||
effects_chain=chain_dicts,
|
||||
is_default=data.set_as_default,
|
||||
source_version_id=source_version_id,
|
||||
)
|
||||
|
||||
return version
|
||||
|
||||
|
||||
@router.put(
|
||||
"/generations/{generation_id}/versions/{version_id}/set-default",
|
||||
response_model=models.GenerationVersionResponse,
|
||||
)
|
||||
async def set_default_version(
|
||||
generation_id: str,
|
||||
version_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Set a specific version as the default for a generation."""
|
||||
from ..services import versions as versions_mod
|
||||
|
||||
version = versions_mod.get_version(version_id, db)
|
||||
if not version or version.generation_id != generation_id:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
result = versions_mod.set_default_version(version_id, db)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/generations/{generation_id}/versions/{version_id}")
|
||||
async def delete_generation_version(
|
||||
generation_id: str,
|
||||
version_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete a version. Cannot delete the last remaining version."""
|
||||
from ..services import versions as versions_mod
|
||||
|
||||
version = versions_mod.get_version(version_id, db)
|
||||
if not version or version.generation_id != generation_id:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
if not versions_mod.delete_version(version_id, db):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot delete the last remaining version",
|
||||
)
|
||||
return {"status": "deleted"}
|
||||
345
backend/routes/generations.py
Normal file
345
backend/routes/generations.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""TTS generation endpoints."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .. import models
|
||||
from ..services import history, profiles, tts
|
||||
from ..database import Generation as DBGeneration, VoiceProfile as DBVoiceProfile, get_db
|
||||
from ..services.generation import run_generation
|
||||
from ..services.task_queue import cancel_generation as cancel_generation_job, enqueue_generation
|
||||
from ..utils.tasks import get_task_manager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _resolve_generation_engine(data: models.GenerationRequest, profile) -> str:
|
||||
return data.engine or getattr(profile, "default_engine", None) or getattr(profile, "preset_engine", None) or "qwen"
|
||||
|
||||
|
||||
@router.post("/generate", response_model=models.GenerationResponse)
|
||||
async def generate_speech(
|
||||
data: models.GenerationRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Generate speech from text using a voice profile."""
|
||||
task_manager = get_task_manager()
|
||||
generation_id = str(uuid.uuid4())
|
||||
|
||||
profile = await profiles.get_profile(data.profile_id, db)
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
|
||||
from ..backends import engine_has_model_sizes
|
||||
|
||||
engine = _resolve_generation_engine(data, profile)
|
||||
try:
|
||||
profiles.validate_profile_engine(profile, engine)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
model_size = (data.model_size or "1.7B") if engine_has_model_sizes(engine) else None
|
||||
|
||||
generation = await history.create_generation(
|
||||
profile_id=data.profile_id,
|
||||
text=data.text,
|
||||
language=data.language,
|
||||
audio_path="",
|
||||
duration=0,
|
||||
seed=data.seed,
|
||||
db=db,
|
||||
instruct=data.instruct,
|
||||
generation_id=generation_id,
|
||||
status="generating",
|
||||
engine=engine,
|
||||
model_size=model_size if engine_has_model_sizes(engine) else None,
|
||||
)
|
||||
|
||||
task_manager.start_generation(
|
||||
task_id=generation_id,
|
||||
profile_id=data.profile_id,
|
||||
text=data.text,
|
||||
)
|
||||
|
||||
effects_chain_config = None
|
||||
if data.effects_chain is not None:
|
||||
effects_chain_config = [e.model_dump() for e in data.effects_chain]
|
||||
else:
|
||||
import json as _json
|
||||
|
||||
profile_obj = db.query(DBVoiceProfile).filter_by(id=data.profile_id).first()
|
||||
if profile_obj and profile_obj.effects_chain:
|
||||
try:
|
||||
effects_chain_config = _json.loads(profile_obj.effects_chain)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
enqueue_generation(
|
||||
generation_id,
|
||||
run_generation(
|
||||
generation_id=generation_id,
|
||||
profile_id=data.profile_id,
|
||||
text=data.text,
|
||||
language=data.language,
|
||||
engine=engine,
|
||||
model_size=model_size,
|
||||
seed=data.seed,
|
||||
normalize=data.normalize,
|
||||
effects_chain=effects_chain_config,
|
||||
instruct=data.instruct,
|
||||
mode="generate",
|
||||
max_chunk_chars=data.max_chunk_chars,
|
||||
crossfade_ms=data.crossfade_ms,
|
||||
)
|
||||
)
|
||||
|
||||
return generation
|
||||
|
||||
|
||||
@router.post("/generate/{generation_id}/retry", response_model=models.GenerationResponse)
|
||||
async def retry_generation(generation_id: str, db: Session = Depends(get_db)):
|
||||
"""Retry a failed generation using the same parameters."""
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
if (gen.status or "completed") != "failed":
|
||||
raise HTTPException(status_code=400, detail="Only failed generations can be retried")
|
||||
|
||||
gen.status = "generating"
|
||||
gen.error = None
|
||||
gen.audio_path = ""
|
||||
gen.duration = 0
|
||||
db.commit()
|
||||
db.refresh(gen)
|
||||
|
||||
task_manager = get_task_manager()
|
||||
task_manager.start_generation(
|
||||
task_id=generation_id,
|
||||
profile_id=gen.profile_id,
|
||||
text=gen.text,
|
||||
)
|
||||
|
||||
enqueue_generation(
|
||||
generation_id,
|
||||
run_generation(
|
||||
generation_id=generation_id,
|
||||
profile_id=gen.profile_id,
|
||||
text=gen.text,
|
||||
language=gen.language,
|
||||
engine=gen.engine or "qwen",
|
||||
model_size=gen.model_size or "1.7B",
|
||||
seed=gen.seed,
|
||||
instruct=gen.instruct,
|
||||
mode="retry",
|
||||
)
|
||||
)
|
||||
|
||||
return models.GenerationResponse.model_validate(gen)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/generate/{generation_id}/regenerate",
|
||||
response_model=models.GenerationResponse,
|
||||
)
|
||||
async def regenerate_generation(generation_id: str, db: Session = Depends(get_db)):
|
||||
"""Re-run TTS with the same parameters and save the result as a new version."""
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
if (gen.status or "completed") != "completed":
|
||||
raise HTTPException(status_code=400, detail="Generation must be completed to regenerate")
|
||||
|
||||
gen.status = "generating"
|
||||
gen.error = None
|
||||
db.commit()
|
||||
db.refresh(gen)
|
||||
|
||||
task_manager = get_task_manager()
|
||||
task_manager.start_generation(
|
||||
task_id=generation_id,
|
||||
profile_id=gen.profile_id,
|
||||
text=gen.text,
|
||||
)
|
||||
|
||||
version_id = str(uuid.uuid4())
|
||||
|
||||
enqueue_generation(
|
||||
generation_id,
|
||||
run_generation(
|
||||
generation_id=generation_id,
|
||||
profile_id=gen.profile_id,
|
||||
text=gen.text,
|
||||
language=gen.language,
|
||||
engine=gen.engine or "qwen",
|
||||
model_size=gen.model_size or "1.7B",
|
||||
seed=gen.seed,
|
||||
instruct=gen.instruct,
|
||||
mode="regenerate",
|
||||
version_id=version_id,
|
||||
)
|
||||
)
|
||||
|
||||
return models.GenerationResponse.model_validate(gen)
|
||||
|
||||
|
||||
@router.post("/generate/{generation_id}/cancel")
|
||||
async def cancel_generation(generation_id: str, db: Session = Depends(get_db)):
|
||||
"""Cancel a queued or running generation."""
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
if (gen.status or "completed") not in ("loading_model", "generating"):
|
||||
raise HTTPException(status_code=400, detail="Only active generations can be cancelled")
|
||||
|
||||
cancellation_state = cancel_generation_job(generation_id)
|
||||
if cancellation_state is None:
|
||||
raise HTTPException(status_code=409, detail="Generation is no longer cancellable")
|
||||
|
||||
if cancellation_state == "queued":
|
||||
task_manager = get_task_manager()
|
||||
task_manager.complete_generation(generation_id)
|
||||
await history.update_generation_status(
|
||||
generation_id=generation_id,
|
||||
status="failed",
|
||||
db=db,
|
||||
error="Generation cancelled",
|
||||
)
|
||||
return {"message": "Queued generation cancelled"}
|
||||
|
||||
return {"message": "Generation cancellation requested"}
|
||||
|
||||
|
||||
@router.get("/generate/{generation_id}/status")
|
||||
async def get_generation_status(generation_id: str, db: Session = Depends(get_db)):
|
||||
"""SSE endpoint that streams generation status updates."""
|
||||
import json
|
||||
|
||||
async def event_stream():
|
||||
try:
|
||||
while True:
|
||||
db.expire_all()
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not gen:
|
||||
yield f"data: {json.dumps({'status': 'not_found', 'id': generation_id})}\n\n"
|
||||
return
|
||||
|
||||
payload = {
|
||||
"id": gen.id,
|
||||
"status": gen.status or "completed",
|
||||
"duration": gen.duration,
|
||||
"error": gen.error,
|
||||
}
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
if (gen.status or "completed") in ("completed", "failed"):
|
||||
return
|
||||
|
||||
await asyncio.sleep(1)
|
||||
except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError):
|
||||
logger.debug("SSE client disconnected for generation %s", generation_id)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/generate/stream")
|
||||
async def stream_speech(
|
||||
data: models.GenerationRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Generate speech and stream the WAV audio directly without saving to disk."""
|
||||
from ..backends import get_tts_backend_for_engine, ensure_model_cached_or_raise, load_engine_model, engine_needs_trim
|
||||
|
||||
profile = await profiles.get_profile(data.profile_id, db)
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
|
||||
engine = _resolve_generation_engine(data, profile)
|
||||
try:
|
||||
profiles.validate_profile_engine(profile, engine)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
tts_model = get_tts_backend_for_engine(engine)
|
||||
model_size = data.model_size or "1.7B"
|
||||
|
||||
await ensure_model_cached_or_raise(engine, model_size)
|
||||
await load_engine_model(engine, model_size)
|
||||
|
||||
voice_prompt = await profiles.create_voice_prompt_for_profile(
|
||||
data.profile_id,
|
||||
db,
|
||||
engine=engine,
|
||||
)
|
||||
|
||||
from ..utils.chunked_tts import generate_chunked
|
||||
|
||||
trim_fn = None
|
||||
if engine_needs_trim(engine):
|
||||
from ..utils.audio import trim_tts_output
|
||||
|
||||
trim_fn = trim_tts_output
|
||||
|
||||
audio, sample_rate = await generate_chunked(
|
||||
tts_model,
|
||||
data.text,
|
||||
voice_prompt,
|
||||
language=data.language,
|
||||
seed=data.seed,
|
||||
instruct=data.instruct,
|
||||
max_chunk_chars=data.max_chunk_chars,
|
||||
crossfade_ms=data.crossfade_ms,
|
||||
trim_fn=trim_fn,
|
||||
)
|
||||
|
||||
effects_chain_config = None
|
||||
if data.effects_chain is not None:
|
||||
effects_chain_config = [e.model_dump() for e in data.effects_chain]
|
||||
elif profile.effects_chain:
|
||||
import json as _json
|
||||
|
||||
try:
|
||||
effects_chain_config = _json.loads(profile.effects_chain)
|
||||
except Exception:
|
||||
effects_chain_config = None
|
||||
|
||||
if effects_chain_config:
|
||||
from ..utils.effects import apply_effects
|
||||
|
||||
audio = apply_effects(audio, sample_rate, effects_chain_config)
|
||||
|
||||
if data.normalize:
|
||||
from ..utils.audio import normalize_audio
|
||||
|
||||
audio = normalize_audio(audio)
|
||||
|
||||
wav_bytes = tts.audio_to_wav_bytes(audio, sample_rate)
|
||||
|
||||
async def _wav_stream():
|
||||
try:
|
||||
chunk_size = 64 * 1024
|
||||
for i in range(0, len(wav_bytes), chunk_size):
|
||||
yield wav_bytes[i : i + chunk_size]
|
||||
except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError):
|
||||
logger.debug("Client disconnected during audio stream")
|
||||
|
||||
return StreamingResponse(
|
||||
_wav_stream(),
|
||||
media_type="audio/wav",
|
||||
headers={"Content-Disposition": 'attachment; filename="speech.wav"'},
|
||||
)
|
||||
248
backend/routes/health.py
Normal file
248
backend/routes/health.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Health and infrastructure endpoints."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import config, models
|
||||
from ..services import tts
|
||||
from ..database import get_db
|
||||
from ..utils.platform_detect import get_backend_type
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Frontend build directory — present in Docker, absent in dev/API-only mode
|
||||
_frontend_dir = Path(__file__).resolve().parent.parent.parent / "frontend"
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root():
|
||||
"""Root endpoint — serves SPA index.html in Docker, JSON otherwise."""
|
||||
from .. import __version__
|
||||
|
||||
index = _frontend_dir / "index.html"
|
||||
if index.is_file():
|
||||
return FileResponse(index, media_type="text/html")
|
||||
return {"message": "voicebox API", "version": __version__}
|
||||
|
||||
|
||||
@router.post("/shutdown")
|
||||
async def shutdown():
|
||||
"""Gracefully shutdown the server."""
|
||||
|
||||
async def shutdown_async():
|
||||
await asyncio.sleep(0.1)
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
|
||||
asyncio.create_task(shutdown_async())
|
||||
return {"message": "Shutting down..."}
|
||||
|
||||
|
||||
@router.post("/watchdog/disable")
|
||||
async def watchdog_disable():
|
||||
"""Disable the parent process watchdog so the server keeps running."""
|
||||
from backend.server import disable_watchdog
|
||||
|
||||
disable_watchdog()
|
||||
return {"message": "Watchdog disabled"}
|
||||
|
||||
|
||||
@router.get("/health", response_model=models.HealthResponse)
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
from huggingface_hub import constants as hf_constants
|
||||
from pathlib import Path
|
||||
|
||||
tts_model = tts.get_tts_model()
|
||||
backend_type = get_backend_type()
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
has_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
|
||||
has_xpu = False
|
||||
xpu_name = None
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa: F401 -- side-effect import enables XPU
|
||||
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
has_xpu = True
|
||||
try:
|
||||
xpu_name = torch.xpu.get_device_name(0)
|
||||
except Exception:
|
||||
xpu_name = "Intel GPU"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
has_directml = False
|
||||
directml_name = None
|
||||
try:
|
||||
import torch_directml
|
||||
|
||||
if torch_directml.device_count() > 0:
|
||||
has_directml = True
|
||||
try:
|
||||
directml_name = torch_directml.device_name(0)
|
||||
except Exception:
|
||||
directml_name = "DirectML GPU"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
gpu_compat_warning = None
|
||||
if has_cuda:
|
||||
from ..backends.base import check_cuda_compatibility
|
||||
|
||||
_compatible, gpu_compat_warning = check_cuda_compatibility()
|
||||
|
||||
gpu_available = has_cuda or has_mps or has_xpu or has_directml or backend_type == "mlx"
|
||||
|
||||
gpu_type = None
|
||||
if has_cuda:
|
||||
gpu_type = f"CUDA ({torch.cuda.get_device_name(0)})"
|
||||
elif has_mps:
|
||||
gpu_type = "MPS (Apple Silicon)"
|
||||
elif backend_type == "mlx":
|
||||
gpu_type = "Metal (Apple Silicon via MLX)"
|
||||
elif has_xpu:
|
||||
gpu_type = f"XPU ({xpu_name})"
|
||||
elif has_directml:
|
||||
gpu_type = f"DirectML ({directml_name})"
|
||||
|
||||
vram_used = None
|
||||
if has_cuda:
|
||||
vram_used = torch.cuda.memory_allocated() / 1024 / 1024
|
||||
elif has_xpu:
|
||||
try:
|
||||
vram_used = torch.xpu.memory_allocated() / 1024 / 1024
|
||||
except Exception:
|
||||
pass # memory_allocated() may not be available on all IPEX versions
|
||||
|
||||
model_loaded = False
|
||||
model_size = None
|
||||
try:
|
||||
if tts_model.is_loaded():
|
||||
model_loaded = True
|
||||
model_size = getattr(tts_model, "_current_model_size", None)
|
||||
if not model_size:
|
||||
model_size = getattr(tts_model, "model_size", None)
|
||||
except Exception:
|
||||
model_loaded = False
|
||||
model_size = None
|
||||
|
||||
model_downloaded = None
|
||||
try:
|
||||
from ..backends import get_model_config
|
||||
|
||||
default_config = get_model_config("qwen-tts-1.7B")
|
||||
default_model_id = default_config.hf_repo_id if default_config else "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
|
||||
try:
|
||||
from huggingface_hub import scan_cache_dir
|
||||
|
||||
cache_info = scan_cache_dir()
|
||||
for repo in cache_info.repos:
|
||||
if repo.repo_id == default_model_id:
|
||||
model_downloaded = True
|
||||
break
|
||||
except (ImportError, Exception):
|
||||
cache_dir = hf_constants.HF_HUB_CACHE
|
||||
repo_cache = Path(cache_dir) / ("models--" + default_model_id.replace("/", "--"))
|
||||
if repo_cache.exists():
|
||||
has_model_files = (
|
||||
any(repo_cache.rglob("*.bin"))
|
||||
or any(repo_cache.rglob("*.safetensors"))
|
||||
or any(repo_cache.rglob("*.pt"))
|
||||
or any(repo_cache.rglob("*.pth"))
|
||||
or any(repo_cache.rglob("*.npz"))
|
||||
)
|
||||
model_downloaded = has_model_files
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return models.HealthResponse(
|
||||
status="healthy",
|
||||
model_loaded=model_loaded,
|
||||
model_downloaded=model_downloaded,
|
||||
model_size=model_size,
|
||||
gpu_available=gpu_available,
|
||||
gpu_type=gpu_type,
|
||||
vram_used_mb=vram_used,
|
||||
backend_type=backend_type,
|
||||
backend_variant=os.environ.get(
|
||||
"VOICEBOX_BACKEND_VARIANT",
|
||||
"cuda" if torch.cuda.is_available() else ("xpu" if has_xpu else "cpu"),
|
||||
),
|
||||
gpu_compatibility_warning=gpu_compat_warning,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health/filesystem", response_model=models.FilesystemHealthResponse)
|
||||
async def filesystem_health():
|
||||
"""Check filesystem health: directory existence, write permissions, and disk space."""
|
||||
import shutil
|
||||
|
||||
dirs_to_check = {
|
||||
"generations": config.get_generations_dir(),
|
||||
"profiles": config.get_profiles_dir(),
|
||||
"data": config.get_data_dir(),
|
||||
}
|
||||
|
||||
checks: list[models.DirectoryCheck] = []
|
||||
all_ok = True
|
||||
|
||||
for _label, dir_path in dirs_to_check.items():
|
||||
exists = dir_path.exists()
|
||||
writable = False
|
||||
error = None
|
||||
if exists:
|
||||
probe = dir_path / ".voicebox_probe"
|
||||
try:
|
||||
probe.write_text("ok")
|
||||
probe.unlink()
|
||||
writable = True
|
||||
except PermissionError:
|
||||
error = "Permission denied"
|
||||
except OSError as e:
|
||||
error = str(e)
|
||||
finally:
|
||||
try:
|
||||
probe.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
error = "Directory does not exist"
|
||||
|
||||
if not exists or not writable:
|
||||
all_ok = False
|
||||
|
||||
checks.append(
|
||||
models.DirectoryCheck(
|
||||
path=str(dir_path.resolve()),
|
||||
exists=exists,
|
||||
writable=writable,
|
||||
error=error,
|
||||
)
|
||||
)
|
||||
|
||||
disk_free_mb = None
|
||||
disk_total_mb = None
|
||||
try:
|
||||
usage = shutil.disk_usage(str(config.get_data_dir()))
|
||||
disk_free_mb = round(usage.free / (1024 * 1024), 1)
|
||||
disk_total_mb = round(usage.total / (1024 * 1024), 1)
|
||||
if disk_free_mb < 500:
|
||||
all_ok = False
|
||||
except OSError:
|
||||
all_ok = False
|
||||
|
||||
return models.FilesystemHealthResponse(
|
||||
healthy=all_ok,
|
||||
disk_free_mb=disk_free_mb,
|
||||
disk_total_mb=disk_total_mb,
|
||||
directories=checks,
|
||||
)
|
||||
189
backend/routes/history.py
Normal file
189
backend/routes/history.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Generation history endpoints."""
|
||||
|
||||
import io
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import config, models
|
||||
from ..services import export_import, history
|
||||
from ..app import safe_content_disposition
|
||||
from ..database import Generation as DBGeneration, VoiceProfile as DBVoiceProfile, get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/history", response_model=models.HistoryListResponse)
|
||||
async def list_history(
|
||||
profile_id: str | None = None,
|
||||
search: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""List generation history with optional filters."""
|
||||
query = models.HistoryQuery(
|
||||
profile_id=profile_id,
|
||||
search=search,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return await history.list_generations(query, db)
|
||||
|
||||
|
||||
@router.get("/history/stats")
|
||||
async def get_stats(db: Session = Depends(get_db)):
|
||||
"""Get generation statistics."""
|
||||
return await history.get_generation_stats(db)
|
||||
|
||||
|
||||
@router.post("/history/import")
|
||||
async def import_generation(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Import a generation from a ZIP archive."""
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024
|
||||
|
||||
content = await file.read()
|
||||
|
||||
if len(content) > MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"File too large. Maximum size is {MAX_FILE_SIZE / (1024 * 1024)}MB"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await export_import.import_generation_from_zip(content, db)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/history/failed")
|
||||
async def clear_failed_generations(db: Session = Depends(get_db)):
|
||||
"""Delete every generation with status='failed'. Used by the UI's 'Clear failed' button (#410)."""
|
||||
count = await history.delete_failed_generations(db)
|
||||
return {"deleted": count}
|
||||
|
||||
|
||||
@router.get("/history/{generation_id}", response_model=models.HistoryResponse)
|
||||
async def get_generation(
|
||||
generation_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get a generation by ID."""
|
||||
result = (
|
||||
db.query(DBGeneration, DBVoiceProfile.name.label("profile_name"))
|
||||
.join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id)
|
||||
.filter(DBGeneration.id == generation_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
gen, profile_name = result
|
||||
return models.HistoryResponse(
|
||||
id=gen.id,
|
||||
profile_id=gen.profile_id,
|
||||
profile_name=profile_name,
|
||||
text=gen.text,
|
||||
language=gen.language,
|
||||
audio_path=gen.audio_path,
|
||||
duration=gen.duration,
|
||||
seed=gen.seed,
|
||||
instruct=gen.instruct,
|
||||
engine=gen.engine or "qwen",
|
||||
model_size=gen.model_size,
|
||||
status=gen.status or "completed",
|
||||
error=gen.error,
|
||||
is_favorited=bool(gen.is_favorited),
|
||||
created_at=gen.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/history/{generation_id}/favorite")
|
||||
async def toggle_favorite(
|
||||
generation_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Toggle the favorite status of a generation."""
|
||||
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
gen.is_favorited = not gen.is_favorited
|
||||
db.commit()
|
||||
return {"is_favorited": gen.is_favorited}
|
||||
|
||||
|
||||
@router.delete("/history/{generation_id}")
|
||||
async def delete_generation(
|
||||
generation_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete a generation."""
|
||||
success = await history.delete_generation(generation_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
return {"message": "Generation deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/history/{generation_id}/export")
|
||||
async def export_generation(
|
||||
generation_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Export a generation as a ZIP archive."""
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
try:
|
||||
zip_bytes = export_import.export_generation_to_zip(generation_id, db)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
safe_text = "".join(c for c in generation.text[:30] if c.isalnum() or c in (" ", "-", "_")).strip()
|
||||
if not safe_text:
|
||||
safe_text = "generation"
|
||||
filename = f"generation-{safe_text}.voicebox.zip"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(zip_bytes),
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history/{generation_id}/export-audio")
|
||||
async def export_generation_audio(
|
||||
generation_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Export only the audio file from a generation."""
|
||||
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
|
||||
if not generation:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
if not generation.audio_path:
|
||||
raise HTTPException(status_code=404, detail="Generation has no audio file")
|
||||
|
||||
audio_path = config.resolve_storage_path(generation.audio_path)
|
||||
if audio_path is None or not audio_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
safe_text = "".join(c for c in generation.text[:30] if c.isalnum() or c in (" ", "-", "_")).strip()
|
||||
if not safe_text:
|
||||
safe_text = "generation"
|
||||
filename = f"{safe_text}.wav"
|
||||
|
||||
return FileResponse(
|
||||
audio_path,
|
||||
media_type="audio/wav",
|
||||
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
|
||||
)
|
||||
475
backend/routes/models.py
Normal file
475
backend/routes/models.py
Normal file
@@ -0,0 +1,475 @@
|
||||
"""Model management endpoints."""
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import models
|
||||
from ..utils.platform_detect import get_backend_type
|
||||
from ..services.task_queue import create_background_task
|
||||
from ..utils.progress import get_progress_manager
|
||||
from ..utils.tasks import get_task_manager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _get_dir_size(path: Path) -> int:
|
||||
"""Get total size of a directory in bytes."""
|
||||
total = 0
|
||||
for f in path.rglob("*"):
|
||||
if f.is_file():
|
||||
total += f.stat().st_size
|
||||
return total
|
||||
|
||||
|
||||
def _copy_with_progress(src: Path, dst: Path, progress_manager, copied_so_far: int, total_bytes: int) -> int:
|
||||
"""Copy a directory tree with byte-level progress tracking."""
|
||||
dst.mkdir(parents=True, exist_ok=True)
|
||||
for item in src.iterdir():
|
||||
dest_item = dst / item.name
|
||||
if item.is_dir():
|
||||
copied_so_far = _copy_with_progress(item, dest_item, progress_manager, copied_so_far, total_bytes)
|
||||
else:
|
||||
size = item.stat().st_size
|
||||
shutil.copy2(str(item), str(dest_item))
|
||||
copied_so_far += size
|
||||
progress_manager.update_progress(
|
||||
"migration",
|
||||
copied_so_far,
|
||||
total_bytes,
|
||||
filename=item.name,
|
||||
status="downloading",
|
||||
)
|
||||
return copied_so_far
|
||||
|
||||
|
||||
@router.post("/models/load")
|
||||
async def load_model(model_size: str = "1.7B"):
|
||||
"""Manually load TTS model."""
|
||||
from ..services import tts
|
||||
|
||||
try:
|
||||
tts_model = tts.get_tts_model()
|
||||
await tts_model.load_model_async(model_size)
|
||||
return {"message": f"Model {model_size} loaded successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/models/unload")
|
||||
async def unload_model():
|
||||
"""Unload the default Qwen TTS model to free memory."""
|
||||
from ..services import tts
|
||||
|
||||
try:
|
||||
tts.unload_tts_model()
|
||||
return {"message": "Model unloaded successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/models/{model_name}/unload")
|
||||
async def unload_model_by_name(model_name: str):
|
||||
"""Unload a specific model from memory without deleting it from disk."""
|
||||
from ..backends import get_model_config, unload_model_by_config
|
||||
|
||||
config = get_model_config(model_name)
|
||||
if not config:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
|
||||
|
||||
try:
|
||||
was_loaded = unload_model_by_config(config)
|
||||
if not was_loaded:
|
||||
return {"message": f"Model {model_name} is not loaded"}
|
||||
return {"message": f"Model {model_name} unloaded successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/models/progress/{model_name}")
|
||||
async def get_model_progress(model_name: str):
|
||||
"""Get model download progress via Server-Sent Events."""
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
async def event_generator():
|
||||
async for event in progress_manager.subscribe(model_name):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models/cache-dir")
|
||||
async def get_models_cache_dir():
|
||||
"""Get the path to the HuggingFace model cache directory."""
|
||||
from huggingface_hub import constants as hf_constants
|
||||
|
||||
return {"path": str(Path(hf_constants.HF_HUB_CACHE))}
|
||||
|
||||
|
||||
@router.post("/models/migrate")
|
||||
async def migrate_models(request: models.ModelMigrateRequest):
|
||||
"""Move all downloaded models to a new directory with byte-level progress via SSE."""
|
||||
from huggingface_hub import constants as hf_constants
|
||||
|
||||
source = Path(hf_constants.HF_HUB_CACHE)
|
||||
destination = Path(request.destination)
|
||||
|
||||
if not source.exists():
|
||||
raise HTTPException(status_code=404, detail="Current model cache directory not found")
|
||||
|
||||
if source.resolve() == destination.resolve():
|
||||
raise HTTPException(status_code=400, detail="Source and destination are the same directory")
|
||||
|
||||
if destination.resolve().is_relative_to(source.resolve()):
|
||||
raise HTTPException(status_code=400, detail="Destination cannot be inside the current cache directory")
|
||||
|
||||
progress_manager = get_progress_manager()
|
||||
model_dirs = [d for d in source.iterdir() if d.name.startswith("models--") and d.is_dir()]
|
||||
if not model_dirs:
|
||||
progress_manager.update_progress("migration", 1, 1, status="complete")
|
||||
progress_manager.mark_complete("migration")
|
||||
return {"moved": 0, "errors": [], "source": str(source), "destination": str(destination)}
|
||||
|
||||
destination.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
same_fs = False
|
||||
try:
|
||||
same_fs = source.stat().st_dev == destination.stat().st_dev
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def migrate_background():
|
||||
moved = 0
|
||||
errors = []
|
||||
try:
|
||||
if same_fs:
|
||||
total = len(model_dirs)
|
||||
for i, item in enumerate(model_dirs):
|
||||
dest_item = destination / item.name
|
||||
try:
|
||||
if dest_item.exists():
|
||||
shutil.rmtree(dest_item)
|
||||
shutil.move(str(item), str(dest_item))
|
||||
moved += 1
|
||||
progress_manager.update_progress(
|
||||
"migration",
|
||||
i + 1,
|
||||
total,
|
||||
filename=item.name,
|
||||
status="downloading",
|
||||
)
|
||||
except Exception as e:
|
||||
errors.append(f"{item.name}: {str(e)}")
|
||||
else:
|
||||
total_bytes = sum(_get_dir_size(d) for d in model_dirs)
|
||||
progress_manager.update_progress(
|
||||
"migration", 0, total_bytes, filename="Calculating...", status="downloading"
|
||||
)
|
||||
|
||||
copied = 0
|
||||
for item in model_dirs:
|
||||
dest_item = destination / item.name
|
||||
try:
|
||||
if dest_item.exists():
|
||||
shutil.rmtree(dest_item)
|
||||
copied = await asyncio.to_thread(
|
||||
_copy_with_progress, item, dest_item, progress_manager, copied, total_bytes
|
||||
)
|
||||
await asyncio.to_thread(shutil.rmtree, str(item))
|
||||
moved += 1
|
||||
except Exception as e:
|
||||
errors.append(f"{item.name}: {str(e)}")
|
||||
|
||||
progress_manager.update_progress("migration", 1, 1, status="complete")
|
||||
progress_manager.mark_complete("migration")
|
||||
except Exception as e:
|
||||
progress_manager.update_progress("migration", 0, 0, status="error")
|
||||
progress_manager.mark_error("migration", str(e))
|
||||
|
||||
create_background_task(migrate_background())
|
||||
|
||||
return {"source": str(source), "destination": str(destination)}
|
||||
|
||||
|
||||
@router.get("/models/migrate/progress")
|
||||
async def get_migration_progress():
|
||||
"""Get model migration progress via Server-Sent Events."""
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
async def event_generator():
|
||||
async for event in progress_manager.subscribe("migration"):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models/status", response_model=models.ModelStatusListResponse)
|
||||
async def get_model_status():
|
||||
"""Get status of all available models."""
|
||||
from huggingface_hub import constants as hf_constants
|
||||
|
||||
backend_type = get_backend_type()
|
||||
task_manager = get_task_manager()
|
||||
|
||||
active_download_names = {task.model_name for task in task_manager.get_active_downloads()}
|
||||
|
||||
try:
|
||||
from huggingface_hub import scan_cache_dir
|
||||
|
||||
use_scan_cache = True
|
||||
except ImportError:
|
||||
use_scan_cache = False
|
||||
|
||||
from ..backends import get_all_model_configs, check_model_loaded
|
||||
|
||||
registry_configs = get_all_model_configs()
|
||||
model_configs = [
|
||||
{
|
||||
"model_name": cfg.model_name,
|
||||
"display_name": cfg.display_name,
|
||||
"hf_repo_id": cfg.hf_repo_id,
|
||||
"model_size": cfg.model_size,
|
||||
"check_loaded": lambda c=cfg: check_model_loaded(c),
|
||||
}
|
||||
for cfg in registry_configs
|
||||
]
|
||||
|
||||
model_to_repo = {cfg["model_name"]: cfg["hf_repo_id"] for cfg in model_configs}
|
||||
active_download_repos = {model_to_repo.get(name) for name in active_download_names if name in model_to_repo}
|
||||
|
||||
cache_info = None
|
||||
if use_scan_cache:
|
||||
try:
|
||||
cache_info = scan_cache_dir()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
statuses = []
|
||||
|
||||
for config in model_configs:
|
||||
try:
|
||||
downloaded = False
|
||||
size_mb = None
|
||||
loaded = False
|
||||
|
||||
if cache_info:
|
||||
repo_id = config["hf_repo_id"]
|
||||
for repo in cache_info.repos:
|
||||
if repo.repo_id == repo_id:
|
||||
has_model_weights = False
|
||||
for rev in repo.revisions:
|
||||
for f in rev.files:
|
||||
fname = f.file_name.lower()
|
||||
if fname.endswith((".safetensors", ".bin", ".pt", ".pth", ".npz")):
|
||||
has_model_weights = True
|
||||
break
|
||||
if has_model_weights:
|
||||
break
|
||||
|
||||
has_incomplete = False
|
||||
try:
|
||||
cache_dir = hf_constants.HF_HUB_CACHE
|
||||
blobs_dir = Path(cache_dir) / ("models--" + repo_id.replace("/", "--")) / "blobs"
|
||||
if blobs_dir.exists():
|
||||
has_incomplete = any(blobs_dir.glob("*.incomplete"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if has_model_weights and not has_incomplete:
|
||||
downloaded = True
|
||||
try:
|
||||
total_size = sum(revision.size_on_disk for revision in repo.revisions)
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
if not downloaded:
|
||||
try:
|
||||
cache_dir = hf_constants.HF_HUB_CACHE
|
||||
repo_cache = Path(cache_dir) / ("models--" + config["hf_repo_id"].replace("/", "--"))
|
||||
|
||||
if repo_cache.exists():
|
||||
blobs_dir = repo_cache / "blobs"
|
||||
has_incomplete = blobs_dir.exists() and any(blobs_dir.glob("*.incomplete"))
|
||||
|
||||
if not has_incomplete:
|
||||
snapshots_dir = repo_cache / "snapshots"
|
||||
has_model_files = False
|
||||
if snapshots_dir.exists():
|
||||
has_model_files = (
|
||||
any(snapshots_dir.rglob("*.bin"))
|
||||
or any(snapshots_dir.rglob("*.safetensors"))
|
||||
or any(snapshots_dir.rglob("*.pt"))
|
||||
or any(snapshots_dir.rglob("*.pth"))
|
||||
or any(snapshots_dir.rglob("*.npz"))
|
||||
)
|
||||
|
||||
if has_model_files:
|
||||
downloaded = True
|
||||
try:
|
||||
total_size = sum(
|
||||
f.stat().st_size
|
||||
for f in repo_cache.rglob("*")
|
||||
if f.is_file() and not f.name.endswith(".incomplete")
|
||||
)
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
loaded = config["check_loaded"]()
|
||||
except Exception:
|
||||
loaded = False
|
||||
|
||||
is_downloading = config["hf_repo_id"] in active_download_repos
|
||||
|
||||
if is_downloading:
|
||||
downloaded = False
|
||||
size_mb = None
|
||||
|
||||
statuses.append(
|
||||
models.ModelStatus(
|
||||
model_name=config["model_name"],
|
||||
display_name=config["display_name"],
|
||||
hf_repo_id=config["hf_repo_id"],
|
||||
downloaded=downloaded,
|
||||
downloading=is_downloading,
|
||||
size_mb=size_mb,
|
||||
loaded=loaded,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
loaded = config["check_loaded"]()
|
||||
except Exception:
|
||||
loaded = False
|
||||
|
||||
is_downloading = config["hf_repo_id"] in active_download_repos
|
||||
|
||||
statuses.append(
|
||||
models.ModelStatus(
|
||||
model_name=config["model_name"],
|
||||
display_name=config["display_name"],
|
||||
hf_repo_id=config["hf_repo_id"],
|
||||
downloaded=False,
|
||||
downloading=is_downloading,
|
||||
size_mb=None,
|
||||
loaded=loaded,
|
||||
)
|
||||
)
|
||||
|
||||
return models.ModelStatusListResponse(models=statuses)
|
||||
|
||||
|
||||
@router.post("/models/download")
|
||||
async def trigger_model_download(request: models.ModelDownloadRequest):
|
||||
"""Trigger download of a specific model."""
|
||||
from ..backends import get_model_config, get_model_load_func
|
||||
|
||||
task_manager = get_task_manager()
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
config = get_model_config(request.model_name)
|
||||
if not config:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown model: {request.model_name}")
|
||||
|
||||
load_func = get_model_load_func(config)
|
||||
|
||||
async def download_in_background():
|
||||
try:
|
||||
result = load_func()
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
task_manager.complete_download(request.model_name)
|
||||
except Exception as e:
|
||||
task_manager.error_download(request.model_name, str(e))
|
||||
|
||||
task_manager.start_download(request.model_name)
|
||||
|
||||
progress_manager.update_progress(
|
||||
model_name=request.model_name,
|
||||
current=0,
|
||||
total=0,
|
||||
filename="Connecting to HuggingFace...",
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
create_background_task(download_in_background())
|
||||
|
||||
return {"message": f"Model {request.model_name} download started"}
|
||||
|
||||
|
||||
@router.post("/models/download/cancel")
|
||||
async def cancel_model_download(request: models.ModelDownloadRequest):
|
||||
"""Cancel or dismiss an errored/stale download task."""
|
||||
task_manager = get_task_manager()
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
removed = task_manager.cancel_download(request.model_name)
|
||||
|
||||
progress_removed = False
|
||||
with progress_manager._lock:
|
||||
if request.model_name in progress_manager._progress:
|
||||
del progress_manager._progress[request.model_name]
|
||||
progress_removed = True
|
||||
|
||||
if removed or progress_removed:
|
||||
return {"message": f"Download task for {request.model_name} cancelled"}
|
||||
return {"message": f"No active task found for {request.model_name}"}
|
||||
|
||||
|
||||
@router.delete("/models/{model_name}")
|
||||
async def delete_model(model_name: str):
|
||||
"""Delete a downloaded model from the HuggingFace cache."""
|
||||
from huggingface_hub import constants as hf_constants
|
||||
from ..backends import get_model_config, unload_model_by_config
|
||||
|
||||
config = get_model_config(model_name)
|
||||
if not config:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
|
||||
|
||||
hf_repo_id = config.hf_repo_id
|
||||
|
||||
try:
|
||||
unload_model_by_config(config)
|
||||
|
||||
cache_dir = hf_constants.HF_HUB_CACHE
|
||||
repo_cache_dir = Path(cache_dir) / ("models--" + hf_repo_id.replace("/", "--"))
|
||||
|
||||
if not repo_cache_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found in cache")
|
||||
|
||||
try:
|
||||
shutil.rmtree(repo_cache_dir)
|
||||
except OSError as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete model cache directory: {str(e)}")
|
||||
|
||||
return {"message": f"Model {model_name} deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
|
||||
363
backend/routes/profiles.py
Normal file
363
backend/routes/profiles.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""Voice profile endpoints."""
|
||||
|
||||
import io
|
||||
import json as _json
|
||||
import logging
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import config, models
|
||||
from ..app import safe_content_disposition
|
||||
from ..database import VoiceProfile as DBVoiceProfile, get_db
|
||||
from ..services import channels, export_import, profiles
|
||||
from ..services.profiles import _profile_to_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/profiles", response_model=models.VoiceProfileResponse)
|
||||
async def create_profile(
|
||||
data: models.VoiceProfileCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create a new voice profile."""
|
||||
try:
|
||||
return await profiles.create_profile(data, db)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/profiles", response_model=list[models.VoiceProfileResponse])
|
||||
async def list_profiles(db: Session = Depends(get_db)):
|
||||
"""List all voice profiles."""
|
||||
return await profiles.list_profiles(db)
|
||||
|
||||
|
||||
@router.post("/profiles/import", response_model=models.VoiceProfileResponse)
|
||||
async def import_profile(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Import a voice profile from a ZIP archive."""
|
||||
MAX_FILE_SIZE = 100 * 1024 * 1024
|
||||
|
||||
content = await file.read()
|
||||
|
||||
if len(content) > MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"File too large. Maximum size is {MAX_FILE_SIZE / (1024 * 1024)}MB"
|
||||
)
|
||||
|
||||
try:
|
||||
profile = await export_import.import_profile_from_zip(content, db)
|
||||
return profile
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ── Preset Voice Endpoints ───────────────────────────────────────────
|
||||
# These MUST be declared before /profiles/{profile_id} to avoid the
|
||||
# wildcard swallowing "presets" as a profile_id.
|
||||
|
||||
|
||||
@router.get("/profiles/presets/{engine}")
|
||||
async def list_preset_voices(engine: str):
|
||||
"""List available preset voices for an engine."""
|
||||
if engine == "kokoro":
|
||||
from ..backends.kokoro_backend import KOKORO_VOICES
|
||||
|
||||
return {
|
||||
"engine": engine,
|
||||
"voices": [
|
||||
{
|
||||
"voice_id": vid,
|
||||
"name": name,
|
||||
"gender": gender,
|
||||
"language": lang,
|
||||
}
|
||||
for vid, name, gender, lang in KOKORO_VOICES
|
||||
],
|
||||
}
|
||||
if engine == "qwen_custom_voice":
|
||||
from ..backends.qwen_custom_voice_backend import QWEN_CUSTOM_VOICES
|
||||
|
||||
return {
|
||||
"engine": engine,
|
||||
"voices": [
|
||||
{
|
||||
"voice_id": speaker_id,
|
||||
"name": display_name,
|
||||
"gender": gender,
|
||||
"language": lang,
|
||||
}
|
||||
for speaker_id, display_name, gender, lang, _desc in QWEN_CUSTOM_VOICES
|
||||
],
|
||||
}
|
||||
return {"engine": engine, "voices": []}
|
||||
|
||||
@router.get("/profiles/{profile_id}", response_model=models.VoiceProfileResponse)
|
||||
async def get_profile(
|
||||
profile_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get a voice profile by ID."""
|
||||
profile = await profiles.get_profile(profile_id, db)
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return profile
|
||||
|
||||
|
||||
@router.put("/profiles/{profile_id}", response_model=models.VoiceProfileResponse)
|
||||
async def update_profile(
|
||||
profile_id: str,
|
||||
data: models.VoiceProfileCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update a voice profile."""
|
||||
try:
|
||||
profile = await profiles.update_profile(profile_id, data, db)
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return profile
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/profiles/{profile_id}")
|
||||
async def delete_profile(
|
||||
profile_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete a voice profile."""
|
||||
success = await profiles.delete_profile(profile_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
return {"message": "Profile deleted successfully"}
|
||||
|
||||
|
||||
SAMPLE_MAX_FILE_SIZE = 50 * 1024 * 1024 # 50 MB
|
||||
SAMPLE_UPLOAD_CHUNK_SIZE = 1024 * 1024 # 1 MB
|
||||
|
||||
|
||||
@router.post("/profiles/{profile_id}/samples", response_model=models.ProfileSampleResponse)
|
||||
async def add_profile_sample(
|
||||
profile_id: str,
|
||||
file: UploadFile = File(...),
|
||||
reference_text: str = Form(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Add a sample to a voice profile."""
|
||||
_allowed_audio_exts = {".wav", ".mp3", ".m4a", ".ogg", ".flac", ".aac", ".webm", ".opus"}
|
||||
_uploaded_ext = Path(file.filename or "").suffix.lower()
|
||||
file_suffix = _uploaded_ext if _uploaded_ext in _allowed_audio_exts else ".wav"
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=file_suffix, delete=False) as tmp:
|
||||
total_size = 0
|
||||
while chunk := await file.read(SAMPLE_UPLOAD_CHUNK_SIZE):
|
||||
total_size += len(chunk)
|
||||
if total_size > SAMPLE_MAX_FILE_SIZE:
|
||||
Path(tmp.name).unlink(missing_ok=True)
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File too large (max {SAMPLE_MAX_FILE_SIZE // (1024 * 1024)} MB)",
|
||||
)
|
||||
tmp.write(chunk)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
sample = await profiles.add_profile_sample(
|
||||
profile_id,
|
||||
tmp_path,
|
||||
reference_text,
|
||||
db,
|
||||
)
|
||||
return sample
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to process audio file: {str(e)}")
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}/samples", response_model=list[models.ProfileSampleResponse])
|
||||
async def get_profile_samples(
|
||||
profile_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all samples for a profile."""
|
||||
return await profiles.get_profile_samples(profile_id, db)
|
||||
|
||||
|
||||
@router.delete("/profiles/samples/{sample_id}")
|
||||
async def delete_profile_sample(
|
||||
sample_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete a profile sample."""
|
||||
success = await profiles.delete_profile_sample(sample_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Sample not found")
|
||||
return {"message": "Sample deleted successfully"}
|
||||
|
||||
|
||||
@router.put("/profiles/samples/{sample_id}", response_model=models.ProfileSampleResponse)
|
||||
async def update_profile_sample(
|
||||
sample_id: str,
|
||||
data: models.ProfileSampleUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update a profile sample's reference text."""
|
||||
sample = await profiles.update_profile_sample(sample_id, data.reference_text, db)
|
||||
if not sample:
|
||||
raise HTTPException(status_code=404, detail="Sample not found")
|
||||
return sample
|
||||
|
||||
|
||||
@router.post("/profiles/{profile_id}/avatar", response_model=models.VoiceProfileResponse)
|
||||
async def upload_profile_avatar(
|
||||
profile_id: str,
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Upload or update avatar image for a profile."""
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp:
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
profile = await profiles.upload_avatar(profile_id, tmp_path, db)
|
||||
return profile
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}/avatar")
|
||||
async def get_profile_avatar(
|
||||
profile_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get avatar image for a profile."""
|
||||
profile = await profiles.get_profile(profile_id, db)
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
|
||||
if not profile.avatar_path:
|
||||
raise HTTPException(status_code=404, detail="No avatar found for this profile")
|
||||
|
||||
avatar_path = config.resolve_storage_path(profile.avatar_path)
|
||||
if avatar_path is None or not avatar_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Avatar file not found")
|
||||
|
||||
return FileResponse(avatar_path)
|
||||
|
||||
|
||||
@router.delete("/profiles/{profile_id}/avatar")
|
||||
async def delete_profile_avatar(
|
||||
profile_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete avatar image for a profile."""
|
||||
success = await profiles.delete_avatar(profile_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Profile not found or no avatar to delete")
|
||||
return {"message": "Avatar deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}/export")
|
||||
async def export_profile(
|
||||
profile_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Export a voice profile as a ZIP archive."""
|
||||
try:
|
||||
profile = await profiles.get_profile(profile_id, db)
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
|
||||
zip_bytes = export_import.export_profile_to_zip(profile_id, db)
|
||||
|
||||
safe_name = "".join(c for c in profile.name if c.isalnum() or c in (" ", "-", "_")).strip()
|
||||
if not safe_name:
|
||||
safe_name = "profile"
|
||||
filename = f"profile-{safe_name}.voicebox.zip"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(zip_bytes),
|
||||
media_type="application/zip",
|
||||
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/profiles/{profile_id}/channels")
|
||||
async def get_profile_channels(
|
||||
profile_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get list of channel IDs assigned to a profile."""
|
||||
try:
|
||||
channel_ids = await channels.get_profile_channels(profile_id, db)
|
||||
return {"channel_ids": channel_ids}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/profiles/{profile_id}/channels")
|
||||
async def set_profile_channels(
|
||||
profile_id: str,
|
||||
data: models.ProfileChannelAssignment,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Set which channels a profile is assigned to."""
|
||||
try:
|
||||
await channels.set_profile_channels(profile_id, data, db)
|
||||
return {"message": "Profile channels updated successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/profiles/{profile_id}/effects", response_model=models.VoiceProfileResponse)
|
||||
async def update_profile_effects(
|
||||
profile_id: str,
|
||||
data: models.ProfileEffectsUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Set or clear the default effects chain for a voice profile."""
|
||||
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
|
||||
if not profile:
|
||||
raise HTTPException(status_code=404, detail="Profile not found")
|
||||
|
||||
if data.effects_chain is not None:
|
||||
from ..utils.effects import validate_effects_chain
|
||||
|
||||
chain_dicts = [e.model_dump() for e in data.effects_chain]
|
||||
error = validate_effects_chain(chain_dicts)
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
profile.effects_chain = _json.dumps(chain_dicts)
|
||||
else:
|
||||
profile.effects_chain = None
|
||||
|
||||
profile.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
db.refresh(profile)
|
||||
|
||||
return _profile_to_response(profile)
|
||||
223
backend/routes/stories.py
Normal file
223
backend/routes/stories.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Story endpoints."""
|
||||
|
||||
import io
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .. import database, models
|
||||
from ..services import stories
|
||||
from ..app import safe_content_disposition
|
||||
from ..database import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/stories", response_model=list[models.StoryResponse])
|
||||
async def list_stories(db: Session = Depends(get_db)):
|
||||
"""List all stories."""
|
||||
return await stories.list_stories(db)
|
||||
|
||||
|
||||
@router.post("/stories", response_model=models.StoryResponse)
|
||||
async def create_story(
|
||||
data: models.StoryCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Create a new story."""
|
||||
try:
|
||||
return await stories.create_story(data, db)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stories/{story_id}", response_model=models.StoryDetailResponse)
|
||||
async def get_story(
|
||||
story_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get a story with all its items."""
|
||||
story = await stories.get_story(story_id, db)
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="Story not found")
|
||||
return story
|
||||
|
||||
|
||||
@router.put("/stories/{story_id}", response_model=models.StoryResponse)
|
||||
async def update_story(
|
||||
story_id: str,
|
||||
data: models.StoryCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update a story."""
|
||||
story = await stories.update_story(story_id, data, db)
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="Story not found")
|
||||
return story
|
||||
|
||||
|
||||
@router.delete("/stories/{story_id}")
|
||||
async def delete_story(
|
||||
story_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Delete a story."""
|
||||
success = await stories.delete_story(story_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Story not found")
|
||||
return {"message": "Story deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/stories/{story_id}/items", response_model=models.StoryItemDetail)
|
||||
async def add_story_item(
|
||||
story_id: str,
|
||||
data: models.StoryItemCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Add a generation to a story."""
|
||||
item = await stories.add_item_to_story(story_id, data, db)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Story or generation not found")
|
||||
return item
|
||||
|
||||
|
||||
@router.delete("/stories/{story_id}/items/{item_id}")
|
||||
async def remove_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Remove a story item from a story."""
|
||||
success = await stories.remove_item_from_story(story_id, item_id, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Story item not found")
|
||||
return {"message": "Item removed successfully"}
|
||||
|
||||
|
||||
@router.put("/stories/{story_id}/items/times")
|
||||
async def update_story_item_times(
|
||||
story_id: str,
|
||||
data: models.StoryItemBatchUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Update story item timecodes."""
|
||||
success = await stories.update_story_item_times(story_id, data, db)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Invalid timecode update request")
|
||||
return {"message": "Item timecodes updated successfully"}
|
||||
|
||||
|
||||
@router.put("/stories/{story_id}/items/reorder", response_model=list[models.StoryItemDetail])
|
||||
async def reorder_story_items(
|
||||
story_id: str,
|
||||
data: models.StoryItemReorder,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Reorder story items and recalculate timecodes."""
|
||||
items = await stories.reorder_story_items(story_id, data.generation_ids, db)
|
||||
if items is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid reorder request - ensure all generation IDs belong to this story"
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
@router.put("/stories/{story_id}/items/{item_id}/move", response_model=models.StoryItemDetail)
|
||||
async def move_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: models.StoryItemMove,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Move a story item (update position and/or track)."""
|
||||
item = await stories.move_story_item(story_id, item_id, data, db)
|
||||
if item is None:
|
||||
raise HTTPException(status_code=404, detail="Story item not found")
|
||||
return item
|
||||
|
||||
|
||||
@router.put("/stories/{story_id}/items/{item_id}/trim", response_model=models.StoryItemDetail)
|
||||
async def trim_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: models.StoryItemTrim,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Trim a story item."""
|
||||
item = await stories.trim_story_item(story_id, item_id, data, db)
|
||||
if item is None:
|
||||
raise HTTPException(status_code=404, detail="Story item not found or invalid trim values")
|
||||
return item
|
||||
|
||||
|
||||
@router.post("/stories/{story_id}/items/{item_id}/split", response_model=list[models.StoryItemDetail])
|
||||
async def split_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: models.StoryItemSplit,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Split a story item at a given time, creating two clips."""
|
||||
items = await stories.split_story_item(story_id, item_id, data, db)
|
||||
if items is None:
|
||||
raise HTTPException(status_code=404, detail="Story item not found or invalid split point")
|
||||
return items
|
||||
|
||||
|
||||
@router.post("/stories/{story_id}/items/{item_id}/duplicate", response_model=models.StoryItemDetail)
|
||||
async def duplicate_story_item(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Duplicate a story item."""
|
||||
item = await stories.duplicate_story_item(story_id, item_id, db)
|
||||
if item is None:
|
||||
raise HTTPException(status_code=404, detail="Story item not found")
|
||||
return item
|
||||
|
||||
|
||||
@router.put("/stories/{story_id}/items/{item_id}/version", response_model=models.StoryItemDetail)
|
||||
async def set_story_item_version(
|
||||
story_id: str,
|
||||
item_id: str,
|
||||
data: models.StoryItemVersionUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Pin a story item to a specific generation version."""
|
||||
item = await stories.set_story_item_version(story_id, item_id, data, db)
|
||||
if item is None:
|
||||
raise HTTPException(status_code=404, detail="Story item or version not found")
|
||||
return item
|
||||
|
||||
|
||||
@router.get("/stories/{story_id}/export-audio")
|
||||
async def export_story_audio(
|
||||
story_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Export story as single mixed audio file."""
|
||||
try:
|
||||
story = db.query(database.Story).filter_by(id=story_id).first()
|
||||
if not story:
|
||||
raise HTTPException(status_code=404, detail="Story not found")
|
||||
|
||||
audio_bytes = await stories.export_story_audio(story_id, db)
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=400, detail="Story has no audio items")
|
||||
|
||||
safe_name = "".join(c for c in story.name if c.isalnum() or c in (" ", "-", "_")).strip()
|
||||
if not safe_name:
|
||||
safe_name = "story"
|
||||
filename = f"{safe_name}.wav"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(audio_bytes),
|
||||
media_type="audio/wav",
|
||||
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
125
backend/routes/tasks.py
Normal file
125
backend/routes/tasks.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Task and cache management endpoints."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .. import models
|
||||
from ..utils.cache import clear_voice_prompt_cache
|
||||
from ..utils.progress import get_progress_manager
|
||||
from ..utils.tasks import get_task_manager
|
||||
from fastapi import HTTPException
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/tasks/clear")
|
||||
async def clear_all_tasks():
|
||||
"""Clear all download tasks and progress state."""
|
||||
task_manager = get_task_manager()
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
task_manager.clear_all()
|
||||
|
||||
with progress_manager._lock:
|
||||
progress_manager._progress.clear()
|
||||
progress_manager._last_notify_time.clear()
|
||||
progress_manager._last_notify_progress.clear()
|
||||
|
||||
return {"message": "All task state cleared"}
|
||||
|
||||
|
||||
@router.post("/cache/clear")
|
||||
async def clear_cache():
|
||||
"""Clear all voice prompt caches (memory and disk)."""
|
||||
try:
|
||||
deleted_count = clear_voice_prompt_cache()
|
||||
return {
|
||||
"message": "Voice prompt cache cleared successfully",
|
||||
"files_deleted": deleted_count,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to clear cache: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tasks/active", response_model=models.ActiveTasksResponse)
|
||||
async def get_active_tasks():
|
||||
"""Return all currently active downloads and generations."""
|
||||
task_manager = get_task_manager()
|
||||
progress_manager = get_progress_manager()
|
||||
|
||||
active_downloads = []
|
||||
task_manager_downloads = task_manager.get_active_downloads()
|
||||
progress_active = progress_manager.get_all_active()
|
||||
|
||||
download_map = {task.model_name: task for task in task_manager_downloads}
|
||||
progress_map = {p["model_name"]: p for p in progress_active}
|
||||
|
||||
all_model_names = set(download_map.keys()) | set(progress_map.keys())
|
||||
for model_name in all_model_names:
|
||||
task = download_map.get(model_name)
|
||||
progress = progress_map.get(model_name)
|
||||
|
||||
if task:
|
||||
error = task.error
|
||||
if not error:
|
||||
with progress_manager._lock:
|
||||
pm_data = progress_manager._progress.get(model_name)
|
||||
if pm_data:
|
||||
error = pm_data.get("error")
|
||||
prog = progress or {}
|
||||
if not prog:
|
||||
with progress_manager._lock:
|
||||
pm_data = progress_manager._progress.get(model_name)
|
||||
if pm_data:
|
||||
prog = pm_data
|
||||
active_downloads.append(
|
||||
models.ActiveDownloadTask(
|
||||
model_name=model_name,
|
||||
status=task.status,
|
||||
started_at=task.started_at,
|
||||
error=error,
|
||||
progress=prog.get("progress"),
|
||||
current=prog.get("current"),
|
||||
total=prog.get("total"),
|
||||
filename=prog.get("filename"),
|
||||
)
|
||||
)
|
||||
elif progress:
|
||||
timestamp_str = progress.get("timestamp")
|
||||
if timestamp_str:
|
||||
try:
|
||||
started_at = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError):
|
||||
started_at = datetime.utcnow()
|
||||
else:
|
||||
started_at = datetime.utcnow()
|
||||
|
||||
active_downloads.append(
|
||||
models.ActiveDownloadTask(
|
||||
model_name=model_name,
|
||||
status=progress.get("status", "downloading"),
|
||||
started_at=started_at,
|
||||
error=progress.get("error"),
|
||||
progress=progress.get("progress"),
|
||||
current=progress.get("current"),
|
||||
total=progress.get("total"),
|
||||
filename=progress.get("filename"),
|
||||
)
|
||||
)
|
||||
|
||||
active_generations = []
|
||||
for gen_task in task_manager.get_active_generations():
|
||||
active_generations.append(
|
||||
models.ActiveGenerationTask(
|
||||
task_id=gen_task.task_id,
|
||||
profile_id=gen_task.profile_id,
|
||||
text_preview=gen_task.text_preview,
|
||||
started_at=gen_task.started_at,
|
||||
)
|
||||
)
|
||||
|
||||
return models.ActiveTasksResponse(
|
||||
downloads=active_downloads,
|
||||
generations=active_generations,
|
||||
)
|
||||
84
backend/routes/transcription.py
Normal file
84
backend/routes/transcription.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Transcription endpoints."""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
|
||||
from .. import models
|
||||
from ..services import transcribe
|
||||
from ..services.task_queue import create_background_task
|
||||
from ..utils.tasks import get_task_manager
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
UPLOAD_CHUNK_SIZE = 1024 * 1024 # 1MB
|
||||
|
||||
|
||||
@router.post("/transcribe", response_model=models.TranscriptionResponse)
|
||||
async def transcribe_audio(
|
||||
file: UploadFile = File(...),
|
||||
language: str | None = Form(None),
|
||||
model: str | None = Form(None),
|
||||
):
|
||||
"""Transcribe audio file to text."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
|
||||
tmp.write(chunk)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
from ..utils.audio import load_audio
|
||||
from ..backends import WHISPER_HF_REPOS
|
||||
|
||||
audio, sr = await asyncio.to_thread(load_audio, tmp_path)
|
||||
duration = len(audio) / sr
|
||||
|
||||
whisper_model = transcribe.get_whisper_model()
|
||||
model_size = model if model else whisper_model.model_size
|
||||
|
||||
valid_sizes = list(WHISPER_HF_REPOS.keys())
|
||||
if model_size not in valid_sizes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid model size '{model_size}'. Must be one of: {', '.join(valid_sizes)}",
|
||||
)
|
||||
|
||||
already_loaded = whisper_model.is_loaded() and whisper_model.model_size == model_size
|
||||
if not already_loaded and not whisper_model._is_model_cached(model_size):
|
||||
progress_model_name = f"whisper-{model_size}"
|
||||
task_manager = get_task_manager()
|
||||
|
||||
async def download_whisper_background():
|
||||
try:
|
||||
await whisper_model.load_model_async(model_size)
|
||||
task_manager.complete_download(progress_model_name)
|
||||
except Exception as e:
|
||||
task_manager.error_download(progress_model_name, str(e))
|
||||
|
||||
task_manager.start_download(progress_model_name)
|
||||
create_background_task(download_whisper_background())
|
||||
|
||||
raise HTTPException(
|
||||
status_code=202,
|
||||
detail={
|
||||
"message": f"Whisper model {model_size} is being downloaded. Please wait and try again.",
|
||||
"model_name": progress_model_name,
|
||||
"downloading": True,
|
||||
},
|
||||
)
|
||||
|
||||
text = await whisper_model.transcribe(tmp_path, language, model_size)
|
||||
|
||||
return models.TranscriptionResponse(
|
||||
text=text,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
Reference in New Issue
Block a user