from __future__ import annotations import argparse import json from pathlib import Path import sys import torch from torch.utils.data import DataLoader from transformers import AutoModelForSequenceClassification, AutoTokenizer PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from scripts.train_local_bert_multi_intent import ( BATCH_SIZE, OUTPUT_DIR, TOP_K, THRESHOLD, MultiLabelIntentDataset, load_all_samples, split_samples, set_seed, ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Evaluate local BERT multi-intent detector.") parser.add_argument("--model-path", default=str(OUTPUT_DIR), help="Path to trained multi-intent model.") parser.add_argument("--threshold", type=float, default=THRESHOLD, help="Probability threshold.") parser.add_argument("--top-k", type=int, default=TOP_K, help="Top-k for recall@k.") parser.add_argument( "--dataset", choices=("dev", "all"), default="dev", help="Evaluate on the held-out dev split or all combined samples.", ) return parser.parse_args() def compute_metrics( probabilities: list[list[float]], targets: list[list[float]], threshold: float, top_k: int, ) -> dict[str, float]: true_positive = 0 false_positive = 0 false_negative = 0 exact_match = 0 recall_at_k_total = 0.0 total = len(probabilities) for scores, target in zip(probabilities, targets): predicted = {index for index, score in enumerate(scores) if score >= threshold} expected = {index for index, value in enumerate(target) if value >= 0.5} if predicted == expected: exact_match += 1 true_positive += len(predicted & expected) false_positive += len(predicted - expected) false_negative += len(expected - predicted) top_indices = sorted(range(len(scores)), key=lambda index: scores[index], reverse=True)[:top_k] if expected: recall_at_k_total += len(set(top_indices) & expected) / len(expected) precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0 recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0 micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0 return { "micro_precision": round(precision, 4), "micro_recall": round(recall, 4), "micro_f1": round(micro_f1, 4), "exact_match": round(exact_match / total, 4) if total else 0.0, "recall_at_k": round(recall_at_k_total / total, 4) if total else 0.0, } def main() -> None: args = parse_args() set_seed(42) samples = load_all_samples() _, dev_samples = split_samples(samples) eval_samples = samples if args.dataset == "all" else dev_samples model_path = Path(args.model_path) if not model_path.exists(): raise FileNotFoundError(f"model path not found: {model_path}") model = AutoModelForSequenceClassification.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) label_to_id = {str(label): int(index) for label, index in (model.config.label2id or {}).items()} if not label_to_id: raise RuntimeError("label2id is missing from model config") dataset = MultiLabelIntentDataset(eval_samples, tokenizer, label_to_id) loader = DataLoader(dataset, batch_size=BATCH_SIZE) device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") model.to(device) model.eval() probabilities: list[list[float]] = [] targets: list[list[float]] = [] with torch.no_grad(): for batch in loader: input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask) probabilities.extend(torch.sigmoid(outputs.logits).detach().cpu().tolist()) targets.extend(labels.detach().cpu().tolist()) metrics = compute_metrics(probabilities, targets, threshold=args.threshold, top_k=args.top_k) result = { "model_path": str(model_path), "dataset": args.dataset, "sample_size": len(eval_samples), "threshold": args.threshold, "top_k": args.top_k, "metrics": metrics, } print(json.dumps(result, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()