from __future__ import annotations import json import os import random from dataclasses import dataclass from pathlib import Path import torch from torch.utils.data import DataLoader, Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer import yaml PROJECT_ROOT = Path(__file__).resolve().parents[1] SINGLE_LABEL_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl" MULTI_LABEL_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_train.jsonl" DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml" OUTPUT_DIR = PROJECT_ROOT / "models/local_bert_multi_intent" DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base" SOCIAL_LABEL = "__social__" OUT_OF_SCOPE_LABEL = "__out_of_scope__" BLOCKED_LABELS = {SOCIAL_LABEL, OUT_OF_SCOPE_LABEL} MAX_LENGTH = 48 BATCH_SIZE = 8 EPOCHS = 12 LEARNING_RATE = 2e-5 THRESHOLD = 0.5 TOP_K = 4 SEED = 42 CONNECTOR_VARIANTS: tuple[tuple[str, str], ...] = ( ("并", "然后"), ("然后", "并"), ("顺便", "再"), ("再", "顺便"), ) @dataclass(frozen=True) class MultiLabelSample: text: str intent_ids: tuple[str, ...] class MultiLabelIntentDataset(Dataset): def __init__( self, samples: list[MultiLabelSample], tokenizer, label_to_id: dict[str, int], ) -> None: self._samples = samples self._tokenizer = tokenizer self._label_to_id = label_to_id self._label_size = len(label_to_id) def __len__(self) -> int: return len(self._samples) def __getitem__(self, index: int) -> dict[str, torch.Tensor]: sample = self._samples[index] encoded = self._tokenizer( sample.text, truncation=True, padding="max_length", max_length=MAX_LENGTH, return_tensors="pt", ) labels = torch.zeros(self._label_size, dtype=torch.float32) for intent_id in sample.intent_ids: labels[self._label_to_id[intent_id]] = 1.0 return { "input_ids": encoded["input_ids"].squeeze(0), "attention_mask": encoded["attention_mask"].squeeze(0), "labels": labels, } def set_seed(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) def resolve_base_model() -> str: configured = os.getenv("AGENT_BERT_BASE_MODEL", "").strip() if configured: return configured return DEFAULT_BASE_MODEL def normalize_text(text: str) -> str: return " ".join(str(text).strip().split()) def normalize_intent_ids(intent_ids: list[str] | tuple[str, ...]) -> tuple[str, ...]: cleaned = sorted( { str(intent_id).strip() for intent_id in intent_ids if str(intent_id).strip() and str(intent_id).strip() not in BLOCKED_LABELS } ) return tuple(cleaned) def expand_single_label_variants(text: str) -> list[str]: normalized = text.strip().strip(",。!?;; ") if not normalized: return [] variants = { normalized, normalized.replace("一下", "").strip(), normalized.replace("帮我", "").strip(), normalized.replace("请", "").strip(), f"帮我{normalized}", f"请{normalized}", f"{normalized}一下", } cleaned: list[str] = [] for item in variants: compact = " ".join(item.split()).strip(",。!?;; ") if compact: cleaned.append(compact) return cleaned def load_single_label_samples(file_path: Path) -> list[MultiLabelSample]: samples: list[MultiLabelSample] = [] if not file_path.exists(): return samples for line in file_path.read_text(encoding="utf-8").splitlines(): line = line.strip() if not line: continue payload = json.loads(line) intent_ids = normalize_intent_ids([str(payload.get("intent_id") or "")]) if not intent_ids: continue text = normalize_text(str(payload.get("text") or "")) if not text: continue samples.append(MultiLabelSample(text=text, intent_ids=intent_ids)) return samples def load_domain_samples(file_path: Path) -> list[MultiLabelSample]: if not file_path.exists(): return [] payload = yaml.safe_load(file_path.read_text(encoding="utf-8")) or {} intents = payload.get("intents", []) samples: list[MultiLabelSample] = [] seen: set[tuple[str, tuple[str, ...]]] = set() for item in intents: intent_ids = normalize_intent_ids([str(item.get("intent_id") or "")]) if not intent_ids: continue seed_texts = list(item.get("examples") or []) seed_texts.extend(item.get("keywords") or []) label = str(item.get("label") or "").strip() if label: seed_texts.append(label) for text in seed_texts: normalized = normalize_text(text) if not normalized: continue for variant in expand_single_label_variants(normalized): key = (variant, intent_ids) if key in seen: continue seen.add(key) samples.append(MultiLabelSample(text=variant, intent_ids=intent_ids)) return samples def load_multilabel_samples(file_path: Path) -> list[MultiLabelSample]: samples: list[MultiLabelSample] = [] if not file_path.exists(): return samples for line in file_path.read_text(encoding="utf-8").splitlines(): line = line.strip() if not line: continue payload = json.loads(line) intent_ids = normalize_intent_ids(list(payload.get("intent_ids") or [])) if len(intent_ids) < 2: continue text = normalize_text(str(payload.get("text") or "")) if not text: continue samples.append(MultiLabelSample(text=text, intent_ids=intent_ids)) return samples def augment_multilabel_samples(samples: list[MultiLabelSample]) -> list[MultiLabelSample]: augmented = list(samples) seen = {(sample.text, sample.intent_ids) for sample in augmented} for sample in list(samples): variants = { sample.text, f"帮我{sample.text}", f"请{sample.text}", sample.text.replace(",", ", "), sample.text.replace(",", ""), } for source, target in CONNECTOR_VARIANTS: if source in sample.text: variants.add(sample.text.replace(source, target, 1)) for variant in variants: normalized = normalize_text(variant).strip(",。!?;; ") key = (normalized, sample.intent_ids) if normalized and key not in seen: augmented.append(MultiLabelSample(text=normalized, intent_ids=sample.intent_ids)) seen.add(key) return augmented def load_all_samples() -> list[MultiLabelSample]: samples = load_single_label_samples(SINGLE_LABEL_PATH) samples.extend(load_domain_samples(DOMAIN_PATH)) samples.extend(augment_multilabel_samples(load_multilabel_samples(MULTI_LABEL_PATH))) deduped: list[MultiLabelSample] = [] seen: set[tuple[str, tuple[str, ...]]] = set() for sample in samples: key = (sample.text, sample.intent_ids) if key in seen: continue seen.add(key) deduped.append(sample) random.shuffle(deduped) return deduped def split_samples(samples: list[MultiLabelSample]) -> tuple[list[MultiLabelSample], list[MultiLabelSample]]: grouped: dict[tuple[str, ...], list[MultiLabelSample]] = {} for sample in samples: grouped.setdefault(sample.intent_ids, []).append(sample) train_samples: list[MultiLabelSample] = [] dev_samples: list[MultiLabelSample] = [] for items in grouped.values(): random.shuffle(items) if len(items) == 1: train_samples.extend(items) continue cut = max(1, int(len(items) * 0.8)) if cut >= len(items): cut = len(items) - 1 train_samples.extend(items[:cut]) dev_samples.extend(items[cut:]) if not dev_samples: dev_samples = train_samples[-max(1, min(32, len(train_samples) // 5 or 1)) :] train_samples = train_samples[: len(train_samples) - len(dev_samples)] random.shuffle(train_samples) random.shuffle(dev_samples) return train_samples, dev_samples def logits_to_probabilities(logits: torch.Tensor) -> list[list[float]]: return torch.sigmoid(logits).detach().cpu().tolist() def compute_metrics( probabilities: list[list[float]], targets: list[list[float]], threshold: float, top_k: int, ) -> dict[str, float]: true_positive = 0 false_positive = 0 false_negative = 0 exact_match = 0 recall_at_k_total = 0.0 total = len(probabilities) for scores, target in zip(probabilities, targets): predicted = {index for index, score in enumerate(scores) if score >= threshold} expected = {index for index, value in enumerate(target) if value >= 0.5} if predicted == expected: exact_match += 1 true_positive += len(predicted & expected) false_positive += len(predicted - expected) false_negative += len(expected - predicted) top_indices = sorted(range(len(scores)), key=lambda index: scores[index], reverse=True)[:top_k] if expected: recall_at_k_total += len(set(top_indices) & expected) / len(expected) precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0 recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0 micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0 return { "micro_precision": round(precision, 4), "micro_recall": round(recall, 4), "micro_f1": round(micro_f1, 4), "exact_match": round(exact_match / total, 4) if total else 0.0, "recall_at_k": round(recall_at_k_total / total, 4) if total else 0.0, } def evaluate(model, loader: DataLoader, device: torch.device, threshold: float, top_k: int) -> tuple[float, dict[str, float]]: model.eval() total_loss = 0.0 probabilities: list[list[float]] = [] targets: list[list[float]] = [] with torch.no_grad(): for batch in loader: input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) total_loss += float(outputs.loss.item()) probabilities.extend(logits_to_probabilities(outputs.logits)) targets.extend(labels.detach().cpu().tolist()) avg_loss = total_loss / max(len(loader), 1) return avg_loss, compute_metrics(probabilities, targets, threshold=threshold, top_k=top_k) def main() -> None: set_seed(SEED) samples = load_all_samples() intents = sorted({intent_id for sample in samples for intent_id in sample.intent_ids}) label_to_id = {intent_id: index for index, intent_id in enumerate(intents)} id_to_label = {index: intent_id for intent_id, index in label_to_id.items()} train_samples, dev_samples = split_samples(samples) base_model = resolve_base_model() tokenizer = AutoTokenizer.from_pretrained(base_model) train_dataset = MultiLabelIntentDataset(train_samples, tokenizer, label_to_id) dev_dataset = MultiLabelIntentDataset(dev_samples, tokenizer, label_to_id) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE) model = AutoModelForSequenceClassification.from_pretrained( base_model, num_labels=len(intents), id2label=id_to_label, label2id=label_to_id, problem_type="multi_label_classification", ) device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) best_dev_f1 = 0.0 best_state = None best_metrics: dict[str, float] = {} for epoch in range(1, EPOCHS + 1): model.train() total_loss = 0.0 for batch in train_loader: optimizer.zero_grad() input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss loss.backward() optimizer.step() total_loss += float(loss.item()) dev_loss, dev_metrics = evaluate(model, dev_loader, device, threshold=THRESHOLD, top_k=TOP_K) avg_loss = total_loss / max(len(train_loader), 1) print( " ".join( [ f"epoch={epoch}", f"train_loss={avg_loss:.4f}", f"dev_loss={dev_loss:.4f}", f"dev_micro_f1={dev_metrics['micro_f1']:.4f}", f"dev_exact_match={dev_metrics['exact_match']:.4f}", ] ) ) if dev_metrics["micro_f1"] >= best_dev_f1: best_dev_f1 = dev_metrics["micro_f1"] best_metrics = dict(dev_metrics) best_state = {key: value.detach().cpu().clone() for key, value in model.state_dict().items()} if best_state is not None: model.load_state_dict(best_state) OUTPUT_DIR.mkdir(parents=True, exist_ok=True) model.save_pretrained(OUTPUT_DIR) tokenizer.save_pretrained(OUTPUT_DIR) label_map = {f"LABEL_{index}": intent_id for index, intent_id in id_to_label.items()} (OUTPUT_DIR / "label_map.json").write_text( json.dumps(label_map, ensure_ascii=False, indent=2), encoding="utf-8", ) train_summary = { "task_type": "multi_label_intent_detection", "base_model": base_model, "epochs": EPOCHS, "batch_size": BATCH_SIZE, "learning_rate": LEARNING_RATE, "threshold": THRESHOLD, "top_k": TOP_K, "train_size": len(train_samples), "dev_size": len(dev_samples), "label_count": len(intents), "labels": intents, "best_dev_metrics": best_metrics, "device": str(device), } (OUTPUT_DIR / "train_summary.json").write_text( json.dumps(train_summary, ensure_ascii=False, indent=2), encoding="utf-8", ) print(json.dumps(train_summary, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()