Initial commit
This commit is contained in:
206
backend/backends/chatterbox_turbo_backend.py
Normal file
206
backend/backends/chatterbox_turbo_backend.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Chatterbox Turbo TTS backend implementation.
|
||||
|
||||
Wraps ChatterboxTurboTTS from chatterbox-tts for fast, English-only
|
||||
voice cloning with paralinguistic tag support ([laugh], [cough], etc.).
|
||||
Forces CPU on macOS due to known MPS tensor issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import TTSBackend
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
empty_device_cache,
|
||||
manual_seed,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
patch_chatterbox_f32,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHATTERBOX_TURBO_HF_REPO = "ResembleAI/chatterbox-turbo"
|
||||
|
||||
# Files that must be present for the turbo model
|
||||
_TURBO_WEIGHT_FILES = [
|
||||
"t3_turbo_v1.safetensors",
|
||||
"s3gen_meanflow.safetensors",
|
||||
"ve.safetensors",
|
||||
]
|
||||
|
||||
|
||||
class ChatterboxTurboTTSBackend:
|
||||
"""Chatterbox Turbo TTS backend — fast, English-only, with paralinguistic tags."""
|
||||
|
||||
# Class-level lock for torch.load monkey-patching
|
||||
_load_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.model_size = "default"
|
||||
self._device = None
|
||||
self._model_load_lock = asyncio.Lock()
|
||||
|
||||
def _get_device(self) -> str:
|
||||
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str = "default") -> str:
|
||||
return CHATTERBOX_TURBO_HF_REPO
|
||||
|
||||
def _is_model_cached(self, model_size: str = "default") -> bool:
|
||||
return is_model_cached(CHATTERBOX_TURBO_HF_REPO, required_files=_TURBO_WEIGHT_FILES)
|
||||
|
||||
async def load_model(self, model_size: str = "default") -> None:
|
||||
"""Load the Chatterbox Turbo model."""
|
||||
if self.model is not None:
|
||||
return
|
||||
async with self._model_load_lock:
|
||||
if self.model is not None:
|
||||
return
|
||||
await asyncio.to_thread(self._load_model_sync)
|
||||
|
||||
def _load_model_sync(self):
|
||||
"""Synchronous model loading."""
|
||||
model_name = "chatterbox-turbo"
|
||||
is_cached = self._is_model_cached()
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
device = self._get_device()
|
||||
self._device = device
|
||||
logger.info(f"Loading Chatterbox Turbo TTS on {device}...")
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
||||
|
||||
local_path = snapshot_download(
|
||||
repo_id=CHATTERBOX_TURBO_HF_REPO,
|
||||
token=None,
|
||||
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"],
|
||||
)
|
||||
|
||||
if device == "cpu":
|
||||
_orig_torch_load = torch.load
|
||||
|
||||
def _patched_load(*args, **kwargs):
|
||||
kwargs.setdefault("map_location", "cpu")
|
||||
return _orig_torch_load(*args, **kwargs)
|
||||
|
||||
with ChatterboxTurboTTSBackend._load_lock:
|
||||
torch.load = _patched_load
|
||||
try:
|
||||
model = ChatterboxTurboTTS.from_local(local_path, device)
|
||||
finally:
|
||||
torch.load = _orig_torch_load
|
||||
else:
|
||||
model = ChatterboxTurboTTS.from_local(local_path, device)
|
||||
|
||||
patch_chatterbox_f32(model)
|
||||
self.model = model
|
||||
|
||||
logger.info("Chatterbox Turbo TTS loaded successfully")
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
if self.model is not None:
|
||||
device = self._device
|
||||
del self.model
|
||||
self.model = None
|
||||
self._device = None
|
||||
empty_device_cache(device)
|
||||
logger.info("Chatterbox Turbo unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
Chatterbox Turbo processes reference audio at generation time, so the
|
||||
prompt just stores the file path.
|
||||
"""
|
||||
voice_prompt = {
|
||||
"ref_audio": str(audio_path),
|
||||
"ref_text": reference_text,
|
||||
}
|
||||
return voice_prompt, False
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio using Chatterbox Turbo TTS.
|
||||
|
||||
Supports paralinguistic tags in text: [laugh], [cough], [chuckle], etc.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize (may include paralinguistic tags)
|
||||
voice_prompt: Dict with ref_audio path
|
||||
language: Ignored (Turbo is English-only)
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Unused (protocol compatibility)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model()
|
||||
|
||||
ref_audio = voice_prompt.get("ref_audio")
|
||||
if ref_audio and not Path(ref_audio).exists():
|
||||
logger.warning(f"Reference audio not found: {ref_audio}")
|
||||
ref_audio = None
|
||||
|
||||
def _generate_sync():
|
||||
import torch
|
||||
|
||||
if seed is not None:
|
||||
manual_seed(seed, self._device)
|
||||
|
||||
logger.info("[Chatterbox Turbo] Generating (English)")
|
||||
|
||||
wav = self.model.generate(
|
||||
text,
|
||||
audio_prompt_path=ref_audio,
|
||||
temperature=0.8,
|
||||
top_k=1000,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.2,
|
||||
)
|
||||
|
||||
# Convert tensor -> numpy
|
||||
if isinstance(wav, torch.Tensor):
|
||||
audio = wav.squeeze().cpu().numpy().astype(np.float32)
|
||||
else:
|
||||
audio = np.asarray(wav, dtype=np.float32)
|
||||
|
||||
sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000)
|
||||
|
||||
return audio, sample_rate
|
||||
|
||||
return await asyncio.to_thread(_generate_sync)
|
||||
Reference in New Issue
Block a user