first commit
This commit is contained in:
121
finetuning/README.md
Normal file
121
finetuning/README.md
Normal file
@@ -0,0 +1,121 @@
|
||||
## Fine Tuning Qwen3-TTS-12Hz-1.7B/0.6B-Base
|
||||
|
||||
The Qwen3-TTS-12Hz-1.7B/0.6B-Base model series currently supports single-speaker fine-tuning. Please run `pip install qwen-tts` first, then run the command below:
|
||||
|
||||
```
|
||||
git clone https://github.com/QwenLM/Qwen3-TTS.git
|
||||
cd Qwen3-TTS/finetuning
|
||||
```
|
||||
|
||||
Then follow the steps below to complete the entire fine-tuning workflow. Multi-speaker fine-tuning and other advanced fine-tuning features will be supported in future releases.
|
||||
|
||||
### 1) Input JSONL format
|
||||
|
||||
Prepare your training file as a JSONL (one JSON object per line). Each line must contain:
|
||||
|
||||
- `audio`: path to the target training audio (wav)
|
||||
- `text`: transcript corresponding to `audio`
|
||||
- `ref_audio`: path to the reference speaker audio (wav)
|
||||
|
||||
Example:
|
||||
```jsonl
|
||||
{"audio":"./data/utt0001.wav","text":"其实我真的有发现,我是一个特别善于观察别人情绪的人。","ref_audio":"./data/ref.wav"}
|
||||
{"audio":"./data/utt0002.wav","text":"She said she would be here by noon.","ref_audio":"./data/ref.wav"}
|
||||
```
|
||||
|
||||
`ref_audio` recommendation:
|
||||
- Strongly recommended: use the same `ref_audio` for all samples.
|
||||
- Keeping `ref_audio` identical across the dataset usually improves speaker consistency and stability during generation.
|
||||
|
||||
|
||||
### 2) Prepare data (extract `audio_codes`)
|
||||
|
||||
Convert `train_raw.jsonl` into a training JSONL that includes `audio_codes`:
|
||||
|
||||
```bash
|
||||
python prepare_data.py \
|
||||
--device cuda:0 \
|
||||
--tokenizer_model_path Qwen/Qwen3-TTS-Tokenizer-12Hz \
|
||||
--input_jsonl train_raw.jsonl \
|
||||
--output_jsonl train_with_codes.jsonl
|
||||
```
|
||||
|
||||
|
||||
### 3) Fine-tune
|
||||
|
||||
Run SFT using the prepared JSONL:
|
||||
|
||||
```bash
|
||||
python sft_12hz.py \
|
||||
--init_model_path Qwen/Qwen3-TTS-12Hz-1.7B-Base \
|
||||
--output_model_path output \
|
||||
--train_jsonl train_with_codes.jsonl \
|
||||
--batch_size 32 \
|
||||
--lr 2e-6 \
|
||||
--num_epochs 10 \
|
||||
--speaker_name speaker_test
|
||||
```
|
||||
|
||||
Checkpoints will be written to:
|
||||
- `output/checkpoint-epoch-0`
|
||||
- `output/checkpoint-epoch-1`
|
||||
- `output/checkpoint-epoch-2`
|
||||
- ...
|
||||
|
||||
|
||||
### 4) Quick inference test
|
||||
|
||||
```python
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
|
||||
device = "cuda:0"
|
||||
tts = Qwen3TTSModel.from_pretrained(
|
||||
"output/checkpoint-epoch-2",
|
||||
device_map=device,
|
||||
dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
wavs, sr = tts.generate_custom_voice(
|
||||
text="She said she would be here by noon.",
|
||||
speaker="speaker_test",
|
||||
)
|
||||
sf.write("output.wav", wavs[0], sr)
|
||||
```
|
||||
|
||||
### One-click shell script example
|
||||
|
||||
```bash
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
DEVICE="cuda:0"
|
||||
TOKENIZER_MODEL_PATH="Qwen/Qwen3-TTS-Tokenizer-12Hz"
|
||||
INIT_MODEL_PATH="Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
|
||||
RAW_JSONL="train_raw.jsonl"
|
||||
TRAIN_JSONL="train_with_codes.jsonl"
|
||||
OUTPUT_DIR="output"
|
||||
|
||||
BATCH_SIZE=2
|
||||
LR=2e-5
|
||||
EPOCHS=3
|
||||
SPEAKER_NAME="speaker_1"
|
||||
|
||||
python prepare_data.py \
|
||||
--device ${DEVICE} \
|
||||
--tokenizer_model_path ${TOKENIZER_MODEL_PATH} \
|
||||
--input_jsonl ${RAW_JSONL} \
|
||||
--output_jsonl ${TRAIN_JSONL}
|
||||
|
||||
python sft_12hz.py \
|
||||
--init_model_path ${INIT_MODEL_PATH} \
|
||||
--output_model_path ${OUTPUT_DIR} \
|
||||
--train_jsonl ${TRAIN_JSONL} \
|
||||
--batch_size ${BATCH_SIZE} \
|
||||
--lr ${LR} \
|
||||
--num_epochs ${EPOCHS} \
|
||||
--speaker_name ${SPEAKER_NAME}
|
||||
```
|
||||
218
finetuning/dataset.py
Normal file
218
finetuning/dataset.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
|
||||
from qwen_tts.core.models.modeling_qwen3_tts import mel_spectrogram
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
AudioLike = Union[
|
||||
str, # wav path, URL, base64
|
||||
np.ndarray, # waveform (requires sr)
|
||||
Tuple[np.ndarray, int], # (waveform, sr)
|
||||
]
|
||||
|
||||
MaybeList = Union[Any, List[Any]]
|
||||
|
||||
class TTSDataset(Dataset):
|
||||
def __init__(self, data_list, processor, config:Qwen3TTSConfig, lag_num = -1):
|
||||
self.data_list = data_list
|
||||
self.processor = processor
|
||||
self.lag_num = lag_num
|
||||
self.config = config
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_list)
|
||||
|
||||
def _load_audio_to_np(self, x: str) -> Tuple[np.ndarray, int]:
|
||||
|
||||
audio, sr = librosa.load(x, sr=None, mono=True)
|
||||
|
||||
if audio.ndim > 1:
|
||||
audio = np.mean(audio, axis=-1)
|
||||
|
||||
return audio.astype(np.float32), int(sr)
|
||||
|
||||
def _normalize_audio_inputs(self, audios: Union[AudioLike, List[AudioLike]]) -> List[Tuple[np.ndarray, int]]:
|
||||
"""
|
||||
Normalize audio inputs into a list of (waveform, sr).
|
||||
|
||||
Supported forms:
|
||||
- str: wav path / URL / base64 audio string
|
||||
- np.ndarray: waveform (NOT allowed alone here because sr is unknown)
|
||||
- (np.ndarray, sr): waveform + sampling rate
|
||||
- list of the above
|
||||
|
||||
Args:
|
||||
audios:
|
||||
Audio input(s).
|
||||
|
||||
Returns:
|
||||
List[Tuple[np.ndarray, int]]:
|
||||
List of (float32 waveform, original sr).
|
||||
|
||||
Raises:
|
||||
ValueError: If a numpy waveform is provided without sr.
|
||||
"""
|
||||
if isinstance(audios, list):
|
||||
items = audios
|
||||
else:
|
||||
items = [audios]
|
||||
|
||||
out: List[Tuple[np.ndarray, int]] = []
|
||||
for a in items:
|
||||
if isinstance(a, str):
|
||||
out.append(self._load_audio_to_np(a))
|
||||
elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray):
|
||||
out.append((a[0].astype(np.float32), int(a[1])))
|
||||
elif isinstance(a, np.ndarray):
|
||||
raise ValueError("For numpy waveform input, pass a tuple (audio, sr).")
|
||||
else:
|
||||
raise TypeError(f"Unsupported audio input type: {type(a)}")
|
||||
return out
|
||||
|
||||
|
||||
def _build_assistant_text(self, text: str) -> str:
|
||||
return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def _ensure_list(self, x: MaybeList) -> List[Any]:
|
||||
return x if isinstance(x, list) else [x]
|
||||
|
||||
def _tokenize_texts(self, text) -> List[torch.Tensor]:
|
||||
input = self.processor(text=text, return_tensors="pt", padding=True)
|
||||
input_id = input["input_ids"]
|
||||
input_id = input_id.unsqueeze(0) if input_id.dim() == 1 else input_id
|
||||
return input_id
|
||||
|
||||
@torch.inference_mode()
|
||||
def extract_mels(self, audio, sr):
|
||||
assert sr == 24000, "Only support 24kHz audio"
|
||||
mels = mel_spectrogram(
|
||||
torch.from_numpy(audio).unsqueeze(0),
|
||||
n_fft=1024,
|
||||
num_mels=128,
|
||||
sampling_rate=24000,
|
||||
hop_size=256,
|
||||
win_size=1024,
|
||||
fmin=0,
|
||||
fmax=12000
|
||||
).transpose(1, 2)
|
||||
return mels
|
||||
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_list[idx]
|
||||
|
||||
audio_path = item["audio"]
|
||||
text = item["text"]
|
||||
audio_codes = item["audio_codes"]
|
||||
language = item.get('language','Auto')
|
||||
ref_audio_path = item['ref_audio']
|
||||
|
||||
text = self._build_assistant_text(text)
|
||||
text_ids = self._tokenize_texts(text)
|
||||
|
||||
audio_codes = torch.tensor(audio_codes, dtype=torch.long)
|
||||
|
||||
ref_audio_list = self._ensure_list(ref_audio_path)
|
||||
normalized = self._normalize_audio_inputs(ref_audio_list)
|
||||
wav,sr = normalized[0]
|
||||
|
||||
ref_mel = self.extract_mels(audio=wav, sr=sr)
|
||||
|
||||
return {
|
||||
"text_ids": text_ids[:,:-5], # 1 , t
|
||||
"audio_codes":audio_codes, # t, 16
|
||||
"ref_mel":ref_mel
|
||||
}
|
||||
|
||||
def collate_fn(self, batch):
|
||||
assert self.lag_num == -1
|
||||
|
||||
item_length = [b['text_ids'].shape[1] + b['audio_codes'].shape[0] for b in batch]
|
||||
max_length = max(item_length) + 8
|
||||
b,t = len(batch),max_length
|
||||
|
||||
input_ids = torch.zeros((b,t,2),dtype=torch.long)
|
||||
codec_ids = torch.zeros((b,t,16),dtype=torch.long)
|
||||
text_embedding_mask = torch.zeros((b,t),dtype=torch.bool)
|
||||
codec_embedding_mask = torch.zeros((b,t),dtype=torch.bool)
|
||||
codec_mask = torch.zeros((b,t),dtype=torch.bool)
|
||||
attention_mask = torch.zeros((b,t),dtype=torch.long)
|
||||
codec_0_labels = torch.full((b, t), -100, dtype=torch.long)
|
||||
|
||||
for i,data in enumerate(batch):
|
||||
text_ids = data['text_ids']
|
||||
audio_codec_0 = data['audio_codes'][:,0]
|
||||
audio_codecs = data['audio_codes']
|
||||
|
||||
text_ids_len = text_ids.shape[1]
|
||||
codec_ids_len = audio_codec_0.shape[0]
|
||||
|
||||
# text channel
|
||||
input_ids[i, :3, 0] = text_ids[0,:3]
|
||||
input_ids[i, 3:7, 0] = self.config.tts_pad_token_id
|
||||
input_ids[i, 7, 0] = self.config.tts_bos_token_id
|
||||
input_ids[i, 8:8+text_ids_len-3, 0] = text_ids[0,3:]
|
||||
input_ids[i, 8+text_ids_len-3, 0] = self.config.tts_eos_token_id
|
||||
input_ids[i, 8+text_ids_len-2:8+text_ids_len+codec_ids_len , 0] = self.config.tts_pad_token_id
|
||||
text_embedding_mask[i, :8+text_ids_len+codec_ids_len] = True
|
||||
|
||||
# codec channel
|
||||
# input_ids[i, :3, 1] = 0
|
||||
input_ids[i, 3:8 ,1] = torch.tensor(
|
||||
[
|
||||
self.config.talker_config.codec_nothink_id,
|
||||
self.config.talker_config.codec_think_bos_id,
|
||||
self.config.talker_config.codec_think_eos_id,
|
||||
0, # for speaker embedding
|
||||
self.config.talker_config.codec_pad_id
|
||||
]
|
||||
)
|
||||
input_ids[i, 8:8+text_ids_len-3 ,1] = self.config.talker_config.codec_pad_id
|
||||
input_ids[i, 8+text_ids_len-3 ,1] = self.config.talker_config.codec_pad_id
|
||||
input_ids[i, 8+text_ids_len-2 ,1] = self.config.talker_config.codec_bos_id
|
||||
input_ids[i, 8+text_ids_len-1:8+text_ids_len-1+codec_ids_len, 1] = audio_codec_0
|
||||
input_ids[i, 8+text_ids_len-1+codec_ids_len, 1] = self.config.talker_config.codec_eos_token_id
|
||||
|
||||
codec_0_labels[i, 8+text_ids_len-1:8+text_ids_len-1+codec_ids_len] = audio_codec_0
|
||||
codec_0_labels[i, 8+text_ids_len-1+codec_ids_len] = self.config.talker_config.codec_eos_token_id
|
||||
|
||||
codec_ids[i, 8+text_ids_len-1:8+text_ids_len-1+codec_ids_len,:] = audio_codecs
|
||||
|
||||
codec_embedding_mask[i, 3:8+text_ids_len+codec_ids_len] = True
|
||||
codec_embedding_mask[i, 6] = False # for speaker embedding
|
||||
|
||||
codec_mask[i, 8+text_ids_len-1:8+text_ids_len-1+codec_ids_len] = True
|
||||
attention_mask[i, :8+text_ids_len+codec_ids_len] = True
|
||||
|
||||
ref_mels = [data['ref_mel'] for data in batch]
|
||||
ref_mels = torch.cat(ref_mels,dim=0)
|
||||
|
||||
return {
|
||||
'input_ids':input_ids,
|
||||
'ref_mels':ref_mels,
|
||||
'attention_mask':attention_mask,
|
||||
'text_embedding_mask':text_embedding_mask.unsqueeze(-1),
|
||||
'codec_embedding_mask':codec_embedding_mask.unsqueeze(-1),
|
||||
'codec_0_labels':codec_0_labels,
|
||||
'codec_ids': codec_ids,
|
||||
'codec_mask':codec_mask
|
||||
}
|
||||
71
finetuning/prepare_data.py
Normal file
71
finetuning/prepare_data.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from qwen_tts import Qwen3TTSTokenizer
|
||||
|
||||
BATCH_INFER_NUM = 32
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--device", type=str, default="cuda:0")
|
||||
parser.add_argument("--tokenizer_model_path", type=str, default="Qwen/Qwen3-TTS-Tokenizer-12Hz")
|
||||
parser.add_argument("--input_jsonl", type=str, required=True)
|
||||
parser.add_argument("--output_jsonl", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer_12hz = Qwen3TTSTokenizer.from_pretrained(
|
||||
args.tokenizer_model_path,
|
||||
device_map=args.device,
|
||||
)
|
||||
|
||||
total_lines = open(args.input_jsonl).readlines()
|
||||
total_lines = [json.loads(line.strip()) for line in total_lines]
|
||||
|
||||
final_lines = []
|
||||
batch_lines = []
|
||||
batch_audios = []
|
||||
for line in total_lines:
|
||||
|
||||
batch_lines.append(line)
|
||||
batch_audios.append(line['audio'])
|
||||
|
||||
if len(batch_lines) >= BATCH_INFER_NUM:
|
||||
enc_res = tokenizer_12hz.encode(batch_audios)
|
||||
for code, line in zip(enc_res.audio_codes, batch_lines):
|
||||
line['audio_codes'] = code.cpu().tolist()
|
||||
final_lines.append(line)
|
||||
batch_lines.clear()
|
||||
batch_audios.clear()
|
||||
|
||||
if len(batch_audios) > 0:
|
||||
enc_res = tokenizer_12hz.encode(batch_audios)
|
||||
for code, line in zip(enc_res.audio_codes, batch_lines):
|
||||
line['audio_codes'] = code.cpu().tolist()
|
||||
final_lines.append(line)
|
||||
batch_lines.clear()
|
||||
batch_audios.clear()
|
||||
|
||||
final_lines = [json.dumps(line, ensure_ascii=False) for line in final_lines]
|
||||
|
||||
with open(args.output_jsonl, 'w') as f:
|
||||
for line in final_lines:
|
||||
f.writelines(line + '\n')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
161
finetuning/sft_12hz.py
Normal file
161
finetuning/sft_12hz.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from dataset import TTSDataset
|
||||
from qwen_tts.inference.qwen3_tts_model import Qwen3TTSModel
|
||||
from safetensors.torch import save_file
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoConfig
|
||||
|
||||
target_speaker_embedding = None
|
||||
def train():
|
||||
global target_speaker_embedding
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--init_model_path", type=str, default="Qwen/Qwen3-TTS-12Hz-1.7B-Base")
|
||||
parser.add_argument("--output_model_path", type=str, default="output")
|
||||
parser.add_argument("--train_jsonl", type=str, required=True)
|
||||
parser.add_argument("--batch_size", type=int, default=2)
|
||||
parser.add_argument("--lr", type=float, default=2e-5)
|
||||
parser.add_argument("--num_epochs", type=int, default=3)
|
||||
parser.add_argument("--speaker_name", type=str, default="speaker_test")
|
||||
args = parser.parse_args()
|
||||
|
||||
accelerator = Accelerator(gradient_accumulation_steps=4, mixed_precision="bf16", log_with="tensorboard")
|
||||
|
||||
MODEL_PATH = args.init_model_path
|
||||
|
||||
qwen3tts = Qwen3TTSModel.from_pretrained(
|
||||
MODEL_PATH,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
config = AutoConfig.from_pretrained(MODEL_PATH)
|
||||
|
||||
train_data = open(args.train_jsonl).readlines()
|
||||
train_data = [json.loads(line) for line in train_data]
|
||||
dataset = TTSDataset(train_data, qwen3tts.processor, config)
|
||||
train_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=dataset.collate_fn)
|
||||
|
||||
optimizer = AdamW(qwen3tts.model.parameters(), lr=args.lr, weight_decay=0.01)
|
||||
|
||||
model, optimizer, train_dataloader = accelerator.prepare(
|
||||
qwen3tts.model, optimizer, train_dataloader
|
||||
)
|
||||
|
||||
num_epochs = args.num_epochs
|
||||
model.train()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
|
||||
input_ids = batch['input_ids']
|
||||
codec_ids = batch['codec_ids']
|
||||
ref_mels = batch['ref_mels']
|
||||
text_embedding_mask = batch['text_embedding_mask']
|
||||
codec_embedding_mask = batch['codec_embedding_mask']
|
||||
attention_mask = batch['attention_mask']
|
||||
codec_0_labels = batch['codec_0_labels']
|
||||
codec_mask = batch['codec_mask']
|
||||
|
||||
speaker_embedding = model.speaker_encoder(ref_mels.to(model.device).to(model.dtype)).detach()
|
||||
if target_speaker_embedding is None:
|
||||
target_speaker_embedding = speaker_embedding
|
||||
|
||||
input_text_ids = input_ids[:, :, 0]
|
||||
input_codec_ids = input_ids[:, :, 1]
|
||||
|
||||
input_text_embedding = model.talker.model.text_embedding(input_text_ids) * text_embedding_mask
|
||||
input_codec_embedding = model.talker.model.codec_embedding(input_codec_ids) * codec_embedding_mask
|
||||
input_codec_embedding[:, 6, :] = speaker_embedding
|
||||
|
||||
input_embeddings = input_text_embedding + input_codec_embedding
|
||||
|
||||
for i in range(1, 16):
|
||||
codec_i_embedding = model.talker.code_predictor.get_input_embeddings()[i - 1](codec_ids[:, :, i])
|
||||
codec_i_embedding = codec_i_embedding * codec_mask.unsqueeze(-1)
|
||||
input_embeddings = input_embeddings + codec_i_embedding
|
||||
|
||||
outputs = model.talker(
|
||||
inputs_embeds=input_embeddings[:, :-1, :],
|
||||
attention_mask=attention_mask[:, :-1],
|
||||
labels=codec_0_labels[:, 1:],
|
||||
output_hidden_states=True
|
||||
)
|
||||
|
||||
hidden_states = outputs.hidden_states[0][-1]
|
||||
talker_hidden_states = hidden_states[codec_mask[:, :-1]]
|
||||
talker_codec_ids = codec_ids[codec_mask]
|
||||
|
||||
sub_talker_logits, sub_talker_loss = model.talker.forward_sub_talker_finetune(talker_codec_ids, talker_hidden_states)
|
||||
|
||||
loss = outputs.loss + 0.3 * sub_talker_loss
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % 10 == 0:
|
||||
accelerator.print(f"Epoch {epoch} | Step {step} | Loss: {loss.item():.4f}")
|
||||
|
||||
if accelerator.is_main_process:
|
||||
output_dir = os.path.join(args.output_model_path, f"checkpoint-epoch-{epoch}")
|
||||
shutil.copytree(MODEL_PATH, output_dir, dirs_exist_ok=True)
|
||||
|
||||
input_config_file = os.path.join(MODEL_PATH, "config.json")
|
||||
output_config_file = os.path.join(output_dir, "config.json")
|
||||
with open(input_config_file, 'r', encoding='utf-8') as f:
|
||||
config_dict = json.load(f)
|
||||
config_dict["tts_model_type"] = "custom_voice"
|
||||
talker_config = config_dict.get("talker_config", {})
|
||||
talker_config["spk_id"] = {
|
||||
args.speaker_name: 3000
|
||||
}
|
||||
talker_config["spk_is_dialect"] = {
|
||||
args.speaker_name: False
|
||||
}
|
||||
config_dict["talker_config"] = talker_config
|
||||
|
||||
with open(output_config_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(config_dict, f, indent=2, ensure_ascii=False)
|
||||
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
state_dict = {k: v.detach().to("cpu") for k, v in unwrapped_model.state_dict().items()}
|
||||
|
||||
drop_prefix = "speaker_encoder"
|
||||
keys_to_drop = [k for k in state_dict.keys() if k.startswith(drop_prefix)]
|
||||
for k in keys_to_drop:
|
||||
del state_dict[k]
|
||||
|
||||
weight = state_dict['talker.model.codec_embedding.weight']
|
||||
state_dict['talker.model.codec_embedding.weight'][3000] = target_speaker_embedding[0].detach().to(weight.device).to(weight.dtype)
|
||||
save_path = os.path.join(output_dir, "model.safetensors")
|
||||
save_file(state_dict, save_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user