185 lines
5.3 KiB
Python
185 lines
5.3 KiB
Python
"""
|
|
LuxTTS backend implementation.
|
|
|
|
Wraps the LuxTTS (ZipVoice) model for zero-shot voice cloning.
|
|
~1GB VRAM, 48kHz output, 150x realtime on CPU.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from . import TTSBackend
|
|
from .base import (
|
|
is_model_cached,
|
|
get_torch_device,
|
|
empty_device_cache,
|
|
manual_seed,
|
|
combine_voice_prompts as _combine_voice_prompts,
|
|
model_load_progress,
|
|
)
|
|
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# HuggingFace repo for model weight detection
|
|
LUXTTS_HF_REPO = "YatharthS/LuxTTS"
|
|
|
|
|
|
class LuxTTSBackend:
|
|
"""LuxTTS backend for zero-shot voice cloning."""
|
|
|
|
def __init__(self):
|
|
self.model = None
|
|
self.model_size = "default" # LuxTTS has only one model size
|
|
self._device = None
|
|
|
|
def _get_device(self) -> str:
|
|
return get_torch_device(allow_mps=True, allow_xpu=True)
|
|
|
|
def is_loaded(self) -> bool:
|
|
return self.model is not None
|
|
|
|
@property
|
|
def device(self) -> str:
|
|
if self._device is None:
|
|
self._device = self._get_device()
|
|
return self._device
|
|
|
|
def _get_model_path(self, model_size: str) -> str:
|
|
return LUXTTS_HF_REPO
|
|
|
|
def _is_model_cached(self, model_size: str = "default") -> bool:
|
|
return is_model_cached(
|
|
LUXTTS_HF_REPO,
|
|
weight_extensions=(".pt", ".safetensors", ".onnx", ".bin"),
|
|
)
|
|
|
|
async def load_model(self, model_size: str = "default") -> None:
|
|
"""Load the LuxTTS model."""
|
|
if self.model is not None:
|
|
return
|
|
|
|
await asyncio.to_thread(self._load_model_sync)
|
|
|
|
def _load_model_sync(self):
|
|
model_name = "luxtts"
|
|
is_cached = self._is_model_cached()
|
|
|
|
with model_load_progress(model_name, is_cached):
|
|
from zipvoice.luxvoice import LuxTTS
|
|
|
|
device = self.device
|
|
logger.info(f"Loading LuxTTS on {device}...")
|
|
|
|
if device == "cpu":
|
|
import os
|
|
|
|
threads = os.cpu_count() or 4
|
|
self.model = LuxTTS(
|
|
model_path=LUXTTS_HF_REPO,
|
|
device="cpu",
|
|
threads=min(threads, 8),
|
|
)
|
|
else:
|
|
self.model = LuxTTS(model_path=LUXTTS_HF_REPO, device=device)
|
|
|
|
logger.info("LuxTTS loaded successfully")
|
|
|
|
def unload_model(self) -> None:
|
|
"""Unload model to free memory."""
|
|
if self.model is not None:
|
|
device = self.device
|
|
del self.model
|
|
self.model = None
|
|
self._device = None
|
|
|
|
empty_device_cache(device)
|
|
|
|
logger.info("LuxTTS unloaded")
|
|
|
|
async def create_voice_prompt(
|
|
self,
|
|
audio_path: str,
|
|
reference_text: str,
|
|
use_cache: bool = True,
|
|
) -> Tuple[dict, bool]:
|
|
"""
|
|
Create voice prompt from reference audio.
|
|
|
|
LuxTTS uses its own encode_prompt() which runs Whisper ASR internally
|
|
to transcribe the reference. The reference_text parameter is not used
|
|
by LuxTTS itself, but we include it in the cache key for consistency.
|
|
"""
|
|
await self.load_model()
|
|
|
|
# Compute cache key once for both lookup and storage
|
|
cache_key = ("luxtts_" + get_cache_key(audio_path, reference_text)) if use_cache else None
|
|
|
|
if cache_key:
|
|
cached = get_cached_voice_prompt(cache_key)
|
|
if cached is not None and isinstance(cached, dict):
|
|
return cached, True
|
|
|
|
def _encode_sync():
|
|
return self.model.encode_prompt(
|
|
prompt_audio=str(audio_path),
|
|
duration=5,
|
|
rms=0.01,
|
|
)
|
|
|
|
encoded = await asyncio.to_thread(_encode_sync)
|
|
|
|
if cache_key:
|
|
cache_voice_prompt(cache_key, encoded)
|
|
|
|
return encoded, False
|
|
|
|
async def combine_voice_prompts(self, audio_paths, reference_texts):
|
|
return await _combine_voice_prompts(audio_paths, reference_texts, sample_rate=24000)
|
|
|
|
async def generate(
|
|
self,
|
|
text: str,
|
|
voice_prompt: dict,
|
|
language: str = "en",
|
|
seed: Optional[int] = None,
|
|
instruct: Optional[str] = None,
|
|
) -> Tuple[np.ndarray, int]:
|
|
"""
|
|
Generate audio from text using LuxTTS.
|
|
|
|
Args:
|
|
text: Text to synthesize
|
|
voice_prompt: Encoded prompt dict from encode_prompt()
|
|
language: Language code (LuxTTS is English-focused)
|
|
seed: Random seed for reproducibility
|
|
instruct: Not supported by LuxTTS (ignored)
|
|
|
|
Returns:
|
|
Tuple of (audio_array, sample_rate)
|
|
"""
|
|
await self.load_model()
|
|
|
|
def _generate_sync():
|
|
if seed is not None:
|
|
manual_seed(seed, self.device)
|
|
|
|
wav = self.model.generate_speech(
|
|
text=text,
|
|
encode_dict=voice_prompt,
|
|
num_steps=4,
|
|
guidance_scale=3.0,
|
|
t_shift=0.5,
|
|
speed=1.0,
|
|
return_smooth=False, # 48kHz output
|
|
)
|
|
|
|
# LuxTTS returns a tensor (may be on GPU/MPS), move to CPU first
|
|
audio = wav.detach().cpu().numpy().squeeze()
|
|
return audio, 48000
|
|
|
|
return await asyncio.to_thread(_generate_sync)
|