Update project and configurations
This commit is contained in:
684
intelligent_cabin/archive/scripts/train_local_bert_intent.py
Normal file
684
intelligent_cabin/archive/scripts/train_local_bert_intent.py
Normal file
@@ -0,0 +1,684 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
import yaml
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
|
||||
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
|
||||
OUTPUT_DIR = PROJECT_ROOT / "models/local_bert_intent"
|
||||
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
|
||||
MAX_LENGTH = 48
|
||||
BATCH_SIZE = 8
|
||||
EPOCHS = 16
|
||||
LEARNING_RATE = 2e-5
|
||||
SEED = 42
|
||||
|
||||
ORDER_IDS = ["A123456", "A700001", "A800002", "A900005", "A202501", "A808001"]
|
||||
DESTINATIONS = ["公司停车场", "浦东机场", "徐家汇", "虹桥机场", "最近的充电站", "南京东路"]
|
||||
TEMPERATURES = [18, 20, 21, 22, 23, 24, 26]
|
||||
SONGS = ["夜曲", "稻香", "青花瓷", "晴天", "告白气球"]
|
||||
GENRES = ["轻音乐", "摇滚", "古典音乐", "民谣", "爵士"]
|
||||
SOCIAL_LABEL = "__social__"
|
||||
OUT_OF_SCOPE_LABEL = "__out_of_scope__"
|
||||
|
||||
TEMPLATES: dict[str, list[str]] = {
|
||||
"cs_query_order": [
|
||||
"查一下订单{order_id}现在什么状态",
|
||||
"我的订单{order_id}到哪一步了",
|
||||
"帮我看看{order_id}这个订单",
|
||||
"确认下{order_id}订单状态",
|
||||
"订单{order_id}现在处理到哪里",
|
||||
"看下{order_id}这单进度",
|
||||
"订单号{order_id}目前怎么样",
|
||||
"帮忙确认订单{order_id}",
|
||||
"订单{order_id}有结果了吗",
|
||||
"帮我追一下订单{order_id}",
|
||||
"订单{order_id}现在受理了吗",
|
||||
"看看{order_id}这单现在啥情况",
|
||||
"帮我查查{order_id}订单进展",
|
||||
"{order_id}这个订单处理好了没",
|
||||
"{order_id}这笔订单现在进展到哪了",
|
||||
],
|
||||
"cs_query_logistics": [
|
||||
"快递{order_id}到哪儿了",
|
||||
"帮我查{order_id}物流进度",
|
||||
"看看{order_id}配送状态",
|
||||
"订单{order_id}物流更新了吗",
|
||||
"查询{order_id}的快递信息",
|
||||
"我的{order_id}现在派送到哪了",
|
||||
"查一下{order_id}这单物流",
|
||||
"配送单{order_id}走到哪里了",
|
||||
"帮我看下{order_id}快递到没到",
|
||||
"物流单号{order_id}现在在哪",
|
||||
"订单{order_id}物流到哪一步了",
|
||||
"{order_id}这单现在派件了吗",
|
||||
"帮我追踪{order_id}运输轨迹",
|
||||
"{order_id}快件现在运到哪里了",
|
||||
"我想看{order_id}的配送更新",
|
||||
],
|
||||
"cs_cancel_order": [
|
||||
"帮我取消{order_id}这个订单",
|
||||
"{order_id}别要了给我撤销",
|
||||
"把订单{order_id}取消掉",
|
||||
"我不要{order_id}了",
|
||||
"撤销一下{order_id}订单",
|
||||
"订单{order_id}不要发了",
|
||||
"帮我把{order_id}退掉并取消",
|
||||
"把{order_id}这一单停掉",
|
||||
"{order_id}这单直接取消",
|
||||
"订单号{order_id}撤回一下",
|
||||
"订单{order_id}我不想要了",
|
||||
"{order_id}这笔订单先别发了",
|
||||
"把{order_id}这单给我撤单",
|
||||
"订单{order_id}停掉吧",
|
||||
"{order_id}这个快给我取消了",
|
||||
],
|
||||
"cs_transfer_human": [
|
||||
"我要找人工客服处理",
|
||||
"现在转人工",
|
||||
"麻烦给我接人工服务",
|
||||
"帮我呼叫真人客服",
|
||||
"别机器人了我要人工",
|
||||
"转真人客服",
|
||||
"我要人工坐席",
|
||||
"帮我接人工处理",
|
||||
"叫人工客服来",
|
||||
"直接给我转人工",
|
||||
"这个问题给我人工跟进",
|
||||
"安排真人客服接手",
|
||||
"机器人处理不了,转人工",
|
||||
"帮我叫个客服专员",
|
||||
"我要人工来处理这事",
|
||||
],
|
||||
"cabin_nav_to": [
|
||||
"导航到{destination}",
|
||||
"带我去{destination}",
|
||||
"我要去{destination}",
|
||||
"去{destination}",
|
||||
"开导航去{destination}",
|
||||
"帮我导航到{destination}",
|
||||
"送我去{destination}",
|
||||
"现在去{destination}",
|
||||
"带路到{destination}",
|
||||
"去一下{destination}",
|
||||
"规划路线去{destination}",
|
||||
"直接开去{destination}",
|
||||
"给我导到{destination}",
|
||||
"{destination}怎么走,导航一下",
|
||||
"出发去{destination}",
|
||||
],
|
||||
"cabin_set_ac": [
|
||||
"把空调设到{temperature}度",
|
||||
"车里温度调成{temperature}度",
|
||||
"冷气开到{temperature}度",
|
||||
"空调给我调到{temperature}度",
|
||||
"温度改成{temperature}度",
|
||||
"车内设成{temperature}度",
|
||||
"把温度打到{temperature}度",
|
||||
"空调调为{temperature}度",
|
||||
"帮我把车里调成{temperature}度",
|
||||
"冷风调到{temperature}度",
|
||||
"把车内温度设为{temperature}度",
|
||||
"空调温度改到{temperature}度",
|
||||
"冷气帮我调到{temperature}度",
|
||||
"舱内调成{temperature}度",
|
||||
"给我把温度定在{temperature}度",
|
||||
"把车里弄凉快点",
|
||||
"车里太热了,降一点",
|
||||
"把里面调凉快一点",
|
||||
"有点热,降温",
|
||||
"空调再冷一点",
|
||||
"车内温度低一点",
|
||||
"把里面弄暖和点",
|
||||
"车里太冷了,升一点温度",
|
||||
],
|
||||
"cabin_ac_on": [
|
||||
"把空调打开",
|
||||
"开一下冷气",
|
||||
"把冷风开起来",
|
||||
"车里热,空调开开",
|
||||
"打开制冷",
|
||||
"空调启动一下",
|
||||
],
|
||||
"cabin_window_open": [
|
||||
"把车窗打开",
|
||||
"开下窗",
|
||||
"窗户开一点",
|
||||
"帮我透透气",
|
||||
"车里太闷了,开下窗",
|
||||
"顺便开下车窗",
|
||||
"把窗户降一点",
|
||||
"把玻璃打开一点",
|
||||
],
|
||||
"cabin_window_close": [
|
||||
"把车窗关上",
|
||||
"窗户关一下",
|
||||
"把窗升起来",
|
||||
"外面太吵了,把窗关了",
|
||||
"把窗户关严",
|
||||
],
|
||||
"cabin_fan_down": [
|
||||
"风别这么大",
|
||||
"风小一点",
|
||||
"别吹这么猛",
|
||||
"把风量调小一点",
|
||||
"出风弱一点",
|
||||
],
|
||||
"cabin_fan_up": [
|
||||
"风再大一点",
|
||||
"把风量开大点",
|
||||
"出风强一点",
|
||||
"风不够,调大些",
|
||||
],
|
||||
"cabin_defog_front_on": [
|
||||
"前挡起雾了,除一下",
|
||||
"把前挡风玻璃雾气清掉",
|
||||
"前窗看不清了,开除雾",
|
||||
],
|
||||
"cabin_defog_rear_on": [
|
||||
"后挡有雾,开下除雾",
|
||||
"后玻璃起雾了,清一下",
|
||||
"后窗看不清了,除雾",
|
||||
],
|
||||
"cabin_play_music": [
|
||||
"播放一首{genre}",
|
||||
"来点{genre}",
|
||||
"我想听{genre}",
|
||||
"给我播点{genre}",
|
||||
"放一首{song}",
|
||||
"来一首{song}",
|
||||
"播放{song}",
|
||||
"放点音乐,来个{genre}",
|
||||
"我想听首{song}",
|
||||
"给我来点歌,放{song}",
|
||||
"随机放点{genre}",
|
||||
"帮我播首{song}",
|
||||
"来点适合开车听的{genre}",
|
||||
"打开音乐,放{song}",
|
||||
"给我放一些{genre}",
|
||||
"放点歌",
|
||||
"来首歌",
|
||||
"整点音乐",
|
||||
"车里放点歌",
|
||||
"来点能听的",
|
||||
],
|
||||
SOCIAL_LABEL: [
|
||||
"你好",
|
||||
"嗨",
|
||||
"哈喽",
|
||||
"早上好",
|
||||
"晚上好",
|
||||
"你叫什么名字",
|
||||
"你是谁",
|
||||
"你能做什么",
|
||||
"今天天气不错",
|
||||
"陪我聊聊天",
|
||||
],
|
||||
OUT_OF_SCOPE_LABEL: [
|
||||
"帮我点个外卖",
|
||||
"订一张去北京的机票",
|
||||
"帮我买杯咖啡",
|
||||
"给我订一家酒店",
|
||||
"人类诞生的意义是什么",
|
||||
"帮我写一份年终总结",
|
||||
"推荐一部电影",
|
||||
"讲个笑话",
|
||||
"帮我做一道数学题",
|
||||
"去美团叫个外卖",
|
||||
],
|
||||
}
|
||||
|
||||
INTENT_REPLACEMENTS: dict[str, list[tuple[str, str]]] = {
|
||||
"cs_query_order": [
|
||||
("订单", "这单"),
|
||||
("查一下", "看一下"),
|
||||
("帮我", "麻烦帮我"),
|
||||
("现在什么状态", "现在啥状态"),
|
||||
("处理到哪里", "进展到哪里"),
|
||||
],
|
||||
"cs_query_logistics": [
|
||||
("物流", "快递"),
|
||||
("快递", "配送"),
|
||||
("配送", "派送"),
|
||||
("帮我", "麻烦帮我"),
|
||||
("现在在哪", "现在到哪了"),
|
||||
],
|
||||
"cs_cancel_order": [
|
||||
("取消", "撤销"),
|
||||
("撤销", "撤单"),
|
||||
("订单", "这单"),
|
||||
("帮我", "麻烦帮我"),
|
||||
("不要发了", "别发了"),
|
||||
],
|
||||
"cs_transfer_human": [
|
||||
("人工客服", "真人客服"),
|
||||
("人工", "人工坐席"),
|
||||
("帮我", "麻烦帮我"),
|
||||
],
|
||||
"cabin_nav_to": [
|
||||
("导航", "带路"),
|
||||
("带我", "送我"),
|
||||
("去", "前往"),
|
||||
],
|
||||
"cabin_set_ac": [
|
||||
("空调", "车里温度"),
|
||||
("调到", "设到"),
|
||||
("温度", "车内温度"),
|
||||
("凉快点", "冷一点"),
|
||||
("暖和点", "热一点"),
|
||||
],
|
||||
"cabin_ac_on": [
|
||||
("空调", "冷气"),
|
||||
("打开", "开"),
|
||||
("冷风", "制冷"),
|
||||
],
|
||||
"cabin_window_open": [
|
||||
("车窗", "窗户"),
|
||||
("打开", "开"),
|
||||
("透透气", "通通风"),
|
||||
],
|
||||
"cabin_window_close": [
|
||||
("关上", "关掉"),
|
||||
("车窗", "窗户"),
|
||||
("关严", "关好"),
|
||||
],
|
||||
"cabin_fan_down": [
|
||||
("风量", "风"),
|
||||
("调小", "调低"),
|
||||
("别吹这么猛", "风小一点"),
|
||||
],
|
||||
"cabin_fan_up": [
|
||||
("风量", "风"),
|
||||
("调大", "调高"),
|
||||
],
|
||||
"cabin_defog_front_on": [
|
||||
("前挡", "前窗"),
|
||||
("除雾", "除一下"),
|
||||
],
|
||||
"cabin_defog_rear_on": [
|
||||
("后挡", "后窗"),
|
||||
("除雾", "清一下雾"),
|
||||
],
|
||||
"cabin_play_music": [
|
||||
("播放", "放"),
|
||||
("来点", "播点"),
|
||||
("我想听", "给我来点"),
|
||||
("放点歌", "来首歌"),
|
||||
],
|
||||
SOCIAL_LABEL: [
|
||||
("你好", "您好"),
|
||||
("哈喽", "hello"),
|
||||
("你叫什么名字", "怎么称呼你"),
|
||||
],
|
||||
OUT_OF_SCOPE_LABEL: [
|
||||
("点个外卖", "叫个外卖"),
|
||||
("订一家酒店", "订个酒店"),
|
||||
("讲个笑话", "说个笑话"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sample:
|
||||
text: str
|
||||
intent_id: str
|
||||
|
||||
|
||||
HARD_NEGATIVE_RAW_SAMPLES: list[tuple[str, str]] = [
|
||||
("订单A700101物流到哪了", "cs_query_logistics"),
|
||||
("帮我看下订单A700102配送到哪里了", "cs_query_logistics"),
|
||||
("订单A700103现在派件了吗", "cs_query_logistics"),
|
||||
("A700104这单物流有没有更新", "cs_query_logistics"),
|
||||
("查一下订单A700105运输轨迹", "cs_query_logistics"),
|
||||
("订单A700106不要了,帮我撤单", "cs_cancel_order"),
|
||||
("A700107这单别发了,直接取消", "cs_cancel_order"),
|
||||
("把订单A700108停掉吧", "cs_cancel_order"),
|
||||
("A700109这个订单我不想要了", "cs_cancel_order"),
|
||||
("订单A700110给我撤回", "cs_cancel_order"),
|
||||
("订单A700111现在受理了吗", "cs_query_order"),
|
||||
("帮我看看A700112这单处理得怎么样了", "cs_query_order"),
|
||||
("A700113订单目前进展如何", "cs_query_order"),
|
||||
("查下订单A700114现在什么情况", "cs_query_order"),
|
||||
("帮我确认订单A700115是否已经处理", "cs_query_order"),
|
||||
("你好呀", SOCIAL_LABEL),
|
||||
("嗨,在吗", SOCIAL_LABEL),
|
||||
("今天天气真不错", SOCIAL_LABEL),
|
||||
("你叫什么名字呀", SOCIAL_LABEL),
|
||||
("你是做什么的", SOCIAL_LABEL),
|
||||
("陪我随便聊聊", SOCIAL_LABEL),
|
||||
("帮我透透气", "cabin_window_open"),
|
||||
("车里太闷了,开下窗", "cabin_window_open"),
|
||||
("把窗户降一点", "cabin_window_open"),
|
||||
("前挡起雾了,除一下", "cabin_defog_front_on"),
|
||||
("后挡有雾,开下除雾", "cabin_defog_rear_on"),
|
||||
("把车里弄凉快点", "cabin_set_ac"),
|
||||
("车里太热了,降一点", "cabin_set_ac"),
|
||||
("空调再冷一点", "cabin_set_ac"),
|
||||
("把里面弄暖和点", "cabin_set_ac"),
|
||||
("风别这么大", "cabin_fan_down"),
|
||||
("别吹这么猛", "cabin_fan_down"),
|
||||
("风再大一点", "cabin_fan_up"),
|
||||
("放点歌", "cabin_play_music"),
|
||||
("来首歌", "cabin_play_music"),
|
||||
("整点音乐", "cabin_play_music"),
|
||||
("透透气", "cabin_window_open"),
|
||||
("通通风", "cabin_window_open"),
|
||||
("车里太闷了", "cabin_window_open"),
|
||||
("凉快点", "cabin_set_ac"),
|
||||
("暖和点", "cabin_set_ac"),
|
||||
("帮我点一份麻辣烫", OUT_OF_SCOPE_LABEL),
|
||||
("给我订今晚的酒店", OUT_OF_SCOPE_LABEL),
|
||||
("帮我买张电影票", OUT_OF_SCOPE_LABEL),
|
||||
("人为什么会做梦", OUT_OF_SCOPE_LABEL),
|
||||
("帮我做个旅游攻略", OUT_OF_SCOPE_LABEL),
|
||||
("帮我点肯德基外卖", OUT_OF_SCOPE_LABEL),
|
||||
("透透气,别给我除雾", "cabin_window_open"),
|
||||
("后挡有雾,不是开窗,是除雾", "cabin_defog_rear_on"),
|
||||
("前挡看不清了,除雾不要开窗", "cabin_defog_front_on"),
|
||||
("凉快点,不是把风量调小", "cabin_set_ac"),
|
||||
("别吹这么猛,不是降温", "cabin_fan_down"),
|
||||
("来首歌,不是切下一首", "cabin_play_music"),
|
||||
("放点歌,不是暂停音乐", "cabin_play_music"),
|
||||
]
|
||||
|
||||
HARD_NEGATIVE_SAMPLES: list[Sample] = [
|
||||
Sample(text=text, intent_id=intent_id) for text, intent_id in HARD_NEGATIVE_RAW_SAMPLES
|
||||
]
|
||||
|
||||
|
||||
class IntentDataset(Dataset):
|
||||
def __init__(self, samples: list[Sample], tokenizer, label_to_id: dict[str, int]) -> None:
|
||||
self._samples = samples
|
||||
self._tokenizer = tokenizer
|
||||
self._label_to_id = label_to_id
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._samples)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
||||
sample = self._samples[index]
|
||||
encoded = self._tokenizer(
|
||||
sample.text,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=MAX_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return {
|
||||
"input_ids": encoded["input_ids"].squeeze(0),
|
||||
"attention_mask": encoded["attention_mask"].squeeze(0),
|
||||
"labels": torch.tensor(self._label_to_id[sample.intent_id], dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
|
||||
def load_samples(file_path: Path) -> list[Sample]:
|
||||
samples: list[Sample] = []
|
||||
if not file_path.exists():
|
||||
return samples
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
samples.append(Sample(text=str(payload["text"]), intent_id=str(payload["intent_id"])))
|
||||
return samples
|
||||
|
||||
|
||||
def load_domain_samples(file_path: Path) -> list[Sample]:
|
||||
if not file_path.exists():
|
||||
return []
|
||||
payload = yaml.safe_load(file_path.read_text(encoding="utf-8")) or {}
|
||||
intents = payload.get("intents", [])
|
||||
samples: list[Sample] = []
|
||||
seen: set[tuple[str, str]] = set()
|
||||
for item in intents:
|
||||
intent_id = str(item.get("intent_id") or "").strip()
|
||||
if not intent_id:
|
||||
continue
|
||||
seed_texts = list(item.get("examples") or [])
|
||||
seed_texts.extend(item.get("keywords") or [])
|
||||
label = str(item.get("label") or "").strip()
|
||||
if label:
|
||||
seed_texts.append(label)
|
||||
for text in seed_texts:
|
||||
normalized = str(text).strip()
|
||||
if not normalized:
|
||||
continue
|
||||
for variant in expand_seed_variants(normalized):
|
||||
key = (variant, intent_id)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(Sample(text=variant, intent_id=intent_id))
|
||||
return samples
|
||||
|
||||
|
||||
def expand_seed_variants(text: str) -> list[str]:
|
||||
normalized = text.strip().strip(",。!?;; ")
|
||||
if not normalized:
|
||||
return []
|
||||
variants = {
|
||||
normalized,
|
||||
normalized.replace("一下", "").strip(),
|
||||
normalized.replace("帮我", "").strip(),
|
||||
normalized.replace("请", "").strip(),
|
||||
f"帮我{normalized}",
|
||||
f"请{normalized}",
|
||||
f"{normalized}一下",
|
||||
}
|
||||
cleaned: list[str] = []
|
||||
for item in variants:
|
||||
compact = " ".join(item.split()).strip(",。!?;; ")
|
||||
if compact:
|
||||
cleaned.append(compact)
|
||||
return cleaned
|
||||
|
||||
|
||||
def load_training_samples() -> list[Sample]:
|
||||
samples = load_samples(TRAIN_PATH)
|
||||
samples.extend(load_domain_samples(DOMAIN_PATH))
|
||||
deduped: list[Sample] = []
|
||||
seen: set[tuple[str, str]] = set()
|
||||
for sample in samples:
|
||||
key = (sample.text, sample.intent_id)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(sample)
|
||||
return deduped
|
||||
|
||||
|
||||
def augment_samples(samples: list[Sample]) -> list[Sample]:
|
||||
augmented = list(samples)
|
||||
seen = {(sample.text, sample.intent_id) for sample in augmented}
|
||||
for intent_id, templates in TEMPLATES.items():
|
||||
for index, template in enumerate(templates):
|
||||
sample = render_template(intent_id, template, index)
|
||||
key = (sample.text, sample.intent_id)
|
||||
if key not in seen:
|
||||
augmented.append(sample)
|
||||
seen.add(key)
|
||||
|
||||
for sample in HARD_NEGATIVE_SAMPLES:
|
||||
key = (sample.text, sample.intent_id)
|
||||
if key not in seen:
|
||||
augmented.append(sample)
|
||||
seen.add(key)
|
||||
|
||||
for sample in list(augmented):
|
||||
text = sample.text
|
||||
for source, target in INTENT_REPLACEMENTS.get(sample.intent_id, []):
|
||||
if source in text:
|
||||
variant = text.replace(source, target, 1)
|
||||
key = (variant, sample.intent_id)
|
||||
if key not in seen:
|
||||
augmented.append(Sample(text=variant, intent_id=sample.intent_id))
|
||||
seen.add(key)
|
||||
|
||||
compact = text
|
||||
for token in ("帮我", "麻烦", "请", "一下"):
|
||||
if token in compact:
|
||||
compact = compact.replace(token, "", 1)
|
||||
compact = compact.strip(" ,。!?")
|
||||
if compact and compact != text:
|
||||
key = (compact, sample.intent_id)
|
||||
if key not in seen:
|
||||
augmented.append(Sample(text=compact, intent_id=sample.intent_id))
|
||||
seen.add(key)
|
||||
random.shuffle(augmented)
|
||||
return augmented
|
||||
|
||||
|
||||
def render_template(intent_id: str, template: str, index: int) -> Sample:
|
||||
order_id = ORDER_IDS[index % len(ORDER_IDS)]
|
||||
destination = DESTINATIONS[index % len(DESTINATIONS)]
|
||||
temperature = TEMPERATURES[index % len(TEMPERATURES)]
|
||||
song = SONGS[index % len(SONGS)]
|
||||
genre = GENRES[index % len(GENRES)]
|
||||
text = template.format(
|
||||
order_id=order_id,
|
||||
destination=destination,
|
||||
temperature=temperature,
|
||||
song=song,
|
||||
genre=genre,
|
||||
)
|
||||
return Sample(text=text, intent_id=intent_id)
|
||||
|
||||
|
||||
def split_samples(samples: list[Sample]) -> tuple[list[Sample], list[Sample]]:
|
||||
grouped: dict[str, list[Sample]] = {}
|
||||
for sample in samples:
|
||||
grouped.setdefault(sample.intent_id, []).append(sample)
|
||||
train_samples: list[Sample] = []
|
||||
dev_samples: list[Sample] = []
|
||||
for items in grouped.values():
|
||||
random.shuffle(items)
|
||||
cut = max(1, int(len(items) * 0.8))
|
||||
train_samples.extend(items[:cut])
|
||||
dev_samples.extend(items[cut:])
|
||||
random.shuffle(train_samples)
|
||||
random.shuffle(dev_samples)
|
||||
return train_samples, dev_samples
|
||||
|
||||
|
||||
def accuracy(model, loader: DataLoader, device: torch.device) -> float:
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for batch in loader:
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
preds = outputs.logits.argmax(dim=-1)
|
||||
correct += int((preds == labels).sum().item())
|
||||
total += int(labels.numel())
|
||||
return correct / total if total else 0.0
|
||||
|
||||
|
||||
def resolve_base_model() -> str:
|
||||
configured = os.getenv("AGENT_BERT_BASE_MODEL", "").strip()
|
||||
if configured:
|
||||
return configured
|
||||
return DEFAULT_BASE_MODEL
|
||||
|
||||
|
||||
def main() -> None:
|
||||
set_seed(SEED)
|
||||
samples = augment_samples(load_training_samples())
|
||||
intents = sorted({sample.intent_id for sample in samples})
|
||||
label_to_id = {intent_id: index for index, intent_id in enumerate(intents)}
|
||||
id_to_label = {index: intent_id for intent_id, index in label_to_id.items()}
|
||||
|
||||
train_samples, dev_samples = split_samples(samples)
|
||||
base_model = resolve_base_model()
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||
train_dataset = IntentDataset(train_samples, tokenizer, label_to_id)
|
||||
dev_dataset = IntentDataset(dev_samples, tokenizer, label_to_id)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
base_model,
|
||||
num_labels=len(intents),
|
||||
id2label=id_to_label,
|
||||
label2id=label_to_id,
|
||||
)
|
||||
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
best_dev_acc = 0.0
|
||||
best_state = None
|
||||
|
||||
for epoch in range(1, EPOCHS + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += float(loss.item())
|
||||
|
||||
dev_acc = accuracy(model, dev_loader, device)
|
||||
avg_loss = total_loss / max(len(train_loader), 1)
|
||||
print(f"epoch={epoch} loss={avg_loss:.4f} dev_acc={dev_acc:.4f}")
|
||||
if dev_acc >= best_dev_acc:
|
||||
best_dev_acc = dev_acc
|
||||
best_state = {key: value.detach().cpu().clone() for key, value in model.state_dict().items()}
|
||||
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
model.save_pretrained(OUTPUT_DIR)
|
||||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||||
label_map = {f"LABEL_{index}": intent_id for index, intent_id in id_to_label.items()}
|
||||
(OUTPUT_DIR / "label_map.json").write_text(
|
||||
json.dumps(label_map, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
train_summary = {
|
||||
"base_model": base_model,
|
||||
"epochs": EPOCHS,
|
||||
"batch_size": BATCH_SIZE,
|
||||
"learning_rate": LEARNING_RATE,
|
||||
"train_size": len(train_samples),
|
||||
"dev_size": len(dev_samples),
|
||||
"best_dev_accuracy": round(best_dev_acc, 4),
|
||||
"device": str(device),
|
||||
}
|
||||
(OUTPUT_DIR / "train_summary.json").write_text(
|
||||
json.dumps(train_summary, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
print(json.dumps(train_summary, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user