98 lines
3.6 KiB
Python
98 lines
3.6 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
|
if str(PROJECT_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from app.core.config import settings
|
|
from app.core.bootstrap import build_intent_registry
|
|
from app.services.classifier import BertIntentClassifier
|
|
from app.services.router import build_matcher_pipeline
|
|
|
|
|
|
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models/local_bert_intent"
|
|
DEFAULT_LABEL_MAP = DEFAULT_MODEL_DIR / "label_map.json"
|
|
|
|
|
|
def build_classifier(threshold: float, top_k: int) -> BertIntentClassifier:
|
|
return BertIntentClassifier(
|
|
model_path=str(DEFAULT_MODEL_DIR),
|
|
threshold=threshold,
|
|
label_map_path=str(DEFAULT_LABEL_MAP),
|
|
fallback=None,
|
|
top_k=top_k,
|
|
)
|
|
|
|
|
|
def predict_once(text: str, threshold: float, top_k: int) -> dict[str, object]:
|
|
classifier = build_classifier(threshold=threshold, top_k=top_k)
|
|
registry = build_intent_registry()
|
|
intents = registry.list()
|
|
result = classifier.predict(text, intents)
|
|
matcher = build_matcher_pipeline(
|
|
registry,
|
|
["classifier"],
|
|
classifier=classifier,
|
|
route_to_cloud_threshold=settings.local_route_to_cloud_threshold,
|
|
clarify_margin_threshold=settings.local_clarify_margin_threshold,
|
|
)
|
|
route_result = matcher.match(text)
|
|
fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None)
|
|
return {
|
|
"text": text,
|
|
"predicted_intent": result.intent.intent_id if result.intent is not None else None,
|
|
"score": round(result.score, 4),
|
|
"model_name": result.model_name,
|
|
"backend": result.backend_name,
|
|
"raw_label": result.raw_label,
|
|
"fallback_reason": result.fallback_reason,
|
|
"error_message": result.error_message,
|
|
"decision": route_result.debug.decision,
|
|
"decision_reason": route_result.debug.decision_reason,
|
|
"confidence_grade": route_result.debug.confidence_grade,
|
|
"unknown_detected": route_result.debug.unknown_detected,
|
|
"fusion_top_score": round(fusion_stage.score, 4) if fusion_stage is not None else None,
|
|
"top_candidates": [
|
|
{"intent_id": intent.intent_id, "score": round(score, 4)}
|
|
for intent, score in (result.candidates or [])
|
|
],
|
|
}
|
|
|
|
|
|
def interactive_loop(threshold: float, top_k: int) -> None:
|
|
print("本地 BERT 意图测试已启动,输入一句话直接查看预测结果,输入 exit 退出。")
|
|
while True:
|
|
try:
|
|
text = input("\n请输入问题> ").strip()
|
|
except EOFError:
|
|
print()
|
|
break
|
|
if not text:
|
|
continue
|
|
if text.lower() in {"exit", "quit", "q"}:
|
|
break
|
|
result = predict_once(text, threshold=threshold, top_k=top_k)
|
|
print(json.dumps(result, ensure_ascii=False, indent=2))
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="本地 BERT 意图识别测试脚本")
|
|
parser.add_argument("--text", type=str, default="", help="单次测试文本")
|
|
parser.add_argument("--threshold", type=float, default=0.0, help="BERT 置信度阈值")
|
|
parser.add_argument("--top-k", type=int, default=3, help="返回候选数量")
|
|
args = parser.parse_args()
|
|
|
|
if args.text.strip():
|
|
print(json.dumps(predict_once(args.text.strip(), threshold=args.threshold, top_k=args.top_k), ensure_ascii=False, indent=2))
|
|
return
|
|
interactive_loop(threshold=args.threshold, top_k=args.top_k)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|