# coding=utf-8 # Copyright 2026 The Alibaba Qwen team. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, List, Tuple, Union import librosa import numpy as np import torch from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig from qwen_tts.core.models.modeling_qwen3_tts import mel_spectrogram from torch.utils.data import Dataset AudioLike = Union[ str, # wav path, URL, base64 np.ndarray, # waveform (requires sr) Tuple[np.ndarray, int], # (waveform, sr) ] MaybeList = Union[Any, List[Any]] class TTSDataset(Dataset): def __init__(self, data_list, processor, config:Qwen3TTSConfig, lag_num = -1): self.data_list = data_list self.processor = processor self.lag_num = lag_num self.config = config def __len__(self): return len(self.data_list) def _load_audio_to_np(self, x: str) -> Tuple[np.ndarray, int]: audio, sr = librosa.load(x, sr=None, mono=True) if audio.ndim > 1: audio = np.mean(audio, axis=-1) return audio.astype(np.float32), int(sr) def _normalize_audio_inputs(self, audios: Union[AudioLike, List[AudioLike]]) -> List[Tuple[np.ndarray, int]]: """ Normalize audio inputs into a list of (waveform, sr). Supported forms: - str: wav path / URL / base64 audio string - np.ndarray: waveform (NOT allowed alone here because sr is unknown) - (np.ndarray, sr): waveform + sampling rate - list of the above Args: audios: Audio input(s). Returns: List[Tuple[np.ndarray, int]]: List of (float32 waveform, original sr). Raises: ValueError: If a numpy waveform is provided without sr. """ if isinstance(audios, list): items = audios else: items = [audios] out: List[Tuple[np.ndarray, int]] = [] for a in items: if isinstance(a, str): out.append(self._load_audio_to_np(a)) elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray): out.append((a[0].astype(np.float32), int(a[1]))) elif isinstance(a, np.ndarray): raise ValueError("For numpy waveform input, pass a tuple (audio, sr).") else: raise TypeError(f"Unsupported audio input type: {type(a)}") return out def _build_assistant_text(self, text: str) -> str: return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" def _ensure_list(self, x: MaybeList) -> List[Any]: return x if isinstance(x, list) else [x] def _tokenize_texts(self, text) -> List[torch.Tensor]: input = self.processor(text=text, return_tensors="pt", padding=True) input_id = input["input_ids"] input_id = input_id.unsqueeze(0) if input_id.dim() == 1 else input_id return input_id @torch.inference_mode() def extract_mels(self, audio, sr): assert sr == 24000, "Only support 24kHz audio" mels = mel_spectrogram( torch.from_numpy(audio).unsqueeze(0), n_fft=1024, num_mels=128, sampling_rate=24000, hop_size=256, win_size=1024, fmin=0, fmax=12000 ).transpose(1, 2) return mels def __getitem__(self, idx): item = self.data_list[idx] audio_path = item["audio"] text = item["text"] audio_codes = item["audio_codes"] language = item.get('language','Auto') ref_audio_path = item['ref_audio'] text = self._build_assistant_text(text) text_ids = self._tokenize_texts(text) audio_codes = torch.tensor(audio_codes, dtype=torch.long) ref_audio_list = self._ensure_list(ref_audio_path) normalized = self._normalize_audio_inputs(ref_audio_list) wav,sr = normalized[0] ref_mel = self.extract_mels(audio=wav, sr=sr) return { "text_ids": text_ids[:,:-5], # 1 , t "audio_codes":audio_codes, # t, 16 "ref_mel":ref_mel } def collate_fn(self, batch): assert self.lag_num == -1 item_length = [b['text_ids'].shape[1] + b['audio_codes'].shape[0] for b in batch] max_length = max(item_length) + 8 b,t = len(batch),max_length input_ids = torch.zeros((b,t,2),dtype=torch.long) codec_ids = torch.zeros((b,t,16),dtype=torch.long) text_embedding_mask = torch.zeros((b,t),dtype=torch.bool) codec_embedding_mask = torch.zeros((b,t),dtype=torch.bool) codec_mask = torch.zeros((b,t),dtype=torch.bool) attention_mask = torch.zeros((b,t),dtype=torch.long) codec_0_labels = torch.full((b, t), -100, dtype=torch.long) for i,data in enumerate(batch): text_ids = data['text_ids'] audio_codec_0 = data['audio_codes'][:,0] audio_codecs = data['audio_codes'] text_ids_len = text_ids.shape[1] codec_ids_len = audio_codec_0.shape[0] # text channel input_ids[i, :3, 0] = text_ids[0,:3] input_ids[i, 3:7, 0] = self.config.tts_pad_token_id input_ids[i, 7, 0] = self.config.tts_bos_token_id input_ids[i, 8:8+text_ids_len-3, 0] = text_ids[0,3:] input_ids[i, 8+text_ids_len-3, 0] = self.config.tts_eos_token_id input_ids[i, 8+text_ids_len-2:8+text_ids_len+codec_ids_len , 0] = self.config.tts_pad_token_id text_embedding_mask[i, :8+text_ids_len+codec_ids_len] = True # codec channel # input_ids[i, :3, 1] = 0 input_ids[i, 3:8 ,1] = torch.tensor( [ self.config.talker_config.codec_nothink_id, self.config.talker_config.codec_think_bos_id, self.config.talker_config.codec_think_eos_id, 0, # for speaker embedding self.config.talker_config.codec_pad_id ] ) input_ids[i, 8:8+text_ids_len-3 ,1] = self.config.talker_config.codec_pad_id input_ids[i, 8+text_ids_len-3 ,1] = self.config.talker_config.codec_pad_id input_ids[i, 8+text_ids_len-2 ,1] = self.config.talker_config.codec_bos_id input_ids[i, 8+text_ids_len-1:8+text_ids_len-1+codec_ids_len, 1] = audio_codec_0 input_ids[i, 8+text_ids_len-1+codec_ids_len, 1] = self.config.talker_config.codec_eos_token_id codec_0_labels[i, 8+text_ids_len-1:8+text_ids_len-1+codec_ids_len] = audio_codec_0 codec_0_labels[i, 8+text_ids_len-1+codec_ids_len] = self.config.talker_config.codec_eos_token_id codec_ids[i, 8+text_ids_len-1:8+text_ids_len-1+codec_ids_len,:] = audio_codecs codec_embedding_mask[i, 3:8+text_ids_len+codec_ids_len] = True codec_embedding_mask[i, 6] = False # for speaker embedding codec_mask[i, 8+text_ids_len-1:8+text_ids_len-1+codec_ids_len] = True attention_mask[i, :8+text_ids_len+codec_ids_len] = True ref_mels = [data['ref_mel'] for data in batch] ref_mels = torch.cat(ref_mels,dim=0) return { 'input_ids':input_ids, 'ref_mels':ref_mels, 'attention_mask':attention_mask, 'text_embedding_mask':text_embedding_mask.unsqueeze(-1), 'codec_embedding_mask':codec_embedding_mask.unsqueeze(-1), 'codec_0_labels':codec_0_labels, 'codec_ids': codec_ids, 'codec_mask':codec_mask }