""" Serial generation queue — ensures only one TTS inference runs at a time to avoid GPU contention. """ import asyncio import traceback from dataclasses import dataclass from typing import Coroutine, Literal # Keep references to fire-and-forget background tasks to prevent GC _background_tasks: set = set() @dataclass class GenerationJob: """Queued generation work plus the generation ID it belongs to.""" generation_id: str coro: Coroutine # Generation queue — serializes TTS inference to avoid GPU contention _generation_queue: asyncio.Queue = None # type: ignore # initialized at startup _generation_worker_task: asyncio.Task | None = None _queued_generation_ids: set[str] = set() _running_generation_tasks: dict[str, asyncio.Task] = {} _cancelled_generation_ids: set[str] = set() def create_background_task(coro) -> asyncio.Task: """Create a background task and prevent it from being garbage collected.""" task = asyncio.create_task(coro) _background_tasks.add(task) task.add_done_callback(_background_tasks.discard) return task async def _generation_worker(): """Worker that processes generation tasks one at a time.""" while True: job = await _generation_queue.get() try: if job.generation_id in _cancelled_generation_ids: _cancelled_generation_ids.discard(job.generation_id) job.coro.close() continue task = asyncio.create_task(job.coro) _running_generation_tasks[job.generation_id] = task _queued_generation_ids.discard(job.generation_id) try: await task except asyncio.CancelledError: if not task.cancelled(): raise except Exception: traceback.print_exc() finally: _running_generation_tasks.pop(job.generation_id, None) _queued_generation_ids.discard(job.generation_id) _generation_queue.task_done() def enqueue_generation(generation_id: str, coro): """Add a generation coroutine to the serial queue.""" if _generation_queue is None: raise RuntimeError("Generation queue has not been initialized") _queued_generation_ids.add(generation_id) _generation_queue.put_nowait(GenerationJob(generation_id=generation_id, coro=coro)) def cancel_generation(generation_id: str) -> Literal["queued", "running"] | None: """Cancel a queued or running generation if it is still active.""" running_task = _running_generation_tasks.get(generation_id) if running_task is not None: running_task.cancel() return "running" if generation_id in _queued_generation_ids: _queued_generation_ids.discard(generation_id) _cancelled_generation_ids.add(generation_id) return "queued" return None def init_queue(force: bool = False): """Initialize the generation queue and start the worker. Must be called once during application startup (inside a running event loop). """ global _generation_queue, _generation_worker_task global _queued_generation_ids, _running_generation_tasks, _cancelled_generation_ids if _generation_worker_task is not None and not _generation_worker_task.done(): if not force: return _generation_worker_task.cancel() for task in list(_running_generation_tasks.values()): task.cancel() _generation_queue = asyncio.Queue() _queued_generation_ids = set() _running_generation_tasks = {} _cancelled_generation_ids = set() _generation_worker_task = create_background_task(_generation_worker())