685 lines
23 KiB
Python
685 lines
23 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
import random
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
from torch.utils.data import DataLoader, Dataset
|
||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||
import yaml
|
||
|
||
|
||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||
TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
|
||
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
|
||
OUTPUT_DIR = PROJECT_ROOT / "models/local_bert_intent"
|
||
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
|
||
MAX_LENGTH = 48
|
||
BATCH_SIZE = 8
|
||
EPOCHS = 16
|
||
LEARNING_RATE = 2e-5
|
||
SEED = 42
|
||
|
||
ORDER_IDS = ["A123456", "A700001", "A800002", "A900005", "A202501", "A808001"]
|
||
DESTINATIONS = ["公司停车场", "浦东机场", "徐家汇", "虹桥机场", "最近的充电站", "南京东路"]
|
||
TEMPERATURES = [18, 20, 21, 22, 23, 24, 26]
|
||
SONGS = ["夜曲", "稻香", "青花瓷", "晴天", "告白气球"]
|
||
GENRES = ["轻音乐", "摇滚", "古典音乐", "民谣", "爵士"]
|
||
SOCIAL_LABEL = "__social__"
|
||
OUT_OF_SCOPE_LABEL = "__out_of_scope__"
|
||
|
||
TEMPLATES: dict[str, list[str]] = {
|
||
"cs_query_order": [
|
||
"查一下订单{order_id}现在什么状态",
|
||
"我的订单{order_id}到哪一步了",
|
||
"帮我看看{order_id}这个订单",
|
||
"确认下{order_id}订单状态",
|
||
"订单{order_id}现在处理到哪里",
|
||
"看下{order_id}这单进度",
|
||
"订单号{order_id}目前怎么样",
|
||
"帮忙确认订单{order_id}",
|
||
"订单{order_id}有结果了吗",
|
||
"帮我追一下订单{order_id}",
|
||
"订单{order_id}现在受理了吗",
|
||
"看看{order_id}这单现在啥情况",
|
||
"帮我查查{order_id}订单进展",
|
||
"{order_id}这个订单处理好了没",
|
||
"{order_id}这笔订单现在进展到哪了",
|
||
],
|
||
"cs_query_logistics": [
|
||
"快递{order_id}到哪儿了",
|
||
"帮我查{order_id}物流进度",
|
||
"看看{order_id}配送状态",
|
||
"订单{order_id}物流更新了吗",
|
||
"查询{order_id}的快递信息",
|
||
"我的{order_id}现在派送到哪了",
|
||
"查一下{order_id}这单物流",
|
||
"配送单{order_id}走到哪里了",
|
||
"帮我看下{order_id}快递到没到",
|
||
"物流单号{order_id}现在在哪",
|
||
"订单{order_id}物流到哪一步了",
|
||
"{order_id}这单现在派件了吗",
|
||
"帮我追踪{order_id}运输轨迹",
|
||
"{order_id}快件现在运到哪里了",
|
||
"我想看{order_id}的配送更新",
|
||
],
|
||
"cs_cancel_order": [
|
||
"帮我取消{order_id}这个订单",
|
||
"{order_id}别要了给我撤销",
|
||
"把订单{order_id}取消掉",
|
||
"我不要{order_id}了",
|
||
"撤销一下{order_id}订单",
|
||
"订单{order_id}不要发了",
|
||
"帮我把{order_id}退掉并取消",
|
||
"把{order_id}这一单停掉",
|
||
"{order_id}这单直接取消",
|
||
"订单号{order_id}撤回一下",
|
||
"订单{order_id}我不想要了",
|
||
"{order_id}这笔订单先别发了",
|
||
"把{order_id}这单给我撤单",
|
||
"订单{order_id}停掉吧",
|
||
"{order_id}这个快给我取消了",
|
||
],
|
||
"cs_transfer_human": [
|
||
"我要找人工客服处理",
|
||
"现在转人工",
|
||
"麻烦给我接人工服务",
|
||
"帮我呼叫真人客服",
|
||
"别机器人了我要人工",
|
||
"转真人客服",
|
||
"我要人工坐席",
|
||
"帮我接人工处理",
|
||
"叫人工客服来",
|
||
"直接给我转人工",
|
||
"这个问题给我人工跟进",
|
||
"安排真人客服接手",
|
||
"机器人处理不了,转人工",
|
||
"帮我叫个客服专员",
|
||
"我要人工来处理这事",
|
||
],
|
||
"cabin_nav_to": [
|
||
"导航到{destination}",
|
||
"带我去{destination}",
|
||
"我要去{destination}",
|
||
"去{destination}",
|
||
"开导航去{destination}",
|
||
"帮我导航到{destination}",
|
||
"送我去{destination}",
|
||
"现在去{destination}",
|
||
"带路到{destination}",
|
||
"去一下{destination}",
|
||
"规划路线去{destination}",
|
||
"直接开去{destination}",
|
||
"给我导到{destination}",
|
||
"{destination}怎么走,导航一下",
|
||
"出发去{destination}",
|
||
],
|
||
"cabin_set_ac": [
|
||
"把空调设到{temperature}度",
|
||
"车里温度调成{temperature}度",
|
||
"冷气开到{temperature}度",
|
||
"空调给我调到{temperature}度",
|
||
"温度改成{temperature}度",
|
||
"车内设成{temperature}度",
|
||
"把温度打到{temperature}度",
|
||
"空调调为{temperature}度",
|
||
"帮我把车里调成{temperature}度",
|
||
"冷风调到{temperature}度",
|
||
"把车内温度设为{temperature}度",
|
||
"空调温度改到{temperature}度",
|
||
"冷气帮我调到{temperature}度",
|
||
"舱内调成{temperature}度",
|
||
"给我把温度定在{temperature}度",
|
||
"把车里弄凉快点",
|
||
"车里太热了,降一点",
|
||
"把里面调凉快一点",
|
||
"有点热,降温",
|
||
"空调再冷一点",
|
||
"车内温度低一点",
|
||
"把里面弄暖和点",
|
||
"车里太冷了,升一点温度",
|
||
],
|
||
"cabin_ac_on": [
|
||
"把空调打开",
|
||
"开一下冷气",
|
||
"把冷风开起来",
|
||
"车里热,空调开开",
|
||
"打开制冷",
|
||
"空调启动一下",
|
||
],
|
||
"cabin_window_open": [
|
||
"把车窗打开",
|
||
"开下窗",
|
||
"窗户开一点",
|
||
"帮我透透气",
|
||
"车里太闷了,开下窗",
|
||
"顺便开下车窗",
|
||
"把窗户降一点",
|
||
"把玻璃打开一点",
|
||
],
|
||
"cabin_window_close": [
|
||
"把车窗关上",
|
||
"窗户关一下",
|
||
"把窗升起来",
|
||
"外面太吵了,把窗关了",
|
||
"把窗户关严",
|
||
],
|
||
"cabin_fan_down": [
|
||
"风别这么大",
|
||
"风小一点",
|
||
"别吹这么猛",
|
||
"把风量调小一点",
|
||
"出风弱一点",
|
||
],
|
||
"cabin_fan_up": [
|
||
"风再大一点",
|
||
"把风量开大点",
|
||
"出风强一点",
|
||
"风不够,调大些",
|
||
],
|
||
"cabin_defog_front_on": [
|
||
"前挡起雾了,除一下",
|
||
"把前挡风玻璃雾气清掉",
|
||
"前窗看不清了,开除雾",
|
||
],
|
||
"cabin_defog_rear_on": [
|
||
"后挡有雾,开下除雾",
|
||
"后玻璃起雾了,清一下",
|
||
"后窗看不清了,除雾",
|
||
],
|
||
"cabin_play_music": [
|
||
"播放一首{genre}",
|
||
"来点{genre}",
|
||
"我想听{genre}",
|
||
"给我播点{genre}",
|
||
"放一首{song}",
|
||
"来一首{song}",
|
||
"播放{song}",
|
||
"放点音乐,来个{genre}",
|
||
"我想听首{song}",
|
||
"给我来点歌,放{song}",
|
||
"随机放点{genre}",
|
||
"帮我播首{song}",
|
||
"来点适合开车听的{genre}",
|
||
"打开音乐,放{song}",
|
||
"给我放一些{genre}",
|
||
"放点歌",
|
||
"来首歌",
|
||
"整点音乐",
|
||
"车里放点歌",
|
||
"来点能听的",
|
||
],
|
||
SOCIAL_LABEL: [
|
||
"你好",
|
||
"嗨",
|
||
"哈喽",
|
||
"早上好",
|
||
"晚上好",
|
||
"你叫什么名字",
|
||
"你是谁",
|
||
"你能做什么",
|
||
"今天天气不错",
|
||
"陪我聊聊天",
|
||
],
|
||
OUT_OF_SCOPE_LABEL: [
|
||
"帮我点个外卖",
|
||
"订一张去北京的机票",
|
||
"帮我买杯咖啡",
|
||
"给我订一家酒店",
|
||
"人类诞生的意义是什么",
|
||
"帮我写一份年终总结",
|
||
"推荐一部电影",
|
||
"讲个笑话",
|
||
"帮我做一道数学题",
|
||
"去美团叫个外卖",
|
||
],
|
||
}
|
||
|
||
INTENT_REPLACEMENTS: dict[str, list[tuple[str, str]]] = {
|
||
"cs_query_order": [
|
||
("订单", "这单"),
|
||
("查一下", "看一下"),
|
||
("帮我", "麻烦帮我"),
|
||
("现在什么状态", "现在啥状态"),
|
||
("处理到哪里", "进展到哪里"),
|
||
],
|
||
"cs_query_logistics": [
|
||
("物流", "快递"),
|
||
("快递", "配送"),
|
||
("配送", "派送"),
|
||
("帮我", "麻烦帮我"),
|
||
("现在在哪", "现在到哪了"),
|
||
],
|
||
"cs_cancel_order": [
|
||
("取消", "撤销"),
|
||
("撤销", "撤单"),
|
||
("订单", "这单"),
|
||
("帮我", "麻烦帮我"),
|
||
("不要发了", "别发了"),
|
||
],
|
||
"cs_transfer_human": [
|
||
("人工客服", "真人客服"),
|
||
("人工", "人工坐席"),
|
||
("帮我", "麻烦帮我"),
|
||
],
|
||
"cabin_nav_to": [
|
||
("导航", "带路"),
|
||
("带我", "送我"),
|
||
("去", "前往"),
|
||
],
|
||
"cabin_set_ac": [
|
||
("空调", "车里温度"),
|
||
("调到", "设到"),
|
||
("温度", "车内温度"),
|
||
("凉快点", "冷一点"),
|
||
("暖和点", "热一点"),
|
||
],
|
||
"cabin_ac_on": [
|
||
("空调", "冷气"),
|
||
("打开", "开"),
|
||
("冷风", "制冷"),
|
||
],
|
||
"cabin_window_open": [
|
||
("车窗", "窗户"),
|
||
("打开", "开"),
|
||
("透透气", "通通风"),
|
||
],
|
||
"cabin_window_close": [
|
||
("关上", "关掉"),
|
||
("车窗", "窗户"),
|
||
("关严", "关好"),
|
||
],
|
||
"cabin_fan_down": [
|
||
("风量", "风"),
|
||
("调小", "调低"),
|
||
("别吹这么猛", "风小一点"),
|
||
],
|
||
"cabin_fan_up": [
|
||
("风量", "风"),
|
||
("调大", "调高"),
|
||
],
|
||
"cabin_defog_front_on": [
|
||
("前挡", "前窗"),
|
||
("除雾", "除一下"),
|
||
],
|
||
"cabin_defog_rear_on": [
|
||
("后挡", "后窗"),
|
||
("除雾", "清一下雾"),
|
||
],
|
||
"cabin_play_music": [
|
||
("播放", "放"),
|
||
("来点", "播点"),
|
||
("我想听", "给我来点"),
|
||
("放点歌", "来首歌"),
|
||
],
|
||
SOCIAL_LABEL: [
|
||
("你好", "您好"),
|
||
("哈喽", "hello"),
|
||
("你叫什么名字", "怎么称呼你"),
|
||
],
|
||
OUT_OF_SCOPE_LABEL: [
|
||
("点个外卖", "叫个外卖"),
|
||
("订一家酒店", "订个酒店"),
|
||
("讲个笑话", "说个笑话"),
|
||
],
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class Sample:
|
||
text: str
|
||
intent_id: str
|
||
|
||
|
||
HARD_NEGATIVE_RAW_SAMPLES: list[tuple[str, str]] = [
|
||
("订单A700101物流到哪了", "cs_query_logistics"),
|
||
("帮我看下订单A700102配送到哪里了", "cs_query_logistics"),
|
||
("订单A700103现在派件了吗", "cs_query_logistics"),
|
||
("A700104这单物流有没有更新", "cs_query_logistics"),
|
||
("查一下订单A700105运输轨迹", "cs_query_logistics"),
|
||
("订单A700106不要了,帮我撤单", "cs_cancel_order"),
|
||
("A700107这单别发了,直接取消", "cs_cancel_order"),
|
||
("把订单A700108停掉吧", "cs_cancel_order"),
|
||
("A700109这个订单我不想要了", "cs_cancel_order"),
|
||
("订单A700110给我撤回", "cs_cancel_order"),
|
||
("订单A700111现在受理了吗", "cs_query_order"),
|
||
("帮我看看A700112这单处理得怎么样了", "cs_query_order"),
|
||
("A700113订单目前进展如何", "cs_query_order"),
|
||
("查下订单A700114现在什么情况", "cs_query_order"),
|
||
("帮我确认订单A700115是否已经处理", "cs_query_order"),
|
||
("你好呀", SOCIAL_LABEL),
|
||
("嗨,在吗", SOCIAL_LABEL),
|
||
("今天天气真不错", SOCIAL_LABEL),
|
||
("你叫什么名字呀", SOCIAL_LABEL),
|
||
("你是做什么的", SOCIAL_LABEL),
|
||
("陪我随便聊聊", SOCIAL_LABEL),
|
||
("帮我透透气", "cabin_window_open"),
|
||
("车里太闷了,开下窗", "cabin_window_open"),
|
||
("把窗户降一点", "cabin_window_open"),
|
||
("前挡起雾了,除一下", "cabin_defog_front_on"),
|
||
("后挡有雾,开下除雾", "cabin_defog_rear_on"),
|
||
("把车里弄凉快点", "cabin_set_ac"),
|
||
("车里太热了,降一点", "cabin_set_ac"),
|
||
("空调再冷一点", "cabin_set_ac"),
|
||
("把里面弄暖和点", "cabin_set_ac"),
|
||
("风别这么大", "cabin_fan_down"),
|
||
("别吹这么猛", "cabin_fan_down"),
|
||
("风再大一点", "cabin_fan_up"),
|
||
("放点歌", "cabin_play_music"),
|
||
("来首歌", "cabin_play_music"),
|
||
("整点音乐", "cabin_play_music"),
|
||
("透透气", "cabin_window_open"),
|
||
("通通风", "cabin_window_open"),
|
||
("车里太闷了", "cabin_window_open"),
|
||
("凉快点", "cabin_set_ac"),
|
||
("暖和点", "cabin_set_ac"),
|
||
("帮我点一份麻辣烫", OUT_OF_SCOPE_LABEL),
|
||
("给我订今晚的酒店", OUT_OF_SCOPE_LABEL),
|
||
("帮我买张电影票", OUT_OF_SCOPE_LABEL),
|
||
("人为什么会做梦", OUT_OF_SCOPE_LABEL),
|
||
("帮我做个旅游攻略", OUT_OF_SCOPE_LABEL),
|
||
("帮我点肯德基外卖", OUT_OF_SCOPE_LABEL),
|
||
("透透气,别给我除雾", "cabin_window_open"),
|
||
("后挡有雾,不是开窗,是除雾", "cabin_defog_rear_on"),
|
||
("前挡看不清了,除雾不要开窗", "cabin_defog_front_on"),
|
||
("凉快点,不是把风量调小", "cabin_set_ac"),
|
||
("别吹这么猛,不是降温", "cabin_fan_down"),
|
||
("来首歌,不是切下一首", "cabin_play_music"),
|
||
("放点歌,不是暂停音乐", "cabin_play_music"),
|
||
]
|
||
|
||
HARD_NEGATIVE_SAMPLES: list[Sample] = [
|
||
Sample(text=text, intent_id=intent_id) for text, intent_id in HARD_NEGATIVE_RAW_SAMPLES
|
||
]
|
||
|
||
|
||
class IntentDataset(Dataset):
|
||
def __init__(self, samples: list[Sample], tokenizer, label_to_id: dict[str, int]) -> None:
|
||
self._samples = samples
|
||
self._tokenizer = tokenizer
|
||
self._label_to_id = label_to_id
|
||
|
||
def __len__(self) -> int:
|
||
return len(self._samples)
|
||
|
||
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
||
sample = self._samples[index]
|
||
encoded = self._tokenizer(
|
||
sample.text,
|
||
truncation=True,
|
||
padding="max_length",
|
||
max_length=MAX_LENGTH,
|
||
return_tensors="pt",
|
||
)
|
||
return {
|
||
"input_ids": encoded["input_ids"].squeeze(0),
|
||
"attention_mask": encoded["attention_mask"].squeeze(0),
|
||
"labels": torch.tensor(self._label_to_id[sample.intent_id], dtype=torch.long),
|
||
}
|
||
|
||
|
||
def set_seed(seed: int) -> None:
|
||
random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
if torch.backends.mps.is_available():
|
||
torch.mps.manual_seed(seed)
|
||
|
||
|
||
def load_samples(file_path: Path) -> list[Sample]:
|
||
samples: list[Sample] = []
|
||
if not file_path.exists():
|
||
return samples
|
||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
payload = json.loads(line)
|
||
samples.append(Sample(text=str(payload["text"]), intent_id=str(payload["intent_id"])))
|
||
return samples
|
||
|
||
|
||
def load_domain_samples(file_path: Path) -> list[Sample]:
|
||
if not file_path.exists():
|
||
return []
|
||
payload = yaml.safe_load(file_path.read_text(encoding="utf-8")) or {}
|
||
intents = payload.get("intents", [])
|
||
samples: list[Sample] = []
|
||
seen: set[tuple[str, str]] = set()
|
||
for item in intents:
|
||
intent_id = str(item.get("intent_id") or "").strip()
|
||
if not intent_id:
|
||
continue
|
||
seed_texts = list(item.get("examples") or [])
|
||
seed_texts.extend(item.get("keywords") or [])
|
||
label = str(item.get("label") or "").strip()
|
||
if label:
|
||
seed_texts.append(label)
|
||
for text in seed_texts:
|
||
normalized = str(text).strip()
|
||
if not normalized:
|
||
continue
|
||
for variant in expand_seed_variants(normalized):
|
||
key = (variant, intent_id)
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
samples.append(Sample(text=variant, intent_id=intent_id))
|
||
return samples
|
||
|
||
|
||
def expand_seed_variants(text: str) -> list[str]:
|
||
normalized = text.strip().strip(",。!?;; ")
|
||
if not normalized:
|
||
return []
|
||
variants = {
|
||
normalized,
|
||
normalized.replace("一下", "").strip(),
|
||
normalized.replace("帮我", "").strip(),
|
||
normalized.replace("请", "").strip(),
|
||
f"帮我{normalized}",
|
||
f"请{normalized}",
|
||
f"{normalized}一下",
|
||
}
|
||
cleaned: list[str] = []
|
||
for item in variants:
|
||
compact = " ".join(item.split()).strip(",。!?;; ")
|
||
if compact:
|
||
cleaned.append(compact)
|
||
return cleaned
|
||
|
||
|
||
def load_training_samples() -> list[Sample]:
|
||
samples = load_samples(TRAIN_PATH)
|
||
samples.extend(load_domain_samples(DOMAIN_PATH))
|
||
deduped: list[Sample] = []
|
||
seen: set[tuple[str, str]] = set()
|
||
for sample in samples:
|
||
key = (sample.text, sample.intent_id)
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
deduped.append(sample)
|
||
return deduped
|
||
|
||
|
||
def augment_samples(samples: list[Sample]) -> list[Sample]:
|
||
augmented = list(samples)
|
||
seen = {(sample.text, sample.intent_id) for sample in augmented}
|
||
for intent_id, templates in TEMPLATES.items():
|
||
for index, template in enumerate(templates):
|
||
sample = render_template(intent_id, template, index)
|
||
key = (sample.text, sample.intent_id)
|
||
if key not in seen:
|
||
augmented.append(sample)
|
||
seen.add(key)
|
||
|
||
for sample in HARD_NEGATIVE_SAMPLES:
|
||
key = (sample.text, sample.intent_id)
|
||
if key not in seen:
|
||
augmented.append(sample)
|
||
seen.add(key)
|
||
|
||
for sample in list(augmented):
|
||
text = sample.text
|
||
for source, target in INTENT_REPLACEMENTS.get(sample.intent_id, []):
|
||
if source in text:
|
||
variant = text.replace(source, target, 1)
|
||
key = (variant, sample.intent_id)
|
||
if key not in seen:
|
||
augmented.append(Sample(text=variant, intent_id=sample.intent_id))
|
||
seen.add(key)
|
||
|
||
compact = text
|
||
for token in ("帮我", "麻烦", "请", "一下"):
|
||
if token in compact:
|
||
compact = compact.replace(token, "", 1)
|
||
compact = compact.strip(" ,。!?")
|
||
if compact and compact != text:
|
||
key = (compact, sample.intent_id)
|
||
if key not in seen:
|
||
augmented.append(Sample(text=compact, intent_id=sample.intent_id))
|
||
seen.add(key)
|
||
random.shuffle(augmented)
|
||
return augmented
|
||
|
||
|
||
def render_template(intent_id: str, template: str, index: int) -> Sample:
|
||
order_id = ORDER_IDS[index % len(ORDER_IDS)]
|
||
destination = DESTINATIONS[index % len(DESTINATIONS)]
|
||
temperature = TEMPERATURES[index % len(TEMPERATURES)]
|
||
song = SONGS[index % len(SONGS)]
|
||
genre = GENRES[index % len(GENRES)]
|
||
text = template.format(
|
||
order_id=order_id,
|
||
destination=destination,
|
||
temperature=temperature,
|
||
song=song,
|
||
genre=genre,
|
||
)
|
||
return Sample(text=text, intent_id=intent_id)
|
||
|
||
|
||
def split_samples(samples: list[Sample]) -> tuple[list[Sample], list[Sample]]:
|
||
grouped: dict[str, list[Sample]] = {}
|
||
for sample in samples:
|
||
grouped.setdefault(sample.intent_id, []).append(sample)
|
||
train_samples: list[Sample] = []
|
||
dev_samples: list[Sample] = []
|
||
for items in grouped.values():
|
||
random.shuffle(items)
|
||
cut = max(1, int(len(items) * 0.8))
|
||
train_samples.extend(items[:cut])
|
||
dev_samples.extend(items[cut:])
|
||
random.shuffle(train_samples)
|
||
random.shuffle(dev_samples)
|
||
return train_samples, dev_samples
|
||
|
||
|
||
def accuracy(model, loader: DataLoader, device: torch.device) -> float:
|
||
model.eval()
|
||
correct = 0
|
||
total = 0
|
||
with torch.no_grad():
|
||
for batch in loader:
|
||
input_ids = batch["input_ids"].to(device)
|
||
attention_mask = batch["attention_mask"].to(device)
|
||
labels = batch["labels"].to(device)
|
||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||
preds = outputs.logits.argmax(dim=-1)
|
||
correct += int((preds == labels).sum().item())
|
||
total += int(labels.numel())
|
||
return correct / total if total else 0.0
|
||
|
||
|
||
def resolve_base_model() -> str:
|
||
configured = os.getenv("AGENT_BERT_BASE_MODEL", "").strip()
|
||
if configured:
|
||
return configured
|
||
return DEFAULT_BASE_MODEL
|
||
|
||
|
||
def main() -> None:
|
||
set_seed(SEED)
|
||
samples = augment_samples(load_training_samples())
|
||
intents = sorted({sample.intent_id for sample in samples})
|
||
label_to_id = {intent_id: index for index, intent_id in enumerate(intents)}
|
||
id_to_label = {index: intent_id for intent_id, index in label_to_id.items()}
|
||
|
||
train_samples, dev_samples = split_samples(samples)
|
||
base_model = resolve_base_model()
|
||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||
train_dataset = IntentDataset(train_samples, tokenizer, label_to_id)
|
||
dev_dataset = IntentDataset(dev_samples, tokenizer, label_to_id)
|
||
|
||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
|
||
|
||
model = AutoModelForSequenceClassification.from_pretrained(
|
||
base_model,
|
||
num_labels=len(intents),
|
||
id2label=id_to_label,
|
||
label2id=label_to_id,
|
||
)
|
||
|
||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||
model.to(device)
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||
|
||
best_dev_acc = 0.0
|
||
best_state = None
|
||
|
||
for epoch in range(1, EPOCHS + 1):
|
||
model.train()
|
||
total_loss = 0.0
|
||
for batch in train_loader:
|
||
optimizer.zero_grad()
|
||
input_ids = batch["input_ids"].to(device)
|
||
attention_mask = batch["attention_mask"].to(device)
|
||
labels = batch["labels"].to(device)
|
||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||
loss = outputs.loss
|
||
loss.backward()
|
||
optimizer.step()
|
||
total_loss += float(loss.item())
|
||
|
||
dev_acc = accuracy(model, dev_loader, device)
|
||
avg_loss = total_loss / max(len(train_loader), 1)
|
||
print(f"epoch={epoch} loss={avg_loss:.4f} dev_acc={dev_acc:.4f}")
|
||
if dev_acc >= best_dev_acc:
|
||
best_dev_acc = dev_acc
|
||
best_state = {key: value.detach().cpu().clone() for key, value in model.state_dict().items()}
|
||
|
||
if best_state is not None:
|
||
model.load_state_dict(best_state)
|
||
|
||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
model.save_pretrained(OUTPUT_DIR)
|
||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||
label_map = {f"LABEL_{index}": intent_id for index, intent_id in id_to_label.items()}
|
||
(OUTPUT_DIR / "label_map.json").write_text(
|
||
json.dumps(label_map, ensure_ascii=False, indent=2),
|
||
encoding="utf-8",
|
||
)
|
||
train_summary = {
|
||
"base_model": base_model,
|
||
"epochs": EPOCHS,
|
||
"batch_size": BATCH_SIZE,
|
||
"learning_rate": LEARNING_RATE,
|
||
"train_size": len(train_samples),
|
||
"dev_size": len(dev_samples),
|
||
"best_dev_accuracy": round(best_dev_acc, 4),
|
||
"device": str(device),
|
||
}
|
||
(OUTPUT_DIR / "train_summary.json").write_text(
|
||
json.dumps(train_summary, ensure_ascii=False, indent=2),
|
||
encoding="utf-8",
|
||
)
|
||
print(json.dumps(train_summary, ensure_ascii=False, indent=2))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|