Files
ai-device/intelligent_cabin/archive/scripts/eval_local_bert_multi_intent.py
2026-06-11 16:28:00 +08:00

124 lines
4.5 KiB
Python

from __future__ import annotations
import argparse
import json
from pathlib import Path
import sys
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from scripts.train_local_bert_multi_intent import (
BATCH_SIZE,
OUTPUT_DIR,
TOP_K,
THRESHOLD,
MultiLabelIntentDataset,
load_all_samples,
split_samples,
set_seed,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate local BERT multi-intent detector.")
parser.add_argument("--model-path", default=str(OUTPUT_DIR), help="Path to trained multi-intent model.")
parser.add_argument("--threshold", type=float, default=THRESHOLD, help="Probability threshold.")
parser.add_argument("--top-k", type=int, default=TOP_K, help="Top-k for recall@k.")
parser.add_argument(
"--dataset",
choices=("dev", "all"),
default="dev",
help="Evaluate on the held-out dev split or all combined samples.",
)
return parser.parse_args()
def compute_metrics(
probabilities: list[list[float]],
targets: list[list[float]],
threshold: float,
top_k: int,
) -> dict[str, float]:
true_positive = 0
false_positive = 0
false_negative = 0
exact_match = 0
recall_at_k_total = 0.0
total = len(probabilities)
for scores, target in zip(probabilities, targets):
predicted = {index for index, score in enumerate(scores) if score >= threshold}
expected = {index for index, value in enumerate(target) if value >= 0.5}
if predicted == expected:
exact_match += 1
true_positive += len(predicted & expected)
false_positive += len(predicted - expected)
false_negative += len(expected - predicted)
top_indices = sorted(range(len(scores)), key=lambda index: scores[index], reverse=True)[:top_k]
if expected:
recall_at_k_total += len(set(top_indices) & expected) / len(expected)
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
return {
"micro_precision": round(precision, 4),
"micro_recall": round(recall, 4),
"micro_f1": round(micro_f1, 4),
"exact_match": round(exact_match / total, 4) if total else 0.0,
"recall_at_k": round(recall_at_k_total / total, 4) if total else 0.0,
}
def main() -> None:
args = parse_args()
set_seed(42)
samples = load_all_samples()
_, dev_samples = split_samples(samples)
eval_samples = samples if args.dataset == "all" else dev_samples
model_path = Path(args.model_path)
if not model_path.exists():
raise FileNotFoundError(f"model path not found: {model_path}")
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
label_to_id = {str(label): int(index) for label, index in (model.config.label2id or {}).items()}
if not label_to_id:
raise RuntimeError("label2id is missing from model config")
dataset = MultiLabelIntentDataset(eval_samples, tokenizer, label_to_id)
loader = DataLoader(dataset, batch_size=BATCH_SIZE)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
model.eval()
probabilities: list[list[float]] = []
targets: list[list[float]] = []
with torch.no_grad():
for batch in loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
probabilities.extend(torch.sigmoid(outputs.logits).detach().cpu().tolist())
targets.extend(labels.detach().cpu().tolist())
metrics = compute_metrics(probabilities, targets, threshold=args.threshold, top_k=args.top_k)
result = {
"model_path": str(model_path),
"dataset": args.dataset,
"sample_size": len(eval_samples),
"threshold": args.threshold,
"top_k": args.top_k,
"metrics": metrics,
}
print(json.dumps(result, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()