Initial commit
This commit is contained in:
1
backend/utils/__init__.py
Normal file
1
backend/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Utils package
|
||||
318
backend/utils/audio.py
Normal file
318
backend/utils/audio.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Audio processing utilities.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import librosa
|
||||
from typing import Tuple, Optional
|
||||
|
||||
|
||||
def normalize_audio(
|
||||
audio: np.ndarray,
|
||||
target_db: float = -20.0,
|
||||
peak_limit: float = 0.85,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Normalize audio to target loudness with peak limiting.
|
||||
|
||||
Args:
|
||||
audio: Input audio array
|
||||
target_db: Target RMS level in dB
|
||||
peak_limit: Peak limit (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
Normalized audio array
|
||||
"""
|
||||
# Convert to float32
|
||||
audio = audio.astype(np.float32)
|
||||
|
||||
# Calculate current RMS
|
||||
rms = np.sqrt(np.mean(audio**2))
|
||||
|
||||
# Calculate target RMS
|
||||
target_rms = 10**(target_db / 20)
|
||||
|
||||
# Apply gain
|
||||
if rms > 0:
|
||||
gain = target_rms / rms
|
||||
audio = audio * gain
|
||||
|
||||
# Peak limiting
|
||||
audio = np.clip(audio, -peak_limit, peak_limit)
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
def load_audio(
|
||||
path: str,
|
||||
sample_rate: int = 24000,
|
||||
mono: bool = True,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Load audio file with normalization.
|
||||
|
||||
Args:
|
||||
path: Path to audio file
|
||||
sample_rate: Target sample rate
|
||||
mono: Convert to mono
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
audio, sr = librosa.load(path, sr=sample_rate, mono=mono)
|
||||
return audio, sr
|
||||
|
||||
|
||||
def save_audio(
|
||||
audio: np.ndarray,
|
||||
path: str,
|
||||
sample_rate: int = 24000,
|
||||
) -> None:
|
||||
"""
|
||||
Save audio file with atomic write and error handling.
|
||||
|
||||
Writes to a temporary file first, then atomically renames to the
|
||||
target path. This prevents corrupted/partial WAV files if the
|
||||
process is interrupted mid-write.
|
||||
|
||||
Args:
|
||||
audio: Audio array
|
||||
path: Output path
|
||||
sample_rate: Sample rate
|
||||
|
||||
Raises:
|
||||
OSError: If file cannot be written
|
||||
"""
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
temp_path = f"{path}.tmp"
|
||||
try:
|
||||
# Ensure parent directory exists
|
||||
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write to temporary file first (explicit format since .tmp
|
||||
# extension is not recognised by soundfile)
|
||||
sf.write(temp_path, audio, sample_rate, format='WAV')
|
||||
|
||||
# Atomic rename to final path
|
||||
os.replace(temp_path, path)
|
||||
|
||||
except Exception as e:
|
||||
# Clean up temp file on failure
|
||||
try:
|
||||
if Path(temp_path).exists():
|
||||
Path(temp_path).unlink()
|
||||
except Exception:
|
||||
pass # Best effort cleanup
|
||||
|
||||
raise OSError(f"Failed to save audio to {path}: {e}") from e
|
||||
|
||||
|
||||
def trim_tts_output(
|
||||
audio: np.ndarray,
|
||||
sample_rate: int = 24000,
|
||||
frame_ms: int = 20,
|
||||
silence_threshold_db: float = -40.0,
|
||||
min_silence_ms: int = 200,
|
||||
max_internal_silence_ms: int = 1000,
|
||||
fade_ms: int = 30,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Trim trailing silence and post-silence hallucination from TTS output.
|
||||
|
||||
Chatterbox sometimes produces ``[speech][silence][hallucinated noise]``.
|
||||
This detects internal silence gaps longer than *max_internal_silence_ms*
|
||||
and cuts the audio at that boundary, then trims trailing silence and
|
||||
applies a short cosine fade-out.
|
||||
|
||||
Args:
|
||||
audio: Input audio array (mono float32)
|
||||
sample_rate: Sample rate in Hz
|
||||
frame_ms: Frame size for RMS energy calculation
|
||||
silence_threshold_db: dB threshold below which a frame is silence
|
||||
min_silence_ms: Minimum trailing silence to keep
|
||||
max_internal_silence_ms: Cut after any silence gap longer than this
|
||||
fade_ms: Cosine fade-out duration in ms
|
||||
|
||||
Returns:
|
||||
Trimmed audio array
|
||||
"""
|
||||
frame_len = int(sample_rate * frame_ms / 1000)
|
||||
if frame_len == 0 or len(audio) < frame_len:
|
||||
return audio
|
||||
|
||||
n_frames = len(audio) // frame_len
|
||||
threshold_linear = 10 ** (silence_threshold_db / 20)
|
||||
|
||||
# Compute per-frame RMS
|
||||
rms = np.array(
|
||||
[
|
||||
np.sqrt(np.mean(audio[i * frame_len : (i + 1) * frame_len] ** 2))
|
||||
for i in range(n_frames)
|
||||
]
|
||||
)
|
||||
is_speech = rms >= threshold_linear
|
||||
|
||||
# Find first speech frame
|
||||
first_speech = 0
|
||||
for i, s in enumerate(is_speech):
|
||||
if s:
|
||||
first_speech = max(0, i - 1) # keep 1 frame padding
|
||||
break
|
||||
|
||||
# Walk forward from first speech; cut at long internal silence gaps
|
||||
max_silence_frames = int(max_internal_silence_ms / frame_ms)
|
||||
consecutive_silence = 0
|
||||
cut_frame = n_frames
|
||||
|
||||
for i in range(first_speech, n_frames):
|
||||
if is_speech[i]:
|
||||
consecutive_silence = 0
|
||||
else:
|
||||
consecutive_silence += 1
|
||||
if consecutive_silence >= max_silence_frames:
|
||||
cut_frame = i - consecutive_silence + 1
|
||||
break
|
||||
|
||||
# Trim trailing silence from the cut point
|
||||
min_silence_frames = int(min_silence_ms / frame_ms)
|
||||
end_frame = cut_frame
|
||||
while end_frame > first_speech and not is_speech[end_frame - 1]:
|
||||
end_frame -= 1
|
||||
# Keep a short tail
|
||||
end_frame = min(end_frame + min_silence_frames, cut_frame)
|
||||
|
||||
# Convert frames back to samples
|
||||
start_sample = first_speech * frame_len
|
||||
end_sample = min(end_frame * frame_len, len(audio))
|
||||
|
||||
trimmed = audio[start_sample:end_sample].copy()
|
||||
|
||||
# Cosine fade-out
|
||||
fade_samples = int(sample_rate * fade_ms / 1000)
|
||||
if fade_samples > 0 and len(trimmed) > fade_samples:
|
||||
fade = np.cos(np.linspace(0, np.pi / 2, fade_samples)) ** 2
|
||||
trimmed[-fade_samples:] *= fade
|
||||
|
||||
return trimmed
|
||||
|
||||
|
||||
def preprocess_reference_audio(
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
peak_target: float = 0.95,
|
||||
trim_top_db: float = 40.0,
|
||||
edge_padding_ms: int = 100,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Clean up a reference-audio sample before validation/storage.
|
||||
|
||||
Removes DC offset, trims leading/trailing silence, and caps the peak so a
|
||||
slightly-hot recording doesn't get rejected downstream as "clipping". The
|
||||
goal is to accept reasonable real-world recordings — not to repair badly
|
||||
distorted ones. True clipping artifacts inside the waveform can't be
|
||||
recovered by peak scaling and will still sound bad.
|
||||
|
||||
Args:
|
||||
audio: Mono audio array.
|
||||
sample_rate: Sample rate of ``audio`` in Hz.
|
||||
peak_target: Peak amplitude cap in [0, 1]. Applied only if the input
|
||||
peak exceeds this value.
|
||||
trim_top_db: Silence threshold for edge trimming, in dB below peak.
|
||||
40 dB sits below normal speech dynamic range (≈30 dB) so soft
|
||||
trailing syllables are preserved, while still catching obvious
|
||||
leading/trailing silence. Lower values are more aggressive;
|
||||
librosa's own default is 60.
|
||||
edge_padding_ms: Milliseconds of padding to add back at each edge
|
||||
*only if* trimming shortened the waveform, so TTS engines have a
|
||||
brief silence to anchor on without ever making the output longer
|
||||
than the input.
|
||||
|
||||
Returns:
|
||||
Preprocessed audio array (float32).
|
||||
"""
|
||||
audio = audio.astype(np.float32, copy=False)
|
||||
|
||||
if audio.size == 0:
|
||||
return audio
|
||||
|
||||
audio = audio - float(np.mean(audio))
|
||||
|
||||
trimmed, _ = librosa.effects.trim(audio, top_db=trim_top_db)
|
||||
if 0 < trimmed.size < audio.size:
|
||||
pad_each = int(sample_rate * edge_padding_ms / 1000)
|
||||
# Never pad past the original length — for near-max-duration uploads
|
||||
# an unconditional pad would push them over the 30 s ceiling and
|
||||
# trigger a spurious "too long" rejection.
|
||||
headroom = (audio.size - trimmed.size) // 2
|
||||
pad = min(pad_each, max(headroom, 0))
|
||||
if pad > 0:
|
||||
trimmed = np.pad(trimmed, (pad, pad), mode="constant")
|
||||
audio = trimmed
|
||||
|
||||
peak = float(np.abs(audio).max())
|
||||
if peak > peak_target and peak > 0:
|
||||
audio = audio * (peak_target / peak)
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
def validate_reference_audio(
|
||||
audio_path: str,
|
||||
min_duration: float = 2.0,
|
||||
max_duration: float = 30.0,
|
||||
min_rms: float = 0.01,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate reference audio for voice cloning.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
min_duration: Minimum duration in seconds
|
||||
max_duration: Maximum duration in seconds
|
||||
min_rms: Minimum RMS level
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
result = validate_and_load_reference_audio(
|
||||
audio_path, min_duration, max_duration, min_rms
|
||||
)
|
||||
return (result[0], result[1])
|
||||
|
||||
|
||||
def validate_and_load_reference_audio(
|
||||
audio_path: str,
|
||||
min_duration: float = 2.0,
|
||||
max_duration: float = 30.0,
|
||||
min_rms: float = 0.01,
|
||||
) -> Tuple[bool, Optional[str], Optional[np.ndarray], Optional[int]]:
|
||||
"""
|
||||
Validate and load reference audio in a single pass.
|
||||
|
||||
Applies :func:`preprocess_reference_audio` before checks so that
|
||||
slightly-hot recordings aren't rejected as clipping. Duration and RMS
|
||||
checks run on the preprocessed waveform.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message, audio_array, sample_rate)
|
||||
"""
|
||||
try:
|
||||
audio, sr = load_audio(audio_path)
|
||||
audio = preprocess_reference_audio(audio, sr)
|
||||
duration = len(audio) / sr
|
||||
|
||||
if duration < min_duration:
|
||||
return False, f"Audio too short (minimum {min_duration} seconds)", None, None
|
||||
if duration > max_duration:
|
||||
return False, f"Audio too long (maximum {max_duration} seconds)", None, None
|
||||
|
||||
rms = np.sqrt(np.mean(audio**2))
|
||||
if rms < min_rms:
|
||||
return False, "Audio is too quiet or silent", None, None
|
||||
|
||||
return True, None, audio, sr
|
||||
except Exception as e:
|
||||
return False, f"Error validating audio: {str(e)}", None, None
|
||||
153
backend/utils/cache.py
Normal file
153
backend/utils/cache.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Voice prompt caching utilities.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Dict, Any
|
||||
|
||||
from .. import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_cache_dir() -> Path:
|
||||
"""Get cache directory from config."""
|
||||
return config.get_cache_dir()
|
||||
|
||||
|
||||
# In-memory cache - can store dict (voice prompt) or tensor (legacy)
|
||||
_memory_cache: dict[str, Union[torch.Tensor, Dict[str, Any]]] = {}
|
||||
|
||||
|
||||
def get_cache_key(audio_path: str, reference_text: str) -> str:
|
||||
"""
|
||||
Generate cache key from audio file and reference text.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
reference_text: Reference text
|
||||
|
||||
Returns:
|
||||
Cache key (MD5 hash)
|
||||
"""
|
||||
# Read audio file
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
# Combine audio bytes and text
|
||||
combined = audio_bytes + reference_text.encode("utf-8")
|
||||
|
||||
# Generate hash
|
||||
return hashlib.md5(combined).hexdigest()
|
||||
|
||||
|
||||
def get_cached_voice_prompt(
|
||||
cache_key: str,
|
||||
) -> Optional[Union[torch.Tensor, Dict[str, Any]]]:
|
||||
"""
|
||||
Get cached voice prompt if available.
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached voice prompt (dict or tensor) or None
|
||||
"""
|
||||
# Check in-memory cache
|
||||
if cache_key in _memory_cache:
|
||||
return _memory_cache[cache_key]
|
||||
|
||||
# Check disk cache
|
||||
cache_file = _get_cache_dir() / f"{cache_key}.prompt"
|
||||
if cache_file.exists():
|
||||
try:
|
||||
prompt = torch.load(cache_file, weights_only=True)
|
||||
_memory_cache[cache_key] = prompt
|
||||
return prompt
|
||||
except Exception:
|
||||
# Cache file corrupted, delete it
|
||||
cache_file.unlink()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def cache_voice_prompt(
|
||||
cache_key: str,
|
||||
voice_prompt: Union[torch.Tensor, Dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Cache voice prompt to memory and disk.
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
voice_prompt: Voice prompt (dict or tensor)
|
||||
"""
|
||||
# Store in memory
|
||||
_memory_cache[cache_key] = voice_prompt
|
||||
|
||||
# Store on disk (torch.save can handle both dicts and tensors)
|
||||
cache_file = _get_cache_dir() / f"{cache_key}.prompt"
|
||||
torch.save(voice_prompt, cache_file)
|
||||
|
||||
|
||||
def clear_voice_prompt_cache() -> int:
|
||||
"""
|
||||
Clear all voice prompt caches (memory and disk).
|
||||
|
||||
Returns:
|
||||
Number of cache files deleted
|
||||
"""
|
||||
# Clear memory cache
|
||||
_memory_cache.clear()
|
||||
|
||||
# Clear disk cache
|
||||
cache_dir = _get_cache_dir()
|
||||
deleted_count = 0
|
||||
|
||||
if cache_dir.exists():
|
||||
# Delete prompt cache files
|
||||
for cache_file in cache_dir.glob("*.prompt"):
|
||||
try:
|
||||
cache_file.unlink()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete cache file %s: %s", cache_file, e)
|
||||
|
||||
# Delete combined audio files
|
||||
for audio_file in cache_dir.glob("combined_*.wav"):
|
||||
try:
|
||||
audio_file.unlink()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete combined audio file %s: %s", audio_file, e)
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
def clear_profile_cache(profile_id: str) -> int:
|
||||
"""
|
||||
Clear cache files for a specific profile.
|
||||
|
||||
Args:
|
||||
profile_id: Profile ID
|
||||
|
||||
Returns:
|
||||
Number of cache files deleted
|
||||
"""
|
||||
cache_dir = _get_cache_dir()
|
||||
deleted_count = 0
|
||||
|
||||
if cache_dir.exists():
|
||||
# Delete combined audio files for this profile
|
||||
pattern = f"combined_{profile_id}_*.wav"
|
||||
for audio_file in cache_dir.glob(pattern):
|
||||
try:
|
||||
audio_file.unlink()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning("Failed to delete combined audio file %s: %s", audio_file, e)
|
||||
|
||||
return deleted_count
|
||||
299
backend/utils/chunked_tts.py
Normal file
299
backend/utils/chunked_tts.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Chunked TTS generation utilities.
|
||||
|
||||
Splits long text into sentence-boundary chunks, generates audio per-chunk
|
||||
via any TTSBackend, and concatenates with crossfade. All logic is
|
||||
engine-agnostic — it wraps the standard ``TTSBackend.generate()`` interface.
|
||||
|
||||
Short text (≤ max_chunk_chars) uses the single-shot fast path with zero
|
||||
overhead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger("voicebox.chunked-tts")
|
||||
|
||||
# Default chunk size in characters. Can be overridden per-request via
|
||||
# the ``max_chunk_chars`` field on GenerationRequest.
|
||||
DEFAULT_MAX_CHUNK_CHARS = 800
|
||||
|
||||
# Common abbreviations that should NOT be treated as sentence endings.
|
||||
# Lowercase for case-insensitive matching.
|
||||
_ABBREVIATIONS = frozenset(
|
||||
{
|
||||
"mr",
|
||||
"mrs",
|
||||
"ms",
|
||||
"dr",
|
||||
"prof",
|
||||
"sr",
|
||||
"jr",
|
||||
"st",
|
||||
"ave",
|
||||
"blvd",
|
||||
"inc",
|
||||
"ltd",
|
||||
"corp",
|
||||
"dept",
|
||||
"est",
|
||||
"approx",
|
||||
"vs",
|
||||
"etc",
|
||||
"e.g",
|
||||
"i.e",
|
||||
"a.m",
|
||||
"p.m",
|
||||
"u.s",
|
||||
"u.s.a",
|
||||
"u.k",
|
||||
}
|
||||
)
|
||||
|
||||
# Paralinguistic tags used by Chatterbox Turbo. The splitter must never
|
||||
# cut inside one of these.
|
||||
_PARA_TAG_RE = re.compile(r"\[[^\]]*\]")
|
||||
|
||||
|
||||
def split_text_into_chunks(text: str, max_chars: int = DEFAULT_MAX_CHUNK_CHARS) -> List[str]:
|
||||
"""Split *text* at natural boundaries into chunks of at most *max_chars*.
|
||||
|
||||
Priority: sentence-end (``.!?`` not preceded by an abbreviation and not
|
||||
inside brackets) → clause boundary (``;:,—``) → whitespace → hard cut.
|
||||
|
||||
Paralinguistic tags like ``[laugh]`` are treated as atomic and will not
|
||||
be split across chunks.
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
if len(text) <= max_chars:
|
||||
return [text]
|
||||
|
||||
chunks: List[str] = []
|
||||
remaining = text
|
||||
|
||||
while remaining:
|
||||
remaining = remaining.lstrip()
|
||||
if not remaining:
|
||||
break
|
||||
if len(remaining) <= max_chars:
|
||||
chunks.append(remaining)
|
||||
break
|
||||
|
||||
segment = remaining[:max_chars]
|
||||
|
||||
# Try to split at the last real sentence ending
|
||||
split_pos = _find_last_sentence_end(segment)
|
||||
if split_pos == -1:
|
||||
split_pos = _find_last_clause_boundary(segment)
|
||||
if split_pos == -1:
|
||||
split_pos = segment.rfind(" ")
|
||||
if split_pos == -1:
|
||||
# Absolute fallback: hard cut but avoid splitting inside a tag
|
||||
split_pos = _safe_hard_cut(segment, max_chars)
|
||||
|
||||
chunk = remaining[: split_pos + 1].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
remaining = remaining[split_pos + 1 :]
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _find_last_sentence_end(text: str) -> int:
|
||||
"""Return the index of the last sentence-ending punctuation in *text*.
|
||||
|
||||
Skips periods that follow common abbreviations (``Dr.``, ``Mr.``, etc.)
|
||||
and periods inside bracket tags (``[laugh]``). Also handles CJK
|
||||
sentence-ending punctuation (``。!?``).
|
||||
"""
|
||||
best = -1
|
||||
# ASCII sentence ends
|
||||
for m in re.finditer(r"[.!?](?:\s|$)", text):
|
||||
pos = m.start()
|
||||
char = text[pos]
|
||||
# Skip periods after abbreviations
|
||||
if char == ".":
|
||||
# Walk backwards to find the preceding word
|
||||
word_start = pos - 1
|
||||
while word_start >= 0 and text[word_start].isalpha():
|
||||
word_start -= 1
|
||||
word = text[word_start + 1 : pos].lower()
|
||||
if word in _ABBREVIATIONS:
|
||||
continue
|
||||
# Skip decimal numbers (digit immediately before the period)
|
||||
if word_start >= 0 and text[word_start].isdigit():
|
||||
continue
|
||||
# Skip if we're inside a bracket tag
|
||||
if _inside_bracket_tag(text, pos):
|
||||
continue
|
||||
best = pos
|
||||
# CJK sentence-ending punctuation
|
||||
for m in re.finditer(r"[\u3002\uff01\uff1f]", text):
|
||||
if m.start() > best:
|
||||
best = m.start()
|
||||
return best
|
||||
|
||||
|
||||
def _find_last_clause_boundary(text: str) -> int:
|
||||
"""Return the index of the last clause-boundary punctuation."""
|
||||
best = -1
|
||||
for m in re.finditer(r"[;:,\u2014](?:\s|$)", text):
|
||||
pos = m.start()
|
||||
# Skip if inside a bracket tag
|
||||
if _inside_bracket_tag(text, pos):
|
||||
continue
|
||||
best = pos
|
||||
return best
|
||||
|
||||
|
||||
def _inside_bracket_tag(text: str, pos: int) -> bool:
|
||||
"""Return True if *pos* falls inside a ``[...]`` tag."""
|
||||
for m in _PARA_TAG_RE.finditer(text):
|
||||
if m.start() < pos < m.end():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _safe_hard_cut(segment: str, max_chars: int) -> int:
|
||||
"""Find a hard-cut position that doesn't split a ``[tag]``."""
|
||||
cut = max_chars - 1
|
||||
# Check if the cut falls inside a bracket tag; if so, move before it
|
||||
for m in _PARA_TAG_RE.finditer(segment):
|
||||
if m.start() < cut < m.end():
|
||||
return m.start() - 1 if m.start() > 0 else cut
|
||||
return cut
|
||||
|
||||
|
||||
def concatenate_audio_chunks(
|
||||
chunks: List[np.ndarray],
|
||||
sample_rate: int,
|
||||
crossfade_ms: int = 50,
|
||||
) -> np.ndarray:
|
||||
"""Concatenate audio arrays with a short crossfade to eliminate clicks.
|
||||
|
||||
Each chunk is expected to be a 1-D float32 ndarray at *sample_rate* Hz.
|
||||
"""
|
||||
if not chunks:
|
||||
return np.array([], dtype=np.float32)
|
||||
if len(chunks) == 1:
|
||||
return chunks[0]
|
||||
|
||||
crossfade_samples = int(sample_rate * crossfade_ms / 1000)
|
||||
result = np.array(chunks[0], dtype=np.float32, copy=True)
|
||||
|
||||
for chunk in chunks[1:]:
|
||||
if len(chunk) == 0:
|
||||
continue
|
||||
overlap = min(crossfade_samples, len(result), len(chunk))
|
||||
if overlap > 0:
|
||||
fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32)
|
||||
fade_in = np.linspace(0.0, 1.0, overlap, dtype=np.float32)
|
||||
result[-overlap:] = result[-overlap:] * fade_out + chunk[:overlap] * fade_in
|
||||
result = np.concatenate([result, chunk[overlap:]])
|
||||
else:
|
||||
result = np.concatenate([result, chunk])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def generate_chunked(
|
||||
backend,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: int | None = None,
|
||||
instruct: str | None = None,
|
||||
max_chunk_chars: int = DEFAULT_MAX_CHUNK_CHARS,
|
||||
crossfade_ms: int = 50,
|
||||
trim_fn=None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""Generate audio with automatic chunking for long text.
|
||||
|
||||
For text shorter than *max_chunk_chars* this is a thin wrapper around
|
||||
``backend.generate()`` with zero overhead.
|
||||
|
||||
For longer text the input is split at natural sentence boundaries,
|
||||
each chunk is generated independently, optionally trimmed (useful for
|
||||
Chatterbox engines that hallucinate trailing noise), and the results
|
||||
are concatenated with a crossfade (or hard cut if *crossfade_ms* is 0).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend : TTSBackend
|
||||
Any backend implementing the ``generate()`` protocol.
|
||||
text : str
|
||||
Input text (may be arbitrarily long).
|
||||
voice_prompt, language, seed, instruct
|
||||
Forwarded to ``backend.generate()`` verbatim.
|
||||
max_chunk_chars : int
|
||||
Maximum characters per chunk (default 800).
|
||||
crossfade_ms : int
|
||||
Crossfade duration in milliseconds between chunks. 0 for a hard
|
||||
cut with no overlap (default 50).
|
||||
trim_fn : callable | None
|
||||
Optional ``(audio, sample_rate) -> audio`` post-processing
|
||||
function applied to each chunk before concatenation (e.g.
|
||||
``trim_tts_output`` for Chatterbox engines).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(audio, sample_rate) : Tuple[np.ndarray, int]
|
||||
"""
|
||||
chunks = split_text_into_chunks(text, max_chunk_chars)
|
||||
|
||||
if len(chunks) <= 1:
|
||||
# Short text — single-shot fast path
|
||||
audio, sample_rate = await backend.generate(
|
||||
text,
|
||||
voice_prompt,
|
||||
language,
|
||||
seed,
|
||||
instruct,
|
||||
)
|
||||
if trim_fn is not None:
|
||||
audio = trim_fn(audio, sample_rate)
|
||||
return audio, sample_rate
|
||||
|
||||
# Long text — chunked generation
|
||||
logger.info(
|
||||
"Splitting %d chars into %d chunks (max %d chars each)",
|
||||
len(text),
|
||||
len(chunks),
|
||||
max_chunk_chars,
|
||||
)
|
||||
audio_chunks: List[np.ndarray] = []
|
||||
sample_rate: int | None = None
|
||||
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
logger.info(
|
||||
"Generating chunk %d/%d (%d chars)",
|
||||
i + 1,
|
||||
len(chunks),
|
||||
len(chunk_text),
|
||||
)
|
||||
# Vary the seed per chunk to avoid correlated RNG artefacts,
|
||||
# but keep it deterministic so the same (text, seed) pair
|
||||
# always produces the same output.
|
||||
chunk_seed = (seed + i) if seed is not None else None
|
||||
|
||||
chunk_audio, chunk_sr = await backend.generate(
|
||||
chunk_text,
|
||||
voice_prompt,
|
||||
language,
|
||||
chunk_seed,
|
||||
instruct,
|
||||
)
|
||||
if trim_fn is not None:
|
||||
chunk_audio = trim_fn(chunk_audio, chunk_sr)
|
||||
|
||||
audio_chunks.append(np.asarray(chunk_audio, dtype=np.float32))
|
||||
if sample_rate is None:
|
||||
sample_rate = chunk_sr
|
||||
|
||||
audio = concatenate_audio_chunks(audio_chunks, sample_rate, crossfade_ms=crossfade_ms)
|
||||
return audio, sample_rate
|
||||
95
backend/utils/dac_shim.py
Normal file
95
backend/utils/dac_shim.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Minimal shim for descript-audio-codec (DAC).
|
||||
|
||||
TADA only imports Snake1d from dac.nn.layers and dac.model.dac.
|
||||
The real DAC package pulls in descript-audiotools which depends on
|
||||
onnx, tensorboard, protobuf, matplotlib, pystoi, etc. — none of
|
||||
which are needed for TADA's runtime use of Snake1d.
|
||||
|
||||
This shim provides the exact Snake1d implementation (MIT-licensed,
|
||||
from https://github.com/descriptinc/descript-audio-codec) so we can
|
||||
avoid the entire audiotools dependency chain.
|
||||
|
||||
If the real DAC package is installed, this module is never used —
|
||||
Python's import system will find the site-packages version first.
|
||||
Install this shim only when descript-audio-codec is NOT installed.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# ── Snake activation (from dac/nn/layers.py) ────────────────────────
|
||||
|
||||
# NOTE: The original DAC code uses @torch.jit.script here for a 1.4x
|
||||
# speedup. We omit it because TorchScript calls inspect.getsource()
|
||||
# which fails inside a PyInstaller frozen binary (no .py source files).
|
||||
def snake(x: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
|
||||
shape = x.shape
|
||||
x = x.reshape(shape[0], shape[1], -1)
|
||||
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return snake(x, self.alpha)
|
||||
|
||||
|
||||
# ── Register as dac.nn.layers and dac.model.dac ─────────────────────
|
||||
|
||||
def install_dac_shim() -> None:
|
||||
"""Register fake dac package modules in sys.modules.
|
||||
|
||||
Only installs the shim if 'dac' is not already importable
|
||||
(i.e. the real descript-audio-codec is not installed).
|
||||
"""
|
||||
try:
|
||||
import dac # noqa: F401 — real package exists, do nothing
|
||||
return
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Create the module tree: dac -> dac.nn -> dac.nn.layers
|
||||
# -> dac.model -> dac.model.dac
|
||||
dac_pkg = types.ModuleType("dac")
|
||||
dac_pkg.__path__ = [] # make it a package
|
||||
dac_pkg.__package__ = "dac"
|
||||
|
||||
dac_nn = types.ModuleType("dac.nn")
|
||||
dac_nn.__path__ = []
|
||||
dac_nn.__package__ = "dac.nn"
|
||||
|
||||
dac_nn_layers = types.ModuleType("dac.nn.layers")
|
||||
dac_nn_layers.__package__ = "dac.nn"
|
||||
dac_nn_layers.Snake1d = Snake1d
|
||||
dac_nn_layers.snake = snake
|
||||
|
||||
dac_model = types.ModuleType("dac.model")
|
||||
dac_model.__path__ = []
|
||||
dac_model.__package__ = "dac.model"
|
||||
|
||||
dac_model_dac = types.ModuleType("dac.model.dac")
|
||||
dac_model_dac.__package__ = "dac.model"
|
||||
dac_model_dac.Snake1d = Snake1d
|
||||
|
||||
# Wire up submodules
|
||||
dac_pkg.nn = dac_nn
|
||||
dac_pkg.model = dac_model
|
||||
dac_nn.layers = dac_nn_layers
|
||||
dac_model.dac = dac_model_dac
|
||||
|
||||
# Register in sys.modules
|
||||
sys.modules["dac"] = dac_pkg
|
||||
sys.modules["dac.nn"] = dac_nn
|
||||
sys.modules["dac.nn.layers"] = dac_nn_layers
|
||||
sys.modules["dac.model"] = dac_model
|
||||
sys.modules["dac.model.dac"] = dac_model_dac
|
||||
373
backend/utils/effects.py
Normal file
373
backend/utils/effects.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""
|
||||
Audio post-processing effects engine.
|
||||
|
||||
Uses Spotify's pedalboard library to apply professional-grade DSP effects
|
||||
to generated audio. Effects are described as a JSON-serializable chain
|
||||
(list of effect dicts) so they can be stored in the database and sent
|
||||
over the API.
|
||||
|
||||
Supported effect types:
|
||||
- chorus (flanger-style with short delays)
|
||||
- reverb (room reverb)
|
||||
- delay (echo / delay line)
|
||||
- compressor (dynamic range compression)
|
||||
- gain (volume adjustment in dB)
|
||||
- highpass (high-pass filter)
|
||||
- lowpass (low-pass filter)
|
||||
- pitch_shift (semitone pitch shifting)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pedalboard import (
|
||||
Pedalboard,
|
||||
Chorus,
|
||||
Reverb,
|
||||
Compressor,
|
||||
Gain,
|
||||
HighpassFilter,
|
||||
LowpassFilter,
|
||||
Delay,
|
||||
PitchShift,
|
||||
)
|
||||
|
||||
|
||||
# Each param definition: (default, min, max, description)
|
||||
EFFECT_REGISTRY: Dict[str, Dict[str, Any]] = {
|
||||
"chorus": {
|
||||
"cls": Chorus,
|
||||
"label": "Chorus / Flanger",
|
||||
"description": "Modulated delay for flanging or chorus effects. Short centre_delay_ms (<10) gives flanger; longer gives chorus.",
|
||||
"params": {
|
||||
"rate_hz": {"default": 1.0, "min": 0.01, "max": 20.0, "step": 0.01, "description": "LFO speed (Hz)"},
|
||||
"depth": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Modulation depth"},
|
||||
"feedback": {"default": 0.0, "min": 0.0, "max": 0.95, "step": 0.01, "description": "Feedback amount"},
|
||||
"centre_delay_ms": {
|
||||
"default": 7.0,
|
||||
"min": 0.5,
|
||||
"max": 50.0,
|
||||
"step": 0.1,
|
||||
"description": "Centre delay (ms)",
|
||||
},
|
||||
"mix": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Wet/dry mix"},
|
||||
},
|
||||
},
|
||||
"reverb": {
|
||||
"cls": Reverb,
|
||||
"label": "Reverb",
|
||||
"description": "Room reverb effect.",
|
||||
"params": {
|
||||
"room_size": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Room size"},
|
||||
"damping": {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "description": "High frequency damping"},
|
||||
"wet_level": {"default": 0.33, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Wet level"},
|
||||
"dry_level": {"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Dry level"},
|
||||
"width": {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Stereo width"},
|
||||
},
|
||||
},
|
||||
"delay": {
|
||||
"cls": Delay,
|
||||
"label": "Delay",
|
||||
"description": "Echo / delay line.",
|
||||
"params": {
|
||||
"delay_seconds": {
|
||||
"default": 0.3,
|
||||
"min": 0.01,
|
||||
"max": 2.0,
|
||||
"step": 0.01,
|
||||
"description": "Delay time (seconds)",
|
||||
},
|
||||
"feedback": {"default": 0.3, "min": 0.0, "max": 0.95, "step": 0.01, "description": "Feedback amount"},
|
||||
"mix": {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01, "description": "Wet/dry mix"},
|
||||
},
|
||||
},
|
||||
"compressor": {
|
||||
"cls": Compressor,
|
||||
"label": "Compressor",
|
||||
"description": "Dynamic range compression for consistent loudness.",
|
||||
"params": {
|
||||
"threshold_db": {"default": -20.0, "min": -60.0, "max": 0.0, "step": 0.5, "description": "Threshold (dB)"},
|
||||
"ratio": {"default": 4.0, "min": 1.0, "max": 20.0, "step": 0.1, "description": "Compression ratio"},
|
||||
"attack_ms": {"default": 10.0, "min": 0.1, "max": 100.0, "step": 0.1, "description": "Attack time (ms)"},
|
||||
"release_ms": {
|
||||
"default": 100.0,
|
||||
"min": 10.0,
|
||||
"max": 1000.0,
|
||||
"step": 1.0,
|
||||
"description": "Release time (ms)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"gain": {
|
||||
"cls": Gain,
|
||||
"label": "Gain",
|
||||
"description": "Volume adjustment in decibels.",
|
||||
"params": {
|
||||
"gain_db": {"default": 0.0, "min": -40.0, "max": 40.0, "step": 0.5, "description": "Gain (dB)"},
|
||||
},
|
||||
},
|
||||
"highpass": {
|
||||
"cls": HighpassFilter,
|
||||
"label": "High-Pass Filter",
|
||||
"description": "Removes frequencies below the cutoff.",
|
||||
"params": {
|
||||
"cutoff_frequency_hz": {
|
||||
"default": 80.0,
|
||||
"min": 20.0,
|
||||
"max": 8000.0,
|
||||
"step": 1.0,
|
||||
"description": "Cutoff frequency (Hz)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"lowpass": {
|
||||
"cls": LowpassFilter,
|
||||
"label": "Low-Pass Filter",
|
||||
"description": "Removes frequencies above the cutoff.",
|
||||
"params": {
|
||||
"cutoff_frequency_hz": {
|
||||
"default": 8000.0,
|
||||
"min": 200.0,
|
||||
"max": 20000.0,
|
||||
"step": 1.0,
|
||||
"description": "Cutoff frequency (Hz)",
|
||||
},
|
||||
},
|
||||
},
|
||||
"pitch_shift": {
|
||||
"cls": PitchShift,
|
||||
"label": "Pitch Shift",
|
||||
"description": "Shift pitch up or down by semitones.",
|
||||
"params": {
|
||||
"semitones": {"default": 0.0, "min": -12.0, "max": 12.0, "step": 0.5, "description": "Semitones to shift"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
BUILTIN_PRESETS: Dict[str, Dict[str, Any]] = {
|
||||
"robotic": {
|
||||
"name": "Robotic",
|
||||
"sort_order": 0,
|
||||
"description": "Metallic robotic voice (flanger with slow LFO and high feedback)",
|
||||
"effects_chain": [
|
||||
{
|
||||
"type": "chorus",
|
||||
"enabled": True,
|
||||
"params": {
|
||||
"rate_hz": 0.2,
|
||||
"depth": 1.0,
|
||||
"feedback": 0.35,
|
||||
"centre_delay_ms": 7.0,
|
||||
"mix": 0.5,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
"radio": {
|
||||
"name": "Radio",
|
||||
"sort_order": 1,
|
||||
"description": "Thin AM-radio voice with band-pass filtering and light compression",
|
||||
"effects_chain": [
|
||||
{
|
||||
"type": "highpass",
|
||||
"enabled": True,
|
||||
"params": {"cutoff_frequency_hz": 300.0},
|
||||
},
|
||||
{
|
||||
"type": "lowpass",
|
||||
"enabled": True,
|
||||
"params": {"cutoff_frequency_hz": 3500.0},
|
||||
},
|
||||
{
|
||||
"type": "compressor",
|
||||
"enabled": True,
|
||||
"params": {
|
||||
"threshold_db": -15.0,
|
||||
"ratio": 6.0,
|
||||
"attack_ms": 5.0,
|
||||
"release_ms": 50.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "gain",
|
||||
"enabled": True,
|
||||
"params": {"gain_db": 6.0},
|
||||
},
|
||||
],
|
||||
},
|
||||
"echo_chamber": {
|
||||
"name": "Echo Chamber",
|
||||
"sort_order": 2,
|
||||
"description": "Spacious reverb with trailing echo",
|
||||
"effects_chain": [
|
||||
{
|
||||
"type": "reverb",
|
||||
"enabled": True,
|
||||
"params": {
|
||||
"room_size": 0.85,
|
||||
"damping": 0.3,
|
||||
"wet_level": 0.45,
|
||||
"dry_level": 0.55,
|
||||
"width": 1.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "delay",
|
||||
"enabled": True,
|
||||
"params": {
|
||||
"delay_seconds": 0.25,
|
||||
"feedback": 0.3,
|
||||
"mix": 0.2,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
"deep_voice": {
|
||||
"name": "Deep Voice",
|
||||
"sort_order": 99,
|
||||
"description": "Lower pitch with added warmth",
|
||||
"effects_chain": [
|
||||
{
|
||||
"type": "pitch_shift",
|
||||
"enabled": True,
|
||||
"params": {"semitones": -3.0},
|
||||
},
|
||||
{
|
||||
"type": "lowpass",
|
||||
"enabled": True,
|
||||
"params": {"cutoff_frequency_hz": 6000.0},
|
||||
},
|
||||
{
|
||||
"type": "compressor",
|
||||
"enabled": True,
|
||||
"params": {
|
||||
"threshold_db": -18.0,
|
||||
"ratio": 3.0,
|
||||
"attack_ms": 10.0,
|
||||
"release_ms": 150.0,
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_available_effects() -> List[Dict[str, Any]]:
|
||||
"""Return the list of available effect types with their parameter definitions.
|
||||
|
||||
Used by the frontend to build the effects chain editor UI.
|
||||
"""
|
||||
result = []
|
||||
for effect_type, info in EFFECT_REGISTRY.items():
|
||||
result.append(
|
||||
{
|
||||
"type": effect_type,
|
||||
"label": info["label"],
|
||||
"description": info["description"],
|
||||
"params": {name: {k: v for k, v in pdef.items()} for name, pdef in info["params"].items()},
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_builtin_presets() -> Dict[str, Dict[str, Any]]:
|
||||
"""Return all built-in effect presets."""
|
||||
return BUILTIN_PRESETS
|
||||
|
||||
|
||||
def validate_effects_chain(effects_chain: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Validate an effects chain configuration.
|
||||
|
||||
Returns None if valid, or an error message string.
|
||||
"""
|
||||
if not isinstance(effects_chain, list):
|
||||
return "effects_chain must be a list"
|
||||
|
||||
for i, effect in enumerate(effects_chain):
|
||||
if not isinstance(effect, dict):
|
||||
return f"Effect at index {i} must be a dict"
|
||||
|
||||
effect_type = effect.get("type")
|
||||
if effect_type not in EFFECT_REGISTRY:
|
||||
return f"Unknown effect type '{effect_type}' at index {i}. Available: {list(EFFECT_REGISTRY.keys())}"
|
||||
|
||||
params = effect.get("params", {})
|
||||
if not isinstance(params, dict):
|
||||
return f"Effect '{effect_type}' at index {i}: params must be a dict"
|
||||
|
||||
registry = EFFECT_REGISTRY[effect_type]
|
||||
for param_name, value in params.items():
|
||||
if param_name not in registry["params"]:
|
||||
return f"Effect '{effect_type}' at index {i}: unknown param '{param_name}'"
|
||||
|
||||
pdef = registry["params"][param_name]
|
||||
if not isinstance(value, (int, float)):
|
||||
return f"Effect '{effect_type}' at index {i}: param '{param_name}' must be a number"
|
||||
if value < pdef["min"] or value > pdef["max"]:
|
||||
return (
|
||||
f"Effect '{effect_type}' at index {i}: param '{param_name}' "
|
||||
f"must be between {pdef['min']} and {pdef['max']} (got {value})"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def build_pedalboard(effects_chain: List[Dict[str, Any]]) -> Pedalboard:
|
||||
"""Build a Pedalboard instance from an effects chain config.
|
||||
|
||||
Skips effects where ``enabled`` is ``False``.
|
||||
"""
|
||||
plugins = []
|
||||
for effect in effects_chain:
|
||||
if not effect.get("enabled", True):
|
||||
continue
|
||||
|
||||
effect_type = effect["type"]
|
||||
registry = EFFECT_REGISTRY[effect_type]
|
||||
cls = registry["cls"]
|
||||
|
||||
# Merge defaults with provided params
|
||||
params = {}
|
||||
for pname, pdef in registry["params"].items():
|
||||
params[pname] = effect.get("params", {}).get(pname, pdef["default"])
|
||||
|
||||
plugins.append(cls(**params))
|
||||
|
||||
return Pedalboard(plugins)
|
||||
|
||||
|
||||
def apply_effects(
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
effects_chain: List[Dict[str, Any]],
|
||||
) -> np.ndarray:
|
||||
"""Apply an effects chain to audio data.
|
||||
|
||||
Args:
|
||||
audio: Input audio array (1-D mono float32).
|
||||
sample_rate: Sample rate in Hz.
|
||||
effects_chain: List of effect configuration dicts.
|
||||
|
||||
Returns:
|
||||
Processed audio array.
|
||||
"""
|
||||
if not effects_chain:
|
||||
return audio
|
||||
|
||||
board = build_pedalboard(effects_chain)
|
||||
|
||||
# pedalboard expects shape (channels, samples)
|
||||
if audio.ndim == 1:
|
||||
audio_2d = audio[np.newaxis, :]
|
||||
else:
|
||||
audio_2d = audio
|
||||
|
||||
processed = board(audio_2d.astype(np.float32), sample_rate)
|
||||
|
||||
# Return same dimensionality as input
|
||||
if audio.ndim == 1:
|
||||
return processed[0]
|
||||
return processed
|
||||
270
backend/utils/hf_offline_patch.py
Normal file
270
backend/utils/hf_offline_patch.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Monkey-patch huggingface_hub to force offline mode with cached models.
|
||||
|
||||
Prevents mlx_audio / transformers from making network requests when models
|
||||
are already downloaded. Must be imported BEFORE mlx_audio.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# huggingface_hub reads ``HF_HUB_OFFLINE`` once at import time into
|
||||
# ``huggingface_hub.constants.HF_HUB_OFFLINE``; transformers mirrors that into
|
||||
# ``transformers.utils.hub._is_offline_mode`` at *its* import time. Toggling
|
||||
# ``os.environ`` after either module is imported does not flip those cached
|
||||
# bools, and the hot paths (``_http._default_backend_factory``,
|
||||
# ``transformers.utils.hub.is_offline_mode``) read the bools — not the env.
|
||||
# We mutate the cached constants directly, guarded by a refcount so
|
||||
# concurrent inference threads share a single offline window safely.
|
||||
|
||||
_offline_lock = threading.RLock()
|
||||
_offline_refcount = 0
|
||||
_saved_env: Optional[str] = None
|
||||
_saved_hf_const: Optional[bool] = None
|
||||
_saved_transformers_const: Optional[bool] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def force_offline_if_cached(is_cached: bool, model_label: str = ""):
|
||||
"""Force offline mode for the duration of a cached-model operation.
|
||||
|
||||
Flips ``HF_HUB_OFFLINE`` in the process env **and** in the cached bools
|
||||
inside ``huggingface_hub.constants`` / ``transformers.utils.hub`` so HTTP
|
||||
adapters and offline-mode checks actually see the change. Uses a refcount
|
||||
so multiple concurrent inference threads share a single offline window
|
||||
and the last one to exit restores state.
|
||||
|
||||
If *is_cached* is ``False`` the block runs normally (network allowed).
|
||||
|
||||
Args:
|
||||
is_cached: Whether the model weights are already on disk.
|
||||
model_label: Human-readable name used in log messages.
|
||||
"""
|
||||
if not is_cached:
|
||||
yield
|
||||
return
|
||||
|
||||
global _offline_refcount, _saved_env, _saved_hf_const, _saved_transformers_const
|
||||
|
||||
with _offline_lock:
|
||||
if _offline_refcount == 0:
|
||||
# Snapshot prior state, apply new state, roll back on *any*
|
||||
# failure. Catching only ImportError here would let a partially
|
||||
# broken install (RuntimeError, AttributeError from a half-init
|
||||
# module, etc.) leave the cached HF constants mutated without
|
||||
# bumping the refcount — a persistent offline leak that outlives
|
||||
# the process and is miserable to debug.
|
||||
prev_env = os.environ.get("HF_HUB_OFFLINE")
|
||||
prev_hf: Optional[bool] = None
|
||||
prev_tf: Optional[bool] = None
|
||||
try:
|
||||
try:
|
||||
import huggingface_hub.constants as hf_const
|
||||
|
||||
prev_hf = hf_const.HF_HUB_OFFLINE
|
||||
hf_const.HF_HUB_OFFLINE = True
|
||||
except ImportError:
|
||||
prev_hf = None
|
||||
|
||||
try:
|
||||
import transformers.utils.hub as tf_hub
|
||||
|
||||
prev_tf = tf_hub._is_offline_mode
|
||||
tf_hub._is_offline_mode = True
|
||||
except ImportError:
|
||||
prev_tf = None
|
||||
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
except BaseException:
|
||||
# Roll back whatever we already changed, then re-raise so
|
||||
# the caller sees the real failure.
|
||||
if prev_hf is not None:
|
||||
try:
|
||||
import huggingface_hub.constants as hf_const
|
||||
|
||||
hf_const.HF_HUB_OFFLINE = prev_hf
|
||||
except ImportError:
|
||||
pass
|
||||
if prev_tf is not None:
|
||||
try:
|
||||
import transformers.utils.hub as tf_hub
|
||||
|
||||
tf_hub._is_offline_mode = prev_tf
|
||||
except ImportError:
|
||||
pass
|
||||
if prev_env is not None:
|
||||
os.environ["HF_HUB_OFFLINE"] = prev_env
|
||||
else:
|
||||
os.environ.pop("HF_HUB_OFFLINE", None)
|
||||
raise
|
||||
|
||||
_saved_env = prev_env
|
||||
_saved_hf_const = prev_hf
|
||||
_saved_transformers_const = prev_tf
|
||||
logger.info(
|
||||
"[offline-guard] %s is cached — forcing offline mode",
|
||||
model_label or "model",
|
||||
)
|
||||
_offline_refcount += 1
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
with _offline_lock:
|
||||
_offline_refcount -= 1
|
||||
if _offline_refcount == 0:
|
||||
if _saved_env is not None:
|
||||
os.environ["HF_HUB_OFFLINE"] = _saved_env
|
||||
else:
|
||||
os.environ.pop("HF_HUB_OFFLINE", None)
|
||||
if _saved_hf_const is not None:
|
||||
try:
|
||||
import huggingface_hub.constants as hf_const
|
||||
|
||||
hf_const.HF_HUB_OFFLINE = _saved_hf_const
|
||||
except ImportError:
|
||||
pass
|
||||
if _saved_transformers_const is not None:
|
||||
try:
|
||||
import transformers.utils.hub as tf_hub
|
||||
|
||||
tf_hub._is_offline_mode = _saved_transformers_const
|
||||
except ImportError:
|
||||
pass
|
||||
_saved_env = None
|
||||
_saved_hf_const = None
|
||||
_saved_transformers_const = None
|
||||
|
||||
|
||||
_mistral_regex_patched = False
|
||||
|
||||
|
||||
def patch_transformers_mistral_regex():
|
||||
"""Make transformers' tokenizer load robust to HuggingFace metadata failures.
|
||||
|
||||
transformers 4.57.x added ``PreTrainedTokenizerBase._patch_mistral_regex``
|
||||
which unconditionally calls ``huggingface_hub.model_info(repo_id)`` during
|
||||
every non-local tokenizer load to check whether the model is a Mistral
|
||||
variant. That call raises on ``HF_HUB_OFFLINE=1`` and on plain network
|
||||
failures, killing unrelated loads (Qwen TTS, TADA, etc.).
|
||||
|
||||
Voicebox never loads Mistral models, so the rewrite the function would
|
||||
apply is a no-op for us anyway. Wrap the method so any exception from the
|
||||
metadata lookup returns the tokenizer unchanged — matching the success-path
|
||||
behavior for non-Mistral repos (transformers 4.57.3,
|
||||
``tokenization_utils_base.py:2503``).
|
||||
"""
|
||||
global _mistral_regex_patched
|
||||
if _mistral_regex_patched:
|
||||
return
|
||||
|
||||
try:
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
except ImportError:
|
||||
logger.debug("transformers not available, skipping mistral-regex patch")
|
||||
return
|
||||
|
||||
original = getattr(PreTrainedTokenizerBase, "_patch_mistral_regex", None)
|
||||
if original is None:
|
||||
logger.debug(
|
||||
"transformers has no _patch_mistral_regex attribute, skipping patch",
|
||||
)
|
||||
return
|
||||
|
||||
def safe_patch_mistral_regex(cls, tokenizer, pretrained_model_name_or_path, *args, **kwargs):
|
||||
try:
|
||||
return original(tokenizer, pretrained_model_name_or_path, *args, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[mistral-regex-patch] suppressed %s for %r, returning tokenizer as-is",
|
||||
type(exc).__name__,
|
||||
pretrained_model_name_or_path,
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
PreTrainedTokenizerBase._patch_mistral_regex = classmethod(safe_patch_mistral_regex)
|
||||
_mistral_regex_patched = True
|
||||
logger.debug("installed _patch_mistral_regex wrapper")
|
||||
|
||||
|
||||
def patch_huggingface_hub_offline():
|
||||
"""Monkey-patch huggingface_hub to force offline mode."""
|
||||
try:
|
||||
import huggingface_hub # noqa: F401 -- need the package loaded
|
||||
from huggingface_hub import constants as hf_constants
|
||||
from huggingface_hub.file_download import _try_to_load_from_cache
|
||||
|
||||
original_try_load = _try_to_load_from_cache
|
||||
|
||||
def _patched_try_to_load_from_cache(
|
||||
repo_id: str,
|
||||
filename: str,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
revision: Optional[str] = None,
|
||||
repo_type: Optional[str] = None,
|
||||
):
|
||||
result = original_try_load(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
cache_path = Path(hf_constants.HF_HUB_CACHE) / f"models--{repo_id.replace('/', '--')}"
|
||||
logger.debug("file not cached: %s/%s (expected at %s)", repo_id, filename, cache_path)
|
||||
else:
|
||||
logger.debug("cache hit: %s/%s", repo_id, filename)
|
||||
|
||||
return result
|
||||
|
||||
import huggingface_hub.file_download as fd
|
||||
|
||||
fd._try_to_load_from_cache = _patched_try_to_load_from_cache
|
||||
logger.debug("huggingface_hub patched for offline mode")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("huggingface_hub not available, skipping offline patch")
|
||||
except Exception:
|
||||
logger.exception("failed to patch huggingface_hub for offline mode")
|
||||
|
||||
|
||||
def ensure_original_qwen_config_cached():
|
||||
"""Symlink the original Qwen repo cache to the MLX community version.
|
||||
|
||||
mlx_audio may try to fetch config from the original Qwen repo. If only
|
||||
the MLX community variant is cached, create a symlink so the cache lookup
|
||||
succeeds without a network request.
|
||||
"""
|
||||
try:
|
||||
from huggingface_hub import constants as hf_constants
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
original_repo = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
mlx_repo = "mlx-community/Qwen3-TTS-12Hz-1.7B-Base-bf16"
|
||||
|
||||
cache_dir = Path(hf_constants.HF_HUB_CACHE)
|
||||
original_path = cache_dir / f"models--{original_repo.replace('/', '--')}"
|
||||
mlx_path = cache_dir / f"models--{mlx_repo.replace('/', '--')}"
|
||||
|
||||
if not original_path.exists() and mlx_path.exists():
|
||||
try:
|
||||
original_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
original_path.symlink_to(mlx_path, target_is_directory=True)
|
||||
logger.info("created cache symlink: %s -> %s", original_repo, mlx_repo)
|
||||
except Exception:
|
||||
logger.warning("could not create cache symlink for %s", original_repo, exc_info=True)
|
||||
|
||||
|
||||
if os.environ.get("VOICEBOX_OFFLINE_PATCH", "1") != "0":
|
||||
patch_huggingface_hub_offline()
|
||||
patch_transformers_mistral_regex()
|
||||
ensure_original_qwen_config_cached()
|
||||
383
backend/utils/hf_progress.py
Normal file
383
backend/utils/hf_progress.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""
|
||||
HuggingFace Hub download progress tracking.
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable
|
||||
from contextlib import contextmanager
|
||||
import logging
|
||||
import threading
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HFProgressTracker:
|
||||
"""Tracks HuggingFace Hub download progress by intercepting tqdm."""
|
||||
|
||||
def __init__(self, progress_callback: Optional[Callable] = None, filter_non_downloads: bool = False):
|
||||
self.progress_callback = progress_callback
|
||||
self.filter_non_downloads = filter_non_downloads # Only filter if True
|
||||
self._original_tqdm_class = None
|
||||
self._lock = threading.Lock()
|
||||
self._total_downloaded = 0
|
||||
self._total_size = 0
|
||||
self._file_sizes = {} # Track sizes of individual files
|
||||
self._file_downloaded = {} # Track downloaded bytes per file
|
||||
self._current_filename = ""
|
||||
self._active_tqdms = {} # Track active tqdm instances
|
||||
self._hf_tqdm_original_update = None # For monkey-patching hf's tqdm
|
||||
|
||||
def _create_tracked_tqdm_class(self):
|
||||
"""Create a tqdm subclass that tracks progress."""
|
||||
tracker = self
|
||||
original_tqdm = self._original_tqdm_class
|
||||
|
||||
class TrackedTqdm(original_tqdm):
|
||||
"""A tqdm subclass that reports progress to our tracker."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Extract filename from desc before passing to parent
|
||||
desc = kwargs.get("desc", "")
|
||||
if not desc and args:
|
||||
first_arg = args[0]
|
||||
if isinstance(first_arg, str):
|
||||
desc = first_arg
|
||||
|
||||
filename = ""
|
||||
if desc:
|
||||
# Try to extract filename from description
|
||||
# HuggingFace Hub uses format like "model.safetensors: 0%|..."
|
||||
if ":" in desc:
|
||||
filename = desc.split(":")[0].strip()
|
||||
else:
|
||||
filename = desc.strip()
|
||||
|
||||
# Filter out non-standard kwargs that huggingface_hub might pass
|
||||
# These are custom kwargs that tqdm doesn't understand
|
||||
filtered_kwargs = {}
|
||||
# Known tqdm kwargs - pass these through
|
||||
tqdm_kwargs = {
|
||||
"iterable",
|
||||
"desc",
|
||||
"total",
|
||||
"leave",
|
||||
"file",
|
||||
"ncols",
|
||||
"mininterval",
|
||||
"maxinterval",
|
||||
"miniters",
|
||||
"ascii",
|
||||
"disable",
|
||||
"unit",
|
||||
"unit_scale",
|
||||
"dynamic_ncols",
|
||||
"smoothing",
|
||||
"bar_format",
|
||||
"initial",
|
||||
"position",
|
||||
"postfix",
|
||||
"unit_divisor",
|
||||
"write_bytes",
|
||||
"lock_args",
|
||||
"nrows",
|
||||
"colour",
|
||||
"color",
|
||||
"delay",
|
||||
"gui",
|
||||
"disable_default",
|
||||
"pos",
|
||||
}
|
||||
for key, value in kwargs.items():
|
||||
if key in tqdm_kwargs:
|
||||
filtered_kwargs[key] = value
|
||||
|
||||
# Force-enable the progress bar — we're tracking progress ourselves,
|
||||
# we don't need tqdm to render to a terminal, but we DO need
|
||||
# self.n to be updated when update() is called.
|
||||
filtered_kwargs["disable"] = False
|
||||
|
||||
# Try to initialize with filtered kwargs, fall back to all kwargs if that fails
|
||||
try:
|
||||
super().__init__(*args, **filtered_kwargs)
|
||||
except TypeError:
|
||||
# If filtering failed, try with all kwargs (maybe tqdm version accepts them)
|
||||
kwargs["disable"] = False
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._tracker_filename = filename or "unknown"
|
||||
|
||||
with tracker._lock:
|
||||
if filename:
|
||||
tracker._current_filename = filename
|
||||
tracker._active_tqdms[id(self)] = {
|
||||
"filename": self._tracker_filename,
|
||||
}
|
||||
|
||||
def update(self, n=1):
|
||||
result = super().update(n)
|
||||
|
||||
# Report progress
|
||||
with tracker._lock:
|
||||
if id(self) in tracker._active_tqdms:
|
||||
filename = tracker._active_tqdms[id(self)]["filename"]
|
||||
current = getattr(self, "n", 0)
|
||||
total = getattr(self, "total", 0)
|
||||
|
||||
if total and total > 0:
|
||||
# Always filter out non-byte progress bars (e.g., "Fetching 12 files")
|
||||
# These cause crazy percentages because they're counting files, not bytes
|
||||
if self._is_non_byte_progress(filename):
|
||||
return result
|
||||
|
||||
# When model is cached, also filter out generation-related progress
|
||||
if tracker.filter_non_downloads:
|
||||
if not self._is_download_progress(filename):
|
||||
return result
|
||||
|
||||
# Update per-file tracking
|
||||
tracker._file_sizes[filename] = total
|
||||
tracker._file_downloaded[filename] = current
|
||||
|
||||
# Calculate totals across all files
|
||||
tracker._total_size = sum(tracker._file_sizes.values())
|
||||
tracker._total_downloaded = sum(tracker._file_downloaded.values())
|
||||
|
||||
# Only report progress once we have a meaningful total (at least 1MB)
|
||||
# This avoids the "100% at 0MB" issue when small config
|
||||
# files are counted before the real model files
|
||||
MIN_TOTAL_BYTES = 1_000_000 # 1MB
|
||||
if tracker._total_size < MIN_TOTAL_BYTES:
|
||||
return result
|
||||
|
||||
# Call progress callback
|
||||
if tracker.progress_callback:
|
||||
tracker.progress_callback(tracker._total_downloaded, tracker._total_size, filename)
|
||||
|
||||
return result
|
||||
|
||||
def _is_non_byte_progress(self, filename: str) -> bool:
|
||||
"""Check if this progress bar should be SKIPPED (returns True to skip).
|
||||
|
||||
We want to track byte-based progress bars. This method identifies
|
||||
progress bars that count files/items instead of bytes, which would
|
||||
cause crazy percentages if mixed with our byte counting.
|
||||
|
||||
Returns:
|
||||
True = SKIP this bar (it's not byte-based)
|
||||
False = TRACK this bar (it counts bytes)
|
||||
"""
|
||||
if not filename:
|
||||
return False
|
||||
|
||||
filename_lower = filename.lower()
|
||||
|
||||
# Skip "Fetching X files" - it counts files (total=12), not bytes
|
||||
# Don't skip "Downloading (incomplete total...)" - that IS byte-based
|
||||
skip_patterns = [
|
||||
"fetching", # "Fetching 12 files" has total=12 files, not bytes
|
||||
]
|
||||
return any(pattern in filename_lower for pattern in skip_patterns)
|
||||
|
||||
def _is_download_progress(self, filename: str) -> bool:
|
||||
"""Check if this is a real file download progress bar vs internal processing."""
|
||||
if not filename or filename == "unknown":
|
||||
return False
|
||||
|
||||
# Real downloads have file extensions
|
||||
download_extensions = [
|
||||
".safetensors",
|
||||
".bin",
|
||||
".pt",
|
||||
".pth", # Model weights
|
||||
".json",
|
||||
".txt",
|
||||
".py", # Config files
|
||||
".msgpack",
|
||||
".h5", # Other formats
|
||||
]
|
||||
|
||||
filename_lower = filename.lower()
|
||||
has_extension = any(filename_lower.endswith(ext) for ext in download_extensions)
|
||||
|
||||
# Skip generation-related progress indicators
|
||||
skip_patterns = ["segment", "processing", "generating", "loading"]
|
||||
has_skip_pattern = any(pattern in filename_lower for pattern in skip_patterns)
|
||||
|
||||
return has_extension and not has_skip_pattern
|
||||
|
||||
def close(self):
|
||||
with tracker._lock:
|
||||
if id(self) in tracker._active_tqdms:
|
||||
del tracker._active_tqdms[id(self)]
|
||||
return super().close()
|
||||
|
||||
return TrackedTqdm
|
||||
|
||||
@contextmanager
|
||||
def patch_download(self):
|
||||
"""Context manager to patch tqdm for progress tracking."""
|
||||
try:
|
||||
import tqdm as tqdm_module
|
||||
|
||||
# Store original tqdm class
|
||||
self._original_tqdm_class = tqdm_module.tqdm
|
||||
|
||||
# Reset totals
|
||||
with self._lock:
|
||||
self._total_downloaded = 0
|
||||
self._total_size = 0
|
||||
self._file_sizes = {}
|
||||
self._file_downloaded = {}
|
||||
self._current_filename = ""
|
||||
self._active_tqdms = {}
|
||||
|
||||
# Create our tracked tqdm class
|
||||
tracked_tqdm = self._create_tracked_tqdm_class()
|
||||
|
||||
# Patch tqdm.tqdm
|
||||
tqdm_module.tqdm = tracked_tqdm
|
||||
|
||||
# Also patch tqdm.auto.tqdm if it exists (used by huggingface_hub)
|
||||
self._original_tqdm_auto = None
|
||||
if hasattr(tqdm_module, "auto") and hasattr(tqdm_module.auto, "tqdm"):
|
||||
self._original_tqdm_auto = tqdm_module.auto.tqdm
|
||||
tqdm_module.auto.tqdm = tracked_tqdm
|
||||
|
||||
# Patch in sys.modules to catch already-imported references
|
||||
# huggingface_hub uses: from tqdm.auto import tqdm as base_tqdm
|
||||
# So we need to patch both 'tqdm' and 'base_tqdm' attributes
|
||||
self._patched_modules = {}
|
||||
tqdm_attr_names = ["tqdm", "base_tqdm", "old_tqdm"] # Various names used
|
||||
|
||||
patched_count = 0
|
||||
for module_name in list(sys.modules.keys()):
|
||||
if "huggingface" in module_name or module_name.startswith("tqdm"):
|
||||
try:
|
||||
module = sys.modules[module_name]
|
||||
for attr_name in tqdm_attr_names:
|
||||
if hasattr(module, attr_name):
|
||||
attr = getattr(module, attr_name)
|
||||
# Only patch if it's a tqdm class (not already patched)
|
||||
is_tqdm_class = (
|
||||
attr is self._original_tqdm_class
|
||||
or (self._original_tqdm_auto and attr is self._original_tqdm_auto)
|
||||
or (
|
||||
hasattr(attr, "__name__")
|
||||
and attr.__name__ == "tqdm"
|
||||
and hasattr(attr, "update")
|
||||
) # tqdm classes have update method
|
||||
)
|
||||
if is_tqdm_class:
|
||||
key = f"{module_name}.{attr_name}"
|
||||
self._patched_modules[key] = (module, attr_name, attr)
|
||||
setattr(module, attr_name, tracked_tqdm)
|
||||
patched_count += 1
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
# ALSO monkey-patch the update method on huggingface_hub's tqdm class
|
||||
# This is needed because the class was already defined at import time
|
||||
self._hf_tqdm_original_update = None
|
||||
try:
|
||||
from huggingface_hub.utils import tqdm as hf_tqdm_module
|
||||
|
||||
if hasattr(hf_tqdm_module, "tqdm"):
|
||||
hf_tqdm_class = hf_tqdm_module.tqdm
|
||||
self._hf_tqdm_original_update = hf_tqdm_class.update
|
||||
|
||||
# Create a wrapper that calls our tracking
|
||||
tracker = self # Reference to HFProgressTracker instance
|
||||
|
||||
def patched_update(tqdm_self, n=1):
|
||||
result = tracker._hf_tqdm_original_update(tqdm_self, n)
|
||||
|
||||
# Track this progress
|
||||
with tracker._lock:
|
||||
desc = getattr(tqdm_self, "desc", "") or ""
|
||||
current = getattr(tqdm_self, "n", 0)
|
||||
total = getattr(tqdm_self, "total", 0) or 0
|
||||
|
||||
# Skip non-byte progress bars
|
||||
if "fetching" in desc.lower():
|
||||
return result
|
||||
|
||||
# Skip until we have a meaningful total (at least 1MB)
|
||||
# This avoids the "100% at 0MB" issue when small config
|
||||
# files are counted before the real model files
|
||||
MIN_TOTAL_BYTES = 1_000_000 # 1MB
|
||||
if total >= MIN_TOTAL_BYTES:
|
||||
tracker._total_downloaded = current
|
||||
tracker._total_size = total
|
||||
|
||||
if tracker.progress_callback:
|
||||
tracker.progress_callback(current, total, desc)
|
||||
|
||||
return result
|
||||
|
||||
hf_tqdm_class.update = patched_update
|
||||
patched_count += 1
|
||||
logger.debug("Monkey-patched huggingface_hub.utils.tqdm.tqdm.update")
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.warning("Could not monkey-patch hf_tqdm: %s", e)
|
||||
|
||||
logger.debug("Patched %d tqdm references", patched_count)
|
||||
|
||||
yield
|
||||
|
||||
except ImportError:
|
||||
# If tqdm not available, just yield without patching
|
||||
yield
|
||||
finally:
|
||||
# Restore original tqdm
|
||||
if self._original_tqdm_class:
|
||||
try:
|
||||
import tqdm as tqdm_module
|
||||
|
||||
tqdm_module.tqdm = self._original_tqdm_class
|
||||
|
||||
if self._original_tqdm_auto:
|
||||
tqdm_module.auto.tqdm = self._original_tqdm_auto
|
||||
|
||||
# Restore patched modules
|
||||
for key, (module, attr_name, original) in self._patched_modules.items():
|
||||
try:
|
||||
if module and original:
|
||||
setattr(module, attr_name, original)
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
self._patched_modules = {}
|
||||
|
||||
# Restore hf_tqdm's original update method
|
||||
if self._hf_tqdm_original_update:
|
||||
try:
|
||||
from huggingface_hub.utils import tqdm as hf_tqdm_module
|
||||
|
||||
if hasattr(hf_tqdm_module, "tqdm"):
|
||||
hf_tqdm_module.tqdm.update = self._hf_tqdm_original_update
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
self._hf_tqdm_original_update = None
|
||||
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
def create_hf_progress_callback(model_name: str, progress_manager):
|
||||
"""Create a progress callback for HuggingFace downloads."""
|
||||
|
||||
def callback(downloaded: int, total: int, filename: str = ""):
|
||||
"""Progress callback.
|
||||
|
||||
Note: We send updates even when total=0 (unknown) to provide feedback
|
||||
during the "incomplete total" phase of huggingface_hub downloads.
|
||||
The frontend handles total=0 gracefully.
|
||||
"""
|
||||
progress_manager.update_progress(
|
||||
model_name=model_name,
|
||||
current=downloaded,
|
||||
total=total,
|
||||
filename=filename or "",
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
return callback
|
||||
114
backend/utils/images.py
Normal file
114
backend/utils/images.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Image processing utilities for avatar uploads."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
|
||||
# JPEG can be reported as 'JPEG' or 'MPO' (for multi-picture format from some cameras)
|
||||
ALLOWED_FORMATS = {'PNG', 'JPEG', 'WEBP', 'MPO', 'JPG'}
|
||||
MAX_SIZE = 512
|
||||
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
|
||||
|
||||
|
||||
def validate_image(file_path: str) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate image format and file size.
|
||||
|
||||
Args:
|
||||
file_path: Path to image file
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
# Check file size
|
||||
if path.stat().st_size > MAX_FILE_SIZE:
|
||||
return False, f"File size exceeds maximum of {MAX_FILE_SIZE // (1024 * 1024)}MB"
|
||||
|
||||
try:
|
||||
with Image.open(file_path) as img:
|
||||
# Verify the image can be loaded
|
||||
img.load()
|
||||
|
||||
# Check format (normalize JPEG variants)
|
||||
img_format = img.format
|
||||
if img_format in ('MPO', 'JPG'):
|
||||
img_format = 'JPEG'
|
||||
|
||||
if img_format not in {'PNG', 'JPEG', 'WEBP'}:
|
||||
return False, f"Invalid format '{img_format}'. Allowed formats: PNG, JPEG, WEBP"
|
||||
|
||||
return True, None
|
||||
except Exception as e:
|
||||
return False, f"Invalid image file: {str(e)}"
|
||||
|
||||
|
||||
def process_avatar(input_path: str, output_path: str, max_size: int = MAX_SIZE) -> None:
|
||||
"""
|
||||
Process avatar image: resize and optimize.
|
||||
|
||||
Resizes image to fit within max_size x max_size while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
input_path: Path to input image
|
||||
output_path: Path to save processed image
|
||||
max_size: Maximum width or height in pixels
|
||||
"""
|
||||
with Image.open(input_path) as img:
|
||||
# Handle EXIF orientation for JPEG images
|
||||
try:
|
||||
from PIL import ExifTags
|
||||
for orientation in ExifTags.TAGS.keys():
|
||||
if ExifTags.TAGS[orientation] == 'Orientation':
|
||||
break
|
||||
exif = img._getexif()
|
||||
if exif is not None:
|
||||
orientation_value = exif.get(orientation)
|
||||
if orientation_value == 3:
|
||||
img = img.rotate(180, expand=True)
|
||||
elif orientation_value == 6:
|
||||
img = img.rotate(270, expand=True)
|
||||
elif orientation_value == 8:
|
||||
img = img.rotate(90, expand=True)
|
||||
except (AttributeError, KeyError, IndexError, TypeError):
|
||||
# No EXIF data or orientation tag
|
||||
pass
|
||||
|
||||
# Convert to RGB if necessary (handles RGBA, P, CMYK, etc.)
|
||||
if img.mode not in ('RGB', 'L'):
|
||||
if img.mode == 'RGBA':
|
||||
# Create white background for RGBA images
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
background.paste(img, mask=img.split()[3]) # Use alpha channel as mask
|
||||
img = background
|
||||
elif img.mode == 'CMYK':
|
||||
# Convert CMYK to RGB
|
||||
img = img.convert('RGB')
|
||||
elif img.mode == 'P':
|
||||
# Convert palette mode to RGB
|
||||
img = img.convert('RGB')
|
||||
else:
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Calculate new size maintaining aspect ratio
|
||||
img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
|
||||
|
||||
# Determine output format from extension
|
||||
output_ext = Path(output_path).suffix.lower()
|
||||
|
||||
format_map = {
|
||||
'.png': 'PNG',
|
||||
'.jpeg': 'JPEG',
|
||||
'.jpg': 'JPEG',
|
||||
'.webp': 'WEBP'
|
||||
}
|
||||
|
||||
output_format = format_map.get(output_ext, 'PNG')
|
||||
|
||||
# Save with optimization
|
||||
save_kwargs = {'optimize': True}
|
||||
if output_format == 'JPEG':
|
||||
save_kwargs['quality'] = 90
|
||||
|
||||
img.save(output_path, format=output_format, **save_kwargs)
|
||||
35
backend/utils/platform_detect.py
Normal file
35
backend/utils/platform_detect.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Platform detection for backend selection.
|
||||
"""
|
||||
|
||||
import platform
|
||||
from typing import Literal
|
||||
|
||||
|
||||
def is_apple_silicon() -> bool:
|
||||
"""
|
||||
Check if running on Apple Silicon (arm64 macOS).
|
||||
|
||||
Returns:
|
||||
True if on Apple Silicon, False otherwise
|
||||
"""
|
||||
return platform.system() == "Darwin" and platform.machine() == "arm64"
|
||||
|
||||
|
||||
def get_backend_type() -> Literal["mlx", "pytorch"]:
|
||||
"""
|
||||
Detect the best backend for the current platform.
|
||||
|
||||
Returns:
|
||||
"mlx" on Apple Silicon (if MLX is available and functional), "pytorch" otherwise
|
||||
"""
|
||||
if is_apple_silicon():
|
||||
try:
|
||||
import mlx.core # noqa: F401 — triggers native lib loading
|
||||
return "mlx"
|
||||
except (ImportError, OSError, RuntimeError):
|
||||
# MLX not installed, or native libraries failed to load inside a
|
||||
# PyInstaller bundle (OSError on missing .dylib / .metallib).
|
||||
# Fall through to PyTorch.
|
||||
return "pytorch"
|
||||
return "pytorch"
|
||||
315
backend/utils/progress.py
Normal file
315
backend/utils/progress.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
Progress tracking for model downloads using Server-Sent Events.
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable, Dict, List
|
||||
from fastapi.responses import StreamingResponse
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ProgressManager:
|
||||
"""Manages download progress for multiple models.
|
||||
|
||||
Thread-safe: can be called from background threads (e.g., via asyncio.to_thread).
|
||||
"""
|
||||
|
||||
# Throttle settings to prevent overwhelming SSE clients
|
||||
THROTTLE_INTERVAL_SECONDS = 0.5 # Minimum time between updates
|
||||
THROTTLE_PROGRESS_DELTA = 1.0 # Minimum progress change (%) to force update
|
||||
|
||||
def __init__(self):
|
||||
self._progress: Dict[str, Dict] = {}
|
||||
self._listeners: Dict[str, list] = {}
|
||||
self._lock = threading.Lock() # Thread-safe lock for progress dict
|
||||
self._main_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._last_notify_time: Dict[str, float] = {} # Last notification time per model
|
||||
self._last_notify_progress: Dict[str, float] = {} # Last notified progress per model
|
||||
|
||||
def _set_main_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
"""Set the main event loop for thread-safe operations."""
|
||||
self._main_loop = loop
|
||||
|
||||
def _notify_listeners_threadsafe(self, model_name: str, progress_data: Dict):
|
||||
"""Notify listeners in a thread-safe manner."""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if model_name not in self._listeners:
|
||||
return
|
||||
|
||||
for queue in self._listeners[model_name]:
|
||||
try:
|
||||
# Check if we're in the main event loop thread
|
||||
try:
|
||||
running_loop = asyncio.get_running_loop()
|
||||
# We're in an async context, can use put_nowait directly
|
||||
queue.put_nowait(progress_data.copy())
|
||||
except RuntimeError:
|
||||
# Not in async context (running in background thread)
|
||||
# Use call_soon_threadsafe to safely put on queue
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
self._main_loop.call_soon_threadsafe(
|
||||
lambda q=queue, d=progress_data.copy(): q.put_nowait(d) if not q.full() else None
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No main loop available for {model_name}, skipping notification")
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"Queue full for {model_name}, dropping update")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error notifying listener for {model_name}: {e}")
|
||||
|
||||
def update_progress(
|
||||
self,
|
||||
model_name: str,
|
||||
current: int,
|
||||
total: int,
|
||||
filename: Optional[str] = None,
|
||||
status: str = "downloading",
|
||||
):
|
||||
"""
|
||||
Update progress for a model download.
|
||||
|
||||
Thread-safe: can be called from background threads.
|
||||
|
||||
Progress updates are throttled to prevent overwhelming SSE clients.
|
||||
Updates are sent at most every THROTTLE_INTERVAL_SECONDS, or when
|
||||
progress changes by at least THROTTLE_PROGRESS_DELTA percent.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model (e.g., "qwen-tts-1.7B", "whisper-base")
|
||||
current: Current bytes downloaded
|
||||
total: Total bytes to download
|
||||
filename: Current file being downloaded
|
||||
status: Status string (downloading, extracting, complete, error)
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Calculate progress percentage, clamped to 0-100 range
|
||||
# This prevents crazy percentages from edge cases like:
|
||||
# - current > total temporarily during aggregation
|
||||
# - mixing file-count progress with byte-count progress
|
||||
if total > 0:
|
||||
progress_pct = min(100.0, max(0.0, (current / total * 100)))
|
||||
else:
|
||||
progress_pct = 0
|
||||
|
||||
progress_data = {
|
||||
"model_name": model_name,
|
||||
"current": current,
|
||||
"total": total,
|
||||
"progress": progress_pct,
|
||||
"filename": filename,
|
||||
"status": status,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# Thread-safe update of progress dict (always update internal state)
|
||||
with self._lock:
|
||||
self._progress[model_name] = progress_data
|
||||
|
||||
# Check if we should notify listeners (throttling)
|
||||
current_time = time.time()
|
||||
last_time = self._last_notify_time.get(model_name, 0)
|
||||
last_progress = self._last_notify_progress.get(model_name, -100)
|
||||
|
||||
time_delta = current_time - last_time
|
||||
progress_delta = abs(progress_pct - last_progress)
|
||||
|
||||
# Always notify for complete/error status, or if throttle conditions are met
|
||||
should_notify = (
|
||||
status in ("complete", "error") or
|
||||
time_delta >= self.THROTTLE_INTERVAL_SECONDS or
|
||||
progress_delta >= self.THROTTLE_PROGRESS_DELTA
|
||||
)
|
||||
|
||||
if not should_notify:
|
||||
return # Skip this update (throttled)
|
||||
|
||||
# Update throttle tracking
|
||||
self._last_notify_time[model_name] = current_time
|
||||
self._last_notify_progress[model_name] = progress_pct
|
||||
|
||||
# Notify all listeners (thread-safe)
|
||||
listener_count = len(self._listeners.get(model_name, []))
|
||||
|
||||
if listener_count > 0:
|
||||
logger.debug(f"Notifying {listener_count} listeners for {model_name}: {progress_pct:.1f}% ({filename})")
|
||||
self._notify_listeners_threadsafe(model_name, progress_data)
|
||||
else:
|
||||
logger.debug(f"No listeners for {model_name}, progress update stored: {progress_pct:.1f}%")
|
||||
|
||||
def get_progress(self, model_name: str) -> Optional[Dict]:
|
||||
"""Get current progress for a model. Thread-safe."""
|
||||
with self._lock:
|
||||
progress = self._progress.get(model_name)
|
||||
return progress.copy() if progress else None
|
||||
|
||||
def get_all_active(self) -> List[Dict]:
|
||||
"""Get all active downloads (status is 'downloading' or 'extracting'). Thread-safe."""
|
||||
active = []
|
||||
with self._lock:
|
||||
for model_name, progress in self._progress.items():
|
||||
status = progress.get("status", "")
|
||||
if status in ("downloading", "extracting"):
|
||||
active.append(progress.copy())
|
||||
return active
|
||||
|
||||
def create_progress_callback(self, model_name: str, filename: Optional[str] = None):
|
||||
"""
|
||||
Create a progress callback function for HuggingFace downloads.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
filename: Optional filename filter
|
||||
|
||||
Returns:
|
||||
Callback function
|
||||
"""
|
||||
def callback(progress: Dict):
|
||||
"""HuggingFace Hub progress callback."""
|
||||
if "total" in progress and "current" in progress:
|
||||
current = progress.get("current", 0)
|
||||
total = progress.get("total", 0)
|
||||
file_name = progress.get("filename", filename)
|
||||
|
||||
self.update_progress(
|
||||
model_name=model_name,
|
||||
current=current,
|
||||
total=total,
|
||||
filename=file_name,
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
async def subscribe(self, model_name: str):
|
||||
"""
|
||||
Subscribe to progress updates for a model.
|
||||
|
||||
Yields progress updates as Server-Sent Events.
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Store the main event loop for thread-safe operations
|
||||
try:
|
||||
self._main_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
queue = asyncio.Queue(maxsize=10)
|
||||
|
||||
# Add to listeners
|
||||
if model_name not in self._listeners:
|
||||
self._listeners[model_name] = []
|
||||
self._listeners[model_name].append(queue)
|
||||
|
||||
logger.info(f"SSE client subscribed to {model_name}, total listeners: {len(self._listeners[model_name])}")
|
||||
|
||||
try:
|
||||
# Send initial progress if available and still in progress (thread-safe read)
|
||||
with self._lock:
|
||||
initial_progress = self._progress.get(model_name)
|
||||
if initial_progress:
|
||||
initial_progress = initial_progress.copy()
|
||||
|
||||
if initial_progress:
|
||||
status = initial_progress.get('status')
|
||||
# Only send initial progress if download is actually in progress
|
||||
# Don't send old 'complete' or 'error' status from previous downloads
|
||||
if status in ('downloading', 'extracting'):
|
||||
logger.info(f"Sending initial progress for {model_name}: {status}")
|
||||
yield f"data: {json.dumps(initial_progress)}\n\n"
|
||||
else:
|
||||
logger.info(f"Skipping initial progress for {model_name} (status: {status})")
|
||||
else:
|
||||
logger.info(f"No initial progress available for {model_name}")
|
||||
|
||||
# Stream updates
|
||||
while True:
|
||||
try:
|
||||
# Wait for update with timeout
|
||||
progress = await asyncio.wait_for(queue.get(), timeout=1.0)
|
||||
logger.debug(f"Sending progress update for {model_name}: {progress.get('status')} - {progress.get('progress', 0):.1f}%")
|
||||
yield f"data: {json.dumps(progress)}\n\n"
|
||||
|
||||
# Stop if complete or error
|
||||
if progress.get("status") in ("complete", "error"):
|
||||
logger.info(f"Download {progress.get('status')} for {model_name}, closing SSE connection")
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat
|
||||
yield ": heartbeat\n\n"
|
||||
continue
|
||||
except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError):
|
||||
logger.debug(f"SSE client disconnected from {model_name}")
|
||||
finally:
|
||||
# Remove from listeners
|
||||
if model_name in self._listeners:
|
||||
self._listeners[model_name].remove(queue)
|
||||
if not self._listeners[model_name]:
|
||||
del self._listeners[model_name]
|
||||
logger.info(f"SSE client unsubscribed from {model_name}, remaining listeners: {len(self._listeners.get(model_name, []))}")
|
||||
|
||||
def mark_complete(self, model_name: str):
|
||||
"""Mark a model download as complete. Thread-safe."""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with self._lock:
|
||||
if model_name in self._progress:
|
||||
self._progress[model_name]["status"] = "complete"
|
||||
self._progress[model_name]["progress"] = 100.0
|
||||
progress_data = self._progress[model_name].copy()
|
||||
else:
|
||||
logger.warning(f"Cannot mark {model_name} as complete: not found in progress")
|
||||
return
|
||||
|
||||
logger.info(f"Marked {model_name} as complete")
|
||||
# Notify listeners (thread-safe)
|
||||
self._notify_listeners_threadsafe(model_name, progress_data)
|
||||
|
||||
def mark_error(self, model_name: str, error: str):
|
||||
"""Mark a model download as failed. Thread-safe."""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with self._lock:
|
||||
if model_name in self._progress:
|
||||
self._progress[model_name]["status"] = "error"
|
||||
self._progress[model_name]["error"] = error
|
||||
progress_data = self._progress[model_name].copy()
|
||||
else:
|
||||
# Create new progress entry for error
|
||||
progress_data = {
|
||||
"model_name": model_name,
|
||||
"current": 0,
|
||||
"total": 0,
|
||||
"progress": 0,
|
||||
"filename": None,
|
||||
"status": "error",
|
||||
"error": error,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
self._progress[model_name] = progress_data
|
||||
|
||||
logger.error(f"Marked {model_name} as error: {error}")
|
||||
# Notify listeners (thread-safe)
|
||||
self._notify_listeners_threadsafe(model_name, progress_data)
|
||||
|
||||
|
||||
# Global progress manager instance
|
||||
_progress_manager: Optional[ProgressManager] = None
|
||||
|
||||
|
||||
def get_progress_manager() -> ProgressManager:
|
||||
"""Get or create the global progress manager."""
|
||||
global _progress_manager
|
||||
if _progress_manager is None:
|
||||
_progress_manager = ProgressManager()
|
||||
return _progress_manager
|
||||
102
backend/utils/tasks.py
Normal file
102
backend/utils/tasks.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Task tracking for active downloads and generations.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, List
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadTask:
|
||||
"""Represents an active download task."""
|
||||
model_name: str
|
||||
status: str = "downloading" # downloading, extracting, complete, error
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationTask:
|
||||
"""Represents an active generation task."""
|
||||
task_id: str
|
||||
profile_id: str
|
||||
text_preview: str # First 50 chars of text
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""Manages active downloads and generations."""
|
||||
|
||||
def __init__(self):
|
||||
self._active_downloads: Dict[str, DownloadTask] = {}
|
||||
self._active_generations: Dict[str, GenerationTask] = {}
|
||||
|
||||
def start_download(self, model_name: str) -> None:
|
||||
"""Mark a download as started."""
|
||||
self._active_downloads[model_name] = DownloadTask(
|
||||
model_name=model_name,
|
||||
status="downloading",
|
||||
)
|
||||
|
||||
def complete_download(self, model_name: str) -> None:
|
||||
"""Mark a download as complete."""
|
||||
if model_name in self._active_downloads:
|
||||
del self._active_downloads[model_name]
|
||||
|
||||
def error_download(self, model_name: str, error: str) -> None:
|
||||
"""Mark a download as failed."""
|
||||
if model_name in self._active_downloads:
|
||||
self._active_downloads[model_name].status = "error"
|
||||
self._active_downloads[model_name].error = error
|
||||
|
||||
def start_generation(self, task_id: str, profile_id: str, text: str) -> None:
|
||||
"""Mark a generation as started."""
|
||||
text_preview = text[:50] + "..." if len(text) > 50 else text
|
||||
self._active_generations[task_id] = GenerationTask(
|
||||
task_id=task_id,
|
||||
profile_id=profile_id,
|
||||
text_preview=text_preview,
|
||||
)
|
||||
|
||||
def complete_generation(self, task_id: str) -> None:
|
||||
"""Mark a generation as complete."""
|
||||
if task_id in self._active_generations:
|
||||
del self._active_generations[task_id]
|
||||
|
||||
def get_active_downloads(self) -> List[DownloadTask]:
|
||||
"""Get all active downloads."""
|
||||
return list(self._active_downloads.values())
|
||||
|
||||
def get_active_generations(self) -> List[GenerationTask]:
|
||||
"""Get all active generations."""
|
||||
return list(self._active_generations.values())
|
||||
|
||||
def cancel_download(self, model_name: str) -> bool:
|
||||
"""Cancel/dismiss a download task (removes it from active list)."""
|
||||
return self._active_downloads.pop(model_name, None) is not None
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all download and generation tasks."""
|
||||
self._active_downloads.clear()
|
||||
self._active_generations.clear()
|
||||
|
||||
def is_download_active(self, model_name: str) -> bool:
|
||||
"""Check if a download is active."""
|
||||
return model_name in self._active_downloads
|
||||
|
||||
def is_generation_active(self, task_id: str) -> bool:
|
||||
"""Check if a generation is active."""
|
||||
return task_id in self._active_generations
|
||||
|
||||
|
||||
# Global task manager instance
|
||||
_task_manager: Optional[TaskManager] = None
|
||||
|
||||
|
||||
def get_task_manager() -> TaskManager:
|
||||
"""Get or create the global task manager."""
|
||||
global _task_manager
|
||||
if _task_manager is None:
|
||||
_task_manager = TaskManager()
|
||||
return _task_manager
|
||||
Reference in New Issue
Block a user