Initial commit
This commit is contained in:
621
backend/backends/__init__.py
Normal file
621
backend/backends/__init__.py
Normal file
@@ -0,0 +1,621 @@
|
||||
"""
|
||||
Backend abstraction layer for TTS and STT.
|
||||
|
||||
Provides a unified interface for MLX and PyTorch backends,
|
||||
and a model config registry that eliminates per-engine dispatch maps.
|
||||
"""
|
||||
|
||||
# Install HF compatibility patches before any backend imports transformers /
|
||||
# huggingface_hub. The module runs ``patch_transformers_mistral_regex`` at
|
||||
# import time, which wraps transformers' tokenizer load against the
|
||||
# unconditional HuggingFace metadata call that otherwise raises on
|
||||
# HF_HUB_OFFLINE=1 and on network failures.
|
||||
from ..utils import hf_offline_patch # noqa: F401
|
||||
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol, Optional, Tuple, List
|
||||
from typing_extensions import runtime_checkable
|
||||
import numpy as np
|
||||
|
||||
from ..utils.platform_detect import get_backend_type
|
||||
|
||||
LANGUAGE_CODE_TO_NAME = {
|
||||
"zh": "chinese",
|
||||
"en": "english",
|
||||
"ja": "japanese",
|
||||
"ko": "korean",
|
||||
"de": "german",
|
||||
"fr": "french",
|
||||
"ru": "russian",
|
||||
"pt": "portuguese",
|
||||
"es": "spanish",
|
||||
"it": "italian",
|
||||
}
|
||||
|
||||
WHISPER_HF_REPOS = {
|
||||
"base": "openai/whisper-base",
|
||||
"small": "openai/whisper-small",
|
||||
"medium": "openai/whisper-medium",
|
||||
"large": "openai/whisper-large-v3",
|
||||
"turbo": "openai/whisper-large-v3-turbo",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Declarative config for a downloadable model variant."""
|
||||
|
||||
model_name: str # e.g. "luxtts", "chatterbox-tts"
|
||||
display_name: str # e.g. "LuxTTS (Fast, CPU-friendly)"
|
||||
engine: str # e.g. "luxtts", "chatterbox"
|
||||
hf_repo_id: str # e.g. "YatharthS/LuxTTS"
|
||||
model_size: str = "default"
|
||||
size_mb: int = 0
|
||||
needs_trim: bool = False
|
||||
supports_instruct: bool = False
|
||||
languages: list[str] = field(default_factory=lambda: ["en"])
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TTSBackend(Protocol):
|
||||
"""Protocol for TTS backend implementations."""
|
||||
|
||||
# Each backend class should define MODEL_CONFIGS as a class variable:
|
||||
# MODEL_CONFIGS: list[ModelConfig]
|
||||
|
||||
async def load_model(self, model_size: str) -> None:
|
||||
"""Load TTS model."""
|
||||
...
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
Returns:
|
||||
Tuple of (voice_prompt_dict, was_cached)
|
||||
"""
|
||||
...
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
"""
|
||||
Combine multiple voice prompts.
|
||||
|
||||
Returns:
|
||||
Tuple of (combined_audio_array, combined_text)
|
||||
"""
|
||||
...
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio from text.
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
...
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
...
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
...
|
||||
|
||||
def _get_model_path(self, model_size: str) -> str:
|
||||
"""
|
||||
Get model path for a given size.
|
||||
|
||||
Returns:
|
||||
Model path or HuggingFace Hub ID
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class STTBackend(Protocol):
|
||||
"""Protocol for STT (Speech-to-Text) backend implementations."""
|
||||
|
||||
async def load_model(self, model_size: str) -> None:
|
||||
"""Load STT model."""
|
||||
...
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: str,
|
||||
language: Optional[str] = None,
|
||||
model_size: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe audio to text.
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
...
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
...
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
...
|
||||
|
||||
|
||||
# Global backend instances
|
||||
_tts_backend: Optional[TTSBackend] = None
|
||||
_tts_backends: dict[str, TTSBackend] = {}
|
||||
_tts_backends_lock = threading.Lock()
|
||||
_stt_backend: Optional[STTBackend] = None
|
||||
|
||||
# Supported TTS engines — keyed by engine name, value is the backend class import path.
|
||||
# The factory function uses this for the if/elif chain; the model configs live on the backend classes.
|
||||
TTS_ENGINES = {
|
||||
"qwen": "Qwen TTS",
|
||||
"qwen_custom_voice": "Qwen CustomVoice",
|
||||
"luxtts": "LuxTTS",
|
||||
"chatterbox": "Chatterbox TTS",
|
||||
"chatterbox_turbo": "Chatterbox Turbo",
|
||||
"tada": "TADA",
|
||||
"kokoro": "Kokoro",
|
||||
}
|
||||
|
||||
|
||||
def _get_qwen_model_configs() -> list[ModelConfig]:
|
||||
"""Return Qwen model configs with backend-aware HF repo IDs."""
|
||||
backend_type = get_backend_type()
|
||||
if backend_type == "mlx":
|
||||
repo_1_7b = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16"
|
||||
repo_0_6b = "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-bf16"
|
||||
else:
|
||||
repo_1_7b = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
repo_0_6b = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
|
||||
|
||||
return [
|
||||
ModelConfig(
|
||||
model_name="qwen-tts-1.7B",
|
||||
display_name="Qwen TTS 1.7B",
|
||||
engine="qwen",
|
||||
hf_repo_id=repo_1_7b,
|
||||
model_size="1.7B",
|
||||
size_mb=3500,
|
||||
supports_instruct=False, # Base model drops instruct silently
|
||||
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="qwen-tts-0.6B",
|
||||
display_name="Qwen TTS 0.6B",
|
||||
engine="qwen",
|
||||
hf_repo_id=repo_0_6b,
|
||||
model_size="0.6B",
|
||||
size_mb=1200,
|
||||
supports_instruct=False,
|
||||
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _get_qwen_custom_voice_configs() -> list[ModelConfig]:
|
||||
"""Return Qwen CustomVoice model configs."""
|
||||
return [
|
||||
ModelConfig(
|
||||
model_name="qwen-custom-voice-1.7B",
|
||||
display_name="Qwen CustomVoice 1.7B",
|
||||
engine="qwen_custom_voice",
|
||||
hf_repo_id="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
|
||||
model_size="1.7B",
|
||||
size_mb=3500,
|
||||
supports_instruct=True,
|
||||
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="qwen-custom-voice-0.6B",
|
||||
display_name="Qwen CustomVoice 0.6B",
|
||||
engine="qwen_custom_voice",
|
||||
hf_repo_id="Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
|
||||
model_size="0.6B",
|
||||
size_mb=1200,
|
||||
supports_instruct=True,
|
||||
languages=["zh", "en", "ja", "ko", "de", "fr", "ru", "pt", "es", "it"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _get_non_qwen_tts_configs() -> list[ModelConfig]:
|
||||
"""Return model configs for non-Qwen TTS engines.
|
||||
|
||||
These are static — no backend-type branching needed.
|
||||
"""
|
||||
return [
|
||||
ModelConfig(
|
||||
model_name="luxtts",
|
||||
display_name="LuxTTS (Fast, CPU-friendly)",
|
||||
engine="luxtts",
|
||||
hf_repo_id="YatharthS/LuxTTS",
|
||||
size_mb=300,
|
||||
languages=["en"],
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="chatterbox-tts",
|
||||
display_name="Chatterbox TTS (Multilingual)",
|
||||
engine="chatterbox",
|
||||
hf_repo_id="ResembleAI/chatterbox",
|
||||
size_mb=3200,
|
||||
needs_trim=True,
|
||||
languages=[
|
||||
"zh",
|
||||
"en",
|
||||
"ja",
|
||||
"ko",
|
||||
"de",
|
||||
"fr",
|
||||
"ru",
|
||||
"pt",
|
||||
"es",
|
||||
"it",
|
||||
"he",
|
||||
"ar",
|
||||
"da",
|
||||
"el",
|
||||
"fi",
|
||||
"hi",
|
||||
"ms",
|
||||
"nl",
|
||||
"no",
|
||||
"pl",
|
||||
"sv",
|
||||
"sw",
|
||||
"tr",
|
||||
],
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="chatterbox-turbo",
|
||||
display_name="Chatterbox Turbo (English, Tags)",
|
||||
engine="chatterbox_turbo",
|
||||
hf_repo_id="ResembleAI/chatterbox-turbo",
|
||||
size_mb=1500,
|
||||
needs_trim=True,
|
||||
languages=["en"],
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="tada-1b",
|
||||
display_name="TADA 1B (English)",
|
||||
engine="tada",
|
||||
hf_repo_id="HumeAI/tada-1b",
|
||||
model_size="1B",
|
||||
size_mb=4000,
|
||||
languages=["en"],
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="tada-3b-ml",
|
||||
display_name="TADA 3B Multilingual",
|
||||
engine="tada",
|
||||
hf_repo_id="HumeAI/tada-3b-ml",
|
||||
model_size="3B",
|
||||
size_mb=8000,
|
||||
languages=["en", "ar", "zh", "de", "es", "fr", "it", "ja", "pl", "pt"],
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="kokoro",
|
||||
display_name="Kokoro 82M",
|
||||
engine="kokoro",
|
||||
hf_repo_id="hexgrad/Kokoro-82M",
|
||||
size_mb=350,
|
||||
languages=["en", "es", "fr", "hi", "it", "pt", "ja", "zh"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _get_whisper_configs() -> list[ModelConfig]:
|
||||
"""Return Whisper STT model configs."""
|
||||
return [
|
||||
ModelConfig(
|
||||
model_name="whisper-base",
|
||||
display_name="Whisper Base",
|
||||
engine="whisper",
|
||||
hf_repo_id="openai/whisper-base",
|
||||
model_size="base",
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="whisper-small",
|
||||
display_name="Whisper Small",
|
||||
engine="whisper",
|
||||
hf_repo_id="openai/whisper-small",
|
||||
model_size="small",
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="whisper-medium",
|
||||
display_name="Whisper Medium",
|
||||
engine="whisper",
|
||||
hf_repo_id="openai/whisper-medium",
|
||||
model_size="medium",
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="whisper-large",
|
||||
display_name="Whisper Large",
|
||||
engine="whisper",
|
||||
hf_repo_id="openai/whisper-large-v3",
|
||||
model_size="large",
|
||||
),
|
||||
ModelConfig(
|
||||
model_name="whisper-turbo",
|
||||
display_name="Whisper Turbo",
|
||||
engine="whisper",
|
||||
hf_repo_id="openai/whisper-large-v3-turbo",
|
||||
model_size="turbo",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_all_model_configs() -> list[ModelConfig]:
|
||||
"""Return the full list of model configs (TTS + STT)."""
|
||||
return _get_qwen_model_configs() + _get_qwen_custom_voice_configs() + _get_non_qwen_tts_configs() + _get_whisper_configs()
|
||||
|
||||
|
||||
def get_tts_model_configs() -> list[ModelConfig]:
|
||||
"""Return only TTS model configs."""
|
||||
return _get_qwen_model_configs() + _get_qwen_custom_voice_configs() + _get_non_qwen_tts_configs()
|
||||
|
||||
|
||||
# Lookup helpers — these replace the if/elif chains in main.py
|
||||
|
||||
|
||||
def get_model_config(model_name: str) -> Optional[ModelConfig]:
|
||||
"""Look up a model config by model_name."""
|
||||
for cfg in get_all_model_configs():
|
||||
if cfg.model_name == model_name:
|
||||
return cfg
|
||||
return None
|
||||
|
||||
|
||||
def engine_needs_trim(engine: str) -> bool:
|
||||
"""Whether this engine's output should be run through trim_tts_output."""
|
||||
for cfg in get_tts_model_configs():
|
||||
if cfg.engine == engine:
|
||||
return cfg.needs_trim
|
||||
return False
|
||||
|
||||
|
||||
def engine_has_model_sizes(engine: str) -> bool:
|
||||
"""Whether this engine supports multiple model sizes (only Qwen currently)."""
|
||||
configs = [c for c in get_tts_model_configs() if c.engine == engine]
|
||||
return len(configs) > 1
|
||||
|
||||
|
||||
async def load_engine_model(engine: str, model_size: str = "default") -> None:
|
||||
"""Load a model for the given engine, handling engines with multiple model sizes."""
|
||||
backend = get_tts_backend_for_engine(engine)
|
||||
if engine in ("qwen", "qwen_custom_voice"):
|
||||
await backend.load_model_async(model_size)
|
||||
elif engine == "tada":
|
||||
await backend.load_model(model_size)
|
||||
else:
|
||||
await backend.load_model()
|
||||
|
||||
|
||||
async def ensure_model_cached_or_raise(engine: str, model_size: str = "default") -> None:
|
||||
"""Check if a model is cached, raise HTTPException if not. Used by streaming endpoint."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
backend = get_tts_backend_for_engine(engine)
|
||||
cfg = None
|
||||
for c in get_tts_model_configs():
|
||||
if c.engine == engine and c.model_size == model_size:
|
||||
cfg = c
|
||||
break
|
||||
|
||||
if engine in ("qwen", "qwen_custom_voice", "tada"):
|
||||
if not backend._is_model_cached(model_size):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Model {model_size} is not downloaded yet. Use /generate to trigger a download.",
|
||||
)
|
||||
else:
|
||||
if not backend._is_model_cached():
|
||||
display = cfg.display_name if cfg else engine
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"{display} model is not downloaded yet. Use /generate to trigger a download.",
|
||||
)
|
||||
|
||||
|
||||
def unload_model_by_config(config: ModelConfig) -> bool:
|
||||
"""Unload a model given its config. Returns True if it was loaded, False otherwise."""
|
||||
from . import get_tts_backend_for_engine
|
||||
from ..services import tts, transcribe
|
||||
|
||||
if config.engine == "whisper":
|
||||
whisper_model = transcribe.get_whisper_model()
|
||||
if whisper_model.is_loaded() and whisper_model.model_size == config.model_size:
|
||||
transcribe.unload_whisper_model()
|
||||
return True
|
||||
return False
|
||||
|
||||
if config.engine == "qwen":
|
||||
tts_model = tts.get_tts_model()
|
||||
loaded_size = getattr(tts_model, "_current_model_size", None) or getattr(tts_model, "model_size", None)
|
||||
if tts_model.is_loaded() and loaded_size == config.model_size:
|
||||
tts.unload_tts_model()
|
||||
return True
|
||||
return False
|
||||
|
||||
if config.engine == "qwen_custom_voice":
|
||||
backend = get_tts_backend_for_engine(config.engine)
|
||||
loaded_size = getattr(backend, "_current_model_size", None) or getattr(backend, "model_size", None)
|
||||
if backend.is_loaded() and loaded_size == config.model_size:
|
||||
backend.unload_model()
|
||||
return True
|
||||
return False
|
||||
|
||||
# All other TTS engines
|
||||
backend = get_tts_backend_for_engine(config.engine)
|
||||
if backend.is_loaded():
|
||||
backend.unload_model()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_model_loaded(config: ModelConfig) -> bool:
|
||||
"""Check if a model is currently loaded."""
|
||||
from . import get_tts_backend_for_engine
|
||||
from ..services import tts, transcribe
|
||||
|
||||
try:
|
||||
if config.engine == "whisper":
|
||||
whisper_model = transcribe.get_whisper_model()
|
||||
return whisper_model.is_loaded() and getattr(whisper_model, "model_size", None) == config.model_size
|
||||
|
||||
if config.engine == "qwen":
|
||||
tts_model = tts.get_tts_model()
|
||||
loaded_size = getattr(tts_model, "_current_model_size", None) or getattr(tts_model, "model_size", None)
|
||||
return tts_model.is_loaded() and loaded_size == config.model_size
|
||||
|
||||
if config.engine == "qwen_custom_voice":
|
||||
backend = get_tts_backend_for_engine(config.engine)
|
||||
loaded_size = getattr(backend, "_current_model_size", None) or getattr(backend, "model_size", None)
|
||||
return backend.is_loaded() and loaded_size == config.model_size
|
||||
|
||||
backend = get_tts_backend_for_engine(config.engine)
|
||||
return backend.is_loaded()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_model_load_func(config: ModelConfig):
|
||||
"""Return a callable that loads/downloads the model."""
|
||||
from . import get_tts_backend_for_engine
|
||||
from ..services import tts, transcribe
|
||||
|
||||
if config.engine == "whisper":
|
||||
return lambda: transcribe.get_whisper_model().load_model(config.model_size)
|
||||
|
||||
if config.engine == "qwen":
|
||||
return lambda: tts.get_tts_model().load_model(config.model_size)
|
||||
|
||||
if config.engine == "qwen_custom_voice":
|
||||
return lambda: get_tts_backend_for_engine(config.engine).load_model(config.model_size)
|
||||
|
||||
return lambda: get_tts_backend_for_engine(config.engine).load_model()
|
||||
|
||||
|
||||
def get_tts_backend() -> TTSBackend:
|
||||
"""
|
||||
Get or create the default (Qwen) TTS backend instance based on platform.
|
||||
|
||||
Returns:
|
||||
TTS backend instance (MLX or PyTorch)
|
||||
"""
|
||||
return get_tts_backend_for_engine("qwen")
|
||||
|
||||
|
||||
def get_tts_backend_for_engine(engine: str) -> TTSBackend:
|
||||
"""
|
||||
Get or create a TTS backend for the given engine.
|
||||
|
||||
Args:
|
||||
engine: Engine name (e.g. "qwen", "luxtts", "chatterbox", "chatterbox_turbo")
|
||||
|
||||
Returns:
|
||||
TTS backend instance
|
||||
"""
|
||||
global _tts_backends
|
||||
|
||||
# Fast path: check without lock
|
||||
if engine in _tts_backends:
|
||||
return _tts_backends[engine]
|
||||
|
||||
# Slow path: create with lock to avoid duplicate instantiation
|
||||
with _tts_backends_lock:
|
||||
# Double-check after acquiring lock
|
||||
if engine in _tts_backends:
|
||||
return _tts_backends[engine]
|
||||
|
||||
if engine == "qwen":
|
||||
backend_type = get_backend_type()
|
||||
if backend_type == "mlx":
|
||||
from .mlx_backend import MLXTTSBackend
|
||||
|
||||
backend = MLXTTSBackend()
|
||||
else:
|
||||
from .pytorch_backend import PyTorchTTSBackend
|
||||
|
||||
backend = PyTorchTTSBackend()
|
||||
elif engine == "luxtts":
|
||||
from .luxtts_backend import LuxTTSBackend
|
||||
|
||||
backend = LuxTTSBackend()
|
||||
elif engine == "chatterbox":
|
||||
from .chatterbox_backend import ChatterboxTTSBackend
|
||||
|
||||
backend = ChatterboxTTSBackend()
|
||||
elif engine == "chatterbox_turbo":
|
||||
from .chatterbox_turbo_backend import ChatterboxTurboTTSBackend
|
||||
|
||||
backend = ChatterboxTurboTTSBackend()
|
||||
elif engine == "tada":
|
||||
from .hume_backend import HumeTadaBackend
|
||||
|
||||
backend = HumeTadaBackend()
|
||||
elif engine == "kokoro":
|
||||
from .kokoro_backend import KokoroTTSBackend
|
||||
|
||||
backend = KokoroTTSBackend()
|
||||
elif engine == "qwen_custom_voice":
|
||||
from .qwen_custom_voice_backend import QwenCustomVoiceBackend
|
||||
|
||||
backend = QwenCustomVoiceBackend()
|
||||
else:
|
||||
raise ValueError(f"Unknown TTS engine: {engine}. Supported: {list(TTS_ENGINES.keys())}")
|
||||
|
||||
_tts_backends[engine] = backend
|
||||
return backend
|
||||
|
||||
|
||||
def get_stt_backend() -> STTBackend:
|
||||
"""
|
||||
Get or create STT backend instance based on platform.
|
||||
|
||||
Returns:
|
||||
STT backend instance (MLX or PyTorch)
|
||||
"""
|
||||
global _stt_backend
|
||||
|
||||
if _stt_backend is None:
|
||||
backend_type = get_backend_type()
|
||||
|
||||
if backend_type == "mlx":
|
||||
from .mlx_backend import MLXSTTBackend
|
||||
|
||||
_stt_backend = MLXSTTBackend()
|
||||
else:
|
||||
from .pytorch_backend import PyTorchSTTBackend
|
||||
|
||||
_stt_backend = PyTorchSTTBackend()
|
||||
|
||||
return _stt_backend
|
||||
|
||||
|
||||
def reset_backends():
|
||||
"""Reset backend instances (useful for testing)."""
|
||||
global _tts_backend, _tts_backends, _stt_backend
|
||||
_tts_backend = None
|
||||
_tts_backends.clear()
|
||||
_stt_backend = None
|
||||
327
backend/backends/base.py
Normal file
327
backend/backends/base.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
Shared utilities for TTS/STT backend implementations.
|
||||
|
||||
Eliminates duplication of cache checking, device detection,
|
||||
voice prompt combination, and model loading progress tracking.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import platform
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils.audio import normalize_audio, load_audio
|
||||
from ..utils.progress import get_progress_manager
|
||||
from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback
|
||||
from ..utils.tasks import get_task_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_model_cached(
|
||||
hf_repo: str,
|
||||
*,
|
||||
weight_extensions: tuple[str, ...] = (".safetensors", ".bin"),
|
||||
required_files: Optional[list[str]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a HuggingFace model is fully cached locally.
|
||||
|
||||
Args:
|
||||
hf_repo: HuggingFace repo ID (e.g. "Qwen/Qwen3-TTS-12Hz-1.7B-Base")
|
||||
weight_extensions: File extensions that count as model weights.
|
||||
required_files: If set, check that these specific filenames exist
|
||||
in snapshots instead of checking by extension.
|
||||
|
||||
Returns:
|
||||
True if model is fully cached, False if missing or incomplete.
|
||||
"""
|
||||
try:
|
||||
from huggingface_hub import constants as hf_constants
|
||||
|
||||
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + hf_repo.replace("/", "--"))
|
||||
|
||||
if not repo_cache.exists():
|
||||
return False
|
||||
|
||||
# Incomplete blobs mean a download is still in progress
|
||||
blobs_dir = repo_cache / "blobs"
|
||||
if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
|
||||
logger.debug(f"Found .incomplete files for {hf_repo}")
|
||||
return False
|
||||
|
||||
snapshots_dir = repo_cache / "snapshots"
|
||||
if not snapshots_dir.exists():
|
||||
return False
|
||||
|
||||
if required_files:
|
||||
# Check that every required filename exists somewhere in snapshots
|
||||
for fname in required_files:
|
||||
if not any(snapshots_dir.rglob(fname)):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Check that at least one weight file exists
|
||||
for ext in weight_extensions:
|
||||
if any(snapshots_dir.rglob(f"*{ext}")):
|
||||
return True
|
||||
|
||||
logger.debug(f"No model weights found for {hf_repo}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking cache for {hf_repo}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_torch_device(
|
||||
*,
|
||||
allow_xpu: bool = False,
|
||||
allow_directml: bool = False,
|
||||
allow_mps: bool = False,
|
||||
force_cpu_on_mac: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Detect the best available torch device.
|
||||
|
||||
Args:
|
||||
allow_xpu: Check for Intel XPU (IPEX) support.
|
||||
allow_directml: Check for DirectML (Windows) support.
|
||||
allow_mps: Allow MPS (Apple Silicon). If False, MPS falls back to CPU.
|
||||
force_cpu_on_mac: Force CPU on macOS regardless of GPU availability.
|
||||
"""
|
||||
if force_cpu_on_mac and platform.system() == "Darwin":
|
||||
return "cpu"
|
||||
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
|
||||
if allow_xpu:
|
||||
try:
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return "xpu"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if allow_directml:
|
||||
try:
|
||||
import torch_directml
|
||||
|
||||
if torch_directml.device_count() > 0:
|
||||
return torch_directml.device(0)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if allow_mps:
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
|
||||
return "cpu"
|
||||
|
||||
|
||||
def check_cuda_compatibility() -> tuple[bool, str | None]:
|
||||
"""Check if the installed PyTorch supports the current GPU's compute capability.
|
||||
|
||||
Returns:
|
||||
(compatible, warning_message) — compatible is True if OK or no CUDA GPU,
|
||||
warning_message is a human-readable string if there's a problem.
|
||||
"""
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return True, None
|
||||
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
capability = f"{major}.{minor}"
|
||||
device_name = torch.cuda.get_device_name(0)
|
||||
sm_tag = f"sm_{major}{minor}"
|
||||
|
||||
# torch.cuda._get_arch_list() returns the SM architectures this build
|
||||
# was compiled for (e.g. ["sm_50", "sm_60", ..., "sm_90"]).
|
||||
try:
|
||||
arch_list = torch.cuda._get_arch_list()
|
||||
if arch_list:
|
||||
# Check for both sm_XX and compute_XX (JIT-compiled) entries
|
||||
compute_tag = f"compute_{major}{minor}"
|
||||
if sm_tag not in arch_list and compute_tag not in arch_list:
|
||||
return False, (
|
||||
f"{device_name} (compute capability {capability} / {sm_tag}) "
|
||||
f"is not supported by this PyTorch build. "
|
||||
f"Supported architectures: {', '.join(arch_list)}. "
|
||||
f"Install PyTorch nightly (cu128) for newer GPU support: "
|
||||
f"pip install torch --index-url https://download.pytorch.org/whl/nightly/cu128"
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def empty_device_cache(device: str) -> None:
|
||||
"""
|
||||
Free cached memory on the given device (CUDA or XPU).
|
||||
|
||||
Backends should call this after unloading models so VRAM is returned
|
||||
to the OS.
|
||||
"""
|
||||
import torch
|
||||
|
||||
if device == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif device == "xpu" and hasattr(torch, "xpu"):
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
|
||||
def manual_seed(seed: int, device: str) -> None:
|
||||
"""
|
||||
Set the random seed on both CPU and the active accelerator.
|
||||
|
||||
Covers CUDA and Intel XPU so that generation is reproducible
|
||||
regardless of which GPU backend is in use.
|
||||
"""
|
||||
import torch
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if device == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
elif device == "xpu" and hasattr(torch, "xpu"):
|
||||
torch.xpu.manual_seed(seed)
|
||||
|
||||
|
||||
async def combine_voice_prompts(
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
*,
|
||||
sample_rate: Optional[int] = None,
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
"""
|
||||
Combine multiple reference audio samples into one.
|
||||
|
||||
Loads each audio file, normalizes, concatenates, and joins texts.
|
||||
|
||||
Args:
|
||||
audio_paths: Paths to reference audio files.
|
||||
reference_texts: Corresponding transcripts.
|
||||
sample_rate: If set, resample audio to this rate during loading.
|
||||
"""
|
||||
combined_audio = []
|
||||
|
||||
for path in audio_paths:
|
||||
kwargs = {"sample_rate": sample_rate} if sample_rate else {}
|
||||
audio, _sr = load_audio(path, **kwargs)
|
||||
audio = normalize_audio(audio)
|
||||
combined_audio.append(audio)
|
||||
|
||||
mixed = np.concatenate(combined_audio)
|
||||
mixed = normalize_audio(mixed)
|
||||
combined_text = " ".join(reference_texts)
|
||||
|
||||
return mixed, combined_text
|
||||
|
||||
|
||||
@contextmanager
|
||||
def model_load_progress(
|
||||
model_name: str,
|
||||
is_cached: bool,
|
||||
filter_non_downloads: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Context manager for model loading with HF download progress tracking.
|
||||
|
||||
Handles the tqdm patching, progress_manager/task_manager lifecycle,
|
||||
and error reporting that every backend duplicates.
|
||||
|
||||
Args:
|
||||
model_name: Progress tracking key (e.g. "qwen-tts-1.7B", "whisper-base").
|
||||
is_cached: Whether the model is already downloaded.
|
||||
filter_non_downloads: Whether to filter non-download tqdm bars.
|
||||
Defaults to `is_cached`.
|
||||
|
||||
Yields:
|
||||
The tracker context (already entered). The caller loads the model
|
||||
inside the `with` block. The tqdm patch is torn down on exit.
|
||||
|
||||
Usage:
|
||||
with model_load_progress("qwen-tts-1.7B", is_cached) as ctx:
|
||||
self.model = SomeModel.from_pretrained(...)
|
||||
"""
|
||||
if filter_non_downloads is None:
|
||||
filter_non_downloads = is_cached
|
||||
|
||||
progress_manager = get_progress_manager()
|
||||
task_manager = get_task_manager()
|
||||
|
||||
progress_callback = create_hf_progress_callback(model_name, progress_manager)
|
||||
tracker = HFProgressTracker(progress_callback, filter_non_downloads=filter_non_downloads)
|
||||
|
||||
tracker_context = tracker.patch_download()
|
||||
tracker_context.__enter__()
|
||||
|
||||
if not is_cached:
|
||||
task_manager.start_download(model_name)
|
||||
progress_manager.update_progress(
|
||||
model_name=model_name,
|
||||
current=0,
|
||||
total=0,
|
||||
filename="Connecting to HuggingFace...",
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
try:
|
||||
yield tracker_context
|
||||
except Exception as e:
|
||||
# Report error to both managers
|
||||
progress_manager.mark_error(model_name, str(e))
|
||||
task_manager.error_download(model_name, str(e))
|
||||
raise
|
||||
else:
|
||||
# Only mark complete if we were tracking a download
|
||||
if not is_cached:
|
||||
progress_manager.mark_complete(model_name)
|
||||
task_manager.complete_download(model_name)
|
||||
finally:
|
||||
tracker_context.__exit__(None, None, None)
|
||||
|
||||
|
||||
def patch_chatterbox_f32(model) -> None:
|
||||
"""
|
||||
Patch float64 -> float32 dtype mismatches in upstream chatterbox.
|
||||
|
||||
librosa.load returns float64 numpy arrays. Multiple upstream code paths
|
||||
convert these to torch tensors via torch.from_numpy() without casting,
|
||||
then matmul against float32 model weights. This patches the two known
|
||||
entry points:
|
||||
|
||||
1. S3Tokenizer.log_mel_spectrogram — audio tensor hits _mel_filters (f32)
|
||||
2. VoiceEncoder.forward — float64 mel spectrograms hit LSTM weights (f32)
|
||||
"""
|
||||
import types
|
||||
|
||||
# Patch S3Tokenizer
|
||||
_tokzr = model.s3gen.tokenizer
|
||||
_orig_log_mel = _tokzr.log_mel_spectrogram.__func__
|
||||
|
||||
def _f32_log_mel(self_tokzr, audio, padding=0):
|
||||
import torch as _torch
|
||||
|
||||
if _torch.is_tensor(audio):
|
||||
audio = audio.float()
|
||||
return _orig_log_mel(self_tokzr, audio, padding)
|
||||
|
||||
_tokzr.log_mel_spectrogram = types.MethodType(_f32_log_mel, _tokzr)
|
||||
|
||||
# Patch VoiceEncoder
|
||||
_ve = model.ve
|
||||
_orig_ve_forward = _ve.forward.__func__
|
||||
|
||||
def _f32_ve_forward(self_ve, mels):
|
||||
return _orig_ve_forward(self_ve, mels.float())
|
||||
|
||||
_ve.forward = types.MethodType(_f32_ve_forward, _ve)
|
||||
226
backend/backends/chatterbox_backend.py
Normal file
226
backend/backends/chatterbox_backend.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
Chatterbox TTS backend implementation.
|
||||
|
||||
Wraps ChatterboxMultilingualTTS from chatterbox-tts for zero-shot
|
||||
voice cloning. Supports 23 languages including Hebrew. Forces CPU
|
||||
on macOS due to known MPS tensor issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import TTSBackend
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
empty_device_cache,
|
||||
manual_seed,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
patch_chatterbox_f32,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHATTERBOX_HF_REPO = "ResembleAI/chatterbox"
|
||||
|
||||
# Files that must be present for the multilingual model
|
||||
_MTL_WEIGHT_FILES = [
|
||||
"t3_mtl23ls_v2.safetensors",
|
||||
"s3gen.pt",
|
||||
"ve.pt",
|
||||
]
|
||||
|
||||
|
||||
class ChatterboxTTSBackend:
|
||||
"""Chatterbox Multilingual TTS backend for voice cloning."""
|
||||
|
||||
# Class-level lock for torch.load monkey-patching
|
||||
_load_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.model_size = "default"
|
||||
self._device = None
|
||||
self._model_load_lock = asyncio.Lock()
|
||||
|
||||
def _get_device(self) -> str:
|
||||
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str = "default") -> str:
|
||||
return CHATTERBOX_HF_REPO
|
||||
|
||||
def _is_model_cached(self, model_size: str = "default") -> bool:
|
||||
return is_model_cached(CHATTERBOX_HF_REPO, required_files=_MTL_WEIGHT_FILES)
|
||||
|
||||
async def load_model(self, model_size: str = "default") -> None:
|
||||
"""Load the Chatterbox multilingual model."""
|
||||
if self.model is not None:
|
||||
return
|
||||
async with self._model_load_lock:
|
||||
if self.model is not None:
|
||||
return
|
||||
await asyncio.to_thread(self._load_model_sync)
|
||||
|
||||
def _load_model_sync(self):
|
||||
"""Synchronous model loading."""
|
||||
model_name = "chatterbox-tts"
|
||||
is_cached = self._is_model_cached()
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
device = self._get_device()
|
||||
self._device = device
|
||||
logger.info(f"Loading Chatterbox Multilingual TTS on {device}...")
|
||||
|
||||
import torch
|
||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||
|
||||
if device == "cpu":
|
||||
_orig_torch_load = torch.load
|
||||
|
||||
def _patched_load(*args, **kwargs):
|
||||
kwargs.setdefault("map_location", "cpu")
|
||||
return _orig_torch_load(*args, **kwargs)
|
||||
|
||||
with ChatterboxTTSBackend._load_lock:
|
||||
torch.load = _patched_load
|
||||
try:
|
||||
model = ChatterboxMultilingualTTS.from_pretrained(device=device)
|
||||
finally:
|
||||
torch.load = _orig_torch_load
|
||||
else:
|
||||
model = ChatterboxMultilingualTTS.from_pretrained(device=device)
|
||||
|
||||
# Fix sdpa attention for output_attentions support
|
||||
t3_tfmr = model.t3.tfmr
|
||||
if hasattr(t3_tfmr, "config") and hasattr(t3_tfmr.config, "_attn_implementation"):
|
||||
t3_tfmr.config._attn_implementation = "eager"
|
||||
for layer in getattr(t3_tfmr, "layers", []):
|
||||
if hasattr(layer, "self_attn"):
|
||||
layer.self_attn._attn_implementation = "eager"
|
||||
|
||||
patch_chatterbox_f32(model)
|
||||
self.model = model
|
||||
|
||||
logger.info("Chatterbox Multilingual TTS loaded successfully")
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
if self.model is not None:
|
||||
device = self._device
|
||||
del self.model
|
||||
self.model = None
|
||||
self._device = None
|
||||
empty_device_cache(device)
|
||||
logger.info("Chatterbox unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
Chatterbox processes reference audio at generation time, so the
|
||||
prompt just stores the file path. The actual audio is loaded by
|
||||
model.generate() via audio_prompt_path.
|
||||
"""
|
||||
voice_prompt = {
|
||||
"ref_audio": str(audio_path),
|
||||
"ref_text": reference_text,
|
||||
}
|
||||
return voice_prompt, False
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts)
|
||||
|
||||
# Per-language generation defaults. Lower temp + higher cfg = clearer speech.
|
||||
_LANG_DEFAULTS: ClassVar[dict] = {
|
||||
"he": {
|
||||
"exaggeration": 0.4,
|
||||
"cfg_weight": 0.7,
|
||||
"temperature": 0.65,
|
||||
"repetition_penalty": 2.5,
|
||||
},
|
||||
}
|
||||
_GLOBAL_DEFAULTS: ClassVar[dict] = {
|
||||
"exaggeration": 0.5,
|
||||
"cfg_weight": 0.5,
|
||||
"temperature": 0.8,
|
||||
"repetition_penalty": 2.0,
|
||||
}
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio using Chatterbox Multilingual TTS.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Dict with ref_audio path
|
||||
language: BCP-47 language code
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Unused (protocol compatibility)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model()
|
||||
|
||||
ref_audio = voice_prompt.get("ref_audio")
|
||||
if ref_audio and not Path(ref_audio).exists():
|
||||
logger.warning(f"Reference audio not found: {ref_audio}")
|
||||
ref_audio = None
|
||||
|
||||
# Merge language-specific defaults with global defaults
|
||||
lang_defaults = self._LANG_DEFAULTS.get(language, self._GLOBAL_DEFAULTS)
|
||||
|
||||
def _generate_sync():
|
||||
import torch
|
||||
|
||||
if seed is not None:
|
||||
manual_seed(seed, self._device)
|
||||
|
||||
logger.info(f"[Chatterbox] Generating: lang={language}")
|
||||
|
||||
wav = self.model.generate(
|
||||
text,
|
||||
language_id=language,
|
||||
audio_prompt_path=ref_audio,
|
||||
exaggeration=lang_defaults["exaggeration"],
|
||||
cfg_weight=lang_defaults["cfg_weight"],
|
||||
temperature=lang_defaults["temperature"],
|
||||
repetition_penalty=lang_defaults["repetition_penalty"],
|
||||
)
|
||||
|
||||
# Convert tensor -> numpy
|
||||
if isinstance(wav, torch.Tensor):
|
||||
audio = wav.squeeze().cpu().numpy().astype(np.float32)
|
||||
else:
|
||||
audio = np.asarray(wav, dtype=np.float32)
|
||||
|
||||
sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000)
|
||||
|
||||
return audio, sample_rate
|
||||
|
||||
return await asyncio.to_thread(_generate_sync)
|
||||
206
backend/backends/chatterbox_turbo_backend.py
Normal file
206
backend/backends/chatterbox_turbo_backend.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Chatterbox Turbo TTS backend implementation.
|
||||
|
||||
Wraps ChatterboxTurboTTS from chatterbox-tts for fast, English-only
|
||||
voice cloning with paralinguistic tag support ([laugh], [cough], etc.).
|
||||
Forces CPU on macOS due to known MPS tensor issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import TTSBackend
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
empty_device_cache,
|
||||
manual_seed,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
patch_chatterbox_f32,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CHATTERBOX_TURBO_HF_REPO = "ResembleAI/chatterbox-turbo"
|
||||
|
||||
# Files that must be present for the turbo model
|
||||
_TURBO_WEIGHT_FILES = [
|
||||
"t3_turbo_v1.safetensors",
|
||||
"s3gen_meanflow.safetensors",
|
||||
"ve.safetensors",
|
||||
]
|
||||
|
||||
|
||||
class ChatterboxTurboTTSBackend:
|
||||
"""Chatterbox Turbo TTS backend — fast, English-only, with paralinguistic tags."""
|
||||
|
||||
# Class-level lock for torch.load monkey-patching
|
||||
_load_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.model_size = "default"
|
||||
self._device = None
|
||||
self._model_load_lock = asyncio.Lock()
|
||||
|
||||
def _get_device(self) -> str:
|
||||
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str = "default") -> str:
|
||||
return CHATTERBOX_TURBO_HF_REPO
|
||||
|
||||
def _is_model_cached(self, model_size: str = "default") -> bool:
|
||||
return is_model_cached(CHATTERBOX_TURBO_HF_REPO, required_files=_TURBO_WEIGHT_FILES)
|
||||
|
||||
async def load_model(self, model_size: str = "default") -> None:
|
||||
"""Load the Chatterbox Turbo model."""
|
||||
if self.model is not None:
|
||||
return
|
||||
async with self._model_load_lock:
|
||||
if self.model is not None:
|
||||
return
|
||||
await asyncio.to_thread(self._load_model_sync)
|
||||
|
||||
def _load_model_sync(self):
|
||||
"""Synchronous model loading."""
|
||||
model_name = "chatterbox-turbo"
|
||||
is_cached = self._is_model_cached()
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
device = self._get_device()
|
||||
self._device = device
|
||||
logger.info(f"Loading Chatterbox Turbo TTS on {device}...")
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
||||
|
||||
local_path = snapshot_download(
|
||||
repo_id=CHATTERBOX_TURBO_HF_REPO,
|
||||
token=None,
|
||||
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"],
|
||||
)
|
||||
|
||||
if device == "cpu":
|
||||
_orig_torch_load = torch.load
|
||||
|
||||
def _patched_load(*args, **kwargs):
|
||||
kwargs.setdefault("map_location", "cpu")
|
||||
return _orig_torch_load(*args, **kwargs)
|
||||
|
||||
with ChatterboxTurboTTSBackend._load_lock:
|
||||
torch.load = _patched_load
|
||||
try:
|
||||
model = ChatterboxTurboTTS.from_local(local_path, device)
|
||||
finally:
|
||||
torch.load = _orig_torch_load
|
||||
else:
|
||||
model = ChatterboxTurboTTS.from_local(local_path, device)
|
||||
|
||||
patch_chatterbox_f32(model)
|
||||
self.model = model
|
||||
|
||||
logger.info("Chatterbox Turbo TTS loaded successfully")
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
if self.model is not None:
|
||||
device = self._device
|
||||
del self.model
|
||||
self.model = None
|
||||
self._device = None
|
||||
empty_device_cache(device)
|
||||
logger.info("Chatterbox Turbo unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
Chatterbox Turbo processes reference audio at generation time, so the
|
||||
prompt just stores the file path.
|
||||
"""
|
||||
voice_prompt = {
|
||||
"ref_audio": str(audio_path),
|
||||
"ref_text": reference_text,
|
||||
}
|
||||
return voice_prompt, False
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio using Chatterbox Turbo TTS.
|
||||
|
||||
Supports paralinguistic tags in text: [laugh], [cough], [chuckle], etc.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize (may include paralinguistic tags)
|
||||
voice_prompt: Dict with ref_audio path
|
||||
language: Ignored (Turbo is English-only)
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Unused (protocol compatibility)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model()
|
||||
|
||||
ref_audio = voice_prompt.get("ref_audio")
|
||||
if ref_audio and not Path(ref_audio).exists():
|
||||
logger.warning(f"Reference audio not found: {ref_audio}")
|
||||
ref_audio = None
|
||||
|
||||
def _generate_sync():
|
||||
import torch
|
||||
|
||||
if seed is not None:
|
||||
manual_seed(seed, self._device)
|
||||
|
||||
logger.info("[Chatterbox Turbo] Generating (English)")
|
||||
|
||||
wav = self.model.generate(
|
||||
text,
|
||||
audio_prompt_path=ref_audio,
|
||||
temperature=0.8,
|
||||
top_k=1000,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.2,
|
||||
)
|
||||
|
||||
# Convert tensor -> numpy
|
||||
if isinstance(wav, torch.Tensor):
|
||||
audio = wav.squeeze().cpu().numpy().astype(np.float32)
|
||||
else:
|
||||
audio = np.asarray(wav, dtype=np.float32)
|
||||
|
||||
sample_rate = getattr(self.model, "sr", None) or getattr(self.model, "sample_rate", 24000)
|
||||
|
||||
return audio, sample_rate
|
||||
|
||||
return await asyncio.to_thread(_generate_sync)
|
||||
346
backend/backends/hume_backend.py
Normal file
346
backend/backends/hume_backend.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
HumeAI TADA TTS backend implementation.
|
||||
|
||||
Wraps HumeAI's TADA (Text-Acoustic Dual Alignment) model for
|
||||
high-quality voice cloning. Two model variants:
|
||||
- tada-1b: English-only, ~2B params (Llama 3.2 1B base)
|
||||
- tada-3b-ml: Multilingual, ~4B params (Llama 3.2 3B base)
|
||||
|
||||
Both use a shared encoder/codec (HumeAI/tada-codec). The encoder
|
||||
produces 1:1 aligned token embeddings from reference audio, and the
|
||||
causal LM generates speech via flow-matching diffusion.
|
||||
|
||||
24kHz output, bf16 inference on CUDA, fp32 on CPU.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import ClassVar, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import TTSBackend
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
empty_device_cache,
|
||||
manual_seed,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
)
|
||||
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HuggingFace repos
|
||||
TADA_CODEC_REPO = "HumeAI/tada-codec"
|
||||
TADA_1B_REPO = "HumeAI/tada-1b"
|
||||
TADA_3B_ML_REPO = "HumeAI/tada-3b-ml"
|
||||
|
||||
TADA_MODEL_REPOS = {
|
||||
"1B": TADA_1B_REPO,
|
||||
"3B": TADA_3B_ML_REPO,
|
||||
}
|
||||
|
||||
# Key weight files for cache detection
|
||||
_TADA_MODEL_WEIGHT_FILES = [
|
||||
"model.safetensors",
|
||||
]
|
||||
|
||||
_TADA_CODEC_WEIGHT_FILES = [
|
||||
"encoder/model.safetensors",
|
||||
]
|
||||
|
||||
|
||||
class HumeTadaBackend:
|
||||
"""HumeAI TADA TTS backend for high-quality voice cloning."""
|
||||
|
||||
_load_lock: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.encoder = None
|
||||
self.model_size = "1B" # default to 1B
|
||||
self._device = None
|
||||
self._model_load_lock = asyncio.Lock()
|
||||
|
||||
def _get_device(self) -> str:
|
||||
# Force CPU on macOS — MPS has issues with flow matching
|
||||
# and large vocab lm_head (>65536 output channels)
|
||||
return get_torch_device(force_cpu_on_mac=True, allow_xpu=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str = "1B") -> str:
|
||||
return TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO)
|
||||
|
||||
def _is_model_cached(self, model_size: str = "1B") -> bool:
|
||||
repo = TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO)
|
||||
model_cached = is_model_cached(repo, required_files=_TADA_MODEL_WEIGHT_FILES)
|
||||
codec_cached = is_model_cached(TADA_CODEC_REPO, required_files=_TADA_CODEC_WEIGHT_FILES)
|
||||
return model_cached and codec_cached
|
||||
|
||||
async def load_model(self, model_size: str = "1B") -> None:
|
||||
"""Load the TADA model and encoder."""
|
||||
if self.model is not None and self.model_size == model_size:
|
||||
return
|
||||
async with self._model_load_lock:
|
||||
if self.model is not None and self.model_size == model_size:
|
||||
return
|
||||
# Unload existing model if switching sizes
|
||||
if self.model is not None:
|
||||
self.unload_model()
|
||||
self.model_size = model_size
|
||||
await asyncio.to_thread(self._load_model_sync, model_size)
|
||||
|
||||
def _load_model_sync(self, model_size: str = "1B"):
|
||||
"""Synchronous model loading with progress tracking."""
|
||||
model_name = f"tada-{model_size.lower()}"
|
||||
is_cached = self._is_model_cached(model_size)
|
||||
repo = TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO)
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
# Install DAC shim before importing tada — tada's encoder/decoder
|
||||
# import dac.nn.layers.Snake1d which requires the descript-audio-codec
|
||||
# package. The real package pulls in onnx/tensorboard/matplotlib via
|
||||
# descript-audiotools, so we use a lightweight shim instead.
|
||||
from ..utils.dac_shim import install_dac_shim
|
||||
|
||||
install_dac_shim()
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
device = self._get_device()
|
||||
self._device = device
|
||||
logger.info(f"Loading HumeAI TADA {model_size} on {device}...")
|
||||
|
||||
# Download codec (encoder + decoder) if not cached
|
||||
logger.info("Downloading TADA codec...")
|
||||
snapshot_download(
|
||||
repo_id=TADA_CODEC_REPO,
|
||||
token=None,
|
||||
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.bin"],
|
||||
)
|
||||
|
||||
# Download model weights if not cached
|
||||
logger.info(f"Downloading TADA {model_size} model...")
|
||||
snapshot_download(
|
||||
repo_id=repo,
|
||||
token=None,
|
||||
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.bin", "*.model"],
|
||||
)
|
||||
|
||||
# TADA hardcodes "meta-llama/Llama-3.2-1B" as the tokenizer
|
||||
# source in its Aligner and TadaForCausalLM.from_pretrained().
|
||||
# That repo is gated (requires Meta license acceptance).
|
||||
# Download the tokenizer from an ungated mirror and get its
|
||||
# local cache path so we can point TADA at it directly.
|
||||
logger.info("Downloading Llama tokenizer (ungated mirror)...")
|
||||
tokenizer_path = snapshot_download(
|
||||
repo_id="unsloth/Llama-3.2-1B",
|
||||
token=None,
|
||||
allow_patterns=["tokenizer*", "special_tokens*"],
|
||||
)
|
||||
|
||||
# Determine dtype — use bf16 on CUDA/XPU for ~50% memory savings
|
||||
if device == "cuda" and torch.cuda.is_bf16_supported():
|
||||
model_dtype = torch.bfloat16
|
||||
elif device == "xpu":
|
||||
# Intel Arc (Alchemist+) supports bf16 natively
|
||||
model_dtype = torch.bfloat16
|
||||
else:
|
||||
model_dtype = torch.float32
|
||||
|
||||
# Patch the Aligner config class to use the local tokenizer
|
||||
# path instead of the gated "meta-llama/Llama-3.2-1B" default.
|
||||
# This avoids monkey-patching AutoTokenizer.from_pretrained
|
||||
# which corrupts the classmethod descriptor for other engines.
|
||||
from tada.modules.aligner import AlignerConfig
|
||||
|
||||
AlignerConfig.tokenizer_name = tokenizer_path
|
||||
|
||||
# Load encoder (only needed for voice prompt encoding)
|
||||
from tada.modules.encoder import Encoder
|
||||
|
||||
logger.info("Loading TADA encoder...")
|
||||
self.encoder = Encoder.from_pretrained(TADA_CODEC_REPO, subfolder="encoder").to(device)
|
||||
self.encoder.eval()
|
||||
|
||||
# Load the causal LM (includes decoder for wav generation).
|
||||
# TadaForCausalLM.from_pretrained() calls
|
||||
# getattr(config, "tokenizer_name", "meta-llama/Llama-3.2-1B")
|
||||
# which hits the gated repo. Pre-load the config from HF,
|
||||
# inject the local tokenizer path, then pass it in.
|
||||
from tada.modules.tada import TadaForCausalLM, TadaConfig
|
||||
|
||||
logger.info(f"Loading TADA {model_size} model...")
|
||||
config = TadaConfig.from_pretrained(repo)
|
||||
config.tokenizer_name = tokenizer_path
|
||||
self.model = TadaForCausalLM.from_pretrained(repo, config=config, torch_dtype=model_dtype).to(device)
|
||||
self.model.eval()
|
||||
|
||||
logger.info(f"HumeAI TADA {model_size} loaded successfully on {device}")
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model and encoder to free memory."""
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
if self.encoder is not None:
|
||||
del self.encoder
|
||||
self.encoder = None
|
||||
|
||||
device = self._device
|
||||
self._device = None
|
||||
|
||||
if device:
|
||||
empty_device_cache(device)
|
||||
|
||||
logger.info("HumeAI TADA unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio using TADA's encoder.
|
||||
|
||||
TADA's encoder performs forced alignment between audio and text tokens,
|
||||
producing an EncoderOutput with 1:1 token-audio alignment. If no
|
||||
reference_text is provided, the encoder uses built-in ASR (English only).
|
||||
|
||||
We serialize the EncoderOutput to a dict for caching.
|
||||
"""
|
||||
await self.load_model(self.model_size)
|
||||
|
||||
cache_key = ("tada_" + get_cache_key(audio_path, reference_text)) if use_cache else None
|
||||
|
||||
if cache_key:
|
||||
cached = get_cached_voice_prompt(cache_key)
|
||||
if cached is not None and isinstance(cached, dict):
|
||||
return cached, True
|
||||
|
||||
def _encode_sync():
|
||||
import torch
|
||||
import soundfile as sf
|
||||
|
||||
device = self._device
|
||||
|
||||
# Load audio with soundfile (torchaudio 2.10+ requires torchcodec)
|
||||
audio_np, sr = sf.read(str(audio_path), dtype="float32")
|
||||
audio = torch.from_numpy(audio_np).float()
|
||||
if audio.ndim == 1:
|
||||
audio = audio.unsqueeze(0) # (samples,) -> (1, samples)
|
||||
else:
|
||||
audio = audio.T # (samples, channels) -> (channels, samples)
|
||||
audio = audio.to(device)
|
||||
|
||||
# Encode with forced alignment
|
||||
text_arg = [reference_text] if reference_text else None
|
||||
prompt = self.encoder(audio, text=text_arg, sample_rate=sr)
|
||||
|
||||
# Serialize EncoderOutput to a dict of CPU tensors for caching
|
||||
prompt_dict = {}
|
||||
for field_name in prompt.__dataclass_fields__:
|
||||
val = getattr(prompt, field_name)
|
||||
if isinstance(val, torch.Tensor):
|
||||
prompt_dict[field_name] = val.detach().cpu()
|
||||
elif isinstance(val, list):
|
||||
prompt_dict[field_name] = val
|
||||
elif isinstance(val, (int, float)):
|
||||
prompt_dict[field_name] = val
|
||||
else:
|
||||
prompt_dict[field_name] = val
|
||||
return prompt_dict
|
||||
|
||||
encoded = await asyncio.to_thread(_encode_sync)
|
||||
|
||||
if cache_key:
|
||||
cache_voice_prompt(cache_key, encoded)
|
||||
|
||||
return encoded, False
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts, sample_rate=24000)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio from text using HumeAI TADA.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Serialized EncoderOutput dict from create_voice_prompt()
|
||||
language: Language code (en, ar, de, es, fr, it, ja, pl, pt, zh)
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Not supported by TADA (ignored)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate=24000)
|
||||
"""
|
||||
await self.load_model(self.model_size)
|
||||
|
||||
def _generate_sync():
|
||||
import torch
|
||||
from tada.modules.encoder import EncoderOutput
|
||||
|
||||
if seed is not None:
|
||||
manual_seed(seed, self._device)
|
||||
|
||||
device = self._device
|
||||
|
||||
# Reconstruct EncoderOutput from the cached dict
|
||||
restored = {}
|
||||
for k, v in voice_prompt.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
# Move to device and match model dtype for float tensors
|
||||
if v.is_floating_point():
|
||||
model_dtype = next(self.model.parameters()).dtype
|
||||
restored[k] = v.to(device=device, dtype=model_dtype)
|
||||
else:
|
||||
restored[k] = v.to(device=device)
|
||||
else:
|
||||
restored[k] = v
|
||||
|
||||
prompt = EncoderOutput(**restored)
|
||||
|
||||
# For non-English with the 3B-ML model, we could reload the
|
||||
# encoder with the language-specific aligner. However, the
|
||||
# generation itself is language-agnostic — only the encoder's
|
||||
# aligner changes. Since we encode at create_voice_prompt time,
|
||||
# the language is already baked in. For simplicity, we don't
|
||||
# reload the encoder here.
|
||||
|
||||
logger.info(f"[TADA] Generating ({language}), text length: {len(text)}")
|
||||
|
||||
output = self.model.generate(
|
||||
prompt=prompt,
|
||||
text=text,
|
||||
)
|
||||
|
||||
# output.audio is a list of tensors (one per batch item)
|
||||
if output.audio and output.audio[0] is not None:
|
||||
audio_tensor = output.audio[0]
|
||||
audio = audio_tensor.detach().cpu().numpy().squeeze().astype(np.float32)
|
||||
else:
|
||||
logger.warning("[TADA] Generation produced no audio")
|
||||
audio = np.zeros(24000, dtype=np.float32)
|
||||
|
||||
return audio, 24000
|
||||
|
||||
return await asyncio.to_thread(_generate_sync)
|
||||
288
backend/backends/kokoro_backend.py
Normal file
288
backend/backends/kokoro_backend.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Kokoro TTS backend implementation.
|
||||
|
||||
Wraps the Kokoro-82M model for fast, lightweight text-to-speech.
|
||||
82M parameters, CPU realtime, 24kHz output, Apache 2.0 license.
|
||||
|
||||
Kokoro uses pre-built voice style vectors (not traditional zero-shot cloning
|
||||
from arbitrary audio). Voice prompts are stored as deferred references to
|
||||
HF-hosted voice .pt files.
|
||||
|
||||
Languages supported (via misaki G2P):
|
||||
- American English (a), British English (b)
|
||||
- Spanish (e), French (f), Hindi (h), Italian (i), Portuguese (p)
|
||||
- Japanese (j) — requires misaki[ja]
|
||||
- Chinese (z) — requires misaki[zh]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import TTSBackend
|
||||
from .base import (
|
||||
get_torch_device,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HuggingFace repo for model + voice detection
|
||||
KOKORO_HF_REPO = "hexgrad/Kokoro-82M"
|
||||
KOKORO_SAMPLE_RATE = 24000
|
||||
|
||||
# Default voice if none specified
|
||||
KOKORO_DEFAULT_VOICE = "af_heart"
|
||||
|
||||
# All available Kokoro voices: (voice_id, display_name, gender, lang_code)
|
||||
KOKORO_VOICES = [
|
||||
# American English female
|
||||
("af_alloy", "Alloy", "female", "en"),
|
||||
("af_aoede", "Aoede", "female", "en"),
|
||||
("af_bella", "Bella", "female", "en"),
|
||||
("af_heart", "Heart", "female", "en"),
|
||||
("af_jessica", "Jessica", "female", "en"),
|
||||
("af_kore", "Kore", "female", "en"),
|
||||
("af_nicole", "Nicole", "female", "en"),
|
||||
("af_nova", "Nova", "female", "en"),
|
||||
("af_river", "River", "female", "en"),
|
||||
("af_sarah", "Sarah", "female", "en"),
|
||||
("af_sky", "Sky", "female", "en"),
|
||||
# American English male
|
||||
("am_adam", "Adam", "male", "en"),
|
||||
("am_echo", "Echo", "male", "en"),
|
||||
("am_eric", "Eric", "male", "en"),
|
||||
("am_fenrir", "Fenrir", "male", "en"),
|
||||
("am_liam", "Liam", "male", "en"),
|
||||
("am_michael", "Michael", "male", "en"),
|
||||
("am_onyx", "Onyx", "male", "en"),
|
||||
("am_puck", "Puck", "male", "en"),
|
||||
("am_santa", "Santa", "male", "en"),
|
||||
# British English female
|
||||
("bf_alice", "Alice", "female", "en"),
|
||||
("bf_emma", "Emma", "female", "en"),
|
||||
("bf_isabella", "Isabella", "female", "en"),
|
||||
("bf_lily", "Lily", "female", "en"),
|
||||
# British English male
|
||||
("bm_daniel", "Daniel", "male", "en"),
|
||||
("bm_fable", "Fable", "male", "en"),
|
||||
("bm_george", "George", "male", "en"),
|
||||
("bm_lewis", "Lewis", "male", "en"),
|
||||
# Spanish
|
||||
("ef_dora", "Dora", "female", "es"),
|
||||
("em_alex", "Alex", "male", "es"),
|
||||
("em_santa", "Santa", "male", "es"),
|
||||
# French
|
||||
("ff_siwis", "Siwis", "female", "fr"),
|
||||
# Hindi
|
||||
("hf_alpha", "Alpha", "female", "hi"),
|
||||
("hf_beta", "Beta", "female", "hi"),
|
||||
("hm_omega", "Omega", "male", "hi"),
|
||||
("hm_psi", "Psi", "male", "hi"),
|
||||
# Italian
|
||||
("if_sara", "Sara", "female", "it"),
|
||||
("im_nicola", "Nicola", "male", "it"),
|
||||
# Japanese
|
||||
("jf_alpha", "Alpha", "female", "ja"),
|
||||
("jf_gongitsune", "Gongitsune", "female", "ja"),
|
||||
("jf_nezumi", "Nezumi", "female", "ja"),
|
||||
("jf_tebukuro", "Tebukuro", "female", "ja"),
|
||||
("jm_kumo", "Kumo", "male", "ja"),
|
||||
# Portuguese
|
||||
("pf_dora", "Dora", "female", "pt"),
|
||||
("pm_alex", "Alex", "male", "pt"),
|
||||
("pm_santa", "Santa", "male", "pt"),
|
||||
# Chinese
|
||||
("zf_xiaobei", "Xiaobei", "female", "zh"),
|
||||
("zf_xiaoni", "Xiaoni", "female", "zh"),
|
||||
("zf_xiaoxiao", "Xiaoxiao", "female", "zh"),
|
||||
("zf_xiaoyi", "Xiaoyi", "female", "zh"),
|
||||
]
|
||||
|
||||
# Map our ISO language codes to Kokoro lang_code characters
|
||||
LANG_CODE_MAP = {
|
||||
"en": "a", # American English
|
||||
"es": "e",
|
||||
"fr": "f",
|
||||
"hi": "h",
|
||||
"it": "i",
|
||||
"pt": "p",
|
||||
"ja": "j",
|
||||
"zh": "z",
|
||||
}
|
||||
|
||||
|
||||
class KokoroTTSBackend:
|
||||
"""Kokoro-82M TTS backend — tiny, fast, CPU-friendly."""
|
||||
|
||||
def __init__(self):
|
||||
self._model = None
|
||||
self._pipelines: dict = {} # lang_code -> KPipeline
|
||||
self._device: Optional[str] = None
|
||||
self.model_size = "default"
|
||||
|
||||
def _get_device(self) -> str:
|
||||
"""Select device. Kokoro supports CUDA and CPU. MPS needs fallback env var."""
|
||||
device = get_torch_device(allow_mps=False)
|
||||
# Kokoro can use MPS but requires PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
# For now, skip MPS to avoid user confusion — CPU is already realtime
|
||||
return device
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
if self._device is None:
|
||||
self._device = self._get_device()
|
||||
return self._device
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self._model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str) -> str:
|
||||
return KOKORO_HF_REPO
|
||||
|
||||
def _is_model_cached(self, model_size: str = "default") -> bool:
|
||||
"""Check if Kokoro model files are cached locally."""
|
||||
from .base import is_model_cached
|
||||
|
||||
return is_model_cached(
|
||||
KOKORO_HF_REPO,
|
||||
required_files=["config.json", "kokoro-v1_0.pth"],
|
||||
)
|
||||
|
||||
async def load_model(self, model_size: str = "default") -> None:
|
||||
"""Load the Kokoro model."""
|
||||
if self._model is not None:
|
||||
return
|
||||
await asyncio.to_thread(self._load_model_sync)
|
||||
|
||||
def _load_model_sync(self):
|
||||
"""Synchronous model loading."""
|
||||
model_name = "kokoro"
|
||||
is_cached = self._is_model_cached()
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
from kokoro import KModel
|
||||
|
||||
device = self.device
|
||||
logger.info(f"Loading Kokoro-82M on {device}...")
|
||||
|
||||
self._model = KModel(repo_id=KOKORO_HF_REPO).to(device).eval()
|
||||
|
||||
logger.info("Kokoro-82M loaded successfully")
|
||||
|
||||
def _get_pipeline(self, lang_code: str):
|
||||
"""Get or create a KPipeline for the given language code."""
|
||||
kokoro_lang = LANG_CODE_MAP.get(lang_code, "a")
|
||||
|
||||
if kokoro_lang not in self._pipelines:
|
||||
from kokoro import KPipeline
|
||||
|
||||
# Create pipeline with our existing model (no redundant model loading)
|
||||
self._pipelines[kokoro_lang] = KPipeline(
|
||||
lang_code=kokoro_lang,
|
||||
repo_id=KOKORO_HF_REPO,
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
return self._pipelines[kokoro_lang]
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
if self._model is not None:
|
||||
del self._model
|
||||
self._model = None
|
||||
self._pipelines.clear()
|
||||
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.info("Kokoro unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt for Kokoro.
|
||||
|
||||
Kokoro doesn't do traditional voice cloning from arbitrary audio.
|
||||
When called for a cloned profile (fallback), uses the default voice.
|
||||
For preset profiles, the voice_prompt dict is built by the profile
|
||||
service and bypasses this method entirely.
|
||||
"""
|
||||
return {
|
||||
"voice_type": "preset",
|
||||
"preset_engine": "kokoro",
|
||||
"preset_voice_id": KOKORO_DEFAULT_VOICE,
|
||||
}, False
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: list[str],
|
||||
reference_texts: list[str],
|
||||
) -> tuple[np.ndarray, str]:
|
||||
"""Combine voice prompts — uses base implementation for audio concatenation."""
|
||||
return await _combine_voice_prompts(
|
||||
audio_paths, reference_texts, sample_rate=KOKORO_SAMPLE_RATE
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio from text using Kokoro.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Dict with kokoro_voice key
|
||||
language: Language code
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Not supported by Kokoro (ignored)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model()
|
||||
|
||||
voice_name = voice_prompt.get("preset_voice_id") or voice_prompt.get("kokoro_voice") or KOKORO_DEFAULT_VOICE
|
||||
|
||||
def _generate_sync():
|
||||
import torch
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
pipeline = self._get_pipeline(language)
|
||||
|
||||
# Generate all chunks and concatenate
|
||||
audio_chunks = []
|
||||
for result in pipeline(text, voice=voice_name, speed=1.0):
|
||||
if result.audio is not None:
|
||||
chunk = result.audio
|
||||
if isinstance(chunk, torch.Tensor):
|
||||
chunk = chunk.detach().cpu().numpy()
|
||||
audio_chunks.append(chunk.squeeze())
|
||||
|
||||
if not audio_chunks:
|
||||
# Return 1 second of silence as fallback
|
||||
return np.zeros(KOKORO_SAMPLE_RATE, dtype=np.float32), KOKORO_SAMPLE_RATE
|
||||
|
||||
audio = np.concatenate(audio_chunks)
|
||||
return audio.astype(np.float32), KOKORO_SAMPLE_RATE
|
||||
|
||||
return await asyncio.to_thread(_generate_sync)
|
||||
184
backend/backends/luxtts_backend.py
Normal file
184
backend/backends/luxtts_backend.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
LuxTTS backend implementation.
|
||||
|
||||
Wraps the LuxTTS (ZipVoice) model for zero-shot voice cloning.
|
||||
~1GB VRAM, 48kHz output, 150x realtime on CPU.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import TTSBackend
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
empty_device_cache,
|
||||
manual_seed,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
)
|
||||
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HuggingFace repo for model weight detection
|
||||
LUXTTS_HF_REPO = "YatharthS/LuxTTS"
|
||||
|
||||
|
||||
class LuxTTSBackend:
|
||||
"""LuxTTS backend for zero-shot voice cloning."""
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.model_size = "default" # LuxTTS has only one model size
|
||||
self._device = None
|
||||
|
||||
def _get_device(self) -> str:
|
||||
return get_torch_device(allow_mps=True, allow_xpu=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
if self._device is None:
|
||||
self._device = self._get_device()
|
||||
return self._device
|
||||
|
||||
def _get_model_path(self, model_size: str) -> str:
|
||||
return LUXTTS_HF_REPO
|
||||
|
||||
def _is_model_cached(self, model_size: str = "default") -> bool:
|
||||
return is_model_cached(
|
||||
LUXTTS_HF_REPO,
|
||||
weight_extensions=(".pt", ".safetensors", ".onnx", ".bin"),
|
||||
)
|
||||
|
||||
async def load_model(self, model_size: str = "default") -> None:
|
||||
"""Load the LuxTTS model."""
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
await asyncio.to_thread(self._load_model_sync)
|
||||
|
||||
def _load_model_sync(self):
|
||||
model_name = "luxtts"
|
||||
is_cached = self._is_model_cached()
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
from zipvoice.luxvoice import LuxTTS
|
||||
|
||||
device = self.device
|
||||
logger.info(f"Loading LuxTTS on {device}...")
|
||||
|
||||
if device == "cpu":
|
||||
import os
|
||||
|
||||
threads = os.cpu_count() or 4
|
||||
self.model = LuxTTS(
|
||||
model_path=LUXTTS_HF_REPO,
|
||||
device="cpu",
|
||||
threads=min(threads, 8),
|
||||
)
|
||||
else:
|
||||
self.model = LuxTTS(model_path=LUXTTS_HF_REPO, device=device)
|
||||
|
||||
logger.info("LuxTTS loaded successfully")
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
if self.model is not None:
|
||||
device = self.device
|
||||
del self.model
|
||||
self.model = None
|
||||
self._device = None
|
||||
|
||||
empty_device_cache(device)
|
||||
|
||||
logger.info("LuxTTS unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
LuxTTS uses its own encode_prompt() which runs Whisper ASR internally
|
||||
to transcribe the reference. The reference_text parameter is not used
|
||||
by LuxTTS itself, but we include it in the cache key for consistency.
|
||||
"""
|
||||
await self.load_model()
|
||||
|
||||
# Compute cache key once for both lookup and storage
|
||||
cache_key = ("luxtts_" + get_cache_key(audio_path, reference_text)) if use_cache else None
|
||||
|
||||
if cache_key:
|
||||
cached = get_cached_voice_prompt(cache_key)
|
||||
if cached is not None and isinstance(cached, dict):
|
||||
return cached, True
|
||||
|
||||
def _encode_sync():
|
||||
return self.model.encode_prompt(
|
||||
prompt_audio=str(audio_path),
|
||||
duration=5,
|
||||
rms=0.01,
|
||||
)
|
||||
|
||||
encoded = await asyncio.to_thread(_encode_sync)
|
||||
|
||||
if cache_key:
|
||||
cache_voice_prompt(cache_key, encoded)
|
||||
|
||||
return encoded, False
|
||||
|
||||
async def combine_voice_prompts(self, audio_paths, reference_texts):
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts, sample_rate=24000)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio from text using LuxTTS.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Encoded prompt dict from encode_prompt()
|
||||
language: Language code (LuxTTS is English-focused)
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Not supported by LuxTTS (ignored)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model()
|
||||
|
||||
def _generate_sync():
|
||||
if seed is not None:
|
||||
manual_seed(seed, self.device)
|
||||
|
||||
wav = self.model.generate_speech(
|
||||
text=text,
|
||||
encode_dict=voice_prompt,
|
||||
num_steps=4,
|
||||
guidance_scale=3.0,
|
||||
t_shift=0.5,
|
||||
speed=1.0,
|
||||
return_smooth=False, # 48kHz output
|
||||
)
|
||||
|
||||
# LuxTTS returns a tensor (may be on GPU/MPS), move to CPU first
|
||||
audio = wav.detach().cpu().numpy().squeeze()
|
||||
return audio, 48000
|
||||
|
||||
return await asyncio.to_thread(_generate_sync)
|
||||
367
backend/backends/mlx_backend.py
Normal file
367
backend/backends/mlx_backend.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
MLX backend implementation for TTS and STT using mlx-audio.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# PATCH: Import and apply offline patch BEFORE any huggingface_hub usage
|
||||
# This prevents mlx_audio from making network requests when models are cached
|
||||
from ..utils.hf_offline_patch import patch_huggingface_hub_offline, ensure_original_qwen_config_cached
|
||||
|
||||
patch_huggingface_hub_offline()
|
||||
ensure_original_qwen_config_cached()
|
||||
|
||||
from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS
|
||||
from .base import is_model_cached, combine_voice_prompts as _combine_voice_prompts, model_load_progress
|
||||
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
|
||||
|
||||
|
||||
class MLXTTSBackend:
|
||||
"""MLX-based TTS backend using mlx-audio."""
|
||||
|
||||
def __init__(self, model_size: str = "1.7B"):
|
||||
self.model = None
|
||||
self.model_size = model_size
|
||||
self._current_model_size = None
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str) -> str:
|
||||
"""
|
||||
Get the MLX model path.
|
||||
|
||||
Args:
|
||||
model_size: Model size (1.7B or 0.6B)
|
||||
|
||||
Returns:
|
||||
HuggingFace Hub model ID for MLX
|
||||
"""
|
||||
mlx_model_map = {
|
||||
"1.7B": "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16",
|
||||
"0.6B": "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-bf16",
|
||||
}
|
||||
|
||||
if model_size not in mlx_model_map:
|
||||
raise ValueError(f"Unknown model size: {model_size}")
|
||||
|
||||
hf_model_id = mlx_model_map[model_size]
|
||||
logger.info("Will download MLX model from HuggingFace Hub: %s", hf_model_id)
|
||||
|
||||
return hf_model_id
|
||||
|
||||
def _is_model_cached(self, model_size: str) -> bool:
|
||||
return is_model_cached(
|
||||
self._get_model_path(model_size),
|
||||
weight_extensions=(".safetensors", ".bin", ".npz"),
|
||||
)
|
||||
|
||||
async def load_model_async(self, model_size: Optional[str] = None):
|
||||
"""
|
||||
Lazy load the MLX TTS model.
|
||||
|
||||
Args:
|
||||
model_size: Model size to load (1.7B or 0.6B)
|
||||
"""
|
||||
if model_size is None:
|
||||
model_size = self.model_size
|
||||
|
||||
# If already loaded with correct size, return
|
||||
if self.model is not None and self._current_model_size == model_size:
|
||||
return
|
||||
|
||||
# Unload existing model if different size requested
|
||||
if self.model is not None and self._current_model_size != model_size:
|
||||
self.unload_model()
|
||||
|
||||
# Run blocking load in thread pool
|
||||
await asyncio.to_thread(self._load_model_sync, model_size)
|
||||
|
||||
# Alias for compatibility
|
||||
load_model = load_model_async
|
||||
|
||||
def _load_model_sync(self, model_size: str):
|
||||
"""Synchronous model loading."""
|
||||
model_path = self._get_model_path(model_size)
|
||||
model_name = f"qwen-tts-{model_size}"
|
||||
is_cached = self._is_model_cached(model_size)
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
from mlx_audio.tts import load
|
||||
|
||||
logger.info("Loading MLX TTS model %s...", model_size)
|
||||
|
||||
self.model = load(model_path)
|
||||
|
||||
self._current_model_size = model_size
|
||||
self.model_size = model_size
|
||||
logger.info("MLX TTS model %s loaded successfully", model_size)
|
||||
|
||||
def unload_model(self):
|
||||
"""Unload the model to free memory."""
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
self._current_model_size = None
|
||||
logger.info("MLX TTS model unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
MLX backend stores voice prompt as a dict with audio path and text.
|
||||
The actual voice prompt processing happens during generation.
|
||||
|
||||
Args:
|
||||
audio_path: Path to reference audio file
|
||||
reference_text: Transcript of reference audio
|
||||
use_cache: Whether to use cached prompt if available
|
||||
|
||||
Returns:
|
||||
Tuple of (voice_prompt_dict, was_cached)
|
||||
"""
|
||||
await self.load_model_async(None)
|
||||
|
||||
# Check cache if enabled
|
||||
if use_cache:
|
||||
cache_key = get_cache_key(audio_path, reference_text)
|
||||
cached_prompt = get_cached_voice_prompt(cache_key)
|
||||
if cached_prompt is not None:
|
||||
# Return cached prompt (should be dict format)
|
||||
if isinstance(cached_prompt, dict):
|
||||
# Validate that the cached audio file still exists
|
||||
cached_audio_path = cached_prompt.get("ref_audio") or cached_prompt.get("ref_audio_path")
|
||||
if cached_audio_path and Path(cached_audio_path).exists():
|
||||
return cached_prompt, True
|
||||
else:
|
||||
# Cached file no longer exists, invalidate cache
|
||||
logger.warning("Cached audio file not found: %s, regenerating prompt", cached_audio_path)
|
||||
|
||||
# MLX voice prompt format - store audio path and text
|
||||
# The model will process this during generation
|
||||
voice_prompt_items = {
|
||||
"ref_audio": str(audio_path),
|
||||
"ref_text": reference_text,
|
||||
}
|
||||
|
||||
# Cache if enabled
|
||||
if use_cache:
|
||||
cache_key = get_cache_key(audio_path, reference_text)
|
||||
cache_voice_prompt(cache_key, voice_prompt_items)
|
||||
|
||||
return voice_prompt_items, False
|
||||
|
||||
async def combine_voice_prompts(self, audio_paths, reference_texts):
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio from text using voice prompt.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Voice prompt dictionary with ref_audio and ref_text
|
||||
language: Language code (en or zh) - may not be fully supported by MLX
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Natural language instruction (may not be supported by MLX)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model_async(None)
|
||||
|
||||
logger.info("Generating audio for text: %s", text)
|
||||
|
||||
def _generate_sync():
|
||||
"""Run synchronous generation in thread pool."""
|
||||
# MLX generate() returns a generator yielding GenerationResult objects
|
||||
audio_chunks = []
|
||||
sample_rate = 24000
|
||||
lang = LANGUAGE_CODE_TO_NAME.get(language, "auto")
|
||||
|
||||
# Set seed if provided (MLX uses numpy random)
|
||||
if seed is not None:
|
||||
import mlx.core as mx
|
||||
|
||||
np.random.seed(seed)
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Extract voice prompt info
|
||||
ref_audio = voice_prompt.get("ref_audio") or voice_prompt.get("ref_audio_path")
|
||||
ref_text = voice_prompt.get("ref_text", "")
|
||||
|
||||
# Validate that the audio file exists
|
||||
if ref_audio and not Path(ref_audio).exists():
|
||||
logger.warning("Audio file not found: %s", ref_audio)
|
||||
logger.warning("This may be due to a cached voice prompt referencing a deleted temp file.")
|
||||
logger.warning("Regenerating without voice prompt.")
|
||||
ref_audio = None
|
||||
|
||||
# Inference runs with the process's default HF_HUB_OFFLINE
|
||||
# state. Forcing offline here (previously used to avoid lazy
|
||||
# mlx_audio lookups hanging when the network drops mid-inference,
|
||||
# issue #462) regressed online users because libraries make
|
||||
# legitimate metadata calls during generation.
|
||||
try:
|
||||
if ref_audio:
|
||||
# Check if generate accepts ref_audio parameter
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(self.model.generate)
|
||||
if "ref_audio" in sig.parameters:
|
||||
# Generate with voice cloning
|
||||
for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text, lang_code=lang):
|
||||
audio_chunks.append(np.array(result.audio))
|
||||
sample_rate = result.sample_rate
|
||||
else:
|
||||
# Fallback: generate without voice cloning
|
||||
for result in self.model.generate(text, lang_code=lang):
|
||||
audio_chunks.append(np.array(result.audio))
|
||||
sample_rate = result.sample_rate
|
||||
else:
|
||||
# No voice prompt, generate normally
|
||||
for result in self.model.generate(text, lang_code=lang):
|
||||
audio_chunks.append(np.array(result.audio))
|
||||
sample_rate = result.sample_rate
|
||||
except Exception as e:
|
||||
# If voice cloning fails, try without it
|
||||
logger.warning("Voice cloning failed, generating without voice prompt: %s", e)
|
||||
for result in self.model.generate(text, lang_code=lang):
|
||||
audio_chunks.append(np.array(result.audio))
|
||||
sample_rate = result.sample_rate
|
||||
|
||||
# Concatenate all chunks
|
||||
if audio_chunks:
|
||||
audio = np.concatenate([np.asarray(chunk, dtype=np.float32) for chunk in audio_chunks])
|
||||
else:
|
||||
# Fallback: empty audio
|
||||
audio = np.array([], dtype=np.float32)
|
||||
|
||||
return audio, sample_rate
|
||||
|
||||
# Run blocking inference in thread pool
|
||||
audio, sample_rate = await asyncio.to_thread(_generate_sync)
|
||||
|
||||
return audio, sample_rate
|
||||
|
||||
|
||||
class MLXSTTBackend:
|
||||
"""MLX-based STT backend using mlx-audio Whisper."""
|
||||
|
||||
def __init__(self, model_size: str = "base"):
|
||||
self.model = None
|
||||
self.model_size = model_size
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self.model is not None
|
||||
|
||||
def _is_model_cached(self, model_size: str) -> bool:
|
||||
hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
|
||||
return is_model_cached(hf_repo, weight_extensions=(".safetensors", ".bin", ".npz"))
|
||||
|
||||
async def load_model_async(self, model_size: Optional[str] = None):
|
||||
"""
|
||||
Lazy load the MLX Whisper model.
|
||||
|
||||
Args:
|
||||
model_size: Model size (tiny, base, small, medium, large)
|
||||
"""
|
||||
if model_size is None:
|
||||
model_size = self.model_size
|
||||
|
||||
if self.model is not None and self.model_size == model_size:
|
||||
return
|
||||
|
||||
# Run blocking load in thread pool
|
||||
await asyncio.to_thread(self._load_model_sync, model_size)
|
||||
|
||||
# Alias for compatibility
|
||||
load_model = load_model_async
|
||||
|
||||
def _load_model_sync(self, model_size: str):
|
||||
"""Synchronous model loading."""
|
||||
progress_model_name = f"whisper-{model_size}"
|
||||
is_cached = self._is_model_cached(model_size)
|
||||
|
||||
with model_load_progress(progress_model_name, is_cached):
|
||||
from mlx_audio.stt import load
|
||||
|
||||
model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
|
||||
logger.info("Loading MLX Whisper model %s...", model_size)
|
||||
|
||||
self.model = load(model_name)
|
||||
|
||||
self.model_size = model_size
|
||||
logger.info("MLX Whisper model %s loaded successfully", model_size)
|
||||
|
||||
def unload_model(self):
|
||||
"""Unload the model to free memory."""
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
logger.info("MLX Whisper model unloaded")
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: str,
|
||||
language: Optional[str] = None,
|
||||
model_size: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe audio to text.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
language: Optional language hint
|
||||
model_size: Optional model size override
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
await self.load_model_async(model_size)
|
||||
|
||||
def _transcribe_sync():
|
||||
"""Run synchronous transcription in thread pool."""
|
||||
# MLX Whisper transcription using generate method
|
||||
# The generate method accepts audio path directly
|
||||
decode_options = {}
|
||||
if language:
|
||||
decode_options["language"] = language
|
||||
|
||||
# Inference runs with the process's default HF_HUB_OFFLINE
|
||||
# state — see the comment in MLXTTSBackend.generate for the
|
||||
# regression this revert fixes (issue #462).
|
||||
result = self.model.generate(str(audio_path), **decode_options)
|
||||
|
||||
# Extract text from result
|
||||
if isinstance(result, str):
|
||||
return result.strip()
|
||||
elif isinstance(result, dict):
|
||||
return result.get("text", "").strip()
|
||||
elif hasattr(result, "text"):
|
||||
return result.text.strip()
|
||||
else:
|
||||
return str(result).strip()
|
||||
|
||||
# Run blocking transcription in thread pool
|
||||
return await asyncio.to_thread(_transcribe_sync)
|
||||
378
backend/backends/pytorch_backend.py
Normal file
378
backend/backends/pytorch_backend.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
PyTorch backend implementation for TTS and STT.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
import asyncio
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
empty_device_cache,
|
||||
manual_seed,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
)
|
||||
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
|
||||
from ..utils.audio import load_audio
|
||||
|
||||
|
||||
class PyTorchTTSBackend:
|
||||
"""PyTorch-based TTS backend using Qwen3-TTS."""
|
||||
|
||||
def __init__(self, model_size: str = "1.7B"):
|
||||
self.model = None
|
||||
self.model_size = model_size
|
||||
self.device = self._get_device()
|
||||
self._current_model_size = None
|
||||
|
||||
def _get_device(self) -> str:
|
||||
"""Get the best available device."""
|
||||
return get_torch_device(allow_xpu=True, allow_directml=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str) -> str:
|
||||
"""
|
||||
Get the HuggingFace Hub model ID.
|
||||
|
||||
Args:
|
||||
model_size: Model size (1.7B or 0.6B)
|
||||
|
||||
Returns:
|
||||
HuggingFace Hub model ID
|
||||
"""
|
||||
hf_model_map = {
|
||||
"1.7B": "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
|
||||
"0.6B": "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
|
||||
}
|
||||
|
||||
if model_size not in hf_model_map:
|
||||
raise ValueError(f"Unknown model size: {model_size}")
|
||||
|
||||
return hf_model_map[model_size]
|
||||
|
||||
def _is_model_cached(self, model_size: str) -> bool:
|
||||
return is_model_cached(self._get_model_path(model_size))
|
||||
|
||||
async def load_model_async(self, model_size: Optional[str] = None):
|
||||
"""
|
||||
Lazy load the TTS model with automatic downloading from HuggingFace Hub.
|
||||
|
||||
Args:
|
||||
model_size: Model size to load (1.7B or 0.6B)
|
||||
"""
|
||||
if model_size is None:
|
||||
model_size = self.model_size
|
||||
|
||||
# If already loaded with correct size, return
|
||||
if self.model is not None and self._current_model_size == model_size:
|
||||
return
|
||||
|
||||
# Unload existing model if different size requested
|
||||
if self.model is not None and self._current_model_size != model_size:
|
||||
self.unload_model()
|
||||
|
||||
# Run blocking load in thread pool
|
||||
await asyncio.to_thread(self._load_model_sync, model_size)
|
||||
|
||||
# Alias for compatibility
|
||||
load_model = load_model_async
|
||||
|
||||
def _load_model_sync(self, model_size: str):
|
||||
"""Synchronous model loading."""
|
||||
model_name = f"qwen-tts-{model_size}"
|
||||
is_cached = self._is_model_cached(model_size)
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
|
||||
model_path = self._get_model_path(model_size)
|
||||
logger.info("Loading TTS model %s on %s...", model_size, self.device)
|
||||
|
||||
# Route both HF Hub and Transformers through a single cache root.
|
||||
# On Windows local setups, model assets can otherwise split between
|
||||
# .hf-cache/hub and .hf-cache/transformers, causing speech_tokenizer
|
||||
# and preprocessor_config.json to fail to resolve during load.
|
||||
from huggingface_hub import constants as hf_constants
|
||||
tts_cache_dir = hf_constants.HF_HUB_CACHE
|
||||
|
||||
if self.device == "cpu":
|
||||
self.model = Qwen3TTSModel.from_pretrained(
|
||||
model_path,
|
||||
cache_dir=tts_cache_dir,
|
||||
torch_dtype=torch.float32,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
else:
|
||||
self.model = Qwen3TTSModel.from_pretrained(
|
||||
model_path,
|
||||
cache_dir=tts_cache_dir,
|
||||
device_map=self.device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
self._current_model_size = model_size
|
||||
self.model_size = model_size
|
||||
logger.info("TTS model %s loaded successfully", model_size)
|
||||
|
||||
def unload_model(self):
|
||||
"""Unload the model to free memory."""
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
self._current_model_size = None
|
||||
|
||||
empty_device_cache(self.device)
|
||||
|
||||
logger.info("TTS model unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
Args:
|
||||
audio_path: Path to reference audio file
|
||||
reference_text: Transcript of reference audio
|
||||
use_cache: Whether to use cached prompt if available
|
||||
|
||||
Returns:
|
||||
Tuple of (voice_prompt_dict, was_cached)
|
||||
"""
|
||||
await self.load_model_async(None)
|
||||
|
||||
# Check cache if enabled
|
||||
if use_cache:
|
||||
cache_key = get_cache_key(audio_path, reference_text)
|
||||
cached_prompt = get_cached_voice_prompt(cache_key)
|
||||
if cached_prompt is not None:
|
||||
# Cache stores as torch.Tensor but actual prompt is dict
|
||||
# Convert if needed
|
||||
if isinstance(cached_prompt, dict):
|
||||
# For PyTorch backend, the dict should contain tensors, not file paths
|
||||
# So we can safely return it
|
||||
return cached_prompt, True
|
||||
elif isinstance(cached_prompt, torch.Tensor):
|
||||
# Legacy cache format - convert to dict
|
||||
# This shouldn't happen in practice, but handle it
|
||||
return {"prompt": cached_prompt}, True
|
||||
|
||||
def _create_prompt_sync():
|
||||
"""Run synchronous voice prompt creation in thread pool."""
|
||||
# Inference runs with the process's default HF_HUB_OFFLINE
|
||||
# state. Forcing offline here (issue #462) regressed online
|
||||
# users whose libraries issue legitimate metadata lookups
|
||||
# during voice-prompt creation.
|
||||
return self.model.create_voice_clone_prompt(
|
||||
ref_audio=str(audio_path),
|
||||
ref_text=reference_text,
|
||||
x_vector_only_mode=False,
|
||||
)
|
||||
|
||||
# Run blocking operation in thread pool
|
||||
voice_prompt_items = await asyncio.to_thread(_create_prompt_sync)
|
||||
|
||||
# Cache if enabled
|
||||
if use_cache:
|
||||
cache_key = get_cache_key(audio_path, reference_text)
|
||||
cache_voice_prompt(cache_key, voice_prompt_items)
|
||||
|
||||
return voice_prompt_items, False
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio from text using voice prompt.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Voice prompt dictionary from create_voice_prompt
|
||||
language: Language code (en or zh)
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Natural language instruction for speech delivery control
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
# Load model
|
||||
await self.load_model_async(None)
|
||||
|
||||
def _generate_sync():
|
||||
"""Run synchronous generation in thread pool."""
|
||||
# Set seed if provided
|
||||
if seed is not None:
|
||||
manual_seed(seed, self.device)
|
||||
|
||||
# See _create_prompt_sync comment — inference runs with the
|
||||
# process's default HF_HUB_OFFLINE state (issue #462).
|
||||
wavs, sample_rate = self.model.generate_voice_clone(
|
||||
text=text,
|
||||
voice_clone_prompt=voice_prompt,
|
||||
language=LANGUAGE_CODE_TO_NAME.get(language, "auto"),
|
||||
instruct=instruct,
|
||||
)
|
||||
return wavs[0], sample_rate
|
||||
|
||||
# Run blocking inference in thread pool to avoid blocking event loop
|
||||
audio, sample_rate = await asyncio.to_thread(_generate_sync)
|
||||
|
||||
return audio, sample_rate
|
||||
|
||||
|
||||
class PyTorchSTTBackend:
|
||||
"""PyTorch-based STT backend using Whisper."""
|
||||
|
||||
def __init__(self, model_size: str = "base"):
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.model_size = model_size
|
||||
self.device = self._get_device()
|
||||
|
||||
def _get_device(self) -> str:
|
||||
"""Get the best available device."""
|
||||
return get_torch_device(allow_xpu=True, allow_directml=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self.model is not None
|
||||
|
||||
def _is_model_cached(self, model_size: str) -> bool:
|
||||
hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
|
||||
return is_model_cached(hf_repo)
|
||||
|
||||
async def load_model_async(self, model_size: Optional[str] = None):
|
||||
"""
|
||||
Lazy load the Whisper model.
|
||||
|
||||
Args:
|
||||
model_size: Model size (tiny, base, small, medium, large)
|
||||
"""
|
||||
if model_size is None:
|
||||
model_size = self.model_size
|
||||
|
||||
if self.model is not None and self.model_size == model_size:
|
||||
return
|
||||
|
||||
await asyncio.to_thread(self._load_model_sync, model_size)
|
||||
|
||||
# Alias for compatibility
|
||||
load_model = load_model_async
|
||||
|
||||
def _load_model_sync(self, model_size: str):
|
||||
"""Synchronous model loading."""
|
||||
progress_model_name = f"whisper-{model_size}"
|
||||
is_cached = self._is_model_cached(model_size)
|
||||
|
||||
with model_load_progress(progress_model_name, is_cached):
|
||||
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||
|
||||
model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
|
||||
logger.info("Loading Whisper model %s on %s...", model_size, self.device)
|
||||
|
||||
self.processor = WhisperProcessor.from_pretrained(model_name)
|
||||
self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
|
||||
|
||||
self.model.to(self.device)
|
||||
self.model_size = model_size
|
||||
logger.info("Whisper model %s loaded successfully", model_size)
|
||||
|
||||
def unload_model(self):
|
||||
"""Unload the model to free memory."""
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
del self.processor
|
||||
self.model = None
|
||||
self.processor = None
|
||||
|
||||
empty_device_cache(self.device)
|
||||
|
||||
logger.info("Whisper model unloaded")
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: str,
|
||||
language: Optional[str] = None,
|
||||
model_size: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe audio to text.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
language: Optional language hint
|
||||
model_size: Optional model size override
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
await self.load_model_async(model_size)
|
||||
|
||||
def _transcribe_sync():
|
||||
"""Run synchronous transcription in thread pool."""
|
||||
# Load audio
|
||||
audio, _sr = load_audio(audio_path, sample_rate=16000)
|
||||
|
||||
# Inference runs with the process's default HF_HUB_OFFLINE
|
||||
# state — forcing offline here (issue #462) broke online users
|
||||
# whose `get_decoder_prompt_ids` / tokenizer calls issue
|
||||
# legitimate metadata lookups.
|
||||
# Process audio
|
||||
inputs = self.processor(
|
||||
audio,
|
||||
sampling_rate=16000,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(self.device)
|
||||
|
||||
# Generate transcription
|
||||
# If language is provided, force it; otherwise let Whisper auto-detect
|
||||
generate_kwargs = {}
|
||||
if language:
|
||||
forced_decoder_ids = self.processor.get_decoder_prompt_ids(
|
||||
language=language,
|
||||
task="transcribe",
|
||||
)
|
||||
generate_kwargs["forced_decoder_ids"] = forced_decoder_ids
|
||||
|
||||
with torch.no_grad():
|
||||
predicted_ids = self.model.generate(
|
||||
inputs["input_features"],
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
# Decode
|
||||
transcription = self.processor.batch_decode(
|
||||
predicted_ids,
|
||||
skip_special_tokens=True,
|
||||
)[0]
|
||||
|
||||
return transcription.strip()
|
||||
|
||||
# Run blocking transcription in thread pool
|
||||
return await asyncio.to_thread(_transcribe_sync)
|
||||
214
backend/backends/qwen_custom_voice_backend.py
Normal file
214
backend/backends/qwen_custom_voice_backend.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Qwen3-TTS CustomVoice backend implementation.
|
||||
|
||||
Wraps the Qwen3-TTS-12Hz CustomVoice model for preset-speaker TTS with
|
||||
instruction-based style control. Uses the same qwen_tts library as the
|
||||
Base model (pytorch_backend.py) but loads a different checkpoint and
|
||||
calls generate_custom_voice() instead of generate_voice_clone().
|
||||
|
||||
Key differences from the Base engine:
|
||||
- Uses preset speakers (9 built-in voices) instead of zero-shot cloning
|
||||
- Supports instruct parameter for tone/emotion/prosody control
|
||||
- Two model sizes: 1.7B and 0.6B
|
||||
|
||||
Languages supported: zh, en, ja, ko, de, fr, ru, pt, es, it
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import TTSBackend, LANGUAGE_CODE_TO_NAME
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Preset speakers ──────────────────────────────────────────────────
|
||||
|
||||
# (speaker_id, display_name, gender, native_language_code, description)
|
||||
QWEN_CUSTOM_VOICES = [
|
||||
("Vivian", "Vivian", "female", "zh", "Bright, slightly edgy young female voice"),
|
||||
("Serena", "Serena", "female", "zh", "Warm, gentle young female voice"),
|
||||
("Uncle_Fu", "Uncle Fu", "male", "zh", "Seasoned male voice with a low, mellow timbre"),
|
||||
("Dylan", "Dylan", "male", "zh", "Youthful Beijing male voice with a clear, natural timbre"),
|
||||
("Eric", "Eric", "male", "zh", "Lively Chengdu male voice with a slightly husky brightness"),
|
||||
("Ryan", "Ryan", "male", "en", "Dynamic male voice with strong rhythmic drive"),
|
||||
("Aiden", "Aiden", "male", "en", "Sunny American male voice with a clear midrange"),
|
||||
("Ono_Anna", "Ono Anna", "female", "ja", "Playful Japanese female voice with a light, nimble timbre"),
|
||||
("Sohee", "Sohee", "female", "ko", "Warm Korean female voice with rich emotion"),
|
||||
]
|
||||
|
||||
QWEN_CV_DEFAULT_SPEAKER = "Ryan"
|
||||
|
||||
# HuggingFace repo IDs per model size
|
||||
QWEN_CV_HF_REPOS = {
|
||||
"1.7B": "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
|
||||
"0.6B": "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
|
||||
}
|
||||
|
||||
|
||||
class QwenCustomVoiceBackend:
|
||||
"""Qwen3-TTS CustomVoice backend — preset speakers with instruct control."""
|
||||
|
||||
def __init__(self, model_size: str = "1.7B"):
|
||||
self.model = None
|
||||
self.model_size = model_size
|
||||
self.device = self._get_device()
|
||||
self._current_model_size: Optional[str] = None
|
||||
|
||||
def _get_device(self) -> str:
|
||||
return get_torch_device(allow_xpu=True, allow_directml=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str) -> str:
|
||||
if model_size not in QWEN_CV_HF_REPOS:
|
||||
raise ValueError(f"Unknown model size: {model_size}")
|
||||
return QWEN_CV_HF_REPOS[model_size]
|
||||
|
||||
def _is_model_cached(self, model_size: Optional[str] = None) -> bool:
|
||||
size = model_size or self.model_size
|
||||
return is_model_cached(self._get_model_path(size))
|
||||
|
||||
async def load_model_async(self, model_size: Optional[str] = None) -> None:
|
||||
if model_size is None:
|
||||
model_size = self.model_size
|
||||
|
||||
if self.model is not None and self._current_model_size == model_size:
|
||||
return
|
||||
|
||||
if self.model is not None and self._current_model_size != model_size:
|
||||
self.unload_model()
|
||||
|
||||
await asyncio.to_thread(self._load_model_sync, model_size)
|
||||
|
||||
# Alias for compatibility with the TTSBackend protocol
|
||||
load_model = load_model_async
|
||||
|
||||
def _load_model_sync(self, model_size: str) -> None:
|
||||
model_name = f"qwen-custom-voice-{model_size}"
|
||||
is_cached = self._is_model_cached(model_size)
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
|
||||
model_path = self._get_model_path(model_size)
|
||||
logger.info("Loading Qwen CustomVoice %s on %s...", model_size, self.device)
|
||||
|
||||
if self.device == "cpu":
|
||||
self.model = Qwen3TTSModel.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.float32,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
else:
|
||||
self.model = Qwen3TTSModel.from_pretrained(
|
||||
model_path,
|
||||
device_map=self.device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
self._current_model_size = model_size
|
||||
self.model_size = model_size
|
||||
logger.info("Qwen CustomVoice %s loaded successfully", model_size)
|
||||
|
||||
def unload_model(self) -> None:
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
self._current_model_size = None
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.info("Qwen CustomVoice unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt for CustomVoice.
|
||||
|
||||
CustomVoice doesn't use reference audio — it uses preset speakers.
|
||||
When called for a cloned profile (fallback), uses the default speaker.
|
||||
For preset profiles, the voice_prompt dict is built by the profile
|
||||
service and bypasses this method entirely.
|
||||
"""
|
||||
return {
|
||||
"voice_type": "preset",
|
||||
"preset_engine": "qwen_custom_voice",
|
||||
"preset_voice_id": QWEN_CV_DEFAULT_SPEAKER,
|
||||
}, False
|
||||
|
||||
async def combine_voice_prompts(
|
||||
self,
|
||||
audio_paths: list[str],
|
||||
reference_texts: list[str],
|
||||
) -> tuple[np.ndarray, str]:
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio using Qwen CustomVoice.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Dict with preset_voice_id (speaker name)
|
||||
language: Language code (zh, en, ja, ko, etc.)
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Natural language instruction for style control
|
||||
(e.g. "Speak in an angry tone", "Very happy")
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model_async(None)
|
||||
|
||||
speaker = voice_prompt.get("preset_voice_id") or QWEN_CV_DEFAULT_SPEAKER
|
||||
|
||||
def _generate_sync():
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
lang_name = LANGUAGE_CODE_TO_NAME.get(language, "auto")
|
||||
|
||||
kwargs = {
|
||||
"text": text,
|
||||
"language": lang_name.capitalize() if lang_name != "auto" else "Auto",
|
||||
"speaker": speaker,
|
||||
}
|
||||
|
||||
# Only pass instruct if non-empty
|
||||
if instruct:
|
||||
kwargs["instruct"] = instruct
|
||||
|
||||
# Inference runs with the process's default HF_HUB_OFFLINE
|
||||
# state. Forcing offline here (issue #462) regressed online
|
||||
# users whose libraries issue legitimate metadata lookups
|
||||
# during generation.
|
||||
wavs, sample_rate = self.model.generate_custom_voice(**kwargs)
|
||||
return wavs[0], sample_rate
|
||||
|
||||
audio, sample_rate = await asyncio.to_thread(_generate_sync)
|
||||
return audio, sample_rate
|
||||
Reference in New Issue
Block a user