276 lines
12 KiB
Python
276 lines
12 KiB
Python
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import sys
|
||
from collections import Counter, defaultdict
|
||
from pathlib import Path
|
||
|
||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||
if str(PROJECT_ROOT) not in sys.path:
|
||
sys.path.insert(0, str(PROJECT_ROOT))
|
||
|
||
from app.core.bootstrap import build_intent_registry
|
||
from app.services.joint_nlu import JointBertNLU
|
||
|
||
|
||
TEST_PATH = PROJECT_ROOT / "app/data/joint_nlu_eval_independent.jsonl"
|
||
MODEL_DIR = PROJECT_ROOT / "models/local_joint_bert_nlu"
|
||
REPORT_DIR = PROJECT_ROOT / "reports"
|
||
RESULT_PATH = REPORT_DIR / "joint_nlu_independent_result.json"
|
||
REPORT_PATH = REPORT_DIR / "joint_nlu_independent_report.md"
|
||
TRAIN_SUMMARY_PATH = MODEL_DIR / "train_summary.json"
|
||
|
||
|
||
def load_cases(file_path: Path) -> list[dict[str, object]]:
|
||
cases: list[dict[str, object]] = []
|
||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
payload = json.loads(line)
|
||
expected_intent_id = str(payload.get("expected_intent_id") or payload.get("intent_id") or "").strip()
|
||
if not expected_intent_id:
|
||
continue
|
||
cases.append(
|
||
{
|
||
"text": str(payload["text"]),
|
||
"expected_intent_id": expected_intent_id,
|
||
"expected_slots": dict(payload.get("expected_slots") or {}),
|
||
"category": str(payload.get("category") or "unknown"),
|
||
}
|
||
)
|
||
return cases
|
||
|
||
|
||
def load_train_summary(file_path: Path) -> dict[str, object]:
|
||
if not file_path.exists():
|
||
return {}
|
||
return json.loads(file_path.read_text(encoding="utf-8"))
|
||
|
||
|
||
def compare_slots(expected: dict[str, object], predicted: dict[str, object]) -> dict[str, object]:
|
||
expected_keys = set(expected)
|
||
predicted_keys = set(predicted)
|
||
missing_keys = sorted(expected_keys - predicted_keys)
|
||
extra_keys = sorted(predicted_keys - expected_keys)
|
||
wrong_values: list[dict[str, object]] = []
|
||
matched_keys = 0
|
||
for key in sorted(expected_keys & predicted_keys):
|
||
if expected[key] == predicted[key]:
|
||
matched_keys += 1
|
||
else:
|
||
wrong_values.append(
|
||
{
|
||
"slot_name": key,
|
||
"expected": expected[key],
|
||
"predicted": predicted[key],
|
||
}
|
||
)
|
||
exact = not missing_keys and not extra_keys and not wrong_values
|
||
return {
|
||
"missing_keys": missing_keys,
|
||
"extra_keys": extra_keys,
|
||
"wrong_values": wrong_values,
|
||
"matched_keys": matched_keys,
|
||
"exact": exact,
|
||
}
|
||
|
||
|
||
def compute_metrics(results: list[dict[str, object]]) -> dict[str, float]:
|
||
total = len(results)
|
||
intent_correct = sum(1 for item in results if item["intent_ok"])
|
||
slot_exact = sum(1 for item in results if item["slot_exact"])
|
||
joint_exact = sum(1 for item in results if item["joint_ok"])
|
||
|
||
slot_tp = 0
|
||
slot_fp = 0
|
||
slot_fn = 0
|
||
for item in results:
|
||
expected = dict(item["expected_slots"])
|
||
predicted = dict(item["predicted_slots"])
|
||
expected_keys = set(expected)
|
||
predicted_keys = set(predicted)
|
||
slot_tp += sum(1 for key in expected_keys & predicted_keys if expected[key] == predicted[key])
|
||
slot_fp += len(predicted_keys - expected_keys)
|
||
slot_fn += len(expected_keys - predicted_keys)
|
||
slot_fp += sum(1 for key in expected_keys & predicted_keys if expected[key] != predicted[key])
|
||
slot_fn += sum(1 for key in expected_keys & predicted_keys if expected[key] != predicted[key])
|
||
|
||
precision = slot_tp / (slot_tp + slot_fp) if (slot_tp + slot_fp) else 0.0
|
||
recall = slot_tp / (slot_tp + slot_fn) if (slot_tp + slot_fn) else 0.0
|
||
slot_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||
return {
|
||
"intent_accuracy": round(intent_correct / total, 4) if total else 0.0,
|
||
"slot_exact_match": round(slot_exact / total, 4) if total else 0.0,
|
||
"joint_exact_match": round(joint_exact / total, 4) if total else 0.0,
|
||
"slot_micro_precision": round(precision, 4),
|
||
"slot_micro_recall": round(recall, 4),
|
||
"slot_micro_f1": round(slot_f1, 4),
|
||
}
|
||
|
||
|
||
def summarize_by_category(results: list[dict[str, object]]) -> list[dict[str, object]]:
|
||
grouped: dict[str, list[dict[str, object]]] = defaultdict(list)
|
||
for item in results:
|
||
grouped[str(item["category"])].append(item)
|
||
summary: list[dict[str, object]] = []
|
||
for category, items in sorted(grouped.items()):
|
||
summary.append(
|
||
{
|
||
"category": category,
|
||
"sample_count": len(items),
|
||
"metrics": compute_metrics(items),
|
||
}
|
||
)
|
||
return summary
|
||
|
||
|
||
def collect_top_confusions(results: list[dict[str, object]], limit: int = 12) -> list[dict[str, object]]:
|
||
counter: Counter[tuple[str, str]] = Counter()
|
||
for item in results:
|
||
if item["intent_ok"]:
|
||
continue
|
||
counter[(str(item["expected_intent_id"]), str(item["predicted_intent_id"]))] += 1
|
||
return [
|
||
{"expected": expected, "predicted": predicted, "count": count}
|
||
for (expected, predicted), count in counter.most_common(limit)
|
||
]
|
||
|
||
|
||
def collect_failures(results: list[dict[str, object]], limit: int = 20) -> list[dict[str, object]]:
|
||
failures = [item for item in results if not item["joint_ok"]]
|
||
|
||
def sort_key(item: dict[str, object]) -> tuple[int, int, int]:
|
||
slot_errors = len(item["slot_diff"]["missing_keys"]) + len(item["slot_diff"]["extra_keys"]) + len(item["slot_diff"]["wrong_values"])
|
||
return (0 if item["intent_ok"] else 1, slot_errors, len(str(item["text"])))
|
||
|
||
return sorted(failures, key=sort_key, reverse=True)[:limit]
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description="Joint NLU 独立评测与失败样例回放")
|
||
parser.add_argument("--test-path", type=str, default=str(TEST_PATH), help="评测集路径")
|
||
parser.add_argument("--model-path", type=str, default=str(MODEL_DIR), help="Joint NLU 模型路径")
|
||
parser.add_argument("--result-path", type=str, default=str(RESULT_PATH), help="结构化结果输出路径")
|
||
parser.add_argument("--report-path", type=str, default=str(REPORT_PATH), help="Markdown 报告输出路径")
|
||
args = parser.parse_args()
|
||
|
||
cases = load_cases(Path(args.test_path))
|
||
registry = build_intent_registry()
|
||
nlu = JointBertNLU(model_path=args.model_path)
|
||
results: list[dict[str, object]] = []
|
||
for case in cases:
|
||
prediction = nlu.predict(str(case["text"]), registry.list())
|
||
predicted_slots = dict(prediction.slots)
|
||
slot_diff = compare_slots(dict(case["expected_slots"]), predicted_slots)
|
||
predicted_intent_id = prediction.intent_id or "None"
|
||
intent_ok = predicted_intent_id == case["expected_intent_id"]
|
||
joint_ok = intent_ok and bool(slot_diff["exact"])
|
||
results.append(
|
||
{
|
||
"text": case["text"],
|
||
"category": case["category"],
|
||
"expected_intent_id": case["expected_intent_id"],
|
||
"predicted_intent_id": predicted_intent_id,
|
||
"expected_slots": case["expected_slots"],
|
||
"predicted_slots": predicted_slots,
|
||
"intent_score": round(prediction.intent_score, 4),
|
||
"intent_ok": intent_ok,
|
||
"slot_exact": bool(slot_diff["exact"]),
|
||
"joint_ok": joint_ok,
|
||
"slot_diff": slot_diff,
|
||
"top_candidates": [
|
||
{"intent_id": item.intent_id, "score": round(item.score, 4)}
|
||
for item in prediction.candidates
|
||
],
|
||
}
|
||
)
|
||
|
||
summary = {
|
||
"model_path": args.model_path,
|
||
"test_path": args.test_path,
|
||
"sample_count": len(results),
|
||
"metrics": compute_metrics(results),
|
||
"per_category": summarize_by_category(results),
|
||
"top_confusions": collect_top_confusions(results),
|
||
"failure_examples": collect_failures(results),
|
||
"train_summary": load_train_summary(TRAIN_SUMMARY_PATH),
|
||
"results": results,
|
||
}
|
||
REPORT_DIR.mkdir(parents=True, exist_ok=True)
|
||
Path(args.result_path).write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
Path(args.report_path).write_text(render_report(summary), encoding="utf-8")
|
||
print(json.dumps({"sample_count": summary["sample_count"], "metrics": summary["metrics"]}, ensure_ascii=False))
|
||
|
||
|
||
def render_report(summary: dict[str, object]) -> str:
|
||
metrics = summary["metrics"]
|
||
per_category = summary["per_category"]
|
||
confusions = summary["top_confusions"]
|
||
failures = summary["failure_examples"]
|
||
train_summary = summary.get("train_summary") or {}
|
||
lines = [
|
||
"# Joint NLU 独立评测报告",
|
||
"",
|
||
"## 概览",
|
||
f"- 模型目录:`{summary['model_path']}`",
|
||
f"- 评测集:`{summary['test_path']}`",
|
||
f"- 样本数:`{summary['sample_count']}`",
|
||
f"- `intent_accuracy`:`{metrics['intent_accuracy']}`",
|
||
f"- `slot_exact_match`:`{metrics['slot_exact_match']}`",
|
||
f"- `joint_exact_match`:`{metrics['joint_exact_match']}`",
|
||
f"- `slot_micro_precision`:`{metrics['slot_micro_precision']}`",
|
||
f"- `slot_micro_recall`:`{metrics['slot_micro_recall']}`",
|
||
f"- `slot_micro_f1`:`{metrics['slot_micro_f1']}`",
|
||
"",
|
||
"## 训练摘要",
|
||
]
|
||
if train_summary:
|
||
lines.extend(
|
||
[
|
||
f"- 训练集 / 评测集:`{train_summary.get('train_size', 'unknown')} / {train_summary.get('eval_size', 'unknown')}`",
|
||
f"- 训练阶段 `intent_accuracy`:`{train_summary.get('metrics', {}).get('intent_accuracy', 'unknown')}`",
|
||
f"- 训练阶段 `slot_exact_match`:`{train_summary.get('metrics', {}).get('slot_exact_match', 'unknown')}`",
|
||
"",
|
||
]
|
||
)
|
||
else:
|
||
lines.extend(["- 未找到训练摘要。", ""])
|
||
lines.extend(["## 分类别结果"])
|
||
for item in per_category:
|
||
category_metrics = item["metrics"]
|
||
lines.append(
|
||
f"- `{item['category']}`: count={item['sample_count']} intent_acc={category_metrics['intent_accuracy']} slot_exact={category_metrics['slot_exact_match']} joint_exact={category_metrics['joint_exact_match']}"
|
||
)
|
||
lines.extend(["", "## 主要意图混淆"])
|
||
if not confusions:
|
||
lines.append("- 未发现意图混淆。")
|
||
else:
|
||
for item in confusions:
|
||
lines.append(f"- 期望 `{item['expected']}`,预测成 `{item['predicted']}`:`{item['count']}` 次")
|
||
lines.extend(["", "## 失败样例回放"])
|
||
if not failures:
|
||
lines.append("- 无失败样例。")
|
||
else:
|
||
for item in failures:
|
||
slot_diff = item["slot_diff"]
|
||
lines.append(
|
||
f"- 文本:`{item['text']}` | 类别:`{item['category']}` | 期望意图:`{item['expected_intent_id']}` | 预测意图:`{item['predicted_intent_id']}` | 期望槽位:`{item['expected_slots']}` | 预测槽位:`{item['predicted_slots']}` | 缺失槽位:`{slot_diff['missing_keys']}` | 多出槽位:`{slot_diff['extra_keys']}`"
|
||
)
|
||
lines.extend(
|
||
[
|
||
"",
|
||
"## 结论",
|
||
"- 先看 `failure_replay` 是否仍然错,能直接判断先前多意图失败到底是联合模型本体问题还是上层组合问题。",
|
||
"- 若 `slot_music` 或 `slot_destination` 仍不稳,优先补 span 标注,不要回退到规则抽槽。",
|
||
"- 若 `no_slot_control` 很稳但 `failure_replay` 中仍有大量错误,下一步应补长尾控制语义数据,而不是急着上更复杂结构。",
|
||
"",
|
||
]
|
||
)
|
||
return "\n".join(lines)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|