from __future__ import annotations import json import random import re import sys from dataclasses import dataclass from pathlib import Path import torch from torch.utils.data import DataLoader, Dataset from transformers import AutoTokenizer import yaml PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from app.services.joint_nlu import JointBertForNLU TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl" MULTI_TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_train.jsonl" SEED_PATH = PROJECT_ROOT / "app/data/joint_nlu_seed.jsonl" EVAL_PATH = PROJECT_ROOT / "app/data/joint_nlu_eval.jsonl" MULTI_EVAL_PATH = PROJECT_ROOT / "app/data/joint_nlu_multilabel_eval.jsonl" DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml" OUTPUT_DIR = PROJECT_ROOT / "models/local_joint_bert_nlu" DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base" MAX_LENGTH = 64 BATCH_SIZE = 8 EPOCHS = 8 LEARNING_RATE = 2e-5 SEED = 42 IGNORE_INDEX = -100 GENRE_KEYWORDS = ("轻音乐", "摇滚", "古典", "民谣", "爵士", "流行", "儿歌") DEFAULT_INTENT_THRESHOLD = 0.3 MULTI_INTENT_REPEAT = 6 THRESHOLD_CANDIDATES = [0.1, 0.12, 0.15, 0.18, 0.2, 0.22, 0.25, 0.28, 0.3, 0.33, 0.35, 0.38, 0.4, 0.45] @dataclass class JointSample: text: str intent_ids: list[str] slots: list[dict[str, object]] class JointDataset(Dataset): def __init__( self, samples: list[JointSample], tokenizer, intent_to_index: dict[str, int], slot_to_index: dict[str, int], ) -> None: self._samples = samples self._tokenizer = tokenizer self._intent_to_index = intent_to_index self._slot_to_index = slot_to_index def __len__(self) -> int: return len(self._samples) def __getitem__(self, index: int) -> dict[str, torch.Tensor]: sample = self._samples[index] encoded = self._tokenizer( sample.text, truncation=True, max_length=MAX_LENGTH, padding="max_length", return_offsets_mapping=True, return_tensors="pt", ) offset_mapping = encoded.pop("offset_mapping")[0].tolist() slot_labels = [IGNORE_INDEX] * len(offset_mapping) char_labels = ["O"] * len(sample.text) for slot in sample.slots: start = int(slot["start"]) end = int(slot["end"]) slot_name = str(slot["slot_name"]) if start < 0 or end > len(sample.text) or start >= end: continue char_labels[start] = f"B-{slot_name}" for pos in range(start + 1, end): char_labels[pos] = f"I-{slot_name}" for token_index, (start, end) in enumerate(offset_mapping): if end <= start: continue label = char_labels[start] slot_labels[token_index] = self._slot_to_index.get(label, self._slot_to_index["O"]) intent_vector = torch.zeros(len(self._intent_to_index), dtype=torch.float32) for intent_id in sample.intent_ids: if intent_id in self._intent_to_index: intent_vector[self._intent_to_index[intent_id]] = 1.0 return { "input_ids": encoded["input_ids"][0], "attention_mask": encoded["attention_mask"][0], "intent_labels": intent_vector, "slot_labels": torch.tensor(slot_labels, dtype=torch.long), } def set_seed() -> None: random.seed(SEED) torch.manual_seed(SEED) def load_jsonl(path: Path) -> list[dict[str, object]]: rows: list[dict[str, object]] = [] with path.open("r", encoding="utf-8") as handle: for line in handle: line = line.strip() if not line: continue rows.append(json.loads(line)) return rows def find_order_id_span(text: str) -> tuple[str, int, int] | None: match = re.search(r"[A-Za-z]\d{5,}", text) if not match: return None return match.group(0), match.start(), match.end() def find_temperature_span(text: str) -> tuple[str, int, int] | None: match = re.search(r"(\d{2}\s*度)", text) if not match: return None return match.group(1), match.start(), match.end() def find_destination_span(text: str) -> tuple[str, int, int] | None: for pattern in ( r"导航去(?P.+)", r"导航到(?P.+)", r"带我去(?P.+)", r"送我去(?P.+)", r"去(?P.+)", ): match = re.search(pattern, text) if not match: continue destination = re.split(r"(?:然后|并且|同时|再|,|,|;|;)", match.group("destination"), maxsplit=1)[0].strip(" ,。") if not destination: continue start = text.find(destination) if start >= 0: return destination, start, start + len(destination) return None def find_music_span(text: str) -> tuple[str, str, int, int] | None: for genre in GENRE_KEYWORDS: start = text.find(genre) if start >= 0: return "genre", genre, start, start + len(genre) for trigger in ("播放", "来点", "放点", "听", "来首", "来一首", "放一首"): if trigger not in text: continue target = text.split(trigger, maxsplit=1)[-1] target = re.split(r"(?:然后|并且|同时|再|,|,|;|;)", target, maxsplit=1)[0].strip(" 的一首首个歌曲音乐吧呀啊,。") if not target or target in {"歌", "音乐"}: continue for genre in GENRE_KEYWORDS: if genre in target: start = text.find(genre) return "genre", genre, start, start + len(genre) start = text.find(target) if start >= 0: return "song", target, start, start + len(target) return None def annotate_slots(text: str, intent_id: str) -> list[dict[str, object]]: slots: list[dict[str, object]] = [] if intent_id in {"cs_query_order", "cs_query_logistics", "cs_cancel_order"}: matched = find_order_id_span(text) if matched is not None: value, start, end = matched slots.append({"slot_name": "order_id", "value": value, "start": start, "end": end}) elif intent_id == "cabin_set_ac": matched = find_temperature_span(text) if matched is not None: value, start, end = matched slots.append({"slot_name": "temperature", "value": value, "start": start, "end": end}) elif intent_id == "cabin_nav_to": matched = find_destination_span(text) if matched is not None: value, start, end = matched slots.append({"slot_name": "destination", "value": value, "start": start, "end": end}) elif intent_id == "cabin_play_music": matched = find_music_span(text) if matched is not None: slot_name, value, start, end = matched slots.append({"slot_name": slot_name, "value": value, "start": start, "end": end}) return slots def annotate_slots_for_intents(text: str, intent_ids: list[str]) -> list[dict[str, object]]: merged: list[dict[str, object]] = [] seen: set[tuple[str, int, int]] = set() for intent_id in intent_ids: for slot in annotate_slots(text, intent_id): key = (str(slot["slot_name"]), int(slot["start"]), int(slot["end"])) if key in seen: continue seen.add(key) merged.append(slot) merged.sort(key=lambda item: (int(item["start"]), int(item["end"]))) return merged def build_train_samples() -> list[JointSample]: samples: list[JointSample] = [] seen: set[tuple[str, tuple[str, ...]]] = set() domain_data = yaml.safe_load(DOMAIN_PATH.read_text(encoding="utf-8")) or {} for intent in domain_data.get("intents", []): intent_id = str(intent.get("intent_id", "")).strip() if not intent_id: continue for text in list(intent.get("examples", [])) + list(intent.get("keywords", [])): text = str(text).strip() if not text: continue key = (text, (intent_id,)) if key in seen: continue seen.add(key) samples.append(JointSample(text=text, intent_ids=[intent_id], slots=annotate_slots(text, intent_id))) for row in load_jsonl(TRAIN_PATH): text = str(row["text"]) intent_id = str(row["intent_id"]) key = (text, (intent_id,)) if key in seen: continue seen.add(key) samples.append(JointSample(text=text, intent_ids=[intent_id], slots=annotate_slots(text, intent_id))) for row in load_jsonl(SEED_PATH): text = str(row["text"]) intent_id = str(row["intent_id"]) key = (text, (intent_id,)) if key in seen: continue seen.add(key) samples.append(JointSample(text=text, intent_ids=[intent_id], slots=list(row.get("slots", [])))) for row in load_jsonl(MULTI_TRAIN_PATH): text = str(row["text"]).strip() intent_ids = sorted({str(item).strip() for item in row.get("intent_ids", []) if str(item).strip()}) if not text or not intent_ids: continue key = (text, tuple(intent_ids)) if key in seen: continue seen.add(key) slots = list(row.get("slots") or annotate_slots_for_intents(text, intent_ids)) samples.append(JointSample(text=text, intent_ids=intent_ids, slots=slots)) if len(intent_ids) >= 2: for _ in range(MULTI_INTENT_REPEAT - 1): samples.append(JointSample(text=text, intent_ids=intent_ids, slots=list(slots))) random.shuffle(samples) return samples def build_eval_samples() -> list[JointSample]: rows = load_jsonl(EVAL_PATH) samples = [ JointSample( text=str(row["text"]), intent_ids=[str(row["intent_id"])], slots=list(row.get("slots", [])), ) for row in rows ] if MULTI_EVAL_PATH.exists(): for row in load_jsonl(MULTI_EVAL_PATH): samples.append( JointSample( text=str(row["text"]), intent_ids=sorted({str(item).strip() for item in row.get("intent_ids", []) if str(item).strip()}), slots=list(row.get("slots") or annotate_slots_for_intents(str(row["text"]), list(row.get("intent_ids", [])))), ) ) return samples def build_slot_labels(samples: list[JointSample]) -> list[str]: slot_names = sorted({str(slot["slot_name"]) for sample in samples for slot in sample.slots}) labels = ["O"] for name in slot_names: labels.append(f"B-{name}") labels.append(f"I-{name}") return labels def compute_metrics( model: JointBertForNLU, dataloader: DataLoader, device: torch.device, intent_labels: list[str], slot_labels: list[str], threshold: float, ) -> dict[str, float]: model.eval() intent_tp = 0 intent_fp = 0 intent_fn = 0 single_intent_correct = 0 single_intent_total = 0 intent_exact_match = 0 correct_slot_tokens = 0 total_slot_tokens = 0 exact_slot_samples = 0 total_samples = 0 with torch.no_grad(): for batch in dataloader: batch = {key: value.to(device) for key, value in batch.items()} intent_logits, slot_logits = model(batch["input_ids"], batch["attention_mask"]) predicted_probs = torch.sigmoid(intent_logits) predicted_multi = predicted_probs >= threshold gold_multi = batch["intent_labels"] > 0.5 intent_tp += int((predicted_multi & gold_multi).sum().item()) intent_fp += int((predicted_multi & ~gold_multi).sum().item()) intent_fn += int((~predicted_multi & gold_multi).sum().item()) intent_exact_match += int((predicted_multi == gold_multi).all(dim=1).sum().item()) top_predicted = torch.argmax(predicted_probs, dim=-1) gold_counts = gold_multi.sum(dim=-1) single_mask = gold_counts == 1 if int(single_mask.sum().item()) > 0: gold_top = torch.argmax(gold_multi.float(), dim=-1) single_intent_correct += int((top_predicted[single_mask] == gold_top[single_mask]).sum().item()) single_intent_total += int(single_mask.sum().item()) predicted_slots = torch.argmax(slot_logits, dim=-1) mask = batch["slot_labels"] != IGNORE_INDEX correct_slot_tokens += int(((predicted_slots == batch["slot_labels"]) & mask).sum().item()) total_slot_tokens += int(mask.sum().item()) for index in range(batch["slot_labels"].size(0)): gold = batch["slot_labels"][index][mask[index]] pred = predicted_slots[index][mask[index]] exact_slot_samples += int(torch.equal(gold, pred)) total_samples += 1 precision = intent_tp / max(intent_tp + intent_fp, 1) recall = intent_tp / max(intent_tp + intent_fn, 1) f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0 return { "intent_threshold": round(threshold, 4), "intent_micro_precision": round(precision, 4), "intent_micro_recall": round(recall, 4), "intent_micro_f1": round(f1, 4), "intent_exact_match": round(intent_exact_match / max(total_samples, 1), 4), "single_intent_top1_accuracy": round(single_intent_correct / max(single_intent_total, 1), 4), "slot_token_accuracy": round(correct_slot_tokens / max(total_slot_tokens, 1), 4), "slot_exact_match": round(exact_slot_samples / max(total_samples, 1), 4), "intent_label_count": float(len(intent_labels)), "slot_label_count": float(len(slot_labels)), } def search_best_threshold( model: JointBertForNLU, dataloader: DataLoader, device: torch.device, intent_labels: list[str], slot_labels: list[str], ) -> dict[str, float]: best_metrics: dict[str, float] | None = None for threshold in THRESHOLD_CANDIDATES: metrics = compute_metrics( model, dataloader, device, intent_labels, slot_labels, threshold=threshold, ) if best_metrics is None: best_metrics = metrics continue current_score = (metrics["intent_micro_f1"], metrics["intent_exact_match"], metrics["slot_exact_match"]) best_score = ( best_metrics["intent_micro_f1"], best_metrics["intent_exact_match"], best_metrics["slot_exact_match"], ) if current_score > best_score: best_metrics = metrics assert best_metrics is not None return best_metrics def build_pos_weight(samples: list[JointSample], intent_labels: list[str]) -> torch.Tensor: positive_counts = {label: 0 for label in intent_labels} for sample in samples: sample_intents = set(sample.intent_ids) for label in intent_labels: if label in sample_intents: positive_counts[label] += 1 total = max(len(samples), 1) weights: list[float] = [] for label in intent_labels: positives = max(positive_counts[label], 1) negatives = max(total - positives, 1) weight = negatives / positives weights.append(min(max(weight, 1.0), 12.0)) return torch.tensor(weights, dtype=torch.float32) def main() -> None: set_seed() train_samples = build_train_samples() eval_samples = build_eval_samples() intent_labels = sorted({intent_id for sample in train_samples + eval_samples for intent_id in sample.intent_ids}) slot_labels = build_slot_labels(train_samples + eval_samples) intent_to_index = {label: index for index, label in enumerate(intent_labels)} slot_to_index = {label: index for index, label in enumerate(slot_labels)} tokenizer = AutoTokenizer.from_pretrained(DEFAULT_BASE_MODEL) train_dataset = JointDataset(train_samples, tokenizer, intent_to_index, slot_to_index) eval_dataset = JointDataset(eval_samples, tokenizer, intent_to_index, slot_to_index) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False) device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") model = JointBertForNLU( base_model_name=DEFAULT_BASE_MODEL, num_intents=len(intent_labels), num_slot_labels=len(slot_labels), ) model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) pos_weight = build_pos_weight(train_samples, intent_labels).to(device) best_metrics: dict[str, float] | None = None best_state: dict[str, torch.Tensor] | None = None for epoch in range(EPOCHS): model.train() epoch_loss = 0.0 for batch in train_loader: batch = {key: value.to(device) for key, value in batch.items()} optimizer.zero_grad() intent_logits, slot_logits = model(batch["input_ids"], batch["attention_mask"]) intent_loss = torch.nn.functional.binary_cross_entropy_with_logits( intent_logits, batch["intent_labels"], pos_weight=pos_weight, ) slot_loss = torch.nn.functional.cross_entropy( slot_logits.view(-1, slot_logits.size(-1)), batch["slot_labels"].view(-1), ignore_index=IGNORE_INDEX, ) loss = intent_loss + slot_loss loss.backward() optimizer.step() epoch_loss += float(loss.item()) metrics = search_best_threshold(model, eval_loader, device, intent_labels, slot_labels) metrics["train_loss"] = round(epoch_loss / max(len(train_loader), 1), 4) print(json.dumps({"epoch": epoch + 1, **metrics}, ensure_ascii=False)) if best_metrics is None or metrics["intent_micro_f1"] > best_metrics["intent_micro_f1"]: best_metrics = metrics best_state = {key: value.detach().cpu() for key, value in model.state_dict().items()} if best_state is None or best_metrics is None: raise RuntimeError("joint nlu training did not produce a best checkpoint") OUTPUT_DIR.mkdir(parents=True, exist_ok=True) tokenizer.save_pretrained(OUTPUT_DIR) torch.save(best_state, OUTPUT_DIR / "model_state.pt") config = { "base_model_name": DEFAULT_BASE_MODEL, "intent_task": "multi_label", "intent_labels": intent_labels, "slot_labels": slot_labels, "max_length": MAX_LENGTH, "intent_threshold": float(best_metrics["intent_threshold"]), "multi_intent_threshold": float(best_metrics["intent_threshold"]), "max_multi_intents": 4, } (OUTPUT_DIR / "joint_nlu_config.json").write_text(json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8") (OUTPUT_DIR / "train_summary.json").write_text( json.dumps( { "train_size": len(train_samples), "eval_size": len(eval_samples), "metrics": best_metrics, }, ensure_ascii=False, indent=2, ), encoding="utf-8", ) if __name__ == "__main__": main()