Files
2026-06-11 16:28:00 +08:00

236 lines
8.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.core.config import settings
from app.core.bootstrap import build_intent_registry
from app.services.classifier import BertIntentClassifier
from app.services.router import build_matcher_pipeline
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "local_bert_intent"
NON_BUSINESS_LABELS = {"__social__", "__out_of_scope__"}
def resolve_model_path(model_path: str) -> Path:
configured = (model_path or settings.classifier_model_path).strip()
if configured:
return Path(configured)
return DEFAULT_MODEL_DIR
def resolve_label_map_path(label_map_path: str, model_path: Path) -> Path:
configured = (label_map_path or settings.classifier_label_map_path).strip()
if configured:
return Path(configured)
return model_path / "label_map.json"
def build_classifier(
*,
model_path: Path,
label_map_path: Path,
threshold: float,
top_k: int,
) -> BertIntentClassifier:
return BertIntentClassifier(
model_path=str(model_path),
threshold=threshold,
label_map_path=str(label_map_path),
fallback=None,
top_k=top_k,
)
def predict_once(
text: str,
*,
model_path: Path,
label_map_path: Path,
threshold: float,
top_k: int,
warmup: bool,
) -> dict[str, object]:
registry = build_intent_registry()
classifier = build_classifier(
model_path=model_path,
label_map_path=label_map_path,
threshold=threshold,
top_k=top_k,
)
warmup_ok = None
if warmup:
warmup_ok = classifier.warmup(settings.classifier_warmup_text)
if not warmup_ok:
error_message = getattr(classifier, "_warmup_error_message", None) or "BERT warmup failed"
raise RuntimeError(error_message)
intents = registry.list()
result = classifier.predict(text, intents)
matcher = build_matcher_pipeline(
registry,
["classifier"],
classifier=classifier,
route_to_cloud_threshold=settings.local_route_to_cloud_threshold,
clarify_margin_threshold=settings.local_clarify_margin_threshold,
classifier_execute_score_threshold=settings.local_classifier_execute_score_threshold,
classifier_execute_margin_threshold=settings.local_classifier_execute_margin_threshold,
)
route_result = matcher.match(text)
fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None)
classifier_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "classifier"), None)
return {
"text": text,
"config": {
"model_path": str(model_path),
"label_map_path": str(label_map_path),
"threshold": threshold,
"top_k": top_k,
"warmup_requested": warmup,
"warmup_ok": warmup_ok,
"warmup_elapsed_ms": getattr(classifier, "_warmup_elapsed_ms", None),
"warmup_error_message": getattr(classifier, "_warmup_error_message", None),
},
"classifier_result": {
"predicted_intent": result.intent.intent_id if result.intent is not None else None,
"score": round(result.score, 4),
"model_name": result.model_name,
"backend_name": result.backend_name,
"used_fallback": result.used_fallback,
"fallback_reason": result.fallback_reason,
"error_message": result.error_message,
"raw_label": result.raw_label,
"raw_candidates": result.raw_candidates or [],
"known_candidates": [
{"intent_id": intent.intent_id, "score": round(score, 4)}
for intent, score in (result.candidates or [])
],
},
"route_result": {
"decision": route_result.debug.decision,
"decision_reason": route_result.debug.decision_reason,
"matched_stage": route_result.debug.matched_stage,
"selected_intent": route_result.debug.selected_intent,
"confidence_grade": route_result.debug.confidence_grade,
"unknown_detected": route_result.debug.unknown_detected,
"classifier_score": round(classifier_stage.score, 4) if classifier_stage is not None else None,
"fusion_score": round(fusion_stage.score, 4) if fusion_stage is not None else None,
},
}
def summarize_business_view(result: dict[str, object]) -> dict[str, object]:
classifier_result = dict(result.get("classifier_result") or {})
route_result = dict(result.get("route_result") or {})
predicted_intent = classifier_result.get("predicted_intent")
raw_label = classifier_result.get("raw_label")
effective_label = raw_label if raw_label in NON_BUSINESS_LABELS else predicted_intent
if effective_label in NON_BUSINESS_LABELS:
classifier_result["predicted_intent"] = None
classifier_result["non_business_label"] = effective_label
classifier_result["business_interpretation"] = "non_business_label_detected"
route_result["selected_intent"] = None
route_result["decision"] = "reject"
route_result["decision_reason"] = "classifier detected a non-business label"
route_result["unknown_detected"] = True
else:
classifier_result["non_business_label"] = None
classifier_result["business_interpretation"] = "known_business_intent_or_uncertain"
return {
"text": result.get("text"),
"config": result.get("config"),
"classifier_result": classifier_result,
"route_result": route_result,
}
def interactive_loop(
*,
model_path: Path,
label_map_path: Path,
threshold: float,
top_k: int,
warmup: bool,
mode: str,
) -> None:
print("当前 BERT 测试已启动,输入一句话直接查看预测结果,输入 exit 退出。")
print(f"model_path={model_path}")
print(f"label_map_path={label_map_path}")
print(f"threshold={threshold} top_k={top_k} warmup={warmup} mode={mode}")
while True:
try:
text = input("\n请输入问题> ").strip()
except EOFError:
print()
break
if not text:
continue
if text.lower() in {"exit", "quit", "q"}:
break
result = predict_once(
text,
model_path=model_path,
label_map_path=label_map_path,
threshold=threshold,
top_k=top_k,
warmup=warmup,
)
if mode == "business":
result = summarize_business_view(result)
print(json.dumps(result, ensure_ascii=False, indent=2))
def main() -> None:
parser = argparse.ArgumentParser(description="当前项目 BERT 意图识别测试脚本")
parser.add_argument("--text", type=str, default="", help="单次测试文本")
parser.add_argument("--threshold", type=float, default=settings.classifier_bert_threshold, help="BERT 置信度阈值")
parser.add_argument("--top-k", type=int, default=settings.classifier_top_k, help="返回候选数量")
parser.add_argument("--model-path", type=str, default="", help="模型目录,默认取 .env 或 models/local_bert_intent")
parser.add_argument("--label-map-path", type=str, default="", help="标签映射文件,默认取 .env 或 model_path/label_map.json")
parser.add_argument("--warmup", action="store_true", help="先执行一次 warmup 再预测")
parser.add_argument(
"--mode",
choices=("classifier", "business"),
default="classifier",
help="classifier 显示原始分类结果business 会把非业务标签折叠成未命中业务意图",
)
args = parser.parse_args()
model_path = resolve_model_path(args.model_path)
label_map_path = resolve_label_map_path(args.label_map_path, model_path)
if args.text.strip():
result = predict_once(
args.text.strip(),
model_path=model_path,
label_map_path=label_map_path,
threshold=args.threshold,
top_k=args.top_k,
warmup=args.warmup,
)
if args.mode == "business":
result = summarize_business_view(result)
print(json.dumps(result, ensure_ascii=False, indent=2))
return
interactive_loop(
model_path=model_path,
label_map_path=label_map_path,
threshold=args.threshold,
top_k=args.top_k,
warmup=args.warmup,
mode=args.mode,
)
if __name__ == "__main__":
main()