""" PyTorch backend implementation for TTS and STT. """ from typing import Optional, List, Tuple import asyncio import logging import torch import numpy as np logger = logging.getLogger(__name__) from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS from .base import ( is_model_cached, get_torch_device, empty_device_cache, manual_seed, combine_voice_prompts as _combine_voice_prompts, model_load_progress, ) from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt from ..utils.audio import load_audio class PyTorchTTSBackend: """PyTorch-based TTS backend using Qwen3-TTS.""" def __init__(self, model_size: str = "1.7B"): self.model = None self.model_size = model_size self.device = self._get_device() self._current_model_size = None def _get_device(self) -> str: """Get the best available device.""" return get_torch_device(allow_xpu=True, allow_directml=True) def is_loaded(self) -> bool: """Check if model is loaded.""" return self.model is not None def _get_model_path(self, model_size: str) -> str: """ Get the HuggingFace Hub model ID. Args: model_size: Model size (1.7B or 0.6B) Returns: HuggingFace Hub model ID """ hf_model_map = { "1.7B": "Qwen/Qwen3-TTS-12Hz-1.7B-Base", "0.6B": "Qwen/Qwen3-TTS-12Hz-0.6B-Base", } if model_size not in hf_model_map: raise ValueError(f"Unknown model size: {model_size}") return hf_model_map[model_size] def _is_model_cached(self, model_size: str) -> bool: return is_model_cached(self._get_model_path(model_size)) async def load_model_async(self, model_size: Optional[str] = None): """ Lazy load the TTS model with automatic downloading from HuggingFace Hub. Args: model_size: Model size to load (1.7B or 0.6B) """ if model_size is None: model_size = self.model_size # If already loaded with correct size, return if self.model is not None and self._current_model_size == model_size: return # Unload existing model if different size requested if self.model is not None and self._current_model_size != model_size: self.unload_model() # Run blocking load in thread pool await asyncio.to_thread(self._load_model_sync, model_size) # Alias for compatibility load_model = load_model_async def _load_model_sync(self, model_size: str): """Synchronous model loading.""" model_name = f"qwen-tts-{model_size}" is_cached = self._is_model_cached(model_size) with model_load_progress(model_name, is_cached): from qwen_tts import Qwen3TTSModel model_path = self._get_model_path(model_size) logger.info("Loading TTS model %s on %s...", model_size, self.device) # Route both HF Hub and Transformers through a single cache root. # On Windows local setups, model assets can otherwise split between # .hf-cache/hub and .hf-cache/transformers, causing speech_tokenizer # and preprocessor_config.json to fail to resolve during load. from huggingface_hub import constants as hf_constants tts_cache_dir = hf_constants.HF_HUB_CACHE if self.device == "cpu": self.model = Qwen3TTSModel.from_pretrained( model_path, cache_dir=tts_cache_dir, torch_dtype=torch.float32, low_cpu_mem_usage=False, ) else: self.model = Qwen3TTSModel.from_pretrained( model_path, cache_dir=tts_cache_dir, device_map=self.device, torch_dtype=torch.bfloat16, ) self._current_model_size = model_size self.model_size = model_size logger.info("TTS model %s loaded successfully", model_size) def unload_model(self): """Unload the model to free memory.""" if self.model is not None: del self.model self.model = None self._current_model_size = None empty_device_cache(self.device) logger.info("TTS model unloaded") async def create_voice_prompt( self, audio_path: str, reference_text: str, use_cache: bool = True, ) -> Tuple[dict, bool]: """ Create voice prompt from reference audio. Args: audio_path: Path to reference audio file reference_text: Transcript of reference audio use_cache: Whether to use cached prompt if available Returns: Tuple of (voice_prompt_dict, was_cached) """ await self.load_model_async(None) # Check cache if enabled if use_cache: cache_key = get_cache_key(audio_path, reference_text) cached_prompt = get_cached_voice_prompt(cache_key) if cached_prompt is not None: # Cache stores as torch.Tensor but actual prompt is dict # Convert if needed if isinstance(cached_prompt, dict): # For PyTorch backend, the dict should contain tensors, not file paths # So we can safely return it return cached_prompt, True elif isinstance(cached_prompt, torch.Tensor): # Legacy cache format - convert to dict # This shouldn't happen in practice, but handle it return {"prompt": cached_prompt}, True def _create_prompt_sync(): """Run synchronous voice prompt creation in thread pool.""" # Inference runs with the process's default HF_HUB_OFFLINE # state. Forcing offline here (issue #462) regressed online # users whose libraries issue legitimate metadata lookups # during voice-prompt creation. return self.model.create_voice_clone_prompt( ref_audio=str(audio_path), ref_text=reference_text, x_vector_only_mode=False, ) # Run blocking operation in thread pool voice_prompt_items = await asyncio.to_thread(_create_prompt_sync) # Cache if enabled if use_cache: cache_key = get_cache_key(audio_path, reference_text) cache_voice_prompt(cache_key, voice_prompt_items) return voice_prompt_items, False async def combine_voice_prompts( self, audio_paths: List[str], reference_texts: List[str], ) -> Tuple[np.ndarray, str]: return await _combine_voice_prompts(audio_paths, reference_texts) async def generate( self, text: str, voice_prompt: dict, language: str = "en", seed: Optional[int] = None, instruct: Optional[str] = None, ) -> Tuple[np.ndarray, int]: """ Generate audio from text using voice prompt. Args: text: Text to synthesize voice_prompt: Voice prompt dictionary from create_voice_prompt language: Language code (en or zh) seed: Random seed for reproducibility instruct: Natural language instruction for speech delivery control Returns: Tuple of (audio_array, sample_rate) """ # Load model await self.load_model_async(None) def _generate_sync(): """Run synchronous generation in thread pool.""" # Set seed if provided if seed is not None: manual_seed(seed, self.device) # See _create_prompt_sync comment — inference runs with the # process's default HF_HUB_OFFLINE state (issue #462). wavs, sample_rate = self.model.generate_voice_clone( text=text, voice_clone_prompt=voice_prompt, language=LANGUAGE_CODE_TO_NAME.get(language, "auto"), instruct=instruct, ) return wavs[0], sample_rate # Run blocking inference in thread pool to avoid blocking event loop audio, sample_rate = await asyncio.to_thread(_generate_sync) return audio, sample_rate class PyTorchSTTBackend: """PyTorch-based STT backend using Whisper.""" def __init__(self, model_size: str = "base"): self.model = None self.processor = None self.model_size = model_size self.device = self._get_device() def _get_device(self) -> str: """Get the best available device.""" return get_torch_device(allow_xpu=True, allow_directml=True) def is_loaded(self) -> bool: """Check if model is loaded.""" return self.model is not None def _is_model_cached(self, model_size: str) -> bool: hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}") return is_model_cached(hf_repo) async def load_model_async(self, model_size: Optional[str] = None): """ Lazy load the Whisper model. Args: model_size: Model size (tiny, base, small, medium, large) """ if model_size is None: model_size = self.model_size if self.model is not None and self.model_size == model_size: return await asyncio.to_thread(self._load_model_sync, model_size) # Alias for compatibility load_model = load_model_async def _load_model_sync(self, model_size: str): """Synchronous model loading.""" progress_model_name = f"whisper-{model_size}" is_cached = self._is_model_cached(model_size) with model_load_progress(progress_model_name, is_cached): from transformers import WhisperProcessor, WhisperForConditionalGeneration model_name = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}") logger.info("Loading Whisper model %s on %s...", model_size, self.device) self.processor = WhisperProcessor.from_pretrained(model_name) self.model = WhisperForConditionalGeneration.from_pretrained(model_name) self.model.to(self.device) self.model_size = model_size logger.info("Whisper model %s loaded successfully", model_size) def unload_model(self): """Unload the model to free memory.""" if self.model is not None: del self.model del self.processor self.model = None self.processor = None empty_device_cache(self.device) logger.info("Whisper model unloaded") async def transcribe( self, audio_path: str, language: Optional[str] = None, model_size: Optional[str] = None, ) -> str: """ Transcribe audio to text. Args: audio_path: Path to audio file language: Optional language hint model_size: Optional model size override Returns: Transcribed text """ await self.load_model_async(model_size) def _transcribe_sync(): """Run synchronous transcription in thread pool.""" # Load audio audio, _sr = load_audio(audio_path, sample_rate=16000) # Inference runs with the process's default HF_HUB_OFFLINE # state — forcing offline here (issue #462) broke online users # whose `get_decoder_prompt_ids` / tokenizer calls issue # legitimate metadata lookups. # Process audio inputs = self.processor( audio, sampling_rate=16000, return_tensors="pt", ) inputs = inputs.to(self.device) # Generate transcription # If language is provided, force it; otherwise let Whisper auto-detect generate_kwargs = {} if language: forced_decoder_ids = self.processor.get_decoder_prompt_ids( language=language, task="transcribe", ) generate_kwargs["forced_decoder_ids"] = forced_decoder_ids with torch.no_grad(): predicted_ids = self.model.generate( inputs["input_features"], **generate_kwargs, ) # Decode transcription = self.processor.batch_decode( predicted_ids, skip_special_tokens=True, )[0] return transcription.strip() # Run blocking transcription in thread pool return await asyncio.to_thread(_transcribe_sync)