from __future__ import annotations import argparse import json import sys from pathlib import Path PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from app.core.config import settings from app.core.bootstrap import build_intent_registry from app.services.classifier import BertIntentClassifier from app.services.router import build_matcher_pipeline DEFAULT_MODEL_DIR = PROJECT_ROOT / "models/local_bert_intent" DEFAULT_LABEL_MAP = DEFAULT_MODEL_DIR / "label_map.json" def build_classifier(threshold: float, top_k: int) -> BertIntentClassifier: return BertIntentClassifier( model_path=str(DEFAULT_MODEL_DIR), threshold=threshold, label_map_path=str(DEFAULT_LABEL_MAP), fallback=None, top_k=top_k, ) def predict_once(text: str, threshold: float, top_k: int) -> dict[str, object]: classifier = build_classifier(threshold=threshold, top_k=top_k) registry = build_intent_registry() intents = registry.list() result = classifier.predict(text, intents) matcher = build_matcher_pipeline( registry, ["classifier"], classifier=classifier, route_to_cloud_threshold=settings.local_route_to_cloud_threshold, clarify_margin_threshold=settings.local_clarify_margin_threshold, ) route_result = matcher.match(text) fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None) return { "text": text, "predicted_intent": result.intent.intent_id if result.intent is not None else None, "score": round(result.score, 4), "model_name": result.model_name, "backend": result.backend_name, "raw_label": result.raw_label, "fallback_reason": result.fallback_reason, "error_message": result.error_message, "decision": route_result.debug.decision, "decision_reason": route_result.debug.decision_reason, "confidence_grade": route_result.debug.confidence_grade, "unknown_detected": route_result.debug.unknown_detected, "fusion_top_score": round(fusion_stage.score, 4) if fusion_stage is not None else None, "top_candidates": [ {"intent_id": intent.intent_id, "score": round(score, 4)} for intent, score in (result.candidates or []) ], } def interactive_loop(threshold: float, top_k: int) -> None: print("本地 BERT 意图测试已启动,输入一句话直接查看预测结果,输入 exit 退出。") while True: try: text = input("\n请输入问题> ").strip() except EOFError: print() break if not text: continue if text.lower() in {"exit", "quit", "q"}: break result = predict_once(text, threshold=threshold, top_k=top_k) print(json.dumps(result, ensure_ascii=False, indent=2)) def main() -> None: parser = argparse.ArgumentParser(description="本地 BERT 意图识别测试脚本") parser.add_argument("--text", type=str, default="", help="单次测试文本") parser.add_argument("--threshold", type=float, default=0.0, help="BERT 置信度阈值") parser.add_argument("--top-k", type=int, default=3, help="返回候选数量") args = parser.parse_args() if args.text.strip(): print(json.dumps(predict_once(args.text.strip(), threshold=args.threshold, top_k=args.top_k), ensure_ascii=False, indent=2)) return interactive_loop(threshold=args.threshold, top_k=args.top_k) if __name__ == "__main__": main()