Initial commit

This commit is contained in:
2026-04-24 19:18:15 +08:00
commit fbcbe08696
555 changed files with 96692 additions and 0 deletions

View File

@@ -0,0 +1,32 @@
"""Route registration for the voicebox API."""
from fastapi import FastAPI
def register_routers(app: FastAPI) -> None:
"""Include all domain routers on the application."""
from .health import router as health_router
from .profiles import router as profiles_router
from .channels import router as channels_router
from .generations import router as generations_router
from .history import router as history_router
from .transcription import router as transcription_router
from .stories import router as stories_router
from .effects import router as effects_router
from .audio import router as audio_router
from .models import router as models_router
from .tasks import router as tasks_router
from .cuda import router as cuda_router
app.include_router(health_router)
app.include_router(profiles_router)
app.include_router(channels_router)
app.include_router(generations_router)
app.include_router(history_router)
app.include_router(transcription_router)
app.include_router(stories_router)
app.include_router(effects_router)
app.include_router(audio_router)
app.include_router(models_router)
app.include_router(tasks_router)
app.include_router(cuda_router)

69
backend/routes/audio.py Normal file
View File

@@ -0,0 +1,69 @@
"""Audio file serving endpoints."""
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from .. import config, models
from ..services import history
from ..database import get_db
router = APIRouter()
@router.get("/audio/version/{version_id}")
async def get_version_audio(version_id: str, db: Session = Depends(get_db)):
"""Serve audio for a specific version."""
from ..services import versions as versions_mod
version = versions_mod.get_version(version_id, db)
if not version:
raise HTTPException(status_code=404, detail="Version not found")
audio_path = config.resolve_storage_path(version.audio_path)
if audio_path is None or not audio_path.exists():
raise HTTPException(status_code=404, detail="Audio file not found")
return FileResponse(
audio_path,
media_type="audio/wav",
filename=f"generation_{version.generation_id}_{version.label}.wav",
)
@router.get("/audio/{generation_id}")
async def get_audio(generation_id: str, db: Session = Depends(get_db)):
"""Serve generated audio file (serves the default version)."""
generation = await history.get_generation(generation_id, db)
if not generation:
raise HTTPException(status_code=404, detail="Generation not found")
audio_path = config.resolve_storage_path(generation.audio_path)
if audio_path is None or not audio_path.exists():
raise HTTPException(status_code=404, detail="Audio file not found")
return FileResponse(
audio_path,
media_type="audio/wav",
filename=f"generation_{generation_id}.wav",
)
@router.get("/samples/{sample_id}")
async def get_sample_audio(sample_id: str, db: Session = Depends(get_db)):
"""Serve profile sample audio file."""
from ..database import ProfileSample as DBProfileSample
sample = db.query(DBProfileSample).filter_by(id=sample_id).first()
if not sample:
raise HTTPException(status_code=404, detail="Sample not found")
audio_path = config.resolve_storage_path(sample.audio_path)
if audio_path is None or not audio_path.exists():
raise HTTPException(status_code=404, detail="Audio file not found")
return FileResponse(
audio_path,
media_type="audio/wav",
filename=f"sample_{sample_id}.wav",
)

View File

@@ -0,0 +1,98 @@
"""Audio channel endpoints."""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from .. import models
from ..services import channels
from ..database import get_db
router = APIRouter()
@router.get("/channels", response_model=list[models.AudioChannelResponse])
async def list_channels(db: Session = Depends(get_db)):
"""List all audio channels."""
return await channels.list_channels(db)
@router.post("/channels", response_model=models.AudioChannelResponse)
async def create_channel(
data: models.AudioChannelCreate,
db: Session = Depends(get_db),
):
"""Create a new audio channel."""
try:
return await channels.create_channel(data, db)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/channels/{channel_id}", response_model=models.AudioChannelResponse)
async def get_channel(
channel_id: str,
db: Session = Depends(get_db),
):
"""Get an audio channel by ID."""
channel = await channels.get_channel(channel_id, db)
if not channel:
raise HTTPException(status_code=404, detail="Channel not found")
return channel
@router.put("/channels/{channel_id}", response_model=models.AudioChannelResponse)
async def update_channel(
channel_id: str,
data: models.AudioChannelUpdate,
db: Session = Depends(get_db),
):
"""Update an audio channel."""
try:
channel = await channels.update_channel(channel_id, data, db)
if not channel:
raise HTTPException(status_code=404, detail="Channel not found")
return channel
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/channels/{channel_id}")
async def delete_channel(
channel_id: str,
db: Session = Depends(get_db),
):
"""Delete an audio channel."""
try:
success = await channels.delete_channel(channel_id, db)
if not success:
raise HTTPException(status_code=404, detail="Channel not found")
return {"message": "Channel deleted successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/channels/{channel_id}/voices")
async def get_channel_voices(
channel_id: str,
db: Session = Depends(get_db),
):
"""Get list of profile IDs assigned to a channel."""
try:
profile_ids = await channels.get_channel_voices(channel_id, db)
return {"profile_ids": profile_ids}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.put("/channels/{channel_id}/voices")
async def set_channel_voices(
channel_id: str,
data: models.ChannelVoiceAssignment,
db: Session = Depends(get_db),
):
"""Set which voices are assigned to a channel."""
try:
await channels.set_channel_voices(channel_id, data, db)
return {"message": "Channel voices updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

82
backend/routes/cuda.py Normal file
View File

@@ -0,0 +1,82 @@
"""CUDA backend management endpoints."""
import logging
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from ..services.task_queue import create_background_task
from ..utils.progress import get_progress_manager
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/backend/cuda-status")
async def get_cuda_status():
"""Get CUDA backend download/availability status."""
from ..services import cuda
return cuda.get_cuda_status()
@router.post("/backend/download-cuda")
async def download_cuda_backend():
"""Download the CUDA backend binary."""
from ..services import cuda
if cuda.get_cuda_binary_path() is not None:
raise HTTPException(status_code=409, detail="CUDA backend already downloaded")
progress_manager = get_progress_manager()
existing = progress_manager.get_progress(cuda.PROGRESS_KEY)
if existing and existing.get("status") == "downloading":
raise HTTPException(status_code=409, detail="CUDA backend download already in progress")
async def _download():
try:
await cuda.download_cuda_binary()
except Exception as e:
logger.error("CUDA download failed: %s", e)
create_background_task(_download())
return {"message": "CUDA backend download started", "progress_key": "cuda-backend"}
@router.delete("/backend/cuda")
async def delete_cuda_backend():
"""Delete the downloaded CUDA backend binary."""
from ..services import cuda
if cuda.is_cuda_active():
raise HTTPException(
status_code=409,
detail="Cannot delete CUDA backend while it is active. Switch to CPU first.",
)
deleted = await cuda.delete_cuda_binary()
if not deleted:
raise HTTPException(status_code=404, detail="No CUDA backend found to delete")
return {"message": "CUDA backend deleted"}
@router.get("/backend/cuda-progress")
async def get_cuda_download_progress():
"""Get CUDA backend download progress via Server-Sent Events."""
progress_manager = get_progress_manager()
async def event_generator():
async for event in progress_manager.subscribe("cuda-backend"):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)

262
backend/routes/effects.py Normal file
View File

@@ -0,0 +1,262 @@
"""Effects presets and generation version endpoints."""
import asyncio
import io
import uuid
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from .. import config, models
from ..services import history
from ..database import Generation as DBGeneration, get_db
router = APIRouter()
@router.post("/effects/preview/{generation_id}")
async def preview_effects(
generation_id: str,
data: models.ApplyEffectsRequest,
db: Session = Depends(get_db),
):
"""Apply effects to a generation's clean audio and stream back without saving."""
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if (gen.status or "completed") != "completed":
raise HTTPException(status_code=400, detail="Generation is not completed")
from ..services import versions as versions_mod
from ..utils.effects import apply_effects, validate_effects_chain
from ..utils.audio import load_audio
chain_dicts = [e.model_dump() for e in data.effects_chain]
error = validate_effects_chain(chain_dicts)
if error:
raise HTTPException(status_code=400, detail=error)
all_versions = versions_mod.list_versions(generation_id, db)
clean_version = next((v for v in all_versions if v.effects_chain is None), None)
source_path = clean_version.audio_path if clean_version else gen.audio_path
resolved_source_path = config.resolve_storage_path(source_path)
if resolved_source_path is None or not resolved_source_path.exists():
raise HTTPException(status_code=404, detail="Source audio file not found")
audio, sample_rate = await asyncio.to_thread(load_audio, str(resolved_source_path))
processed = await asyncio.to_thread(apply_effects, audio, sample_rate, chain_dicts)
import soundfile as sf
buf = io.BytesIO()
await asyncio.to_thread(lambda: sf.write(buf, processed, sample_rate, format="WAV"))
buf.seek(0)
return StreamingResponse(
buf,
media_type="audio/wav",
headers={
"Content-Disposition": f'inline; filename="preview_{generation_id}.wav"',
"Cache-Control": "no-cache, no-store",
},
)
@router.get("/effects/available", response_model=models.AvailableEffectsResponse)
async def get_available_effects():
"""List all available effect types with parameter definitions."""
from ..utils.effects import get_available_effects as _get_effects
return models.AvailableEffectsResponse(effects=[models.AvailableEffect(**e) for e in _get_effects()])
@router.get("/effects/presets", response_model=list[models.EffectPresetResponse])
async def list_effect_presets(db: Session = Depends(get_db)):
"""List all effect presets (built-in + user-created)."""
from ..services import effects as effects_mod
return effects_mod.list_presets(db)
@router.get("/effects/presets/{preset_id}", response_model=models.EffectPresetResponse)
async def get_effect_preset(preset_id: str, db: Session = Depends(get_db)):
"""Get a specific effect preset."""
from ..services import effects as effects_mod
preset = effects_mod.get_preset(preset_id, db)
if not preset:
raise HTTPException(status_code=404, detail="Preset not found")
return preset
@router.post("/effects/presets", response_model=models.EffectPresetResponse)
async def create_effect_preset(
data: models.EffectPresetCreate,
db: Session = Depends(get_db),
):
"""Create a new effect preset."""
from ..services import effects as effects_mod
try:
return effects_mod.create_preset(data, db)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.put("/effects/presets/{preset_id}", response_model=models.EffectPresetResponse)
async def update_effect_preset(
preset_id: str,
data: models.EffectPresetUpdate,
db: Session = Depends(get_db),
):
"""Update an effect preset."""
from ..services import effects as effects_mod
try:
result = effects_mod.update_preset(preset_id, data, db)
if not result:
raise HTTPException(status_code=404, detail="Preset not found")
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/effects/presets/{preset_id}")
async def delete_effect_preset(preset_id: str, db: Session = Depends(get_db)):
"""Delete a user effect preset."""
from ..services import effects as effects_mod
try:
if not effects_mod.delete_preset(preset_id, db):
raise HTTPException(status_code=404, detail="Preset not found")
return {"status": "deleted"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get(
"/generations/{generation_id}/versions",
response_model=list[models.GenerationVersionResponse],
)
async def list_generation_versions(
generation_id: str,
db: Session = Depends(get_db),
):
"""List all versions for a generation."""
gen = await history.get_generation(generation_id, db)
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
from ..services import versions as versions_mod
return versions_mod.list_versions(generation_id, db)
@router.post(
"/generations/{generation_id}/versions/apply-effects",
response_model=models.GenerationVersionResponse,
)
async def apply_effects_to_generation(
generation_id: str,
data: models.ApplyEffectsRequest,
db: Session = Depends(get_db),
):
"""Apply an effects chain to an existing generation, creating a new version."""
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if (gen.status or "completed") != "completed":
raise HTTPException(status_code=400, detail="Generation is not completed")
from ..services import versions as versions_mod
from ..utils.effects import apply_effects, validate_effects_chain
from ..utils.audio import load_audio, save_audio
chain_dicts = [e.model_dump() for e in data.effects_chain]
error = validate_effects_chain(chain_dicts)
if error:
raise HTTPException(status_code=400, detail=error)
all_versions = versions_mod.list_versions(generation_id, db)
source_version_id = data.source_version_id
if source_version_id:
source_version = next((v for v in all_versions if v.id == source_version_id), None)
if not source_version:
raise HTTPException(status_code=404, detail="Source version not found")
source_path = source_version.audio_path
else:
clean_version = next((v for v in all_versions if v.effects_chain is None), None)
if not clean_version:
source_path = gen.audio_path
else:
source_path = clean_version.audio_path
source_version_id = clean_version.id
resolved_source_path = config.resolve_storage_path(source_path)
if resolved_source_path is None or not resolved_source_path.exists():
raise HTTPException(status_code=404, detail="Source audio file not found")
audio, sample_rate = await asyncio.to_thread(load_audio, str(resolved_source_path))
processed_audio = await asyncio.to_thread(apply_effects, audio, sample_rate, chain_dicts)
version_id = str(uuid.uuid4())
processed_path = config.get_generations_dir() / f"{generation_id}_{version_id[:8]}.wav"
await asyncio.to_thread(save_audio, processed_audio, str(processed_path), sample_rate)
label = data.label or f"version-{len(all_versions) + 1}"
version = versions_mod.create_version(
generation_id=generation_id,
label=label,
audio_path=config.to_storage_path(processed_path),
db=db,
effects_chain=chain_dicts,
is_default=data.set_as_default,
source_version_id=source_version_id,
)
return version
@router.put(
"/generations/{generation_id}/versions/{version_id}/set-default",
response_model=models.GenerationVersionResponse,
)
async def set_default_version(
generation_id: str,
version_id: str,
db: Session = Depends(get_db),
):
"""Set a specific version as the default for a generation."""
from ..services import versions as versions_mod
version = versions_mod.get_version(version_id, db)
if not version or version.generation_id != generation_id:
raise HTTPException(status_code=404, detail="Version not found")
result = versions_mod.set_default_version(version_id, db)
if not result:
raise HTTPException(status_code=404, detail="Version not found")
return result
@router.delete("/generations/{generation_id}/versions/{version_id}")
async def delete_generation_version(
generation_id: str,
version_id: str,
db: Session = Depends(get_db),
):
"""Delete a version. Cannot delete the last remaining version."""
from ..services import versions as versions_mod
version = versions_mod.get_version(version_id, db)
if not version or version.generation_id != generation_id:
raise HTTPException(status_code=404, detail="Version not found")
if not versions_mod.delete_version(version_id, db):
raise HTTPException(
status_code=400,
detail="Cannot delete the last remaining version",
)
return {"status": "deleted"}

View File

@@ -0,0 +1,345 @@
"""TTS generation endpoints."""
import asyncio
import logging
import uuid
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
from .. import models
from ..services import history, profiles, tts
from ..database import Generation as DBGeneration, VoiceProfile as DBVoiceProfile, get_db
from ..services.generation import run_generation
from ..services.task_queue import cancel_generation as cancel_generation_job, enqueue_generation
from ..utils.tasks import get_task_manager
router = APIRouter()
def _resolve_generation_engine(data: models.GenerationRequest, profile) -> str:
return data.engine or getattr(profile, "default_engine", None) or getattr(profile, "preset_engine", None) or "qwen"
@router.post("/generate", response_model=models.GenerationResponse)
async def generate_speech(
data: models.GenerationRequest,
db: Session = Depends(get_db),
):
"""Generate speech from text using a voice profile."""
task_manager = get_task_manager()
generation_id = str(uuid.uuid4())
profile = await profiles.get_profile(data.profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
from ..backends import engine_has_model_sizes
engine = _resolve_generation_engine(data, profile)
try:
profiles.validate_profile_engine(profile, engine)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
model_size = (data.model_size or "1.7B") if engine_has_model_sizes(engine) else None
generation = await history.create_generation(
profile_id=data.profile_id,
text=data.text,
language=data.language,
audio_path="",
duration=0,
seed=data.seed,
db=db,
instruct=data.instruct,
generation_id=generation_id,
status="generating",
engine=engine,
model_size=model_size if engine_has_model_sizes(engine) else None,
)
task_manager.start_generation(
task_id=generation_id,
profile_id=data.profile_id,
text=data.text,
)
effects_chain_config = None
if data.effects_chain is not None:
effects_chain_config = [e.model_dump() for e in data.effects_chain]
else:
import json as _json
profile_obj = db.query(DBVoiceProfile).filter_by(id=data.profile_id).first()
if profile_obj and profile_obj.effects_chain:
try:
effects_chain_config = _json.loads(profile_obj.effects_chain)
except Exception:
pass
enqueue_generation(
generation_id,
run_generation(
generation_id=generation_id,
profile_id=data.profile_id,
text=data.text,
language=data.language,
engine=engine,
model_size=model_size,
seed=data.seed,
normalize=data.normalize,
effects_chain=effects_chain_config,
instruct=data.instruct,
mode="generate",
max_chunk_chars=data.max_chunk_chars,
crossfade_ms=data.crossfade_ms,
)
)
return generation
@router.post("/generate/{generation_id}/retry", response_model=models.GenerationResponse)
async def retry_generation(generation_id: str, db: Session = Depends(get_db)):
"""Retry a failed generation using the same parameters."""
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if (gen.status or "completed") != "failed":
raise HTTPException(status_code=400, detail="Only failed generations can be retried")
gen.status = "generating"
gen.error = None
gen.audio_path = ""
gen.duration = 0
db.commit()
db.refresh(gen)
task_manager = get_task_manager()
task_manager.start_generation(
task_id=generation_id,
profile_id=gen.profile_id,
text=gen.text,
)
enqueue_generation(
generation_id,
run_generation(
generation_id=generation_id,
profile_id=gen.profile_id,
text=gen.text,
language=gen.language,
engine=gen.engine or "qwen",
model_size=gen.model_size or "1.7B",
seed=gen.seed,
instruct=gen.instruct,
mode="retry",
)
)
return models.GenerationResponse.model_validate(gen)
@router.post(
"/generate/{generation_id}/regenerate",
response_model=models.GenerationResponse,
)
async def regenerate_generation(generation_id: str, db: Session = Depends(get_db)):
"""Re-run TTS with the same parameters and save the result as a new version."""
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if (gen.status or "completed") != "completed":
raise HTTPException(status_code=400, detail="Generation must be completed to regenerate")
gen.status = "generating"
gen.error = None
db.commit()
db.refresh(gen)
task_manager = get_task_manager()
task_manager.start_generation(
task_id=generation_id,
profile_id=gen.profile_id,
text=gen.text,
)
version_id = str(uuid.uuid4())
enqueue_generation(
generation_id,
run_generation(
generation_id=generation_id,
profile_id=gen.profile_id,
text=gen.text,
language=gen.language,
engine=gen.engine or "qwen",
model_size=gen.model_size or "1.7B",
seed=gen.seed,
instruct=gen.instruct,
mode="regenerate",
version_id=version_id,
)
)
return models.GenerationResponse.model_validate(gen)
@router.post("/generate/{generation_id}/cancel")
async def cancel_generation(generation_id: str, db: Session = Depends(get_db)):
"""Cancel a queued or running generation."""
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if (gen.status or "completed") not in ("loading_model", "generating"):
raise HTTPException(status_code=400, detail="Only active generations can be cancelled")
cancellation_state = cancel_generation_job(generation_id)
if cancellation_state is None:
raise HTTPException(status_code=409, detail="Generation is no longer cancellable")
if cancellation_state == "queued":
task_manager = get_task_manager()
task_manager.complete_generation(generation_id)
await history.update_generation_status(
generation_id=generation_id,
status="failed",
db=db,
error="Generation cancelled",
)
return {"message": "Queued generation cancelled"}
return {"message": "Generation cancellation requested"}
@router.get("/generate/{generation_id}/status")
async def get_generation_status(generation_id: str, db: Session = Depends(get_db)):
"""SSE endpoint that streams generation status updates."""
import json
async def event_stream():
try:
while True:
db.expire_all()
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if not gen:
yield f"data: {json.dumps({'status': 'not_found', 'id': generation_id})}\n\n"
return
payload = {
"id": gen.id,
"status": gen.status or "completed",
"duration": gen.duration,
"error": gen.error,
}
yield f"data: {json.dumps(payload)}\n\n"
if (gen.status or "completed") in ("completed", "failed"):
return
await asyncio.sleep(1)
except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError):
logger.debug("SSE client disconnected for generation %s", generation_id)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.post("/generate/stream")
async def stream_speech(
data: models.GenerationRequest,
db: Session = Depends(get_db),
):
"""Generate speech and stream the WAV audio directly without saving to disk."""
from ..backends import get_tts_backend_for_engine, ensure_model_cached_or_raise, load_engine_model, engine_needs_trim
profile = await profiles.get_profile(data.profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
engine = _resolve_generation_engine(data, profile)
try:
profiles.validate_profile_engine(profile, engine)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
tts_model = get_tts_backend_for_engine(engine)
model_size = data.model_size or "1.7B"
await ensure_model_cached_or_raise(engine, model_size)
await load_engine_model(engine, model_size)
voice_prompt = await profiles.create_voice_prompt_for_profile(
data.profile_id,
db,
engine=engine,
)
from ..utils.chunked_tts import generate_chunked
trim_fn = None
if engine_needs_trim(engine):
from ..utils.audio import trim_tts_output
trim_fn = trim_tts_output
audio, sample_rate = await generate_chunked(
tts_model,
data.text,
voice_prompt,
language=data.language,
seed=data.seed,
instruct=data.instruct,
max_chunk_chars=data.max_chunk_chars,
crossfade_ms=data.crossfade_ms,
trim_fn=trim_fn,
)
effects_chain_config = None
if data.effects_chain is not None:
effects_chain_config = [e.model_dump() for e in data.effects_chain]
elif profile.effects_chain:
import json as _json
try:
effects_chain_config = _json.loads(profile.effects_chain)
except Exception:
effects_chain_config = None
if effects_chain_config:
from ..utils.effects import apply_effects
audio = apply_effects(audio, sample_rate, effects_chain_config)
if data.normalize:
from ..utils.audio import normalize_audio
audio = normalize_audio(audio)
wav_bytes = tts.audio_to_wav_bytes(audio, sample_rate)
async def _wav_stream():
try:
chunk_size = 64 * 1024
for i in range(0, len(wav_bytes), chunk_size):
yield wav_bytes[i : i + chunk_size]
except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError):
logger.debug("Client disconnected during audio stream")
return StreamingResponse(
_wav_stream(),
media_type="audio/wav",
headers={"Content-Disposition": 'attachment; filename="speech.wav"'},
)

248
backend/routes/health.py Normal file
View File

@@ -0,0 +1,248 @@
"""Health and infrastructure endpoints."""
import asyncio
import os
import signal
from pathlib import Path
import torch
from fastapi import APIRouter, Depends
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from .. import config, models
from ..services import tts
from ..database import get_db
from ..utils.platform_detect import get_backend_type
router = APIRouter()
# Frontend build directory — present in Docker, absent in dev/API-only mode
_frontend_dir = Path(__file__).resolve().parent.parent.parent / "frontend"
@router.get("/")
async def root():
"""Root endpoint — serves SPA index.html in Docker, JSON otherwise."""
from .. import __version__
index = _frontend_dir / "index.html"
if index.is_file():
return FileResponse(index, media_type="text/html")
return {"message": "voicebox API", "version": __version__}
@router.post("/shutdown")
async def shutdown():
"""Gracefully shutdown the server."""
async def shutdown_async():
await asyncio.sleep(0.1)
os.kill(os.getpid(), signal.SIGTERM)
asyncio.create_task(shutdown_async())
return {"message": "Shutting down..."}
@router.post("/watchdog/disable")
async def watchdog_disable():
"""Disable the parent process watchdog so the server keeps running."""
from backend.server import disable_watchdog
disable_watchdog()
return {"message": "Watchdog disabled"}
@router.get("/health", response_model=models.HealthResponse)
async def health():
"""Health check endpoint."""
from huggingface_hub import constants as hf_constants
from pathlib import Path
tts_model = tts.get_tts_model()
backend_type = get_backend_type()
has_cuda = torch.cuda.is_available()
has_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
has_xpu = False
xpu_name = None
try:
import intel_extension_for_pytorch as ipex # noqa: F401 -- side-effect import enables XPU
if hasattr(torch, "xpu") and torch.xpu.is_available():
has_xpu = True
try:
xpu_name = torch.xpu.get_device_name(0)
except Exception:
xpu_name = "Intel GPU"
except ImportError:
pass
has_directml = False
directml_name = None
try:
import torch_directml
if torch_directml.device_count() > 0:
has_directml = True
try:
directml_name = torch_directml.device_name(0)
except Exception:
directml_name = "DirectML GPU"
except ImportError:
pass
gpu_compat_warning = None
if has_cuda:
from ..backends.base import check_cuda_compatibility
_compatible, gpu_compat_warning = check_cuda_compatibility()
gpu_available = has_cuda or has_mps or has_xpu or has_directml or backend_type == "mlx"
gpu_type = None
if has_cuda:
gpu_type = f"CUDA ({torch.cuda.get_device_name(0)})"
elif has_mps:
gpu_type = "MPS (Apple Silicon)"
elif backend_type == "mlx":
gpu_type = "Metal (Apple Silicon via MLX)"
elif has_xpu:
gpu_type = f"XPU ({xpu_name})"
elif has_directml:
gpu_type = f"DirectML ({directml_name})"
vram_used = None
if has_cuda:
vram_used = torch.cuda.memory_allocated() / 1024 / 1024
elif has_xpu:
try:
vram_used = torch.xpu.memory_allocated() / 1024 / 1024
except Exception:
pass # memory_allocated() may not be available on all IPEX versions
model_loaded = False
model_size = None
try:
if tts_model.is_loaded():
model_loaded = True
model_size = getattr(tts_model, "_current_model_size", None)
if not model_size:
model_size = getattr(tts_model, "model_size", None)
except Exception:
model_loaded = False
model_size = None
model_downloaded = None
try:
from ..backends import get_model_config
default_config = get_model_config("qwen-tts-1.7B")
default_model_id = default_config.hf_repo_id if default_config else "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
try:
from huggingface_hub import scan_cache_dir
cache_info = scan_cache_dir()
for repo in cache_info.repos:
if repo.repo_id == default_model_id:
model_downloaded = True
break
except (ImportError, Exception):
cache_dir = hf_constants.HF_HUB_CACHE
repo_cache = Path(cache_dir) / ("models--" + default_model_id.replace("/", "--"))
if repo_cache.exists():
has_model_files = (
any(repo_cache.rglob("*.bin"))
or any(repo_cache.rglob("*.safetensors"))
or any(repo_cache.rglob("*.pt"))
or any(repo_cache.rglob("*.pth"))
or any(repo_cache.rglob("*.npz"))
)
model_downloaded = has_model_files
except Exception:
pass
return models.HealthResponse(
status="healthy",
model_loaded=model_loaded,
model_downloaded=model_downloaded,
model_size=model_size,
gpu_available=gpu_available,
gpu_type=gpu_type,
vram_used_mb=vram_used,
backend_type=backend_type,
backend_variant=os.environ.get(
"VOICEBOX_BACKEND_VARIANT",
"cuda" if torch.cuda.is_available() else ("xpu" if has_xpu else "cpu"),
),
gpu_compatibility_warning=gpu_compat_warning,
)
@router.get("/health/filesystem", response_model=models.FilesystemHealthResponse)
async def filesystem_health():
"""Check filesystem health: directory existence, write permissions, and disk space."""
import shutil
dirs_to_check = {
"generations": config.get_generations_dir(),
"profiles": config.get_profiles_dir(),
"data": config.get_data_dir(),
}
checks: list[models.DirectoryCheck] = []
all_ok = True
for _label, dir_path in dirs_to_check.items():
exists = dir_path.exists()
writable = False
error = None
if exists:
probe = dir_path / ".voicebox_probe"
try:
probe.write_text("ok")
probe.unlink()
writable = True
except PermissionError:
error = "Permission denied"
except OSError as e:
error = str(e)
finally:
try:
probe.unlink(missing_ok=True)
except Exception:
pass
else:
error = "Directory does not exist"
if not exists or not writable:
all_ok = False
checks.append(
models.DirectoryCheck(
path=str(dir_path.resolve()),
exists=exists,
writable=writable,
error=error,
)
)
disk_free_mb = None
disk_total_mb = None
try:
usage = shutil.disk_usage(str(config.get_data_dir()))
disk_free_mb = round(usage.free / (1024 * 1024), 1)
disk_total_mb = round(usage.total / (1024 * 1024), 1)
if disk_free_mb < 500:
all_ok = False
except OSError:
all_ok = False
return models.FilesystemHealthResponse(
healthy=all_ok,
disk_free_mb=disk_free_mb,
disk_total_mb=disk_total_mb,
directories=checks,
)

189
backend/routes/history.py Normal file
View File

@@ -0,0 +1,189 @@
"""Generation history endpoints."""
import io
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from fastapi.responses import FileResponse, StreamingResponse
from sqlalchemy.orm import Session
from .. import config, models
from ..services import export_import, history
from ..app import safe_content_disposition
from ..database import Generation as DBGeneration, VoiceProfile as DBVoiceProfile, get_db
router = APIRouter()
@router.get("/history", response_model=models.HistoryListResponse)
async def list_history(
profile_id: str | None = None,
search: str | None = None,
limit: int = 50,
offset: int = 0,
db: Session = Depends(get_db),
):
"""List generation history with optional filters."""
query = models.HistoryQuery(
profile_id=profile_id,
search=search,
limit=limit,
offset=offset,
)
return await history.list_generations(query, db)
@router.get("/history/stats")
async def get_stats(db: Session = Depends(get_db)):
"""Get generation statistics."""
return await history.get_generation_stats(db)
@router.post("/history/import")
async def import_generation(
file: UploadFile = File(...),
db: Session = Depends(get_db),
):
"""Import a generation from a ZIP archive."""
MAX_FILE_SIZE = 50 * 1024 * 1024
content = await file.read()
if len(content) > MAX_FILE_SIZE:
raise HTTPException(
status_code=400, detail=f"File too large. Maximum size is {MAX_FILE_SIZE / (1024 * 1024)}MB"
)
try:
result = await export_import.import_generation_from_zip(content, db)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/history/failed")
async def clear_failed_generations(db: Session = Depends(get_db)):
"""Delete every generation with status='failed'. Used by the UI's 'Clear failed' button (#410)."""
count = await history.delete_failed_generations(db)
return {"deleted": count}
@router.get("/history/{generation_id}", response_model=models.HistoryResponse)
async def get_generation(
generation_id: str,
db: Session = Depends(get_db),
):
"""Get a generation by ID."""
result = (
db.query(DBGeneration, DBVoiceProfile.name.label("profile_name"))
.join(DBVoiceProfile, DBGeneration.profile_id == DBVoiceProfile.id)
.filter(DBGeneration.id == generation_id)
.first()
)
if not result:
raise HTTPException(status_code=404, detail="Generation not found")
gen, profile_name = result
return models.HistoryResponse(
id=gen.id,
profile_id=gen.profile_id,
profile_name=profile_name,
text=gen.text,
language=gen.language,
audio_path=gen.audio_path,
duration=gen.duration,
seed=gen.seed,
instruct=gen.instruct,
engine=gen.engine or "qwen",
model_size=gen.model_size,
status=gen.status or "completed",
error=gen.error,
is_favorited=bool(gen.is_favorited),
created_at=gen.created_at,
)
@router.post("/history/{generation_id}/favorite")
async def toggle_favorite(
generation_id: str,
db: Session = Depends(get_db),
):
"""Toggle the favorite status of a generation."""
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
gen.is_favorited = not gen.is_favorited
db.commit()
return {"is_favorited": gen.is_favorited}
@router.delete("/history/{generation_id}")
async def delete_generation(
generation_id: str,
db: Session = Depends(get_db),
):
"""Delete a generation."""
success = await history.delete_generation(generation_id, db)
if not success:
raise HTTPException(status_code=404, detail="Generation not found")
return {"message": "Generation deleted successfully"}
@router.get("/history/{generation_id}/export")
async def export_generation(
generation_id: str,
db: Session = Depends(get_db),
):
"""Export a generation as a ZIP archive."""
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
if not generation:
raise HTTPException(status_code=404, detail="Generation not found")
try:
zip_bytes = export_import.export_generation_to_zip(generation_id, db)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
safe_text = "".join(c for c in generation.text[:30] if c.isalnum() or c in (" ", "-", "_")).strip()
if not safe_text:
safe_text = "generation"
filename = f"generation-{safe_text}.voicebox.zip"
return StreamingResponse(
io.BytesIO(zip_bytes),
media_type="application/zip",
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
)
@router.get("/history/{generation_id}/export-audio")
async def export_generation_audio(
generation_id: str,
db: Session = Depends(get_db),
):
"""Export only the audio file from a generation."""
generation = db.query(DBGeneration).filter_by(id=generation_id).first()
if not generation:
raise HTTPException(status_code=404, detail="Generation not found")
if not generation.audio_path:
raise HTTPException(status_code=404, detail="Generation has no audio file")
audio_path = config.resolve_storage_path(generation.audio_path)
if audio_path is None or not audio_path.is_file():
raise HTTPException(status_code=404, detail="Audio file not found")
safe_text = "".join(c for c in generation.text[:30] if c.isalnum() or c in (" ", "-", "_")).strip()
if not safe_text:
safe_text = "generation"
filename = f"{safe_text}.wav"
return FileResponse(
audio_path,
media_type="audio/wav",
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
)

475
backend/routes/models.py Normal file
View File

@@ -0,0 +1,475 @@
"""Model management endpoints."""
import asyncio
import shutil
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from .. import models
from ..utils.platform_detect import get_backend_type
from ..services.task_queue import create_background_task
from ..utils.progress import get_progress_manager
from ..utils.tasks import get_task_manager
router = APIRouter()
def _get_dir_size(path: Path) -> int:
"""Get total size of a directory in bytes."""
total = 0
for f in path.rglob("*"):
if f.is_file():
total += f.stat().st_size
return total
def _copy_with_progress(src: Path, dst: Path, progress_manager, copied_so_far: int, total_bytes: int) -> int:
"""Copy a directory tree with byte-level progress tracking."""
dst.mkdir(parents=True, exist_ok=True)
for item in src.iterdir():
dest_item = dst / item.name
if item.is_dir():
copied_so_far = _copy_with_progress(item, dest_item, progress_manager, copied_so_far, total_bytes)
else:
size = item.stat().st_size
shutil.copy2(str(item), str(dest_item))
copied_so_far += size
progress_manager.update_progress(
"migration",
copied_so_far,
total_bytes,
filename=item.name,
status="downloading",
)
return copied_so_far
@router.post("/models/load")
async def load_model(model_size: str = "1.7B"):
"""Manually load TTS model."""
from ..services import tts
try:
tts_model = tts.get_tts_model()
await tts_model.load_model_async(model_size)
return {"message": f"Model {model_size} loaded successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/models/unload")
async def unload_model():
"""Unload the default Qwen TTS model to free memory."""
from ..services import tts
try:
tts.unload_tts_model()
return {"message": "Model unloaded successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/models/{model_name}/unload")
async def unload_model_by_name(model_name: str):
"""Unload a specific model from memory without deleting it from disk."""
from ..backends import get_model_config, unload_model_by_config
config = get_model_config(model_name)
if not config:
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
try:
was_loaded = unload_model_by_config(config)
if not was_loaded:
return {"message": f"Model {model_name} is not loaded"}
return {"message": f"Model {model_name} unloaded successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/models/progress/{model_name}")
async def get_model_progress(model_name: str):
"""Get model download progress via Server-Sent Events."""
progress_manager = get_progress_manager()
async def event_generator():
async for event in progress_manager.subscribe(model_name):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.get("/models/cache-dir")
async def get_models_cache_dir():
"""Get the path to the HuggingFace model cache directory."""
from huggingface_hub import constants as hf_constants
return {"path": str(Path(hf_constants.HF_HUB_CACHE))}
@router.post("/models/migrate")
async def migrate_models(request: models.ModelMigrateRequest):
"""Move all downloaded models to a new directory with byte-level progress via SSE."""
from huggingface_hub import constants as hf_constants
source = Path(hf_constants.HF_HUB_CACHE)
destination = Path(request.destination)
if not source.exists():
raise HTTPException(status_code=404, detail="Current model cache directory not found")
if source.resolve() == destination.resolve():
raise HTTPException(status_code=400, detail="Source and destination are the same directory")
if destination.resolve().is_relative_to(source.resolve()):
raise HTTPException(status_code=400, detail="Destination cannot be inside the current cache directory")
progress_manager = get_progress_manager()
model_dirs = [d for d in source.iterdir() if d.name.startswith("models--") and d.is_dir()]
if not model_dirs:
progress_manager.update_progress("migration", 1, 1, status="complete")
progress_manager.mark_complete("migration")
return {"moved": 0, "errors": [], "source": str(source), "destination": str(destination)}
destination.mkdir(parents=True, exist_ok=True)
same_fs = False
try:
same_fs = source.stat().st_dev == destination.stat().st_dev
except OSError:
pass
async def migrate_background():
moved = 0
errors = []
try:
if same_fs:
total = len(model_dirs)
for i, item in enumerate(model_dirs):
dest_item = destination / item.name
try:
if dest_item.exists():
shutil.rmtree(dest_item)
shutil.move(str(item), str(dest_item))
moved += 1
progress_manager.update_progress(
"migration",
i + 1,
total,
filename=item.name,
status="downloading",
)
except Exception as e:
errors.append(f"{item.name}: {str(e)}")
else:
total_bytes = sum(_get_dir_size(d) for d in model_dirs)
progress_manager.update_progress(
"migration", 0, total_bytes, filename="Calculating...", status="downloading"
)
copied = 0
for item in model_dirs:
dest_item = destination / item.name
try:
if dest_item.exists():
shutil.rmtree(dest_item)
copied = await asyncio.to_thread(
_copy_with_progress, item, dest_item, progress_manager, copied, total_bytes
)
await asyncio.to_thread(shutil.rmtree, str(item))
moved += 1
except Exception as e:
errors.append(f"{item.name}: {str(e)}")
progress_manager.update_progress("migration", 1, 1, status="complete")
progress_manager.mark_complete("migration")
except Exception as e:
progress_manager.update_progress("migration", 0, 0, status="error")
progress_manager.mark_error("migration", str(e))
create_background_task(migrate_background())
return {"source": str(source), "destination": str(destination)}
@router.get("/models/migrate/progress")
async def get_migration_progress():
"""Get model migration progress via Server-Sent Events."""
progress_manager = get_progress_manager()
async def event_generator():
async for event in progress_manager.subscribe("migration"):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.get("/models/status", response_model=models.ModelStatusListResponse)
async def get_model_status():
"""Get status of all available models."""
from huggingface_hub import constants as hf_constants
backend_type = get_backend_type()
task_manager = get_task_manager()
active_download_names = {task.model_name for task in task_manager.get_active_downloads()}
try:
from huggingface_hub import scan_cache_dir
use_scan_cache = True
except ImportError:
use_scan_cache = False
from ..backends import get_all_model_configs, check_model_loaded
registry_configs = get_all_model_configs()
model_configs = [
{
"model_name": cfg.model_name,
"display_name": cfg.display_name,
"hf_repo_id": cfg.hf_repo_id,
"model_size": cfg.model_size,
"check_loaded": lambda c=cfg: check_model_loaded(c),
}
for cfg in registry_configs
]
model_to_repo = {cfg["model_name"]: cfg["hf_repo_id"] for cfg in model_configs}
active_download_repos = {model_to_repo.get(name) for name in active_download_names if name in model_to_repo}
cache_info = None
if use_scan_cache:
try:
cache_info = scan_cache_dir()
except Exception:
pass
statuses = []
for config in model_configs:
try:
downloaded = False
size_mb = None
loaded = False
if cache_info:
repo_id = config["hf_repo_id"]
for repo in cache_info.repos:
if repo.repo_id == repo_id:
has_model_weights = False
for rev in repo.revisions:
for f in rev.files:
fname = f.file_name.lower()
if fname.endswith((".safetensors", ".bin", ".pt", ".pth", ".npz")):
has_model_weights = True
break
if has_model_weights:
break
has_incomplete = False
try:
cache_dir = hf_constants.HF_HUB_CACHE
blobs_dir = Path(cache_dir) / ("models--" + repo_id.replace("/", "--")) / "blobs"
if blobs_dir.exists():
has_incomplete = any(blobs_dir.glob("*.incomplete"))
except Exception:
pass
if has_model_weights and not has_incomplete:
downloaded = True
try:
total_size = sum(revision.size_on_disk for revision in repo.revisions)
size_mb = total_size / (1024 * 1024)
except Exception:
pass
break
if not downloaded:
try:
cache_dir = hf_constants.HF_HUB_CACHE
repo_cache = Path(cache_dir) / ("models--" + config["hf_repo_id"].replace("/", "--"))
if repo_cache.exists():
blobs_dir = repo_cache / "blobs"
has_incomplete = blobs_dir.exists() and any(blobs_dir.glob("*.incomplete"))
if not has_incomplete:
snapshots_dir = repo_cache / "snapshots"
has_model_files = False
if snapshots_dir.exists():
has_model_files = (
any(snapshots_dir.rglob("*.bin"))
or any(snapshots_dir.rglob("*.safetensors"))
or any(snapshots_dir.rglob("*.pt"))
or any(snapshots_dir.rglob("*.pth"))
or any(snapshots_dir.rglob("*.npz"))
)
if has_model_files:
downloaded = True
try:
total_size = sum(
f.stat().st_size
for f in repo_cache.rglob("*")
if f.is_file() and not f.name.endswith(".incomplete")
)
size_mb = total_size / (1024 * 1024)
except Exception:
pass
except Exception:
pass
try:
loaded = config["check_loaded"]()
except Exception:
loaded = False
is_downloading = config["hf_repo_id"] in active_download_repos
if is_downloading:
downloaded = False
size_mb = None
statuses.append(
models.ModelStatus(
model_name=config["model_name"],
display_name=config["display_name"],
hf_repo_id=config["hf_repo_id"],
downloaded=downloaded,
downloading=is_downloading,
size_mb=size_mb,
loaded=loaded,
)
)
except Exception:
try:
loaded = config["check_loaded"]()
except Exception:
loaded = False
is_downloading = config["hf_repo_id"] in active_download_repos
statuses.append(
models.ModelStatus(
model_name=config["model_name"],
display_name=config["display_name"],
hf_repo_id=config["hf_repo_id"],
downloaded=False,
downloading=is_downloading,
size_mb=None,
loaded=loaded,
)
)
return models.ModelStatusListResponse(models=statuses)
@router.post("/models/download")
async def trigger_model_download(request: models.ModelDownloadRequest):
"""Trigger download of a specific model."""
from ..backends import get_model_config, get_model_load_func
task_manager = get_task_manager()
progress_manager = get_progress_manager()
config = get_model_config(request.model_name)
if not config:
raise HTTPException(status_code=400, detail=f"Unknown model: {request.model_name}")
load_func = get_model_load_func(config)
async def download_in_background():
try:
result = load_func()
if asyncio.iscoroutine(result):
await result
task_manager.complete_download(request.model_name)
except Exception as e:
task_manager.error_download(request.model_name, str(e))
task_manager.start_download(request.model_name)
progress_manager.update_progress(
model_name=request.model_name,
current=0,
total=0,
filename="Connecting to HuggingFace...",
status="downloading",
)
create_background_task(download_in_background())
return {"message": f"Model {request.model_name} download started"}
@router.post("/models/download/cancel")
async def cancel_model_download(request: models.ModelDownloadRequest):
"""Cancel or dismiss an errored/stale download task."""
task_manager = get_task_manager()
progress_manager = get_progress_manager()
removed = task_manager.cancel_download(request.model_name)
progress_removed = False
with progress_manager._lock:
if request.model_name in progress_manager._progress:
del progress_manager._progress[request.model_name]
progress_removed = True
if removed or progress_removed:
return {"message": f"Download task for {request.model_name} cancelled"}
return {"message": f"No active task found for {request.model_name}"}
@router.delete("/models/{model_name}")
async def delete_model(model_name: str):
"""Delete a downloaded model from the HuggingFace cache."""
from huggingface_hub import constants as hf_constants
from ..backends import get_model_config, unload_model_by_config
config = get_model_config(model_name)
if not config:
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
hf_repo_id = config.hf_repo_id
try:
unload_model_by_config(config)
cache_dir = hf_constants.HF_HUB_CACHE
repo_cache_dir = Path(cache_dir) / ("models--" + hf_repo_id.replace("/", "--"))
if not repo_cache_dir.exists():
raise HTTPException(status_code=404, detail=f"Model {model_name} not found in cache")
try:
shutil.rmtree(repo_cache_dir)
except OSError as e:
raise HTTPException(status_code=500, detail=f"Failed to delete model cache directory: {str(e)}")
return {"message": f"Model {model_name} deleted successfully"}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")

363
backend/routes/profiles.py Normal file
View File

@@ -0,0 +1,363 @@
"""Voice profile endpoints."""
import io
import json as _json
import logging
import tempfile
from datetime import datetime
from pathlib import Path
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from fastapi.responses import FileResponse, StreamingResponse
from sqlalchemy.orm import Session
from .. import config, models
from ..app import safe_content_disposition
from ..database import VoiceProfile as DBVoiceProfile, get_db
from ..services import channels, export_import, profiles
from ..services.profiles import _profile_to_response
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/profiles", response_model=models.VoiceProfileResponse)
async def create_profile(
data: models.VoiceProfileCreate,
db: Session = Depends(get_db),
):
"""Create a new voice profile."""
try:
return await profiles.create_profile(data, db)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/profiles", response_model=list[models.VoiceProfileResponse])
async def list_profiles(db: Session = Depends(get_db)):
"""List all voice profiles."""
return await profiles.list_profiles(db)
@router.post("/profiles/import", response_model=models.VoiceProfileResponse)
async def import_profile(
file: UploadFile = File(...),
db: Session = Depends(get_db),
):
"""Import a voice profile from a ZIP archive."""
MAX_FILE_SIZE = 100 * 1024 * 1024
content = await file.read()
if len(content) > MAX_FILE_SIZE:
raise HTTPException(
status_code=400, detail=f"File too large. Maximum size is {MAX_FILE_SIZE / (1024 * 1024)}MB"
)
try:
profile = await export_import.import_profile_from_zip(content, db)
return profile
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ── Preset Voice Endpoints ───────────────────────────────────────────
# These MUST be declared before /profiles/{profile_id} to avoid the
# wildcard swallowing "presets" as a profile_id.
@router.get("/profiles/presets/{engine}")
async def list_preset_voices(engine: str):
"""List available preset voices for an engine."""
if engine == "kokoro":
from ..backends.kokoro_backend import KOKORO_VOICES
return {
"engine": engine,
"voices": [
{
"voice_id": vid,
"name": name,
"gender": gender,
"language": lang,
}
for vid, name, gender, lang in KOKORO_VOICES
],
}
if engine == "qwen_custom_voice":
from ..backends.qwen_custom_voice_backend import QWEN_CUSTOM_VOICES
return {
"engine": engine,
"voices": [
{
"voice_id": speaker_id,
"name": display_name,
"gender": gender,
"language": lang,
}
for speaker_id, display_name, gender, lang, _desc in QWEN_CUSTOM_VOICES
],
}
return {"engine": engine, "voices": []}
@router.get("/profiles/{profile_id}", response_model=models.VoiceProfileResponse)
async def get_profile(
profile_id: str,
db: Session = Depends(get_db),
):
"""Get a voice profile by ID."""
profile = await profiles.get_profile(profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
return profile
@router.put("/profiles/{profile_id}", response_model=models.VoiceProfileResponse)
async def update_profile(
profile_id: str,
data: models.VoiceProfileCreate,
db: Session = Depends(get_db),
):
"""Update a voice profile."""
try:
profile = await profiles.update_profile(profile_id, data, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
return profile
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/profiles/{profile_id}")
async def delete_profile(
profile_id: str,
db: Session = Depends(get_db),
):
"""Delete a voice profile."""
success = await profiles.delete_profile(profile_id, db)
if not success:
raise HTTPException(status_code=404, detail="Profile not found")
return {"message": "Profile deleted successfully"}
SAMPLE_MAX_FILE_SIZE = 50 * 1024 * 1024 # 50 MB
SAMPLE_UPLOAD_CHUNK_SIZE = 1024 * 1024 # 1 MB
@router.post("/profiles/{profile_id}/samples", response_model=models.ProfileSampleResponse)
async def add_profile_sample(
profile_id: str,
file: UploadFile = File(...),
reference_text: str = Form(...),
db: Session = Depends(get_db),
):
"""Add a sample to a voice profile."""
_allowed_audio_exts = {".wav", ".mp3", ".m4a", ".ogg", ".flac", ".aac", ".webm", ".opus"}
_uploaded_ext = Path(file.filename or "").suffix.lower()
file_suffix = _uploaded_ext if _uploaded_ext in _allowed_audio_exts else ".wav"
with tempfile.NamedTemporaryFile(suffix=file_suffix, delete=False) as tmp:
total_size = 0
while chunk := await file.read(SAMPLE_UPLOAD_CHUNK_SIZE):
total_size += len(chunk)
if total_size > SAMPLE_MAX_FILE_SIZE:
Path(tmp.name).unlink(missing_ok=True)
raise HTTPException(
status_code=413,
detail=f"File too large (max {SAMPLE_MAX_FILE_SIZE // (1024 * 1024)} MB)",
)
tmp.write(chunk)
tmp_path = tmp.name
try:
sample = await profiles.add_profile_sample(
profile_id,
tmp_path,
reference_text,
db,
)
return sample
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to process audio file: {str(e)}")
finally:
Path(tmp_path).unlink(missing_ok=True)
@router.get("/profiles/{profile_id}/samples", response_model=list[models.ProfileSampleResponse])
async def get_profile_samples(
profile_id: str,
db: Session = Depends(get_db),
):
"""Get all samples for a profile."""
return await profiles.get_profile_samples(profile_id, db)
@router.delete("/profiles/samples/{sample_id}")
async def delete_profile_sample(
sample_id: str,
db: Session = Depends(get_db),
):
"""Delete a profile sample."""
success = await profiles.delete_profile_sample(sample_id, db)
if not success:
raise HTTPException(status_code=404, detail="Sample not found")
return {"message": "Sample deleted successfully"}
@router.put("/profiles/samples/{sample_id}", response_model=models.ProfileSampleResponse)
async def update_profile_sample(
sample_id: str,
data: models.ProfileSampleUpdate,
db: Session = Depends(get_db),
):
"""Update a profile sample's reference text."""
sample = await profiles.update_profile_sample(sample_id, data.reference_text, db)
if not sample:
raise HTTPException(status_code=404, detail="Sample not found")
return sample
@router.post("/profiles/{profile_id}/avatar", response_model=models.VoiceProfileResponse)
async def upload_profile_avatar(
profile_id: str,
file: UploadFile = File(...),
db: Session = Depends(get_db),
):
"""Upload or update avatar image for a profile."""
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
profile = await profiles.upload_avatar(profile_id, tmp_path, db)
return profile
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
finally:
Path(tmp_path).unlink(missing_ok=True)
@router.get("/profiles/{profile_id}/avatar")
async def get_profile_avatar(
profile_id: str,
db: Session = Depends(get_db),
):
"""Get avatar image for a profile."""
profile = await profiles.get_profile(profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
if not profile.avatar_path:
raise HTTPException(status_code=404, detail="No avatar found for this profile")
avatar_path = config.resolve_storage_path(profile.avatar_path)
if avatar_path is None or not avatar_path.exists():
raise HTTPException(status_code=404, detail="Avatar file not found")
return FileResponse(avatar_path)
@router.delete("/profiles/{profile_id}/avatar")
async def delete_profile_avatar(
profile_id: str,
db: Session = Depends(get_db),
):
"""Delete avatar image for a profile."""
success = await profiles.delete_avatar(profile_id, db)
if not success:
raise HTTPException(status_code=404, detail="Profile not found or no avatar to delete")
return {"message": "Avatar deleted successfully"}
@router.get("/profiles/{profile_id}/export")
async def export_profile(
profile_id: str,
db: Session = Depends(get_db),
):
"""Export a voice profile as a ZIP archive."""
try:
profile = await profiles.get_profile(profile_id, db)
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
zip_bytes = export_import.export_profile_to_zip(profile_id, db)
safe_name = "".join(c for c in profile.name if c.isalnum() or c in (" ", "-", "_")).strip()
if not safe_name:
safe_name = "profile"
filename = f"profile-{safe_name}.voicebox.zip"
return StreamingResponse(
io.BytesIO(zip_bytes),
media_type="application/zip",
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/profiles/{profile_id}/channels")
async def get_profile_channels(
profile_id: str,
db: Session = Depends(get_db),
):
"""Get list of channel IDs assigned to a profile."""
try:
channel_ids = await channels.get_profile_channels(profile_id, db)
return {"channel_ids": channel_ids}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.put("/profiles/{profile_id}/channels")
async def set_profile_channels(
profile_id: str,
data: models.ProfileChannelAssignment,
db: Session = Depends(get_db),
):
"""Set which channels a profile is assigned to."""
try:
await channels.set_profile_channels(profile_id, data, db)
return {"message": "Profile channels updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.put("/profiles/{profile_id}/effects", response_model=models.VoiceProfileResponse)
async def update_profile_effects(
profile_id: str,
data: models.ProfileEffectsUpdate,
db: Session = Depends(get_db),
):
"""Set or clear the default effects chain for a voice profile."""
profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first()
if not profile:
raise HTTPException(status_code=404, detail="Profile not found")
if data.effects_chain is not None:
from ..utils.effects import validate_effects_chain
chain_dicts = [e.model_dump() for e in data.effects_chain]
error = validate_effects_chain(chain_dicts)
if error:
raise HTTPException(status_code=400, detail=error)
profile.effects_chain = _json.dumps(chain_dicts)
else:
profile.effects_chain = None
profile.updated_at = datetime.utcnow()
db.commit()
db.refresh(profile)
return _profile_to_response(profile)

223
backend/routes/stories.py Normal file
View File

@@ -0,0 +1,223 @@
"""Story endpoints."""
import io
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from .. import database, models
from ..services import stories
from ..app import safe_content_disposition
from ..database import get_db
router = APIRouter()
@router.get("/stories", response_model=list[models.StoryResponse])
async def list_stories(db: Session = Depends(get_db)):
"""List all stories."""
return await stories.list_stories(db)
@router.post("/stories", response_model=models.StoryResponse)
async def create_story(
data: models.StoryCreate,
db: Session = Depends(get_db),
):
"""Create a new story."""
try:
return await stories.create_story(data, db)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/stories/{story_id}", response_model=models.StoryDetailResponse)
async def get_story(
story_id: str,
db: Session = Depends(get_db),
):
"""Get a story with all its items."""
story = await stories.get_story(story_id, db)
if not story:
raise HTTPException(status_code=404, detail="Story not found")
return story
@router.put("/stories/{story_id}", response_model=models.StoryResponse)
async def update_story(
story_id: str,
data: models.StoryCreate,
db: Session = Depends(get_db),
):
"""Update a story."""
story = await stories.update_story(story_id, data, db)
if not story:
raise HTTPException(status_code=404, detail="Story not found")
return story
@router.delete("/stories/{story_id}")
async def delete_story(
story_id: str,
db: Session = Depends(get_db),
):
"""Delete a story."""
success = await stories.delete_story(story_id, db)
if not success:
raise HTTPException(status_code=404, detail="Story not found")
return {"message": "Story deleted successfully"}
@router.post("/stories/{story_id}/items", response_model=models.StoryItemDetail)
async def add_story_item(
story_id: str,
data: models.StoryItemCreate,
db: Session = Depends(get_db),
):
"""Add a generation to a story."""
item = await stories.add_item_to_story(story_id, data, db)
if not item:
raise HTTPException(status_code=404, detail="Story or generation not found")
return item
@router.delete("/stories/{story_id}/items/{item_id}")
async def remove_story_item(
story_id: str,
item_id: str,
db: Session = Depends(get_db),
):
"""Remove a story item from a story."""
success = await stories.remove_item_from_story(story_id, item_id, db)
if not success:
raise HTTPException(status_code=404, detail="Story item not found")
return {"message": "Item removed successfully"}
@router.put("/stories/{story_id}/items/times")
async def update_story_item_times(
story_id: str,
data: models.StoryItemBatchUpdate,
db: Session = Depends(get_db),
):
"""Update story item timecodes."""
success = await stories.update_story_item_times(story_id, data, db)
if not success:
raise HTTPException(status_code=400, detail="Invalid timecode update request")
return {"message": "Item timecodes updated successfully"}
@router.put("/stories/{story_id}/items/reorder", response_model=list[models.StoryItemDetail])
async def reorder_story_items(
story_id: str,
data: models.StoryItemReorder,
db: Session = Depends(get_db),
):
"""Reorder story items and recalculate timecodes."""
items = await stories.reorder_story_items(story_id, data.generation_ids, db)
if items is None:
raise HTTPException(
status_code=400, detail="Invalid reorder request - ensure all generation IDs belong to this story"
)
return items
@router.put("/stories/{story_id}/items/{item_id}/move", response_model=models.StoryItemDetail)
async def move_story_item(
story_id: str,
item_id: str,
data: models.StoryItemMove,
db: Session = Depends(get_db),
):
"""Move a story item (update position and/or track)."""
item = await stories.move_story_item(story_id, item_id, data, db)
if item is None:
raise HTTPException(status_code=404, detail="Story item not found")
return item
@router.put("/stories/{story_id}/items/{item_id}/trim", response_model=models.StoryItemDetail)
async def trim_story_item(
story_id: str,
item_id: str,
data: models.StoryItemTrim,
db: Session = Depends(get_db),
):
"""Trim a story item."""
item = await stories.trim_story_item(story_id, item_id, data, db)
if item is None:
raise HTTPException(status_code=404, detail="Story item not found or invalid trim values")
return item
@router.post("/stories/{story_id}/items/{item_id}/split", response_model=list[models.StoryItemDetail])
async def split_story_item(
story_id: str,
item_id: str,
data: models.StoryItemSplit,
db: Session = Depends(get_db),
):
"""Split a story item at a given time, creating two clips."""
items = await stories.split_story_item(story_id, item_id, data, db)
if items is None:
raise HTTPException(status_code=404, detail="Story item not found or invalid split point")
return items
@router.post("/stories/{story_id}/items/{item_id}/duplicate", response_model=models.StoryItemDetail)
async def duplicate_story_item(
story_id: str,
item_id: str,
db: Session = Depends(get_db),
):
"""Duplicate a story item."""
item = await stories.duplicate_story_item(story_id, item_id, db)
if item is None:
raise HTTPException(status_code=404, detail="Story item not found")
return item
@router.put("/stories/{story_id}/items/{item_id}/version", response_model=models.StoryItemDetail)
async def set_story_item_version(
story_id: str,
item_id: str,
data: models.StoryItemVersionUpdate,
db: Session = Depends(get_db),
):
"""Pin a story item to a specific generation version."""
item = await stories.set_story_item_version(story_id, item_id, data, db)
if item is None:
raise HTTPException(status_code=404, detail="Story item or version not found")
return item
@router.get("/stories/{story_id}/export-audio")
async def export_story_audio(
story_id: str,
db: Session = Depends(get_db),
):
"""Export story as single mixed audio file."""
try:
story = db.query(database.Story).filter_by(id=story_id).first()
if not story:
raise HTTPException(status_code=404, detail="Story not found")
audio_bytes = await stories.export_story_audio(story_id, db)
if not audio_bytes:
raise HTTPException(status_code=400, detail="Story has no audio items")
safe_name = "".join(c for c in story.name if c.isalnum() or c in (" ", "-", "_")).strip()
if not safe_name:
safe_name = "story"
filename = f"{safe_name}.wav"
return StreamingResponse(
io.BytesIO(audio_bytes),
media_type="audio/wav",
headers={"Content-Disposition": safe_content_disposition("attachment", filename)},
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

125
backend/routes/tasks.py Normal file
View File

@@ -0,0 +1,125 @@
"""Task and cache management endpoints."""
from datetime import datetime
from fastapi import APIRouter
from .. import models
from ..utils.cache import clear_voice_prompt_cache
from ..utils.progress import get_progress_manager
from ..utils.tasks import get_task_manager
from fastapi import HTTPException
router = APIRouter()
@router.post("/tasks/clear")
async def clear_all_tasks():
"""Clear all download tasks and progress state."""
task_manager = get_task_manager()
progress_manager = get_progress_manager()
task_manager.clear_all()
with progress_manager._lock:
progress_manager._progress.clear()
progress_manager._last_notify_time.clear()
progress_manager._last_notify_progress.clear()
return {"message": "All task state cleared"}
@router.post("/cache/clear")
async def clear_cache():
"""Clear all voice prompt caches (memory and disk)."""
try:
deleted_count = clear_voice_prompt_cache()
return {
"message": "Voice prompt cache cleared successfully",
"files_deleted": deleted_count,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to clear cache: {str(e)}")
@router.get("/tasks/active", response_model=models.ActiveTasksResponse)
async def get_active_tasks():
"""Return all currently active downloads and generations."""
task_manager = get_task_manager()
progress_manager = get_progress_manager()
active_downloads = []
task_manager_downloads = task_manager.get_active_downloads()
progress_active = progress_manager.get_all_active()
download_map = {task.model_name: task for task in task_manager_downloads}
progress_map = {p["model_name"]: p for p in progress_active}
all_model_names = set(download_map.keys()) | set(progress_map.keys())
for model_name in all_model_names:
task = download_map.get(model_name)
progress = progress_map.get(model_name)
if task:
error = task.error
if not error:
with progress_manager._lock:
pm_data = progress_manager._progress.get(model_name)
if pm_data:
error = pm_data.get("error")
prog = progress or {}
if not prog:
with progress_manager._lock:
pm_data = progress_manager._progress.get(model_name)
if pm_data:
prog = pm_data
active_downloads.append(
models.ActiveDownloadTask(
model_name=model_name,
status=task.status,
started_at=task.started_at,
error=error,
progress=prog.get("progress"),
current=prog.get("current"),
total=prog.get("total"),
filename=prog.get("filename"),
)
)
elif progress:
timestamp_str = progress.get("timestamp")
if timestamp_str:
try:
started_at = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00"))
except (ValueError, AttributeError):
started_at = datetime.utcnow()
else:
started_at = datetime.utcnow()
active_downloads.append(
models.ActiveDownloadTask(
model_name=model_name,
status=progress.get("status", "downloading"),
started_at=started_at,
error=progress.get("error"),
progress=progress.get("progress"),
current=progress.get("current"),
total=progress.get("total"),
filename=progress.get("filename"),
)
)
active_generations = []
for gen_task in task_manager.get_active_generations():
active_generations.append(
models.ActiveGenerationTask(
task_id=gen_task.task_id,
profile_id=gen_task.profile_id,
text_preview=gen_task.text_preview,
started_at=gen_task.started_at,
)
)
return models.ActiveTasksResponse(
downloads=active_downloads,
generations=active_generations,
)

View File

@@ -0,0 +1,84 @@
"""Transcription endpoints."""
import asyncio
import tempfile
from pathlib import Path
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from .. import models
from ..services import transcribe
from ..services.task_queue import create_background_task
from ..utils.tasks import get_task_manager
router = APIRouter()
UPLOAD_CHUNK_SIZE = 1024 * 1024 # 1MB
@router.post("/transcribe", response_model=models.TranscriptionResponse)
async def transcribe_audio(
file: UploadFile = File(...),
language: str | None = Form(None),
model: str | None = Form(None),
):
"""Transcribe audio file to text."""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
tmp.write(chunk)
tmp_path = tmp.name
try:
from ..utils.audio import load_audio
from ..backends import WHISPER_HF_REPOS
audio, sr = await asyncio.to_thread(load_audio, tmp_path)
duration = len(audio) / sr
whisper_model = transcribe.get_whisper_model()
model_size = model if model else whisper_model.model_size
valid_sizes = list(WHISPER_HF_REPOS.keys())
if model_size not in valid_sizes:
raise HTTPException(
status_code=400,
detail=f"Invalid model size '{model_size}'. Must be one of: {', '.join(valid_sizes)}",
)
already_loaded = whisper_model.is_loaded() and whisper_model.model_size == model_size
if not already_loaded and not whisper_model._is_model_cached(model_size):
progress_model_name = f"whisper-{model_size}"
task_manager = get_task_manager()
async def download_whisper_background():
try:
await whisper_model.load_model_async(model_size)
task_manager.complete_download(progress_model_name)
except Exception as e:
task_manager.error_download(progress_model_name, str(e))
task_manager.start_download(progress_model_name)
create_background_task(download_whisper_background())
raise HTTPException(
status_code=202,
detail={
"message": f"Whisper model {model_size} is being downloaded. Please wait and try again.",
"model_name": progress_model_name,
"downloading": True,
},
)
text = await whisper_model.transcribe(tmp_path, language, model_size)
return models.TranscriptionResponse(
text=text,
duration=duration,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
Path(tmp_path).unlink(missing_ok=True)