""" HumeAI TADA TTS backend implementation. Wraps HumeAI's TADA (Text-Acoustic Dual Alignment) model for high-quality voice cloning. Two model variants: - tada-1b: English-only, ~2B params (Llama 3.2 1B base) - tada-3b-ml: Multilingual, ~4B params (Llama 3.2 3B base) Both use a shared encoder/codec (HumeAI/tada-codec). The encoder produces 1:1 aligned token embeddings from reference audio, and the causal LM generates speech via flow-matching diffusion. 24kHz output, bf16 inference on CUDA, fp32 on CPU. """ import asyncio import logging import threading from typing import ClassVar, List, Optional, Tuple import numpy as np from . import TTSBackend from .base import ( is_model_cached, get_torch_device, empty_device_cache, manual_seed, combine_voice_prompts as _combine_voice_prompts, model_load_progress, ) from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt logger = logging.getLogger(__name__) # HuggingFace repos TADA_CODEC_REPO = "HumeAI/tada-codec" TADA_1B_REPO = "HumeAI/tada-1b" TADA_3B_ML_REPO = "HumeAI/tada-3b-ml" TADA_MODEL_REPOS = { "1B": TADA_1B_REPO, "3B": TADA_3B_ML_REPO, } # Key weight files for cache detection _TADA_MODEL_WEIGHT_FILES = [ "model.safetensors", ] _TADA_CODEC_WEIGHT_FILES = [ "encoder/model.safetensors", ] class HumeTadaBackend: """HumeAI TADA TTS backend for high-quality voice cloning.""" _load_lock: ClassVar[threading.Lock] = threading.Lock() def __init__(self): self.model = None self.encoder = None self.model_size = "1B" # default to 1B self._device = None self._model_load_lock = asyncio.Lock() def _get_device(self) -> str: # Force CPU on macOS — MPS has issues with flow matching # and large vocab lm_head (>65536 output channels) return get_torch_device(force_cpu_on_mac=True, allow_xpu=True) def is_loaded(self) -> bool: return self.model is not None def _get_model_path(self, model_size: str = "1B") -> str: return TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO) def _is_model_cached(self, model_size: str = "1B") -> bool: repo = TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO) model_cached = is_model_cached(repo, required_files=_TADA_MODEL_WEIGHT_FILES) codec_cached = is_model_cached(TADA_CODEC_REPO, required_files=_TADA_CODEC_WEIGHT_FILES) return model_cached and codec_cached async def load_model(self, model_size: str = "1B") -> None: """Load the TADA model and encoder.""" if self.model is not None and self.model_size == model_size: return async with self._model_load_lock: if self.model is not None and self.model_size == model_size: return # Unload existing model if switching sizes if self.model is not None: self.unload_model() self.model_size = model_size await asyncio.to_thread(self._load_model_sync, model_size) def _load_model_sync(self, model_size: str = "1B"): """Synchronous model loading with progress tracking.""" model_name = f"tada-{model_size.lower()}" is_cached = self._is_model_cached(model_size) repo = TADA_MODEL_REPOS.get(model_size, TADA_1B_REPO) with model_load_progress(model_name, is_cached): # Install DAC shim before importing tada — tada's encoder/decoder # import dac.nn.layers.Snake1d which requires the descript-audio-codec # package. The real package pulls in onnx/tensorboard/matplotlib via # descript-audiotools, so we use a lightweight shim instead. from ..utils.dac_shim import install_dac_shim install_dac_shim() import torch from huggingface_hub import snapshot_download device = self._get_device() self._device = device logger.info(f"Loading HumeAI TADA {model_size} on {device}...") # Download codec (encoder + decoder) if not cached logger.info("Downloading TADA codec...") snapshot_download( repo_id=TADA_CODEC_REPO, token=None, allow_patterns=["*.safetensors", "*.json", "*.txt", "*.bin"], ) # Download model weights if not cached logger.info(f"Downloading TADA {model_size} model...") snapshot_download( repo_id=repo, token=None, allow_patterns=["*.safetensors", "*.json", "*.txt", "*.bin", "*.model"], ) # TADA hardcodes "meta-llama/Llama-3.2-1B" as the tokenizer # source in its Aligner and TadaForCausalLM.from_pretrained(). # That repo is gated (requires Meta license acceptance). # Download the tokenizer from an ungated mirror and get its # local cache path so we can point TADA at it directly. logger.info("Downloading Llama tokenizer (ungated mirror)...") tokenizer_path = snapshot_download( repo_id="unsloth/Llama-3.2-1B", token=None, allow_patterns=["tokenizer*", "special_tokens*"], ) # Determine dtype — use bf16 on CUDA/XPU for ~50% memory savings if device == "cuda" and torch.cuda.is_bf16_supported(): model_dtype = torch.bfloat16 elif device == "xpu": # Intel Arc (Alchemist+) supports bf16 natively model_dtype = torch.bfloat16 else: model_dtype = torch.float32 # Patch the Aligner config class to use the local tokenizer # path instead of the gated "meta-llama/Llama-3.2-1B" default. # This avoids monkey-patching AutoTokenizer.from_pretrained # which corrupts the classmethod descriptor for other engines. from tada.modules.aligner import AlignerConfig AlignerConfig.tokenizer_name = tokenizer_path # Load encoder (only needed for voice prompt encoding) from tada.modules.encoder import Encoder logger.info("Loading TADA encoder...") self.encoder = Encoder.from_pretrained(TADA_CODEC_REPO, subfolder="encoder").to(device) self.encoder.eval() # Load the causal LM (includes decoder for wav generation). # TadaForCausalLM.from_pretrained() calls # getattr(config, "tokenizer_name", "meta-llama/Llama-3.2-1B") # which hits the gated repo. Pre-load the config from HF, # inject the local tokenizer path, then pass it in. from tada.modules.tada import TadaForCausalLM, TadaConfig logger.info(f"Loading TADA {model_size} model...") config = TadaConfig.from_pretrained(repo) config.tokenizer_name = tokenizer_path self.model = TadaForCausalLM.from_pretrained(repo, config=config, torch_dtype=model_dtype).to(device) self.model.eval() logger.info(f"HumeAI TADA {model_size} loaded successfully on {device}") def unload_model(self) -> None: """Unload model and encoder to free memory.""" if self.model is not None: del self.model self.model = None if self.encoder is not None: del self.encoder self.encoder = None device = self._device self._device = None if device: empty_device_cache(device) logger.info("HumeAI TADA unloaded") async def create_voice_prompt( self, audio_path: str, reference_text: str, use_cache: bool = True, ) -> Tuple[dict, bool]: """ Create voice prompt from reference audio using TADA's encoder. TADA's encoder performs forced alignment between audio and text tokens, producing an EncoderOutput with 1:1 token-audio alignment. If no reference_text is provided, the encoder uses built-in ASR (English only). We serialize the EncoderOutput to a dict for caching. """ await self.load_model(self.model_size) cache_key = ("tada_" + get_cache_key(audio_path, reference_text)) if use_cache else None if cache_key: cached = get_cached_voice_prompt(cache_key) if cached is not None and isinstance(cached, dict): return cached, True def _encode_sync(): import torch import soundfile as sf device = self._device # Load audio with soundfile (torchaudio 2.10+ requires torchcodec) audio_np, sr = sf.read(str(audio_path), dtype="float32") audio = torch.from_numpy(audio_np).float() if audio.ndim == 1: audio = audio.unsqueeze(0) # (samples,) -> (1, samples) else: audio = audio.T # (samples, channels) -> (channels, samples) audio = audio.to(device) # Encode with forced alignment text_arg = [reference_text] if reference_text else None prompt = self.encoder(audio, text=text_arg, sample_rate=sr) # Serialize EncoderOutput to a dict of CPU tensors for caching prompt_dict = {} for field_name in prompt.__dataclass_fields__: val = getattr(prompt, field_name) if isinstance(val, torch.Tensor): prompt_dict[field_name] = val.detach().cpu() elif isinstance(val, list): prompt_dict[field_name] = val elif isinstance(val, (int, float)): prompt_dict[field_name] = val else: prompt_dict[field_name] = val return prompt_dict encoded = await asyncio.to_thread(_encode_sync) if cache_key: cache_voice_prompt(cache_key, encoded) return encoded, False async def combine_voice_prompts( self, audio_paths: List[str], reference_texts: List[str], ) -> Tuple[np.ndarray, str]: return await _combine_voice_prompts(audio_paths, reference_texts, sample_rate=24000) async def generate( self, text: str, voice_prompt: dict, language: str = "en", seed: Optional[int] = None, instruct: Optional[str] = None, ) -> Tuple[np.ndarray, int]: """ Generate audio from text using HumeAI TADA. Args: text: Text to synthesize voice_prompt: Serialized EncoderOutput dict from create_voice_prompt() language: Language code (en, ar, de, es, fr, it, ja, pl, pt, zh) seed: Random seed for reproducibility instruct: Not supported by TADA (ignored) Returns: Tuple of (audio_array, sample_rate=24000) """ await self.load_model(self.model_size) def _generate_sync(): import torch from tada.modules.encoder import EncoderOutput if seed is not None: manual_seed(seed, self._device) device = self._device # Reconstruct EncoderOutput from the cached dict restored = {} for k, v in voice_prompt.items(): if isinstance(v, torch.Tensor): # Move to device and match model dtype for float tensors if v.is_floating_point(): model_dtype = next(self.model.parameters()).dtype restored[k] = v.to(device=device, dtype=model_dtype) else: restored[k] = v.to(device=device) else: restored[k] = v prompt = EncoderOutput(**restored) # For non-English with the 3B-ML model, we could reload the # encoder with the language-specific aligner. However, the # generation itself is language-agnostic — only the encoder's # aligner changes. Since we encode at create_voice_prompt time, # the language is already baked in. For simplicity, we don't # reload the encoder here. logger.info(f"[TADA] Generating ({language}), text length: {len(text)}") output = self.model.generate( prompt=prompt, text=text, ) # output.audio is a list of tensors (one per batch item) if output.audio and output.audio[0] is not None: audio_tensor = output.audio[0] audio = audio_tensor.detach().cpu().numpy().squeeze().astype(np.float32) else: logger.warning("[TADA] Generation produced no audio") audio = np.zeros(24000, dtype=np.float32) return audio, 24000 return await asyncio.to_thread(_generate_sync)