Update project and configurations
This commit is contained in:
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from scripts.train_local_bert_multi_intent import (
|
||||
BATCH_SIZE,
|
||||
OUTPUT_DIR,
|
||||
TOP_K,
|
||||
THRESHOLD,
|
||||
MultiLabelIntentDataset,
|
||||
load_all_samples,
|
||||
split_samples,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Evaluate local BERT multi-intent detector.")
|
||||
parser.add_argument("--model-path", default=str(OUTPUT_DIR), help="Path to trained multi-intent model.")
|
||||
parser.add_argument("--threshold", type=float, default=THRESHOLD, help="Probability threshold.")
|
||||
parser.add_argument("--top-k", type=int, default=TOP_K, help="Top-k for recall@k.")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
choices=("dev", "all"),
|
||||
default="dev",
|
||||
help="Evaluate on the held-out dev split or all combined samples.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
probabilities: list[list[float]],
|
||||
targets: list[list[float]],
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
) -> dict[str, float]:
|
||||
true_positive = 0
|
||||
false_positive = 0
|
||||
false_negative = 0
|
||||
exact_match = 0
|
||||
recall_at_k_total = 0.0
|
||||
total = len(probabilities)
|
||||
for scores, target in zip(probabilities, targets):
|
||||
predicted = {index for index, score in enumerate(scores) if score >= threshold}
|
||||
expected = {index for index, value in enumerate(target) if value >= 0.5}
|
||||
if predicted == expected:
|
||||
exact_match += 1
|
||||
true_positive += len(predicted & expected)
|
||||
false_positive += len(predicted - expected)
|
||||
false_negative += len(expected - predicted)
|
||||
top_indices = sorted(range(len(scores)), key=lambda index: scores[index], reverse=True)[:top_k]
|
||||
if expected:
|
||||
recall_at_k_total += len(set(top_indices) & expected) / len(expected)
|
||||
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
|
||||
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
|
||||
micro_f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
|
||||
return {
|
||||
"micro_precision": round(precision, 4),
|
||||
"micro_recall": round(recall, 4),
|
||||
"micro_f1": round(micro_f1, 4),
|
||||
"exact_match": round(exact_match / total, 4) if total else 0.0,
|
||||
"recall_at_k": round(recall_at_k_total / total, 4) if total else 0.0,
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
set_seed(42)
|
||||
samples = load_all_samples()
|
||||
_, dev_samples = split_samples(samples)
|
||||
eval_samples = samples if args.dataset == "all" else dev_samples
|
||||
model_path = Path(args.model_path)
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"model path not found: {model_path}")
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
label_to_id = {str(label): int(index) for label, index in (model.config.label2id or {}).items()}
|
||||
if not label_to_id:
|
||||
raise RuntimeError("label2id is missing from model config")
|
||||
|
||||
dataset = MultiLabelIntentDataset(eval_samples, tokenizer, label_to_id)
|
||||
loader = DataLoader(dataset, batch_size=BATCH_SIZE)
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
probabilities: list[list[float]] = []
|
||||
targets: list[list[float]] = []
|
||||
with torch.no_grad():
|
||||
for batch in loader:
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
labels = batch["labels"].to(device)
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
probabilities.extend(torch.sigmoid(outputs.logits).detach().cpu().tolist())
|
||||
targets.extend(labels.detach().cpu().tolist())
|
||||
|
||||
metrics = compute_metrics(probabilities, targets, threshold=args.threshold, top_k=args.top_k)
|
||||
result = {
|
||||
"model_path": str(model_path),
|
||||
"dataset": args.dataset,
|
||||
"sample_size": len(eval_samples),
|
||||
"threshold": args.threshold,
|
||||
"top_k": args.top_k,
|
||||
"metrics": metrics,
|
||||
}
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user