from __future__ import annotations import argparse import json import sys from pathlib import Path PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from app.core.config import settings from app.core.bootstrap import build_intent_registry from app.services.classifier import BertIntentClassifier from app.services.router import build_matcher_pipeline DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "local_bert_intent" NON_BUSINESS_LABELS = {"__social__", "__out_of_scope__"} def resolve_model_path(model_path: str) -> Path: configured = (model_path or settings.classifier_model_path).strip() if configured: return Path(configured) return DEFAULT_MODEL_DIR def resolve_label_map_path(label_map_path: str, model_path: Path) -> Path: configured = (label_map_path or settings.classifier_label_map_path).strip() if configured: return Path(configured) return model_path / "label_map.json" def build_classifier( *, model_path: Path, label_map_path: Path, threshold: float, top_k: int, ) -> BertIntentClassifier: return BertIntentClassifier( model_path=str(model_path), threshold=threshold, label_map_path=str(label_map_path), fallback=None, top_k=top_k, ) def predict_once( text: str, *, model_path: Path, label_map_path: Path, threshold: float, top_k: int, warmup: bool, ) -> dict[str, object]: registry = build_intent_registry() classifier = build_classifier( model_path=model_path, label_map_path=label_map_path, threshold=threshold, top_k=top_k, ) warmup_ok = None if warmup: warmup_ok = classifier.warmup(settings.classifier_warmup_text) if not warmup_ok: error_message = getattr(classifier, "_warmup_error_message", None) or "BERT warmup failed" raise RuntimeError(error_message) intents = registry.list() result = classifier.predict(text, intents) matcher = build_matcher_pipeline( registry, ["classifier"], classifier=classifier, route_to_cloud_threshold=settings.local_route_to_cloud_threshold, clarify_margin_threshold=settings.local_clarify_margin_threshold, classifier_execute_score_threshold=settings.local_classifier_execute_score_threshold, classifier_execute_margin_threshold=settings.local_classifier_execute_margin_threshold, ) route_result = matcher.match(text) fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None) classifier_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "classifier"), None) return { "text": text, "config": { "model_path": str(model_path), "label_map_path": str(label_map_path), "threshold": threshold, "top_k": top_k, "warmup_requested": warmup, "warmup_ok": warmup_ok, "warmup_elapsed_ms": getattr(classifier, "_warmup_elapsed_ms", None), "warmup_error_message": getattr(classifier, "_warmup_error_message", None), }, "classifier_result": { "predicted_intent": result.intent.intent_id if result.intent is not None else None, "score": round(result.score, 4), "model_name": result.model_name, "backend_name": result.backend_name, "used_fallback": result.used_fallback, "fallback_reason": result.fallback_reason, "error_message": result.error_message, "raw_label": result.raw_label, "raw_candidates": result.raw_candidates or [], "known_candidates": [ {"intent_id": intent.intent_id, "score": round(score, 4)} for intent, score in (result.candidates or []) ], }, "route_result": { "decision": route_result.debug.decision, "decision_reason": route_result.debug.decision_reason, "matched_stage": route_result.debug.matched_stage, "selected_intent": route_result.debug.selected_intent, "confidence_grade": route_result.debug.confidence_grade, "unknown_detected": route_result.debug.unknown_detected, "classifier_score": round(classifier_stage.score, 4) if classifier_stage is not None else None, "fusion_score": round(fusion_stage.score, 4) if fusion_stage is not None else None, }, } def summarize_business_view(result: dict[str, object]) -> dict[str, object]: classifier_result = dict(result.get("classifier_result") or {}) route_result = dict(result.get("route_result") or {}) predicted_intent = classifier_result.get("predicted_intent") raw_label = classifier_result.get("raw_label") effective_label = raw_label if raw_label in NON_BUSINESS_LABELS else predicted_intent if effective_label in NON_BUSINESS_LABELS: classifier_result["predicted_intent"] = None classifier_result["non_business_label"] = effective_label classifier_result["business_interpretation"] = "non_business_label_detected" route_result["selected_intent"] = None route_result["decision"] = "reject" route_result["decision_reason"] = "classifier detected a non-business label" route_result["unknown_detected"] = True else: classifier_result["non_business_label"] = None classifier_result["business_interpretation"] = "known_business_intent_or_uncertain" return { "text": result.get("text"), "config": result.get("config"), "classifier_result": classifier_result, "route_result": route_result, } def interactive_loop( *, model_path: Path, label_map_path: Path, threshold: float, top_k: int, warmup: bool, mode: str, ) -> None: print("当前 BERT 测试已启动,输入一句话直接查看预测结果,输入 exit 退出。") print(f"model_path={model_path}") print(f"label_map_path={label_map_path}") print(f"threshold={threshold} top_k={top_k} warmup={warmup} mode={mode}") while True: try: text = input("\n请输入问题> ").strip() except EOFError: print() break if not text: continue if text.lower() in {"exit", "quit", "q"}: break result = predict_once( text, model_path=model_path, label_map_path=label_map_path, threshold=threshold, top_k=top_k, warmup=warmup, ) if mode == "business": result = summarize_business_view(result) print(json.dumps(result, ensure_ascii=False, indent=2)) def main() -> None: parser = argparse.ArgumentParser(description="当前项目 BERT 意图识别测试脚本") parser.add_argument("--text", type=str, default="", help="单次测试文本") parser.add_argument("--threshold", type=float, default=settings.classifier_bert_threshold, help="BERT 置信度阈值") parser.add_argument("--top-k", type=int, default=settings.classifier_top_k, help="返回候选数量") parser.add_argument("--model-path", type=str, default="", help="模型目录,默认取 .env 或 models/local_bert_intent") parser.add_argument("--label-map-path", type=str, default="", help="标签映射文件,默认取 .env 或 model_path/label_map.json") parser.add_argument("--warmup", action="store_true", help="先执行一次 warmup 再预测") parser.add_argument( "--mode", choices=("classifier", "business"), default="classifier", help="classifier 显示原始分类结果;business 会把非业务标签折叠成未命中业务意图", ) args = parser.parse_args() model_path = resolve_model_path(args.model_path) label_map_path = resolve_label_map_path(args.label_map_path, model_path) if args.text.strip(): result = predict_once( args.text.strip(), model_path=model_path, label_map_path=label_map_path, threshold=args.threshold, top_k=args.top_k, warmup=args.warmup, ) if args.mode == "business": result = summarize_business_view(result) print(json.dumps(result, ensure_ascii=False, indent=2)) return interactive_loop( model_path=model_path, label_map_path=label_map_path, threshold=args.threshold, top_k=args.top_k, warmup=args.warmup, mode=args.mode, ) if __name__ == "__main__": main()