Files
ai-device/intelligent_cabin/archive/scripts/eval_local_bert_intent.py
2026-06-11 16:28:00 +08:00

232 lines
8.5 KiB
Python

from __future__ import annotations
import argparse
import json
from collections import Counter, defaultdict
from pathlib import Path
import sys
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.core.bootstrap import build_intent_registry
from app.services.classifier import BertIntentClassifier
TEST_PATH = PROJECT_ROOT / "app/data/bert_intent_eval_independent.jsonl"
MODEL_DIR = PROJECT_ROOT / "models/local_bert_intent"
REPORT_DIR = PROJECT_ROOT / "reports"
REPORT_PATH = REPORT_DIR / "bert_local_test_report.md"
RESULT_PATH = REPORT_DIR / "bert_local_test_result.json"
BERT_THRESHOLD = 0.0
TRAIN_SUMMARY_PATH = MODEL_DIR / "train_summary.json"
def load_cases(file_path: Path) -> list[dict[str, str]]:
cases: list[dict[str, str]] = []
for line in file_path.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line:
continue
payload = json.loads(line)
expected_label = str(payload.get("expected_label") or payload.get("intent_id") or "").strip()
if not expected_label:
continue
category = str(payload.get("category") or infer_category(expected_label)).strip()
cases.append(
{
"text": str(payload["text"]),
"expected_label": expected_label,
"category": category,
}
)
return cases
def load_train_summary(file_path: Path) -> dict[str, object]:
if not file_path.exists():
return {}
return json.loads(file_path.read_text(encoding="utf-8"))
def infer_category(label: str) -> str:
if label == "__social__":
return "social"
if label == "__out_of_scope__":
return "out_of_scope"
return "business"
def resolve_predicted_label(result) -> str:
if result.intent is not None:
return result.intent.intent_id
if result.raw_label:
return str(result.raw_label)
return "None"
def main() -> None:
parser = argparse.ArgumentParser(description="本地 BERT 独立评测脚本")
parser.add_argument("--test-path", type=str, default=str(TEST_PATH), help="评测集路径")
parser.add_argument("--result-path", type=str, default=str(RESULT_PATH), help="结构化评测结果输出路径")
parser.add_argument("--report-path", type=str, default=str(REPORT_PATH), help="Markdown 评测报告输出路径")
args = parser.parse_args()
intent_registry = build_intent_registry()
intents = intent_registry.list()
classifier = BertIntentClassifier(
model_path=str(MODEL_DIR),
threshold=BERT_THRESHOLD,
label_map_path=str(MODEL_DIR / "label_map.json"),
fallback=None,
top_k=3,
)
cases = load_cases(Path(args.test_path))
results: list[dict[str, object]] = []
confusion: dict[str, Counter[str]] = defaultdict(Counter)
category_confusion: dict[str, Counter[str]] = defaultdict(Counter)
correct = 0
for case in cases:
result = classifier.predict(case["text"], intents)
predicted = resolve_predicted_label(result)
expected = case["expected_label"]
ok = predicted == expected
if ok:
correct += 1
confusion[expected][predicted] += 1
category_confusion[case["category"]]["correct" if ok else "wrong"] += 1
results.append(
{
"text": case["text"],
"category": case["category"],
"expected_label": expected,
"predicted_label": predicted,
"score": round(result.score, 4),
"raw_label": result.raw_label,
"ok": ok,
"top_candidates": [
{"intent_id": intent.intent_id, "score": round(score, 4)}
for intent, score in (result.candidates or [])
],
}
)
accuracy = correct / len(cases) if cases else 0.0
train_summary = load_train_summary(TRAIN_SUMMARY_PATH)
per_label_stats: list[dict[str, object]] = []
for label in sorted({case["expected_label"] for case in cases}):
label_cases = [item for item in results if item["expected_label"] == label]
label_correct = sum(1 for item in label_cases if item["ok"])
per_label_stats.append(
{
"label": label,
"category": infer_category(label),
"total": len(label_cases),
"correct": label_correct,
"accuracy": round(label_correct / len(label_cases), 4) if label_cases else 0.0,
}
)
per_category_stats: list[dict[str, object]] = []
for category in sorted({case["category"] for case in cases}):
category_cases = [item for item in results if item["category"] == category]
category_correct = sum(1 for item in category_cases if item["ok"])
per_category_stats.append(
{
"category": category,
"total": len(category_cases),
"correct": category_correct,
"accuracy": round(category_correct / len(category_cases), 4) if category_cases else 0.0,
}
)
errors = [item for item in results if not item["ok"]]
summary = {
"model_dir": str(MODEL_DIR),
"threshold": BERT_THRESHOLD,
"test_path": str(args.test_path),
"test_case_count": len(cases),
"accuracy": round(accuracy, 4),
"train_summary": train_summary,
"per_category": per_category_stats,
"per_label": per_label_stats,
"errors": errors,
"confusion": {key: dict(value) for key, value in confusion.items()},
}
REPORT_DIR.mkdir(parents=True, exist_ok=True)
Path(args.result_path).write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
Path(args.report_path).write_text(render_report(summary), encoding="utf-8")
print(json.dumps({"accuracy": summary["accuracy"], "test_case_count": len(cases), "error_count": len(errors)}, ensure_ascii=False))
def render_report(summary: dict[str, object]) -> str:
per_category = summary["per_category"]
per_label = summary["per_label"]
errors = summary["errors"]
train_summary = summary.get("train_summary") or {}
lines = [
"# 本地 BERT 意图识别测试报告",
"",
"## 概览",
f"- 模型目录:`{summary['model_dir']}`",
f"- 评测集:`{summary['test_path']}`",
f"- 评测阈值:`{summary['threshold']}`",
f"- 测试样本数:`{summary['test_case_count']}`",
f"- 总体准确率:`{summary['accuracy']}`",
"",
"## 训练摘要",
]
if train_summary:
lines.extend(
[
f"- 基座模型:`{train_summary.get('base_model', 'unknown')}`",
f"- 训练集 / 验证集:`{train_summary.get('train_size', 'unknown')} / {train_summary.get('dev_size', 'unknown')}`",
f"- 最佳验证准确率:`{train_summary.get('best_dev_accuracy', 'unknown')}`",
f"- 训练设备:`{train_summary.get('device', 'unknown')}`",
"",
]
)
else:
lines.extend(["- 未找到训练摘要。", ""])
lines.extend(
[
"## 分类别结果",
]
)
for item in per_category:
lines.append(
f"- `{item['category']}`: {item['correct']}/{item['total']} = {item['accuracy']}"
)
lines.extend(["", "## 分标签结果"])
for item in per_label:
lines.append(
f"- `{item['label']}` ({item['category']}): {item['correct']}/{item['total']} = {item['accuracy']}"
)
lines.extend(["", "## 错误样例"])
if not errors:
lines.append("- 无错误样例。")
else:
for item in errors[:10]:
lines.append(
f"- 文本:`{item['text']}` | 类别:`{item['category']}` | 期望:`{item['expected_label']}` | 预测:`{item['predicted_label']}` | 分数:`{item['score']}`"
)
lines.extend(
[
"",
"## 结论",
"- 当前本地 MacBERT 已具备较强的业务意图识别能力,可作为本地快链路分类器。",
"- 误判主要集中在方向相反或语义接近的控制指令,下一步应补充对抗样本和真实口语表达。",
"- 上线前建议继续补充 ASR 错字、多轮短句和多意图子句级样本。",
"",
]
)
return "\n".join(lines)
if __name__ == "__main__":
main()