Files
ai-device/intelligent_cabin/archive/scripts/test_local_bert_intent.py
2026-06-11 16:28:00 +08:00

98 lines
3.6 KiB
Python

from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.core.config import settings
from app.core.bootstrap import build_intent_registry
from app.services.classifier import BertIntentClassifier
from app.services.router import build_matcher_pipeline
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models/local_bert_intent"
DEFAULT_LABEL_MAP = DEFAULT_MODEL_DIR / "label_map.json"
def build_classifier(threshold: float, top_k: int) -> BertIntentClassifier:
return BertIntentClassifier(
model_path=str(DEFAULT_MODEL_DIR),
threshold=threshold,
label_map_path=str(DEFAULT_LABEL_MAP),
fallback=None,
top_k=top_k,
)
def predict_once(text: str, threshold: float, top_k: int) -> dict[str, object]:
classifier = build_classifier(threshold=threshold, top_k=top_k)
registry = build_intent_registry()
intents = registry.list()
result = classifier.predict(text, intents)
matcher = build_matcher_pipeline(
registry,
["classifier"],
classifier=classifier,
route_to_cloud_threshold=settings.local_route_to_cloud_threshold,
clarify_margin_threshold=settings.local_clarify_margin_threshold,
)
route_result = matcher.match(text)
fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None)
return {
"text": text,
"predicted_intent": result.intent.intent_id if result.intent is not None else None,
"score": round(result.score, 4),
"model_name": result.model_name,
"backend": result.backend_name,
"raw_label": result.raw_label,
"fallback_reason": result.fallback_reason,
"error_message": result.error_message,
"decision": route_result.debug.decision,
"decision_reason": route_result.debug.decision_reason,
"confidence_grade": route_result.debug.confidence_grade,
"unknown_detected": route_result.debug.unknown_detected,
"fusion_top_score": round(fusion_stage.score, 4) if fusion_stage is not None else None,
"top_candidates": [
{"intent_id": intent.intent_id, "score": round(score, 4)}
for intent, score in (result.candidates or [])
],
}
def interactive_loop(threshold: float, top_k: int) -> None:
print("本地 BERT 意图测试已启动,输入一句话直接查看预测结果,输入 exit 退出。")
while True:
try:
text = input("\n请输入问题> ").strip()
except EOFError:
print()
break
if not text:
continue
if text.lower() in {"exit", "quit", "q"}:
break
result = predict_once(text, threshold=threshold, top_k=top_k)
print(json.dumps(result, ensure_ascii=False, indent=2))
def main() -> None:
parser = argparse.ArgumentParser(description="本地 BERT 意图识别测试脚本")
parser.add_argument("--text", type=str, default="", help="单次测试文本")
parser.add_argument("--threshold", type=float, default=0.0, help="BERT 置信度阈值")
parser.add_argument("--top-k", type=int, default=3, help="返回候选数量")
args = parser.parse_args()
if args.text.strip():
print(json.dumps(predict_once(args.text.strip(), threshold=args.threshold, top_k=args.top_k), ensure_ascii=False, indent=2))
return
interactive_loop(threshold=args.threshold, top_k=args.top_k)
if __name__ == "__main__":
main()