328 lines
10 KiB
Python
328 lines
10 KiB
Python
"""
|
|
Shared utilities for TTS/STT backend implementations.
|
|
|
|
Eliminates duplication of cache checking, device detection,
|
|
voice prompt combination, and model loading progress tracking.
|
|
"""
|
|
|
|
import logging
|
|
import platform
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
from typing import Callable, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from ..utils.audio import normalize_audio, load_audio
|
|
from ..utils.progress import get_progress_manager
|
|
from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback
|
|
from ..utils.tasks import get_task_manager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def is_model_cached(
|
|
hf_repo: str,
|
|
*,
|
|
weight_extensions: tuple[str, ...] = (".safetensors", ".bin"),
|
|
required_files: Optional[list[str]] = None,
|
|
) -> bool:
|
|
"""
|
|
Check if a HuggingFace model is fully cached locally.
|
|
|
|
Args:
|
|
hf_repo: HuggingFace repo ID (e.g. "Qwen/Qwen3-TTS-12Hz-1.7B-Base")
|
|
weight_extensions: File extensions that count as model weights.
|
|
required_files: If set, check that these specific filenames exist
|
|
in snapshots instead of checking by extension.
|
|
|
|
Returns:
|
|
True if model is fully cached, False if missing or incomplete.
|
|
"""
|
|
try:
|
|
from huggingface_hub import constants as hf_constants
|
|
|
|
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + hf_repo.replace("/", "--"))
|
|
|
|
if not repo_cache.exists():
|
|
return False
|
|
|
|
# Incomplete blobs mean a download is still in progress
|
|
blobs_dir = repo_cache / "blobs"
|
|
if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
|
|
logger.debug(f"Found .incomplete files for {hf_repo}")
|
|
return False
|
|
|
|
snapshots_dir = repo_cache / "snapshots"
|
|
if not snapshots_dir.exists():
|
|
return False
|
|
|
|
if required_files:
|
|
# Check that every required filename exists somewhere in snapshots
|
|
for fname in required_files:
|
|
if not any(snapshots_dir.rglob(fname)):
|
|
return False
|
|
return True
|
|
|
|
# Check that at least one weight file exists
|
|
for ext in weight_extensions:
|
|
if any(snapshots_dir.rglob(f"*{ext}")):
|
|
return True
|
|
|
|
logger.debug(f"No model weights found for {hf_repo}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error checking cache for {hf_repo}: {e}")
|
|
return False
|
|
|
|
|
|
def get_torch_device(
|
|
*,
|
|
allow_xpu: bool = False,
|
|
allow_directml: bool = False,
|
|
allow_mps: bool = False,
|
|
force_cpu_on_mac: bool = False,
|
|
) -> str:
|
|
"""
|
|
Detect the best available torch device.
|
|
|
|
Args:
|
|
allow_xpu: Check for Intel XPU (IPEX) support.
|
|
allow_directml: Check for DirectML (Windows) support.
|
|
allow_mps: Allow MPS (Apple Silicon). If False, MPS falls back to CPU.
|
|
force_cpu_on_mac: Force CPU on macOS regardless of GPU availability.
|
|
"""
|
|
if force_cpu_on_mac and platform.system() == "Darwin":
|
|
return "cpu"
|
|
|
|
import torch
|
|
|
|
if torch.cuda.is_available():
|
|
return "cuda"
|
|
|
|
if allow_xpu:
|
|
try:
|
|
import intel_extension_for_pytorch # noqa: F401
|
|
|
|
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
return "xpu"
|
|
except ImportError:
|
|
pass
|
|
|
|
if allow_directml:
|
|
try:
|
|
import torch_directml
|
|
|
|
if torch_directml.device_count() > 0:
|
|
return torch_directml.device(0)
|
|
except ImportError:
|
|
pass
|
|
|
|
if allow_mps:
|
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
return "mps"
|
|
|
|
return "cpu"
|
|
|
|
|
|
def check_cuda_compatibility() -> tuple[bool, str | None]:
|
|
"""Check if the installed PyTorch supports the current GPU's compute capability.
|
|
|
|
Returns:
|
|
(compatible, warning_message) — compatible is True if OK or no CUDA GPU,
|
|
warning_message is a human-readable string if there's a problem.
|
|
"""
|
|
import torch
|
|
|
|
if not torch.cuda.is_available():
|
|
return True, None
|
|
|
|
major, minor = torch.cuda.get_device_capability(0)
|
|
capability = f"{major}.{minor}"
|
|
device_name = torch.cuda.get_device_name(0)
|
|
sm_tag = f"sm_{major}{minor}"
|
|
|
|
# torch.cuda._get_arch_list() returns the SM architectures this build
|
|
# was compiled for (e.g. ["sm_50", "sm_60", ..., "sm_90"]).
|
|
try:
|
|
arch_list = torch.cuda._get_arch_list()
|
|
if arch_list:
|
|
# Check for both sm_XX and compute_XX (JIT-compiled) entries
|
|
compute_tag = f"compute_{major}{minor}"
|
|
if sm_tag not in arch_list and compute_tag not in arch_list:
|
|
return False, (
|
|
f"{device_name} (compute capability {capability} / {sm_tag}) "
|
|
f"is not supported by this PyTorch build. "
|
|
f"Supported architectures: {', '.join(arch_list)}. "
|
|
f"Install PyTorch nightly (cu128) for newer GPU support: "
|
|
f"pip install torch --index-url https://download.pytorch.org/whl/nightly/cu128"
|
|
)
|
|
except AttributeError:
|
|
pass
|
|
|
|
return True, None
|
|
|
|
|
|
def empty_device_cache(device: str) -> None:
|
|
"""
|
|
Free cached memory on the given device (CUDA or XPU).
|
|
|
|
Backends should call this after unloading models so VRAM is returned
|
|
to the OS.
|
|
"""
|
|
import torch
|
|
|
|
if device == "cuda" and torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
elif device == "xpu" and hasattr(torch, "xpu"):
|
|
torch.xpu.empty_cache()
|
|
|
|
|
|
def manual_seed(seed: int, device: str) -> None:
|
|
"""
|
|
Set the random seed on both CPU and the active accelerator.
|
|
|
|
Covers CUDA and Intel XPU so that generation is reproducible
|
|
regardless of which GPU backend is in use.
|
|
"""
|
|
import torch
|
|
|
|
torch.manual_seed(seed)
|
|
if device == "cuda" and torch.cuda.is_available():
|
|
torch.cuda.manual_seed(seed)
|
|
elif device == "xpu" and hasattr(torch, "xpu"):
|
|
torch.xpu.manual_seed(seed)
|
|
|
|
|
|
async def combine_voice_prompts(
|
|
audio_paths: List[str],
|
|
reference_texts: List[str],
|
|
*,
|
|
sample_rate: Optional[int] = None,
|
|
) -> Tuple[np.ndarray, str]:
|
|
"""
|
|
Combine multiple reference audio samples into one.
|
|
|
|
Loads each audio file, normalizes, concatenates, and joins texts.
|
|
|
|
Args:
|
|
audio_paths: Paths to reference audio files.
|
|
reference_texts: Corresponding transcripts.
|
|
sample_rate: If set, resample audio to this rate during loading.
|
|
"""
|
|
combined_audio = []
|
|
|
|
for path in audio_paths:
|
|
kwargs = {"sample_rate": sample_rate} if sample_rate else {}
|
|
audio, _sr = load_audio(path, **kwargs)
|
|
audio = normalize_audio(audio)
|
|
combined_audio.append(audio)
|
|
|
|
mixed = np.concatenate(combined_audio)
|
|
mixed = normalize_audio(mixed)
|
|
combined_text = " ".join(reference_texts)
|
|
|
|
return mixed, combined_text
|
|
|
|
|
|
@contextmanager
|
|
def model_load_progress(
|
|
model_name: str,
|
|
is_cached: bool,
|
|
filter_non_downloads: Optional[bool] = None,
|
|
):
|
|
"""
|
|
Context manager for model loading with HF download progress tracking.
|
|
|
|
Handles the tqdm patching, progress_manager/task_manager lifecycle,
|
|
and error reporting that every backend duplicates.
|
|
|
|
Args:
|
|
model_name: Progress tracking key (e.g. "qwen-tts-1.7B", "whisper-base").
|
|
is_cached: Whether the model is already downloaded.
|
|
filter_non_downloads: Whether to filter non-download tqdm bars.
|
|
Defaults to `is_cached`.
|
|
|
|
Yields:
|
|
The tracker context (already entered). The caller loads the model
|
|
inside the `with` block. The tqdm patch is torn down on exit.
|
|
|
|
Usage:
|
|
with model_load_progress("qwen-tts-1.7B", is_cached) as ctx:
|
|
self.model = SomeModel.from_pretrained(...)
|
|
"""
|
|
if filter_non_downloads is None:
|
|
filter_non_downloads = is_cached
|
|
|
|
progress_manager = get_progress_manager()
|
|
task_manager = get_task_manager()
|
|
|
|
progress_callback = create_hf_progress_callback(model_name, progress_manager)
|
|
tracker = HFProgressTracker(progress_callback, filter_non_downloads=filter_non_downloads)
|
|
|
|
tracker_context = tracker.patch_download()
|
|
tracker_context.__enter__()
|
|
|
|
if not is_cached:
|
|
task_manager.start_download(model_name)
|
|
progress_manager.update_progress(
|
|
model_name=model_name,
|
|
current=0,
|
|
total=0,
|
|
filename="Connecting to HuggingFace...",
|
|
status="downloading",
|
|
)
|
|
|
|
try:
|
|
yield tracker_context
|
|
except Exception as e:
|
|
# Report error to both managers
|
|
progress_manager.mark_error(model_name, str(e))
|
|
task_manager.error_download(model_name, str(e))
|
|
raise
|
|
else:
|
|
# Only mark complete if we were tracking a download
|
|
if not is_cached:
|
|
progress_manager.mark_complete(model_name)
|
|
task_manager.complete_download(model_name)
|
|
finally:
|
|
tracker_context.__exit__(None, None, None)
|
|
|
|
|
|
def patch_chatterbox_f32(model) -> None:
|
|
"""
|
|
Patch float64 -> float32 dtype mismatches in upstream chatterbox.
|
|
|
|
librosa.load returns float64 numpy arrays. Multiple upstream code paths
|
|
convert these to torch tensors via torch.from_numpy() without casting,
|
|
then matmul against float32 model weights. This patches the two known
|
|
entry points:
|
|
|
|
1. S3Tokenizer.log_mel_spectrogram — audio tensor hits _mel_filters (f32)
|
|
2. VoiceEncoder.forward — float64 mel spectrograms hit LSTM weights (f32)
|
|
"""
|
|
import types
|
|
|
|
# Patch S3Tokenizer
|
|
_tokzr = model.s3gen.tokenizer
|
|
_orig_log_mel = _tokzr.log_mel_spectrogram.__func__
|
|
|
|
def _f32_log_mel(self_tokzr, audio, padding=0):
|
|
import torch as _torch
|
|
|
|
if _torch.is_tensor(audio):
|
|
audio = audio.float()
|
|
return _orig_log_mel(self_tokzr, audio, padding)
|
|
|
|
_tokzr.log_mel_spectrogram = types.MethodType(_f32_log_mel, _tokzr)
|
|
|
|
# Patch VoiceEncoder
|
|
_ve = model.ve
|
|
_orig_ve_forward = _ve.forward.__func__
|
|
|
|
def _f32_ve_forward(self_ve, mels):
|
|
return _orig_ve_forward(self_ve, mels.float())
|
|
|
|
_ve.forward = types.MethodType(_f32_ve_forward, _ve)
|