124 lines
4.5 KiB
Python
124 lines
4.5 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
import sys
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
|
if str(PROJECT_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from scripts.train_local_bert_multi_intent import (
|
|
BATCH_SIZE,
|
|
OUTPUT_DIR,
|
|
TOP_K,
|
|
THRESHOLD,
|
|
MultiLabelIntentDataset,
|
|
load_all_samples,
|
|
split_samples,
|
|
set_seed,
|
|
)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Evaluate local BERT multi-intent detector.")
|
|
parser.add_argument("--model-path", default=str(OUTPUT_DIR), help="Path to trained multi-intent model.")
|
|
parser.add_argument("--threshold", type=float, default=THRESHOLD, help="Probability threshold.")
|
|
parser.add_argument("--top-k", type=int, default=TOP_K, help="Top-k for recall@k.")
|
|
parser.add_argument(
|
|
"--dataset",
|
|
choices=("dev", "all"),
|
|
default="dev",
|
|
help="Evaluate on the held-out dev split or all combined samples.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def compute_metrics(
|
|
probabilities: list[list[float]],
|
|
targets: list[list[float]],
|
|
threshold: float,
|
|
top_k: int,
|
|
) -> dict[str, float]:
|
|
true_positive = 0
|
|
false_positive = 0
|
|
false_negative = 0
|
|
exact_match = 0
|
|
recall_at_k_total = 0.0
|
|
total = len(probabilities)
|
|
for scores, target in zip(probabilities, targets):
|
|
predicted = {index for index, score in enumerate(scores) if score >= threshold}
|
|
expected = {index for index, value in enumerate(target) if value >= 0.5}
|
|
if predicted == expected:
|
|
exact_match += 1
|
|
true_positive += len(predicted & expected)
|
|
false_positive += len(predicted - expected)
|
|
false_negative += len(expected - predicted)
|
|
top_indices = sorted(range(len(scores)), key=lambda index: scores[index], reverse=True)[:top_k]
|
|
if expected:
|
|
recall_at_k_total += len(set(top_indices) & expected) / len(expected)
|
|
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
|
|
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
|
|
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
|
return {
|
|
"micro_precision": round(precision, 4),
|
|
"micro_recall": round(recall, 4),
|
|
"micro_f1": round(micro_f1, 4),
|
|
"exact_match": round(exact_match / total, 4) if total else 0.0,
|
|
"recall_at_k": round(recall_at_k_total / total, 4) if total else 0.0,
|
|
}
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
set_seed(42)
|
|
samples = load_all_samples()
|
|
_, dev_samples = split_samples(samples)
|
|
eval_samples = samples if args.dataset == "all" else dev_samples
|
|
model_path = Path(args.model_path)
|
|
if not model_path.exists():
|
|
raise FileNotFoundError(f"model path not found: {model_path}")
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
label_to_id = {str(label): int(index) for label, index in (model.config.label2id or {}).items()}
|
|
if not label_to_id:
|
|
raise RuntimeError("label2id is missing from model config")
|
|
|
|
dataset = MultiLabelIntentDataset(eval_samples, tokenizer, label_to_id)
|
|
loader = DataLoader(dataset, batch_size=BATCH_SIZE)
|
|
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
probabilities: list[list[float]] = []
|
|
targets: list[list[float]] = []
|
|
with torch.no_grad():
|
|
for batch in loader:
|
|
input_ids = batch["input_ids"].to(device)
|
|
attention_mask = batch["attention_mask"].to(device)
|
|
labels = batch["labels"].to(device)
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
|
probabilities.extend(torch.sigmoid(outputs.logits).detach().cpu().tolist())
|
|
targets.extend(labels.detach().cpu().tolist())
|
|
|
|
metrics = compute_metrics(probabilities, targets, threshold=args.threshold, top_k=args.top_k)
|
|
result = {
|
|
"model_path": str(model_path),
|
|
"dataset": args.dataset,
|
|
"sample_size": len(eval_samples),
|
|
"threshold": args.threshold,
|
|
"top_k": args.top_k,
|
|
"metrics": metrics,
|
|
}
|
|
print(json.dumps(result, ensure_ascii=False, indent=2))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|