207 lines
6.4 KiB
Python
207 lines
6.4 KiB
Python
"""
|
|
Chatterbox Turbo TTS backend implementation.
|
|
|
|
Wraps ChatterboxTurboTTS from chatterbox-tts for fast, English-only
|
|
voice cloning with paralinguistic tag support ([laugh], [cough], etc.).
|
|
Forces CPU on macOS due to known MPS tensor issues.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import threading
|
|
from pathlib import Path
|
|
from typing import ClassVar, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from . import TTSBackend
|
|
from .base import (
|
|
is_model_cached,
|
|
get_torch_device,
|
|
empty_device_cache,
|
|
manual_seed,
|
|
combine_voice_prompts as _combine_voice_prompts,
|
|
model_load_progress,
|
|
patch_chatterbox_f32,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
CHATTERBOX_TURBO_HF_REPO = "ResembleAI/chatterbox-turbo"
|
|
|
|
# Files that must be present for the turbo model
|
|
_TURBO_WEIGHT_FILES = [
|
|
"t3_turbo_v1.safetensors",
|
|
"s3gen_meanflow.safetensors",
|
|
"ve.safetensors",
|
|
]
|
|
|
|
|
|
class ChatterboxTurboTTSBackend:
|
|
"""Chatterbox Turbo TTS backend — fast, English-only, with paralinguistic tags."""
|
|
|
|
# Class-level lock for torch.load monkey-patching
|
|
_load_lock: ClassVar[threading.Lock] = threading.Lock()
|
|
|
|
def __init__(self):
|
|
self.model = None
|
|
self.model_size = "default"
|
|
self._device = None
|
|
self._model_load_lock = asyncio.Lock()
|
|
|
|
def _get_device(self) -> str:
|
|
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
|
|
|
|
def is_loaded(self) -> bool:
|
|
return self.model is not None
|
|
|
|
def _get_model_path(self, model_size: str = "default") -> str:
|
|
return CHATTERBOX_TURBO_HF_REPO
|
|
|
|
def _is_model_cached(self, model_size: str = "default") -> bool:
|
|
return is_model_cached(CHATTERBOX_TURBO_HF_REPO, required_files=_TURBO_WEIGHT_FILES)
|
|
|
|
async def load_model(self, model_size: str = "default") -> None:
|
|
"""Load the Chatterbox Turbo model."""
|
|
if self.model is not None:
|
|
return
|
|
async with self._model_load_lock:
|
|
if self.model is not None:
|
|
return
|
|
await asyncio.to_thread(self._load_model_sync)
|
|
|
|
def _load_model_sync(self):
|
|
"""Synchronous model loading."""
|
|
model_name = "chatterbox-turbo"
|
|
is_cached = self._is_model_cached()
|
|
|
|
with model_load_progress(model_name, is_cached):
|
|
device = self._get_device()
|
|
self._device = device
|
|
logger.info(f"Loading Chatterbox Turbo TTS on {device}...")
|
|
|
|
import torch
|
|
from huggingface_hub import snapshot_download
|
|
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
|
|
|
local_path = snapshot_download(
|
|
repo_id=CHATTERBOX_TURBO_HF_REPO,
|
|
token=None,
|
|
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"],
|
|
)
|
|
|
|
if device == "cpu":
|
|
_orig_torch_load = torch.load
|
|
|
|
def _patched_load(*args, **kwargs):
|
|
kwargs.setdefault("map_location", "cpu")
|
|
return _orig_torch_load(*args, **kwargs)
|
|
|
|
with ChatterboxTurboTTSBackend._load_lock:
|
|
torch.load = _patched_load
|
|
try:
|
|
model = ChatterboxTurboTTS.from_local(local_path, device)
|
|
finally:
|
|
torch.load = _orig_torch_load
|
|
else:
|
|
model = ChatterboxTurboTTS.from_local(local_path, device)
|
|
|
|
patch_chatterbox_f32(model)
|
|
self.model = model
|
|
|
|
logger.info("Chatterbox Turbo TTS loaded successfully")
|
|
|
|
def unload_model(self) -> None:
|
|
"""Unload model to free memory."""
|
|
if self.model is not None:
|
|
device = self._device
|
|
del self.model
|
|
self.model = None
|
|
self._device = None
|
|
empty_device_cache(device)
|
|
logger.info("Chatterbox Turbo unloaded")
|
|
|
|
async def create_voice_prompt(
|
|
self,
|
|
audio_path: str,
|
|
reference_text: str,
|
|
use_cache: bool = True,
|
|
) -> Tuple[dict, bool]:
|
|
"""
|
|
Create voice prompt from reference audio.
|
|
|
|
Chatterbox Turbo processes reference audio at generation time, so the
|
|
prompt just stores the file path.
|
|
"""
|
|
voice_prompt = {
|
|
"ref_audio": str(audio_path),
|
|
"ref_text": reference_text,
|
|
}
|
|
return voice_prompt, False
|
|
|
|
async def combine_voice_prompts(
|
|
self,
|
|
audio_paths: List[str],
|
|
reference_texts: List[str],
|
|
) -> Tuple[np.ndarray, str]:
|
|
return await _combine_voice_prompts(audio_paths, reference_texts)
|
|
|
|
async def generate(
|
|
self,
|
|
text: str,
|
|
voice_prompt: dict,
|
|
language: str = "en",
|
|
seed: Optional[int] = None,
|
|
instruct: Optional[str] = None,
|
|
) -> Tuple[np.ndarray, int]:
|
|
"""
|
|
Generate audio using Chatterbox Turbo TTS.
|
|
|
|
Supports paralinguistic tags in text: [laugh], [cough], [chuckle], etc.
|
|
|
|
Args:
|
|
text: Text to synthesize (may include paralinguistic tags)
|
|
voice_prompt: Dict with ref_audio path
|
|
language: Ignored (Turbo is English-only)
|
|
seed: Random seed for reproducibility
|
|
instruct: Unused (protocol compatibility)
|
|
|
|
Returns:
|
|
Tuple of (audio_array, sample_rate)
|
|
"""
|
|
await self.load_model()
|
|
|
|
ref_audio = voice_prompt.get("ref_audio")
|
|
if ref_audio and not Path(ref_audio).exists():
|
|
logger.warning(f"Reference audio not found: {ref_audio}")
|
|
ref_audio = None
|
|
|
|
def _generate_sync():
|
|
import torch
|
|
|
|
if seed is not None:
|
|
manual_seed(seed, self._device)
|
|
|
|
logger.info("[Chatterbox Turbo] Generating (English)")
|
|
|
|
wav = self.model.generate(
|
|
text,
|
|
audio_prompt_path=ref_audio,
|
|
temperature=0.8,
|
|
top_k=1000,
|
|
top_p=0.95,
|
|
repetition_penalty=1.2,
|
|
)
|
|
|
|
# Convert tensor -> numpy
|
|
if isinstance(wav, torch.Tensor):
|
|
audio = wav.squeeze().cpu().numpy().astype(np.float32)
|
|
else:
|
|
audio = np.asarray(wav, dtype=np.float32)
|
|
|
|
sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000)
|
|
|
|
return audio, sample_rate
|
|
|
|
return await asyncio.to_thread(_generate_sync)
|