Initial commit
This commit is contained in:
299
backend/utils/chunked_tts.py
Normal file
299
backend/utils/chunked_tts.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Chunked TTS generation utilities.
|
||||
|
||||
Splits long text into sentence-boundary chunks, generates audio per-chunk
|
||||
via any TTSBackend, and concatenates with crossfade. All logic is
|
||||
engine-agnostic — it wraps the standard ``TTSBackend.generate()`` interface.
|
||||
|
||||
Short text (≤ max_chunk_chars) uses the single-shot fast path with zero
|
||||
overhead.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger("voicebox.chunked-tts")
|
||||
|
||||
# Default chunk size in characters. Can be overridden per-request via
|
||||
# the ``max_chunk_chars`` field on GenerationRequest.
|
||||
DEFAULT_MAX_CHUNK_CHARS = 800
|
||||
|
||||
# Common abbreviations that should NOT be treated as sentence endings.
|
||||
# Lowercase for case-insensitive matching.
|
||||
_ABBREVIATIONS = frozenset(
|
||||
{
|
||||
"mr",
|
||||
"mrs",
|
||||
"ms",
|
||||
"dr",
|
||||
"prof",
|
||||
"sr",
|
||||
"jr",
|
||||
"st",
|
||||
"ave",
|
||||
"blvd",
|
||||
"inc",
|
||||
"ltd",
|
||||
"corp",
|
||||
"dept",
|
||||
"est",
|
||||
"approx",
|
||||
"vs",
|
||||
"etc",
|
||||
"e.g",
|
||||
"i.e",
|
||||
"a.m",
|
||||
"p.m",
|
||||
"u.s",
|
||||
"u.s.a",
|
||||
"u.k",
|
||||
}
|
||||
)
|
||||
|
||||
# Paralinguistic tags used by Chatterbox Turbo. The splitter must never
|
||||
# cut inside one of these.
|
||||
_PARA_TAG_RE = re.compile(r"\[[^\]]*\]")
|
||||
|
||||
|
||||
def split_text_into_chunks(text: str, max_chars: int = DEFAULT_MAX_CHUNK_CHARS) -> List[str]:
|
||||
"""Split *text* at natural boundaries into chunks of at most *max_chars*.
|
||||
|
||||
Priority: sentence-end (``.!?`` not preceded by an abbreviation and not
|
||||
inside brackets) → clause boundary (``;:,—``) → whitespace → hard cut.
|
||||
|
||||
Paralinguistic tags like ``[laugh]`` are treated as atomic and will not
|
||||
be split across chunks.
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
if len(text) <= max_chars:
|
||||
return [text]
|
||||
|
||||
chunks: List[str] = []
|
||||
remaining = text
|
||||
|
||||
while remaining:
|
||||
remaining = remaining.lstrip()
|
||||
if not remaining:
|
||||
break
|
||||
if len(remaining) <= max_chars:
|
||||
chunks.append(remaining)
|
||||
break
|
||||
|
||||
segment = remaining[:max_chars]
|
||||
|
||||
# Try to split at the last real sentence ending
|
||||
split_pos = _find_last_sentence_end(segment)
|
||||
if split_pos == -1:
|
||||
split_pos = _find_last_clause_boundary(segment)
|
||||
if split_pos == -1:
|
||||
split_pos = segment.rfind(" ")
|
||||
if split_pos == -1:
|
||||
# Absolute fallback: hard cut but avoid splitting inside a tag
|
||||
split_pos = _safe_hard_cut(segment, max_chars)
|
||||
|
||||
chunk = remaining[: split_pos + 1].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
remaining = remaining[split_pos + 1 :]
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _find_last_sentence_end(text: str) -> int:
|
||||
"""Return the index of the last sentence-ending punctuation in *text*.
|
||||
|
||||
Skips periods that follow common abbreviations (``Dr.``, ``Mr.``, etc.)
|
||||
and periods inside bracket tags (``[laugh]``). Also handles CJK
|
||||
sentence-ending punctuation (``。!?``).
|
||||
"""
|
||||
best = -1
|
||||
# ASCII sentence ends
|
||||
for m in re.finditer(r"[.!?](?:\s|$)", text):
|
||||
pos = m.start()
|
||||
char = text[pos]
|
||||
# Skip periods after abbreviations
|
||||
if char == ".":
|
||||
# Walk backwards to find the preceding word
|
||||
word_start = pos - 1
|
||||
while word_start >= 0 and text[word_start].isalpha():
|
||||
word_start -= 1
|
||||
word = text[word_start + 1 : pos].lower()
|
||||
if word in _ABBREVIATIONS:
|
||||
continue
|
||||
# Skip decimal numbers (digit immediately before the period)
|
||||
if word_start >= 0 and text[word_start].isdigit():
|
||||
continue
|
||||
# Skip if we're inside a bracket tag
|
||||
if _inside_bracket_tag(text, pos):
|
||||
continue
|
||||
best = pos
|
||||
# CJK sentence-ending punctuation
|
||||
for m in re.finditer(r"[\u3002\uff01\uff1f]", text):
|
||||
if m.start() > best:
|
||||
best = m.start()
|
||||
return best
|
||||
|
||||
|
||||
def _find_last_clause_boundary(text: str) -> int:
|
||||
"""Return the index of the last clause-boundary punctuation."""
|
||||
best = -1
|
||||
for m in re.finditer(r"[;:,\u2014](?:\s|$)", text):
|
||||
pos = m.start()
|
||||
# Skip if inside a bracket tag
|
||||
if _inside_bracket_tag(text, pos):
|
||||
continue
|
||||
best = pos
|
||||
return best
|
||||
|
||||
|
||||
def _inside_bracket_tag(text: str, pos: int) -> bool:
|
||||
"""Return True if *pos* falls inside a ``[...]`` tag."""
|
||||
for m in _PARA_TAG_RE.finditer(text):
|
||||
if m.start() < pos < m.end():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _safe_hard_cut(segment: str, max_chars: int) -> int:
|
||||
"""Find a hard-cut position that doesn't split a ``[tag]``."""
|
||||
cut = max_chars - 1
|
||||
# Check if the cut falls inside a bracket tag; if so, move before it
|
||||
for m in _PARA_TAG_RE.finditer(segment):
|
||||
if m.start() < cut < m.end():
|
||||
return m.start() - 1 if m.start() > 0 else cut
|
||||
return cut
|
||||
|
||||
|
||||
def concatenate_audio_chunks(
|
||||
chunks: List[np.ndarray],
|
||||
sample_rate: int,
|
||||
crossfade_ms: int = 50,
|
||||
) -> np.ndarray:
|
||||
"""Concatenate audio arrays with a short crossfade to eliminate clicks.
|
||||
|
||||
Each chunk is expected to be a 1-D float32 ndarray at *sample_rate* Hz.
|
||||
"""
|
||||
if not chunks:
|
||||
return np.array([], dtype=np.float32)
|
||||
if len(chunks) == 1:
|
||||
return chunks[0]
|
||||
|
||||
crossfade_samples = int(sample_rate * crossfade_ms / 1000)
|
||||
result = np.array(chunks[0], dtype=np.float32, copy=True)
|
||||
|
||||
for chunk in chunks[1:]:
|
||||
if len(chunk) == 0:
|
||||
continue
|
||||
overlap = min(crossfade_samples, len(result), len(chunk))
|
||||
if overlap > 0:
|
||||
fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32)
|
||||
fade_in = np.linspace(0.0, 1.0, overlap, dtype=np.float32)
|
||||
result[-overlap:] = result[-overlap:] * fade_out + chunk[:overlap] * fade_in
|
||||
result = np.concatenate([result, chunk[overlap:]])
|
||||
else:
|
||||
result = np.concatenate([result, chunk])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def generate_chunked(
|
||||
backend,
|
||||
text: str,
|
||||
voice_prompt: dict,
|
||||
language: str = "en",
|
||||
seed: int | None = None,
|
||||
instruct: str | None = None,
|
||||
max_chunk_chars: int = DEFAULT_MAX_CHUNK_CHARS,
|
||||
crossfade_ms: int = 50,
|
||||
trim_fn=None,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""Generate audio with automatic chunking for long text.
|
||||
|
||||
For text shorter than *max_chunk_chars* this is a thin wrapper around
|
||||
``backend.generate()`` with zero overhead.
|
||||
|
||||
For longer text the input is split at natural sentence boundaries,
|
||||
each chunk is generated independently, optionally trimmed (useful for
|
||||
Chatterbox engines that hallucinate trailing noise), and the results
|
||||
are concatenated with a crossfade (or hard cut if *crossfade_ms* is 0).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backend : TTSBackend
|
||||
Any backend implementing the ``generate()`` protocol.
|
||||
text : str
|
||||
Input text (may be arbitrarily long).
|
||||
voice_prompt, language, seed, instruct
|
||||
Forwarded to ``backend.generate()`` verbatim.
|
||||
max_chunk_chars : int
|
||||
Maximum characters per chunk (default 800).
|
||||
crossfade_ms : int
|
||||
Crossfade duration in milliseconds between chunks. 0 for a hard
|
||||
cut with no overlap (default 50).
|
||||
trim_fn : callable | None
|
||||
Optional ``(audio, sample_rate) -> audio`` post-processing
|
||||
function applied to each chunk before concatenation (e.g.
|
||||
``trim_tts_output`` for Chatterbox engines).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(audio, sample_rate) : Tuple[np.ndarray, int]
|
||||
"""
|
||||
chunks = split_text_into_chunks(text, max_chunk_chars)
|
||||
|
||||
if len(chunks) <= 1:
|
||||
# Short text — single-shot fast path
|
||||
audio, sample_rate = await backend.generate(
|
||||
text,
|
||||
voice_prompt,
|
||||
language,
|
||||
seed,
|
||||
instruct,
|
||||
)
|
||||
if trim_fn is not None:
|
||||
audio = trim_fn(audio, sample_rate)
|
||||
return audio, sample_rate
|
||||
|
||||
# Long text — chunked generation
|
||||
logger.info(
|
||||
"Splitting %d chars into %d chunks (max %d chars each)",
|
||||
len(text),
|
||||
len(chunks),
|
||||
max_chunk_chars,
|
||||
)
|
||||
audio_chunks: List[np.ndarray] = []
|
||||
sample_rate: int | None = None
|
||||
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
logger.info(
|
||||
"Generating chunk %d/%d (%d chars)",
|
||||
i + 1,
|
||||
len(chunks),
|
||||
len(chunk_text),
|
||||
)
|
||||
# Vary the seed per chunk to avoid correlated RNG artefacts,
|
||||
# but keep it deterministic so the same (text, seed) pair
|
||||
# always produces the same output.
|
||||
chunk_seed = (seed + i) if seed is not None else None
|
||||
|
||||
chunk_audio, chunk_sr = await backend.generate(
|
||||
chunk_text,
|
||||
voice_prompt,
|
||||
language,
|
||||
chunk_seed,
|
||||
instruct,
|
||||
)
|
||||
if trim_fn is not None:
|
||||
chunk_audio = trim_fn(chunk_audio, chunk_sr)
|
||||
|
||||
audio_chunks.append(np.asarray(chunk_audio, dtype=np.float32))
|
||||
if sample_rate is None:
|
||||
sample_rate = chunk_sr
|
||||
|
||||
audio = concatenate_audio_chunks(audio_chunks, sample_rate, crossfade_ms=crossfade_ms)
|
||||
return audio, sample_rate
|
||||
Reference in New Issue
Block a user