109 lines
3.6 KiB
Python
109 lines
3.6 KiB
Python
"""
|
|
Serial generation queue — ensures only one TTS inference runs at a time
|
|
to avoid GPU contention.
|
|
"""
|
|
|
|
import asyncio
|
|
import traceback
|
|
from dataclasses import dataclass
|
|
from typing import Coroutine, Literal
|
|
|
|
# Keep references to fire-and-forget background tasks to prevent GC
|
|
_background_tasks: set = set()
|
|
|
|
|
|
@dataclass
|
|
class GenerationJob:
|
|
"""Queued generation work plus the generation ID it belongs to."""
|
|
|
|
generation_id: str
|
|
coro: Coroutine
|
|
|
|
|
|
# Generation queue — serializes TTS inference to avoid GPU contention
|
|
_generation_queue: asyncio.Queue = None # type: ignore # initialized at startup
|
|
_generation_worker_task: asyncio.Task | None = None
|
|
_queued_generation_ids: set[str] = set()
|
|
_running_generation_tasks: dict[str, asyncio.Task] = {}
|
|
_cancelled_generation_ids: set[str] = set()
|
|
|
|
|
|
def create_background_task(coro) -> asyncio.Task:
|
|
"""Create a background task and prevent it from being garbage collected."""
|
|
task = asyncio.create_task(coro)
|
|
_background_tasks.add(task)
|
|
task.add_done_callback(_background_tasks.discard)
|
|
return task
|
|
|
|
|
|
async def _generation_worker():
|
|
"""Worker that processes generation tasks one at a time."""
|
|
while True:
|
|
job = await _generation_queue.get()
|
|
try:
|
|
if job.generation_id in _cancelled_generation_ids:
|
|
_cancelled_generation_ids.discard(job.generation_id)
|
|
job.coro.close()
|
|
continue
|
|
|
|
task = asyncio.create_task(job.coro)
|
|
_running_generation_tasks[job.generation_id] = task
|
|
_queued_generation_ids.discard(job.generation_id)
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
if not task.cancelled():
|
|
raise
|
|
except Exception:
|
|
traceback.print_exc()
|
|
finally:
|
|
_running_generation_tasks.pop(job.generation_id, None)
|
|
_queued_generation_ids.discard(job.generation_id)
|
|
_generation_queue.task_done()
|
|
|
|
|
|
def enqueue_generation(generation_id: str, coro):
|
|
"""Add a generation coroutine to the serial queue."""
|
|
if _generation_queue is None:
|
|
raise RuntimeError("Generation queue has not been initialized")
|
|
|
|
_queued_generation_ids.add(generation_id)
|
|
_generation_queue.put_nowait(GenerationJob(generation_id=generation_id, coro=coro))
|
|
|
|
|
|
def cancel_generation(generation_id: str) -> Literal["queued", "running"] | None:
|
|
"""Cancel a queued or running generation if it is still active."""
|
|
running_task = _running_generation_tasks.get(generation_id)
|
|
if running_task is not None:
|
|
running_task.cancel()
|
|
return "running"
|
|
|
|
if generation_id in _queued_generation_ids:
|
|
_queued_generation_ids.discard(generation_id)
|
|
_cancelled_generation_ids.add(generation_id)
|
|
return "queued"
|
|
|
|
return None
|
|
|
|
|
|
def init_queue(force: bool = False):
|
|
"""Initialize the generation queue and start the worker.
|
|
|
|
Must be called once during application startup (inside a running event loop).
|
|
"""
|
|
global _generation_queue, _generation_worker_task
|
|
global _queued_generation_ids, _running_generation_tasks, _cancelled_generation_ids
|
|
|
|
if _generation_worker_task is not None and not _generation_worker_task.done():
|
|
if not force:
|
|
return
|
|
_generation_worker_task.cancel()
|
|
for task in list(_running_generation_tasks.values()):
|
|
task.cancel()
|
|
|
|
_generation_queue = asyncio.Queue()
|
|
_queued_generation_ids = set()
|
|
_running_generation_tasks = {}
|
|
_cancelled_generation_ids = set()
|
|
_generation_worker_task = create_background_task(_generation_worker())
|