Update project and configurations
This commit is contained in:
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.joint_nlu import JointBertNLU
|
||||
|
||||
|
||||
DEFAULT_TEST_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_eval_independent.jsonl"
|
||||
|
||||
|
||||
def load_cases(path: Path) -> list[dict[str, object]]:
|
||||
rows: list[dict[str, object]] = []
|
||||
for line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
rows.append(payload)
|
||||
return rows
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Joint BERT 多意图独立评测")
|
||||
parser.add_argument("--model-path", type=str, default="models/local_joint_bert_nlu")
|
||||
parser.add_argument("--test-path", type=str, default=str(DEFAULT_TEST_PATH))
|
||||
args = parser.parse_args()
|
||||
|
||||
registry = build_intent_registry()
|
||||
nlu = JointBertNLU(model_path=args.model_path)
|
||||
cases = load_cases(Path(args.test_path))
|
||||
|
||||
tp = 0
|
||||
fp = 0
|
||||
fn = 0
|
||||
exact = 0
|
||||
failures: list[dict[str, object]] = []
|
||||
category_correct: Counter[str] = Counter()
|
||||
category_total: Counter[str] = Counter()
|
||||
|
||||
for case in cases:
|
||||
text = str(case["text"])
|
||||
expected = sorted({str(item) for item in case.get("expected_intent_ids", [])})
|
||||
predicted = sorted(item.intent_id for item in nlu.predict_multi_intents(text, registry.list(), top_k=8, max_labels=4))
|
||||
expected_set = set(expected)
|
||||
predicted_set = set(predicted)
|
||||
tp += len(expected_set & predicted_set)
|
||||
fp += len(predicted_set - expected_set)
|
||||
fn += len(expected_set - predicted_set)
|
||||
category = str(case.get("category") or "unknown")
|
||||
category_total[category] += 1
|
||||
if expected_set == predicted_set:
|
||||
exact += 1
|
||||
category_correct[category] += 1
|
||||
else:
|
||||
failures.append(
|
||||
{
|
||||
"text": text,
|
||||
"expected_intent_ids": expected,
|
||||
"predicted_intent_ids": predicted,
|
||||
"category": category,
|
||||
}
|
||||
)
|
||||
|
||||
precision = tp / max(tp + fp, 1)
|
||||
recall = tp / max(tp + fn, 1)
|
||||
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
result = {
|
||||
"sample_count": len(cases),
|
||||
"micro_precision": round(precision, 4),
|
||||
"micro_recall": round(recall, 4),
|
||||
"micro_f1": round(f1, 4),
|
||||
"exact_match": round(exact / max(len(cases), 1), 4),
|
||||
"per_category_exact_match": {
|
||||
category: round(category_correct[category] / max(total, 1), 4)
|
||||
for category, total in sorted(category_total.items())
|
||||
},
|
||||
"failures": failures[:20],
|
||||
}
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
59
intelligent_cabin/archive/scripts/eval_joint_bert_nlu.py
Normal file
59
intelligent_cabin/archive/scripts/eval_joint_bert_nlu.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.joint_nlu import JointBertNLU
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="评测 Joint BERT NLU 单句意图与槽位输出")
|
||||
parser.add_argument("--text", type=str, required=True, help="待评测文本")
|
||||
parser.add_argument("--model-path", type=str, default="models/local_joint_bert_nlu", help="模型目录")
|
||||
args = parser.parse_args()
|
||||
|
||||
registry = build_intent_registry()
|
||||
nlu = JointBertNLU(model_path=args.model_path)
|
||||
result = nlu.predict(args.text, registry.list())
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"text": args.text,
|
||||
"intent_id": result.intent_id,
|
||||
"intent_score": round(result.intent_score, 4),
|
||||
"candidates": [
|
||||
{"intent_id": item.intent_id, "score": round(item.score, 4)}
|
||||
for item in result.candidates
|
||||
],
|
||||
"multi_intent_candidates": [
|
||||
{"intent_id": item.intent_id, "score": round(item.score, 4)}
|
||||
for item in result.multi_intent_candidates
|
||||
],
|
||||
"slots": result.slots,
|
||||
"slot_items": [
|
||||
{
|
||||
"slot_name": item.slot_name,
|
||||
"value": item.value,
|
||||
"start": item.start,
|
||||
"end": item.end,
|
||||
"score": item.score,
|
||||
}
|
||||
for item in result.slot_items
|
||||
],
|
||||
"error_message": result.error_message,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
275
intelligent_cabin/archive/scripts/eval_joint_nlu_independent.py
Normal file
275
intelligent_cabin/archive/scripts/eval_joint_nlu_independent.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from collections import Counter, defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.joint_nlu import JointBertNLU
|
||||
|
||||
|
||||
TEST_PATH = PROJECT_ROOT / "app/data/joint_nlu_eval_independent.jsonl"
|
||||
MODEL_DIR = PROJECT_ROOT / "models/local_joint_bert_nlu"
|
||||
REPORT_DIR = PROJECT_ROOT / "reports"
|
||||
RESULT_PATH = REPORT_DIR / "joint_nlu_independent_result.json"
|
||||
REPORT_PATH = REPORT_DIR / "joint_nlu_independent_report.md"
|
||||
TRAIN_SUMMARY_PATH = MODEL_DIR / "train_summary.json"
|
||||
|
||||
|
||||
def load_cases(file_path: Path) -> list[dict[str, object]]:
|
||||
cases: list[dict[str, object]] = []
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
expected_intent_id = str(payload.get("expected_intent_id") or payload.get("intent_id") or "").strip()
|
||||
if not expected_intent_id:
|
||||
continue
|
||||
cases.append(
|
||||
{
|
||||
"text": str(payload["text"]),
|
||||
"expected_intent_id": expected_intent_id,
|
||||
"expected_slots": dict(payload.get("expected_slots") or {}),
|
||||
"category": str(payload.get("category") or "unknown"),
|
||||
}
|
||||
)
|
||||
return cases
|
||||
|
||||
|
||||
def load_train_summary(file_path: Path) -> dict[str, object]:
|
||||
if not file_path.exists():
|
||||
return {}
|
||||
return json.loads(file_path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def compare_slots(expected: dict[str, object], predicted: dict[str, object]) -> dict[str, object]:
|
||||
expected_keys = set(expected)
|
||||
predicted_keys = set(predicted)
|
||||
missing_keys = sorted(expected_keys - predicted_keys)
|
||||
extra_keys = sorted(predicted_keys - expected_keys)
|
||||
wrong_values: list[dict[str, object]] = []
|
||||
matched_keys = 0
|
||||
for key in sorted(expected_keys & predicted_keys):
|
||||
if expected[key] == predicted[key]:
|
||||
matched_keys += 1
|
||||
else:
|
||||
wrong_values.append(
|
||||
{
|
||||
"slot_name": key,
|
||||
"expected": expected[key],
|
||||
"predicted": predicted[key],
|
||||
}
|
||||
)
|
||||
exact = not missing_keys and not extra_keys and not wrong_values
|
||||
return {
|
||||
"missing_keys": missing_keys,
|
||||
"extra_keys": extra_keys,
|
||||
"wrong_values": wrong_values,
|
||||
"matched_keys": matched_keys,
|
||||
"exact": exact,
|
||||
}
|
||||
|
||||
|
||||
def compute_metrics(results: list[dict[str, object]]) -> dict[str, float]:
|
||||
total = len(results)
|
||||
intent_correct = sum(1 for item in results if item["intent_ok"])
|
||||
slot_exact = sum(1 for item in results if item["slot_exact"])
|
||||
joint_exact = sum(1 for item in results if item["joint_ok"])
|
||||
|
||||
slot_tp = 0
|
||||
slot_fp = 0
|
||||
slot_fn = 0
|
||||
for item in results:
|
||||
expected = dict(item["expected_slots"])
|
||||
predicted = dict(item["predicted_slots"])
|
||||
expected_keys = set(expected)
|
||||
predicted_keys = set(predicted)
|
||||
slot_tp += sum(1 for key in expected_keys & predicted_keys if expected[key] == predicted[key])
|
||||
slot_fp += len(predicted_keys - expected_keys)
|
||||
slot_fn += len(expected_keys - predicted_keys)
|
||||
slot_fp += sum(1 for key in expected_keys & predicted_keys if expected[key] != predicted[key])
|
||||
slot_fn += sum(1 for key in expected_keys & predicted_keys if expected[key] != predicted[key])
|
||||
|
||||
precision = slot_tp / (slot_tp + slot_fp) if (slot_tp + slot_fp) else 0.0
|
||||
recall = slot_tp / (slot_tp + slot_fn) if (slot_tp + slot_fn) else 0.0
|
||||
slot_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
return {
|
||||
"intent_accuracy": round(intent_correct / total, 4) if total else 0.0,
|
||||
"slot_exact_match": round(slot_exact / total, 4) if total else 0.0,
|
||||
"joint_exact_match": round(joint_exact / total, 4) if total else 0.0,
|
||||
"slot_micro_precision": round(precision, 4),
|
||||
"slot_micro_recall": round(recall, 4),
|
||||
"slot_micro_f1": round(slot_f1, 4),
|
||||
}
|
||||
|
||||
|
||||
def summarize_by_category(results: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||
grouped: dict[str, list[dict[str, object]]] = defaultdict(list)
|
||||
for item in results:
|
||||
grouped[str(item["category"])].append(item)
|
||||
summary: list[dict[str, object]] = []
|
||||
for category, items in sorted(grouped.items()):
|
||||
summary.append(
|
||||
{
|
||||
"category": category,
|
||||
"sample_count": len(items),
|
||||
"metrics": compute_metrics(items),
|
||||
}
|
||||
)
|
||||
return summary
|
||||
|
||||
|
||||
def collect_top_confusions(results: list[dict[str, object]], limit: int = 12) -> list[dict[str, object]]:
|
||||
counter: Counter[tuple[str, str]] = Counter()
|
||||
for item in results:
|
||||
if item["intent_ok"]:
|
||||
continue
|
||||
counter[(str(item["expected_intent_id"]), str(item["predicted_intent_id"]))] += 1
|
||||
return [
|
||||
{"expected": expected, "predicted": predicted, "count": count}
|
||||
for (expected, predicted), count in counter.most_common(limit)
|
||||
]
|
||||
|
||||
|
||||
def collect_failures(results: list[dict[str, object]], limit: int = 20) -> list[dict[str, object]]:
|
||||
failures = [item for item in results if not item["joint_ok"]]
|
||||
|
||||
def sort_key(item: dict[str, object]) -> tuple[int, int, int]:
|
||||
slot_errors = len(item["slot_diff"]["missing_keys"]) + len(item["slot_diff"]["extra_keys"]) + len(item["slot_diff"]["wrong_values"])
|
||||
return (0 if item["intent_ok"] else 1, slot_errors, len(str(item["text"])))
|
||||
|
||||
return sorted(failures, key=sort_key, reverse=True)[:limit]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Joint NLU 独立评测与失败样例回放")
|
||||
parser.add_argument("--test-path", type=str, default=str(TEST_PATH), help="评测集路径")
|
||||
parser.add_argument("--model-path", type=str, default=str(MODEL_DIR), help="Joint NLU 模型路径")
|
||||
parser.add_argument("--result-path", type=str, default=str(RESULT_PATH), help="结构化结果输出路径")
|
||||
parser.add_argument("--report-path", type=str, default=str(REPORT_PATH), help="Markdown 报告输出路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
cases = load_cases(Path(args.test_path))
|
||||
registry = build_intent_registry()
|
||||
nlu = JointBertNLU(model_path=args.model_path)
|
||||
results: list[dict[str, object]] = []
|
||||
for case in cases:
|
||||
prediction = nlu.predict(str(case["text"]), registry.list())
|
||||
predicted_slots = dict(prediction.slots)
|
||||
slot_diff = compare_slots(dict(case["expected_slots"]), predicted_slots)
|
||||
predicted_intent_id = prediction.intent_id or "None"
|
||||
intent_ok = predicted_intent_id == case["expected_intent_id"]
|
||||
joint_ok = intent_ok and bool(slot_diff["exact"])
|
||||
results.append(
|
||||
{
|
||||
"text": case["text"],
|
||||
"category": case["category"],
|
||||
"expected_intent_id": case["expected_intent_id"],
|
||||
"predicted_intent_id": predicted_intent_id,
|
||||
"expected_slots": case["expected_slots"],
|
||||
"predicted_slots": predicted_slots,
|
||||
"intent_score": round(prediction.intent_score, 4),
|
||||
"intent_ok": intent_ok,
|
||||
"slot_exact": bool(slot_diff["exact"]),
|
||||
"joint_ok": joint_ok,
|
||||
"slot_diff": slot_diff,
|
||||
"top_candidates": [
|
||||
{"intent_id": item.intent_id, "score": round(item.score, 4)}
|
||||
for item in prediction.candidates
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
summary = {
|
||||
"model_path": args.model_path,
|
||||
"test_path": args.test_path,
|
||||
"sample_count": len(results),
|
||||
"metrics": compute_metrics(results),
|
||||
"per_category": summarize_by_category(results),
|
||||
"top_confusions": collect_top_confusions(results),
|
||||
"failure_examples": collect_failures(results),
|
||||
"train_summary": load_train_summary(TRAIN_SUMMARY_PATH),
|
||||
"results": results,
|
||||
}
|
||||
REPORT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
Path(args.result_path).write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
Path(args.report_path).write_text(render_report(summary), encoding="utf-8")
|
||||
print(json.dumps({"sample_count": summary["sample_count"], "metrics": summary["metrics"]}, ensure_ascii=False))
|
||||
|
||||
|
||||
def render_report(summary: dict[str, object]) -> str:
|
||||
metrics = summary["metrics"]
|
||||
per_category = summary["per_category"]
|
||||
confusions = summary["top_confusions"]
|
||||
failures = summary["failure_examples"]
|
||||
train_summary = summary.get("train_summary") or {}
|
||||
lines = [
|
||||
"# Joint NLU 独立评测报告",
|
||||
"",
|
||||
"## 概览",
|
||||
f"- 模型目录:`{summary['model_path']}`",
|
||||
f"- 评测集:`{summary['test_path']}`",
|
||||
f"- 样本数:`{summary['sample_count']}`",
|
||||
f"- `intent_accuracy`:`{metrics['intent_accuracy']}`",
|
||||
f"- `slot_exact_match`:`{metrics['slot_exact_match']}`",
|
||||
f"- `joint_exact_match`:`{metrics['joint_exact_match']}`",
|
||||
f"- `slot_micro_precision`:`{metrics['slot_micro_precision']}`",
|
||||
f"- `slot_micro_recall`:`{metrics['slot_micro_recall']}`",
|
||||
f"- `slot_micro_f1`:`{metrics['slot_micro_f1']}`",
|
||||
"",
|
||||
"## 训练摘要",
|
||||
]
|
||||
if train_summary:
|
||||
lines.extend(
|
||||
[
|
||||
f"- 训练集 / 评测集:`{train_summary.get('train_size', 'unknown')} / {train_summary.get('eval_size', 'unknown')}`",
|
||||
f"- 训练阶段 `intent_accuracy`:`{train_summary.get('metrics', {}).get('intent_accuracy', 'unknown')}`",
|
||||
f"- 训练阶段 `slot_exact_match`:`{train_summary.get('metrics', {}).get('slot_exact_match', 'unknown')}`",
|
||||
"",
|
||||
]
|
||||
)
|
||||
else:
|
||||
lines.extend(["- 未找到训练摘要。", ""])
|
||||
lines.extend(["## 分类别结果"])
|
||||
for item in per_category:
|
||||
category_metrics = item["metrics"]
|
||||
lines.append(
|
||||
f"- `{item['category']}`: count={item['sample_count']} intent_acc={category_metrics['intent_accuracy']} slot_exact={category_metrics['slot_exact_match']} joint_exact={category_metrics['joint_exact_match']}"
|
||||
)
|
||||
lines.extend(["", "## 主要意图混淆"])
|
||||
if not confusions:
|
||||
lines.append("- 未发现意图混淆。")
|
||||
else:
|
||||
for item in confusions:
|
||||
lines.append(f"- 期望 `{item['expected']}`,预测成 `{item['predicted']}`:`{item['count']}` 次")
|
||||
lines.extend(["", "## 失败样例回放"])
|
||||
if not failures:
|
||||
lines.append("- 无失败样例。")
|
||||
else:
|
||||
for item in failures:
|
||||
slot_diff = item["slot_diff"]
|
||||
lines.append(
|
||||
f"- 文本:`{item['text']}` | 类别:`{item['category']}` | 期望意图:`{item['expected_intent_id']}` | 预测意图:`{item['predicted_intent_id']}` | 期望槽位:`{item['expected_slots']}` | 预测槽位:`{item['predicted_slots']}` | 缺失槽位:`{slot_diff['missing_keys']}` | 多出槽位:`{slot_diff['extra_keys']}`"
|
||||
)
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"## 结论",
|
||||
"- 先看 `failure_replay` 是否仍然错,能直接判断先前多意图失败到底是联合模型本体问题还是上层组合问题。",
|
||||
"- 若 `slot_music` 或 `slot_destination` 仍不稳,优先补 span 标注,不要回退到规则抽槽。",
|
||||
"- 若 `no_slot_control` 很稳但 `failure_replay` 中仍有大量错误,下一步应补长尾控制语义数据,而不是急着上更复杂结构。",
|
||||
"",
|
||||
]
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
231
intelligent_cabin/archive/scripts/eval_local_bert_intent.py
Normal file
231
intelligent_cabin/archive/scripts/eval_local_bert_intent.py
Normal file
@@ -0,0 +1,231 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections import Counter, defaultdict
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.classifier import BertIntentClassifier
|
||||
|
||||
|
||||
TEST_PATH = PROJECT_ROOT / "app/data/bert_intent_eval_independent.jsonl"
|
||||
MODEL_DIR = PROJECT_ROOT / "models/local_bert_intent"
|
||||
REPORT_DIR = PROJECT_ROOT / "reports"
|
||||
REPORT_PATH = REPORT_DIR / "bert_local_test_report.md"
|
||||
RESULT_PATH = REPORT_DIR / "bert_local_test_result.json"
|
||||
BERT_THRESHOLD = 0.0
|
||||
TRAIN_SUMMARY_PATH = MODEL_DIR / "train_summary.json"
|
||||
|
||||
|
||||
def load_cases(file_path: Path) -> list[dict[str, str]]:
|
||||
cases: list[dict[str, str]] = []
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
expected_label = str(payload.get("expected_label") or payload.get("intent_id") or "").strip()
|
||||
if not expected_label:
|
||||
continue
|
||||
category = str(payload.get("category") or infer_category(expected_label)).strip()
|
||||
cases.append(
|
||||
{
|
||||
"text": str(payload["text"]),
|
||||
"expected_label": expected_label,
|
||||
"category": category,
|
||||
}
|
||||
)
|
||||
return cases
|
||||
|
||||
|
||||
def load_train_summary(file_path: Path) -> dict[str, object]:
|
||||
if not file_path.exists():
|
||||
return {}
|
||||
return json.loads(file_path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def infer_category(label: str) -> str:
|
||||
if label == "__social__":
|
||||
return "social"
|
||||
if label == "__out_of_scope__":
|
||||
return "out_of_scope"
|
||||
return "business"
|
||||
|
||||
|
||||
def resolve_predicted_label(result) -> str:
|
||||
if result.intent is not None:
|
||||
return result.intent.intent_id
|
||||
if result.raw_label:
|
||||
return str(result.raw_label)
|
||||
return "None"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="本地 BERT 独立评测脚本")
|
||||
parser.add_argument("--test-path", type=str, default=str(TEST_PATH), help="评测集路径")
|
||||
parser.add_argument("--result-path", type=str, default=str(RESULT_PATH), help="结构化评测结果输出路径")
|
||||
parser.add_argument("--report-path", type=str, default=str(REPORT_PATH), help="Markdown 评测报告输出路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
intent_registry = build_intent_registry()
|
||||
intents = intent_registry.list()
|
||||
classifier = BertIntentClassifier(
|
||||
model_path=str(MODEL_DIR),
|
||||
threshold=BERT_THRESHOLD,
|
||||
label_map_path=str(MODEL_DIR / "label_map.json"),
|
||||
fallback=None,
|
||||
top_k=3,
|
||||
)
|
||||
cases = load_cases(Path(args.test_path))
|
||||
|
||||
results: list[dict[str, object]] = []
|
||||
confusion: dict[str, Counter[str]] = defaultdict(Counter)
|
||||
category_confusion: dict[str, Counter[str]] = defaultdict(Counter)
|
||||
correct = 0
|
||||
|
||||
for case in cases:
|
||||
result = classifier.predict(case["text"], intents)
|
||||
predicted = resolve_predicted_label(result)
|
||||
expected = case["expected_label"]
|
||||
ok = predicted == expected
|
||||
if ok:
|
||||
correct += 1
|
||||
confusion[expected][predicted] += 1
|
||||
category_confusion[case["category"]]["correct" if ok else "wrong"] += 1
|
||||
results.append(
|
||||
{
|
||||
"text": case["text"],
|
||||
"category": case["category"],
|
||||
"expected_label": expected,
|
||||
"predicted_label": predicted,
|
||||
"score": round(result.score, 4),
|
||||
"raw_label": result.raw_label,
|
||||
"ok": ok,
|
||||
"top_candidates": [
|
||||
{"intent_id": intent.intent_id, "score": round(score, 4)}
|
||||
for intent, score in (result.candidates or [])
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
accuracy = correct / len(cases) if cases else 0.0
|
||||
train_summary = load_train_summary(TRAIN_SUMMARY_PATH)
|
||||
per_label_stats: list[dict[str, object]] = []
|
||||
for label in sorted({case["expected_label"] for case in cases}):
|
||||
label_cases = [item for item in results if item["expected_label"] == label]
|
||||
label_correct = sum(1 for item in label_cases if item["ok"])
|
||||
per_label_stats.append(
|
||||
{
|
||||
"label": label,
|
||||
"category": infer_category(label),
|
||||
"total": len(label_cases),
|
||||
"correct": label_correct,
|
||||
"accuracy": round(label_correct / len(label_cases), 4) if label_cases else 0.0,
|
||||
}
|
||||
)
|
||||
per_category_stats: list[dict[str, object]] = []
|
||||
for category in sorted({case["category"] for case in cases}):
|
||||
category_cases = [item for item in results if item["category"] == category]
|
||||
category_correct = sum(1 for item in category_cases if item["ok"])
|
||||
per_category_stats.append(
|
||||
{
|
||||
"category": category,
|
||||
"total": len(category_cases),
|
||||
"correct": category_correct,
|
||||
"accuracy": round(category_correct / len(category_cases), 4) if category_cases else 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
errors = [item for item in results if not item["ok"]]
|
||||
summary = {
|
||||
"model_dir": str(MODEL_DIR),
|
||||
"threshold": BERT_THRESHOLD,
|
||||
"test_path": str(args.test_path),
|
||||
"test_case_count": len(cases),
|
||||
"accuracy": round(accuracy, 4),
|
||||
"train_summary": train_summary,
|
||||
"per_category": per_category_stats,
|
||||
"per_label": per_label_stats,
|
||||
"errors": errors,
|
||||
"confusion": {key: dict(value) for key, value in confusion.items()},
|
||||
}
|
||||
|
||||
REPORT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
Path(args.result_path).write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
Path(args.report_path).write_text(render_report(summary), encoding="utf-8")
|
||||
print(json.dumps({"accuracy": summary["accuracy"], "test_case_count": len(cases), "error_count": len(errors)}, ensure_ascii=False))
|
||||
|
||||
|
||||
def render_report(summary: dict[str, object]) -> str:
|
||||
per_category = summary["per_category"]
|
||||
per_label = summary["per_label"]
|
||||
errors = summary["errors"]
|
||||
train_summary = summary.get("train_summary") or {}
|
||||
lines = [
|
||||
"# 本地 BERT 意图识别测试报告",
|
||||
"",
|
||||
"## 概览",
|
||||
f"- 模型目录:`{summary['model_dir']}`",
|
||||
f"- 评测集:`{summary['test_path']}`",
|
||||
f"- 评测阈值:`{summary['threshold']}`",
|
||||
f"- 测试样本数:`{summary['test_case_count']}`",
|
||||
f"- 总体准确率:`{summary['accuracy']}`",
|
||||
"",
|
||||
"## 训练摘要",
|
||||
]
|
||||
if train_summary:
|
||||
lines.extend(
|
||||
[
|
||||
f"- 基座模型:`{train_summary.get('base_model', 'unknown')}`",
|
||||
f"- 训练集 / 验证集:`{train_summary.get('train_size', 'unknown')} / {train_summary.get('dev_size', 'unknown')}`",
|
||||
f"- 最佳验证准确率:`{train_summary.get('best_dev_accuracy', 'unknown')}`",
|
||||
f"- 训练设备:`{train_summary.get('device', 'unknown')}`",
|
||||
"",
|
||||
]
|
||||
)
|
||||
else:
|
||||
lines.extend(["- 未找到训练摘要。", ""])
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"## 分类别结果",
|
||||
]
|
||||
)
|
||||
for item in per_category:
|
||||
lines.append(
|
||||
f"- `{item['category']}`: {item['correct']}/{item['total']} = {item['accuracy']}"
|
||||
)
|
||||
lines.extend(["", "## 分标签结果"])
|
||||
for item in per_label:
|
||||
lines.append(
|
||||
f"- `{item['label']}` ({item['category']}): {item['correct']}/{item['total']} = {item['accuracy']}"
|
||||
)
|
||||
lines.extend(["", "## 错误样例"])
|
||||
if not errors:
|
||||
lines.append("- 无错误样例。")
|
||||
else:
|
||||
for item in errors[:10]:
|
||||
lines.append(
|
||||
f"- 文本:`{item['text']}` | 类别:`{item['category']}` | 期望:`{item['expected_label']}` | 预测:`{item['predicted_label']}` | 分数:`{item['score']}`"
|
||||
)
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"## 结论",
|
||||
"- 当前本地 MacBERT 已具备较强的业务意图识别能力,可作为本地快链路分类器。",
|
||||
"- 误判主要集中在方向相反或语义接近的控制指令,下一步应补充对抗样本和真实口语表达。",
|
||||
"- 上线前建议继续补充 ASR 错字、多轮短句和多意图子句级样本。",
|
||||
"",
|
||||
]
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from scripts.train_local_bert_multi_intent import (
|
||||
BATCH_SIZE,
|
||||
OUTPUT_DIR,
|
||||
TOP_K,
|
||||
THRESHOLD,
|
||||
MultiLabelIntentDataset,
|
||||
load_all_samples,
|
||||
split_samples,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Evaluate local BERT multi-intent detector.")
|
||||
parser.add_argument("--model-path", default=str(OUTPUT_DIR), help="Path to trained multi-intent model.")
|
||||
parser.add_argument("--threshold", type=float, default=THRESHOLD, help="Probability threshold.")
|
||||
parser.add_argument("--top-k", type=int, default=TOP_K, help="Top-k for recall@k.")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
choices=("dev", "all"),
|
||||
default="dev",
|
||||
help="Evaluate on the held-out dev split or all combined samples.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
probabilities: list[list[float]],
|
||||
targets: list[list[float]],
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
) -> dict[str, float]:
|
||||
true_positive = 0
|
||||
false_positive = 0
|
||||
false_negative = 0
|
||||
exact_match = 0
|
||||
recall_at_k_total = 0.0
|
||||
total = len(probabilities)
|
||||
for scores, target in zip(probabilities, targets):
|
||||
predicted = {index for index, score in enumerate(scores) if score >= threshold}
|
||||
expected = {index for index, value in enumerate(target) if value >= 0.5}
|
||||
if predicted == expected:
|
||||
exact_match += 1
|
||||
true_positive += len(predicted & expected)
|
||||
false_positive += len(predicted - expected)
|
||||
false_negative += len(expected - predicted)
|
||||
top_indices = sorted(range(len(scores)), key=lambda index: scores[index], reverse=True)[:top_k]
|
||||
if expected:
|
||||
recall_at_k_total += len(set(top_indices) & expected) / len(expected)
|
||||
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
|
||||
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
|
||||
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
return {
|
||||
"micro_precision": round(precision, 4),
|
||||
"micro_recall": round(recall, 4),
|
||||
"micro_f1": round(micro_f1, 4),
|
||||
"exact_match": round(exact_match / total, 4) if total else 0.0,
|
||||
"recall_at_k": round(recall_at_k_total / total, 4) if total else 0.0,
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
set_seed(42)
|
||||
samples = load_all_samples()
|
||||
_, dev_samples = split_samples(samples)
|
||||
eval_samples = samples if args.dataset == "all" else dev_samples
|
||||
model_path = Path(args.model_path)
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"model path not found: {model_path}")
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
label_to_id = {str(label): int(index) for label, index in (model.config.label2id or {}).items()}
|
||||
if not label_to_id:
|
||||
raise RuntimeError("label2id is missing from model config")
|
||||
|
||||
dataset = MultiLabelIntentDataset(eval_samples, tokenizer, label_to_id)
|
||||
loader = DataLoader(dataset, batch_size=BATCH_SIZE)
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
probabilities: list[list[float]] = []
|
||||
targets: list[list[float]] = []
|
||||
with torch.no_grad():
|
||||
for batch in loader:
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
probabilities.extend(torch.sigmoid(outputs.logits).detach().cpu().tolist())
|
||||
targets.extend(labels.detach().cpu().tolist())
|
||||
|
||||
metrics = compute_metrics(probabilities, targets, threshold=args.threshold, top_k=args.top_k)
|
||||
result = {
|
||||
"model_path": str(model_path),
|
||||
"dataset": args.dataset,
|
||||
"sample_size": len(eval_samples),
|
||||
"threshold": args.threshold,
|
||||
"top_k": args.top_k,
|
||||
"metrics": metrics,
|
||||
}
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,247 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections import Counter, defaultdict
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.multi_intent_detector import BertMultiIntentDetector
|
||||
|
||||
|
||||
TEST_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_eval_independent.jsonl"
|
||||
MODEL_DIR = PROJECT_ROOT / "models/local_bert_multi_intent"
|
||||
REPORT_DIR = PROJECT_ROOT / "reports"
|
||||
RESULT_PATH = REPORT_DIR / "bert_multi_intent_independent_result.json"
|
||||
REPORT_PATH = REPORT_DIR / "bert_multi_intent_independent_report.md"
|
||||
THRESHOLD = 0.45
|
||||
TOP_K = 8
|
||||
MAX_LABELS = 4
|
||||
|
||||
|
||||
def load_cases(file_path: Path) -> list[dict[str, object]]:
|
||||
cases: list[dict[str, object]] = []
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
expected = sorted({str(item).strip() for item in payload.get("expected_intent_ids") or [] if str(item).strip()})
|
||||
if not expected:
|
||||
continue
|
||||
cases.append(
|
||||
{
|
||||
"text": str(payload["text"]),
|
||||
"expected_intent_ids": expected,
|
||||
"category": str(payload.get("category") or "unknown"),
|
||||
}
|
||||
)
|
||||
return cases
|
||||
|
||||
|
||||
def compute_set_metrics(results: list[dict[str, object]]) -> dict[str, float]:
|
||||
true_positive = 0
|
||||
false_positive = 0
|
||||
false_negative = 0
|
||||
exact_match = 0
|
||||
multi_recall_hit = 0
|
||||
single_false_alarm = 0
|
||||
total = len(results)
|
||||
single_guard_total = 0
|
||||
for item in results:
|
||||
expected = set(item["expected_intent_ids"])
|
||||
predicted = set(item["predicted_intent_ids"])
|
||||
if expected == predicted:
|
||||
exact_match += 1
|
||||
true_positive += len(expected & predicted)
|
||||
false_positive += len(predicted - expected)
|
||||
false_negative += len(expected - predicted)
|
||||
if len(expected) >= 2 and expected.issubset(predicted):
|
||||
multi_recall_hit += 1
|
||||
if len(expected) == 1:
|
||||
single_guard_total += 1
|
||||
if len(predicted) > 1:
|
||||
single_false_alarm += 1
|
||||
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
|
||||
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
|
||||
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
multi_total = sum(1 for item in results if len(item["expected_intent_ids"]) >= 2)
|
||||
return {
|
||||
"micro_precision": round(precision, 4),
|
||||
"micro_recall": round(recall, 4),
|
||||
"micro_f1": round(micro_f1, 4),
|
||||
"exact_match": round(exact_match / total, 4) if total else 0.0,
|
||||
"multi_sentence_recall": round(multi_recall_hit / multi_total, 4) if multi_total else 0.0,
|
||||
"single_guard_false_alarm_rate": round(single_false_alarm / single_guard_total, 4) if single_guard_total else 0.0,
|
||||
}
|
||||
|
||||
|
||||
def summarize_by_category(results: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||
grouped: dict[str, list[dict[str, object]]] = defaultdict(list)
|
||||
for item in results:
|
||||
grouped[str(item["category"])].append(item)
|
||||
summary: list[dict[str, object]] = []
|
||||
for category, items in sorted(grouped.items()):
|
||||
summary.append(
|
||||
{
|
||||
"category": category,
|
||||
"sample_count": len(items),
|
||||
"metrics": compute_set_metrics(items),
|
||||
}
|
||||
)
|
||||
return summary
|
||||
|
||||
|
||||
def collect_error_examples(results: list[dict[str, object]], limit: int = 15) -> list[dict[str, object]]:
|
||||
errors = [item for item in results if set(item["expected_intent_ids"]) != set(item["predicted_intent_ids"])]
|
||||
def sort_key(item: dict[str, object]) -> tuple[int, int]:
|
||||
expected = set(item["expected_intent_ids"])
|
||||
predicted = set(item["predicted_intent_ids"])
|
||||
miss = len(expected - predicted)
|
||||
extra = len(predicted - expected)
|
||||
return (miss + extra, miss)
|
||||
return sorted(errors, key=sort_key, reverse=True)[:limit]
|
||||
|
||||
|
||||
def top_confusions(results: list[dict[str, object]], limit: int = 12) -> list[dict[str, object]]:
|
||||
counter: Counter[tuple[str, str]] = Counter()
|
||||
for item in results:
|
||||
expected = set(item["expected_intent_ids"])
|
||||
predicted = set(item["predicted_intent_ids"])
|
||||
for miss in sorted(expected - predicted):
|
||||
for extra in sorted(predicted - expected):
|
||||
counter[(miss, extra)] += 1
|
||||
return [
|
||||
{"expected_missing": pair[0], "wrong_extra": pair[1], "count": count}
|
||||
for pair, count in counter.most_common(limit)
|
||||
]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="本地多标签 detector 独立评测脚本")
|
||||
parser.add_argument("--test-path", type=str, default=str(TEST_PATH), help="独立评测集路径")
|
||||
parser.add_argument("--model-path", type=str, default=str(MODEL_DIR), help="多标签模型路径")
|
||||
parser.add_argument("--threshold", type=float, default=THRESHOLD, help="检测阈值")
|
||||
parser.add_argument("--top-k", type=int, default=TOP_K, help="输出 top-k 原始分数")
|
||||
parser.add_argument("--max-labels", type=int, default=MAX_LABELS, help="最多返回标签数")
|
||||
parser.add_argument("--result-path", type=str, default=str(RESULT_PATH), help="结构化结果输出路径")
|
||||
parser.add_argument("--report-path", type=str, default=str(REPORT_PATH), help="Markdown 报告输出路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
cases = load_cases(Path(args.test_path))
|
||||
intents = build_intent_registry().list()
|
||||
detector = BertMultiIntentDetector(
|
||||
model_path=args.model_path,
|
||||
threshold=args.threshold,
|
||||
top_k=args.top_k,
|
||||
max_labels=args.max_labels,
|
||||
)
|
||||
|
||||
results: list[dict[str, object]] = []
|
||||
for case in cases:
|
||||
detection = detector.detect(str(case["text"]), intents)
|
||||
predicted = [candidate.intent_id for candidate in detection.candidates]
|
||||
raw_top = [
|
||||
{
|
||||
"intent_id": str(item.get("intent_id") or item.get("label") or ""),
|
||||
"score": round(float(item.get("score", 0.0)), 4),
|
||||
}
|
||||
for item in detection.raw_scores
|
||||
]
|
||||
results.append(
|
||||
{
|
||||
"text": case["text"],
|
||||
"category": case["category"],
|
||||
"expected_intent_ids": case["expected_intent_ids"],
|
||||
"predicted_intent_ids": predicted,
|
||||
"detected": detection.detected,
|
||||
"backend_name": detection.backend_name,
|
||||
"reason": detection.reason,
|
||||
"raw_top_scores": raw_top,
|
||||
}
|
||||
)
|
||||
|
||||
summary = {
|
||||
"model_path": args.model_path,
|
||||
"test_path": args.test_path,
|
||||
"threshold": args.threshold,
|
||||
"top_k": args.top_k,
|
||||
"max_labels": args.max_labels,
|
||||
"sample_count": len(results),
|
||||
"metrics": compute_set_metrics(results),
|
||||
"per_category": summarize_by_category(results),
|
||||
"top_confusions": top_confusions(results),
|
||||
"error_examples": collect_error_examples(results),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
REPORT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
Path(args.result_path).write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
Path(args.report_path).write_text(render_report(summary), encoding="utf-8")
|
||||
print(json.dumps({"sample_count": len(results), "metrics": summary["metrics"]}, ensure_ascii=False))
|
||||
|
||||
|
||||
def render_report(summary: dict[str, object]) -> str:
|
||||
metrics = summary["metrics"]
|
||||
per_category = summary["per_category"]
|
||||
confusions = summary["top_confusions"]
|
||||
errors = summary["error_examples"]
|
||||
lines = [
|
||||
"# 本地多标签 Detector 独立评测报告",
|
||||
"",
|
||||
"## 概览",
|
||||
f"- 模型目录:`{summary['model_path']}`",
|
||||
f"- 评测集:`{summary['test_path']}`",
|
||||
f"- 样本数:`{summary['sample_count']}`",
|
||||
f"- 阈值 / top_k / max_labels:`{summary['threshold']} / {summary['top_k']} / {summary['max_labels']}`",
|
||||
f"- `micro_precision`:`{metrics['micro_precision']}`",
|
||||
f"- `micro_recall`:`{metrics['micro_recall']}`",
|
||||
f"- `micro_f1`:`{metrics['micro_f1']}`",
|
||||
f"- `exact_match`:`{metrics['exact_match']}`",
|
||||
f"- `multi_sentence_recall`:`{metrics['multi_sentence_recall']}`",
|
||||
f"- `single_guard_false_alarm_rate`:`{metrics['single_guard_false_alarm_rate']}`",
|
||||
"",
|
||||
"## 分类别结果",
|
||||
]
|
||||
for item in per_category:
|
||||
category_metrics = item["metrics"]
|
||||
lines.append(
|
||||
f"- `{item['category']}`: count={item['sample_count']} micro_f1={category_metrics['micro_f1']} exact_match={category_metrics['exact_match']}"
|
||||
)
|
||||
lines.extend(["", "## 主要混淆"])
|
||||
if not confusions:
|
||||
lines.append("- 未发现明显混淆对。")
|
||||
else:
|
||||
for item in confusions:
|
||||
lines.append(
|
||||
f"- 漏掉 `{item['expected_missing']}`,同时误报 `{item['wrong_extra']}`:`{item['count']}` 次"
|
||||
)
|
||||
lines.extend(["", "## 错误样例"])
|
||||
if not errors:
|
||||
lines.append("- 无错误样例。")
|
||||
else:
|
||||
for item in errors:
|
||||
lines.append(
|
||||
f"- 文本:`{item['text']}` | 类别:`{item['category']}` | 期望:`{item['expected_intent_ids']}` | 预测:`{item['predicted_intent_ids']}`"
|
||||
)
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"## 结论建议",
|
||||
"- 先看多意图句是否存在系统性漏召回,再看单意图是否被误报成多意图。",
|
||||
"- 若 `single_guard_false_alarm_rate` 偏高,需要先收紧 detector 阈值或补单意图负样本,再考虑进入 NER。",
|
||||
"- 若 `multi_sentence_recall` 不稳定,应继续补条件句、弱连接句和口语化多动作语料。",
|
||||
"",
|
||||
]
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
97
intelligent_cabin/archive/scripts/test_local_bert_intent.py
Normal file
97
intelligent_cabin/archive/scripts/test_local_bert_intent.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.classifier import BertIntentClassifier
|
||||
from app.services.router import build_matcher_pipeline
|
||||
|
||||
|
||||
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models/local_bert_intent"
|
||||
DEFAULT_LABEL_MAP = DEFAULT_MODEL_DIR / "label_map.json"
|
||||
|
||||
|
||||
def build_classifier(threshold: float, top_k: int) -> BertIntentClassifier:
|
||||
return BertIntentClassifier(
|
||||
model_path=str(DEFAULT_MODEL_DIR),
|
||||
threshold=threshold,
|
||||
label_map_path=str(DEFAULT_LABEL_MAP),
|
||||
fallback=None,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
|
||||
def predict_once(text: str, threshold: float, top_k: int) -> dict[str, object]:
|
||||
classifier = build_classifier(threshold=threshold, top_k=top_k)
|
||||
registry = build_intent_registry()
|
||||
intents = registry.list()
|
||||
result = classifier.predict(text, intents)
|
||||
matcher = build_matcher_pipeline(
|
||||
registry,
|
||||
["classifier"],
|
||||
classifier=classifier,
|
||||
route_to_cloud_threshold=settings.local_route_to_cloud_threshold,
|
||||
clarify_margin_threshold=settings.local_clarify_margin_threshold,
|
||||
)
|
||||
route_result = matcher.match(text)
|
||||
fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None)
|
||||
return {
|
||||
"text": text,
|
||||
"predicted_intent": result.intent.intent_id if result.intent is not None else None,
|
||||
"score": round(result.score, 4),
|
||||
"model_name": result.model_name,
|
||||
"backend": result.backend_name,
|
||||
"raw_label": result.raw_label,
|
||||
"fallback_reason": result.fallback_reason,
|
||||
"error_message": result.error_message,
|
||||
"decision": route_result.debug.decision,
|
||||
"decision_reason": route_result.debug.decision_reason,
|
||||
"confidence_grade": route_result.debug.confidence_grade,
|
||||
"unknown_detected": route_result.debug.unknown_detected,
|
||||
"fusion_top_score": round(fusion_stage.score, 4) if fusion_stage is not None else None,
|
||||
"top_candidates": [
|
||||
{"intent_id": intent.intent_id, "score": round(score, 4)}
|
||||
for intent, score in (result.candidates or [])
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def interactive_loop(threshold: float, top_k: int) -> None:
|
||||
print("本地 BERT 意图测试已启动,输入一句话直接查看预测结果,输入 exit 退出。")
|
||||
while True:
|
||||
try:
|
||||
text = input("\n请输入问题> ").strip()
|
||||
except EOFError:
|
||||
print()
|
||||
break
|
||||
if not text:
|
||||
continue
|
||||
if text.lower() in {"exit", "quit", "q"}:
|
||||
break
|
||||
result = predict_once(text, threshold=threshold, top_k=top_k)
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="本地 BERT 意图识别测试脚本")
|
||||
parser.add_argument("--text", type=str, default="", help="单次测试文本")
|
||||
parser.add_argument("--threshold", type=float, default=0.0, help="BERT 置信度阈值")
|
||||
parser.add_argument("--top-k", type=int, default=3, help="返回候选数量")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.text.strip():
|
||||
print(json.dumps(predict_once(args.text.strip(), threshold=args.threshold, top_k=args.top_k), ensure_ascii=False, indent=2))
|
||||
return
|
||||
interactive_loop(threshold=args.threshold, top_k=args.top_k)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
500
intelligent_cabin/archive/scripts/train_joint_bert_nlu.py
Normal file
500
intelligent_cabin/archive/scripts/train_joint_bert_nlu.py
Normal file
@@ -0,0 +1,500 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import AutoTokenizer
|
||||
import yaml
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.services.joint_nlu import JointBertForNLU
|
||||
|
||||
|
||||
TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
|
||||
MULTI_TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_train.jsonl"
|
||||
SEED_PATH = PROJECT_ROOT / "app/data/joint_nlu_seed.jsonl"
|
||||
EVAL_PATH = PROJECT_ROOT / "app/data/joint_nlu_eval.jsonl"
|
||||
MULTI_EVAL_PATH = PROJECT_ROOT / "app/data/joint_nlu_multilabel_eval.jsonl"
|
||||
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
|
||||
OUTPUT_DIR = PROJECT_ROOT / "models/local_joint_bert_nlu"
|
||||
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
|
||||
MAX_LENGTH = 64
|
||||
BATCH_SIZE = 8
|
||||
EPOCHS = 8
|
||||
LEARNING_RATE = 2e-5
|
||||
SEED = 42
|
||||
IGNORE_INDEX = -100
|
||||
GENRE_KEYWORDS = ("轻音乐", "摇滚", "古典", "民谣", "爵士", "流行", "儿歌")
|
||||
DEFAULT_INTENT_THRESHOLD = 0.3
|
||||
MULTI_INTENT_REPEAT = 6
|
||||
THRESHOLD_CANDIDATES = [0.1, 0.12, 0.15, 0.18, 0.2, 0.22, 0.25, 0.28, 0.3, 0.33, 0.35, 0.38, 0.4, 0.45]
|
||||
|
||||
|
||||
@dataclass
|
||||
class JointSample:
|
||||
text: str
|
||||
intent_ids: list[str]
|
||||
slots: list[dict[str, object]]
|
||||
|
||||
|
||||
class JointDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
samples: list[JointSample],
|
||||
tokenizer,
|
||||
intent_to_index: dict[str, int],
|
||||
slot_to_index: dict[str, int],
|
||||
) -> None:
|
||||
self._samples = samples
|
||||
self._tokenizer = tokenizer
|
||||
self._intent_to_index = intent_to_index
|
||||
self._slot_to_index = slot_to_index
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._samples)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
||||
sample = self._samples[index]
|
||||
encoded = self._tokenizer(
|
||||
sample.text,
|
||||
truncation=True,
|
||||
max_length=MAX_LENGTH,
|
||||
padding="max_length",
|
||||
return_offsets_mapping=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
offset_mapping = encoded.pop("offset_mapping")[0].tolist()
|
||||
slot_labels = [IGNORE_INDEX] * len(offset_mapping)
|
||||
char_labels = ["O"] * len(sample.text)
|
||||
for slot in sample.slots:
|
||||
start = int(slot["start"])
|
||||
end = int(slot["end"])
|
||||
slot_name = str(slot["slot_name"])
|
||||
if start < 0 or end > len(sample.text) or start >= end:
|
||||
continue
|
||||
char_labels[start] = f"B-{slot_name}"
|
||||
for pos in range(start + 1, end):
|
||||
char_labels[pos] = f"I-{slot_name}"
|
||||
|
||||
for token_index, (start, end) in enumerate(offset_mapping):
|
||||
if end <= start:
|
||||
continue
|
||||
label = char_labels[start]
|
||||
slot_labels[token_index] = self._slot_to_index.get(label, self._slot_to_index["O"])
|
||||
intent_vector = torch.zeros(len(self._intent_to_index), dtype=torch.float32)
|
||||
for intent_id in sample.intent_ids:
|
||||
if intent_id in self._intent_to_index:
|
||||
intent_vector[self._intent_to_index[intent_id]] = 1.0
|
||||
|
||||
return {
|
||||
"input_ids": encoded["input_ids"][0],
|
||||
"attention_mask": encoded["attention_mask"][0],
|
||||
"intent_labels": intent_vector,
|
||||
"slot_labels": torch.tensor(slot_labels, dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def set_seed() -> None:
|
||||
random.seed(SEED)
|
||||
torch.manual_seed(SEED)
|
||||
|
||||
|
||||
def load_jsonl(path: Path) -> list[dict[str, object]]:
|
||||
rows: list[dict[str, object]] = []
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def find_order_id_span(text: str) -> tuple[str, int, int] | None:
|
||||
match = re.search(r"[A-Za-z]\d{5,}", text)
|
||||
if not match:
|
||||
return None
|
||||
return match.group(0), match.start(), match.end()
|
||||
|
||||
|
||||
def find_temperature_span(text: str) -> tuple[str, int, int] | None:
|
||||
match = re.search(r"(\d{2}\s*度)", text)
|
||||
if not match:
|
||||
return None
|
||||
return match.group(1), match.start(), match.end()
|
||||
|
||||
|
||||
def find_destination_span(text: str) -> tuple[str, int, int] | None:
|
||||
for pattern in (
|
||||
r"导航去(?P<destination>.+)",
|
||||
r"导航到(?P<destination>.+)",
|
||||
r"带我去(?P<destination>.+)",
|
||||
r"送我去(?P<destination>.+)",
|
||||
r"去(?P<destination>.+)",
|
||||
):
|
||||
match = re.search(pattern, text)
|
||||
if not match:
|
||||
continue
|
||||
destination = re.split(r"(?:然后|并且|同时|再|,|,|;|;)", match.group("destination"), maxsplit=1)[0].strip(" ,。")
|
||||
if not destination:
|
||||
continue
|
||||
start = text.find(destination)
|
||||
if start >= 0:
|
||||
return destination, start, start + len(destination)
|
||||
return None
|
||||
|
||||
|
||||
def find_music_span(text: str) -> tuple[str, str, int, int] | None:
|
||||
for genre in GENRE_KEYWORDS:
|
||||
start = text.find(genre)
|
||||
if start >= 0:
|
||||
return "genre", genre, start, start + len(genre)
|
||||
for trigger in ("播放", "来点", "放点", "听", "来首", "来一首", "放一首"):
|
||||
if trigger not in text:
|
||||
continue
|
||||
target = text.split(trigger, maxsplit=1)[-1]
|
||||
target = re.split(r"(?:然后|并且|同时|再|,|,|;|;)", target, maxsplit=1)[0].strip(" 的一首首个歌曲音乐吧呀啊,。")
|
||||
if not target or target in {"歌", "音乐"}:
|
||||
continue
|
||||
for genre in GENRE_KEYWORDS:
|
||||
if genre in target:
|
||||
start = text.find(genre)
|
||||
return "genre", genre, start, start + len(genre)
|
||||
start = text.find(target)
|
||||
if start >= 0:
|
||||
return "song", target, start, start + len(target)
|
||||
return None
|
||||
|
||||
|
||||
def annotate_slots(text: str, intent_id: str) -> list[dict[str, object]]:
|
||||
slots: list[dict[str, object]] = []
|
||||
if intent_id in {"cs_query_order", "cs_query_logistics", "cs_cancel_order"}:
|
||||
matched = find_order_id_span(text)
|
||||
if matched is not None:
|
||||
value, start, end = matched
|
||||
slots.append({"slot_name": "order_id", "value": value, "start": start, "end": end})
|
||||
elif intent_id == "cabin_set_ac":
|
||||
matched = find_temperature_span(text)
|
||||
if matched is not None:
|
||||
value, start, end = matched
|
||||
slots.append({"slot_name": "temperature", "value": value, "start": start, "end": end})
|
||||
elif intent_id == "cabin_nav_to":
|
||||
matched = find_destination_span(text)
|
||||
if matched is not None:
|
||||
value, start, end = matched
|
||||
slots.append({"slot_name": "destination", "value": value, "start": start, "end": end})
|
||||
elif intent_id == "cabin_play_music":
|
||||
matched = find_music_span(text)
|
||||
if matched is not None:
|
||||
slot_name, value, start, end = matched
|
||||
slots.append({"slot_name": slot_name, "value": value, "start": start, "end": end})
|
||||
return slots
|
||||
|
||||
|
||||
def annotate_slots_for_intents(text: str, intent_ids: list[str]) -> list[dict[str, object]]:
|
||||
merged: list[dict[str, object]] = []
|
||||
seen: set[tuple[str, int, int]] = set()
|
||||
for intent_id in intent_ids:
|
||||
for slot in annotate_slots(text, intent_id):
|
||||
key = (str(slot["slot_name"]), int(slot["start"]), int(slot["end"]))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
merged.append(slot)
|
||||
merged.sort(key=lambda item: (int(item["start"]), int(item["end"])))
|
||||
return merged
|
||||
|
||||
|
||||
def build_train_samples() -> list[JointSample]:
|
||||
samples: list[JointSample] = []
|
||||
seen: set[tuple[str, tuple[str, ...]]] = set()
|
||||
domain_data = yaml.safe_load(DOMAIN_PATH.read_text(encoding="utf-8")) or {}
|
||||
for intent in domain_data.get("intents", []):
|
||||
intent_id = str(intent.get("intent_id", "")).strip()
|
||||
if not intent_id:
|
||||
continue
|
||||
for text in list(intent.get("examples", [])) + list(intent.get("keywords", [])):
|
||||
text = str(text).strip()
|
||||
if not text:
|
||||
continue
|
||||
key = (text, (intent_id,))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=annotate_slots(text, intent_id)))
|
||||
for row in load_jsonl(TRAIN_PATH):
|
||||
text = str(row["text"])
|
||||
intent_id = str(row["intent_id"])
|
||||
key = (text, (intent_id,))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=annotate_slots(text, intent_id)))
|
||||
for row in load_jsonl(SEED_PATH):
|
||||
text = str(row["text"])
|
||||
intent_id = str(row["intent_id"])
|
||||
key = (text, (intent_id,))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(JointSample(text=text, intent_ids=[intent_id], slots=list(row.get("slots", []))))
|
||||
for row in load_jsonl(MULTI_TRAIN_PATH):
|
||||
text = str(row["text"]).strip()
|
||||
intent_ids = sorted({str(item).strip() for item in row.get("intent_ids", []) if str(item).strip()})
|
||||
if not text or not intent_ids:
|
||||
continue
|
||||
key = (text, tuple(intent_ids))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
slots = list(row.get("slots") or annotate_slots_for_intents(text, intent_ids))
|
||||
samples.append(JointSample(text=text, intent_ids=intent_ids, slots=slots))
|
||||
if len(intent_ids) >= 2:
|
||||
for _ in range(MULTI_INTENT_REPEAT - 1):
|
||||
samples.append(JointSample(text=text, intent_ids=intent_ids, slots=list(slots)))
|
||||
random.shuffle(samples)
|
||||
return samples
|
||||
|
||||
|
||||
def build_eval_samples() -> list[JointSample]:
|
||||
rows = load_jsonl(EVAL_PATH)
|
||||
samples = [
|
||||
JointSample(
|
||||
text=str(row["text"]),
|
||||
intent_ids=[str(row["intent_id"])],
|
||||
slots=list(row.get("slots", [])),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
if MULTI_EVAL_PATH.exists():
|
||||
for row in load_jsonl(MULTI_EVAL_PATH):
|
||||
samples.append(
|
||||
JointSample(
|
||||
text=str(row["text"]),
|
||||
intent_ids=sorted({str(item).strip() for item in row.get("intent_ids", []) if str(item).strip()}),
|
||||
slots=list(row.get("slots") or annotate_slots_for_intents(str(row["text"]), list(row.get("intent_ids", [])))),
|
||||
)
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
def build_slot_labels(samples: list[JointSample]) -> list[str]:
|
||||
slot_names = sorted({str(slot["slot_name"]) for sample in samples for slot in sample.slots})
|
||||
labels = ["O"]
|
||||
for name in slot_names:
|
||||
labels.append(f"B-{name}")
|
||||
labels.append(f"I-{name}")
|
||||
return labels
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
model: JointBertForNLU,
|
||||
dataloader: DataLoader,
|
||||
device: torch.device,
|
||||
intent_labels: list[str],
|
||||
slot_labels: list[str],
|
||||
threshold: float,
|
||||
) -> dict[str, float]:
|
||||
model.eval()
|
||||
intent_tp = 0
|
||||
intent_fp = 0
|
||||
intent_fn = 0
|
||||
single_intent_correct = 0
|
||||
single_intent_total = 0
|
||||
intent_exact_match = 0
|
||||
correct_slot_tokens = 0
|
||||
total_slot_tokens = 0
|
||||
exact_slot_samples = 0
|
||||
total_samples = 0
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
batch = {key: value.to(device) for key, value in batch.items()}
|
||||
intent_logits, slot_logits = model(batch["input_ids"], batch["attention_mask"])
|
||||
predicted_probs = torch.sigmoid(intent_logits)
|
||||
predicted_multi = predicted_probs >= threshold
|
||||
gold_multi = batch["intent_labels"] > 0.5
|
||||
intent_tp += int((predicted_multi & gold_multi).sum().item())
|
||||
intent_fp += int((predicted_multi & ~gold_multi).sum().item())
|
||||
intent_fn += int((~predicted_multi & gold_multi).sum().item())
|
||||
intent_exact_match += int((predicted_multi == gold_multi).all(dim=1).sum().item())
|
||||
top_predicted = torch.argmax(predicted_probs, dim=-1)
|
||||
gold_counts = gold_multi.sum(dim=-1)
|
||||
single_mask = gold_counts == 1
|
||||
if int(single_mask.sum().item()) > 0:
|
||||
gold_top = torch.argmax(gold_multi.float(), dim=-1)
|
||||
single_intent_correct += int((top_predicted[single_mask] == gold_top[single_mask]).sum().item())
|
||||
single_intent_total += int(single_mask.sum().item())
|
||||
|
||||
predicted_slots = torch.argmax(slot_logits, dim=-1)
|
||||
mask = batch["slot_labels"] != IGNORE_INDEX
|
||||
correct_slot_tokens += int(((predicted_slots == batch["slot_labels"]) & mask).sum().item())
|
||||
total_slot_tokens += int(mask.sum().item())
|
||||
|
||||
for index in range(batch["slot_labels"].size(0)):
|
||||
gold = batch["slot_labels"][index][mask[index]]
|
||||
pred = predicted_slots[index][mask[index]]
|
||||
exact_slot_samples += int(torch.equal(gold, pred))
|
||||
total_samples += 1
|
||||
precision = intent_tp / max(intent_tp + intent_fp, 1)
|
||||
recall = intent_tp / max(intent_tp + intent_fn, 1)
|
||||
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
return {
|
||||
"intent_threshold": round(threshold, 4),
|
||||
"intent_micro_precision": round(precision, 4),
|
||||
"intent_micro_recall": round(recall, 4),
|
||||
"intent_micro_f1": round(f1, 4),
|
||||
"intent_exact_match": round(intent_exact_match / max(total_samples, 1), 4),
|
||||
"single_intent_top1_accuracy": round(single_intent_correct / max(single_intent_total, 1), 4),
|
||||
"slot_token_accuracy": round(correct_slot_tokens / max(total_slot_tokens, 1), 4),
|
||||
"slot_exact_match": round(exact_slot_samples / max(total_samples, 1), 4),
|
||||
"intent_label_count": float(len(intent_labels)),
|
||||
"slot_label_count": float(len(slot_labels)),
|
||||
}
|
||||
|
||||
|
||||
def search_best_threshold(
|
||||
model: JointBertForNLU,
|
||||
dataloader: DataLoader,
|
||||
device: torch.device,
|
||||
intent_labels: list[str],
|
||||
slot_labels: list[str],
|
||||
) -> dict[str, float]:
|
||||
best_metrics: dict[str, float] | None = None
|
||||
for threshold in THRESHOLD_CANDIDATES:
|
||||
metrics = compute_metrics(
|
||||
model,
|
||||
dataloader,
|
||||
device,
|
||||
intent_labels,
|
||||
slot_labels,
|
||||
threshold=threshold,
|
||||
)
|
||||
if best_metrics is None:
|
||||
best_metrics = metrics
|
||||
continue
|
||||
current_score = (metrics["intent_micro_f1"], metrics["intent_exact_match"], metrics["slot_exact_match"])
|
||||
best_score = (
|
||||
best_metrics["intent_micro_f1"],
|
||||
best_metrics["intent_exact_match"],
|
||||
best_metrics["slot_exact_match"],
|
||||
)
|
||||
if current_score > best_score:
|
||||
best_metrics = metrics
|
||||
assert best_metrics is not None
|
||||
return best_metrics
|
||||
|
||||
|
||||
def build_pos_weight(samples: list[JointSample], intent_labels: list[str]) -> torch.Tensor:
|
||||
positive_counts = {label: 0 for label in intent_labels}
|
||||
for sample in samples:
|
||||
sample_intents = set(sample.intent_ids)
|
||||
for label in intent_labels:
|
||||
if label in sample_intents:
|
||||
positive_counts[label] += 1
|
||||
total = max(len(samples), 1)
|
||||
weights: list[float] = []
|
||||
for label in intent_labels:
|
||||
positives = max(positive_counts[label], 1)
|
||||
negatives = max(total - positives, 1)
|
||||
weight = negatives / positives
|
||||
weights.append(min(max(weight, 1.0), 12.0))
|
||||
return torch.tensor(weights, dtype=torch.float32)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
set_seed()
|
||||
train_samples = build_train_samples()
|
||||
eval_samples = build_eval_samples()
|
||||
intent_labels = sorted({intent_id for sample in train_samples + eval_samples for intent_id in sample.intent_ids})
|
||||
slot_labels = build_slot_labels(train_samples + eval_samples)
|
||||
intent_to_index = {label: index for index, label in enumerate(intent_labels)}
|
||||
slot_to_index = {label: index for index, label in enumerate(slot_labels)}
|
||||
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_BASE_MODEL)
|
||||
train_dataset = JointDataset(train_samples, tokenizer, intent_to_index, slot_to_index)
|
||||
eval_dataset = JointDataset(eval_samples, tokenizer, intent_to_index, slot_to_index)
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
model = JointBertForNLU(
|
||||
base_model_name=DEFAULT_BASE_MODEL,
|
||||
num_intents=len(intent_labels),
|
||||
num_slot_labels=len(slot_labels),
|
||||
)
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
pos_weight = build_pos_weight(train_samples, intent_labels).to(device)
|
||||
best_metrics: dict[str, float] | None = None
|
||||
best_state: dict[str, torch.Tensor] | None = None
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
model.train()
|
||||
epoch_loss = 0.0
|
||||
for batch in train_loader:
|
||||
batch = {key: value.to(device) for key, value in batch.items()}
|
||||
optimizer.zero_grad()
|
||||
intent_logits, slot_logits = model(batch["input_ids"], batch["attention_mask"])
|
||||
intent_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
intent_logits,
|
||||
batch["intent_labels"],
|
||||
pos_weight=pos_weight,
|
||||
)
|
||||
slot_loss = torch.nn.functional.cross_entropy(
|
||||
slot_logits.view(-1, slot_logits.size(-1)),
|
||||
batch["slot_labels"].view(-1),
|
||||
ignore_index=IGNORE_INDEX,
|
||||
)
|
||||
loss = intent_loss + slot_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
epoch_loss += float(loss.item())
|
||||
|
||||
metrics = search_best_threshold(model, eval_loader, device, intent_labels, slot_labels)
|
||||
metrics["train_loss"] = round(epoch_loss / max(len(train_loader), 1), 4)
|
||||
print(json.dumps({"epoch": epoch + 1, **metrics}, ensure_ascii=False))
|
||||
if best_metrics is None or metrics["intent_micro_f1"] > best_metrics["intent_micro_f1"]:
|
||||
best_metrics = metrics
|
||||
best_state = {key: value.detach().cpu() for key, value in model.state_dict().items()}
|
||||
|
||||
if best_state is None or best_metrics is None:
|
||||
raise RuntimeError("joint nlu training did not produce a best checkpoint")
|
||||
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||||
torch.save(best_state, OUTPUT_DIR / "model_state.pt")
|
||||
config = {
|
||||
"base_model_name": DEFAULT_BASE_MODEL,
|
||||
"intent_task": "multi_label",
|
||||
"intent_labels": intent_labels,
|
||||
"slot_labels": slot_labels,
|
||||
"max_length": MAX_LENGTH,
|
||||
"intent_threshold": float(best_metrics["intent_threshold"]),
|
||||
"multi_intent_threshold": float(best_metrics["intent_threshold"]),
|
||||
"max_multi_intents": 4,
|
||||
}
|
||||
(OUTPUT_DIR / "joint_nlu_config.json").write_text(json.dumps(config, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
(OUTPUT_DIR / "train_summary.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"train_size": len(train_samples),
|
||||
"eval_size": len(eval_samples),
|
||||
"metrics": best_metrics,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
684
intelligent_cabin/archive/scripts/train_local_bert_intent.py
Normal file
684
intelligent_cabin/archive/scripts/train_local_bert_intent.py
Normal file
@@ -0,0 +1,684 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
import yaml
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
TRAIN_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
|
||||
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
|
||||
OUTPUT_DIR = PROJECT_ROOT / "models/local_bert_intent"
|
||||
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
|
||||
MAX_LENGTH = 48
|
||||
BATCH_SIZE = 8
|
||||
EPOCHS = 16
|
||||
LEARNING_RATE = 2e-5
|
||||
SEED = 42
|
||||
|
||||
ORDER_IDS = ["A123456", "A700001", "A800002", "A900005", "A202501", "A808001"]
|
||||
DESTINATIONS = ["公司停车场", "浦东机场", "徐家汇", "虹桥机场", "最近的充电站", "南京东路"]
|
||||
TEMPERATURES = [18, 20, 21, 22, 23, 24, 26]
|
||||
SONGS = ["夜曲", "稻香", "青花瓷", "晴天", "告白气球"]
|
||||
GENRES = ["轻音乐", "摇滚", "古典音乐", "民谣", "爵士"]
|
||||
SOCIAL_LABEL = "__social__"
|
||||
OUT_OF_SCOPE_LABEL = "__out_of_scope__"
|
||||
|
||||
TEMPLATES: dict[str, list[str]] = {
|
||||
"cs_query_order": [
|
||||
"查一下订单{order_id}现在什么状态",
|
||||
"我的订单{order_id}到哪一步了",
|
||||
"帮我看看{order_id}这个订单",
|
||||
"确认下{order_id}订单状态",
|
||||
"订单{order_id}现在处理到哪里",
|
||||
"看下{order_id}这单进度",
|
||||
"订单号{order_id}目前怎么样",
|
||||
"帮忙确认订单{order_id}",
|
||||
"订单{order_id}有结果了吗",
|
||||
"帮我追一下订单{order_id}",
|
||||
"订单{order_id}现在受理了吗",
|
||||
"看看{order_id}这单现在啥情况",
|
||||
"帮我查查{order_id}订单进展",
|
||||
"{order_id}这个订单处理好了没",
|
||||
"{order_id}这笔订单现在进展到哪了",
|
||||
],
|
||||
"cs_query_logistics": [
|
||||
"快递{order_id}到哪儿了",
|
||||
"帮我查{order_id}物流进度",
|
||||
"看看{order_id}配送状态",
|
||||
"订单{order_id}物流更新了吗",
|
||||
"查询{order_id}的快递信息",
|
||||
"我的{order_id}现在派送到哪了",
|
||||
"查一下{order_id}这单物流",
|
||||
"配送单{order_id}走到哪里了",
|
||||
"帮我看下{order_id}快递到没到",
|
||||
"物流单号{order_id}现在在哪",
|
||||
"订单{order_id}物流到哪一步了",
|
||||
"{order_id}这单现在派件了吗",
|
||||
"帮我追踪{order_id}运输轨迹",
|
||||
"{order_id}快件现在运到哪里了",
|
||||
"我想看{order_id}的配送更新",
|
||||
],
|
||||
"cs_cancel_order": [
|
||||
"帮我取消{order_id}这个订单",
|
||||
"{order_id}别要了给我撤销",
|
||||
"把订单{order_id}取消掉",
|
||||
"我不要{order_id}了",
|
||||
"撤销一下{order_id}订单",
|
||||
"订单{order_id}不要发了",
|
||||
"帮我把{order_id}退掉并取消",
|
||||
"把{order_id}这一单停掉",
|
||||
"{order_id}这单直接取消",
|
||||
"订单号{order_id}撤回一下",
|
||||
"订单{order_id}我不想要了",
|
||||
"{order_id}这笔订单先别发了",
|
||||
"把{order_id}这单给我撤单",
|
||||
"订单{order_id}停掉吧",
|
||||
"{order_id}这个快给我取消了",
|
||||
],
|
||||
"cs_transfer_human": [
|
||||
"我要找人工客服处理",
|
||||
"现在转人工",
|
||||
"麻烦给我接人工服务",
|
||||
"帮我呼叫真人客服",
|
||||
"别机器人了我要人工",
|
||||
"转真人客服",
|
||||
"我要人工坐席",
|
||||
"帮我接人工处理",
|
||||
"叫人工客服来",
|
||||
"直接给我转人工",
|
||||
"这个问题给我人工跟进",
|
||||
"安排真人客服接手",
|
||||
"机器人处理不了,转人工",
|
||||
"帮我叫个客服专员",
|
||||
"我要人工来处理这事",
|
||||
],
|
||||
"cabin_nav_to": [
|
||||
"导航到{destination}",
|
||||
"带我去{destination}",
|
||||
"我要去{destination}",
|
||||
"去{destination}",
|
||||
"开导航去{destination}",
|
||||
"帮我导航到{destination}",
|
||||
"送我去{destination}",
|
||||
"现在去{destination}",
|
||||
"带路到{destination}",
|
||||
"去一下{destination}",
|
||||
"规划路线去{destination}",
|
||||
"直接开去{destination}",
|
||||
"给我导到{destination}",
|
||||
"{destination}怎么走,导航一下",
|
||||
"出发去{destination}",
|
||||
],
|
||||
"cabin_set_ac": [
|
||||
"把空调设到{temperature}度",
|
||||
"车里温度调成{temperature}度",
|
||||
"冷气开到{temperature}度",
|
||||
"空调给我调到{temperature}度",
|
||||
"温度改成{temperature}度",
|
||||
"车内设成{temperature}度",
|
||||
"把温度打到{temperature}度",
|
||||
"空调调为{temperature}度",
|
||||
"帮我把车里调成{temperature}度",
|
||||
"冷风调到{temperature}度",
|
||||
"把车内温度设为{temperature}度",
|
||||
"空调温度改到{temperature}度",
|
||||
"冷气帮我调到{temperature}度",
|
||||
"舱内调成{temperature}度",
|
||||
"给我把温度定在{temperature}度",
|
||||
"把车里弄凉快点",
|
||||
"车里太热了,降一点",
|
||||
"把里面调凉快一点",
|
||||
"有点热,降温",
|
||||
"空调再冷一点",
|
||||
"车内温度低一点",
|
||||
"把里面弄暖和点",
|
||||
"车里太冷了,升一点温度",
|
||||
],
|
||||
"cabin_ac_on": [
|
||||
"把空调打开",
|
||||
"开一下冷气",
|
||||
"把冷风开起来",
|
||||
"车里热,空调开开",
|
||||
"打开制冷",
|
||||
"空调启动一下",
|
||||
],
|
||||
"cabin_window_open": [
|
||||
"把车窗打开",
|
||||
"开下窗",
|
||||
"窗户开一点",
|
||||
"帮我透透气",
|
||||
"车里太闷了,开下窗",
|
||||
"顺便开下车窗",
|
||||
"把窗户降一点",
|
||||
"把玻璃打开一点",
|
||||
],
|
||||
"cabin_window_close": [
|
||||
"把车窗关上",
|
||||
"窗户关一下",
|
||||
"把窗升起来",
|
||||
"外面太吵了,把窗关了",
|
||||
"把窗户关严",
|
||||
],
|
||||
"cabin_fan_down": [
|
||||
"风别这么大",
|
||||
"风小一点",
|
||||
"别吹这么猛",
|
||||
"把风量调小一点",
|
||||
"出风弱一点",
|
||||
],
|
||||
"cabin_fan_up": [
|
||||
"风再大一点",
|
||||
"把风量开大点",
|
||||
"出风强一点",
|
||||
"风不够,调大些",
|
||||
],
|
||||
"cabin_defog_front_on": [
|
||||
"前挡起雾了,除一下",
|
||||
"把前挡风玻璃雾气清掉",
|
||||
"前窗看不清了,开除雾",
|
||||
],
|
||||
"cabin_defog_rear_on": [
|
||||
"后挡有雾,开下除雾",
|
||||
"后玻璃起雾了,清一下",
|
||||
"后窗看不清了,除雾",
|
||||
],
|
||||
"cabin_play_music": [
|
||||
"播放一首{genre}",
|
||||
"来点{genre}",
|
||||
"我想听{genre}",
|
||||
"给我播点{genre}",
|
||||
"放一首{song}",
|
||||
"来一首{song}",
|
||||
"播放{song}",
|
||||
"放点音乐,来个{genre}",
|
||||
"我想听首{song}",
|
||||
"给我来点歌,放{song}",
|
||||
"随机放点{genre}",
|
||||
"帮我播首{song}",
|
||||
"来点适合开车听的{genre}",
|
||||
"打开音乐,放{song}",
|
||||
"给我放一些{genre}",
|
||||
"放点歌",
|
||||
"来首歌",
|
||||
"整点音乐",
|
||||
"车里放点歌",
|
||||
"来点能听的",
|
||||
],
|
||||
SOCIAL_LABEL: [
|
||||
"你好",
|
||||
"嗨",
|
||||
"哈喽",
|
||||
"早上好",
|
||||
"晚上好",
|
||||
"你叫什么名字",
|
||||
"你是谁",
|
||||
"你能做什么",
|
||||
"今天天气不错",
|
||||
"陪我聊聊天",
|
||||
],
|
||||
OUT_OF_SCOPE_LABEL: [
|
||||
"帮我点个外卖",
|
||||
"订一张去北京的机票",
|
||||
"帮我买杯咖啡",
|
||||
"给我订一家酒店",
|
||||
"人类诞生的意义是什么",
|
||||
"帮我写一份年终总结",
|
||||
"推荐一部电影",
|
||||
"讲个笑话",
|
||||
"帮我做一道数学题",
|
||||
"去美团叫个外卖",
|
||||
],
|
||||
}
|
||||
|
||||
INTENT_REPLACEMENTS: dict[str, list[tuple[str, str]]] = {
|
||||
"cs_query_order": [
|
||||
("订单", "这单"),
|
||||
("查一下", "看一下"),
|
||||
("帮我", "麻烦帮我"),
|
||||
("现在什么状态", "现在啥状态"),
|
||||
("处理到哪里", "进展到哪里"),
|
||||
],
|
||||
"cs_query_logistics": [
|
||||
("物流", "快递"),
|
||||
("快递", "配送"),
|
||||
("配送", "派送"),
|
||||
("帮我", "麻烦帮我"),
|
||||
("现在在哪", "现在到哪了"),
|
||||
],
|
||||
"cs_cancel_order": [
|
||||
("取消", "撤销"),
|
||||
("撤销", "撤单"),
|
||||
("订单", "这单"),
|
||||
("帮我", "麻烦帮我"),
|
||||
("不要发了", "别发了"),
|
||||
],
|
||||
"cs_transfer_human": [
|
||||
("人工客服", "真人客服"),
|
||||
("人工", "人工坐席"),
|
||||
("帮我", "麻烦帮我"),
|
||||
],
|
||||
"cabin_nav_to": [
|
||||
("导航", "带路"),
|
||||
("带我", "送我"),
|
||||
("去", "前往"),
|
||||
],
|
||||
"cabin_set_ac": [
|
||||
("空调", "车里温度"),
|
||||
("调到", "设到"),
|
||||
("温度", "车内温度"),
|
||||
("凉快点", "冷一点"),
|
||||
("暖和点", "热一点"),
|
||||
],
|
||||
"cabin_ac_on": [
|
||||
("空调", "冷气"),
|
||||
("打开", "开"),
|
||||
("冷风", "制冷"),
|
||||
],
|
||||
"cabin_window_open": [
|
||||
("车窗", "窗户"),
|
||||
("打开", "开"),
|
||||
("透透气", "通通风"),
|
||||
],
|
||||
"cabin_window_close": [
|
||||
("关上", "关掉"),
|
||||
("车窗", "窗户"),
|
||||
("关严", "关好"),
|
||||
],
|
||||
"cabin_fan_down": [
|
||||
("风量", "风"),
|
||||
("调小", "调低"),
|
||||
("别吹这么猛", "风小一点"),
|
||||
],
|
||||
"cabin_fan_up": [
|
||||
("风量", "风"),
|
||||
("调大", "调高"),
|
||||
],
|
||||
"cabin_defog_front_on": [
|
||||
("前挡", "前窗"),
|
||||
("除雾", "除一下"),
|
||||
],
|
||||
"cabin_defog_rear_on": [
|
||||
("后挡", "后窗"),
|
||||
("除雾", "清一下雾"),
|
||||
],
|
||||
"cabin_play_music": [
|
||||
("播放", "放"),
|
||||
("来点", "播点"),
|
||||
("我想听", "给我来点"),
|
||||
("放点歌", "来首歌"),
|
||||
],
|
||||
SOCIAL_LABEL: [
|
||||
("你好", "您好"),
|
||||
("哈喽", "hello"),
|
||||
("你叫什么名字", "怎么称呼你"),
|
||||
],
|
||||
OUT_OF_SCOPE_LABEL: [
|
||||
("点个外卖", "叫个外卖"),
|
||||
("订一家酒店", "订个酒店"),
|
||||
("讲个笑话", "说个笑话"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Sample:
|
||||
text: str
|
||||
intent_id: str
|
||||
|
||||
|
||||
HARD_NEGATIVE_RAW_SAMPLES: list[tuple[str, str]] = [
|
||||
("订单A700101物流到哪了", "cs_query_logistics"),
|
||||
("帮我看下订单A700102配送到哪里了", "cs_query_logistics"),
|
||||
("订单A700103现在派件了吗", "cs_query_logistics"),
|
||||
("A700104这单物流有没有更新", "cs_query_logistics"),
|
||||
("查一下订单A700105运输轨迹", "cs_query_logistics"),
|
||||
("订单A700106不要了,帮我撤单", "cs_cancel_order"),
|
||||
("A700107这单别发了,直接取消", "cs_cancel_order"),
|
||||
("把订单A700108停掉吧", "cs_cancel_order"),
|
||||
("A700109这个订单我不想要了", "cs_cancel_order"),
|
||||
("订单A700110给我撤回", "cs_cancel_order"),
|
||||
("订单A700111现在受理了吗", "cs_query_order"),
|
||||
("帮我看看A700112这单处理得怎么样了", "cs_query_order"),
|
||||
("A700113订单目前进展如何", "cs_query_order"),
|
||||
("查下订单A700114现在什么情况", "cs_query_order"),
|
||||
("帮我确认订单A700115是否已经处理", "cs_query_order"),
|
||||
("你好呀", SOCIAL_LABEL),
|
||||
("嗨,在吗", SOCIAL_LABEL),
|
||||
("今天天气真不错", SOCIAL_LABEL),
|
||||
("你叫什么名字呀", SOCIAL_LABEL),
|
||||
("你是做什么的", SOCIAL_LABEL),
|
||||
("陪我随便聊聊", SOCIAL_LABEL),
|
||||
("帮我透透气", "cabin_window_open"),
|
||||
("车里太闷了,开下窗", "cabin_window_open"),
|
||||
("把窗户降一点", "cabin_window_open"),
|
||||
("前挡起雾了,除一下", "cabin_defog_front_on"),
|
||||
("后挡有雾,开下除雾", "cabin_defog_rear_on"),
|
||||
("把车里弄凉快点", "cabin_set_ac"),
|
||||
("车里太热了,降一点", "cabin_set_ac"),
|
||||
("空调再冷一点", "cabin_set_ac"),
|
||||
("把里面弄暖和点", "cabin_set_ac"),
|
||||
("风别这么大", "cabin_fan_down"),
|
||||
("别吹这么猛", "cabin_fan_down"),
|
||||
("风再大一点", "cabin_fan_up"),
|
||||
("放点歌", "cabin_play_music"),
|
||||
("来首歌", "cabin_play_music"),
|
||||
("整点音乐", "cabin_play_music"),
|
||||
("透透气", "cabin_window_open"),
|
||||
("通通风", "cabin_window_open"),
|
||||
("车里太闷了", "cabin_window_open"),
|
||||
("凉快点", "cabin_set_ac"),
|
||||
("暖和点", "cabin_set_ac"),
|
||||
("帮我点一份麻辣烫", OUT_OF_SCOPE_LABEL),
|
||||
("给我订今晚的酒店", OUT_OF_SCOPE_LABEL),
|
||||
("帮我买张电影票", OUT_OF_SCOPE_LABEL),
|
||||
("人为什么会做梦", OUT_OF_SCOPE_LABEL),
|
||||
("帮我做个旅游攻略", OUT_OF_SCOPE_LABEL),
|
||||
("帮我点肯德基外卖", OUT_OF_SCOPE_LABEL),
|
||||
("透透气,别给我除雾", "cabin_window_open"),
|
||||
("后挡有雾,不是开窗,是除雾", "cabin_defog_rear_on"),
|
||||
("前挡看不清了,除雾不要开窗", "cabin_defog_front_on"),
|
||||
("凉快点,不是把风量调小", "cabin_set_ac"),
|
||||
("别吹这么猛,不是降温", "cabin_fan_down"),
|
||||
("来首歌,不是切下一首", "cabin_play_music"),
|
||||
("放点歌,不是暂停音乐", "cabin_play_music"),
|
||||
]
|
||||
|
||||
HARD_NEGATIVE_SAMPLES: list[Sample] = [
|
||||
Sample(text=text, intent_id=intent_id) for text, intent_id in HARD_NEGATIVE_RAW_SAMPLES
|
||||
]
|
||||
|
||||
|
||||
class IntentDataset(Dataset):
|
||||
def __init__(self, samples: list[Sample], tokenizer, label_to_id: dict[str, int]) -> None:
|
||||
self._samples = samples
|
||||
self._tokenizer = tokenizer
|
||||
self._label_to_id = label_to_id
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._samples)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
||||
sample = self._samples[index]
|
||||
encoded = self._tokenizer(
|
||||
sample.text,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=MAX_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return {
|
||||
"input_ids": encoded["input_ids"].squeeze(0),
|
||||
"attention_mask": encoded["attention_mask"].squeeze(0),
|
||||
"labels": torch.tensor(self._label_to_id[sample.intent_id], dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
|
||||
def load_samples(file_path: Path) -> list[Sample]:
|
||||
samples: list[Sample] = []
|
||||
if not file_path.exists():
|
||||
return samples
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
samples.append(Sample(text=str(payload["text"]), intent_id=str(payload["intent_id"])))
|
||||
return samples
|
||||
|
||||
|
||||
def load_domain_samples(file_path: Path) -> list[Sample]:
|
||||
if not file_path.exists():
|
||||
return []
|
||||
payload = yaml.safe_load(file_path.read_text(encoding="utf-8")) or {}
|
||||
intents = payload.get("intents", [])
|
||||
samples: list[Sample] = []
|
||||
seen: set[tuple[str, str]] = set()
|
||||
for item in intents:
|
||||
intent_id = str(item.get("intent_id") or "").strip()
|
||||
if not intent_id:
|
||||
continue
|
||||
seed_texts = list(item.get("examples") or [])
|
||||
seed_texts.extend(item.get("keywords") or [])
|
||||
label = str(item.get("label") or "").strip()
|
||||
if label:
|
||||
seed_texts.append(label)
|
||||
for text in seed_texts:
|
||||
normalized = str(text).strip()
|
||||
if not normalized:
|
||||
continue
|
||||
for variant in expand_seed_variants(normalized):
|
||||
key = (variant, intent_id)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(Sample(text=variant, intent_id=intent_id))
|
||||
return samples
|
||||
|
||||
|
||||
def expand_seed_variants(text: str) -> list[str]:
|
||||
normalized = text.strip().strip(",。!?;; ")
|
||||
if not normalized:
|
||||
return []
|
||||
variants = {
|
||||
normalized,
|
||||
normalized.replace("一下", "").strip(),
|
||||
normalized.replace("帮我", "").strip(),
|
||||
normalized.replace("请", "").strip(),
|
||||
f"帮我{normalized}",
|
||||
f"请{normalized}",
|
||||
f"{normalized}一下",
|
||||
}
|
||||
cleaned: list[str] = []
|
||||
for item in variants:
|
||||
compact = " ".join(item.split()).strip(",。!?;; ")
|
||||
if compact:
|
||||
cleaned.append(compact)
|
||||
return cleaned
|
||||
|
||||
|
||||
def load_training_samples() -> list[Sample]:
|
||||
samples = load_samples(TRAIN_PATH)
|
||||
samples.extend(load_domain_samples(DOMAIN_PATH))
|
||||
deduped: list[Sample] = []
|
||||
seen: set[tuple[str, str]] = set()
|
||||
for sample in samples:
|
||||
key = (sample.text, sample.intent_id)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(sample)
|
||||
return deduped
|
||||
|
||||
|
||||
def augment_samples(samples: list[Sample]) -> list[Sample]:
|
||||
augmented = list(samples)
|
||||
seen = {(sample.text, sample.intent_id) for sample in augmented}
|
||||
for intent_id, templates in TEMPLATES.items():
|
||||
for index, template in enumerate(templates):
|
||||
sample = render_template(intent_id, template, index)
|
||||
key = (sample.text, sample.intent_id)
|
||||
if key not in seen:
|
||||
augmented.append(sample)
|
||||
seen.add(key)
|
||||
|
||||
for sample in HARD_NEGATIVE_SAMPLES:
|
||||
key = (sample.text, sample.intent_id)
|
||||
if key not in seen:
|
||||
augmented.append(sample)
|
||||
seen.add(key)
|
||||
|
||||
for sample in list(augmented):
|
||||
text = sample.text
|
||||
for source, target in INTENT_REPLACEMENTS.get(sample.intent_id, []):
|
||||
if source in text:
|
||||
variant = text.replace(source, target, 1)
|
||||
key = (variant, sample.intent_id)
|
||||
if key not in seen:
|
||||
augmented.append(Sample(text=variant, intent_id=sample.intent_id))
|
||||
seen.add(key)
|
||||
|
||||
compact = text
|
||||
for token in ("帮我", "麻烦", "请", "一下"):
|
||||
if token in compact:
|
||||
compact = compact.replace(token, "", 1)
|
||||
compact = compact.strip(" ,。!?")
|
||||
if compact and compact != text:
|
||||
key = (compact, sample.intent_id)
|
||||
if key not in seen:
|
||||
augmented.append(Sample(text=compact, intent_id=sample.intent_id))
|
||||
seen.add(key)
|
||||
random.shuffle(augmented)
|
||||
return augmented
|
||||
|
||||
|
||||
def render_template(intent_id: str, template: str, index: int) -> Sample:
|
||||
order_id = ORDER_IDS[index % len(ORDER_IDS)]
|
||||
destination = DESTINATIONS[index % len(DESTINATIONS)]
|
||||
temperature = TEMPERATURES[index % len(TEMPERATURES)]
|
||||
song = SONGS[index % len(SONGS)]
|
||||
genre = GENRES[index % len(GENRES)]
|
||||
text = template.format(
|
||||
order_id=order_id,
|
||||
destination=destination,
|
||||
temperature=temperature,
|
||||
song=song,
|
||||
genre=genre,
|
||||
)
|
||||
return Sample(text=text, intent_id=intent_id)
|
||||
|
||||
|
||||
def split_samples(samples: list[Sample]) -> tuple[list[Sample], list[Sample]]:
|
||||
grouped: dict[str, list[Sample]] = {}
|
||||
for sample in samples:
|
||||
grouped.setdefault(sample.intent_id, []).append(sample)
|
||||
train_samples: list[Sample] = []
|
||||
dev_samples: list[Sample] = []
|
||||
for items in grouped.values():
|
||||
random.shuffle(items)
|
||||
cut = max(1, int(len(items) * 0.8))
|
||||
train_samples.extend(items[:cut])
|
||||
dev_samples.extend(items[cut:])
|
||||
random.shuffle(train_samples)
|
||||
random.shuffle(dev_samples)
|
||||
return train_samples, dev_samples
|
||||
|
||||
|
||||
def accuracy(model, loader: DataLoader, device: torch.device) -> float:
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for batch in loader:
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
preds = outputs.logits.argmax(dim=-1)
|
||||
correct += int((preds == labels).sum().item())
|
||||
total += int(labels.numel())
|
||||
return correct / total if total else 0.0
|
||||
|
||||
|
||||
def resolve_base_model() -> str:
|
||||
configured = os.getenv("AGENT_BERT_BASE_MODEL", "").strip()
|
||||
if configured:
|
||||
return configured
|
||||
return DEFAULT_BASE_MODEL
|
||||
|
||||
|
||||
def main() -> None:
|
||||
set_seed(SEED)
|
||||
samples = augment_samples(load_training_samples())
|
||||
intents = sorted({sample.intent_id for sample in samples})
|
||||
label_to_id = {intent_id: index for index, intent_id in enumerate(intents)}
|
||||
id_to_label = {index: intent_id for intent_id, index in label_to_id.items()}
|
||||
|
||||
train_samples, dev_samples = split_samples(samples)
|
||||
base_model = resolve_base_model()
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||
train_dataset = IntentDataset(train_samples, tokenizer, label_to_id)
|
||||
dev_dataset = IntentDataset(dev_samples, tokenizer, label_to_id)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
base_model,
|
||||
num_labels=len(intents),
|
||||
id2label=id_to_label,
|
||||
label2id=label_to_id,
|
||||
)
|
||||
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
best_dev_acc = 0.0
|
||||
best_state = None
|
||||
|
||||
for epoch in range(1, EPOCHS + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += float(loss.item())
|
||||
|
||||
dev_acc = accuracy(model, dev_loader, device)
|
||||
avg_loss = total_loss / max(len(train_loader), 1)
|
||||
print(f"epoch={epoch} loss={avg_loss:.4f} dev_acc={dev_acc:.4f}")
|
||||
if dev_acc >= best_dev_acc:
|
||||
best_dev_acc = dev_acc
|
||||
best_state = {key: value.detach().cpu().clone() for key, value in model.state_dict().items()}
|
||||
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
model.save_pretrained(OUTPUT_DIR)
|
||||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||||
label_map = {f"LABEL_{index}": intent_id for index, intent_id in id_to_label.items()}
|
||||
(OUTPUT_DIR / "label_map.json").write_text(
|
||||
json.dumps(label_map, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
train_summary = {
|
||||
"base_model": base_model,
|
||||
"epochs": EPOCHS,
|
||||
"batch_size": BATCH_SIZE,
|
||||
"learning_rate": LEARNING_RATE,
|
||||
"train_size": len(train_samples),
|
||||
"dev_size": len(dev_samples),
|
||||
"best_dev_accuracy": round(best_dev_acc, 4),
|
||||
"device": str(device),
|
||||
}
|
||||
(OUTPUT_DIR / "train_summary.json").write_text(
|
||||
json.dumps(train_summary, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
print(json.dumps(train_summary, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,415 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
import yaml
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
SINGLE_LABEL_PATH = PROJECT_ROOT / "app/data/bert_intent_train.jsonl"
|
||||
MULTI_LABEL_PATH = PROJECT_ROOT / "app/data/bert_intent_multilabel_train.jsonl"
|
||||
DOMAIN_PATH = PROJECT_ROOT / "config/domain.yml"
|
||||
OUTPUT_DIR = PROJECT_ROOT / "models/local_bert_multi_intent"
|
||||
DEFAULT_BASE_MODEL = "hfl/chinese-macbert-base"
|
||||
SOCIAL_LABEL = "__social__"
|
||||
OUT_OF_SCOPE_LABEL = "__out_of_scope__"
|
||||
BLOCKED_LABELS = {SOCIAL_LABEL, OUT_OF_SCOPE_LABEL}
|
||||
MAX_LENGTH = 48
|
||||
BATCH_SIZE = 8
|
||||
EPOCHS = 12
|
||||
LEARNING_RATE = 2e-5
|
||||
THRESHOLD = 0.5
|
||||
TOP_K = 4
|
||||
SEED = 42
|
||||
|
||||
CONNECTOR_VARIANTS: tuple[tuple[str, str], ...] = (
|
||||
("并", "然后"),
|
||||
("然后", "并"),
|
||||
("顺便", "再"),
|
||||
("再", "顺便"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiLabelSample:
|
||||
text: str
|
||||
intent_ids: tuple[str, ...]
|
||||
|
||||
|
||||
class MultiLabelIntentDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
samples: list[MultiLabelSample],
|
||||
tokenizer,
|
||||
label_to_id: dict[str, int],
|
||||
) -> None:
|
||||
self._samples = samples
|
||||
self._tokenizer = tokenizer
|
||||
self._label_to_id = label_to_id
|
||||
self._label_size = len(label_to_id)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._samples)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
|
||||
sample = self._samples[index]
|
||||
encoded = self._tokenizer(
|
||||
sample.text,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=MAX_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = torch.zeros(self._label_size, dtype=torch.float32)
|
||||
for intent_id in sample.intent_ids:
|
||||
labels[self._label_to_id[intent_id]] = 1.0
|
||||
return {
|
||||
"input_ids": encoded["input_ids"].squeeze(0),
|
||||
"attention_mask": encoded["attention_mask"].squeeze(0),
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
|
||||
def resolve_base_model() -> str:
|
||||
configured = os.getenv("AGENT_BERT_BASE_MODEL", "").strip()
|
||||
if configured:
|
||||
return configured
|
||||
return DEFAULT_BASE_MODEL
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
return " ".join(str(text).strip().split())
|
||||
|
||||
|
||||
def normalize_intent_ids(intent_ids: list[str] | tuple[str, ...]) -> tuple[str, ...]:
|
||||
cleaned = sorted(
|
||||
{
|
||||
str(intent_id).strip()
|
||||
for intent_id in intent_ids
|
||||
if str(intent_id).strip() and str(intent_id).strip() not in BLOCKED_LABELS
|
||||
}
|
||||
)
|
||||
return tuple(cleaned)
|
||||
|
||||
|
||||
def expand_single_label_variants(text: str) -> list[str]:
|
||||
normalized = text.strip().strip(",。!?;; ")
|
||||
if not normalized:
|
||||
return []
|
||||
variants = {
|
||||
normalized,
|
||||
normalized.replace("一下", "").strip(),
|
||||
normalized.replace("帮我", "").strip(),
|
||||
normalized.replace("请", "").strip(),
|
||||
f"帮我{normalized}",
|
||||
f"请{normalized}",
|
||||
f"{normalized}一下",
|
||||
}
|
||||
cleaned: list[str] = []
|
||||
for item in variants:
|
||||
compact = " ".join(item.split()).strip(",。!?;; ")
|
||||
if compact:
|
||||
cleaned.append(compact)
|
||||
return cleaned
|
||||
|
||||
|
||||
def load_single_label_samples(file_path: Path) -> list[MultiLabelSample]:
|
||||
samples: list[MultiLabelSample] = []
|
||||
if not file_path.exists():
|
||||
return samples
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
intent_ids = normalize_intent_ids([str(payload.get("intent_id") or "")])
|
||||
if not intent_ids:
|
||||
continue
|
||||
text = normalize_text(str(payload.get("text") or ""))
|
||||
if not text:
|
||||
continue
|
||||
samples.append(MultiLabelSample(text=text, intent_ids=intent_ids))
|
||||
return samples
|
||||
|
||||
|
||||
def load_domain_samples(file_path: Path) -> list[MultiLabelSample]:
|
||||
if not file_path.exists():
|
||||
return []
|
||||
payload = yaml.safe_load(file_path.read_text(encoding="utf-8")) or {}
|
||||
intents = payload.get("intents", [])
|
||||
samples: list[MultiLabelSample] = []
|
||||
seen: set[tuple[str, tuple[str, ...]]] = set()
|
||||
for item in intents:
|
||||
intent_ids = normalize_intent_ids([str(item.get("intent_id") or "")])
|
||||
if not intent_ids:
|
||||
continue
|
||||
seed_texts = list(item.get("examples") or [])
|
||||
seed_texts.extend(item.get("keywords") or [])
|
||||
label = str(item.get("label") or "").strip()
|
||||
if label:
|
||||
seed_texts.append(label)
|
||||
for text in seed_texts:
|
||||
normalized = normalize_text(text)
|
||||
if not normalized:
|
||||
continue
|
||||
for variant in expand_single_label_variants(normalized):
|
||||
key = (variant, intent_ids)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
samples.append(MultiLabelSample(text=variant, intent_ids=intent_ids))
|
||||
return samples
|
||||
|
||||
|
||||
def load_multilabel_samples(file_path: Path) -> list[MultiLabelSample]:
|
||||
samples: list[MultiLabelSample] = []
|
||||
if not file_path.exists():
|
||||
return samples
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
payload = json.loads(line)
|
||||
intent_ids = normalize_intent_ids(list(payload.get("intent_ids") or []))
|
||||
if len(intent_ids) < 2:
|
||||
continue
|
||||
text = normalize_text(str(payload.get("text") or ""))
|
||||
if not text:
|
||||
continue
|
||||
samples.append(MultiLabelSample(text=text, intent_ids=intent_ids))
|
||||
return samples
|
||||
|
||||
|
||||
def augment_multilabel_samples(samples: list[MultiLabelSample]) -> list[MultiLabelSample]:
|
||||
augmented = list(samples)
|
||||
seen = {(sample.text, sample.intent_ids) for sample in augmented}
|
||||
for sample in list(samples):
|
||||
variants = {
|
||||
sample.text,
|
||||
f"帮我{sample.text}",
|
||||
f"请{sample.text}",
|
||||
sample.text.replace(",", ", "),
|
||||
sample.text.replace(",", ""),
|
||||
}
|
||||
for source, target in CONNECTOR_VARIANTS:
|
||||
if source in sample.text:
|
||||
variants.add(sample.text.replace(source, target, 1))
|
||||
for variant in variants:
|
||||
normalized = normalize_text(variant).strip(",。!?;; ")
|
||||
key = (normalized, sample.intent_ids)
|
||||
if normalized and key not in seen:
|
||||
augmented.append(MultiLabelSample(text=normalized, intent_ids=sample.intent_ids))
|
||||
seen.add(key)
|
||||
return augmented
|
||||
|
||||
|
||||
def load_all_samples() -> list[MultiLabelSample]:
|
||||
samples = load_single_label_samples(SINGLE_LABEL_PATH)
|
||||
samples.extend(load_domain_samples(DOMAIN_PATH))
|
||||
samples.extend(augment_multilabel_samples(load_multilabel_samples(MULTI_LABEL_PATH)))
|
||||
deduped: list[MultiLabelSample] = []
|
||||
seen: set[tuple[str, tuple[str, ...]]] = set()
|
||||
for sample in samples:
|
||||
key = (sample.text, sample.intent_ids)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(sample)
|
||||
random.shuffle(deduped)
|
||||
return deduped
|
||||
|
||||
|
||||
def split_samples(samples: list[MultiLabelSample]) -> tuple[list[MultiLabelSample], list[MultiLabelSample]]:
|
||||
grouped: dict[tuple[str, ...], list[MultiLabelSample]] = {}
|
||||
for sample in samples:
|
||||
grouped.setdefault(sample.intent_ids, []).append(sample)
|
||||
train_samples: list[MultiLabelSample] = []
|
||||
dev_samples: list[MultiLabelSample] = []
|
||||
for items in grouped.values():
|
||||
random.shuffle(items)
|
||||
if len(items) == 1:
|
||||
train_samples.extend(items)
|
||||
continue
|
||||
cut = max(1, int(len(items) * 0.8))
|
||||
if cut >= len(items):
|
||||
cut = len(items) - 1
|
||||
train_samples.extend(items[:cut])
|
||||
dev_samples.extend(items[cut:])
|
||||
if not dev_samples:
|
||||
dev_samples = train_samples[-max(1, min(32, len(train_samples) // 5 or 1)) :]
|
||||
train_samples = train_samples[: len(train_samples) - len(dev_samples)]
|
||||
random.shuffle(train_samples)
|
||||
random.shuffle(dev_samples)
|
||||
return train_samples, dev_samples
|
||||
|
||||
|
||||
def logits_to_probabilities(logits: torch.Tensor) -> list[list[float]]:
|
||||
return torch.sigmoid(logits).detach().cpu().tolist()
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
probabilities: list[list[float]],
|
||||
targets: list[list[float]],
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
) -> dict[str, float]:
|
||||
true_positive = 0
|
||||
false_positive = 0
|
||||
false_negative = 0
|
||||
exact_match = 0
|
||||
recall_at_k_total = 0.0
|
||||
total = len(probabilities)
|
||||
for scores, target in zip(probabilities, targets):
|
||||
predicted = {index for index, score in enumerate(scores) if score >= threshold}
|
||||
expected = {index for index, value in enumerate(target) if value >= 0.5}
|
||||
if predicted == expected:
|
||||
exact_match += 1
|
||||
true_positive += len(predicted & expected)
|
||||
false_positive += len(predicted - expected)
|
||||
false_negative += len(expected - predicted)
|
||||
top_indices = sorted(range(len(scores)), key=lambda index: scores[index], reverse=True)[:top_k]
|
||||
if expected:
|
||||
recall_at_k_total += len(set(top_indices) & expected) / len(expected)
|
||||
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
|
||||
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
|
||||
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
return {
|
||||
"micro_precision": round(precision, 4),
|
||||
"micro_recall": round(recall, 4),
|
||||
"micro_f1": round(micro_f1, 4),
|
||||
"exact_match": round(exact_match / total, 4) if total else 0.0,
|
||||
"recall_at_k": round(recall_at_k_total / total, 4) if total else 0.0,
|
||||
}
|
||||
|
||||
|
||||
def evaluate(model, loader: DataLoader, device: torch.device, threshold: float, top_k: int) -> tuple[float, dict[str, float]]:
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
probabilities: list[list[float]] = []
|
||||
targets: list[list[float]] = []
|
||||
with torch.no_grad():
|
||||
for batch in loader:
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||||
total_loss += float(outputs.loss.item())
|
||||
probabilities.extend(logits_to_probabilities(outputs.logits))
|
||||
targets.extend(labels.detach().cpu().tolist())
|
||||
avg_loss = total_loss / max(len(loader), 1)
|
||||
return avg_loss, compute_metrics(probabilities, targets, threshold=threshold, top_k=top_k)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
set_seed(SEED)
|
||||
samples = load_all_samples()
|
||||
intents = sorted({intent_id for sample in samples for intent_id in sample.intent_ids})
|
||||
label_to_id = {intent_id: index for index, intent_id in enumerate(intents)}
|
||||
id_to_label = {index: intent_id for intent_id, index in label_to_id.items()}
|
||||
|
||||
train_samples, dev_samples = split_samples(samples)
|
||||
base_model = resolve_base_model()
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||
train_dataset = MultiLabelIntentDataset(train_samples, tokenizer, label_to_id)
|
||||
dev_dataset = MultiLabelIntentDataset(dev_samples, tokenizer, label_to_id)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
base_model,
|
||||
num_labels=len(intents),
|
||||
id2label=id_to_label,
|
||||
label2id=label_to_id,
|
||||
problem_type="multi_label_classification",
|
||||
)
|
||||
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
best_dev_f1 = 0.0
|
||||
best_state = None
|
||||
best_metrics: dict[str, float] = {}
|
||||
|
||||
for epoch in range(1, EPOCHS + 1):
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += float(loss.item())
|
||||
|
||||
dev_loss, dev_metrics = evaluate(model, dev_loader, device, threshold=THRESHOLD, top_k=TOP_K)
|
||||
avg_loss = total_loss / max(len(train_loader), 1)
|
||||
print(
|
||||
" ".join(
|
||||
[
|
||||
f"epoch={epoch}",
|
||||
f"train_loss={avg_loss:.4f}",
|
||||
f"dev_loss={dev_loss:.4f}",
|
||||
f"dev_micro_f1={dev_metrics['micro_f1']:.4f}",
|
||||
f"dev_exact_match={dev_metrics['exact_match']:.4f}",
|
||||
]
|
||||
)
|
||||
)
|
||||
if dev_metrics["micro_f1"] >= best_dev_f1:
|
||||
best_dev_f1 = dev_metrics["micro_f1"]
|
||||
best_metrics = dict(dev_metrics)
|
||||
best_state = {key: value.detach().cpu().clone() for key, value in model.state_dict().items()}
|
||||
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
model.save_pretrained(OUTPUT_DIR)
|
||||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||||
label_map = {f"LABEL_{index}": intent_id for index, intent_id in id_to_label.items()}
|
||||
(OUTPUT_DIR / "label_map.json").write_text(
|
||||
json.dumps(label_map, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
train_summary = {
|
||||
"task_type": "multi_label_intent_detection",
|
||||
"base_model": base_model,
|
||||
"epochs": EPOCHS,
|
||||
"batch_size": BATCH_SIZE,
|
||||
"learning_rate": LEARNING_RATE,
|
||||
"threshold": THRESHOLD,
|
||||
"top_k": TOP_K,
|
||||
"train_size": len(train_samples),
|
||||
"dev_size": len(dev_samples),
|
||||
"label_count": len(intents),
|
||||
"labels": intents,
|
||||
"best_dev_metrics": best_metrics,
|
||||
"device": str(device),
|
||||
}
|
||||
(OUTPUT_DIR / "train_summary.json").write_text(
|
||||
json.dumps(train_summary, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
print(json.dumps(train_summary, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user