from __future__ import annotations import argparse import json import sys from pathlib import Path PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from app.core.bootstrap import build_intent_registry from app.services.joint_nlu import JointBertNLU def main() -> None: parser = argparse.ArgumentParser(description="评测 Joint BERT NLU 单句意图与槽位输出") parser.add_argument("--text", type=str, required=True, help="待评测文本") parser.add_argument("--model-path", type=str, default="models/local_joint_bert_nlu", help="模型目录") args = parser.parse_args() registry = build_intent_registry() nlu = JointBertNLU(model_path=args.model_path) result = nlu.predict(args.text, registry.list()) print( json.dumps( { "text": args.text, "intent_id": result.intent_id, "intent_score": round(result.intent_score, 4), "candidates": [ {"intent_id": item.intent_id, "score": round(item.score, 4)} for item in result.candidates ], "multi_intent_candidates": [ {"intent_id": item.intent_id, "score": round(item.score, 4)} for item in result.multi_intent_candidates ], "slots": result.slots, "slot_items": [ { "slot_name": item.slot_name, "value": item.value, "start": item.start, "end": item.end, "score": item.score, } for item in result.slot_items ], "error_message": result.error_message, }, ensure_ascii=False, indent=2, ) ) if __name__ == "__main__": main()