Update project and configurations

This commit is contained in:
Zou-Seay
2026-06-11 16:28:00 +08:00
parent 12d3922091
commit a29a91867d
237 changed files with 164880 additions and 90 deletions

View File

@@ -0,0 +1,235 @@
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.core.config import settings
from app.core.bootstrap import build_intent_registry
from app.services.classifier import BertIntentClassifier
from app.services.router import build_matcher_pipeline
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "local_bert_intent"
NON_BUSINESS_LABELS = {"__social__", "__out_of_scope__"}
def resolve_model_path(model_path: str) -> Path:
configured = (model_path or settings.classifier_model_path).strip()
if configured:
return Path(configured)
return DEFAULT_MODEL_DIR
def resolve_label_map_path(label_map_path: str, model_path: Path) -> Path:
configured = (label_map_path or settings.classifier_label_map_path).strip()
if configured:
return Path(configured)
return model_path / "label_map.json"
def build_classifier(
*,
model_path: Path,
label_map_path: Path,
threshold: float,
top_k: int,
) -> BertIntentClassifier:
return BertIntentClassifier(
model_path=str(model_path),
threshold=threshold,
label_map_path=str(label_map_path),
fallback=None,
top_k=top_k,
)
def predict_once(
text: str,
*,
model_path: Path,
label_map_path: Path,
threshold: float,
top_k: int,
warmup: bool,
) -> dict[str, object]:
registry = build_intent_registry()
classifier = build_classifier(
model_path=model_path,
label_map_path=label_map_path,
threshold=threshold,
top_k=top_k,
)
warmup_ok = None
if warmup:
warmup_ok = classifier.warmup(settings.classifier_warmup_text)
if not warmup_ok:
error_message = getattr(classifier, "_warmup_error_message", None) or "BERT warmup failed"
raise RuntimeError(error_message)
intents = registry.list()
result = classifier.predict(text, intents)
matcher = build_matcher_pipeline(
registry,
["classifier"],
classifier=classifier,
route_to_cloud_threshold=settings.local_route_to_cloud_threshold,
clarify_margin_threshold=settings.local_clarify_margin_threshold,
classifier_execute_score_threshold=settings.local_classifier_execute_score_threshold,
classifier_execute_margin_threshold=settings.local_classifier_execute_margin_threshold,
)
route_result = matcher.match(text)
fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None)
classifier_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "classifier"), None)
return {
"text": text,
"config": {
"model_path": str(model_path),
"label_map_path": str(label_map_path),
"threshold": threshold,
"top_k": top_k,
"warmup_requested": warmup,
"warmup_ok": warmup_ok,
"warmup_elapsed_ms": getattr(classifier, "_warmup_elapsed_ms", None),
"warmup_error_message": getattr(classifier, "_warmup_error_message", None),
},
"classifier_result": {
"predicted_intent": result.intent.intent_id if result.intent is not None else None,
"score": round(result.score, 4),
"model_name": result.model_name,
"backend_name": result.backend_name,
"used_fallback": result.used_fallback,
"fallback_reason": result.fallback_reason,
"error_message": result.error_message,
"raw_label": result.raw_label,
"raw_candidates": result.raw_candidates or [],
"known_candidates": [
{"intent_id": intent.intent_id, "score": round(score, 4)}
for intent, score in (result.candidates or [])
],
},
"route_result": {
"decision": route_result.debug.decision,
"decision_reason": route_result.debug.decision_reason,
"matched_stage": route_result.debug.matched_stage,
"selected_intent": route_result.debug.selected_intent,
"confidence_grade": route_result.debug.confidence_grade,
"unknown_detected": route_result.debug.unknown_detected,
"classifier_score": round(classifier_stage.score, 4) if classifier_stage is not None else None,
"fusion_score": round(fusion_stage.score, 4) if fusion_stage is not None else None,
},
}
def summarize_business_view(result: dict[str, object]) -> dict[str, object]:
classifier_result = dict(result.get("classifier_result") or {})
route_result = dict(result.get("route_result") or {})
predicted_intent = classifier_result.get("predicted_intent")
raw_label = classifier_result.get("raw_label")
effective_label = raw_label if raw_label in NON_BUSINESS_LABELS else predicted_intent
if effective_label in NON_BUSINESS_LABELS:
classifier_result["predicted_intent"] = None
classifier_result["non_business_label"] = effective_label
classifier_result["business_interpretation"] = "non_business_label_detected"
route_result["selected_intent"] = None
route_result["decision"] = "reject"
route_result["decision_reason"] = "classifier detected a non-business label"
route_result["unknown_detected"] = True
else:
classifier_result["non_business_label"] = None
classifier_result["business_interpretation"] = "known_business_intent_or_uncertain"
return {
"text": result.get("text"),
"config": result.get("config"),
"classifier_result": classifier_result,
"route_result": route_result,
}
def interactive_loop(
*,
model_path: Path,
label_map_path: Path,
threshold: float,
top_k: int,
warmup: bool,
mode: str,
) -> None:
print("当前 BERT 测试已启动,输入一句话直接查看预测结果,输入 exit 退出。")
print(f"model_path={model_path}")
print(f"label_map_path={label_map_path}")
print(f"threshold={threshold} top_k={top_k} warmup={warmup} mode={mode}")
while True:
try:
text = input("\n请输入问题> ").strip()
except EOFError:
print()
break
if not text:
continue
if text.lower() in {"exit", "quit", "q"}:
break
result = predict_once(
text,
model_path=model_path,
label_map_path=label_map_path,
threshold=threshold,
top_k=top_k,
warmup=warmup,
)
if mode == "business":
result = summarize_business_view(result)
print(json.dumps(result, ensure_ascii=False, indent=2))
def main() -> None:
parser = argparse.ArgumentParser(description="当前项目 BERT 意图识别测试脚本")
parser.add_argument("--text", type=str, default="", help="单次测试文本")
parser.add_argument("--threshold", type=float, default=settings.classifier_bert_threshold, help="BERT 置信度阈值")
parser.add_argument("--top-k", type=int, default=settings.classifier_top_k, help="返回候选数量")
parser.add_argument("--model-path", type=str, default="", help="模型目录,默认取 .env 或 models/local_bert_intent")
parser.add_argument("--label-map-path", type=str, default="", help="标签映射文件,默认取 .env 或 model_path/label_map.json")
parser.add_argument("--warmup", action="store_true", help="先执行一次 warmup 再预测")
parser.add_argument(
"--mode",
choices=("classifier", "business"),
default="classifier",
help="classifier 显示原始分类结果business 会把非业务标签折叠成未命中业务意图",
)
args = parser.parse_args()
model_path = resolve_model_path(args.model_path)
label_map_path = resolve_label_map_path(args.label_map_path, model_path)
if args.text.strip():
result = predict_once(
args.text.strip(),
model_path=model_path,
label_map_path=label_map_path,
threshold=args.threshold,
top_k=args.top_k,
warmup=args.warmup,
)
if args.mode == "business":
result = summarize_business_view(result)
print(json.dumps(result, ensure_ascii=False, indent=2))
return
interactive_loop(
model_path=model_path,
label_map_path=label_map_path,
threshold=args.threshold,
top_k=args.top_k,
warmup=args.warmup,
mode=args.mode,
)
if __name__ == "__main__":
main()