Initial commit
This commit is contained in:
108
backend/services/task_queue.py
Normal file
108
backend/services/task_queue.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Serial generation queue — ensures only one TTS inference runs at a time
|
||||
to avoid GPU contention.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Coroutine, Literal
|
||||
|
||||
# Keep references to fire-and-forget background tasks to prevent GC
|
||||
_background_tasks: set = set()
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationJob:
|
||||
"""Queued generation work plus the generation ID it belongs to."""
|
||||
|
||||
generation_id: str
|
||||
coro: Coroutine
|
||||
|
||||
|
||||
# Generation queue — serializes TTS inference to avoid GPU contention
|
||||
_generation_queue: asyncio.Queue = None # type: ignore # initialized at startup
|
||||
_generation_worker_task: asyncio.Task | None = None
|
||||
_queued_generation_ids: set[str] = set()
|
||||
_running_generation_tasks: dict[str, asyncio.Task] = {}
|
||||
_cancelled_generation_ids: set[str] = set()
|
||||
|
||||
|
||||
def create_background_task(coro) -> asyncio.Task:
|
||||
"""Create a background task and prevent it from being garbage collected."""
|
||||
task = asyncio.create_task(coro)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
return task
|
||||
|
||||
|
||||
async def _generation_worker():
|
||||
"""Worker that processes generation tasks one at a time."""
|
||||
while True:
|
||||
job = await _generation_queue.get()
|
||||
try:
|
||||
if job.generation_id in _cancelled_generation_ids:
|
||||
_cancelled_generation_ids.discard(job.generation_id)
|
||||
job.coro.close()
|
||||
continue
|
||||
|
||||
task = asyncio.create_task(job.coro)
|
||||
_running_generation_tasks[job.generation_id] = task
|
||||
_queued_generation_ids.discard(job.generation_id)
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
if not task.cancelled():
|
||||
raise
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
_running_generation_tasks.pop(job.generation_id, None)
|
||||
_queued_generation_ids.discard(job.generation_id)
|
||||
_generation_queue.task_done()
|
||||
|
||||
|
||||
def enqueue_generation(generation_id: str, coro):
|
||||
"""Add a generation coroutine to the serial queue."""
|
||||
if _generation_queue is None:
|
||||
raise RuntimeError("Generation queue has not been initialized")
|
||||
|
||||
_queued_generation_ids.add(generation_id)
|
||||
_generation_queue.put_nowait(GenerationJob(generation_id=generation_id, coro=coro))
|
||||
|
||||
|
||||
def cancel_generation(generation_id: str) -> Literal["queued", "running"] | None:
|
||||
"""Cancel a queued or running generation if it is still active."""
|
||||
running_task = _running_generation_tasks.get(generation_id)
|
||||
if running_task is not None:
|
||||
running_task.cancel()
|
||||
return "running"
|
||||
|
||||
if generation_id in _queued_generation_ids:
|
||||
_queued_generation_ids.discard(generation_id)
|
||||
_cancelled_generation_ids.add(generation_id)
|
||||
return "queued"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def init_queue(force: bool = False):
|
||||
"""Initialize the generation queue and start the worker.
|
||||
|
||||
Must be called once during application startup (inside a running event loop).
|
||||
"""
|
||||
global _generation_queue, _generation_worker_task
|
||||
global _queued_generation_ids, _running_generation_tasks, _cancelled_generation_ids
|
||||
|
||||
if _generation_worker_task is not None and not _generation_worker_task.done():
|
||||
if not force:
|
||||
return
|
||||
_generation_worker_task.cancel()
|
||||
for task in list(_running_generation_tasks.values()):
|
||||
task.cancel()
|
||||
|
||||
_generation_queue = asyncio.Queue()
|
||||
_queued_generation_ids = set()
|
||||
_running_generation_tasks = {}
|
||||
_cancelled_generation_ids = set()
|
||||
_generation_worker_task = create_background_task(_generation_worker())
|
||||
Reference in New Issue
Block a user