Update project and configurations
This commit is contained in:
97
intelligent_cabin/archive/scripts/test_local_bert_intent.py
Normal file
97
intelligent_cabin/archive/scripts/test_local_bert_intent.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.classifier import BertIntentClassifier
|
||||
from app.services.router import build_matcher_pipeline
|
||||
|
||||
|
||||
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models/local_bert_intent"
|
||||
DEFAULT_LABEL_MAP = DEFAULT_MODEL_DIR / "label_map.json"
|
||||
|
||||
|
||||
def build_classifier(threshold: float, top_k: int) -> BertIntentClassifier:
|
||||
return BertIntentClassifier(
|
||||
model_path=str(DEFAULT_MODEL_DIR),
|
||||
threshold=threshold,
|
||||
label_map_path=str(DEFAULT_LABEL_MAP),
|
||||
fallback=None,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
|
||||
def predict_once(text: str, threshold: float, top_k: int) -> dict[str, object]:
|
||||
classifier = build_classifier(threshold=threshold, top_k=top_k)
|
||||
registry = build_intent_registry()
|
||||
intents = registry.list()
|
||||
result = classifier.predict(text, intents)
|
||||
matcher = build_matcher_pipeline(
|
||||
registry,
|
||||
["classifier"],
|
||||
classifier=classifier,
|
||||
route_to_cloud_threshold=settings.local_route_to_cloud_threshold,
|
||||
clarify_margin_threshold=settings.local_clarify_margin_threshold,
|
||||
)
|
||||
route_result = matcher.match(text)
|
||||
fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None)
|
||||
return {
|
||||
"text": text,
|
||||
"predicted_intent": result.intent.intent_id if result.intent is not None else None,
|
||||
"score": round(result.score, 4),
|
||||
"model_name": result.model_name,
|
||||
"backend": result.backend_name,
|
||||
"raw_label": result.raw_label,
|
||||
"fallback_reason": result.fallback_reason,
|
||||
"error_message": result.error_message,
|
||||
"decision": route_result.debug.decision,
|
||||
"decision_reason": route_result.debug.decision_reason,
|
||||
"confidence_grade": route_result.debug.confidence_grade,
|
||||
"unknown_detected": route_result.debug.unknown_detected,
|
||||
"fusion_top_score": round(fusion_stage.score, 4) if fusion_stage is not None else None,
|
||||
"top_candidates": [
|
||||
{"intent_id": intent.intent_id, "score": round(score, 4)}
|
||||
for intent, score in (result.candidates or [])
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def interactive_loop(threshold: float, top_k: int) -> None:
|
||||
print("本地 BERT 意图测试已启动,输入一句话直接查看预测结果,输入 exit 退出。")
|
||||
while True:
|
||||
try:
|
||||
text = input("\n请输入问题> ").strip()
|
||||
except EOFError:
|
||||
print()
|
||||
break
|
||||
if not text:
|
||||
continue
|
||||
if text.lower() in {"exit", "quit", "q"}:
|
||||
break
|
||||
result = predict_once(text, threshold=threshold, top_k=top_k)
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="本地 BERT 意图识别测试脚本")
|
||||
parser.add_argument("--text", type=str, default="", help="单次测试文本")
|
||||
parser.add_argument("--threshold", type=float, default=0.0, help="BERT 置信度阈值")
|
||||
parser.add_argument("--top-k", type=int, default=3, help="返回候选数量")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.text.strip():
|
||||
print(json.dumps(predict_once(args.text.strip(), threshold=args.threshold, top_k=args.top_k), ensure_ascii=False, indent=2))
|
||||
return
|
||||
interactive_loop(threshold=args.threshold, top_k=args.top_k)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user