85 lines
2.8 KiB
Python
85 lines
2.8 KiB
Python
"""Transcription endpoints."""
|
|
|
|
import asyncio
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
|
|
|
from .. import models
|
|
from ..services import transcribe
|
|
from ..services.task_queue import create_background_task
|
|
from ..utils.tasks import get_task_manager
|
|
|
|
router = APIRouter()
|
|
|
|
UPLOAD_CHUNK_SIZE = 1024 * 1024 # 1MB
|
|
|
|
|
|
@router.post("/transcribe", response_model=models.TranscriptionResponse)
|
|
async def transcribe_audio(
|
|
file: UploadFile = File(...),
|
|
language: str | None = Form(None),
|
|
model: str | None = Form(None),
|
|
):
|
|
"""Transcribe audio file to text."""
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
|
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
|
|
tmp.write(chunk)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
from ..utils.audio import load_audio
|
|
from ..backends import WHISPER_HF_REPOS
|
|
|
|
audio, sr = await asyncio.to_thread(load_audio, tmp_path)
|
|
duration = len(audio) / sr
|
|
|
|
whisper_model = transcribe.get_whisper_model()
|
|
model_size = model if model else whisper_model.model_size
|
|
|
|
valid_sizes = list(WHISPER_HF_REPOS.keys())
|
|
if model_size not in valid_sizes:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid model size '{model_size}'. Must be one of: {', '.join(valid_sizes)}",
|
|
)
|
|
|
|
already_loaded = whisper_model.is_loaded() and whisper_model.model_size == model_size
|
|
if not already_loaded and not whisper_model._is_model_cached(model_size):
|
|
progress_model_name = f"whisper-{model_size}"
|
|
task_manager = get_task_manager()
|
|
|
|
async def download_whisper_background():
|
|
try:
|
|
await whisper_model.load_model_async(model_size)
|
|
task_manager.complete_download(progress_model_name)
|
|
except Exception as e:
|
|
task_manager.error_download(progress_model_name, str(e))
|
|
|
|
task_manager.start_download(progress_model_name)
|
|
create_background_task(download_whisper_background())
|
|
|
|
raise HTTPException(
|
|
status_code=202,
|
|
detail={
|
|
"message": f"Whisper model {model_size} is being downloaded. Please wait and try again.",
|
|
"model_name": progress_model_name,
|
|
"downloading": True,
|
|
},
|
|
)
|
|
|
|
text = await whisper_model.transcribe(tmp_path, language, model_size)
|
|
|
|
return models.TranscriptionResponse(
|
|
text=text,
|
|
duration=duration,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
finally:
|
|
Path(tmp_path).unlink(missing_ok=True)
|