Initial commit
This commit is contained in:
327
backend/backends/base.py
Normal file
327
backend/backends/base.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
Shared utilities for TTS/STT backend implementations.
|
||||
|
||||
Eliminates duplication of cache checking, device detection,
|
||||
voice prompt combination, and model loading progress tracking.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import platform
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils.audio import normalize_audio, load_audio
|
||||
from ..utils.progress import get_progress_manager
|
||||
from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback
|
||||
from ..utils.tasks import get_task_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_model_cached(
|
||||
hf_repo: str,
|
||||
*,
|
||||
weight_extensions: tuple[str, ...] = (".safetensors", ".bin"),
|
||||
required_files: Optional[list[str]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a HuggingFace model is fully cached locally.
|
||||
|
||||
Args:
|
||||
hf_repo: HuggingFace repo ID (e.g. "Qwen/Qwen3-TTS-12Hz-1.7B-Base")
|
||||
weight_extensions: File extensions that count as model weights.
|
||||
required_files: If set, check that these specific filenames exist
|
||||
in snapshots instead of checking by extension.
|
||||
|
||||
Returns:
|
||||
True if model is fully cached, False if missing or incomplete.
|
||||
"""
|
||||
try:
|
||||
from huggingface_hub import constants as hf_constants
|
||||
|
||||
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + hf_repo.replace("/", "--"))
|
||||
|
||||
if not repo_cache.exists():
|
||||
return False
|
||||
|
||||
# Incomplete blobs mean a download is still in progress
|
||||
blobs_dir = repo_cache / "blobs"
|
||||
if blobs_dir.exists() and any(blobs_dir.glob("*.incomplete")):
|
||||
logger.debug(f"Found .incomplete files for {hf_repo}")
|
||||
return False
|
||||
|
||||
snapshots_dir = repo_cache / "snapshots"
|
||||
if not snapshots_dir.exists():
|
||||
return False
|
||||
|
||||
if required_files:
|
||||
# Check that every required filename exists somewhere in snapshots
|
||||
for fname in required_files:
|
||||
if not any(snapshots_dir.rglob(fname)):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Check that at least one weight file exists
|
||||
for ext in weight_extensions:
|
||||
if any(snapshots_dir.rglob(f"*{ext}")):
|
||||
return True
|
||||
|
||||
logger.debug(f"No model weights found for {hf_repo}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking cache for {hf_repo}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_torch_device(
|
||||
*,
|
||||
allow_xpu: bool = False,
|
||||
allow_directml: bool = False,
|
||||
allow_mps: bool = False,
|
||||
force_cpu_on_mac: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Detect the best available torch device.
|
||||
|
||||
Args:
|
||||
allow_xpu: Check for Intel XPU (IPEX) support.
|
||||
allow_directml: Check for DirectML (Windows) support.
|
||||
allow_mps: Allow MPS (Apple Silicon). If False, MPS falls back to CPU.
|
||||
force_cpu_on_mac: Force CPU on macOS regardless of GPU availability.
|
||||
"""
|
||||
if force_cpu_on_mac and platform.system() == "Darwin":
|
||||
return "cpu"
|
||||
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
|
||||
if allow_xpu:
|
||||
try:
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return "xpu"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if allow_directml:
|
||||
try:
|
||||
import torch_directml
|
||||
|
||||
if torch_directml.device_count() > 0:
|
||||
return torch_directml.device(0)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if allow_mps:
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
|
||||
return "cpu"
|
||||
|
||||
|
||||
def check_cuda_compatibility() -> tuple[bool, str | None]:
|
||||
"""Check if the installed PyTorch supports the current GPU's compute capability.
|
||||
|
||||
Returns:
|
||||
(compatible, warning_message) — compatible is True if OK or no CUDA GPU,
|
||||
warning_message is a human-readable string if there's a problem.
|
||||
"""
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return True, None
|
||||
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
capability = f"{major}.{minor}"
|
||||
device_name = torch.cuda.get_device_name(0)
|
||||
sm_tag = f"sm_{major}{minor}"
|
||||
|
||||
# torch.cuda._get_arch_list() returns the SM architectures this build
|
||||
# was compiled for (e.g. ["sm_50", "sm_60", ..., "sm_90"]).
|
||||
try:
|
||||
arch_list = torch.cuda._get_arch_list()
|
||||
if arch_list:
|
||||
# Check for both sm_XX and compute_XX (JIT-compiled) entries
|
||||
compute_tag = f"compute_{major}{minor}"
|
||||
if sm_tag not in arch_list and compute_tag not in arch_list:
|
||||
return False, (
|
||||
f"{device_name} (compute capability {capability} / {sm_tag}) "
|
||||
f"is not supported by this PyTorch build. "
|
||||
f"Supported architectures: {', '.join(arch_list)}. "
|
||||
f"Install PyTorch nightly (cu128) for newer GPU support: "
|
||||
f"pip install torch --index-url https://download.pytorch.org/whl/nightly/cu128"
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def empty_device_cache(device: str) -> None:
|
||||
"""
|
||||
Free cached memory on the given device (CUDA or XPU).
|
||||
|
||||
Backends should call this after unloading models so VRAM is returned
|
||||
to the OS.
|
||||
"""
|
||||
import torch
|
||||
|
||||
if device == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif device == "xpu" and hasattr(torch, "xpu"):
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
|
||||
def manual_seed(seed: int, device: str) -> None:
|
||||
"""
|
||||
Set the random seed on both CPU and the active accelerator.
|
||||
|
||||
Covers CUDA and Intel XPU so that generation is reproducible
|
||||
regardless of which GPU backend is in use.
|
||||
"""
|
||||
import torch
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if device == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
elif device == "xpu" and hasattr(torch, "xpu"):
|
||||
torch.xpu.manual_seed(seed)
|
||||
|
||||
|
||||
async def combine_voice_prompts(
|
||||
audio_paths: List[str],
|
||||
reference_texts: List[str],
|
||||
*,
|
||||
sample_rate: Optional[int] = None,
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
"""
|
||||
Combine multiple reference audio samples into one.
|
||||
|
||||
Loads each audio file, normalizes, concatenates, and joins texts.
|
||||
|
||||
Args:
|
||||
audio_paths: Paths to reference audio files.
|
||||
reference_texts: Corresponding transcripts.
|
||||
sample_rate: If set, resample audio to this rate during loading.
|
||||
"""
|
||||
combined_audio = []
|
||||
|
||||
for path in audio_paths:
|
||||
kwargs = {"sample_rate": sample_rate} if sample_rate else {}
|
||||
audio, _sr = load_audio(path, **kwargs)
|
||||
audio = normalize_audio(audio)
|
||||
combined_audio.append(audio)
|
||||
|
||||
mixed = np.concatenate(combined_audio)
|
||||
mixed = normalize_audio(mixed)
|
||||
combined_text = " ".join(reference_texts)
|
||||
|
||||
return mixed, combined_text
|
||||
|
||||
|
||||
@contextmanager
|
||||
def model_load_progress(
|
||||
model_name: str,
|
||||
is_cached: bool,
|
||||
filter_non_downloads: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Context manager for model loading with HF download progress tracking.
|
||||
|
||||
Handles the tqdm patching, progress_manager/task_manager lifecycle,
|
||||
and error reporting that every backend duplicates.
|
||||
|
||||
Args:
|
||||
model_name: Progress tracking key (e.g. "qwen-tts-1.7B", "whisper-base").
|
||||
is_cached: Whether the model is already downloaded.
|
||||
filter_non_downloads: Whether to filter non-download tqdm bars.
|
||||
Defaults to `is_cached`.
|
||||
|
||||
Yields:
|
||||
The tracker context (already entered). The caller loads the model
|
||||
inside the `with` block. The tqdm patch is torn down on exit.
|
||||
|
||||
Usage:
|
||||
with model_load_progress("qwen-tts-1.7B", is_cached) as ctx:
|
||||
self.model = SomeModel.from_pretrained(...)
|
||||
"""
|
||||
if filter_non_downloads is None:
|
||||
filter_non_downloads = is_cached
|
||||
|
||||
progress_manager = get_progress_manager()
|
||||
task_manager = get_task_manager()
|
||||
|
||||
progress_callback = create_hf_progress_callback(model_name, progress_manager)
|
||||
tracker = HFProgressTracker(progress_callback, filter_non_downloads=filter_non_downloads)
|
||||
|
||||
tracker_context = tracker.patch_download()
|
||||
tracker_context.__enter__()
|
||||
|
||||
if not is_cached:
|
||||
task_manager.start_download(model_name)
|
||||
progress_manager.update_progress(
|
||||
model_name=model_name,
|
||||
current=0,
|
||||
total=0,
|
||||
filename="Connecting to HuggingFace...",
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
try:
|
||||
yield tracker_context
|
||||
except Exception as e:
|
||||
# Report error to both managers
|
||||
progress_manager.mark_error(model_name, str(e))
|
||||
task_manager.error_download(model_name, str(e))
|
||||
raise
|
||||
else:
|
||||
# Only mark complete if we were tracking a download
|
||||
if not is_cached:
|
||||
progress_manager.mark_complete(model_name)
|
||||
task_manager.complete_download(model_name)
|
||||
finally:
|
||||
tracker_context.__exit__(None, None, None)
|
||||
|
||||
|
||||
def patch_chatterbox_f32(model) -> None:
|
||||
"""
|
||||
Patch float64 -> float32 dtype mismatches in upstream chatterbox.
|
||||
|
||||
librosa.load returns float64 numpy arrays. Multiple upstream code paths
|
||||
convert these to torch tensors via torch.from_numpy() without casting,
|
||||
then matmul against float32 model weights. This patches the two known
|
||||
entry points:
|
||||
|
||||
1. S3Tokenizer.log_mel_spectrogram — audio tensor hits _mel_filters (f32)
|
||||
2. VoiceEncoder.forward — float64 mel spectrograms hit LSTM weights (f32)
|
||||
"""
|
||||
import types
|
||||
|
||||
# Patch S3Tokenizer
|
||||
_tokzr = model.s3gen.tokenizer
|
||||
_orig_log_mel = _tokzr.log_mel_spectrogram.__func__
|
||||
|
||||
def _f32_log_mel(self_tokzr, audio, padding=0):
|
||||
import torch as _torch
|
||||
|
||||
if _torch.is_tensor(audio):
|
||||
audio = audio.float()
|
||||
return _orig_log_mel(self_tokzr, audio, padding)
|
||||
|
||||
_tokzr.log_mel_spectrogram = types.MethodType(_f32_log_mel, _tokzr)
|
||||
|
||||
# Patch VoiceEncoder
|
||||
_ve = model.ve
|
||||
_orig_ve_forward = _ve.forward.__func__
|
||||
|
||||
def _f32_ve_forward(self_ve, mels):
|
||||
return _orig_ve_forward(self_ve, mels.float())
|
||||
|
||||
_ve.forward = types.MethodType(_f32_ve_forward, _ve)
|
||||
Reference in New Issue
Block a user