Initial commit

This commit is contained in:
2026-04-24 19:18:15 +08:00
commit fbcbe08696
555 changed files with 96692 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Utils package

318
backend/utils/audio.py Normal file
View File

@@ -0,0 +1,318 @@
"""
Audio processing utilities.
"""
import numpy as np
import soundfile as sf
import librosa
from typing import Tuple, Optional
def normalize_audio(
audio: np.ndarray,
target_db: float = -20.0,
peak_limit: float = 0.85,
) -> np.ndarray:
"""
Normalize audio to target loudness with peak limiting.
Args:
audio: Input audio array
target_db: Target RMS level in dB
peak_limit: Peak limit (0.0-1.0)
Returns:
Normalized audio array
"""
# Convert to float32
audio = audio.astype(np.float32)
# Calculate current RMS
rms = np.sqrt(np.mean(audio**2))
# Calculate target RMS
target_rms = 10**(target_db / 20)
# Apply gain
if rms > 0:
gain = target_rms / rms
audio = audio * gain
# Peak limiting
audio = np.clip(audio, -peak_limit, peak_limit)
return audio
def load_audio(
path: str,
sample_rate: int = 24000,
mono: bool = True,
) -> Tuple[np.ndarray, int]:
"""
Load audio file with normalization.
Args:
path: Path to audio file
sample_rate: Target sample rate
mono: Convert to mono
Returns:
Tuple of (audio_array, sample_rate)
"""
audio, sr = librosa.load(path, sr=sample_rate, mono=mono)
return audio, sr
def save_audio(
audio: np.ndarray,
path: str,
sample_rate: int = 24000,
) -> None:
"""
Save audio file with atomic write and error handling.
Writes to a temporary file first, then atomically renames to the
target path. This prevents corrupted/partial WAV files if the
process is interrupted mid-write.
Args:
audio: Audio array
path: Output path
sample_rate: Sample rate
Raises:
OSError: If file cannot be written
"""
from pathlib import Path
import os
temp_path = f"{path}.tmp"
try:
# Ensure parent directory exists
Path(path).parent.mkdir(parents=True, exist_ok=True)
# Write to temporary file first (explicit format since .tmp
# extension is not recognised by soundfile)
sf.write(temp_path, audio, sample_rate, format='WAV')
# Atomic rename to final path
os.replace(temp_path, path)
except Exception as e:
# Clean up temp file on failure
try:
if Path(temp_path).exists():
Path(temp_path).unlink()
except Exception:
pass # Best effort cleanup
raise OSError(f"Failed to save audio to {path}: {e}") from e
def trim_tts_output(
audio: np.ndarray,
sample_rate: int = 24000,
frame_ms: int = 20,
silence_threshold_db: float = -40.0,
min_silence_ms: int = 200,
max_internal_silence_ms: int = 1000,
fade_ms: int = 30,
) -> np.ndarray:
"""
Trim trailing silence and post-silence hallucination from TTS output.
Chatterbox sometimes produces ``[speech][silence][hallucinated noise]``.
This detects internal silence gaps longer than *max_internal_silence_ms*
and cuts the audio at that boundary, then trims trailing silence and
applies a short cosine fade-out.
Args:
audio: Input audio array (mono float32)
sample_rate: Sample rate in Hz
frame_ms: Frame size for RMS energy calculation
silence_threshold_db: dB threshold below which a frame is silence
min_silence_ms: Minimum trailing silence to keep
max_internal_silence_ms: Cut after any silence gap longer than this
fade_ms: Cosine fade-out duration in ms
Returns:
Trimmed audio array
"""
frame_len = int(sample_rate * frame_ms / 1000)
if frame_len == 0 or len(audio) < frame_len:
return audio
n_frames = len(audio) // frame_len
threshold_linear = 10 ** (silence_threshold_db / 20)
# Compute per-frame RMS
rms = np.array(
[
np.sqrt(np.mean(audio[i * frame_len : (i + 1) * frame_len] ** 2))
for i in range(n_frames)
]
)
is_speech = rms >= threshold_linear
# Find first speech frame
first_speech = 0
for i, s in enumerate(is_speech):
if s:
first_speech = max(0, i - 1) # keep 1 frame padding
break
# Walk forward from first speech; cut at long internal silence gaps
max_silence_frames = int(max_internal_silence_ms / frame_ms)
consecutive_silence = 0
cut_frame = n_frames
for i in range(first_speech, n_frames):
if is_speech[i]:
consecutive_silence = 0
else:
consecutive_silence += 1
if consecutive_silence >= max_silence_frames:
cut_frame = i - consecutive_silence + 1
break
# Trim trailing silence from the cut point
min_silence_frames = int(min_silence_ms / frame_ms)
end_frame = cut_frame
while end_frame > first_speech and not is_speech[end_frame - 1]:
end_frame -= 1
# Keep a short tail
end_frame = min(end_frame + min_silence_frames, cut_frame)
# Convert frames back to samples
start_sample = first_speech * frame_len
end_sample = min(end_frame * frame_len, len(audio))
trimmed = audio[start_sample:end_sample].copy()
# Cosine fade-out
fade_samples = int(sample_rate * fade_ms / 1000)
if fade_samples > 0 and len(trimmed) > fade_samples:
fade = np.cos(np.linspace(0, np.pi / 2, fade_samples)) ** 2
trimmed[-fade_samples:] *= fade
return trimmed
def preprocess_reference_audio(
audio: np.ndarray,
sample_rate: int,
peak_target: float = 0.95,
trim_top_db: float = 40.0,
edge_padding_ms: int = 100,
) -> np.ndarray:
"""
Clean up a reference-audio sample before validation/storage.
Removes DC offset, trims leading/trailing silence, and caps the peak so a
slightly-hot recording doesn't get rejected downstream as "clipping". The
goal is to accept reasonable real-world recordings — not to repair badly
distorted ones. True clipping artifacts inside the waveform can't be
recovered by peak scaling and will still sound bad.
Args:
audio: Mono audio array.
sample_rate: Sample rate of ``audio`` in Hz.
peak_target: Peak amplitude cap in [0, 1]. Applied only if the input
peak exceeds this value.
trim_top_db: Silence threshold for edge trimming, in dB below peak.
40 dB sits below normal speech dynamic range (≈30 dB) so soft
trailing syllables are preserved, while still catching obvious
leading/trailing silence. Lower values are more aggressive;
librosa's own default is 60.
edge_padding_ms: Milliseconds of padding to add back at each edge
*only if* trimming shortened the waveform, so TTS engines have a
brief silence to anchor on without ever making the output longer
than the input.
Returns:
Preprocessed audio array (float32).
"""
audio = audio.astype(np.float32, copy=False)
if audio.size == 0:
return audio
audio = audio - float(np.mean(audio))
trimmed, _ = librosa.effects.trim(audio, top_db=trim_top_db)
if 0 < trimmed.size < audio.size:
pad_each = int(sample_rate * edge_padding_ms / 1000)
# Never pad past the original length — for near-max-duration uploads
# an unconditional pad would push them over the 30 s ceiling and
# trigger a spurious "too long" rejection.
headroom = (audio.size - trimmed.size) // 2
pad = min(pad_each, max(headroom, 0))
if pad > 0:
trimmed = np.pad(trimmed, (pad, pad), mode="constant")
audio = trimmed
peak = float(np.abs(audio).max())
if peak > peak_target and peak > 0:
audio = audio * (peak_target / peak)
return audio
def validate_reference_audio(
audio_path: str,
min_duration: float = 2.0,
max_duration: float = 30.0,
min_rms: float = 0.01,
) -> Tuple[bool, Optional[str]]:
"""
Validate reference audio for voice cloning.
Args:
audio_path: Path to audio file
min_duration: Minimum duration in seconds
max_duration: Maximum duration in seconds
min_rms: Minimum RMS level
Returns:
Tuple of (is_valid, error_message)
"""
result = validate_and_load_reference_audio(
audio_path, min_duration, max_duration, min_rms
)
return (result[0], result[1])
def validate_and_load_reference_audio(
audio_path: str,
min_duration: float = 2.0,
max_duration: float = 30.0,
min_rms: float = 0.01,
) -> Tuple[bool, Optional[str], Optional[np.ndarray], Optional[int]]:
"""
Validate and load reference audio in a single pass.
Applies :func:`preprocess_reference_audio` before checks so that
slightly-hot recordings aren't rejected as clipping. Duration and RMS
checks run on the preprocessed waveform.
Returns:
Tuple of (is_valid, error_message, audio_array, sample_rate)
"""
try:
audio, sr = load_audio(audio_path)
audio = preprocess_reference_audio(audio, sr)
duration = len(audio) / sr
if duration < min_duration:
return False, f"Audio too short (minimum {min_duration} seconds)", None, None
if duration > max_duration:
return False, f"Audio too long (maximum {max_duration} seconds)", None, None
rms = np.sqrt(np.mean(audio**2))
if rms < min_rms:
return False, "Audio is too quiet or silent", None, None
return True, None, audio, sr
except Exception as e:
return False, f"Error validating audio: {str(e)}", None, None

153
backend/utils/cache.py Normal file
View File

@@ -0,0 +1,153 @@
"""
Voice prompt caching utilities.
"""
import hashlib
import logging
import torch
from pathlib import Path
from typing import Optional, Union, Dict, Any
from .. import config
logger = logging.getLogger(__name__)
def _get_cache_dir() -> Path:
"""Get cache directory from config."""
return config.get_cache_dir()
# In-memory cache - can store dict (voice prompt) or tensor (legacy)
_memory_cache: dict[str, Union[torch.Tensor, Dict[str, Any]]] = {}
def get_cache_key(audio_path: str, reference_text: str) -> str:
"""
Generate cache key from audio file and reference text.
Args:
audio_path: Path to audio file
reference_text: Reference text
Returns:
Cache key (MD5 hash)
"""
# Read audio file
with open(audio_path, "rb") as f:
audio_bytes = f.read()
# Combine audio bytes and text
combined = audio_bytes + reference_text.encode("utf-8")
# Generate hash
return hashlib.md5(combined).hexdigest()
def get_cached_voice_prompt(
cache_key: str,
) -> Optional[Union[torch.Tensor, Dict[str, Any]]]:
"""
Get cached voice prompt if available.
Args:
cache_key: Cache key
Returns:
Cached voice prompt (dict or tensor) or None
"""
# Check in-memory cache
if cache_key in _memory_cache:
return _memory_cache[cache_key]
# Check disk cache
cache_file = _get_cache_dir() / f"{cache_key}.prompt"
if cache_file.exists():
try:
prompt = torch.load(cache_file, weights_only=True)
_memory_cache[cache_key] = prompt
return prompt
except Exception:
# Cache file corrupted, delete it
cache_file.unlink()
return None
def cache_voice_prompt(
cache_key: str,
voice_prompt: Union[torch.Tensor, Dict[str, Any]],
) -> None:
"""
Cache voice prompt to memory and disk.
Args:
cache_key: Cache key
voice_prompt: Voice prompt (dict or tensor)
"""
# Store in memory
_memory_cache[cache_key] = voice_prompt
# Store on disk (torch.save can handle both dicts and tensors)
cache_file = _get_cache_dir() / f"{cache_key}.prompt"
torch.save(voice_prompt, cache_file)
def clear_voice_prompt_cache() -> int:
"""
Clear all voice prompt caches (memory and disk).
Returns:
Number of cache files deleted
"""
# Clear memory cache
_memory_cache.clear()
# Clear disk cache
cache_dir = _get_cache_dir()
deleted_count = 0
if cache_dir.exists():
# Delete prompt cache files
for cache_file in cache_dir.glob("*.prompt"):
try:
cache_file.unlink()
deleted_count += 1
except Exception as e:
logger.warning("Failed to delete cache file %s: %s", cache_file, e)
# Delete combined audio files
for audio_file in cache_dir.glob("combined_*.wav"):
try:
audio_file.unlink()
deleted_count += 1
except Exception as e:
logger.warning("Failed to delete combined audio file %s: %s", audio_file, e)
return deleted_count
def clear_profile_cache(profile_id: str) -> int:
"""
Clear cache files for a specific profile.
Args:
profile_id: Profile ID
Returns:
Number of cache files deleted
"""
cache_dir = _get_cache_dir()
deleted_count = 0
if cache_dir.exists():
# Delete combined audio files for this profile
pattern = f"combined_{profile_id}_*.wav"
for audio_file in cache_dir.glob(pattern):
try:
audio_file.unlink()
deleted_count += 1
except Exception as e:
logger.warning("Failed to delete combined audio file %s: %s", audio_file, e)
return deleted_count

View File

@@ -0,0 +1,299 @@
"""
Chunked TTS generation utilities.
Splits long text into sentence-boundary chunks, generates audio per-chunk
via any TTSBackend, and concatenates with crossfade. All logic is
engine-agnostic — it wraps the standard ``TTSBackend.generate()`` interface.
Short text (≤ max_chunk_chars) uses the single-shot fast path with zero
overhead.
"""
import logging
import re
from typing import List, Tuple
import numpy as np
logger = logging.getLogger("voicebox.chunked-tts")
# Default chunk size in characters. Can be overridden per-request via
# the ``max_chunk_chars`` field on GenerationRequest.
DEFAULT_MAX_CHUNK_CHARS = 800
# Common abbreviations that should NOT be treated as sentence endings.
# Lowercase for case-insensitive matching.
_ABBREVIATIONS = frozenset(
{
"mr",
"mrs",
"ms",
"dr",
"prof",
"sr",
"jr",
"st",
"ave",
"blvd",
"inc",
"ltd",
"corp",
"dept",
"est",
"approx",
"vs",
"etc",
"e.g",
"i.e",
"a.m",
"p.m",
"u.s",
"u.s.a",
"u.k",
}
)
# Paralinguistic tags used by Chatterbox Turbo. The splitter must never
# cut inside one of these.
_PARA_TAG_RE = re.compile(r"\[[^\]]*\]")
def split_text_into_chunks(text: str, max_chars: int = DEFAULT_MAX_CHUNK_CHARS) -> List[str]:
"""Split *text* at natural boundaries into chunks of at most *max_chars*.
Priority: sentence-end (``.!?`` not preceded by an abbreviation and not
inside brackets) → clause boundary (``;:,—``) → whitespace → hard cut.
Paralinguistic tags like ``[laugh]`` are treated as atomic and will not
be split across chunks.
"""
text = text.strip()
if not text:
return []
if len(text) <= max_chars:
return [text]
chunks: List[str] = []
remaining = text
while remaining:
remaining = remaining.lstrip()
if not remaining:
break
if len(remaining) <= max_chars:
chunks.append(remaining)
break
segment = remaining[:max_chars]
# Try to split at the last real sentence ending
split_pos = _find_last_sentence_end(segment)
if split_pos == -1:
split_pos = _find_last_clause_boundary(segment)
if split_pos == -1:
split_pos = segment.rfind(" ")
if split_pos == -1:
# Absolute fallback: hard cut but avoid splitting inside a tag
split_pos = _safe_hard_cut(segment, max_chars)
chunk = remaining[: split_pos + 1].strip()
if chunk:
chunks.append(chunk)
remaining = remaining[split_pos + 1 :]
return chunks
def _find_last_sentence_end(text: str) -> int:
"""Return the index of the last sentence-ending punctuation in *text*.
Skips periods that follow common abbreviations (``Dr.``, ``Mr.``, etc.)
and periods inside bracket tags (``[laugh]``). Also handles CJK
sentence-ending punctuation (``。!?``).
"""
best = -1
# ASCII sentence ends
for m in re.finditer(r"[.!?](?:\s|$)", text):
pos = m.start()
char = text[pos]
# Skip periods after abbreviations
if char == ".":
# Walk backwards to find the preceding word
word_start = pos - 1
while word_start >= 0 and text[word_start].isalpha():
word_start -= 1
word = text[word_start + 1 : pos].lower()
if word in _ABBREVIATIONS:
continue
# Skip decimal numbers (digit immediately before the period)
if word_start >= 0 and text[word_start].isdigit():
continue
# Skip if we're inside a bracket tag
if _inside_bracket_tag(text, pos):
continue
best = pos
# CJK sentence-ending punctuation
for m in re.finditer(r"[\u3002\uff01\uff1f]", text):
if m.start() > best:
best = m.start()
return best
def _find_last_clause_boundary(text: str) -> int:
"""Return the index of the last clause-boundary punctuation."""
best = -1
for m in re.finditer(r"[;:,\u2014](?:\s|$)", text):
pos = m.start()
# Skip if inside a bracket tag
if _inside_bracket_tag(text, pos):
continue
best = pos
return best
def _inside_bracket_tag(text: str, pos: int) -> bool:
"""Return True if *pos* falls inside a ``[...]`` tag."""
for m in _PARA_TAG_RE.finditer(text):
if m.start() < pos < m.end():
return True
return False
def _safe_hard_cut(segment: str, max_chars: int) -> int:
"""Find a hard-cut position that doesn't split a ``[tag]``."""
cut = max_chars - 1
# Check if the cut falls inside a bracket tag; if so, move before it
for m in _PARA_TAG_RE.finditer(segment):
if m.start() < cut < m.end():
return m.start() - 1 if m.start() > 0 else cut
return cut
def concatenate_audio_chunks(
chunks: List[np.ndarray],
sample_rate: int,
crossfade_ms: int = 50,
) -> np.ndarray:
"""Concatenate audio arrays with a short crossfade to eliminate clicks.
Each chunk is expected to be a 1-D float32 ndarray at *sample_rate* Hz.
"""
if not chunks:
return np.array([], dtype=np.float32)
if len(chunks) == 1:
return chunks[0]
crossfade_samples = int(sample_rate * crossfade_ms / 1000)
result = np.array(chunks[0], dtype=np.float32, copy=True)
for chunk in chunks[1:]:
if len(chunk) == 0:
continue
overlap = min(crossfade_samples, len(result), len(chunk))
if overlap > 0:
fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32)
fade_in = np.linspace(0.0, 1.0, overlap, dtype=np.float32)
result[-overlap:] = result[-overlap:] * fade_out + chunk[:overlap] * fade_in
result = np.concatenate([result, chunk[overlap:]])
else:
result = np.concatenate([result, chunk])
return result
async def generate_chunked(
backend,
text: str,
voice_prompt: dict,
language: str = "en",
seed: int | None = None,
instruct: str | None = None,
max_chunk_chars: int = DEFAULT_MAX_CHUNK_CHARS,
crossfade_ms: int = 50,
trim_fn=None,
) -> Tuple[np.ndarray, int]:
"""Generate audio with automatic chunking for long text.
For text shorter than *max_chunk_chars* this is a thin wrapper around
``backend.generate()`` with zero overhead.
For longer text the input is split at natural sentence boundaries,
each chunk is generated independently, optionally trimmed (useful for
Chatterbox engines that hallucinate trailing noise), and the results
are concatenated with a crossfade (or hard cut if *crossfade_ms* is 0).
Parameters
----------
backend : TTSBackend
Any backend implementing the ``generate()`` protocol.
text : str
Input text (may be arbitrarily long).
voice_prompt, language, seed, instruct
Forwarded to ``backend.generate()`` verbatim.
max_chunk_chars : int
Maximum characters per chunk (default 800).
crossfade_ms : int
Crossfade duration in milliseconds between chunks. 0 for a hard
cut with no overlap (default 50).
trim_fn : callable | None
Optional ``(audio, sample_rate) -> audio`` post-processing
function applied to each chunk before concatenation (e.g.
``trim_tts_output`` for Chatterbox engines).
Returns
-------
(audio, sample_rate) : Tuple[np.ndarray, int]
"""
chunks = split_text_into_chunks(text, max_chunk_chars)
if len(chunks) <= 1:
# Short text — single-shot fast path
audio, sample_rate = await backend.generate(
text,
voice_prompt,
language,
seed,
instruct,
)
if trim_fn is not None:
audio = trim_fn(audio, sample_rate)
return audio, sample_rate
# Long text — chunked generation
logger.info(
"Splitting %d chars into %d chunks (max %d chars each)",
len(text),
len(chunks),
max_chunk_chars,
)
audio_chunks: List[np.ndarray] = []
sample_rate: int | None = None
for i, chunk_text in enumerate(chunks):
logger.info(
"Generating chunk %d/%d (%d chars)",
i + 1,
len(chunks),
len(chunk_text),
)
# Vary the seed per chunk to avoid correlated RNG artefacts,
# but keep it deterministic so the same (text, seed) pair
# always produces the same output.
chunk_seed = (seed + i) if seed is not None else None
chunk_audio, chunk_sr = await backend.generate(
chunk_text,
voice_prompt,
language,
chunk_seed,
instruct,
)
if trim_fn is not None:
chunk_audio = trim_fn(chunk_audio, chunk_sr)
audio_chunks.append(np.asarray(chunk_audio, dtype=np.float32))
if sample_rate is None:
sample_rate = chunk_sr
audio = concatenate_audio_chunks(audio_chunks, sample_rate, crossfade_ms=crossfade_ms)
return audio, sample_rate

95
backend/utils/dac_shim.py Normal file
View File

@@ -0,0 +1,95 @@
"""
Minimal shim for descript-audio-codec (DAC).
TADA only imports Snake1d from dac.nn.layers and dac.model.dac.
The real DAC package pulls in descript-audiotools which depends on
onnx, tensorboard, protobuf, matplotlib, pystoi, etc. — none of
which are needed for TADA's runtime use of Snake1d.
This shim provides the exact Snake1d implementation (MIT-licensed,
from https://github.com/descriptinc/descript-audio-codec) so we can
avoid the entire audiotools dependency chain.
If the real DAC package is installed, this module is never used —
Python's import system will find the site-packages version first.
Install this shim only when descript-audio-codec is NOT installed.
"""
import sys
import types
import torch
import torch.nn as nn
# ── Snake activation (from dac/nn/layers.py) ────────────────────────
# NOTE: The original DAC code uses @torch.jit.script here for a 1.4x
# speedup. We omit it because TorchScript calls inspect.getsource()
# which fails inside a PyInstaller frozen binary (no .py source files).
def snake(x: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return snake(x, self.alpha)
# ── Register as dac.nn.layers and dac.model.dac ─────────────────────
def install_dac_shim() -> None:
"""Register fake dac package modules in sys.modules.
Only installs the shim if 'dac' is not already importable
(i.e. the real descript-audio-codec is not installed).
"""
try:
import dac # noqa: F401 — real package exists, do nothing
return
except ImportError:
pass
# Create the module tree: dac -> dac.nn -> dac.nn.layers
# -> dac.model -> dac.model.dac
dac_pkg = types.ModuleType("dac")
dac_pkg.__path__ = [] # make it a package
dac_pkg.__package__ = "dac"
dac_nn = types.ModuleType("dac.nn")
dac_nn.__path__ = []
dac_nn.__package__ = "dac.nn"
dac_nn_layers = types.ModuleType("dac.nn.layers")
dac_nn_layers.__package__ = "dac.nn"
dac_nn_layers.Snake1d = Snake1d
dac_nn_layers.snake = snake
dac_model = types.ModuleType("dac.model")
dac_model.__path__ = []
dac_model.__package__ = "dac.model"
dac_model_dac = types.ModuleType("dac.model.dac")
dac_model_dac.__package__ = "dac.model"
dac_model_dac.Snake1d = Snake1d
# Wire up submodules
dac_pkg.nn = dac_nn
dac_pkg.model = dac_model
dac_nn.layers = dac_nn_layers
dac_model.dac = dac_model_dac
# Register in sys.modules
sys.modules["dac"] = dac_pkg
sys.modules["dac.nn"] = dac_nn
sys.modules["dac.nn.layers"] = dac_nn_layers
sys.modules["dac.model"] = dac_model
sys.modules["dac.model.dac"] = dac_model_dac

373
backend/utils/effects.py Normal file
View File

@@ -0,0 +1,373 @@
"""
Audio post-processing effects engine.
Uses Spotify's pedalboard library to apply professional-grade DSP effects
to generated audio. Effects are described as a JSON-serializable chain
(list of effect dicts) so they can be stored in the database and sent
over the API.
Supported effect types:
- chorus (flanger-style with short delays)
- reverb (room reverb)
- delay (echo / delay line)
- compressor (dynamic range compression)
- gain (volume adjustment in dB)
- highpass (high-pass filter)
- lowpass (low-pass filter)
- pitch_shift (semitone pitch shifting)
"""
from __future__ import annotations
import numpy as np
from typing import Any, Dict, List, Optional
from pedalboard import (
Pedalboard,
Chorus,
Reverb,
Compressor,
Gain,
HighpassFilter,
LowpassFilter,
Delay,
PitchShift,
)
# Each param definition: (default, min, max, description)
EFFECT_REGISTRY: Dict[str, Dict[str, Any]] = {
"chorus": {
"cls": Chorus,
"label": "Chorus / Flanger",
"description": "Modulated delay for flanging or chorus effects. Short centre_delay_ms (<10) gives flanger; longer gives chorus.",
"params": {
"rate_hz": {"default": 1.0, "min": 0.01, "max": 20.0, "step": 0.01, "description": "LFO speed (Hz)"},
"depth": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Modulation depth"},
"feedback": {"default": 0.0, "min": 0.0, "max": 0.95, "step": 0.01, "description": "Feedback amount"},
"centre_delay_ms": {
"default": 7.0,
"min": 0.5,
"max": 50.0,
"step": 0.1,
"description": "Centre delay (ms)",
},
"mix": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Wet/dry mix"},
},
},
"reverb": {
"cls": Reverb,
"label": "Reverb",
"description": "Room reverb effect.",
"params": {
"room_size": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Room size"},
"damping": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "High frequency damping"},
"wet_level": {"default": 0.33, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Wet level"},
"dry_level": {"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Dry level"},
"width": {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Stereo width"},
},
},
"delay": {
"cls": Delay,
"label": "Delay",
"description": "Echo / delay line.",
"params": {
"delay_seconds": {
"default": 0.3,
"min": 0.01,
"max": 2.0,
"step": 0.01,
"description": "Delay time (seconds)",
},
"feedback": {"default": 0.3, "min": 0.0, "max": 0.95, "step": 0.01, "description": "Feedback amount"},
"mix": {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Wet/dry mix"},
},
},
"compressor": {
"cls": Compressor,
"label": "Compressor",
"description": "Dynamic range compression for consistent loudness.",
"params": {
"threshold_db": {"default": -20.0, "min": -60.0, "max": 0.0, "step": 0.5, "description": "Threshold (dB)"},
"ratio": {"default": 4.0, "min": 1.0, "max": 20.0, "step": 0.1, "description": "Compression ratio"},
"attack_ms": {"default": 10.0, "min": 0.1, "max": 100.0, "step": 0.1, "description": "Attack time (ms)"},
"release_ms": {
"default": 100.0,
"min": 10.0,
"max": 1000.0,
"step": 1.0,
"description": "Release time (ms)",
},
},
},
"gain": {
"cls": Gain,
"label": "Gain",
"description": "Volume adjustment in decibels.",
"params": {
"gain_db": {"default": 0.0, "min": -40.0, "max": 40.0, "step": 0.5, "description": "Gain (dB)"},
},
},
"highpass": {
"cls": HighpassFilter,
"label": "High-Pass Filter",
"description": "Removes frequencies below the cutoff.",
"params": {
"cutoff_frequency_hz": {
"default": 80.0,
"min": 20.0,
"max": 8000.0,
"step": 1.0,
"description": "Cutoff frequency (Hz)",
},
},
},
"lowpass": {
"cls": LowpassFilter,
"label": "Low-Pass Filter",
"description": "Removes frequencies above the cutoff.",
"params": {
"cutoff_frequency_hz": {
"default": 8000.0,
"min": 200.0,
"max": 20000.0,
"step": 1.0,
"description": "Cutoff frequency (Hz)",
},
},
},
"pitch_shift": {
"cls": PitchShift,
"label": "Pitch Shift",
"description": "Shift pitch up or down by semitones.",
"params": {
"semitones": {"default": 0.0, "min": -12.0, "max": 12.0, "step": 0.5, "description": "Semitones to shift"},
},
},
}
BUILTIN_PRESETS: Dict[str, Dict[str, Any]] = {
"robotic": {
"name": "Robotic",
"sort_order": 0,
"description": "Metallic robotic voice (flanger with slow LFO and high feedback)",
"effects_chain": [
{
"type": "chorus",
"enabled": True,
"params": {
"rate_hz": 0.2,
"depth": 1.0,
"feedback": 0.35,
"centre_delay_ms": 7.0,
"mix": 0.5,
},
},
],
},
"radio": {
"name": "Radio",
"sort_order": 1,
"description": "Thin AM-radio voice with band-pass filtering and light compression",
"effects_chain": [
{
"type": "highpass",
"enabled": True,
"params": {"cutoff_frequency_hz": 300.0},
},
{
"type": "lowpass",
"enabled": True,
"params": {"cutoff_frequency_hz": 3500.0},
},
{
"type": "compressor",
"enabled": True,
"params": {
"threshold_db": -15.0,
"ratio": 6.0,
"attack_ms": 5.0,
"release_ms": 50.0,
},
},
{
"type": "gain",
"enabled": True,
"params": {"gain_db": 6.0},
},
],
},
"echo_chamber": {
"name": "Echo Chamber",
"sort_order": 2,
"description": "Spacious reverb with trailing echo",
"effects_chain": [
{
"type": "reverb",
"enabled": True,
"params": {
"room_size": 0.85,
"damping": 0.3,
"wet_level": 0.45,
"dry_level": 0.55,
"width": 1.0,
},
},
{
"type": "delay",
"enabled": True,
"params": {
"delay_seconds": 0.25,
"feedback": 0.3,
"mix": 0.2,
},
},
],
},
"deep_voice": {
"name": "Deep Voice",
"sort_order": 99,
"description": "Lower pitch with added warmth",
"effects_chain": [
{
"type": "pitch_shift",
"enabled": True,
"params": {"semitones": -3.0},
},
{
"type": "lowpass",
"enabled": True,
"params": {"cutoff_frequency_hz": 6000.0},
},
{
"type": "compressor",
"enabled": True,
"params": {
"threshold_db": -18.0,
"ratio": 3.0,
"attack_ms": 10.0,
"release_ms": 150.0,
},
},
],
},
}
def get_available_effects() -> List[Dict[str, Any]]:
"""Return the list of available effect types with their parameter definitions.
Used by the frontend to build the effects chain editor UI.
"""
result = []
for effect_type, info in EFFECT_REGISTRY.items():
result.append(
{
"type": effect_type,
"label": info["label"],
"description": info["description"],
"params": {name: {k: v for k, v in pdef.items()} for name, pdef in info["params"].items()},
}
)
return result
def get_builtin_presets() -> Dict[str, Dict[str, Any]]:
"""Return all built-in effect presets."""
return BUILTIN_PRESETS
def validate_effects_chain(effects_chain: List[Dict[str, Any]]) -> Optional[str]:
"""Validate an effects chain configuration.
Returns None if valid, or an error message string.
"""
if not isinstance(effects_chain, list):
return "effects_chain must be a list"
for i, effect in enumerate(effects_chain):
if not isinstance(effect, dict):
return f"Effect at index {i} must be a dict"
effect_type = effect.get("type")
if effect_type not in EFFECT_REGISTRY:
return f"Unknown effect type '{effect_type}' at index {i}. Available: {list(EFFECT_REGISTRY.keys())}"
params = effect.get("params", {})
if not isinstance(params, dict):
return f"Effect '{effect_type}' at index {i}: params must be a dict"
registry = EFFECT_REGISTRY[effect_type]
for param_name, value in params.items():
if param_name not in registry["params"]:
return f"Effect '{effect_type}' at index {i}: unknown param '{param_name}'"
pdef = registry["params"][param_name]
if not isinstance(value, (int, float)):
return f"Effect '{effect_type}' at index {i}: param '{param_name}' must be a number"
if value < pdef["min"] or value > pdef["max"]:
return (
f"Effect '{effect_type}' at index {i}: param '{param_name}' "
f"must be between {pdef['min']} and {pdef['max']} (got {value})"
)
return None
def build_pedalboard(effects_chain: List[Dict[str, Any]]) -> Pedalboard:
"""Build a Pedalboard instance from an effects chain config.
Skips effects where ``enabled`` is ``False``.
"""
plugins = []
for effect in effects_chain:
if not effect.get("enabled", True):
continue
effect_type = effect["type"]
registry = EFFECT_REGISTRY[effect_type]
cls = registry["cls"]
# Merge defaults with provided params
params = {}
for pname, pdef in registry["params"].items():
params[pname] = effect.get("params", {}).get(pname, pdef["default"])
plugins.append(cls(**params))
return Pedalboard(plugins)
def apply_effects(
audio: np.ndarray,
sample_rate: int,
effects_chain: List[Dict[str, Any]],
) -> np.ndarray:
"""Apply an effects chain to audio data.
Args:
audio: Input audio array (1-D mono float32).
sample_rate: Sample rate in Hz.
effects_chain: List of effect configuration dicts.
Returns:
Processed audio array.
"""
if not effects_chain:
return audio
board = build_pedalboard(effects_chain)
# pedalboard expects shape (channels, samples)
if audio.ndim == 1:
audio_2d = audio[np.newaxis, :]
else:
audio_2d = audio
processed = board(audio_2d.astype(np.float32), sample_rate)
# Return same dimensionality as input
if audio.ndim == 1:
return processed[0]
return processed

View File

@@ -0,0 +1,270 @@
"""Monkey-patch huggingface_hub to force offline mode with cached models.
Prevents mlx_audio / transformers from making network requests when models
are already downloaded. Must be imported BEFORE mlx_audio.
"""
import logging
import os
import threading
from contextlib import contextmanager
from pathlib import Path
from typing import Optional, Union
logger = logging.getLogger(__name__)
# huggingface_hub reads ``HF_HUB_OFFLINE`` once at import time into
# ``huggingface_hub.constants.HF_HUB_OFFLINE``; transformers mirrors that into
# ``transformers.utils.hub._is_offline_mode`` at *its* import time. Toggling
# ``os.environ`` after either module is imported does not flip those cached
# bools, and the hot paths (``_http._default_backend_factory``,
# ``transformers.utils.hub.is_offline_mode``) read the bools — not the env.
# We mutate the cached constants directly, guarded by a refcount so
# concurrent inference threads share a single offline window safely.
_offline_lock = threading.RLock()
_offline_refcount = 0
_saved_env: Optional[str] = None
_saved_hf_const: Optional[bool] = None
_saved_transformers_const: Optional[bool] = None
@contextmanager
def force_offline_if_cached(is_cached: bool, model_label: str = ""):
"""Force offline mode for the duration of a cached-model operation.
Flips ``HF_HUB_OFFLINE`` in the process env **and** in the cached bools
inside ``huggingface_hub.constants`` / ``transformers.utils.hub`` so HTTP
adapters and offline-mode checks actually see the change. Uses a refcount
so multiple concurrent inference threads share a single offline window
and the last one to exit restores state.
If *is_cached* is ``False`` the block runs normally (network allowed).
Args:
is_cached: Whether the model weights are already on disk.
model_label: Human-readable name used in log messages.
"""
if not is_cached:
yield
return
global _offline_refcount, _saved_env, _saved_hf_const, _saved_transformers_const
with _offline_lock:
if _offline_refcount == 0:
# Snapshot prior state, apply new state, roll back on *any*
# failure. Catching only ImportError here would let a partially
# broken install (RuntimeError, AttributeError from a half-init
# module, etc.) leave the cached HF constants mutated without
# bumping the refcount — a persistent offline leak that outlives
# the process and is miserable to debug.
prev_env = os.environ.get("HF_HUB_OFFLINE")
prev_hf: Optional[bool] = None
prev_tf: Optional[bool] = None
try:
try:
import huggingface_hub.constants as hf_const
prev_hf = hf_const.HF_HUB_OFFLINE
hf_const.HF_HUB_OFFLINE = True
except ImportError:
prev_hf = None
try:
import transformers.utils.hub as tf_hub
prev_tf = tf_hub._is_offline_mode
tf_hub._is_offline_mode = True
except ImportError:
prev_tf = None
os.environ["HF_HUB_OFFLINE"] = "1"
except BaseException:
# Roll back whatever we already changed, then re-raise so
# the caller sees the real failure.
if prev_hf is not None:
try:
import huggingface_hub.constants as hf_const
hf_const.HF_HUB_OFFLINE = prev_hf
except ImportError:
pass
if prev_tf is not None:
try:
import transformers.utils.hub as tf_hub
tf_hub._is_offline_mode = prev_tf
except ImportError:
pass
if prev_env is not None:
os.environ["HF_HUB_OFFLINE"] = prev_env
else:
os.environ.pop("HF_HUB_OFFLINE", None)
raise
_saved_env = prev_env
_saved_hf_const = prev_hf
_saved_transformers_const = prev_tf
logger.info(
"[offline-guard] %s is cached — forcing offline mode",
model_label or "model",
)
_offline_refcount += 1
try:
yield
finally:
with _offline_lock:
_offline_refcount -= 1
if _offline_refcount == 0:
if _saved_env is not None:
os.environ["HF_HUB_OFFLINE"] = _saved_env
else:
os.environ.pop("HF_HUB_OFFLINE", None)
if _saved_hf_const is not None:
try:
import huggingface_hub.constants as hf_const
hf_const.HF_HUB_OFFLINE = _saved_hf_const
except ImportError:
pass
if _saved_transformers_const is not None:
try:
import transformers.utils.hub as tf_hub
tf_hub._is_offline_mode = _saved_transformers_const
except ImportError:
pass
_saved_env = None
_saved_hf_const = None
_saved_transformers_const = None
_mistral_regex_patched = False
def patch_transformers_mistral_regex():
"""Make transformers' tokenizer load robust to HuggingFace metadata failures.
transformers 4.57.x added ``PreTrainedTokenizerBase._patch_mistral_regex``
which unconditionally calls ``huggingface_hub.model_info(repo_id)`` during
every non-local tokenizer load to check whether the model is a Mistral
variant. That call raises on ``HF_HUB_OFFLINE=1`` and on plain network
failures, killing unrelated loads (Qwen TTS, TADA, etc.).
Voicebox never loads Mistral models, so the rewrite the function would
apply is a no-op for us anyway. Wrap the method so any exception from the
metadata lookup returns the tokenizer unchanged — matching the success-path
behavior for non-Mistral repos (transformers 4.57.3,
``tokenization_utils_base.py:2503``).
"""
global _mistral_regex_patched
if _mistral_regex_patched:
return
try:
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
except ImportError:
logger.debug("transformers not available, skipping mistral-regex patch")
return
original = getattr(PreTrainedTokenizerBase, "_patch_mistral_regex", None)
if original is None:
logger.debug(
"transformers has no _patch_mistral_regex attribute, skipping patch",
)
return
def safe_patch_mistral_regex(cls, tokenizer, pretrained_model_name_or_path, *args, **kwargs):
try:
return original(tokenizer, pretrained_model_name_or_path, *args, **kwargs)
except Exception as exc:
logger.debug(
"[mistral-regex-patch] suppressed %s for %r, returning tokenizer as-is",
type(exc).__name__,
pretrained_model_name_or_path,
)
return tokenizer
PreTrainedTokenizerBase._patch_mistral_regex = classmethod(safe_patch_mistral_regex)
_mistral_regex_patched = True
logger.debug("installed _patch_mistral_regex wrapper")
def patch_huggingface_hub_offline():
"""Monkey-patch huggingface_hub to force offline mode."""
try:
import huggingface_hub # noqa: F401 -- need the package loaded
from huggingface_hub import constants as hf_constants
from huggingface_hub.file_download import _try_to_load_from_cache
original_try_load = _try_to_load_from_cache
def _patched_try_to_load_from_cache(
repo_id: str,
filename: str,
cache_dir: Union[str, Path, None] = None,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
):
result = original_try_load(
repo_id=repo_id,
filename=filename,
cache_dir=cache_dir,
revision=revision,
repo_type=repo_type,
)
if result is None:
cache_path = Path(hf_constants.HF_HUB_CACHE) / f"models--{repo_id.replace('/', '--')}"
logger.debug("file not cached: %s/%s (expected at %s)", repo_id, filename, cache_path)
else:
logger.debug("cache hit: %s/%s", repo_id, filename)
return result
import huggingface_hub.file_download as fd
fd._try_to_load_from_cache = _patched_try_to_load_from_cache
logger.debug("huggingface_hub patched for offline mode")
except ImportError:
logger.debug("huggingface_hub not available, skipping offline patch")
except Exception:
logger.exception("failed to patch huggingface_hub for offline mode")
def ensure_original_qwen_config_cached():
"""Symlink the original Qwen repo cache to the MLX community version.
mlx_audio may try to fetch config from the original Qwen repo. If only
the MLX community variant is cached, create a symlink so the cache lookup
succeeds without a network request.
"""
try:
from huggingface_hub import constants as hf_constants
except ImportError:
return
original_repo = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
mlx_repo = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16"
cache_dir = Path(hf_constants.HF_HUB_CACHE)
original_path = cache_dir / f"models--{original_repo.replace('/', '--')}"
mlx_path = cache_dir / f"models--{mlx_repo.replace('/', '--')}"
if not original_path.exists() and mlx_path.exists():
try:
original_path.parent.mkdir(parents=True, exist_ok=True)
original_path.symlink_to(mlx_path, target_is_directory=True)
logger.info("created cache symlink: %s -> %s", original_repo, mlx_repo)
except Exception:
logger.warning("could not create cache symlink for %s", original_repo, exc_info=True)
if os.environ.get("VOICEBOX_OFFLINE_PATCH", "1") != "0":
patch_huggingface_hub_offline()
patch_transformers_mistral_regex()
ensure_original_qwen_config_cached()

View File

@@ -0,0 +1,383 @@
"""
HuggingFace Hub download progress tracking.
"""
from typing import Optional, Callable
from contextlib import contextmanager
import logging
import threading
import sys
logger = logging.getLogger(__name__)
class HFProgressTracker:
"""Tracks HuggingFace Hub download progress by intercepting tqdm."""
def __init__(self, progress_callback: Optional[Callable] = None, filter_non_downloads: bool = False):
self.progress_callback = progress_callback
self.filter_non_downloads = filter_non_downloads # Only filter if True
self._original_tqdm_class = None
self._lock = threading.Lock()
self._total_downloaded = 0
self._total_size = 0
self._file_sizes = {} # Track sizes of individual files
self._file_downloaded = {} # Track downloaded bytes per file
self._current_filename = ""
self._active_tqdms = {} # Track active tqdm instances
self._hf_tqdm_original_update = None # For monkey-patching hf's tqdm
def _create_tracked_tqdm_class(self):
"""Create a tqdm subclass that tracks progress."""
tracker = self
original_tqdm = self._original_tqdm_class
class TrackedTqdm(original_tqdm):
"""A tqdm subclass that reports progress to our tracker."""
def __init__(self, *args, **kwargs):
# Extract filename from desc before passing to parent
desc = kwargs.get("desc", "")
if not desc and args:
first_arg = args[0]
if isinstance(first_arg, str):
desc = first_arg
filename = ""
if desc:
# Try to extract filename from description
# HuggingFace Hub uses format like "model.safetensors: 0%|..."
if ":" in desc:
filename = desc.split(":")[0].strip()
else:
filename = desc.strip()
# Filter out non-standard kwargs that huggingface_hub might pass
# These are custom kwargs that tqdm doesn't understand
filtered_kwargs = {}
# Known tqdm kwargs - pass these through
tqdm_kwargs = {
"iterable",
"desc",
"total",
"leave",
"file",
"ncols",
"mininterval",
"maxinterval",
"miniters",
"ascii",
"disable",
"unit",
"unit_scale",
"dynamic_ncols",
"smoothing",
"bar_format",
"initial",
"position",
"postfix",
"unit_divisor",
"write_bytes",
"lock_args",
"nrows",
"colour",
"color",
"delay",
"gui",
"disable_default",
"pos",
}
for key, value in kwargs.items():
if key in tqdm_kwargs:
filtered_kwargs[key] = value
# Force-enable the progress bar — we're tracking progress ourselves,
# we don't need tqdm to render to a terminal, but we DO need
# self.n to be updated when update() is called.
filtered_kwargs["disable"] = False
# Try to initialize with filtered kwargs, fall back to all kwargs if that fails
try:
super().__init__(*args, **filtered_kwargs)
except TypeError:
# If filtering failed, try with all kwargs (maybe tqdm version accepts them)
kwargs["disable"] = False
super().__init__(*args, **kwargs)
self._tracker_filename = filename or "unknown"
with tracker._lock:
if filename:
tracker._current_filename = filename
tracker._active_tqdms[id(self)] = {
"filename": self._tracker_filename,
}
def update(self, n=1):
result = super().update(n)
# Report progress
with tracker._lock:
if id(self) in tracker._active_tqdms:
filename = tracker._active_tqdms[id(self)]["filename"]
current = getattr(self, "n", 0)
total = getattr(self, "total", 0)
if total and total > 0:
# Always filter out non-byte progress bars (e.g., "Fetching 12 files")
# These cause crazy percentages because they're counting files, not bytes
if self._is_non_byte_progress(filename):
return result
# When model is cached, also filter out generation-related progress
if tracker.filter_non_downloads:
if not self._is_download_progress(filename):
return result
# Update per-file tracking
tracker._file_sizes[filename] = total
tracker._file_downloaded[filename] = current
# Calculate totals across all files
tracker._total_size = sum(tracker._file_sizes.values())
tracker._total_downloaded = sum(tracker._file_downloaded.values())
# Only report progress once we have a meaningful total (at least 1MB)
# This avoids the "100% at 0MB" issue when small config
# files are counted before the real model files
MIN_TOTAL_BYTES = 1_000_000 # 1MB
if tracker._total_size < MIN_TOTAL_BYTES:
return result
# Call progress callback
if tracker.progress_callback:
tracker.progress_callback(tracker._total_downloaded, tracker._total_size, filename)
return result
def _is_non_byte_progress(self, filename: str) -> bool:
"""Check if this progress bar should be SKIPPED (returns True to skip).
We want to track byte-based progress bars. This method identifies
progress bars that count files/items instead of bytes, which would
cause crazy percentages if mixed with our byte counting.
Returns:
True = SKIP this bar (it's not byte-based)
False = TRACK this bar (it counts bytes)
"""
if not filename:
return False
filename_lower = filename.lower()
# Skip "Fetching X files" - it counts files (total=12), not bytes
# Don't skip "Downloading (incomplete total...)" - that IS byte-based
skip_patterns = [
"fetching", # "Fetching 12 files" has total=12 files, not bytes
]
return any(pattern in filename_lower for pattern in skip_patterns)
def _is_download_progress(self, filename: str) -> bool:
"""Check if this is a real file download progress bar vs internal processing."""
if not filename or filename == "unknown":
return False
# Real downloads have file extensions
download_extensions = [
".safetensors",
".bin",
".pt",
".pth", # Model weights
".json",
".txt",
".py", # Config files
".msgpack",
".h5", # Other formats
]
filename_lower = filename.lower()
has_extension = any(filename_lower.endswith(ext) for ext in download_extensions)
# Skip generation-related progress indicators
skip_patterns = ["segment", "processing", "generating", "loading"]
has_skip_pattern = any(pattern in filename_lower for pattern in skip_patterns)
return has_extension and not has_skip_pattern
def close(self):
with tracker._lock:
if id(self) in tracker._active_tqdms:
del tracker._active_tqdms[id(self)]
return super().close()
return TrackedTqdm
@contextmanager
def patch_download(self):
"""Context manager to patch tqdm for progress tracking."""
try:
import tqdm as tqdm_module
# Store original tqdm class
self._original_tqdm_class = tqdm_module.tqdm
# Reset totals
with self._lock:
self._total_downloaded = 0
self._total_size = 0
self._file_sizes = {}
self._file_downloaded = {}
self._current_filename = ""
self._active_tqdms = {}
# Create our tracked tqdm class
tracked_tqdm = self._create_tracked_tqdm_class()
# Patch tqdm.tqdm
tqdm_module.tqdm = tracked_tqdm
# Also patch tqdm.auto.tqdm if it exists (used by huggingface_hub)
self._original_tqdm_auto = None
if hasattr(tqdm_module, "auto") and hasattr(tqdm_module.auto, "tqdm"):
self._original_tqdm_auto = tqdm_module.auto.tqdm
tqdm_module.auto.tqdm = tracked_tqdm
# Patch in sys.modules to catch already-imported references
# huggingface_hub uses: from tqdm.auto import tqdm as base_tqdm
# So we need to patch both 'tqdm' and 'base_tqdm' attributes
self._patched_modules = {}
tqdm_attr_names = ["tqdm", "base_tqdm", "old_tqdm"] # Various names used
patched_count = 0
for module_name in list(sys.modules.keys()):
if "huggingface" in module_name or module_name.startswith("tqdm"):
try:
module = sys.modules[module_name]
for attr_name in tqdm_attr_names:
if hasattr(module, attr_name):
attr = getattr(module, attr_name)
# Only patch if it's a tqdm class (not already patched)
is_tqdm_class = (
attr is self._original_tqdm_class
or (self._original_tqdm_auto and attr is self._original_tqdm_auto)
or (
hasattr(attr, "__name__")
and attr.__name__ == "tqdm"
and hasattr(attr, "update")
) # tqdm classes have update method
)
if is_tqdm_class:
key = f"{module_name}.{attr_name}"
self._patched_modules[key] = (module, attr_name, attr)
setattr(module, attr_name, tracked_tqdm)
patched_count += 1
except (AttributeError, TypeError):
pass
# ALSO monkey-patch the update method on huggingface_hub's tqdm class
# This is needed because the class was already defined at import time
self._hf_tqdm_original_update = None
try:
from huggingface_hub.utils import tqdm as hf_tqdm_module
if hasattr(hf_tqdm_module, "tqdm"):
hf_tqdm_class = hf_tqdm_module.tqdm
self._hf_tqdm_original_update = hf_tqdm_class.update
# Create a wrapper that calls our tracking
tracker = self # Reference to HFProgressTracker instance
def patched_update(tqdm_self, n=1):
result = tracker._hf_tqdm_original_update(tqdm_self, n)
# Track this progress
with tracker._lock:
desc = getattr(tqdm_self, "desc", "") or ""
current = getattr(tqdm_self, "n", 0)
total = getattr(tqdm_self, "total", 0) or 0
# Skip non-byte progress bars
if "fetching" in desc.lower():
return result
# Skip until we have a meaningful total (at least 1MB)
# This avoids the "100% at 0MB" issue when small config
# files are counted before the real model files
MIN_TOTAL_BYTES = 1_000_000 # 1MB
if total >= MIN_TOTAL_BYTES:
tracker._total_downloaded = current
tracker._total_size = total
if tracker.progress_callback:
tracker.progress_callback(current, total, desc)
return result
hf_tqdm_class.update = patched_update
patched_count += 1
logger.debug("Monkey-patched huggingface_hub.utils.tqdm.tqdm.update")
except (ImportError, AttributeError) as e:
logger.warning("Could not monkey-patch hf_tqdm: %s", e)
logger.debug("Patched %d tqdm references", patched_count)
yield
except ImportError:
# If tqdm not available, just yield without patching
yield
finally:
# Restore original tqdm
if self._original_tqdm_class:
try:
import tqdm as tqdm_module
tqdm_module.tqdm = self._original_tqdm_class
if self._original_tqdm_auto:
tqdm_module.auto.tqdm = self._original_tqdm_auto
# Restore patched modules
for key, (module, attr_name, original) in self._patched_modules.items():
try:
if module and original:
setattr(module, attr_name, original)
except (AttributeError, TypeError):
pass
self._patched_modules = {}
# Restore hf_tqdm's original update method
if self._hf_tqdm_original_update:
try:
from huggingface_hub.utils import tqdm as hf_tqdm_module
if hasattr(hf_tqdm_module, "tqdm"):
hf_tqdm_module.tqdm.update = self._hf_tqdm_original_update
except (ImportError, AttributeError):
pass
self._hf_tqdm_original_update = None
except (ImportError, AttributeError):
pass
def create_hf_progress_callback(model_name: str, progress_manager):
"""Create a progress callback for HuggingFace downloads."""
def callback(downloaded: int, total: int, filename: str = ""):
"""Progress callback.
Note: We send updates even when total=0 (unknown) to provide feedback
during the "incomplete total" phase of huggingface_hub downloads.
The frontend handles total=0 gracefully.
"""
progress_manager.update_progress(
model_name=model_name,
current=downloaded,
total=total,
filename=filename or "",
status="downloading",
)
return callback

114
backend/utils/images.py Normal file
View File

@@ -0,0 +1,114 @@
"""Image processing utilities for avatar uploads."""
from pathlib import Path
from typing import Optional, Tuple
from PIL import Image
# JPEG can be reported as 'JPEG' or 'MPO' (for multi-picture format from some cameras)
ALLOWED_FORMATS = {'PNG', 'JPEG', 'WEBP', 'MPO', 'JPG'}
MAX_SIZE = 512
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
def validate_image(file_path: str) -> Tuple[bool, Optional[str]]:
"""
Validate image format and file size.
Args:
file_path: Path to image file
Returns:
Tuple of (is_valid, error_message)
"""
path = Path(file_path)
# Check file size
if path.stat().st_size > MAX_FILE_SIZE:
return False, f"File size exceeds maximum of {MAX_FILE_SIZE // (1024 * 1024)}MB"
try:
with Image.open(file_path) as img:
# Verify the image can be loaded
img.load()
# Check format (normalize JPEG variants)
img_format = img.format
if img_format in ('MPO', 'JPG'):
img_format = 'JPEG'
if img_format not in {'PNG', 'JPEG', 'WEBP'}:
return False, f"Invalid format '{img_format}'. Allowed formats: PNG, JPEG, WEBP"
return True, None
except Exception as e:
return False, f"Invalid image file: {str(e)}"
def process_avatar(input_path: str, output_path: str, max_size: int = MAX_SIZE) -> None:
"""
Process avatar image: resize and optimize.
Resizes image to fit within max_size x max_size while maintaining aspect ratio.
Args:
input_path: Path to input image
output_path: Path to save processed image
max_size: Maximum width or height in pixels
"""
with Image.open(input_path) as img:
# Handle EXIF orientation for JPEG images
try:
from PIL import ExifTags
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation] == 'Orientation':
break
exif = img._getexif()
if exif is not None:
orientation_value = exif.get(orientation)
if orientation_value == 3:
img = img.rotate(180, expand=True)
elif orientation_value == 6:
img = img.rotate(270, expand=True)
elif orientation_value == 8:
img = img.rotate(90, expand=True)
except (AttributeError, KeyError, IndexError, TypeError):
# No EXIF data or orientation tag
pass
# Convert to RGB if necessary (handles RGBA, P, CMYK, etc.)
if img.mode not in ('RGB', 'L'):
if img.mode == 'RGBA':
# Create white background for RGBA images
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=img.split()[3]) # Use alpha channel as mask
img = background
elif img.mode == 'CMYK':
# Convert CMYK to RGB
img = img.convert('RGB')
elif img.mode == 'P':
# Convert palette mode to RGB
img = img.convert('RGB')
else:
img = img.convert('RGB')
# Calculate new size maintaining aspect ratio
img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
# Determine output format from extension
output_ext = Path(output_path).suffix.lower()
format_map = {
'.png': 'PNG',
'.jpeg': 'JPEG',
'.jpg': 'JPEG',
'.webp': 'WEBP'
}
output_format = format_map.get(output_ext, 'PNG')
# Save with optimization
save_kwargs = {'optimize': True}
if output_format == 'JPEG':
save_kwargs['quality'] = 90
img.save(output_path, format=output_format, **save_kwargs)

View File

@@ -0,0 +1,35 @@
"""
Platform detection for backend selection.
"""
import platform
from typing import Literal
def is_apple_silicon() -> bool:
"""
Check if running on Apple Silicon (arm64 macOS).
Returns:
True if on Apple Silicon, False otherwise
"""
return platform.system() == "Darwin" and platform.machine() == "arm64"
def get_backend_type() -> Literal["mlx", "pytorch"]:
"""
Detect the best backend for the current platform.
Returns:
"mlx" on Apple Silicon (if MLX is available and functional), "pytorch" otherwise
"""
if is_apple_silicon():
try:
import mlx.core # noqa: F401 — triggers native lib loading
return "mlx"
except (ImportError, OSError, RuntimeError):
# MLX not installed, or native libraries failed to load inside a
# PyInstaller bundle (OSError on missing .dylib / .metallib).
# Fall through to PyTorch.
return "pytorch"
return "pytorch"

315
backend/utils/progress.py Normal file
View File

@@ -0,0 +1,315 @@
"""
Progress tracking for model downloads using Server-Sent Events.
"""
from typing import Optional, Callable, Dict, List
from fastapi.responses import StreamingResponse
import asyncio
import json
import threading
from datetime import datetime
class ProgressManager:
"""Manages download progress for multiple models.
Thread-safe: can be called from background threads (e.g., via asyncio.to_thread).
"""
# Throttle settings to prevent overwhelming SSE clients
THROTTLE_INTERVAL_SECONDS = 0.5 # Minimum time between updates
THROTTLE_PROGRESS_DELTA = 1.0 # Minimum progress change (%) to force update
def __init__(self):
self._progress: Dict[str, Dict] = {}
self._listeners: Dict[str, list] = {}
self._lock = threading.Lock() # Thread-safe lock for progress dict
self._main_loop: Optional[asyncio.AbstractEventLoop] = None
self._last_notify_time: Dict[str, float] = {} # Last notification time per model
self._last_notify_progress: Dict[str, float] = {} # Last notified progress per model
def _set_main_loop(self, loop: asyncio.AbstractEventLoop):
"""Set the main event loop for thread-safe operations."""
self._main_loop = loop
def _notify_listeners_threadsafe(self, model_name: str, progress_data: Dict):
"""Notify listeners in a thread-safe manner."""
import logging
logger = logging.getLogger(__name__)
if model_name not in self._listeners:
return
for queue in self._listeners[model_name]:
try:
# Check if we're in the main event loop thread
try:
running_loop = asyncio.get_running_loop()
# We're in an async context, can use put_nowait directly
queue.put_nowait(progress_data.copy())
except RuntimeError:
# Not in async context (running in background thread)
# Use call_soon_threadsafe to safely put on queue
if self._main_loop and self._main_loop.is_running():
self._main_loop.call_soon_threadsafe(
lambda q=queue, d=progress_data.copy(): q.put_nowait(d) if not q.full() else None
)
else:
logger.debug(f"No main loop available for {model_name}, skipping notification")
except asyncio.QueueFull:
logger.warning(f"Queue full for {model_name}, dropping update")
except Exception as e:
logger.warning(f"Error notifying listener for {model_name}: {e}")
def update_progress(
self,
model_name: str,
current: int,
total: int,
filename: Optional[str] = None,
status: str = "downloading",
):
"""
Update progress for a model download.
Thread-safe: can be called from background threads.
Progress updates are throttled to prevent overwhelming SSE clients.
Updates are sent at most every THROTTLE_INTERVAL_SECONDS, or when
progress changes by at least THROTTLE_PROGRESS_DELTA percent.
Args:
model_name: Name of the model (e.g., "qwen-tts-1.7B", "whisper-base")
current: Current bytes downloaded
total: Total bytes to download
filename: Current file being downloaded
status: Status string (downloading, extracting, complete, error)
"""
import logging
import time
logger = logging.getLogger(__name__)
# Calculate progress percentage, clamped to 0-100 range
# This prevents crazy percentages from edge cases like:
# - current > total temporarily during aggregation
# - mixing file-count progress with byte-count progress
if total > 0:
progress_pct = min(100.0, max(0.0, (current / total * 100)))
else:
progress_pct = 0
progress_data = {
"model_name": model_name,
"current": current,
"total": total,
"progress": progress_pct,
"filename": filename,
"status": status,
"timestamp": datetime.now().isoformat(),
}
# Thread-safe update of progress dict (always update internal state)
with self._lock:
self._progress[model_name] = progress_data
# Check if we should notify listeners (throttling)
current_time = time.time()
last_time = self._last_notify_time.get(model_name, 0)
last_progress = self._last_notify_progress.get(model_name, -100)
time_delta = current_time - last_time
progress_delta = abs(progress_pct - last_progress)
# Always notify for complete/error status, or if throttle conditions are met
should_notify = (
status in ("complete", "error") or
time_delta >= self.THROTTLE_INTERVAL_SECONDS or
progress_delta >= self.THROTTLE_PROGRESS_DELTA
)
if not should_notify:
return # Skip this update (throttled)
# Update throttle tracking
self._last_notify_time[model_name] = current_time
self._last_notify_progress[model_name] = progress_pct
# Notify all listeners (thread-safe)
listener_count = len(self._listeners.get(model_name, []))
if listener_count > 0:
logger.debug(f"Notifying {listener_count} listeners for {model_name}: {progress_pct:.1f}% ({filename})")
self._notify_listeners_threadsafe(model_name, progress_data)
else:
logger.debug(f"No listeners for {model_name}, progress update stored: {progress_pct:.1f}%")
def get_progress(self, model_name: str) -> Optional[Dict]:
"""Get current progress for a model. Thread-safe."""
with self._lock:
progress = self._progress.get(model_name)
return progress.copy() if progress else None
def get_all_active(self) -> List[Dict]:
"""Get all active downloads (status is 'downloading' or 'extracting'). Thread-safe."""
active = []
with self._lock:
for model_name, progress in self._progress.items():
status = progress.get("status", "")
if status in ("downloading", "extracting"):
active.append(progress.copy())
return active
def create_progress_callback(self, model_name: str, filename: Optional[str] = None):
"""
Create a progress callback function for HuggingFace downloads.
Args:
model_name: Name of the model
filename: Optional filename filter
Returns:
Callback function
"""
def callback(progress: Dict):
"""HuggingFace Hub progress callback."""
if "total" in progress and "current" in progress:
current = progress.get("current", 0)
total = progress.get("total", 0)
file_name = progress.get("filename", filename)
self.update_progress(
model_name=model_name,
current=current,
total=total,
filename=file_name,
status="downloading",
)
return callback
async def subscribe(self, model_name: str):
"""
Subscribe to progress updates for a model.
Yields progress updates as Server-Sent Events.
"""
import logging
logger = logging.getLogger(__name__)
# Store the main event loop for thread-safe operations
try:
self._main_loop = asyncio.get_running_loop()
except RuntimeError:
pass
queue = asyncio.Queue(maxsize=10)
# Add to listeners
if model_name not in self._listeners:
self._listeners[model_name] = []
self._listeners[model_name].append(queue)
logger.info(f"SSE client subscribed to {model_name}, total listeners: {len(self._listeners[model_name])}")
try:
# Send initial progress if available and still in progress (thread-safe read)
with self._lock:
initial_progress = self._progress.get(model_name)
if initial_progress:
initial_progress = initial_progress.copy()
if initial_progress:
status = initial_progress.get('status')
# Only send initial progress if download is actually in progress
# Don't send old 'complete' or 'error' status from previous downloads
if status in ('downloading', 'extracting'):
logger.info(f"Sending initial progress for {model_name}: {status}")
yield f"data: {json.dumps(initial_progress)}\n\n"
else:
logger.info(f"Skipping initial progress for {model_name} (status: {status})")
else:
logger.info(f"No initial progress available for {model_name}")
# Stream updates
while True:
try:
# Wait for update with timeout
progress = await asyncio.wait_for(queue.get(), timeout=1.0)
logger.debug(f"Sending progress update for {model_name}: {progress.get('status')} - {progress.get('progress', 0):.1f}%")
yield f"data: {json.dumps(progress)}\n\n"
# Stop if complete or error
if progress.get("status") in ("complete", "error"):
logger.info(f"Download {progress.get('status')} for {model_name}, closing SSE connection")
break
except asyncio.TimeoutError:
# Send heartbeat
yield ": heartbeat\n\n"
continue
except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError):
logger.debug(f"SSE client disconnected from {model_name}")
finally:
# Remove from listeners
if model_name in self._listeners:
self._listeners[model_name].remove(queue)
if not self._listeners[model_name]:
del self._listeners[model_name]
logger.info(f"SSE client unsubscribed from {model_name}, remaining listeners: {len(self._listeners.get(model_name, []))}")
def mark_complete(self, model_name: str):
"""Mark a model download as complete. Thread-safe."""
import logging
logger = logging.getLogger(__name__)
with self._lock:
if model_name in self._progress:
self._progress[model_name]["status"] = "complete"
self._progress[model_name]["progress"] = 100.0
progress_data = self._progress[model_name].copy()
else:
logger.warning(f"Cannot mark {model_name} as complete: not found in progress")
return
logger.info(f"Marked {model_name} as complete")
# Notify listeners (thread-safe)
self._notify_listeners_threadsafe(model_name, progress_data)
def mark_error(self, model_name: str, error: str):
"""Mark a model download as failed. Thread-safe."""
import logging
logger = logging.getLogger(__name__)
with self._lock:
if model_name in self._progress:
self._progress[model_name]["status"] = "error"
self._progress[model_name]["error"] = error
progress_data = self._progress[model_name].copy()
else:
# Create new progress entry for error
progress_data = {
"model_name": model_name,
"current": 0,
"total": 0,
"progress": 0,
"filename": None,
"status": "error",
"error": error,
"timestamp": datetime.now().isoformat(),
}
self._progress[model_name] = progress_data
logger.error(f"Marked {model_name} as error: {error}")
# Notify listeners (thread-safe)
self._notify_listeners_threadsafe(model_name, progress_data)
# Global progress manager instance
_progress_manager: Optional[ProgressManager] = None
def get_progress_manager() -> ProgressManager:
"""Get or create the global progress manager."""
global _progress_manager
if _progress_manager is None:
_progress_manager = ProgressManager()
return _progress_manager

102
backend/utils/tasks.py Normal file
View File

@@ -0,0 +1,102 @@
"""
Task tracking for active downloads and generations.
"""
from typing import Optional, Dict, List
from datetime import datetime
from dataclasses import dataclass, field
@dataclass
class DownloadTask:
"""Represents an active download task."""
model_name: str
status: str = "downloading" # downloading, extracting, complete, error
started_at: datetime = field(default_factory=datetime.utcnow)
error: Optional[str] = None
@dataclass
class GenerationTask:
"""Represents an active generation task."""
task_id: str
profile_id: str
text_preview: str # First 50 chars of text
started_at: datetime = field(default_factory=datetime.utcnow)
class TaskManager:
"""Manages active downloads and generations."""
def __init__(self):
self._active_downloads: Dict[str, DownloadTask] = {}
self._active_generations: Dict[str, GenerationTask] = {}
def start_download(self, model_name: str) -> None:
"""Mark a download as started."""
self._active_downloads[model_name] = DownloadTask(
model_name=model_name,
status="downloading",
)
def complete_download(self, model_name: str) -> None:
"""Mark a download as complete."""
if model_name in self._active_downloads:
del self._active_downloads[model_name]
def error_download(self, model_name: str, error: str) -> None:
"""Mark a download as failed."""
if model_name in self._active_downloads:
self._active_downloads[model_name].status = "error"
self._active_downloads[model_name].error = error
def start_generation(self, task_id: str, profile_id: str, text: str) -> None:
"""Mark a generation as started."""
text_preview = text[:50] + "..." if len(text) > 50 else text
self._active_generations[task_id] = GenerationTask(
task_id=task_id,
profile_id=profile_id,
text_preview=text_preview,
)
def complete_generation(self, task_id: str) -> None:
"""Mark a generation as complete."""
if task_id in self._active_generations:
del self._active_generations[task_id]
def get_active_downloads(self) -> List[DownloadTask]:
"""Get all active downloads."""
return list(self._active_downloads.values())
def get_active_generations(self) -> List[GenerationTask]:
"""Get all active generations."""
return list(self._active_generations.values())
def cancel_download(self, model_name: str) -> bool:
"""Cancel/dismiss a download task (removes it from active list)."""
return self._active_downloads.pop(model_name, None) is not None
def clear_all(self) -> None:
"""Clear all download and generation tasks."""
self._active_downloads.clear()
self._active_generations.clear()
def is_download_active(self, model_name: str) -> bool:
"""Check if a download is active."""
return model_name in self._active_downloads
def is_generation_active(self, task_id: str) -> bool:
"""Check if a generation is active."""
return task_id in self._active_generations
# Global task manager instance
_task_manager: Optional[TaskManager] = None
def get_task_manager() -> TaskManager:
"""Get or create the global task manager."""
global _task_manager
if _task_manager is None:
_task_manager = TaskManager()
return _task_manager