""" Chunked TTS generation utilities. Splits long text into sentence-boundary chunks, generates audio per-chunk via any TTSBackend, and concatenates with crossfade. All logic is engine-agnostic — it wraps the standard ``TTSBackend.generate()`` interface. Short text (≤ max_chunk_chars) uses the single-shot fast path with zero overhead. """ import logging import re from typing import List, Tuple import numpy as np logger = logging.getLogger("voicebox.chunked-tts") # Default chunk size in characters. Can be overridden per-request via # the ``max_chunk_chars`` field on GenerationRequest. DEFAULT_MAX_CHUNK_CHARS = 800 # Common abbreviations that should NOT be treated as sentence endings. # Lowercase for case-insensitive matching. _ABBREVIATIONS = frozenset( { "mr", "mrs", "ms", "dr", "prof", "sr", "jr", "st", "ave", "blvd", "inc", "ltd", "corp", "dept", "est", "approx", "vs", "etc", "e.g", "i.e", "a.m", "p.m", "u.s", "u.s.a", "u.k", } ) # Paralinguistic tags used by Chatterbox Turbo. The splitter must never # cut inside one of these. _PARA_TAG_RE = re.compile(r"\[[^\]]*\]") def split_text_into_chunks(text: str, max_chars: int = DEFAULT_MAX_CHUNK_CHARS) -> List[str]: """Split *text* at natural boundaries into chunks of at most *max_chars*. Priority: sentence-end (``.!?`` not preceded by an abbreviation and not inside brackets) → clause boundary (``;:,—``) → whitespace → hard cut. Paralinguistic tags like ``[laugh]`` are treated as atomic and will not be split across chunks. """ text = text.strip() if not text: return [] if len(text) <= max_chars: return [text] chunks: List[str] = [] remaining = text while remaining: remaining = remaining.lstrip() if not remaining: break if len(remaining) <= max_chars: chunks.append(remaining) break segment = remaining[:max_chars] # Try to split at the last real sentence ending split_pos = _find_last_sentence_end(segment) if split_pos == -1: split_pos = _find_last_clause_boundary(segment) if split_pos == -1: split_pos = segment.rfind(" ") if split_pos == -1: # Absolute fallback: hard cut but avoid splitting inside a tag split_pos = _safe_hard_cut(segment, max_chars) chunk = remaining[: split_pos + 1].strip() if chunk: chunks.append(chunk) remaining = remaining[split_pos + 1 :] return chunks def _find_last_sentence_end(text: str) -> int: """Return the index of the last sentence-ending punctuation in *text*. Skips periods that follow common abbreviations (``Dr.``, ``Mr.``, etc.) and periods inside bracket tags (``[laugh]``). Also handles CJK sentence-ending punctuation (``。!?``). """ best = -1 # ASCII sentence ends for m in re.finditer(r"[.!?](?:\s|$)", text): pos = m.start() char = text[pos] # Skip periods after abbreviations if char == ".": # Walk backwards to find the preceding word word_start = pos - 1 while word_start >= 0 and text[word_start].isalpha(): word_start -= 1 word = text[word_start + 1 : pos].lower() if word in _ABBREVIATIONS: continue # Skip decimal numbers (digit immediately before the period) if word_start >= 0 and text[word_start].isdigit(): continue # Skip if we're inside a bracket tag if _inside_bracket_tag(text, pos): continue best = pos # CJK sentence-ending punctuation for m in re.finditer(r"[\u3002\uff01\uff1f]", text): if m.start() > best: best = m.start() return best def _find_last_clause_boundary(text: str) -> int: """Return the index of the last clause-boundary punctuation.""" best = -1 for m in re.finditer(r"[;:,\u2014](?:\s|$)", text): pos = m.start() # Skip if inside a bracket tag if _inside_bracket_tag(text, pos): continue best = pos return best def _inside_bracket_tag(text: str, pos: int) -> bool: """Return True if *pos* falls inside a ``[...]`` tag.""" for m in _PARA_TAG_RE.finditer(text): if m.start() < pos < m.end(): return True return False def _safe_hard_cut(segment: str, max_chars: int) -> int: """Find a hard-cut position that doesn't split a ``[tag]``.""" cut = max_chars - 1 # Check if the cut falls inside a bracket tag; if so, move before it for m in _PARA_TAG_RE.finditer(segment): if m.start() < cut < m.end(): return m.start() - 1 if m.start() > 0 else cut return cut def concatenate_audio_chunks( chunks: List[np.ndarray], sample_rate: int, crossfade_ms: int = 50, ) -> np.ndarray: """Concatenate audio arrays with a short crossfade to eliminate clicks. Each chunk is expected to be a 1-D float32 ndarray at *sample_rate* Hz. """ if not chunks: return np.array([], dtype=np.float32) if len(chunks) == 1: return chunks[0] crossfade_samples = int(sample_rate * crossfade_ms / 1000) result = np.array(chunks[0], dtype=np.float32, copy=True) for chunk in chunks[1:]: if len(chunk) == 0: continue overlap = min(crossfade_samples, len(result), len(chunk)) if overlap > 0: fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32) fade_in = np.linspace(0.0, 1.0, overlap, dtype=np.float32) result[-overlap:] = result[-overlap:] * fade_out + chunk[:overlap] * fade_in result = np.concatenate([result, chunk[overlap:]]) else: result = np.concatenate([result, chunk]) return result async def generate_chunked( backend, text: str, voice_prompt: dict, language: str = "en", seed: int | None = None, instruct: str | None = None, max_chunk_chars: int = DEFAULT_MAX_CHUNK_CHARS, crossfade_ms: int = 50, trim_fn=None, ) -> Tuple[np.ndarray, int]: """Generate audio with automatic chunking for long text. For text shorter than *max_chunk_chars* this is a thin wrapper around ``backend.generate()`` with zero overhead. For longer text the input is split at natural sentence boundaries, each chunk is generated independently, optionally trimmed (useful for Chatterbox engines that hallucinate trailing noise), and the results are concatenated with a crossfade (or hard cut if *crossfade_ms* is 0). Parameters ---------- backend : TTSBackend Any backend implementing the ``generate()`` protocol. text : str Input text (may be arbitrarily long). voice_prompt, language, seed, instruct Forwarded to ``backend.generate()`` verbatim. max_chunk_chars : int Maximum characters per chunk (default 800). crossfade_ms : int Crossfade duration in milliseconds between chunks. 0 for a hard cut with no overlap (default 50). trim_fn : callable | None Optional ``(audio, sample_rate) -> audio`` post-processing function applied to each chunk before concatenation (e.g. ``trim_tts_output`` for Chatterbox engines). Returns ------- (audio, sample_rate) : Tuple[np.ndarray, int] """ chunks = split_text_into_chunks(text, max_chunk_chars) if len(chunks) <= 1: # Short text — single-shot fast path audio, sample_rate = await backend.generate( text, voice_prompt, language, seed, instruct, ) if trim_fn is not None: audio = trim_fn(audio, sample_rate) return audio, sample_rate # Long text — chunked generation logger.info( "Splitting %d chars into %d chunks (max %d chars each)", len(text), len(chunks), max_chunk_chars, ) audio_chunks: List[np.ndarray] = [] sample_rate: int | None = None for i, chunk_text in enumerate(chunks): logger.info( "Generating chunk %d/%d (%d chars)", i + 1, len(chunks), len(chunk_text), ) # Vary the seed per chunk to avoid correlated RNG artefacts, # but keep it deterministic so the same (text, seed) pair # always produces the same output. chunk_seed = (seed + i) if seed is not None else None chunk_audio, chunk_sr = await backend.generate( chunk_text, voice_prompt, language, chunk_seed, instruct, ) if trim_fn is not None: chunk_audio = trim_fn(chunk_audio, chunk_sr) audio_chunks.append(np.asarray(chunk_audio, dtype=np.float32)) if sample_rate is None: sample_rate = chunk_sr audio = concatenate_audio_chunks(audio_chunks, sample_rate, crossfade_ms=crossfade_ms) return audio, sample_rate