Initial commit
This commit is contained in:
226
backend/backends/chatterbox_backend.py
Normal file
226
backend/backends/chatterbox_backend.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
Chatterbox TTS backend implementation.
|
||||
|
||||
Wraps ChatterboxMultilingualTTS from chatterbox-tts for zero-shot
|
||||
voice cloning. Supports 23 languages including Hebrew. Forces CPU
|
||||
on macOS due to known MPS tensor issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import TTSBackend
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
empty_device_cache,
|
||||
manual_seed,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
patch_chatterbox_f32,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHATTERBOX_HF_REPO = "ResembleAI/chatterbox"
|
||||
|
||||
# Files that must be present for the multilingual model
|
||||
_MTL_WEIGHT_FILES = [
|
||||
"t3_mtl23ls_v2.safetensors",
|
||||
"s3gen.pt",
|
||||
"ve.pt",
|
||||
]
|
||||
|
||||
|
||||
class ChatterboxTTSBackend:
|
||||
"""Chatterbox Multilingual TTS backend for voice cloning."""
|
||||
|
||||
# Class-level lock for torch.load monkey-patching
|
||||
_load_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.model_size = "default"
|
||||
self._device = None
|
||||
self._model_load_lock = asyncio.Lock()
|
||||
|
||||
def _get_device(self) -> str:
|
||||
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str = "default") -> str:
|
||||
return CHATTERBOX_HF_REPO
|
||||
|
||||
def _is_model_cached(self, model_size: str = "default") -> bool:
|
||||
return is_model_cached(CHATTERBOX_HF_REPO, required_files=_MTL_WEIGHT_FILES)
|
||||
|
||||
async def load_model(self, model_size: str = "default") -> None:
|
||||
"""Load the Chatterbox multilingual model."""
|
||||
if self.model is not None:
|
||||
return
|
||||
async with self._model_load_lock:
|
||||
if self.model is not None:
|
||||
return
|
||||
await asyncio.to_thread(self._load_model_sync)
|
||||
|
||||
def _load_model_sync(self):
|
||||
"""Synchronous model loading."""
|
||||
model_name = "chatterbox-tts"
|
||||
is_cached = self._is_model_cached()
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
device = self._get_device()
|
||||
self._device = device
|
||||
logger.info(f"Loading Chatterbox Multilingual TTS on {device}...")
|
||||
|
||||
import torch
|
||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||
|
||||
if device == "cpu":
|
||||
_orig_torch_load = torch.load
|
||||
|
||||
def _patched_load(*args, **kwargs):
|
||||
kwargs.setdefault("map_location", "cpu")
|
||||
return _orig_torch_load(*args, **kwargs)
|
||||
|
||||
with ChatterboxTTSBackend._load_lock:
|
||||
torch.load = _patched_load
|
||||
try:
|
||||
model = ChatterboxMultilingualTTS.from_pretrained(device=device)
|
||||
finally:
|
||||
torch.load = _orig_torch_load
|
||||
else:
|
||||
model = ChatterboxMultilingualTTS.from_pretrained(device=device)
|
||||
|
||||
# Fix sdpa attention for output_attentions support
|
||||
t3_tfmr = model.t3.tfmr
|
||||
if hasattr(t3_tfmr, "config") and hasattr(t3_tfmr.config, "_attn_implementation"):
|
||||
t3_tfmr.config._attn_implementation = "eager"
|
||||
for layer in getattr(t3_tfmr, "layers", []):
|
||||
if hasattr(layer, "self_attn"):
|
||||
layer.self_attn._attn_implementation = "eager"
|
||||
|
||||
patch_chatterbox_f32(model)
|
||||
self.model = model
|
||||
|
||||
logger.info("Chatterbox Multilingual TTS loaded successfully")
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
if self.model is not None:
|
||||
device = self._device
|
||||
del self.model
|
||||
self.model = None
|
||||
self._device = None
|
||||
empty_device_cache(device)
|
||||
logger.info("Chatterbox unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
Chatterbox processes reference audio at generation time, so the
|
||||
prompt just stores the file path. The actual audio is loaded by
|
||||
model.generate() via audio_prompt_path.
|
||||
"""
|
||||
voice_prompt = {
|
||||
"ref_audio": str(audio_path),
|
||||
"ref_text": reference_text,
|
||||
}
|
||||
return voice_prompt, False
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts)
|
||||
|
||||
# Per-language generation defaults. Lower temp + higher cfg = clearer speech.
|
||||
_LANG_DEFAULTS: ClassVar[dict] = {
|
||||
"he": {
|
||||
"exaggeration": 0.4,
|
||||
"cfg_weight": 0.7,
|
||||
"temperature": 0.65,
|
||||
"repetition_penalty": 2.5,
|
||||
},
|
||||
}
|
||||
_GLOBAL_DEFAULTS: ClassVar[dict] = {
|
||||
"exaggeration": 0.5,
|
||||
"cfg_weight": 0.5,
|
||||
"temperature": 0.8,
|
||||
"repetition_penalty": 2.0,
|
||||
}
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio using Chatterbox Multilingual TTS.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Dict with ref_audio path
|
||||
language: BCP-47 language code
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Unused (protocol compatibility)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model()
|
||||
|
||||
ref_audio = voice_prompt.get("ref_audio")
|
||||
if ref_audio and not Path(ref_audio).exists():
|
||||
logger.warning(f"Reference audio not found: {ref_audio}")
|
||||
ref_audio = None
|
||||
|
||||
# Merge language-specific defaults with global defaults
|
||||
lang_defaults = self._LANG_DEFAULTS.get(language, self._GLOBAL_DEFAULTS)
|
||||
|
||||
def _generate_sync():
|
||||
import torch
|
||||
|
||||
if seed is not None:
|
||||
manual_seed(seed, self._device)
|
||||
|
||||
logger.info(f"[Chatterbox] Generating: lang={language}")
|
||||
|
||||
wav = self.model.generate(
|
||||
text,
|
||||
language_id=language,
|
||||
audio_prompt_path=ref_audio,
|
||||
exaggeration=lang_defaults["exaggeration"],
|
||||
cfg_weight=lang_defaults["cfg_weight"],
|
||||
temperature=lang_defaults["temperature"],
|
||||
repetition_penalty=lang_defaults["repetition_penalty"],
|
||||
)
|
||||
|
||||
# Convert tensor -> numpy
|
||||
if isinstance(wav, torch.Tensor):
|
||||
audio = wav.squeeze().cpu().numpy().astype(np.float32)
|
||||
else:
|
||||
audio = np.asarray(wav, dtype=np.float32)
|
||||
|
||||
sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000)
|
||||
|
||||
return audio, sample_rate
|
||||
|
||||
return await asyncio.to_thread(_generate_sync)
|
||||
Reference in New Issue
Block a user