368 lines
13 KiB
Python
368 lines
13 KiB
Python
"""
|
|
MLX backend implementation for TTS and STT using mlx-audio.
|
|
"""
|
|
|
|
from typing import Optional, List, Tuple
|
|
import asyncio
|
|
import logging
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# PATCH: Import and apply offline patch BEFORE any huggingface_hub usage
|
|
# This prevents mlx_audio from making network requests when models are cached
|
|
from ..utils.hf_offline_patch import patch_huggingface_hub_offline, ensure_original_qwen_config_cached
|
|
|
|
patch_huggingface_hub_offline()
|
|
ensure_original_qwen_config_cached()
|
|
|
|
from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS
|
|
from .base import is_model_cached, combine_voice_prompts as _combine_voice_prompts, model_load_progress
|
|
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
|
|
|
|
|
|
class MLXTTSBackend:
|
|
"""MLX-based TTS backend using mlx-audio."""
|
|
|
|
def __init__(self, model_size: str = "1.7B"):
|
|
self.model = None
|
|
self.model_size = model_size
|
|
self._current_model_size = None
|
|
|
|
def is_loaded(self) -> bool:
|
|
"""Check if model is loaded."""
|
|
return self.model is not None
|
|
|
|
def _get_model_path(self, model_size: str) -> str:
|
|
"""
|
|
Get the MLX model path.
|
|
|
|
Args:
|
|
model_size: Model size (1.7B or 0.6B)
|
|
|
|
Returns:
|
|
HuggingFace Hub model ID for MLX
|
|
"""
|
|
mlx_model_map = {
|
|
"1.7B": "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16",
|
|
"0.6B": "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-bf16",
|
|
}
|
|
|
|
if model_size not in mlx_model_map:
|
|
raise ValueError(f"Unknown model size: {model_size}")
|
|
|
|
hf_model_id = mlx_model_map[model_size]
|
|
logger.info("Will download MLX model from HuggingFace Hub: %s", hf_model_id)
|
|
|
|
return hf_model_id
|
|
|
|
def _is_model_cached(self, model_size: str) -> bool:
|
|
return is_model_cached(
|
|
self._get_model_path(model_size),
|
|
weight_extensions=(".safetensors", ".bin", ".npz"),
|
|
)
|
|
|
|
async def load_model_async(self, model_size: Optional[str] = None):
|
|
"""
|
|
Lazy load the MLX TTS model.
|
|
|
|
Args:
|
|
model_size: Model size to load (1.7B or 0.6B)
|
|
"""
|
|
if model_size is None:
|
|
model_size = self.model_size
|
|
|
|
# If already loaded with correct size, return
|
|
if self.model is not None and self._current_model_size == model_size:
|
|
return
|
|
|
|
# Unload existing model if different size requested
|
|
if self.model is not None and self._current_model_size != model_size:
|
|
self.unload_model()
|
|
|
|
# Run blocking load in thread pool
|
|
await asyncio.to_thread(self._load_model_sync, model_size)
|
|
|
|
# Alias for compatibility
|
|
load_model = load_model_async
|
|
|
|
def _load_model_sync(self, model_size: str):
|
|
"""Synchronous model loading."""
|
|
model_path = self._get_model_path(model_size)
|
|
model_name = f"qwen-tts-{model_size}"
|
|
is_cached = self._is_model_cached(model_size)
|
|
|
|
with model_load_progress(model_name, is_cached):
|
|
from mlx_audio.tts import load
|
|
|
|
logger.info("Loading MLX TTS model %s...", model_size)
|
|
|
|
self.model = load(model_path)
|
|
|
|
self._current_model_size = model_size
|
|
self.model_size = model_size
|
|
logger.info("MLX TTS model %s loaded successfully", model_size)
|
|
|
|
def unload_model(self):
|
|
"""Unload the model to free memory."""
|
|
if self.model is not None:
|
|
del self.model
|
|
self.model = None
|
|
self._current_model_size = None
|
|
logger.info("MLX TTS model unloaded")
|
|
|
|
async def create_voice_prompt(
|
|
self,
|
|
audio_path: str,
|
|
reference_text: str,
|
|
use_cache: bool = True,
|
|
) -> Tuple[dict, bool]:
|
|
"""
|
|
Create voice prompt from reference audio.
|
|
|
|
MLX backend stores voice prompt as a dict with audio path and text.
|
|
The actual voice prompt processing happens during generation.
|
|
|
|
Args:
|
|
audio_path: Path to reference audio file
|
|
reference_text: Transcript of reference audio
|
|
use_cache: Whether to use cached prompt if available
|
|
|
|
Returns:
|
|
Tuple of (voice_prompt_dict, was_cached)
|
|
"""
|
|
await self.load_model_async(None)
|
|
|
|
# Check cache if enabled
|
|
if use_cache:
|
|
cache_key = get_cache_key(audio_path, reference_text)
|
|
cached_prompt = get_cached_voice_prompt(cache_key)
|
|
if cached_prompt is not None:
|
|
# Return cached prompt (should be dict format)
|
|
if isinstance(cached_prompt, dict):
|
|
# Validate that the cached audio file still exists
|
|
cached_audio_path = cached_prompt.get("ref_audio") or cached_prompt.get("ref_audio_path")
|
|
if cached_audio_path and Path(cached_audio_path).exists():
|
|
return cached_prompt, True
|
|
else:
|
|
# Cached file no longer exists, invalidate cache
|
|
logger.warning("Cached audio file not found: %s, regenerating prompt", cached_audio_path)
|
|
|
|
# MLX voice prompt format - store audio path and text
|
|
# The model will process this during generation
|
|
voice_prompt_items = {
|
|
"ref_audio": str(audio_path),
|
|
"ref_text": reference_text,
|
|
}
|
|
|
|
# Cache if enabled
|
|
if use_cache:
|
|
cache_key = get_cache_key(audio_path, reference_text)
|
|
cache_voice_prompt(cache_key, voice_prompt_items)
|
|
|
|
return voice_prompt_items, False
|
|
|
|
async def combine_voice_prompts(self, audio_paths, reference_texts):
|
|
return await _combine_voice_prompts(audio_paths, reference_texts)
|
|
|
|
async def generate(
|
|
self,
|
|
text: str,
|
|
voice_prompt: dict,
|
|
language: str = "en",
|
|
seed: Optional[int] = None,
|
|
instruct: Optional[str] = None,
|
|
) -> Tuple[np.ndarray, int]:
|
|
"""
|
|
Generate audio from text using voice prompt.
|
|
|
|
Args:
|
|
text: Text to synthesize
|
|
voice_prompt: Voice prompt dictionary with ref_audio and ref_text
|
|
language: Language code (en or zh) - may not be fully supported by MLX
|
|
seed: Random seed for reproducibility
|
|
instruct: Natural language instruction (may not be supported by MLX)
|
|
|
|
Returns:
|
|
Tuple of (audio_array, sample_rate)
|
|
"""
|
|
await self.load_model_async(None)
|
|
|
|
logger.info("Generating audio for text: %s", text)
|
|
|
|
def _generate_sync():
|
|
"""Run synchronous generation in thread pool."""
|
|
# MLX generate() returns a generator yielding GenerationResult objects
|
|
audio_chunks = []
|
|
sample_rate = 24000
|
|
lang = LANGUAGE_CODE_TO_NAME.get(language, "auto")
|
|
|
|
# Set seed if provided (MLX uses numpy random)
|
|
if seed is not None:
|
|
import mlx.core as mx
|
|
|
|
np.random.seed(seed)
|
|
mx.random.seed(seed)
|
|
|
|
# Extract voice prompt info
|
|
ref_audio = voice_prompt.get("ref_audio") or voice_prompt.get("ref_audio_path")
|
|
ref_text = voice_prompt.get("ref_text", "")
|
|
|
|
# Validate that the audio file exists
|
|
if ref_audio and not Path(ref_audio).exists():
|
|
logger.warning("Audio file not found: %s", ref_audio)
|
|
logger.warning("This may be due to a cached voice prompt referencing a deleted temp file.")
|
|
logger.warning("Regenerating without voice prompt.")
|
|
ref_audio = None
|
|
|
|
# Inference runs with the process's default HF_HUB_OFFLINE
|
|
# state. Forcing offline here (previously used to avoid lazy
|
|
# mlx_audio lookups hanging when the network drops mid-inference,
|
|
# issue #462) regressed online users because libraries make
|
|
# legitimate metadata calls during generation.
|
|
try:
|
|
if ref_audio:
|
|
# Check if generate accepts ref_audio parameter
|
|
import inspect
|
|
|
|
sig = inspect.signature(self.model.generate)
|
|
if "ref_audio" in sig.parameters:
|
|
# Generate with voice cloning
|
|
for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text, lang_code=lang):
|
|
audio_chunks.append(np.array(result.audio))
|
|
sample_rate = result.sample_rate
|
|
else:
|
|
# Fallback: generate without voice cloning
|
|
for result in self.model.generate(text, lang_code=lang):
|
|
audio_chunks.append(np.array(result.audio))
|
|
sample_rate = result.sample_rate
|
|
else:
|
|
# No voice prompt, generate normally
|
|
for result in self.model.generate(text, lang_code=lang):
|
|
audio_chunks.append(np.array(result.audio))
|
|
sample_rate = result.sample_rate
|
|
except Exception as e:
|
|
# If voice cloning fails, try without it
|
|
logger.warning("Voice cloning failed, generating without voice prompt: %s", e)
|
|
for result in self.model.generate(text, lang_code=lang):
|
|
audio_chunks.append(np.array(result.audio))
|
|
sample_rate = result.sample_rate
|
|
|
|
# Concatenate all chunks
|
|
if audio_chunks:
|
|
audio = np.concatenate([np.asarray(chunk, dtype=np.float32) for chunk in audio_chunks])
|
|
else:
|
|
# Fallback: empty audio
|
|
audio = np.array([], dtype=np.float32)
|
|
|
|
return audio, sample_rate
|
|
|
|
# Run blocking inference in thread pool
|
|
audio, sample_rate = await asyncio.to_thread(_generate_sync)
|
|
|
|
return audio, sample_rate
|
|
|
|
|
|
class MLXSTTBackend:
|
|
"""MLX-based STT backend using mlx-audio Whisper."""
|
|
|
|
def __init__(self, model_size: str = "base"):
|
|
self.model = None
|
|
self.model_size = model_size
|
|
|
|
def is_loaded(self) -> bool:
|
|
"""Check if model is loaded."""
|
|
return self.model is not None
|
|
|
|
def _is_model_cached(self, model_size: str) -> bool:
|
|
hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
|
|
return is_model_cached(hf_repo, weight_extensions=(".safetensors", ".bin", ".npz"))
|
|
|
|
async def load_model_async(self, model_size: Optional[str] = None):
|
|
"""
|
|
Lazy load the MLX Whisper model.
|
|
|
|
Args:
|
|
model_size: Model size (tiny, base, small, medium, large)
|
|
"""
|
|
if model_size is None:
|
|
model_size = self.model_size
|
|
|
|
if self.model is not None and self.model_size == model_size:
|
|
return
|
|
|
|
# Run blocking load in thread pool
|
|
await asyncio.to_thread(self._load_model_sync, model_size)
|
|
|
|
# Alias for compatibility
|
|
load_model = load_model_async
|
|
|
|
def _load_model_sync(self, model_size: str):
|
|
"""Synchronous model loading."""
|
|
progress_model_name = f"whisper-{model_size}"
|
|
is_cached = self._is_model_cached(model_size)
|
|
|
|
with model_load_progress(progress_model_name, is_cached):
|
|
from mlx_audio.stt import load
|
|
|
|
model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
|
|
logger.info("Loading MLX Whisper model %s...", model_size)
|
|
|
|
self.model = load(model_name)
|
|
|
|
self.model_size = model_size
|
|
logger.info("MLX Whisper model %s loaded successfully", model_size)
|
|
|
|
def unload_model(self):
|
|
"""Unload the model to free memory."""
|
|
if self.model is not None:
|
|
del self.model
|
|
self.model = None
|
|
logger.info("MLX Whisper model unloaded")
|
|
|
|
async def transcribe(
|
|
self,
|
|
audio_path: str,
|
|
language: Optional[str] = None,
|
|
model_size: Optional[str] = None,
|
|
) -> str:
|
|
"""
|
|
Transcribe audio to text.
|
|
|
|
Args:
|
|
audio_path: Path to audio file
|
|
language: Optional language hint
|
|
model_size: Optional model size override
|
|
|
|
Returns:
|
|
Transcribed text
|
|
"""
|
|
await self.load_model_async(model_size)
|
|
|
|
def _transcribe_sync():
|
|
"""Run synchronous transcription in thread pool."""
|
|
# MLX Whisper transcription using generate method
|
|
# The generate method accepts audio path directly
|
|
decode_options = {}
|
|
if language:
|
|
decode_options["language"] = language
|
|
|
|
# Inference runs with the process's default HF_HUB_OFFLINE
|
|
# state — see the comment in MLXTTSBackend.generate for the
|
|
# regression this revert fixes (issue #462).
|
|
result = self.model.generate(str(audio_path), **decode_options)
|
|
|
|
# Extract text from result
|
|
if isinstance(result, str):
|
|
return result.strip()
|
|
elif isinstance(result, dict):
|
|
return result.get("text", "").strip()
|
|
elif hasattr(result, "text"):
|
|
return result.text.strip()
|
|
else:
|
|
return str(result).strip()
|
|
|
|
# Run blocking transcription in thread pool
|
|
return await asyncio.to_thread(_transcribe_sync)
|