from __future__ import annotations import json import os import random from dataclasses import dataclass from pathlib import Path import torch from torch.utils.data import DataLoader, Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer import yaml PROJECT_ROOT = Path(__file__).resolve().parents[1] TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl" DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml" OUTPUT_DIR = PROJECT_ROOT / "models/local_bert_intent" DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base" MAX_LENGTH = 48 BATCH_SIZE = 8 EPOCHS = 16 LEARNING_RATE = 2e-5 SEED = 42 ORDER_IDS = ["A123456", "A700001", "A800002", "A900005", "A202501", "A808001"] DESTINATIONS = ["公司停车场", "浦东机场", "徐家汇", "虹桥机场", "最近的充电站", "南京东路"] TEMPERATURES = [18, 20, 21, 22, 23, 24, 26] SONGS = ["夜曲", "稻香", "青花瓷", "晴天", "告白气球"] GENRES = ["轻音乐", "摇滚", "古典音乐", "民谣", "爵士"] SOCIAL_LABEL = "__social__" OUT_OF_SCOPE_LABEL = "__out_of_scope__" TEMPLATES: dict[str, list[str]] = { "cs_query_order": [ "查一下订单{order_id}现在什么状态", "我的订单{order_id}到哪一步了", "帮我看看{order_id}这个订单", "确认下{order_id}订单状态", "订单{order_id}现在处理到哪里", "看下{order_id}这单进度", "订单号{order_id}目前怎么样", "帮忙确认订单{order_id}", "订单{order_id}有结果了吗", "帮我追一下订单{order_id}", "订单{order_id}现在受理了吗", "看看{order_id}这单现在啥情况", "帮我查查{order_id}订单进展", "{order_id}这个订单处理好了没", "{order_id}这笔订单现在进展到哪了", ], "cs_query_logistics": [ "快递{order_id}到哪儿了", "帮我查{order_id}物流进度", "看看{order_id}配送状态", "订单{order_id}物流更新了吗", "查询{order_id}的快递信息", "我的{order_id}现在派送到哪了", "查一下{order_id}这单物流", "配送单{order_id}走到哪里了", "帮我看下{order_id}快递到没到", "物流单号{order_id}现在在哪", "订单{order_id}物流到哪一步了", "{order_id}这单现在派件了吗", "帮我追踪{order_id}运输轨迹", "{order_id}快件现在运到哪里了", "我想看{order_id}的配送更新", ], "cs_cancel_order": [ "帮我取消{order_id}这个订单", "{order_id}别要了给我撤销", "把订单{order_id}取消掉", "我不要{order_id}了", "撤销一下{order_id}订单", "订单{order_id}不要发了", "帮我把{order_id}退掉并取消", "把{order_id}这一单停掉", "{order_id}这单直接取消", "订单号{order_id}撤回一下", "订单{order_id}我不想要了", "{order_id}这笔订单先别发了", "把{order_id}这单给我撤单", "订单{order_id}停掉吧", "{order_id}这个快给我取消了", ], "cs_transfer_human": [ "我要找人工客服处理", "现在转人工", "麻烦给我接人工服务", "帮我呼叫真人客服", "别机器人了我要人工", "转真人客服", "我要人工坐席", "帮我接人工处理", "叫人工客服来", "直接给我转人工", "这个问题给我人工跟进", "安排真人客服接手", "机器人处理不了,转人工", "帮我叫个客服专员", "我要人工来处理这事", ], "cabin_nav_to": [ "导航到{destination}", "带我去{destination}", "我要去{destination}", "去{destination}", "开导航去{destination}", "帮我导航到{destination}", "送我去{destination}", "现在去{destination}", "带路到{destination}", "去一下{destination}", "规划路线去{destination}", "直接开去{destination}", "给我导到{destination}", "{destination}怎么走,导航一下", "出发去{destination}", ], "cabin_set_ac": [ "把空调设到{temperature}度", "车里温度调成{temperature}度", "冷气开到{temperature}度", "空调给我调到{temperature}度", "温度改成{temperature}度", "车内设成{temperature}度", "把温度打到{temperature}度", "空调调为{temperature}度", "帮我把车里调成{temperature}度", "冷风调到{temperature}度", "把车内温度设为{temperature}度", "空调温度改到{temperature}度", "冷气帮我调到{temperature}度", "舱内调成{temperature}度", "给我把温度定在{temperature}度", "把车里弄凉快点", "车里太热了,降一点", "把里面调凉快一点", "有点热,降温", "空调再冷一点", "车内温度低一点", "把里面弄暖和点", "车里太冷了,升一点温度", ], "cabin_ac_on": [ "把空调打开", "开一下冷气", "把冷风开起来", "车里热,空调开开", "打开制冷", "空调启动一下", ], "cabin_window_open": [ "把车窗打开", "开下窗", "窗户开一点", "帮我透透气", "车里太闷了,开下窗", "顺便开下车窗", "把窗户降一点", "把玻璃打开一点", ], "cabin_window_close": [ "把车窗关上", "窗户关一下", "把窗升起来", "外面太吵了,把窗关了", "把窗户关严", ], "cabin_fan_down": [ "风别这么大", "风小一点", "别吹这么猛", "把风量调小一点", "出风弱一点", ], "cabin_fan_up": [ "风再大一点", "把风量开大点", "出风强一点", "风不够,调大些", ], "cabin_defog_front_on": [ "前挡起雾了,除一下", "把前挡风玻璃雾气清掉", "前窗看不清了,开除雾", ], "cabin_defog_rear_on": [ "后挡有雾,开下除雾", "后玻璃起雾了,清一下", "后窗看不清了,除雾", ], "cabin_play_music": [ "播放一首{genre}", "来点{genre}", "我想听{genre}", "给我播点{genre}", "放一首{song}", "来一首{song}", "播放{song}", "放点音乐,来个{genre}", "我想听首{song}", "给我来点歌,放{song}", "随机放点{genre}", "帮我播首{song}", "来点适合开车听的{genre}", "打开音乐,放{song}", "给我放一些{genre}", "放点歌", "来首歌", "整点音乐", "车里放点歌", "来点能听的", ], SOCIAL_LABEL: [ "你好", "嗨", "哈喽", "早上好", "晚上好", "你叫什么名字", "你是谁", "你能做什么", "今天天气不错", "陪我聊聊天", ], OUT_OF_SCOPE_LABEL: [ "帮我点个外卖", "订一张去北京的机票", "帮我买杯咖啡", "给我订一家酒店", "人类诞生的意义是什么", "帮我写一份年终总结", "推荐一部电影", "讲个笑话", "帮我做一道数学题", "去美团叫个外卖", ], } INTENT_REPLACEMENTS: dict[str, list[tuple[str, str]]] = { "cs_query_order": [ ("订单", "这单"), ("查一下", "看一下"), ("帮我", "麻烦帮我"), ("现在什么状态", "现在啥状态"), ("处理到哪里", "进展到哪里"), ], "cs_query_logistics": [ ("物流", "快递"), ("快递", "配送"), ("配送", "派送"), ("帮我", "麻烦帮我"), ("现在在哪", "现在到哪了"), ], "cs_cancel_order": [ ("取消", "撤销"), ("撤销", "撤单"), ("订单", "这单"), ("帮我", "麻烦帮我"), ("不要发了", "别发了"), ], "cs_transfer_human": [ ("人工客服", "真人客服"), ("人工", "人工坐席"), ("帮我", "麻烦帮我"), ], "cabin_nav_to": [ ("导航", "带路"), ("带我", "送我"), ("去", "前往"), ], "cabin_set_ac": [ ("空调", "车里温度"), ("调到", "设到"), ("温度", "车内温度"), ("凉快点", "冷一点"), ("暖和点", "热一点"), ], "cabin_ac_on": [ ("空调", "冷气"), ("打开", "开"), ("冷风", "制冷"), ], "cabin_window_open": [ ("车窗", "窗户"), ("打开", "开"), ("透透气", "通通风"), ], "cabin_window_close": [ ("关上", "关掉"), ("车窗", "窗户"), ("关严", "关好"), ], "cabin_fan_down": [ ("风量", "风"), ("调小", "调低"), ("别吹这么猛", "风小一点"), ], "cabin_fan_up": [ ("风量", "风"), ("调大", "调高"), ], "cabin_defog_front_on": [ ("前挡", "前窗"), ("除雾", "除一下"), ], "cabin_defog_rear_on": [ ("后挡", "后窗"), ("除雾", "清一下雾"), ], "cabin_play_music": [ ("播放", "放"), ("来点", "播点"), ("我想听", "给我来点"), ("放点歌", "来首歌"), ], SOCIAL_LABEL: [ ("你好", "您好"), ("哈喽", "hello"), ("你叫什么名字", "怎么称呼你"), ], OUT_OF_SCOPE_LABEL: [ ("点个外卖", "叫个外卖"), ("订一家酒店", "订个酒店"), ("讲个笑话", "说个笑话"), ], } @dataclass class Sample: text: str intent_id: str HARD_NEGATIVE_RAW_SAMPLES: list[tuple[str, str]] = [ ("订单A700101物流到哪了", "cs_query_logistics"), ("帮我看下订单A700102配送到哪里了", "cs_query_logistics"), ("订单A700103现在派件了吗", "cs_query_logistics"), ("A700104这单物流有没有更新", "cs_query_logistics"), ("查一下订单A700105运输轨迹", "cs_query_logistics"), ("订单A700106不要了,帮我撤单", "cs_cancel_order"), ("A700107这单别发了,直接取消", "cs_cancel_order"), ("把订单A700108停掉吧", "cs_cancel_order"), ("A700109这个订单我不想要了", "cs_cancel_order"), ("订单A700110给我撤回", "cs_cancel_order"), ("订单A700111现在受理了吗", "cs_query_order"), ("帮我看看A700112这单处理得怎么样了", "cs_query_order"), ("A700113订单目前进展如何", "cs_query_order"), ("查下订单A700114现在什么情况", "cs_query_order"), ("帮我确认订单A700115是否已经处理", "cs_query_order"), ("你好呀", SOCIAL_LABEL), ("嗨,在吗", SOCIAL_LABEL), ("今天天气真不错", SOCIAL_LABEL), ("你叫什么名字呀", SOCIAL_LABEL), ("你是做什么的", SOCIAL_LABEL), ("陪我随便聊聊", SOCIAL_LABEL), ("帮我透透气", "cabin_window_open"), ("车里太闷了,开下窗", "cabin_window_open"), ("把窗户降一点", "cabin_window_open"), ("前挡起雾了,除一下", "cabin_defog_front_on"), ("后挡有雾,开下除雾", "cabin_defog_rear_on"), ("把车里弄凉快点", "cabin_set_ac"), ("车里太热了,降一点", "cabin_set_ac"), ("空调再冷一点", "cabin_set_ac"), ("把里面弄暖和点", "cabin_set_ac"), ("风别这么大", "cabin_fan_down"), ("别吹这么猛", "cabin_fan_down"), ("风再大一点", "cabin_fan_up"), ("放点歌", "cabin_play_music"), ("来首歌", "cabin_play_music"), ("整点音乐", "cabin_play_music"), ("透透气", "cabin_window_open"), ("通通风", "cabin_window_open"), ("车里太闷了", "cabin_window_open"), ("凉快点", "cabin_set_ac"), ("暖和点", "cabin_set_ac"), ("帮我点一份麻辣烫", OUT_OF_SCOPE_LABEL), ("给我订今晚的酒店", OUT_OF_SCOPE_LABEL), ("帮我买张电影票", OUT_OF_SCOPE_LABEL), ("人为什么会做梦", OUT_OF_SCOPE_LABEL), ("帮我做个旅游攻略", OUT_OF_SCOPE_LABEL), ("帮我点肯德基外卖", OUT_OF_SCOPE_LABEL), ("透透气,别给我除雾", "cabin_window_open"), ("后挡有雾,不是开窗,是除雾", "cabin_defog_rear_on"), ("前挡看不清了,除雾不要开窗", "cabin_defog_front_on"), ("凉快点,不是把风量调小", "cabin_set_ac"), ("别吹这么猛,不是降温", "cabin_fan_down"), ("来首歌,不是切下一首", "cabin_play_music"), ("放点歌,不是暂停音乐", "cabin_play_music"), ] HARD_NEGATIVE_SAMPLES: list[Sample] = [ Sample(text=text, intent_id=intent_id) for text, intent_id in HARD_NEGATIVE_RAW_SAMPLES ] class IntentDataset(Dataset): def __init__(self, samples: list[Sample], tokenizer, label_to_id: dict[str, int]) -> None: self._samples = samples self._tokenizer = tokenizer self._label_to_id = label_to_id def __len__(self) -> int: return len(self._samples) def __getitem__(self, index: int) -> dict[str, torch.Tensor]: sample = self._samples[index] encoded = self._tokenizer( sample.text, truncation=True, padding="max_length", max_length=MAX_LENGTH, return_tensors="pt", ) return { "input_ids": encoded["input_ids"].squeeze(0), "attention_mask": encoded["attention_mask"].squeeze(0), "labels": torch.tensor(self._label_to_id[sample.intent_id], dtype=torch.long), } def set_seed(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) def load_samples(file_path: Path) -> list[Sample]: samples: list[Sample] = [] if not file_path.exists(): return samples for line in file_path.read_text(encoding="utf-8").splitlines(): line = line.strip() if not line: continue payload = json.loads(line) samples.append(Sample(text=str(payload["text"]), intent_id=str(payload["intent_id"]))) return samples def load_domain_samples(file_path: Path) -> list[Sample]: if not file_path.exists(): return [] payload = yaml.safe_load(file_path.read_text(encoding="utf-8")) or {} intents = payload.get("intents", []) samples: list[Sample] = [] seen: set[tuple[str, str]] = set() for item in intents: intent_id = str(item.get("intent_id") or "").strip() if not intent_id: continue seed_texts = list(item.get("examples") or []) seed_texts.extend(item.get("keywords") or []) label = str(item.get("label") or "").strip() if label: seed_texts.append(label) for text in seed_texts: normalized = str(text).strip() if not normalized: continue for variant in expand_seed_variants(normalized): key = (variant, intent_id) if key in seen: continue seen.add(key) samples.append(Sample(text=variant, intent_id=intent_id)) return samples def expand_seed_variants(text: str) -> list[str]: normalized = text.strip().strip(",。!?;; ") if not normalized: return [] variants = { normalized, normalized.replace("一下", "").strip(), normalized.replace("帮我", "").strip(), normalized.replace("请", "").strip(), f"帮我{normalized}", f"请{normalized}", f"{normalized}一下", } cleaned: list[str] = [] for item in variants: compact = " ".join(item.split()).strip(",。!?;; ") if compact: cleaned.append(compact) return cleaned def load_training_samples() -> list[Sample]: samples = load_samples(TRAIN_PATH) samples.extend(load_domain_samples(DOMAIN_PATH)) deduped: list[Sample] = [] seen: set[tuple[str, str]] = set() for sample in samples: key = (sample.text, sample.intent_id) if key in seen: continue seen.add(key) deduped.append(sample) return deduped def augment_samples(samples: list[Sample]) -> list[Sample]: augmented = list(samples) seen = {(sample.text, sample.intent_id) for sample in augmented} for intent_id, templates in TEMPLATES.items(): for index, template in enumerate(templates): sample = render_template(intent_id, template, index) key = (sample.text, sample.intent_id) if key not in seen: augmented.append(sample) seen.add(key) for sample in HARD_NEGATIVE_SAMPLES: key = (sample.text, sample.intent_id) if key not in seen: augmented.append(sample) seen.add(key) for sample in list(augmented): text = sample.text for source, target in INTENT_REPLACEMENTS.get(sample.intent_id, []): if source in text: variant = text.replace(source, target, 1) key = (variant, sample.intent_id) if key not in seen: augmented.append(Sample(text=variant, intent_id=sample.intent_id)) seen.add(key) compact = text for token in ("帮我", "麻烦", "请", "一下"): if token in compact: compact = compact.replace(token, "", 1) compact = compact.strip(" ,。!?") if compact and compact != text: key = (compact, sample.intent_id) if key not in seen: augmented.append(Sample(text=compact, intent_id=sample.intent_id)) seen.add(key) random.shuffle(augmented) return augmented def render_template(intent_id: str, template: str, index: int) -> Sample: order_id = ORDER_IDS[index % len(ORDER_IDS)] destination = DESTINATIONS[index % len(DESTINATIONS)] temperature = TEMPERATURES[index % len(TEMPERATURES)] song = SONGS[index % len(SONGS)] genre = GENRES[index % len(GENRES)] text = template.format( order_id=order_id, destination=destination, temperature=temperature, song=song, genre=genre, ) return Sample(text=text, intent_id=intent_id) def split_samples(samples: list[Sample]) -> tuple[list[Sample], list[Sample]]: grouped: dict[str, list[Sample]] = {} for sample in samples: grouped.setdefault(sample.intent_id, []).append(sample) train_samples: list[Sample] = [] dev_samples: list[Sample] = [] for items in grouped.values(): random.shuffle(items) cut = max(1, int(len(items) * 0.8)) train_samples.extend(items[:cut]) dev_samples.extend(items[cut:]) random.shuffle(train_samples) random.shuffle(dev_samples) return train_samples, dev_samples def accuracy(model, loader: DataLoader, device: torch.device) -> float: model.eval() correct = 0 total = 0 with torch.no_grad(): for batch in loader: input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask) preds = outputs.logits.argmax(dim=-1) correct += int((preds == labels).sum().item()) total += int(labels.numel()) return correct / total if total else 0.0 def resolve_base_model() -> str: configured = os.getenv("AGENT_BERT_BASE_MODEL", "").strip() if configured: return configured return DEFAULT_BASE_MODEL def main() -> None: set_seed(SEED) samples = augment_samples(load_training_samples()) intents = sorted({sample.intent_id for sample in samples}) label_to_id = {intent_id: index for index, intent_id in enumerate(intents)} id_to_label = {index: intent_id for intent_id, index in label_to_id.items()} train_samples, dev_samples = split_samples(samples) base_model = resolve_base_model() tokenizer = AutoTokenizer.from_pretrained(base_model) train_dataset = IntentDataset(train_samples, tokenizer, label_to_id) dev_dataset = IntentDataset(dev_samples, tokenizer, label_to_id) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE) model = AutoModelForSequenceClassification.from_pretrained( base_model, num_labels=len(intents), id2label=id_to_label, label2id=label_to_id, ) device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) best_dev_acc = 0.0 best_state = None for epoch in range(1, EPOCHS + 1): model.train() total_loss = 0.0 for batch in train_loader: optimizer.zero_grad() input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss loss.backward() optimizer.step() total_loss += float(loss.item()) dev_acc = accuracy(model, dev_loader, device) avg_loss = total_loss / max(len(train_loader), 1) print(f"epoch={epoch} loss={avg_loss:.4f} dev_acc={dev_acc:.4f}") if dev_acc >= best_dev_acc: best_dev_acc = dev_acc best_state = {key: value.detach().cpu().clone() for key, value in model.state_dict().items()} if best_state is not None: model.load_state_dict(best_state) OUTPUT_DIR.mkdir(parents=True, exist_ok=True) model.save_pretrained(OUTPUT_DIR) tokenizer.save_pretrained(OUTPUT_DIR) label_map = {f"LABEL_{index}": intent_id for index, intent_id in id_to_label.items()} (OUTPUT_DIR / "label_map.json").write_text( json.dumps(label_map, ensure_ascii=False, indent=2), encoding="utf-8", ) train_summary = { "base_model": base_model, "epochs": EPOCHS, "batch_size": BATCH_SIZE, "learning_rate": LEARNING_RATE, "train_size": len(train_samples), "dev_size": len(dev_samples), "best_dev_accuracy": round(best_dev_acc, 4), "device": str(device), } (OUTPUT_DIR / "train_summary.json").write_text( json.dumps(train_summary, ensure_ascii=False, indent=2), encoding="utf-8", ) print(json.dumps(train_summary, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()