Initial commit
This commit is contained in:
367
backend/backends/mlx_backend.py
Normal file
367
backend/backends/mlx_backend.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
MLX backend implementation for TTS and STT using mlx-audio.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# PATCH: Import and apply offline patch BEFORE any huggingface_hub usage
|
||||
# This prevents mlx_audio from making network requests when models are cached
|
||||
from ..utils.hf_offline_patch import patch_huggingface_hub_offline, ensure_original_qwen_config_cached
|
||||
|
||||
patch_huggingface_hub_offline()
|
||||
ensure_original_qwen_config_cached()
|
||||
|
||||
from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS
|
||||
from .base import is_model_cached, combine_voice_prompts as _combine_voice_prompts, model_load_progress
|
||||
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
|
||||
|
||||
|
||||
class MLXTTSBackend:
|
||||
"""MLX-based TTS backend using mlx-audio."""
|
||||
|
||||
def __init__(self, model_size: str = "1.7B"):
|
||||
self.model = None
|
||||
self.model_size = model_size
|
||||
self._current_model_size = None
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self.model is not None
|
||||
|
||||
def _get_model_path(self, model_size: str) -> str:
|
||||
"""
|
||||
Get the MLX model path.
|
||||
|
||||
Args:
|
||||
model_size: Model size (1.7B or 0.6B)
|
||||
|
||||
Returns:
|
||||
HuggingFace Hub model ID for MLX
|
||||
"""
|
||||
mlx_model_map = {
|
||||
"1.7B": "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16",
|
||||
"0.6B": "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-bf16",
|
||||
}
|
||||
|
||||
if model_size not in mlx_model_map:
|
||||
raise ValueError(f"Unknown model size: {model_size}")
|
||||
|
||||
hf_model_id = mlx_model_map[model_size]
|
||||
logger.info("Will download MLX model from HuggingFace Hub: %s", hf_model_id)
|
||||
|
||||
return hf_model_id
|
||||
|
||||
def _is_model_cached(self, model_size: str) -> bool:
|
||||
return is_model_cached(
|
||||
self._get_model_path(model_size),
|
||||
weight_extensions=(".safetensors", ".bin", ".npz"),
|
||||
)
|
||||
|
||||
async def load_model_async(self, model_size: Optional[str] = None):
|
||||
"""
|
||||
Lazy load the MLX TTS model.
|
||||
|
||||
Args:
|
||||
model_size: Model size to load (1.7B or 0.6B)
|
||||
"""
|
||||
if model_size is None:
|
||||
model_size = self.model_size
|
||||
|
||||
# If already loaded with correct size, return
|
||||
if self.model is not None and self._current_model_size == model_size:
|
||||
return
|
||||
|
||||
# Unload existing model if different size requested
|
||||
if self.model is not None and self._current_model_size != model_size:
|
||||
self.unload_model()
|
||||
|
||||
# Run blocking load in thread pool
|
||||
await asyncio.to_thread(self._load_model_sync, model_size)
|
||||
|
||||
# Alias for compatibility
|
||||
load_model = load_model_async
|
||||
|
||||
def _load_model_sync(self, model_size: str):
|
||||
"""Synchronous model loading."""
|
||||
model_path = self._get_model_path(model_size)
|
||||
model_name = f"qwen-tts-{model_size}"
|
||||
is_cached = self._is_model_cached(model_size)
|
||||
|
||||
with model_load_progress(model_name, is_cached):
|
||||
from mlx_audio.tts import load
|
||||
|
||||
logger.info("Loading MLX TTS model %s...", model_size)
|
||||
|
||||
self.model = load(model_path)
|
||||
|
||||
self._current_model_size = model_size
|
||||
self.model_size = model_size
|
||||
logger.info("MLX TTS model %s loaded successfully", model_size)
|
||||
|
||||
def unload_model(self):
|
||||
"""Unload the model to free memory."""
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
self._current_model_size = None
|
||||
logger.info("MLX TTS model unloaded")
|
||||
|
||||
async def create_voice_prompt(
|
||||
self,
|
||||
audio_path: str,
|
||||
reference_text: str,
|
||||
use_cache: bool = True,
|
||||
) -> Tuple[dict, bool]:
|
||||
"""
|
||||
Create voice prompt from reference audio.
|
||||
|
||||
MLX backend stores voice prompt as a dict with audio path and text.
|
||||
The actual voice prompt processing happens during generation.
|
||||
|
||||
Args:
|
||||
audio_path: Path to reference audio file
|
||||
reference_text: Transcript of reference audio
|
||||
use_cache: Whether to use cached prompt if available
|
||||
|
||||
Returns:
|
||||
Tuple of (voice_prompt_dict, was_cached)
|
||||
"""
|
||||
await self.load_model_async(None)
|
||||
|
||||
# Check cache if enabled
|
||||
if use_cache:
|
||||
cache_key = get_cache_key(audio_path, reference_text)
|
||||
cached_prompt = get_cached_voice_prompt(cache_key)
|
||||
if cached_prompt is not None:
|
||||
# Return cached prompt (should be dict format)
|
||||
if isinstance(cached_prompt, dict):
|
||||
# Validate that the cached audio file still exists
|
||||
cached_audio_path = cached_prompt.get("ref_audio") or cached_prompt.get("ref_audio_path")
|
||||
if cached_audio_path and Path(cached_audio_path).exists():
|
||||
return cached_prompt, True
|
||||
else:
|
||||
# Cached file no longer exists, invalidate cache
|
||||
logger.warning("Cached audio file not found: %s, regenerating prompt", cached_audio_path)
|
||||
|
||||
# MLX voice prompt format - store audio path and text
|
||||
# The model will process this during generation
|
||||
voice_prompt_items = {
|
||||
"ref_audio": str(audio_path),
|
||||
"ref_text": reference_text,
|
||||
}
|
||||
|
||||
# Cache if enabled
|
||||
if use_cache:
|
||||
cache_key = get_cache_key(audio_path, reference_text)
|
||||
cache_voice_prompt(cache_key, voice_prompt_items)
|
||||
|
||||
return voice_prompt_items, False
|
||||
|
||||
async def combine_voice_prompts(self, audio_paths, reference_texts):
|
||||
return await _combine_voice_prompts(audio_paths, reference_texts)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: Optional[int] = None,
|
||||
instruct: Optional[str] = None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Generate audio from text using voice prompt.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
voice_prompt: Voice prompt dictionary with ref_audio and ref_text
|
||||
language: Language code (en or zh) - may not be fully supported by MLX
|
||||
seed: Random seed for reproducibility
|
||||
instruct: Natural language instruction (may not be supported by MLX)
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
await self.load_model_async(None)
|
||||
|
||||
logger.info("Generating audio for text: %s", text)
|
||||
|
||||
def _generate_sync():
|
||||
"""Run synchronous generation in thread pool."""
|
||||
# MLX generate() returns a generator yielding GenerationResult objects
|
||||
audio_chunks = []
|
||||
sample_rate = 24000
|
||||
lang = LANGUAGE_CODE_TO_NAME.get(language, "auto")
|
||||
|
||||
# Set seed if provided (MLX uses numpy random)
|
||||
if seed is not None:
|
||||
import mlx.core as mx
|
||||
|
||||
np.random.seed(seed)
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Extract voice prompt info
|
||||
ref_audio = voice_prompt.get("ref_audio") or voice_prompt.get("ref_audio_path")
|
||||
ref_text = voice_prompt.get("ref_text", "")
|
||||
|
||||
# Validate that the audio file exists
|
||||
if ref_audio and not Path(ref_audio).exists():
|
||||
logger.warning("Audio file not found: %s", ref_audio)
|
||||
logger.warning("This may be due to a cached voice prompt referencing a deleted temp file.")
|
||||
logger.warning("Regenerating without voice prompt.")
|
||||
ref_audio = None
|
||||
|
||||
# Inference runs with the process's default HF_HUB_OFFLINE
|
||||
# state. Forcing offline here (previously used to avoid lazy
|
||||
# mlx_audio lookups hanging when the network drops mid-inference,
|
||||
# issue #462) regressed online users because libraries make
|
||||
# legitimate metadata calls during generation.
|
||||
try:
|
||||
if ref_audio:
|
||||
# Check if generate accepts ref_audio parameter
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(self.model.generate)
|
||||
if "ref_audio" in sig.parameters:
|
||||
# Generate with voice cloning
|
||||
for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text, lang_code=lang):
|
||||
audio_chunks.append(np.array(result.audio))
|
||||
sample_rate = result.sample_rate
|
||||
else:
|
||||
# Fallback: generate without voice cloning
|
||||
for result in self.model.generate(text, lang_code=lang):
|
||||
audio_chunks.append(np.array(result.audio))
|
||||
sample_rate = result.sample_rate
|
||||
else:
|
||||
# No voice prompt, generate normally
|
||||
for result in self.model.generate(text, lang_code=lang):
|
||||
audio_chunks.append(np.array(result.audio))
|
||||
sample_rate = result.sample_rate
|
||||
except Exception as e:
|
||||
# If voice cloning fails, try without it
|
||||
logger.warning("Voice cloning failed, generating without voice prompt: %s", e)
|
||||
for result in self.model.generate(text, lang_code=lang):
|
||||
audio_chunks.append(np.array(result.audio))
|
||||
sample_rate = result.sample_rate
|
||||
|
||||
# Concatenate all chunks
|
||||
if audio_chunks:
|
||||
audio = np.concatenate([np.asarray(chunk, dtype=np.float32) for chunk in audio_chunks])
|
||||
else:
|
||||
# Fallback: empty audio
|
||||
audio = np.array([], dtype=np.float32)
|
||||
|
||||
return audio, sample_rate
|
||||
|
||||
# Run blocking inference in thread pool
|
||||
audio, sample_rate = await asyncio.to_thread(_generate_sync)
|
||||
|
||||
return audio, sample_rate
|
||||
|
||||
|
||||
class MLXSTTBackend:
|
||||
"""MLX-based STT backend using mlx-audio Whisper."""
|
||||
|
||||
def __init__(self, model_size: str = "base"):
|
||||
self.model = None
|
||||
self.model_size = model_size
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self.model is not None
|
||||
|
||||
def _is_model_cached(self, model_size: str) -> bool:
|
||||
hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
|
||||
return is_model_cached(hf_repo, weight_extensions=(".safetensors", ".bin", ".npz"))
|
||||
|
||||
async def load_model_async(self, model_size: Optional[str] = None):
|
||||
"""
|
||||
Lazy load the MLX Whisper model.
|
||||
|
||||
Args:
|
||||
model_size: Model size (tiny, base, small, medium, large)
|
||||
"""
|
||||
if model_size is None:
|
||||
model_size = self.model_size
|
||||
|
||||
if self.model is not None and self.model_size == model_size:
|
||||
return
|
||||
|
||||
# Run blocking load in thread pool
|
||||
await asyncio.to_thread(self._load_model_sync, model_size)
|
||||
|
||||
# Alias for compatibility
|
||||
load_model = load_model_async
|
||||
|
||||
def _load_model_sync(self, model_size: str):
|
||||
"""Synchronous model loading."""
|
||||
progress_model_name = f"whisper-{model_size}"
|
||||
is_cached = self._is_model_cached(model_size)
|
||||
|
||||
with model_load_progress(progress_model_name, is_cached):
|
||||
from mlx_audio.stt import load
|
||||
|
||||
model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
|
||||
logger.info("Loading MLX Whisper model %s...", model_size)
|
||||
|
||||
self.model = load(model_name)
|
||||
|
||||
self.model_size = model_size
|
||||
logger.info("MLX Whisper model %s loaded successfully", model_size)
|
||||
|
||||
def unload_model(self):
|
||||
"""Unload the model to free memory."""
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
logger.info("MLX Whisper model unloaded")
|
||||
|
||||
async def transcribe(
|
||||
self,
|
||||
audio_path: str,
|
||||
language: Optional[str] = None,
|
||||
model_size: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe audio to text.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
language: Optional language hint
|
||||
model_size: Optional model size override
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
await self.load_model_async(model_size)
|
||||
|
||||
def _transcribe_sync():
|
||||
"""Run synchronous transcription in thread pool."""
|
||||
# MLX Whisper transcription using generate method
|
||||
# The generate method accepts audio path directly
|
||||
decode_options = {}
|
||||
if language:
|
||||
decode_options["language"] = language
|
||||
|
||||
# Inference runs with the process's default HF_HUB_OFFLINE
|
||||
# state — see the comment in MLXTTSBackend.generate for the
|
||||
# regression this revert fixes (issue #462).
|
||||
result = self.model.generate(str(audio_path), **decode_options)
|
||||
|
||||
# Extract text from result
|
||||
if isinstance(result, str):
|
||||
return result.strip()
|
||||
elif isinstance(result, dict):
|
||||
return result.get("text", "").strip()
|
||||
elif hasattr(result, "text"):
|
||||
return result.text.strip()
|
||||
else:
|
||||
return str(result).strip()
|
||||
|
||||
# Run blocking transcription in thread pool
|
||||
return await asyncio.to_thread(_transcribe_sync)
|
||||
Reference in New Issue
Block a user