Initial commit

This commit is contained in:
2026-04-24 19:18:15 +08:00
commit fbcbe08696
555 changed files with 96692 additions and 0 deletions

View File

@@ -0,0 +1,621 @@
"""
Backend abstraction layer for TTS and STT.
Provides a unified interface for MLX and PyTorch backends,
and a model config registry that eliminates per-engine dispatch maps.
"""
# Install HF compatibility patches before any backend imports transformers /
# huggingface_hub. The module runs ``patch_transformers_mistral_regex`` at
# import time, which wraps transformers' tokenizer load against the
# unconditional HuggingFace metadata call that otherwise raises on
# HF_HUB_OFFLINE=1 and on network failures.
from ..utils import hf_offline_patch # noqa: F401
import threading
from dataclasses import dataclass, field
from typing import Protocol, Optional, Tuple, List
from typing_extensions import runtime_checkable
import numpy as np
from ..utils.platform_detect import get_backend_type
LANGUAGE_CODE_TO_NAME = {
"zh": "chinese",
"en": "english",
"ja": "japanese",
"ko": "korean",
"de": "german",
"fr": "french",
"ru": "russian",
"pt": "portuguese",
"es": "spanish",
"it": "italian",
}
WHISPER_HF_REPOS = {
"base": "openai/whisper-base",
"small": "openai/whisper-small",
"medium": "openai/whisper-medium",
"large": "openai/whisper-large-v3",
"turbo": "openai/whisper-large-v3-turbo",
}
@dataclass
class ModelConfig:
"""Declarative config for a downloadable model variant."""
model_name: str # e.g. "luxtts", "chatterbox-tts"
display_name: str # e.g. "LuxTTS (Fast, CPU-friendly)"
engine: str # e.g. "luxtts", "chatterbox"
hf_repo_id: str # e.g. "YatharthS/LuxTTS"
model_size: str = "default"
size_mb: int = 0
needs_trim: bool = False
supports_instruct: bool = False
languages: list[str] = field(default_factory=lambda: ["en"])
@runtime_checkable
class TTSBackend(Protocol):
"""Protocol for TTS backend implementations."""
# Each backend class should define MODEL_CONFIGS as a class variable:
# MODEL_CONFIGS: list[ModelConfig]
async def load_model(self, model_size: str) -> None:
"""Load TTS model."""
...
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio.
Returns:
Tuple of (voice_prompt_dict, was_cached)
"""
...
async def combine_voice_prompts(
self,
audio_paths: List[str],
reference_texts: List[str],
) -> Tuple[np.ndarray, str]:
"""
Combine multiple voice prompts.
Returns:
Tuple of (combined_audio_array, combined_text)
"""
...
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio from text.
Returns:
Tuple of (audio_array, sample_rate)
"""
...
def unload_model(self) -> None:
"""Unload model to free memory."""
...
def is_loaded(self) -> bool:
"""Check if model is loaded."""
...
def _get_model_path(self, model_size: str) -> str:
"""
Get model path for a given size.
Returns:
Model path or HuggingFace Hub ID
"""
...
@runtime_checkable
class STTBackend(Protocol):
"""Protocol for STT (Speech-to-Text) backend implementations."""
async def load_model(self, model_size: str) -> None:
"""Load STT model."""
...
async def transcribe(
self,
audio_path: str,
language: Optional[str] = None,
model_size: Optional[str] = None,
) -> str:
"""
Transcribe audio to text.
Returns:
Transcribed text
"""
...
def unload_model(self) -> None:
"""Unload model to free memory."""
...
def is_loaded(self) -> bool:
"""Check if model is loaded."""
...
# Global backend instances
_tts_backend: Optional[TTSBackend] = None
_tts_backends: dict[str, TTSBackend] = {}
_tts_backends_lock = threading.Lock()
_stt_backend: Optional[STTBackend] = None
# Supported TTS engines — keyed by engine name, value is the backend class import path.
# The factory function uses this for the if/elif chain; the model configs live on the backend classes.
TTS_ENGINES = {
"qwen": "Qwen TTS",
"qwen_custom_voice": "Qwen CustomVoice",
"luxtts": "LuxTTS",
"chatterbox": "Chatterbox TTS",
"chatterbox_turbo": "Chatterbox Turbo",
"tada": "TADA",
"kokoro": "Kokoro",
}
def _get_qwen_model_configs() -> list[ModelConfig]:
"""Return Qwen model configs with backend-aware HF repo IDs."""
backend_type = get_backend_type()
if backend_type == "mlx":
repo_1_7b = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16"
repo_0_6b = "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-bf16"
else:
repo_1_7b = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
repo_0_6b = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
return [
ModelConfig(
model_name="qwen-tts-1.7B",
display_name="Qwen TTS 1.7B",
engine="qwen",
hf_repo_id=repo_1_7b,
model_size="1.7B",
size_mb=3500,
supports_instruct=False, # Base model drops instruct silently
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
),
ModelConfig(
model_name="qwen-tts-0.6B",
display_name="Qwen TTS 0.6B",
engine="qwen",
hf_repo_id=repo_0_6b,
model_size="0.6B",
size_mb=1200,
supports_instruct=False,
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
),
]
def _get_qwen_custom_voice_configs() -> list[ModelConfig]:
"""Return Qwen CustomVoice model configs."""
return [
ModelConfig(
model_name="qwen-custom-voice-1.7B",
display_name="Qwen CustomVoice 1.7B",
engine="qwen_custom_voice",
hf_repo_id="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
model_size="1.7B",
size_mb=3500,
supports_instruct=True,
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
),
ModelConfig(
model_name="qwen-custom-voice-0.6B",
display_name="Qwen CustomVoice 0.6B",
engine="qwen_custom_voice",
hf_repo_id="Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
model_size="0.6B",
size_mb=1200,
supports_instruct=True,
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
),
]
def _get_non_qwen_tts_configs() -> list[ModelConfig]:
"""Return model configs for non-Qwen TTS engines.
These are static — no backend-type branching needed.
"""
return [
ModelConfig(
model_name="luxtts",
display_name="LuxTTS (Fast, CPU-friendly)",
engine="luxtts",
hf_repo_id="YatharthS/LuxTTS",
size_mb=300,
languages=["en"],
),
ModelConfig(
model_name="chatterbox-tts",
display_name="Chatterbox TTS (Multilingual)",
engine="chatterbox",
hf_repo_id="ResembleAI/chatterbox",
size_mb=3200,
needs_trim=True,
languages=[
"zh",
"en",
"ja",
"ko",
"de",
"fr",
"ru",
"pt",
"es",
"it",
"he",
"ar",
"da",
"el",
"fi",
"hi",
"ms",
"nl",
"no",
"pl",
"sv",
"sw",
"tr",
],
),
ModelConfig(
model_name="chatterbox-turbo",
display_name="Chatterbox Turbo (English, Tags)",
engine="chatterbox_turbo",
hf_repo_id="ResembleAI/chatterbox-turbo",
size_mb=1500,
needs_trim=True,
languages=["en"],
),
ModelConfig(
model_name="tada-1b",
display_name="TADA 1B (English)",
engine="tada",
hf_repo_id="HumeAI/tada-1b",
model_size="1B",
size_mb=4000,
languages=["en"],
),
ModelConfig(
model_name="tada-3b-ml",
display_name="TADA 3B Multilingual",
engine="tada",
hf_repo_id="HumeAI/tada-3b-ml",
model_size="3B",
size_mb=8000,
languages=["en", "ar", "zh", "de", "es", "fr", "it", "ja", "pl", "pt"],
),
ModelConfig(
model_name="kokoro",
display_name="Kokoro 82M",
engine="kokoro",
hf_repo_id="hexgrad/Kokoro-82M",
size_mb=350,
languages=["en", "es", "fr", "hi", "it", "pt", "ja", "zh"],
),
]
def _get_whisper_configs() -> list[ModelConfig]:
"""Return Whisper STT model configs."""
return [
ModelConfig(
model_name="whisper-base",
display_name="Whisper Base",
engine="whisper",
hf_repo_id="openai/whisper-base",
model_size="base",
),
ModelConfig(
model_name="whisper-small",
display_name="Whisper Small",
engine="whisper",
hf_repo_id="openai/whisper-small",
model_size="small",
),
ModelConfig(
model_name="whisper-medium",
display_name="Whisper Medium",
engine="whisper",
hf_repo_id="openai/whisper-medium",
model_size="medium",
),
ModelConfig(
model_name="whisper-large",
display_name="Whisper Large",
engine="whisper",
hf_repo_id="openai/whisper-large-v3",
model_size="large",
),
ModelConfig(
model_name="whisper-turbo",
display_name="Whisper Turbo",
engine="whisper",
hf_repo_id="openai/whisper-large-v3-turbo",
model_size="turbo",
),
]
def get_all_model_configs() -> list[ModelConfig]:
"""Return the full list of model configs (TTS + STT)."""
return _get_qwen_model_configs() + _get_qwen_custom_voice_configs() + _get_non_qwen_tts_configs() + _get_whisper_configs()
def get_tts_model_configs() -> list[ModelConfig]:
"""Return only TTS model configs."""
return _get_qwen_model_configs() + _get_qwen_custom_voice_configs() + _get_non_qwen_tts_configs()
# Lookup helpers — these replace the if/elif chains in main.py
def get_model_config(model_name: str) -> Optional[ModelConfig]:
"""Look up a model config by model_name."""
for cfg in get_all_model_configs():
if cfg.model_name == model_name:
return cfg
return None
def engine_needs_trim(engine: str) -> bool:
"""Whether this engine's output should be run through trim_tts_output."""
for cfg in get_tts_model_configs():
if cfg.engine == engine:
return cfg.needs_trim
return False
def engine_has_model_sizes(engine: str) -> bool:
"""Whether this engine supports multiple model sizes (only Qwen currently)."""
configs = [c for c in get_tts_model_configs() if c.engine == engine]
return len(configs) > 1
async def load_engine_model(engine: str, model_size: str = "default") -> None:
"""Load a model for the given engine, handling engines with multiple model sizes."""
backend = get_tts_backend_for_engine(engine)
if engine in ("qwen", "qwen_custom_voice"):
await backend.load_model_async(model_size)
elif engine == "tada":
await backend.load_model(model_size)
else:
await backend.load_model()
async def ensure_model_cached_or_raise(engine: str, model_size: str = "default") -> None:
"""Check if a model is cached, raise HTTPException if not. Used by streaming endpoint."""
from fastapi import HTTPException
backend = get_tts_backend_for_engine(engine)
cfg = None
for c in get_tts_model_configs():
if c.engine == engine and c.model_size == model_size:
cfg = c
break
if engine in ("qwen", "qwen_custom_voice", "tada"):
if not backend._is_model_cached(model_size):
raise HTTPException(
status_code=400,
detail=f"Model {model_size} is not downloaded yet. Use /generate to trigger a download.",
)
else:
if not backend._is_model_cached():
display = cfg.display_name if cfg else engine
raise HTTPException(
status_code=400,
detail=f"{display} model is not downloaded yet. Use /generate to trigger a download.",
)
def unload_model_by_config(config: ModelConfig) -> bool:
"""Unload a model given its config. Returns True if it was loaded, False otherwise."""
from . import get_tts_backend_for_engine
from ..services import tts, transcribe
if config.engine == "whisper":
whisper_model = transcribe.get_whisper_model()
if whisper_model.is_loaded() and whisper_model.model_size == config.model_size:
transcribe.unload_whisper_model()
return True
return False
if config.engine == "qwen":
tts_model = tts.get_tts_model()
loaded_size = getattr(tts_model, "_current_model_size", None) or getattr(tts_model, "model_size", None)
if tts_model.is_loaded() and loaded_size == config.model_size:
tts.unload_tts_model()
return True
return False
if config.engine == "qwen_custom_voice":
backend = get_tts_backend_for_engine(config.engine)
loaded_size = getattr(backend, "_current_model_size", None) or getattr(backend, "model_size", None)
if backend.is_loaded() and loaded_size == config.model_size:
backend.unload_model()
return True
return False
# All other TTS engines
backend = get_tts_backend_for_engine(config.engine)
if backend.is_loaded():
backend.unload_model()
return True
return False
def check_model_loaded(config: ModelConfig) -> bool:
"""Check if a model is currently loaded."""
from . import get_tts_backend_for_engine
from ..services import tts, transcribe
try:
if config.engine == "whisper":
whisper_model = transcribe.get_whisper_model()
return whisper_model.is_loaded() and getattr(whisper_model, "model_size", None) == config.model_size
if config.engine == "qwen":
tts_model = tts.get_tts_model()
loaded_size = getattr(tts_model, "_current_model_size", None) or getattr(tts_model, "model_size", None)
return tts_model.is_loaded() and loaded_size == config.model_size
if config.engine == "qwen_custom_voice":
backend = get_tts_backend_for_engine(config.engine)
loaded_size = getattr(backend, "_current_model_size", None) or getattr(backend, "model_size", None)
return backend.is_loaded() and loaded_size == config.model_size
backend = get_tts_backend_for_engine(config.engine)
return backend.is_loaded()
except Exception:
return False
def get_model_load_func(config: ModelConfig):
"""Return a callable that loads/downloads the model."""
from . import get_tts_backend_for_engine
from ..services import tts, transcribe
if config.engine == "whisper":
return lambda: transcribe.get_whisper_model().load_model(config.model_size)
if config.engine == "qwen":
return lambda: tts.get_tts_model().load_model(config.model_size)
if config.engine == "qwen_custom_voice":
return lambda: get_tts_backend_for_engine(config.engine).load_model(config.model_size)
return lambda: get_tts_backend_for_engine(config.engine).load_model()
def get_tts_backend() -> TTSBackend:
"""
Get or create the default (Qwen) TTS backend instance based on platform.
Returns:
TTS backend instance (MLX or PyTorch)
"""
return get_tts_backend_for_engine("qwen")
def get_tts_backend_for_engine(engine: str) -> TTSBackend:
"""
Get or create a TTS backend for the given engine.
Args:
engine: Engine name (e.g. "qwen", "luxtts", "chatterbox", "chatterbox_turbo")
Returns:
TTS backend instance
"""
global _tts_backends
# Fast path: check without lock
if engine in _tts_backends:
return _tts_backends[engine]
# Slow path: create with lock to avoid duplicate instantiation
with _tts_backends_lock:
# Double-check after acquiring lock
if engine in _tts_backends:
return _tts_backends[engine]
if engine == "qwen":
backend_type = get_backend_type()
if backend_type == "mlx":
from .mlx_backend import MLXTTSBackend
backend = MLXTTSBackend()
else:
from .pytorch_backend import PyTorchTTSBackend
backend = PyTorchTTSBackend()
elif engine == "luxtts":
from .luxtts_backend import LuxTTSBackend
backend = LuxTTSBackend()
elif engine == "chatterbox":
from .chatterbox_backend import ChatterboxTTSBackend
backend = ChatterboxTTSBackend()
elif engine == "chatterbox_turbo":
from .chatterbox_turbo_backend import ChatterboxTurboTTSBackend
backend = ChatterboxTurboTTSBackend()
elif engine == "tada":
from .hume_backend import HumeTadaBackend
backend = HumeTadaBackend()
elif engine == "kokoro":
from .kokoro_backend import KokoroTTSBackend
backend = KokoroTTSBackend()
elif engine == "qwen_custom_voice":
from .qwen_custom_voice_backend import QwenCustomVoiceBackend
backend = QwenCustomVoiceBackend()
else:
raise ValueError(f"Unknown TTS engine: {engine}. Supported: {list(TTS_ENGINES.keys())}")
_tts_backends[engine] = backend
return backend
def get_stt_backend() -> STTBackend:
"""
Get or create STT backend instance based on platform.
Returns:
STT backend instance (MLX or PyTorch)
"""
global _stt_backend
if _stt_backend is None:
backend_type = get_backend_type()
if backend_type == "mlx":
from .mlx_backend import MLXSTTBackend
_stt_backend = MLXSTTBackend()
else:
from .pytorch_backend import PyTorchSTTBackend
_stt_backend = PyTorchSTTBackend()
return _stt_backend
def reset_backends():
"""Reset backend instances (useful for testing)."""
global _tts_backend, _tts_backends, _stt_backend
_tts_backend = None
_tts_backends.clear()
_stt_backend = None

327
backend/backends/base.py Normal file
View File

@@ -0,0 +1,327 @@
"""
Shared utilities for TTS/STT backend implementations.
Eliminates duplication of cache checking, device detection,
voice prompt combination, and model loading progress tracking.
"""
import logging
import platform
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, List, Optional, Tuple
import numpy as np
from ..utils.audio import normalize_audio, load_audio
from ..utils.progress import get_progress_manager
from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback
from ..utils.tasks import get_task_manager
logger = logging.getLogger(__name__)
def is_model_cached(
hf_repo: str,
*,
weight_extensions: tuple[str, ...] = (".safetensors", ".bin"),
required_files: Optional[list[str]] = None,
) -> bool:
"""
Check if a HuggingFace model is fully cached locally.
Args:
hf_repo: HuggingFace repo ID (e.g. "Qwen/Qwen3-TTS-12Hz-1.7B-Base")
weight_extensions: File extensions that count as model weights.
required_files: If set, check that these specific filenames exist
in snapshots instead of checking by extension.
Returns:
True if model is fully cached, False if missing or incomplete.
"""
try:
from huggingface_hub import constants as hf_constants
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + hf_repo.replace("/", "--"))
if not repo_cache.exists():
return False
# Incomplete blobs mean a download is still in progress
blobs_dir = repo_cache / "blobs"
if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
logger.debug(f"Found .incomplete files for {hf_repo}")
return False
snapshots_dir = repo_cache / "snapshots"
if not snapshots_dir.exists():
return False
if required_files:
# Check that every required filename exists somewhere in snapshots
for fname in required_files:
if not any(snapshots_dir.rglob(fname)):
return False
return True
# Check that at least one weight file exists
for ext in weight_extensions:
if any(snapshots_dir.rglob(f"*{ext}")):
return True
logger.debug(f"No model weights found for {hf_repo}")
return False
except Exception as e:
logger.warning(f"Error checking cache for {hf_repo}: {e}")
return False
def get_torch_device(
*,
allow_xpu: bool = False,
allow_directml: bool = False,
allow_mps: bool = False,
force_cpu_on_mac: bool = False,
) -> str:
"""
Detect the best available torch device.
Args:
allow_xpu: Check for Intel XPU (IPEX) support.
allow_directml: Check for DirectML (Windows) support.
allow_mps: Allow MPS (Apple Silicon). If False, MPS falls back to CPU.
force_cpu_on_mac: Force CPU on macOS regardless of GPU availability.
"""
if force_cpu_on_mac and platform.system() == "Darwin":
return "cpu"
import torch
if torch.cuda.is_available():
return "cuda"
if allow_xpu:
try:
import intel_extension_for_pytorch # noqa: F401
if hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
except ImportError:
pass
if allow_directml:
try:
import torch_directml
if torch_directml.device_count() > 0:
return torch_directml.device(0)
except ImportError:
pass
if allow_mps:
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
def check_cuda_compatibility() -> tuple[bool, str | None]:
"""Check if the installed PyTorch supports the current GPU's compute capability.
Returns:
(compatible, warning_message) — compatible is True if OK or no CUDA GPU,
warning_message is a human-readable string if there's a problem.
"""
import torch
if not torch.cuda.is_available():
return True, None
major, minor = torch.cuda.get_device_capability(0)
capability = f"{major}.{minor}"
device_name = torch.cuda.get_device_name(0)
sm_tag = f"sm_{major}{minor}"
# torch.cuda._get_arch_list() returns the SM architectures this build
# was compiled for (e.g. ["sm_50", "sm_60", ..., "sm_90"]).
try:
arch_list = torch.cuda._get_arch_list()
if arch_list:
# Check for both sm_XX and compute_XX (JIT-compiled) entries
compute_tag = f"compute_{major}{minor}"
if sm_tag not in arch_list and compute_tag not in arch_list:
return False, (
f"{device_name} (compute capability {capability} / {sm_tag}) "
f"is not supported by this PyTorch build. "
f"Supported architectures: {', '.join(arch_list)}. "
f"Install PyTorch nightly (cu128) for newer GPU support: "
f"pip install torch --index-url https://download.pytorch.org/whl/nightly/cu128"
)
except AttributeError:
pass
return True, None
def empty_device_cache(device: str) -> None:
"""
Free cached memory on the given device (CUDA or XPU).
Backends should call this after unloading models so VRAM is returned
to the OS.
"""
import torch
if device == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()
elif device == "xpu" and hasattr(torch, "xpu"):
torch.xpu.empty_cache()
def manual_seed(seed: int, device: str) -> None:
"""
Set the random seed on both CPU and the active accelerator.
Covers CUDA and Intel XPU so that generation is reproducible
regardless of which GPU backend is in use.
"""
import torch
torch.manual_seed(seed)
if device == "cuda" and torch.cuda.is_available():
torch.cuda.manual_seed(seed)
elif device == "xpu" and hasattr(torch, "xpu"):
torch.xpu.manual_seed(seed)
async def combine_voice_prompts(
audio_paths: List[str],
reference_texts: List[str],
*,
sample_rate: Optional[int] = None,
) -> Tuple[np.ndarray, str]:
"""
Combine multiple reference audio samples into one.
Loads each audio file, normalizes, concatenates, and joins texts.
Args:
audio_paths: Paths to reference audio files.
reference_texts: Corresponding transcripts.
sample_rate: If set, resample audio to this rate during loading.
"""
combined_audio = []
for path in audio_paths:
kwargs = {"sample_rate": sample_rate} if sample_rate else {}
audio, _sr = load_audio(path, **kwargs)
audio = normalize_audio(audio)
combined_audio.append(audio)
mixed = np.concatenate(combined_audio)
mixed = normalize_audio(mixed)
combined_text = " ".join(reference_texts)
return mixed, combined_text
@contextmanager
def model_load_progress(
model_name: str,
is_cached: bool,
filter_non_downloads: Optional[bool] = None,
):
"""
Context manager for model loading with HF download progress tracking.
Handles the tqdm patching, progress_manager/task_manager lifecycle,
and error reporting that every backend duplicates.
Args:
model_name: Progress tracking key (e.g. "qwen-tts-1.7B", "whisper-base").
is_cached: Whether the model is already downloaded.
filter_non_downloads: Whether to filter non-download tqdm bars.
Defaults to `is_cached`.
Yields:
The tracker context (already entered). The caller loads the model
inside the `with` block. The tqdm patch is torn down on exit.
Usage:
with model_load_progress("qwen-tts-1.7B", is_cached) as ctx:
self.model = SomeModel.from_pretrained(...)
"""
if filter_non_downloads is None:
filter_non_downloads = is_cached
progress_manager = get_progress_manager()
task_manager = get_task_manager()
progress_callback = create_hf_progress_callback(model_name, progress_manager)
tracker = HFProgressTracker(progress_callback, filter_non_downloads=filter_non_downloads)
tracker_context = tracker.patch_download()
tracker_context.__enter__()
if not is_cached:
task_manager.start_download(model_name)
progress_manager.update_progress(
model_name=model_name,
current=0,
total=0,
filename="Connecting to HuggingFace...",
status="downloading",
)
try:
yield tracker_context
except Exception as e:
# Report error to both managers
progress_manager.mark_error(model_name, str(e))
task_manager.error_download(model_name, str(e))
raise
else:
# Only mark complete if we were tracking a download
if not is_cached:
progress_manager.mark_complete(model_name)
task_manager.complete_download(model_name)
finally:
tracker_context.__exit__(None, None, None)
def patch_chatterbox_f32(model) -> None:
"""
Patch float64 -> float32 dtype mismatches in upstream chatterbox.
librosa.load returns float64 numpy arrays. Multiple upstream code paths
convert these to torch tensors via torch.from_numpy() without casting,
then matmul against float32 model weights. This patches the two known
entry points:
1. S3Tokenizer.log_mel_spectrogram — audio tensor hits _mel_filters (f32)
2. VoiceEncoder.forward — float64 mel spectrograms hit LSTM weights (f32)
"""
import types
# Patch S3Tokenizer
_tokzr = model.s3gen.tokenizer
_orig_log_mel = _tokzr.log_mel_spectrogram.__func__
def _f32_log_mel(self_tokzr, audio, padding=0):
import torch as _torch
if _torch.is_tensor(audio):
audio = audio.float()
return _orig_log_mel(self_tokzr, audio, padding)
_tokzr.log_mel_spectrogram = types.MethodType(_f32_log_mel, _tokzr)
# Patch VoiceEncoder
_ve = model.ve
_orig_ve_forward = _ve.forward.__func__
def _f32_ve_forward(self_ve, mels):
return _orig_ve_forward(self_ve, mels.float())
_ve.forward = types.MethodType(_f32_ve_forward, _ve)

View File

@@ -0,0 +1,226 @@
"""
Chatterbox TTS backend implementation.
Wraps ChatterboxMultilingualTTS from chatterbox-tts for zero-shot
voice cloning. Supports 23 languages including Hebrew. Forces CPU
on macOS due to known MPS tensor issues.
"""
import asyncio
import logging
import threading
from pathlib import Path
from typing import ClassVar, List, Optional, Tuple
import numpy as np
from . import TTSBackend
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
patch_chatterbox_f32,
)
logger = logging.getLogger(__name__)
CHATTERBOX_HF_REPO = "ResembleAI/chatterbox"
# Files that must be present for the multilingual model
_MTL_WEIGHT_FILES = [
"t3_mtl23ls_v2.safetensors",
"s3gen.pt",
"ve.pt",
]
class ChatterboxTTSBackend:
"""Chatterbox Multilingual TTS backend for voice cloning."""
# Class-level lock for torch.load monkey-patching
_load_lock: ClassVar[threading.Lock] = threading.Lock()
def __init__(self):
self.model = None
self.model_size = "default"
self._device = None
self._model_load_lock = asyncio.Lock()
def _get_device(self) -> str:
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
def is_loaded(self) -> bool:
return self.model is not None
def _get_model_path(self, model_size: str = "default") -> str:
return CHATTERBOX_HF_REPO
def _is_model_cached(self, model_size: str = "default") -> bool:
return is_model_cached(CHATTERBOX_HF_REPO, required_files=_MTL_WEIGHT_FILES)
async def load_model(self, model_size: str = "default") -> None:
"""Load the Chatterbox multilingual model."""
if self.model is not None:
return
async with self._model_load_lock:
if self.model is not None:
return
await asyncio.to_thread(self._load_model_sync)
def _load_model_sync(self):
"""Synchronous model loading."""
model_name = "chatterbox-tts"
is_cached = self._is_model_cached()
with model_load_progress(model_name, is_cached):
device = self._get_device()
self._device = device
logger.info(f"Loading Chatterbox Multilingual TTS on {device}...")
import torch
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
if device == "cpu":
_orig_torch_load = torch.load
def _patched_load(*args, **kwargs):
kwargs.setdefault("map_location", "cpu")
return _orig_torch_load(*args, **kwargs)
with ChatterboxTTSBackend._load_lock:
torch.load = _patched_load
try:
model = ChatterboxMultilingualTTS.from_pretrained(device=device)
finally:
torch.load = _orig_torch_load
else:
model = ChatterboxMultilingualTTS.from_pretrained(device=device)
# Fix sdpa attention for output_attentions support
t3_tfmr = model.t3.tfmr
if hasattr(t3_tfmr, "config") and hasattr(t3_tfmr.config, "_attn_implementation"):
t3_tfmr.config._attn_implementation = "eager"
for layer in getattr(t3_tfmr, "layers", []):
if hasattr(layer, "self_attn"):
layer.self_attn._attn_implementation = "eager"
patch_chatterbox_f32(model)
self.model = model
logger.info("Chatterbox Multilingual TTS loaded successfully")
def unload_model(self) -> None:
"""Unload model to free memory."""
if self.model is not None:
device = self._device
del self.model
self.model = None
self._device = None
empty_device_cache(device)
logger.info("Chatterbox unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio.
Chatterbox processes reference audio at generation time, so the
prompt just stores the file path. The actual audio is loaded by
model.generate() via audio_prompt_path.
"""
voice_prompt = {
"ref_audio": str(audio_path),
"ref_text": reference_text,
}
return voice_prompt, False
async def combine_voice_prompts(
self,
audio_paths: List[str],
reference_texts: List[str],
) -> Tuple[np.ndarray, str]:
return await _combine_voice_prompts(audio_paths, reference_texts)
# Per-language generation defaults. Lower temp + higher cfg = clearer speech.
_LANG_DEFAULTS: ClassVar[dict] = {
"he": {
"exaggeration": 0.4,
"cfg_weight": 0.7,
"temperature": 0.65,
"repetition_penalty": 2.5,
},
}
_GLOBAL_DEFAULTS: ClassVar[dict] = {
"exaggeration": 0.5,
"cfg_weight": 0.5,
"temperature": 0.8,
"repetition_penalty": 2.0,
}
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio using Chatterbox Multilingual TTS.
Args:
text: Text to synthesize
voice_prompt: Dict with ref_audio path
language: BCP-47 language code
seed: Random seed for reproducibility
instruct: Unused (protocol compatibility)
Returns:
Tuple of (audio_array, sample_rate)
"""
await self.load_model()
ref_audio = voice_prompt.get("ref_audio")
if ref_audio and not Path(ref_audio).exists():
logger.warning(f"Reference audio not found: {ref_audio}")
ref_audio = None
# Merge language-specific defaults with global defaults
lang_defaults = self._LANG_DEFAULTS.get(language, self._GLOBAL_DEFAULTS)
def _generate_sync():
import torch
if seed is not None:
manual_seed(seed, self._device)
logger.info(f"[Chatterbox] Generating: lang={language}")
wav = self.model.generate(
text,
language_id=language,
audio_prompt_path=ref_audio,
exaggeration=lang_defaults["exaggeration"],
cfg_weight=lang_defaults["cfg_weight"],
temperature=lang_defaults["temperature"],
repetition_penalty=lang_defaults["repetition_penalty"],
)
# Convert tensor -> numpy
if isinstance(wav, torch.Tensor):
audio = wav.squeeze().cpu().numpy().astype(np.float32)
else:
audio = np.asarray(wav, dtype=np.float32)
sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000)
return audio, sample_rate
return await asyncio.to_thread(_generate_sync)

View File

@@ -0,0 +1,206 @@
"""
Chatterbox Turbo TTS backend implementation.
Wraps ChatterboxTurboTTS from chatterbox-tts for fast, English-only
voice cloning with paralinguistic tag support ([laugh], [cough], etc.).
Forces CPU on macOS due to known MPS tensor issues.
"""
import asyncio
import logging
import threading
from pathlib import Path
from typing import ClassVar, List, Optional, Tuple
import numpy as np
from . import TTSBackend
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
patch_chatterbox_f32,
)
logger = logging.getLogger(__name__)
CHATTERBOX_TURBO_HF_REPO = "ResembleAI/chatterbox-turbo"
# Files that must be present for the turbo model
_TURBO_WEIGHT_FILES = [
"t3_turbo_v1.safetensors",
"s3gen_meanflow.safetensors",
"ve.safetensors",
]
class ChatterboxTurboTTSBackend:
"""Chatterbox Turbo TTS backend — fast, English-only, with paralinguistic tags."""
# Class-level lock for torch.load monkey-patching
_load_lock: ClassVar[threading.Lock] = threading.Lock()
def __init__(self):
self.model = None
self.model_size = "default"
self._device = None
self._model_load_lock = asyncio.Lock()
def _get_device(self) -> str:
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
def is_loaded(self) -> bool:
return self.model is not None
def _get_model_path(self, model_size: str = "default") -> str:
return CHATTERBOX_TURBO_HF_REPO
def _is_model_cached(self, model_size: str = "default") -> bool:
return is_model_cached(CHATTERBOX_TURBO_HF_REPO, required_files=_TURBO_WEIGHT_FILES)
async def load_model(self, model_size: str = "default") -> None:
"""Load the Chatterbox Turbo model."""
if self.model is not None:
return
async with self._model_load_lock:
if self.model is not None:
return
await asyncio.to_thread(self._load_model_sync)
def _load_model_sync(self):
"""Synchronous model loading."""
model_name = "chatterbox-turbo"
is_cached = self._is_model_cached()
with model_load_progress(model_name, is_cached):
device = self._get_device()
self._device = device
logger.info(f"Loading Chatterbox Turbo TTS on {device}...")
import torch
from huggingface_hub import snapshot_download
from chatterbox.tts_turbo import ChatterboxTurboTTS
local_path = snapshot_download(
repo_id=CHATTERBOX_TURBO_HF_REPO,
token=None,
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"],
)
if device == "cpu":
_orig_torch_load = torch.load
def _patched_load(*args, **kwargs):
kwargs.setdefault("map_location", "cpu")
return _orig_torch_load(*args, **kwargs)
with ChatterboxTurboTTSBackend._load_lock:
torch.load = _patched_load
try:
model = ChatterboxTurboTTS.from_local(local_path, device)
finally:
torch.load = _orig_torch_load
else:
model = ChatterboxTurboTTS.from_local(local_path, device)
patch_chatterbox_f32(model)
self.model = model
logger.info("Chatterbox Turbo TTS loaded successfully")
def unload_model(self) -> None:
"""Unload model to free memory."""
if self.model is not None:
device = self._device
del self.model
self.model = None
self._device = None
empty_device_cache(device)
logger.info("Chatterbox Turbo unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio.
Chatterbox Turbo processes reference audio at generation time, so the
prompt just stores the file path.
"""
voice_prompt = {
"ref_audio": str(audio_path),
"ref_text": reference_text,
}
return voice_prompt, False
async def combine_voice_prompts(
self,
audio_paths: List[str],
reference_texts: List[str],
) -> Tuple[np.ndarray, str]:
return await _combine_voice_prompts(audio_paths, reference_texts)
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio using Chatterbox Turbo TTS.
Supports paralinguistic tags in text: [laugh], [cough], [chuckle], etc.
Args:
text: Text to synthesize (may include paralinguistic tags)
voice_prompt: Dict with ref_audio path
language: Ignored (Turbo is English-only)
seed: Random seed for reproducibility
instruct: Unused (protocol compatibility)
Returns:
Tuple of (audio_array, sample_rate)
"""
await self.load_model()
ref_audio = voice_prompt.get("ref_audio")
if ref_audio and not Path(ref_audio).exists():
logger.warning(f"Reference audio not found: {ref_audio}")
ref_audio = None
def _generate_sync():
import torch
if seed is not None:
manual_seed(seed, self._device)
logger.info("[Chatterbox Turbo] Generating (English)")
wav = self.model.generate(
text,
audio_prompt_path=ref_audio,
temperature=0.8,
top_k=1000,
top_p=0.95,
repetition_penalty=1.2,
)
# Convert tensor -> numpy
if isinstance(wav, torch.Tensor):
audio = wav.squeeze().cpu().numpy().astype(np.float32)
else:
audio = np.asarray(wav, dtype=np.float32)
sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000)
return audio, sample_rate
return await asyncio.to_thread(_generate_sync)

View File

@@ -0,0 +1,346 @@
"""
HumeAI TADA TTS backend implementation.
Wraps HumeAI's TADA (Text-Acoustic Dual Alignment) model for
high-quality voice cloning. Two model variants:
- tada-1b: English-only, ~2B params (Llama 3.2 1B base)
- tada-3b-ml: Multilingual, ~4B params (Llama 3.2 3B base)
Both use a shared encoder/codec (HumeAI/tada-codec). The encoder
produces 1:1 aligned token embeddings from reference audio, and the
causal LM generates speech via flow-matching diffusion.
24kHz output, bf16 inference on CUDA, fp32 on CPU.
"""
import asyncio
import logging
import threading
from typing import ClassVar, List, Optional, Tuple
import numpy as np
from . import TTSBackend
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
)
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
logger = logging.getLogger(__name__)
# HuggingFace repos
TADA_CODEC_REPO = "HumeAI/tada-codec"
TADA_1B_REPO = "HumeAI/tada-1b"
TADA_3B_ML_REPO = "HumeAI/tada-3b-ml"
TADA_MODEL_REPOS = {
"1B": TADA_1B_REPO,
"3B": TADA_3B_ML_REPO,
}
# Key weight files for cache detection
_TADA_MODEL_WEIGHT_FILES = [
"model.safetensors",
]
_TADA_CODEC_WEIGHT_FILES = [
"encoder/model.safetensors",
]
class HumeTadaBackend:
"""HumeAI TADA TTS backend for high-quality voice cloning."""
_load_lock: ClassVar[threading.Lock] = threading.Lock()
def __init__(self):
self.model = None
self.encoder = None
self.model_size = "1B" # default to 1B
self._device = None
self._model_load_lock = asyncio.Lock()
def _get_device(self) -> str:
# Force CPU on macOS — MPS has issues with flow matching
# and large vocab lm_head (>65536 output channels)
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
def is_loaded(self) -> bool:
return self.model is not None
def _get_model_path(self, model_size: str = "1B") -> str:
return TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO)
def _is_model_cached(self, model_size: str = "1B") -> bool:
repo = TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO)
model_cached = is_model_cached(repo, required_files=_TADA_MODEL_WEIGHT_FILES)
codec_cached = is_model_cached(TADA_CODEC_REPO, required_files=_TADA_CODEC_WEIGHT_FILES)
return model_cached and codec_cached
async def load_model(self, model_size: str = "1B") -> None:
"""Load the TADA model and encoder."""
if self.model is not None and self.model_size == model_size:
return
async with self._model_load_lock:
if self.model is not None and self.model_size == model_size:
return
# Unload existing model if switching sizes
if self.model is not None:
self.unload_model()
self.model_size = model_size
await asyncio.to_thread(self._load_model_sync, model_size)
def _load_model_sync(self, model_size: str = "1B"):
"""Synchronous model loading with progress tracking."""
model_name = f"tada-{model_size.lower()}"
is_cached = self._is_model_cached(model_size)
repo = TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO)
with model_load_progress(model_name, is_cached):
# Install DAC shim before importing tada — tada's encoder/decoder
# import dac.nn.layers.Snake1d which requires the descript-audio-codec
# package. The real package pulls in onnx/tensorboard/matplotlib via
# descript-audiotools, so we use a lightweight shim instead.
from ..utils.dac_shim import install_dac_shim
install_dac_shim()
import torch
from huggingface_hub import snapshot_download
device = self._get_device()
self._device = device
logger.info(f"Loading HumeAI TADA {model_size} on {device}...")
# Download codec (encoder + decoder) if not cached
logger.info("Downloading TADA codec...")
snapshot_download(
repo_id=TADA_CODEC_REPO,
token=None,
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.bin"],
)
# Download model weights if not cached
logger.info(f"Downloading TADA {model_size} model...")
snapshot_download(
repo_id=repo,
token=None,
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.bin", "*.model"],
)
# TADA hardcodes "meta-llama/Llama-3.2-1B" as the tokenizer
# source in its Aligner and TadaForCausalLM.from_pretrained().
# That repo is gated (requires Meta license acceptance).
# Download the tokenizer from an ungated mirror and get its
# local cache path so we can point TADA at it directly.
logger.info("Downloading Llama tokenizer (ungated mirror)...")
tokenizer_path = snapshot_download(
repo_id="unsloth/Llama-3.2-1B",
token=None,
allow_patterns=["tokenizer*", "special_tokens*"],
)
# Determine dtype — use bf16 on CUDA/XPU for ~50% memory savings
if device == "cuda" and torch.cuda.is_bf16_supported():
model_dtype = torch.bfloat16
elif device == "xpu":
# Intel Arc (Alchemist+) supports bf16 natively
model_dtype = torch.bfloat16
else:
model_dtype = torch.float32
# Patch the Aligner config class to use the local tokenizer
# path instead of the gated "meta-llama/Llama-3.2-1B" default.
# This avoids monkey-patching AutoTokenizer.from_pretrained
# which corrupts the classmethod descriptor for other engines.
from tada.modules.aligner import AlignerConfig
AlignerConfig.tokenizer_name = tokenizer_path
# Load encoder (only needed for voice prompt encoding)
from tada.modules.encoder import Encoder
logger.info("Loading TADA encoder...")
self.encoder = Encoder.from_pretrained(TADA_CODEC_REPO, subfolder="encoder").to(device)
self.encoder.eval()
# Load the causal LM (includes decoder for wav generation).
# TadaForCausalLM.from_pretrained() calls
# getattr(config, "tokenizer_name", "meta-llama/Llama-3.2-1B")
# which hits the gated repo. Pre-load the config from HF,
# inject the local tokenizer path, then pass it in.
from tada.modules.tada import TadaForCausalLM, TadaConfig
logger.info(f"Loading TADA {model_size} model...")
config = TadaConfig.from_pretrained(repo)
config.tokenizer_name = tokenizer_path
self.model = TadaForCausalLM.from_pretrained(repo, config=config, torch_dtype=model_dtype).to(device)
self.model.eval()
logger.info(f"HumeAI TADA {model_size} loaded successfully on {device}")
def unload_model(self) -> None:
"""Unload model and encoder to free memory."""
if self.model is not None:
del self.model
self.model = None
if self.encoder is not None:
del self.encoder
self.encoder = None
device = self._device
self._device = None
if device:
empty_device_cache(device)
logger.info("HumeAI TADA unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio using TADA's encoder.
TADA's encoder performs forced alignment between audio and text tokens,
producing an EncoderOutput with 1:1 token-audio alignment. If no
reference_text is provided, the encoder uses built-in ASR (English only).
We serialize the EncoderOutput to a dict for caching.
"""
await self.load_model(self.model_size)
cache_key = ("tada_" + get_cache_key(audio_path, reference_text)) if use_cache else None
if cache_key:
cached = get_cached_voice_prompt(cache_key)
if cached is not None and isinstance(cached, dict):
return cached, True
def _encode_sync():
import torch
import soundfile as sf
device = self._device
# Load audio with soundfile (torchaudio 2.10+ requires torchcodec)
audio_np, sr = sf.read(str(audio_path), dtype="float32")
audio = torch.from_numpy(audio_np).float()
if audio.ndim == 1:
audio = audio.unsqueeze(0) # (samples,) -> (1, samples)
else:
audio = audio.T # (samples, channels) -> (channels, samples)
audio = audio.to(device)
# Encode with forced alignment
text_arg = [reference_text] if reference_text else None
prompt = self.encoder(audio, text=text_arg, sample_rate=sr)
# Serialize EncoderOutput to a dict of CPU tensors for caching
prompt_dict = {}
for field_name in prompt.__dataclass_fields__:
val = getattr(prompt, field_name)
if isinstance(val, torch.Tensor):
prompt_dict[field_name] = val.detach().cpu()
elif isinstance(val, list):
prompt_dict[field_name] = val
elif isinstance(val, (int, float)):
prompt_dict[field_name] = val
else:
prompt_dict[field_name] = val
return prompt_dict
encoded = await asyncio.to_thread(_encode_sync)
if cache_key:
cache_voice_prompt(cache_key, encoded)
return encoded, False
async def combine_voice_prompts(
self,
audio_paths: List[str],
reference_texts: List[str],
) -> Tuple[np.ndarray, str]:
return await _combine_voice_prompts(audio_paths, reference_texts, sample_rate=24000)
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio from text using HumeAI TADA.
Args:
text: Text to synthesize
voice_prompt: Serialized EncoderOutput dict from create_voice_prompt()
language: Language code (en, ar, de, es, fr, it, ja, pl, pt, zh)
seed: Random seed for reproducibility
instruct: Not supported by TADA (ignored)
Returns:
Tuple of (audio_array, sample_rate=24000)
"""
await self.load_model(self.model_size)
def _generate_sync():
import torch
from tada.modules.encoder import EncoderOutput
if seed is not None:
manual_seed(seed, self._device)
device = self._device
# Reconstruct EncoderOutput from the cached dict
restored = {}
for k, v in voice_prompt.items():
if isinstance(v, torch.Tensor):
# Move to device and match model dtype for float tensors
if v.is_floating_point():
model_dtype = next(self.model.parameters()).dtype
restored[k] = v.to(device=device, dtype=model_dtype)
else:
restored[k] = v.to(device=device)
else:
restored[k] = v
prompt = EncoderOutput(**restored)
# For non-English with the 3B-ML model, we could reload the
# encoder with the language-specific aligner. However, the
# generation itself is language-agnostic — only the encoder's
# aligner changes. Since we encode at create_voice_prompt time,
# the language is already baked in. For simplicity, we don't
# reload the encoder here.
logger.info(f"[TADA] Generating ({language}), text length: {len(text)}")
output = self.model.generate(
prompt=prompt,
text=text,
)
# output.audio is a list of tensors (one per batch item)
if output.audio and output.audio[0] is not None:
audio_tensor = output.audio[0]
audio = audio_tensor.detach().cpu().numpy().squeeze().astype(np.float32)
else:
logger.warning("[TADA] Generation produced no audio")
audio = np.zeros(24000, dtype=np.float32)
return audio, 24000
return await asyncio.to_thread(_generate_sync)

View File

@@ -0,0 +1,288 @@
"""
Kokoro TTS backend implementation.
Wraps the Kokoro-82M model for fast, lightweight text-to-speech.
82M parameters, CPU realtime, 24kHz output, Apache 2.0 license.
Kokoro uses pre-built voice style vectors (not traditional zero-shot cloning
from arbitrary audio). Voice prompts are stored as deferred references to
HF-hosted voice .pt files.
Languages supported (via misaki G2P):
- American English (a), British English (b)
- Spanish (e), French (f), Hindi (h), Italian (i), Portuguese (p)
- Japanese (j) — requires misaki[ja]
- Chinese (z) — requires misaki[zh]
"""
import asyncio
import logging
import os
from typing import Optional
import numpy as np
from . import TTSBackend
from .base import (
get_torch_device,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
)
logger = logging.getLogger(__name__)
# HuggingFace repo for model + voice detection
KOKORO_HF_REPO = "hexgrad/Kokoro-82M"
KOKORO_SAMPLE_RATE = 24000
# Default voice if none specified
KOKORO_DEFAULT_VOICE = "af_heart"
# All available Kokoro voices: (voice_id, display_name, gender, lang_code)
KOKORO_VOICES = [
# American English female
("af_alloy", "Alloy", "female", "en"),
("af_aoede", "Aoede", "female", "en"),
("af_bella", "Bella", "female", "en"),
("af_heart", "Heart", "female", "en"),
("af_jessica", "Jessica", "female", "en"),
("af_kore", "Kore", "female", "en"),
("af_nicole", "Nicole", "female", "en"),
("af_nova", "Nova", "female", "en"),
("af_river", "River", "female", "en"),
("af_sarah", "Sarah", "female", "en"),
("af_sky", "Sky", "female", "en"),
# American English male
("am_adam", "Adam", "male", "en"),
("am_echo", "Echo", "male", "en"),
("am_eric", "Eric", "male", "en"),
("am_fenrir", "Fenrir", "male", "en"),
("am_liam", "Liam", "male", "en"),
("am_michael", "Michael", "male", "en"),
("am_onyx", "Onyx", "male", "en"),
("am_puck", "Puck", "male", "en"),
("am_santa", "Santa", "male", "en"),
# British English female
("bf_alice", "Alice", "female", "en"),
("bf_emma", "Emma", "female", "en"),
("bf_isabella", "Isabella", "female", "en"),
("bf_lily", "Lily", "female", "en"),
# British English male
("bm_daniel", "Daniel", "male", "en"),
("bm_fable", "Fable", "male", "en"),
("bm_george", "George", "male", "en"),
("bm_lewis", "Lewis", "male", "en"),
# Spanish
("ef_dora", "Dora", "female", "es"),
("em_alex", "Alex", "male", "es"),
("em_santa", "Santa", "male", "es"),
# French
("ff_siwis", "Siwis", "female", "fr"),
# Hindi
("hf_alpha", "Alpha", "female", "hi"),
("hf_beta", "Beta", "female", "hi"),
("hm_omega", "Omega", "male", "hi"),
("hm_psi", "Psi", "male", "hi"),
# Italian
("if_sara", "Sara", "female", "it"),
("im_nicola", "Nicola", "male", "it"),
# Japanese
("jf_alpha", "Alpha", "female", "ja"),
("jf_gongitsune", "Gongitsune", "female", "ja"),
("jf_nezumi", "Nezumi", "female", "ja"),
("jf_tebukuro", "Tebukuro", "female", "ja"),
("jm_kumo", "Kumo", "male", "ja"),
# Portuguese
("pf_dora", "Dora", "female", "pt"),
("pm_alex", "Alex", "male", "pt"),
("pm_santa", "Santa", "male", "pt"),
# Chinese
("zf_xiaobei", "Xiaobei", "female", "zh"),
("zf_xiaoni", "Xiaoni", "female", "zh"),
("zf_xiaoxiao", "Xiaoxiao", "female", "zh"),
("zf_xiaoyi", "Xiaoyi", "female", "zh"),
]
# Map our ISO language codes to Kokoro lang_code characters
LANG_CODE_MAP = {
"en": "a", # American English
"es": "e",
"fr": "f",
"hi": "h",
"it": "i",
"pt": "p",
"ja": "j",
"zh": "z",
}
class KokoroTTSBackend:
"""Kokoro-82M TTS backend — tiny, fast, CPU-friendly."""
def __init__(self):
self._model = None
self._pipelines: dict = {} # lang_code -> KPipeline
self._device: Optional[str] = None
self.model_size = "default"
def _get_device(self) -> str:
"""Select device. Kokoro supports CUDA and CPU. MPS needs fallback env var."""
device = get_torch_device(allow_mps=False)
# Kokoro can use MPS but requires PYTORCH_ENABLE_MPS_FALLBACK=1
# For now, skip MPS to avoid user confusion — CPU is already realtime
return device
@property
def device(self) -> str:
if self._device is None:
self._device = self._get_device()
return self._device
def is_loaded(self) -> bool:
return self._model is not None
def _get_model_path(self, model_size: str) -> str:
return KOKORO_HF_REPO
def _is_model_cached(self, model_size: str = "default") -> bool:
"""Check if Kokoro model files are cached locally."""
from .base import is_model_cached
return is_model_cached(
KOKORO_HF_REPO,
required_files=["config.json", "kokoro-v1_0.pth"],
)
async def load_model(self, model_size: str = "default") -> None:
"""Load the Kokoro model."""
if self._model is not None:
return
await asyncio.to_thread(self._load_model_sync)
def _load_model_sync(self):
"""Synchronous model loading."""
model_name = "kokoro"
is_cached = self._is_model_cached()
with model_load_progress(model_name, is_cached):
from kokoro import KModel
device = self.device
logger.info(f"Loading Kokoro-82M on {device}...")
self._model = KModel(repo_id=KOKORO_HF_REPO).to(device).eval()
logger.info("Kokoro-82M loaded successfully")
def _get_pipeline(self, lang_code: str):
"""Get or create a KPipeline for the given language code."""
kokoro_lang = LANG_CODE_MAP.get(lang_code, "a")
if kokoro_lang not in self._pipelines:
from kokoro import KPipeline
# Create pipeline with our existing model (no redundant model loading)
self._pipelines[kokoro_lang] = KPipeline(
lang_code=kokoro_lang,
repo_id=KOKORO_HF_REPO,
model=self._model,
)
return self._pipelines[kokoro_lang]
def unload_model(self) -> None:
"""Unload model to free memory."""
if self._model is not None:
del self._model
self._model = None
self._pipelines.clear()
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Kokoro unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> tuple[dict, bool]:
"""
Create voice prompt for Kokoro.
Kokoro doesn't do traditional voice cloning from arbitrary audio.
When called for a cloned profile (fallback), uses the default voice.
For preset profiles, the voice_prompt dict is built by the profile
service and bypasses this method entirely.
"""
return {
"voice_type": "preset",
"preset_engine": "kokoro",
"preset_voice_id": KOKORO_DEFAULT_VOICE,
}, False
async def combine_voice_prompts(
self,
audio_paths: list[str],
reference_texts: list[str],
) -> tuple[np.ndarray, str]:
"""Combine voice prompts — uses base implementation for audio concatenation."""
return await _combine_voice_prompts(
audio_paths, reference_texts, sample_rate=KOKORO_SAMPLE_RATE
)
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> tuple[np.ndarray, int]:
"""
Generate audio from text using Kokoro.
Args:
text: Text to synthesize
voice_prompt: Dict with kokoro_voice key
language: Language code
seed: Random seed for reproducibility
instruct: Not supported by Kokoro (ignored)
Returns:
Tuple of (audio_array, sample_rate)
"""
await self.load_model()
voice_name = voice_prompt.get("preset_voice_id") or voice_prompt.get("kokoro_voice") or KOKORO_DEFAULT_VOICE
def _generate_sync():
import torch
if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
pipeline = self._get_pipeline(language)
# Generate all chunks and concatenate
audio_chunks = []
for result in pipeline(text, voice=voice_name, speed=1.0):
if result.audio is not None:
chunk = result.audio
if isinstance(chunk, torch.Tensor):
chunk = chunk.detach().cpu().numpy()
audio_chunks.append(chunk.squeeze())
if not audio_chunks:
# Return 1 second of silence as fallback
return np.zeros(KOKORO_SAMPLE_RATE, dtype=np.float32), KOKORO_SAMPLE_RATE
audio = np.concatenate(audio_chunks)
return audio.astype(np.float32), KOKORO_SAMPLE_RATE
return await asyncio.to_thread(_generate_sync)

View File

@@ -0,0 +1,184 @@
"""
LuxTTS backend implementation.
Wraps the LuxTTS (ZipVoice) model for zero-shot voice cloning.
~1GB VRAM, 48kHz output, 150x realtime on CPU.
"""
import asyncio
import logging
from typing import Optional, Tuple
import numpy as np
from . import TTSBackend
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
)
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
logger = logging.getLogger(__name__)
# HuggingFace repo for model weight detection
LUXTTS_HF_REPO = "YatharthS/LuxTTS"
class LuxTTSBackend:
"""LuxTTS backend for zero-shot voice cloning."""
def __init__(self):
self.model = None
self.model_size = "default" # LuxTTS has only one model size
self._device = None
def _get_device(self) -> str:
return get_torch_device(allow_mps=True, allow_xpu=True)
def is_loaded(self) -> bool:
return self.model is not None
@property
def device(self) -> str:
if self._device is None:
self._device = self._get_device()
return self._device
def _get_model_path(self, model_size: str) -> str:
return LUXTTS_HF_REPO
def _is_model_cached(self, model_size: str = "default") -> bool:
return is_model_cached(
LUXTTS_HF_REPO,
weight_extensions=(".pt", ".safetensors", ".onnx", ".bin"),
)
async def load_model(self, model_size: str = "default") -> None:
"""Load the LuxTTS model."""
if self.model is not None:
return
await asyncio.to_thread(self._load_model_sync)
def _load_model_sync(self):
model_name = "luxtts"
is_cached = self._is_model_cached()
with model_load_progress(model_name, is_cached):
from zipvoice.luxvoice import LuxTTS
device = self.device
logger.info(f"Loading LuxTTS on {device}...")
if device == "cpu":
import os
threads = os.cpu_count() or 4
self.model = LuxTTS(
model_path=LUXTTS_HF_REPO,
device="cpu",
threads=min(threads, 8),
)
else:
self.model = LuxTTS(model_path=LUXTTS_HF_REPO, device=device)
logger.info("LuxTTS loaded successfully")
def unload_model(self) -> None:
"""Unload model to free memory."""
if self.model is not None:
device = self.device
del self.model
self.model = None
self._device = None
empty_device_cache(device)
logger.info("LuxTTS unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio.
LuxTTS uses its own encode_prompt() which runs Whisper ASR internally
to transcribe the reference. The reference_text parameter is not used
by LuxTTS itself, but we include it in the cache key for consistency.
"""
await self.load_model()
# Compute cache key once for both lookup and storage
cache_key = ("luxtts_" + get_cache_key(audio_path, reference_text)) if use_cache else None
if cache_key:
cached = get_cached_voice_prompt(cache_key)
if cached is not None and isinstance(cached, dict):
return cached, True
def _encode_sync():
return self.model.encode_prompt(
prompt_audio=str(audio_path),
duration=5,
rms=0.01,
)
encoded = await asyncio.to_thread(_encode_sync)
if cache_key:
cache_voice_prompt(cache_key, encoded)
return encoded, False
async def combine_voice_prompts(self, audio_paths, reference_texts):
return await _combine_voice_prompts(audio_paths, reference_texts, sample_rate=24000)
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio from text using LuxTTS.
Args:
text: Text to synthesize
voice_prompt: Encoded prompt dict from encode_prompt()
language: Language code (LuxTTS is English-focused)
seed: Random seed for reproducibility
instruct: Not supported by LuxTTS (ignored)
Returns:
Tuple of (audio_array, sample_rate)
"""
await self.load_model()
def _generate_sync():
if seed is not None:
manual_seed(seed, self.device)
wav = self.model.generate_speech(
text=text,
encode_dict=voice_prompt,
num_steps=4,
guidance_scale=3.0,
t_shift=0.5,
speed=1.0,
return_smooth=False, # 48kHz output
)
# LuxTTS returns a tensor (may be on GPU/MPS), move to CPU first
audio = wav.detach().cpu().numpy().squeeze()
return audio, 48000
return await asyncio.to_thread(_generate_sync)

View File

@@ -0,0 +1,367 @@
"""
MLX backend implementation for TTS and STT using mlx-audio.
"""
from typing import Optional, List, Tuple
import asyncio
import logging
import numpy as np
from pathlib import Path
logger = logging.getLogger(__name__)
# PATCH: Import and apply offline patch BEFORE any huggingface_hub usage
# This prevents mlx_audio from making network requests when models are cached
from ..utils.hf_offline_patch import patch_huggingface_hub_offline, ensure_original_qwen_config_cached
patch_huggingface_hub_offline()
ensure_original_qwen_config_cached()
from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS
from .base import is_model_cached, combine_voice_prompts as _combine_voice_prompts, model_load_progress
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
class MLXTTSBackend:
"""MLX-based TTS backend using mlx-audio."""
def __init__(self, model_size: str = "1.7B"):
self.model = None
self.model_size = model_size
self._current_model_size = None
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self.model is not None
def _get_model_path(self, model_size: str) -> str:
"""
Get the MLX model path.
Args:
model_size: Model size (1.7B or 0.6B)
Returns:
HuggingFace Hub model ID for MLX
"""
mlx_model_map = {
"1.7B": "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16",
"0.6B": "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-bf16",
}
if model_size not in mlx_model_map:
raise ValueError(f"Unknown model size: {model_size}")
hf_model_id = mlx_model_map[model_size]
logger.info("Will download MLX model from HuggingFace Hub: %s", hf_model_id)
return hf_model_id
def _is_model_cached(self, model_size: str) -> bool:
return is_model_cached(
self._get_model_path(model_size),
weight_extensions=(".safetensors", ".bin", ".npz"),
)
async def load_model_async(self, model_size: Optional[str] = None):
"""
Lazy load the MLX TTS model.
Args:
model_size: Model size to load (1.7B or 0.6B)
"""
if model_size is None:
model_size = self.model_size
# If already loaded with correct size, return
if self.model is not None and self._current_model_size == model_size:
return
# Unload existing model if different size requested
if self.model is not None and self._current_model_size != model_size:
self.unload_model()
# Run blocking load in thread pool
await asyncio.to_thread(self._load_model_sync, model_size)
# Alias for compatibility
load_model = load_model_async
def _load_model_sync(self, model_size: str):
"""Synchronous model loading."""
model_path = self._get_model_path(model_size)
model_name = f"qwen-tts-{model_size}"
is_cached = self._is_model_cached(model_size)
with model_load_progress(model_name, is_cached):
from mlx_audio.tts import load
logger.info("Loading MLX TTS model %s...", model_size)
self.model = load(model_path)
self._current_model_size = model_size
self.model_size = model_size
logger.info("MLX TTS model %s loaded successfully", model_size)
def unload_model(self):
"""Unload the model to free memory."""
if self.model is not None:
del self.model
self.model = None
self._current_model_size = None
logger.info("MLX TTS model unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio.
MLX backend stores voice prompt as a dict with audio path and text.
The actual voice prompt processing happens during generation.
Args:
audio_path: Path to reference audio file
reference_text: Transcript of reference audio
use_cache: Whether to use cached prompt if available
Returns:
Tuple of (voice_prompt_dict, was_cached)
"""
await self.load_model_async(None)
# Check cache if enabled
if use_cache:
cache_key = get_cache_key(audio_path, reference_text)
cached_prompt = get_cached_voice_prompt(cache_key)
if cached_prompt is not None:
# Return cached prompt (should be dict format)
if isinstance(cached_prompt, dict):
# Validate that the cached audio file still exists
cached_audio_path = cached_prompt.get("ref_audio") or cached_prompt.get("ref_audio_path")
if cached_audio_path and Path(cached_audio_path).exists():
return cached_prompt, True
else:
# Cached file no longer exists, invalidate cache
logger.warning("Cached audio file not found: %s, regenerating prompt", cached_audio_path)
# MLX voice prompt format - store audio path and text
# The model will process this during generation
voice_prompt_items = {
"ref_audio": str(audio_path),
"ref_text": reference_text,
}
# Cache if enabled
if use_cache:
cache_key = get_cache_key(audio_path, reference_text)
cache_voice_prompt(cache_key, voice_prompt_items)
return voice_prompt_items, False
async def combine_voice_prompts(self, audio_paths, reference_texts):
return await _combine_voice_prompts(audio_paths, reference_texts)
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio from text using voice prompt.
Args:
text: Text to synthesize
voice_prompt: Voice prompt dictionary with ref_audio and ref_text
language: Language code (en or zh) - may not be fully supported by MLX
seed: Random seed for reproducibility
instruct: Natural language instruction (may not be supported by MLX)
Returns:
Tuple of (audio_array, sample_rate)
"""
await self.load_model_async(None)
logger.info("Generating audio for text: %s", text)
def _generate_sync():
"""Run synchronous generation in thread pool."""
# MLX generate() returns a generator yielding GenerationResult objects
audio_chunks = []
sample_rate = 24000
lang = LANGUAGE_CODE_TO_NAME.get(language, "auto")
# Set seed if provided (MLX uses numpy random)
if seed is not None:
import mlx.core as mx
np.random.seed(seed)
mx.random.seed(seed)
# Extract voice prompt info
ref_audio = voice_prompt.get("ref_audio") or voice_prompt.get("ref_audio_path")
ref_text = voice_prompt.get("ref_text", "")
# Validate that the audio file exists
if ref_audio and not Path(ref_audio).exists():
logger.warning("Audio file not found: %s", ref_audio)
logger.warning("This may be due to a cached voice prompt referencing a deleted temp file.")
logger.warning("Regenerating without voice prompt.")
ref_audio = None
# Inference runs with the process's default HF_HUB_OFFLINE
# state. Forcing offline here (previously used to avoid lazy
# mlx_audio lookups hanging when the network drops mid-inference,
# issue #462) regressed online users because libraries make
# legitimate metadata calls during generation.
try:
if ref_audio:
# Check if generate accepts ref_audio parameter
import inspect
sig = inspect.signature(self.model.generate)
if "ref_audio" in sig.parameters:
# Generate with voice cloning
for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text, lang_code=lang):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
else:
# Fallback: generate without voice cloning
for result in self.model.generate(text, lang_code=lang):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
else:
# No voice prompt, generate normally
for result in self.model.generate(text, lang_code=lang):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
except Exception as e:
# If voice cloning fails, try without it
logger.warning("Voice cloning failed, generating without voice prompt: %s", e)
for result in self.model.generate(text, lang_code=lang):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
# Concatenate all chunks
if audio_chunks:
audio = np.concatenate([np.asarray(chunk, dtype=np.float32) for chunk in audio_chunks])
else:
# Fallback: empty audio
audio = np.array([], dtype=np.float32)
return audio, sample_rate
# Run blocking inference in thread pool
audio, sample_rate = await asyncio.to_thread(_generate_sync)
return audio, sample_rate
class MLXSTTBackend:
"""MLX-based STT backend using mlx-audio Whisper."""
def __init__(self, model_size: str = "base"):
self.model = None
self.model_size = model_size
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self.model is not None
def _is_model_cached(self, model_size: str) -> bool:
hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
return is_model_cached(hf_repo, weight_extensions=(".safetensors", ".bin", ".npz"))
async def load_model_async(self, model_size: Optional[str] = None):
"""
Lazy load the MLX Whisper model.
Args:
model_size: Model size (tiny, base, small, medium, large)
"""
if model_size is None:
model_size = self.model_size
if self.model is not None and self.model_size == model_size:
return
# Run blocking load in thread pool
await asyncio.to_thread(self._load_model_sync, model_size)
# Alias for compatibility
load_model = load_model_async
def _load_model_sync(self, model_size: str):
"""Synchronous model loading."""
progress_model_name = f"whisper-{model_size}"
is_cached = self._is_model_cached(model_size)
with model_load_progress(progress_model_name, is_cached):
from mlx_audio.stt import load
model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
logger.info("Loading MLX Whisper model %s...", model_size)
self.model = load(model_name)
self.model_size = model_size
logger.info("MLX Whisper model %s loaded successfully", model_size)
def unload_model(self):
"""Unload the model to free memory."""
if self.model is not None:
del self.model
self.model = None
logger.info("MLX Whisper model unloaded")
async def transcribe(
self,
audio_path: str,
language: Optional[str] = None,
model_size: Optional[str] = None,
) -> str:
"""
Transcribe audio to text.
Args:
audio_path: Path to audio file
language: Optional language hint
model_size: Optional model size override
Returns:
Transcribed text
"""
await self.load_model_async(model_size)
def _transcribe_sync():
"""Run synchronous transcription in thread pool."""
# MLX Whisper transcription using generate method
# The generate method accepts audio path directly
decode_options = {}
if language:
decode_options["language"] = language
# Inference runs with the process's default HF_HUB_OFFLINE
# state — see the comment in MLXTTSBackend.generate for the
# regression this revert fixes (issue #462).
result = self.model.generate(str(audio_path), **decode_options)
# Extract text from result
if isinstance(result, str):
return result.strip()
elif isinstance(result, dict):
return result.get("text", "").strip()
elif hasattr(result, "text"):
return result.text.strip()
else:
return str(result).strip()
# Run blocking transcription in thread pool
return await asyncio.to_thread(_transcribe_sync)

View File

@@ -0,0 +1,378 @@
"""
PyTorch backend implementation for TTS and STT.
"""
from typing import Optional, List, Tuple
import asyncio
import logging
import torch
import numpy as np
logger = logging.getLogger(__name__)
from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS
from .base import (
is_model_cached,
get_torch_device,
empty_device_cache,
manual_seed,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
)
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
from ..utils.audio import load_audio
class PyTorchTTSBackend:
"""PyTorch-based TTS backend using Qwen3-TTS."""
def __init__(self, model_size: str = "1.7B"):
self.model = None
self.model_size = model_size
self.device = self._get_device()
self._current_model_size = None
def _get_device(self) -> str:
"""Get the best available device."""
return get_torch_device(allow_xpu=True, allow_directml=True)
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self.model is not None
def _get_model_path(self, model_size: str) -> str:
"""
Get the HuggingFace Hub model ID.
Args:
model_size: Model size (1.7B or 0.6B)
Returns:
HuggingFace Hub model ID
"""
hf_model_map = {
"1.7B": "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
"0.6B": "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
}
if model_size not in hf_model_map:
raise ValueError(f"Unknown model size: {model_size}")
return hf_model_map[model_size]
def _is_model_cached(self, model_size: str) -> bool:
return is_model_cached(self._get_model_path(model_size))
async def load_model_async(self, model_size: Optional[str] = None):
"""
Lazy load the TTS model with automatic downloading from HuggingFace Hub.
Args:
model_size: Model size to load (1.7B or 0.6B)
"""
if model_size is None:
model_size = self.model_size
# If already loaded with correct size, return
if self.model is not None and self._current_model_size == model_size:
return
# Unload existing model if different size requested
if self.model is not None and self._current_model_size != model_size:
self.unload_model()
# Run blocking load in thread pool
await asyncio.to_thread(self._load_model_sync, model_size)
# Alias for compatibility
load_model = load_model_async
def _load_model_sync(self, model_size: str):
"""Synchronous model loading."""
model_name = f"qwen-tts-{model_size}"
is_cached = self._is_model_cached(model_size)
with model_load_progress(model_name, is_cached):
from qwen_tts import Qwen3TTSModel
model_path = self._get_model_path(model_size)
logger.info("Loading TTS model %s on %s...", model_size, self.device)
# Route both HF Hub and Transformers through a single cache root.
# On Windows local setups, model assets can otherwise split between
# .hf-cache/hub and .hf-cache/transformers, causing speech_tokenizer
# and preprocessor_config.json to fail to resolve during load.
from huggingface_hub import constants as hf_constants
tts_cache_dir = hf_constants.HF_HUB_CACHE
if self.device == "cpu":
self.model = Qwen3TTSModel.from_pretrained(
model_path,
cache_dir=tts_cache_dir,
torch_dtype=torch.float32,
low_cpu_mem_usage=False,
)
else:
self.model = Qwen3TTSModel.from_pretrained(
model_path,
cache_dir=tts_cache_dir,
device_map=self.device,
torch_dtype=torch.bfloat16,
)
self._current_model_size = model_size
self.model_size = model_size
logger.info("TTS model %s loaded successfully", model_size)
def unload_model(self):
"""Unload the model to free memory."""
if self.model is not None:
del self.model
self.model = None
self._current_model_size = None
empty_device_cache(self.device)
logger.info("TTS model unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
"""
Create voice prompt from reference audio.
Args:
audio_path: Path to reference audio file
reference_text: Transcript of reference audio
use_cache: Whether to use cached prompt if available
Returns:
Tuple of (voice_prompt_dict, was_cached)
"""
await self.load_model_async(None)
# Check cache if enabled
if use_cache:
cache_key = get_cache_key(audio_path, reference_text)
cached_prompt = get_cached_voice_prompt(cache_key)
if cached_prompt is not None:
# Cache stores as torch.Tensor but actual prompt is dict
# Convert if needed
if isinstance(cached_prompt, dict):
# For PyTorch backend, the dict should contain tensors, not file paths
# So we can safely return it
return cached_prompt, True
elif isinstance(cached_prompt, torch.Tensor):
# Legacy cache format - convert to dict
# This shouldn't happen in practice, but handle it
return {"prompt": cached_prompt}, True
def _create_prompt_sync():
"""Run synchronous voice prompt creation in thread pool."""
# Inference runs with the process's default HF_HUB_OFFLINE
# state. Forcing offline here (issue #462) regressed online
# users whose libraries issue legitimate metadata lookups
# during voice-prompt creation.
return self.model.create_voice_clone_prompt(
ref_audio=str(audio_path),
ref_text=reference_text,
x_vector_only_mode=False,
)
# Run blocking operation in thread pool
voice_prompt_items = await asyncio.to_thread(_create_prompt_sync)
# Cache if enabled
if use_cache:
cache_key = get_cache_key(audio_path, reference_text)
cache_voice_prompt(cache_key, voice_prompt_items)
return voice_prompt_items, False
async def combine_voice_prompts(
self,
audio_paths: List[str],
reference_texts: List[str],
) -> Tuple[np.ndarray, str]:
return await _combine_voice_prompts(audio_paths, reference_texts)
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""
Generate audio from text using voice prompt.
Args:
text: Text to synthesize
voice_prompt: Voice prompt dictionary from create_voice_prompt
language: Language code (en or zh)
seed: Random seed for reproducibility
instruct: Natural language instruction for speech delivery control
Returns:
Tuple of (audio_array, sample_rate)
"""
# Load model
await self.load_model_async(None)
def _generate_sync():
"""Run synchronous generation in thread pool."""
# Set seed if provided
if seed is not None:
manual_seed(seed, self.device)
# See _create_prompt_sync comment — inference runs with the
# process's default HF_HUB_OFFLINE state (issue #462).
wavs, sample_rate = self.model.generate_voice_clone(
text=text,
voice_clone_prompt=voice_prompt,
language=LANGUAGE_CODE_TO_NAME.get(language, "auto"),
instruct=instruct,
)
return wavs[0], sample_rate
# Run blocking inference in thread pool to avoid blocking event loop
audio, sample_rate = await asyncio.to_thread(_generate_sync)
return audio, sample_rate
class PyTorchSTTBackend:
"""PyTorch-based STT backend using Whisper."""
def __init__(self, model_size: str = "base"):
self.model = None
self.processor = None
self.model_size = model_size
self.device = self._get_device()
def _get_device(self) -> str:
"""Get the best available device."""
return get_torch_device(allow_xpu=True, allow_directml=True)
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self.model is not None
def _is_model_cached(self, model_size: str) -> bool:
hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
return is_model_cached(hf_repo)
async def load_model_async(self, model_size: Optional[str] = None):
"""
Lazy load the Whisper model.
Args:
model_size: Model size (tiny, base, small, medium, large)
"""
if model_size is None:
model_size = self.model_size
if self.model is not None and self.model_size == model_size:
return
await asyncio.to_thread(self._load_model_sync, model_size)
# Alias for compatibility
load_model = load_model_async
def _load_model_sync(self, model_size: str):
"""Synchronous model loading."""
progress_model_name = f"whisper-{model_size}"
is_cached = self._is_model_cached(model_size)
with model_load_progress(progress_model_name, is_cached):
from transformers import WhisperProcessor, WhisperForConditionalGeneration
model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
logger.info("Loading Whisper model %s on %s...", model_size, self.device)
self.processor = WhisperProcessor.from_pretrained(model_name)
self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
self.model.to(self.device)
self.model_size = model_size
logger.info("Whisper model %s loaded successfully", model_size)
def unload_model(self):
"""Unload the model to free memory."""
if self.model is not None:
del self.model
del self.processor
self.model = None
self.processor = None
empty_device_cache(self.device)
logger.info("Whisper model unloaded")
async def transcribe(
self,
audio_path: str,
language: Optional[str] = None,
model_size: Optional[str] = None,
) -> str:
"""
Transcribe audio to text.
Args:
audio_path: Path to audio file
language: Optional language hint
model_size: Optional model size override
Returns:
Transcribed text
"""
await self.load_model_async(model_size)
def _transcribe_sync():
"""Run synchronous transcription in thread pool."""
# Load audio
audio, _sr = load_audio(audio_path, sample_rate=16000)
# Inference runs with the process's default HF_HUB_OFFLINE
# state — forcing offline here (issue #462) broke online users
# whose `get_decoder_prompt_ids` / tokenizer calls issue
# legitimate metadata lookups.
# Process audio
inputs = self.processor(
audio,
sampling_rate=16000,
return_tensors="pt",
)
inputs = inputs.to(self.device)
# Generate transcription
# If language is provided, force it; otherwise let Whisper auto-detect
generate_kwargs = {}
if language:
forced_decoder_ids = self.processor.get_decoder_prompt_ids(
language=language,
task="transcribe",
)
generate_kwargs["forced_decoder_ids"] = forced_decoder_ids
with torch.no_grad():
predicted_ids = self.model.generate(
inputs["input_features"],
**generate_kwargs,
)
# Decode
transcription = self.processor.batch_decode(
predicted_ids,
skip_special_tokens=True,
)[0]
return transcription.strip()
# Run blocking transcription in thread pool
return await asyncio.to_thread(_transcribe_sync)

View File

@@ -0,0 +1,214 @@
"""
Qwen3-TTS CustomVoice backend implementation.
Wraps the Qwen3-TTS-12Hz CustomVoice model for preset-speaker TTS with
instruction-based style control. Uses the same qwen_tts library as the
Base model (pytorch_backend.py) but loads a different checkpoint and
calls generate_custom_voice() instead of generate_voice_clone().
Key differences from the Base engine:
- Uses preset speakers (9 built-in voices) instead of zero-shot cloning
- Supports instruct parameter for tone/emotion/prosody control
- Two model sizes: 1.7B and 0.6B
Languages supported: zh, en, ja, ko, de, fr, ru, pt, es, it
"""
import asyncio
import logging
from typing import Optional
import numpy as np
import torch
from . import TTSBackend, LANGUAGE_CODE_TO_NAME
from .base import (
is_model_cached,
get_torch_device,
combine_voice_prompts as _combine_voice_prompts,
model_load_progress,
)
logger = logging.getLogger(__name__)
# ── Preset speakers ──────────────────────────────────────────────────
# (speaker_id, display_name, gender, native_language_code, description)
QWEN_CUSTOM_VOICES = [
("Vivian", "Vivian", "female", "zh", "Bright, slightly edgy young female voice"),
("Serena", "Serena", "female", "zh", "Warm, gentle young female voice"),
("Uncle_Fu", "Uncle Fu", "male", "zh", "Seasoned male voice with a low, mellow timbre"),
("Dylan", "Dylan", "male", "zh", "Youthful Beijing male voice with a clear, natural timbre"),
("Eric", "Eric", "male", "zh", "Lively Chengdu male voice with a slightly husky brightness"),
("Ryan", "Ryan", "male", "en", "Dynamic male voice with strong rhythmic drive"),
("Aiden", "Aiden", "male", "en", "Sunny American male voice with a clear midrange"),
("Ono_Anna", "Ono Anna", "female", "ja", "Playful Japanese female voice with a light, nimble timbre"),
("Sohee", "Sohee", "female", "ko", "Warm Korean female voice with rich emotion"),
]
QWEN_CV_DEFAULT_SPEAKER = "Ryan"
# HuggingFace repo IDs per model size
QWEN_CV_HF_REPOS = {
"1.7B": "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
"0.6B": "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
}
class QwenCustomVoiceBackend:
"""Qwen3-TTS CustomVoice backend — preset speakers with instruct control."""
def __init__(self, model_size: str = "1.7B"):
self.model = None
self.model_size = model_size
self.device = self._get_device()
self._current_model_size: Optional[str] = None
def _get_device(self) -> str:
return get_torch_device(allow_xpu=True, allow_directml=True)
def is_loaded(self) -> bool:
return self.model is not None
def _get_model_path(self, model_size: str) -> str:
if model_size not in QWEN_CV_HF_REPOS:
raise ValueError(f"Unknown model size: {model_size}")
return QWEN_CV_HF_REPOS[model_size]
def _is_model_cached(self, model_size: Optional[str] = None) -> bool:
size = model_size or self.model_size
return is_model_cached(self._get_model_path(size))
async def load_model_async(self, model_size: Optional[str] = None) -> None:
if model_size is None:
model_size = self.model_size
if self.model is not None and self._current_model_size == model_size:
return
if self.model is not None and self._current_model_size != model_size:
self.unload_model()
await asyncio.to_thread(self._load_model_sync, model_size)
# Alias for compatibility with the TTSBackend protocol
load_model = load_model_async
def _load_model_sync(self, model_size: str) -> None:
model_name = f"qwen-custom-voice-{model_size}"
is_cached = self._is_model_cached(model_size)
with model_load_progress(model_name, is_cached):
from qwen_tts import Qwen3TTSModel
model_path = self._get_model_path(model_size)
logger.info("Loading Qwen CustomVoice %s on %s...", model_size, self.device)
if self.device == "cpu":
self.model = Qwen3TTSModel.from_pretrained(
model_path,
torch_dtype=torch.float32,
low_cpu_mem_usage=False,
)
else:
self.model = Qwen3TTSModel.from_pretrained(
model_path,
device_map=self.device,
torch_dtype=torch.bfloat16,
)
self._current_model_size = model_size
self.model_size = model_size
logger.info("Qwen CustomVoice %s loaded successfully", model_size)
def unload_model(self) -> None:
if self.model is not None:
del self.model
self.model = None
self._current_model_size = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Qwen CustomVoice unloaded")
async def create_voice_prompt(
self,
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> tuple[dict, bool]:
"""
Create voice prompt for CustomVoice.
CustomVoice doesn't use reference audio — it uses preset speakers.
When called for a cloned profile (fallback), uses the default speaker.
For preset profiles, the voice_prompt dict is built by the profile
service and bypasses this method entirely.
"""
return {
"voice_type": "preset",
"preset_engine": "qwen_custom_voice",
"preset_voice_id": QWEN_CV_DEFAULT_SPEAKER,
}, False
async def combine_voice_prompts(
self,
audio_paths: list[str],
reference_texts: list[str],
) -> tuple[np.ndarray, str]:
return await _combine_voice_prompts(audio_paths, reference_texts)
async def generate(
self,
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> tuple[np.ndarray, int]:
"""
Generate audio using Qwen CustomVoice.
Args:
text: Text to synthesize
voice_prompt: Dict with preset_voice_id (speaker name)
language: Language code (zh, en, ja, ko, etc.)
seed: Random seed for reproducibility
instruct: Natural language instruction for style control
(e.g. "Speak in an angry tone", "Very happy")
Returns:
Tuple of (audio_array, sample_rate)
"""
await self.load_model_async(None)
speaker = voice_prompt.get("preset_voice_id") or QWEN_CV_DEFAULT_SPEAKER
def _generate_sync():
if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
lang_name = LANGUAGE_CODE_TO_NAME.get(language, "auto")
kwargs = {
"text": text,
"language": lang_name.capitalize() if lang_name != "auto" else "Auto",
"speaker": speaker,
}
# Only pass instruct if non-empty
if instruct:
kwargs["instruct"] = instruct
# Inference runs with the process's default HF_HUB_OFFLINE
# state. Forcing offline here (issue #462) regressed online
# users whose libraries issue legitimate metadata lookups
# during generation.
wavs, sample_rate = self.model.generate_custom_voice(**kwargs)
return wavs[0], sample_rate
audio, sample_rate = await asyncio.to_thread(_generate_sync)
return audio, sample_rate