Update project and configurations
This commit is contained in:
235
intelligent_cabin/archive/tests/test_bert.py
Normal file
235
intelligent_cabin/archive/tests/test_bert.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.classifier import BertIntentClassifier
|
||||
from app.services.router import build_matcher_pipeline
|
||||
|
||||
|
||||
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "local_bert_intent"
|
||||
NON_BUSINESS_LABELS = {"__social__", "__out_of_scope__"}
|
||||
|
||||
|
||||
def resolve_model_path(model_path: str) -> Path:
|
||||
configured = (model_path or settings.classifier_model_path).strip()
|
||||
if configured:
|
||||
return Path(configured)
|
||||
return DEFAULT_MODEL_DIR
|
||||
|
||||
|
||||
def resolve_label_map_path(label_map_path: str, model_path: Path) -> Path:
|
||||
configured = (label_map_path or settings.classifier_label_map_path).strip()
|
||||
if configured:
|
||||
return Path(configured)
|
||||
return model_path / "label_map.json"
|
||||
|
||||
|
||||
def build_classifier(
|
||||
*,
|
||||
model_path: Path,
|
||||
label_map_path: Path,
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
) -> BertIntentClassifier:
|
||||
return BertIntentClassifier(
|
||||
model_path=str(model_path),
|
||||
threshold=threshold,
|
||||
label_map_path=str(label_map_path),
|
||||
fallback=None,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
|
||||
def predict_once(
|
||||
text: str,
|
||||
*,
|
||||
model_path: Path,
|
||||
label_map_path: Path,
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
warmup: bool,
|
||||
) -> dict[str, object]:
|
||||
registry = build_intent_registry()
|
||||
classifier = build_classifier(
|
||||
model_path=model_path,
|
||||
label_map_path=label_map_path,
|
||||
threshold=threshold,
|
||||
top_k=top_k,
|
||||
)
|
||||
warmup_ok = None
|
||||
if warmup:
|
||||
warmup_ok = classifier.warmup(settings.classifier_warmup_text)
|
||||
if not warmup_ok:
|
||||
error_message = getattr(classifier, "_warmup_error_message", None) or "BERT warmup failed"
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
intents = registry.list()
|
||||
result = classifier.predict(text, intents)
|
||||
matcher = build_matcher_pipeline(
|
||||
registry,
|
||||
["classifier"],
|
||||
classifier=classifier,
|
||||
route_to_cloud_threshold=settings.local_route_to_cloud_threshold,
|
||||
clarify_margin_threshold=settings.local_clarify_margin_threshold,
|
||||
classifier_execute_score_threshold=settings.local_classifier_execute_score_threshold,
|
||||
classifier_execute_margin_threshold=settings.local_classifier_execute_margin_threshold,
|
||||
)
|
||||
route_result = matcher.match(text)
|
||||
fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None)
|
||||
classifier_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "classifier"), None)
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"config": {
|
||||
"model_path": str(model_path),
|
||||
"label_map_path": str(label_map_path),
|
||||
"threshold": threshold,
|
||||
"top_k": top_k,
|
||||
"warmup_requested": warmup,
|
||||
"warmup_ok": warmup_ok,
|
||||
"warmup_elapsed_ms": getattr(classifier, "_warmup_elapsed_ms", None),
|
||||
"warmup_error_message": getattr(classifier, "_warmup_error_message", None),
|
||||
},
|
||||
"classifier_result": {
|
||||
"predicted_intent": result.intent.intent_id if result.intent is not None else None,
|
||||
"score": round(result.score, 4),
|
||||
"model_name": result.model_name,
|
||||
"backend_name": result.backend_name,
|
||||
"used_fallback": result.used_fallback,
|
||||
"fallback_reason": result.fallback_reason,
|
||||
"error_message": result.error_message,
|
||||
"raw_label": result.raw_label,
|
||||
"raw_candidates": result.raw_candidates or [],
|
||||
"known_candidates": [
|
||||
{"intent_id": intent.intent_id, "score": round(score, 4)}
|
||||
for intent, score in (result.candidates or [])
|
||||
],
|
||||
},
|
||||
"route_result": {
|
||||
"decision": route_result.debug.decision,
|
||||
"decision_reason": route_result.debug.decision_reason,
|
||||
"matched_stage": route_result.debug.matched_stage,
|
||||
"selected_intent": route_result.debug.selected_intent,
|
||||
"confidence_grade": route_result.debug.confidence_grade,
|
||||
"unknown_detected": route_result.debug.unknown_detected,
|
||||
"classifier_score": round(classifier_stage.score, 4) if classifier_stage is not None else None,
|
||||
"fusion_score": round(fusion_stage.score, 4) if fusion_stage is not None else None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def summarize_business_view(result: dict[str, object]) -> dict[str, object]:
|
||||
classifier_result = dict(result.get("classifier_result") or {})
|
||||
route_result = dict(result.get("route_result") or {})
|
||||
predicted_intent = classifier_result.get("predicted_intent")
|
||||
raw_label = classifier_result.get("raw_label")
|
||||
effective_label = raw_label if raw_label in NON_BUSINESS_LABELS else predicted_intent
|
||||
if effective_label in NON_BUSINESS_LABELS:
|
||||
classifier_result["predicted_intent"] = None
|
||||
classifier_result["non_business_label"] = effective_label
|
||||
classifier_result["business_interpretation"] = "non_business_label_detected"
|
||||
route_result["selected_intent"] = None
|
||||
route_result["decision"] = "reject"
|
||||
route_result["decision_reason"] = "classifier detected a non-business label"
|
||||
route_result["unknown_detected"] = True
|
||||
else:
|
||||
classifier_result["non_business_label"] = None
|
||||
classifier_result["business_interpretation"] = "known_business_intent_or_uncertain"
|
||||
return {
|
||||
"text": result.get("text"),
|
||||
"config": result.get("config"),
|
||||
"classifier_result": classifier_result,
|
||||
"route_result": route_result,
|
||||
}
|
||||
|
||||
|
||||
def interactive_loop(
|
||||
*,
|
||||
model_path: Path,
|
||||
label_map_path: Path,
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
warmup: bool,
|
||||
mode: str,
|
||||
) -> None:
|
||||
print("当前 BERT 测试已启动,输入一句话直接查看预测结果,输入 exit 退出。")
|
||||
print(f"model_path={model_path}")
|
||||
print(f"label_map_path={label_map_path}")
|
||||
print(f"threshold={threshold} top_k={top_k} warmup={warmup} mode={mode}")
|
||||
while True:
|
||||
try:
|
||||
text = input("\n请输入问题> ").strip()
|
||||
except EOFError:
|
||||
print()
|
||||
break
|
||||
if not text:
|
||||
continue
|
||||
if text.lower() in {"exit", "quit", "q"}:
|
||||
break
|
||||
result = predict_once(
|
||||
text,
|
||||
model_path=model_path,
|
||||
label_map_path=label_map_path,
|
||||
threshold=threshold,
|
||||
top_k=top_k,
|
||||
warmup=warmup,
|
||||
)
|
||||
if mode == "business":
|
||||
result = summarize_business_view(result)
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="当前项目 BERT 意图识别测试脚本")
|
||||
parser.add_argument("--text", type=str, default="", help="单次测试文本")
|
||||
parser.add_argument("--threshold", type=float, default=settings.classifier_bert_threshold, help="BERT 置信度阈值")
|
||||
parser.add_argument("--top-k", type=int, default=settings.classifier_top_k, help="返回候选数量")
|
||||
parser.add_argument("--model-path", type=str, default="", help="模型目录,默认取 .env 或 models/local_bert_intent")
|
||||
parser.add_argument("--label-map-path", type=str, default="", help="标签映射文件,默认取 .env 或 model_path/label_map.json")
|
||||
parser.add_argument("--warmup", action="store_true", help="先执行一次 warmup 再预测")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=("classifier", "business"),
|
||||
default="classifier",
|
||||
help="classifier 显示原始分类结果;business 会把非业务标签折叠成未命中业务意图",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = resolve_model_path(args.model_path)
|
||||
label_map_path = resolve_label_map_path(args.label_map_path, model_path)
|
||||
|
||||
if args.text.strip():
|
||||
result = predict_once(
|
||||
args.text.strip(),
|
||||
model_path=model_path,
|
||||
label_map_path=label_map_path,
|
||||
threshold=args.threshold,
|
||||
top_k=args.top_k,
|
||||
warmup=args.warmup,
|
||||
)
|
||||
if args.mode == "business":
|
||||
result = summarize_business_view(result)
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
return
|
||||
|
||||
interactive_loop(
|
||||
model_path=model_path,
|
||||
label_map_path=label_map_path,
|
||||
threshold=args.threshold,
|
||||
top_k=args.top_k,
|
||||
warmup=args.warmup,
|
||||
mode=args.mode,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user