Initial commit
This commit is contained in:
184
backend/backends/luxtts_backend.py
Normal file
184
backend/backends/luxtts_backend.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
LuxTTS backend implementation.
|
||||
|
||||
Wraps the LuxTTS (ZipVoice) model for zero-shot voice cloning.
|
||||
~1GB VRAM, 48kHz output, 150x realtime on CPU.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import TTSBackend
|
||||
from .base import (
|
||||
is_model_cached,
|
||||
get_torch_device,
|
||||
empty_device_cache,
|
||||
manual_seed,
|
||||
combine_voice_prompts as _combine_voice_prompts,
|
||||
model_load_progress,
|
||||
)
|
||||
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HuggingFace repo for model weight detection
|
||||
LUXTTS_HF_REPO = "YatharthS/LuxTTS"
|
||||
|
||||
|
||||
class LuxTTSBackend:
|
||||
"""LuxTTS backend for zero-shot voice cloning."""
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.model_size = "default" # LuxTTS has only one model size
|
||||
self._device = None
|
||||
|
||||
def _get_device(self) -> str:
|
||||
return get_torch_device(allow_mps=True, allow_xpu=True)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
if self._device is None:
|
||||
self._device = self._get_device()
|
||||
return self._device
|
||||
|
||||
def _get_model_path(self, model_size: str) -> str:
|
||||
return LUXTTS_HF_REPO
|
||||
|
||||
def _is_model_cached(self, model_size: str = "default") -> bool:
|
||||
return is_model_cached(
|
||||
LUXTTS_HF_REPO,
|
||||
weight_extensions=(".pt", ".safetensors", ".onnx", ".bin"),
|
||||
)
|
||||
|
||||
async def load_model(self, model_size: str = "default") -> None:
|
||||
"""Load the LuxTTS model."""
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
await asyncio.to_thread(self._load_model_sync)
|
||||
|
||||
def _load_model_sync(self):
|
||||
model_name = "luxtts"
|
||||
is_cached = self._is_model_cached()
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
from zipvoice.luxvoice import LuxTTS
|
||||
|
||||
device = self.device
|
||||
logger.info(f"Loading LuxTTS on {device}...")
|
||||
|
||||
if device == "cpu":
|
||||
import os
|
||||
|
||||
threads = os.cpu_count() or 4
|
||||
self.model = LuxTTS(
|
||||
model_path=LUXTTS_HF_REPO,
|
||||
device="cpu",
|
||||
threads=min(threads, 8),
|
||||
)
|
||||
else:
|
||||
self.model = LuxTTS(model_path=LUXTTS_HF_REPO, device=device)
|
||||
|
||||
logger.info("LuxTTS loaded successfully")
|
||||
|
||||
def unload_model(self) -> None:
|
||||
"""Unload model to free memory."""
|
||||
if self.model is not None:
|
||||
device = self.device
|
||||
del self.model
|
||||
self.model = None
|
||||
self._device = None
|
||||
|
||||
empty_device_cache(device)
|
||||
|
||||
logger.info("LuxTTS unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
LuxTTS uses its own encode_prompt() which runs Whisper ASR internally
|
||||
to transcribe the reference. The reference_text parameter is not used
|
||||
by LuxTTS itself, but we include it in the cache key for consistency.
|
||||
"""
|
||||
await self.load_model()
|
||||
|
||||
# Compute cache key once for both lookup and storage
|
||||
cache_key = ("luxtts_" + get_cache_key(audio_path, reference_text)) if use_cache else None
|
||||
|
||||
if cache_key:
|
||||
cached = get_cached_voice_prompt(cache_key)
|
||||
if cached is not None and isinstance(cached, dict):
|
||||
return cached, True
|
||||
|
||||
def _encode_sync():
|
||||
return self.model.encode_prompt(
|
||||
prompt_audio=str(audio_path),
|
||||
duration=5,
|
||||
rms=0.01,
|
||||
)
|
||||
|
||||
encoded = await asyncio.to_thread(_encode_sync)
|
||||
|
||||
if cache_key:
|
||||
cache_voice_prompt(cache_key, encoded)
|
||||
|
||||
return encoded, False
|
||||
|
||||
async def combine_voice_prompts(self, audio_paths, reference_texts):
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts, sample_rate=24000)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio from text using LuxTTS.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Encoded prompt dict from encode_prompt()
|
||||
language: Language code (LuxTTS is English-focused)
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Not supported by LuxTTS (ignored)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model()
|
||||
|
||||
def _generate_sync():
|
||||
if seed is not None:
|
||||
manual_seed(seed, self.device)
|
||||
|
||||
wav = self.model.generate_speech(
|
||||
text=text,
|
||||
encode_dict=voice_prompt,
|
||||
num_steps=4,
|
||||
guidance_scale=3.0,
|
||||
t_shift=0.5,
|
||||
speed=1.0,
|
||||
return_smooth=False, # 48kHz output
|
||||
)
|
||||
|
||||
# LuxTTS returns a tensor (may be on GPU/MPS), move to CPU first
|
||||
audio = wav.detach().cpu().numpy().squeeze()
|
||||
return audio, 48000
|
||||
|
||||
return await asyncio.to_thread(_generate_sync)
|
||||
Reference in New Issue
Block a user