Update project and configurations
This commit is contained in:
132
intelligent_cabin/archive/tests/test_agent_cloud_route.py
Normal file
132
intelligent_cabin/archive/tests/test_agent_cloud_route.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from app.plugins.base import PluginRegistry
|
||||
from app.schemas.chat import ChatRequest
|
||||
from app.schemas.debug import IntentCandidate, MatcherStageDebug, RoutingDebug
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.services.agent_service import AgentService
|
||||
from app.services.intent_registry import IntentRegistry
|
||||
from app.services.planner import PlanningResult
|
||||
from app.services.session_store import InMemorySessionStore
|
||||
|
||||
|
||||
class _RouteToCloudRouter:
|
||||
def route(self, text: str):
|
||||
_ = text
|
||||
return type(
|
||||
"RouteResult",
|
||||
(),
|
||||
{
|
||||
"intent": None,
|
||||
"debug": RoutingDebug(
|
||||
selected_intent="cabin_nav_to",
|
||||
matched_stage="fusion",
|
||||
decision="route_to_cloud",
|
||||
decision_reason="local signal is not stable enough, routing to cloud planner",
|
||||
confidence_grade="low",
|
||||
stages=[
|
||||
MatcherStageDebug(
|
||||
stage="fusion",
|
||||
accepted=False,
|
||||
selected_intent="cabin_nav_to",
|
||||
score=0.88,
|
||||
reason="route to cloud",
|
||||
candidates=[
|
||||
IntentCandidate(intent_id="cabin_nav_to", score=0.88, reason="fusion", model_name="fusion"),
|
||||
IntentCandidate(intent_id="cabin_play_music", score=0.75, reason="fusion", model_name="fusion"),
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
},
|
||||
)()
|
||||
|
||||
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, object]:
|
||||
_ = (text, intent)
|
||||
return {}
|
||||
|
||||
|
||||
class _PlannerRejects:
|
||||
def plan(self, text: str, intents: list[IntentDefinition], context: dict[str, object] | None = None) -> PlanningResult:
|
||||
_ = (text, intents, context)
|
||||
return PlanningResult(
|
||||
accepted=False,
|
||||
workflow_type="single",
|
||||
model_name="qwen3.5-plus",
|
||||
backend="dashscope",
|
||||
reason="cloud planner could not produce a stable executable step",
|
||||
)
|
||||
|
||||
|
||||
class _PlannerOutOfScope:
|
||||
def plan(self, text: str, intents: list[IntentDefinition], context: dict[str, object] | None = None) -> PlanningResult:
|
||||
_ = (text, intents, context)
|
||||
return PlanningResult(
|
||||
accepted=False,
|
||||
workflow_type="single",
|
||||
model_name="qwen3.5-plus",
|
||||
backend="dashscope",
|
||||
reason="The provided intent catalog only contains cabin and service actions. There is no matching intent for ordering food via a third-party app action.",
|
||||
)
|
||||
|
||||
|
||||
def _intent(intent_id: str) -> IntentDefinition:
|
||||
return IntentDefinition(
|
||||
intent_id=intent_id,
|
||||
plugin_id=f"mock.{intent_id}",
|
||||
domain="cabin",
|
||||
keywords=[],
|
||||
examples=[],
|
||||
)
|
||||
|
||||
|
||||
class AgentCloudRouteTests(unittest.TestCase):
|
||||
def test_route_to_cloud_returns_explicit_clarify_feedback_when_planner_does_not_accept(self) -> None:
|
||||
service = AgentService(
|
||||
intent_registry=IntentRegistry([_intent("cabin_nav_to"), _intent("cabin_play_music")]),
|
||||
router=_RouteToCloudRouter(),
|
||||
plugins=PluginRegistry(),
|
||||
session_store=InMemorySessionStore(),
|
||||
planner=_PlannerRejects(),
|
||||
)
|
||||
|
||||
response = service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_cloud_route",
|
||||
user_id="user_1",
|
||||
input_text="带我过去",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.decision, "route_to_cloud")
|
||||
self.assertEqual(response.reply_type, "clarify")
|
||||
self.assertEqual(response.status, "route_to_cloud")
|
||||
self.assertIn("请确认一下", response.reply_text)
|
||||
|
||||
def test_route_to_cloud_rejects_when_planner_marks_request_out_of_scope(self) -> None:
|
||||
service = AgentService(
|
||||
intent_registry=IntentRegistry([_intent("cabin_nav_to"), _intent("cabin_play_music")]),
|
||||
router=_RouteToCloudRouter(),
|
||||
plugins=PluginRegistry(),
|
||||
session_store=InMemorySessionStore(),
|
||||
planner=_PlannerOutOfScope(),
|
||||
)
|
||||
|
||||
response = service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_cloud_route_reject",
|
||||
user_id="user_1",
|
||||
input_text="去美团叫个外卖",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "reject")
|
||||
self.assertEqual(response.decision, "reject")
|
||||
self.assertEqual(response.status, "rejected")
|
||||
self.assertIn("做不了", response.reply_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
235
intelligent_cabin/archive/tests/test_bert.py
Normal file
235
intelligent_cabin/archive/tests/test_bert.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.bootstrap import build_intent_registry
|
||||
from app.services.classifier import BertIntentClassifier
|
||||
from app.services.router import build_matcher_pipeline
|
||||
|
||||
|
||||
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "local_bert_intent"
|
||||
NON_BUSINESS_LABELS = {"__social__", "__out_of_scope__"}
|
||||
|
||||
|
||||
def resolve_model_path(model_path: str) -> Path:
|
||||
configured = (model_path or settings.classifier_model_path).strip()
|
||||
if configured:
|
||||
return Path(configured)
|
||||
return DEFAULT_MODEL_DIR
|
||||
|
||||
|
||||
def resolve_label_map_path(label_map_path: str, model_path: Path) -> Path:
|
||||
configured = (label_map_path or settings.classifier_label_map_path).strip()
|
||||
if configured:
|
||||
return Path(configured)
|
||||
return model_path / "label_map.json"
|
||||
|
||||
|
||||
def build_classifier(
|
||||
*,
|
||||
model_path: Path,
|
||||
label_map_path: Path,
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
) -> BertIntentClassifier:
|
||||
return BertIntentClassifier(
|
||||
model_path=str(model_path),
|
||||
threshold=threshold,
|
||||
label_map_path=str(label_map_path),
|
||||
fallback=None,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
|
||||
def predict_once(
|
||||
text: str,
|
||||
*,
|
||||
model_path: Path,
|
||||
label_map_path: Path,
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
warmup: bool,
|
||||
) -> dict[str, object]:
|
||||
registry = build_intent_registry()
|
||||
classifier = build_classifier(
|
||||
model_path=model_path,
|
||||
label_map_path=label_map_path,
|
||||
threshold=threshold,
|
||||
top_k=top_k,
|
||||
)
|
||||
warmup_ok = None
|
||||
if warmup:
|
||||
warmup_ok = classifier.warmup(settings.classifier_warmup_text)
|
||||
if not warmup_ok:
|
||||
error_message = getattr(classifier, "_warmup_error_message", None) or "BERT warmup failed"
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
intents = registry.list()
|
||||
result = classifier.predict(text, intents)
|
||||
matcher = build_matcher_pipeline(
|
||||
registry,
|
||||
["classifier"],
|
||||
classifier=classifier,
|
||||
route_to_cloud_threshold=settings.local_route_to_cloud_threshold,
|
||||
clarify_margin_threshold=settings.local_clarify_margin_threshold,
|
||||
classifier_execute_score_threshold=settings.local_classifier_execute_score_threshold,
|
||||
classifier_execute_margin_threshold=settings.local_classifier_execute_margin_threshold,
|
||||
)
|
||||
route_result = matcher.match(text)
|
||||
fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None)
|
||||
classifier_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "classifier"), None)
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"config": {
|
||||
"model_path": str(model_path),
|
||||
"label_map_path": str(label_map_path),
|
||||
"threshold": threshold,
|
||||
"top_k": top_k,
|
||||
"warmup_requested": warmup,
|
||||
"warmup_ok": warmup_ok,
|
||||
"warmup_elapsed_ms": getattr(classifier, "_warmup_elapsed_ms", None),
|
||||
"warmup_error_message": getattr(classifier, "_warmup_error_message", None),
|
||||
},
|
||||
"classifier_result": {
|
||||
"predicted_intent": result.intent.intent_id if result.intent is not None else None,
|
||||
"score": round(result.score, 4),
|
||||
"model_name": result.model_name,
|
||||
"backend_name": result.backend_name,
|
||||
"used_fallback": result.used_fallback,
|
||||
"fallback_reason": result.fallback_reason,
|
||||
"error_message": result.error_message,
|
||||
"raw_label": result.raw_label,
|
||||
"raw_candidates": result.raw_candidates or [],
|
||||
"known_candidates": [
|
||||
{"intent_id": intent.intent_id, "score": round(score, 4)}
|
||||
for intent, score in (result.candidates or [])
|
||||
],
|
||||
},
|
||||
"route_result": {
|
||||
"decision": route_result.debug.decision,
|
||||
"decision_reason": route_result.debug.decision_reason,
|
||||
"matched_stage": route_result.debug.matched_stage,
|
||||
"selected_intent": route_result.debug.selected_intent,
|
||||
"confidence_grade": route_result.debug.confidence_grade,
|
||||
"unknown_detected": route_result.debug.unknown_detected,
|
||||
"classifier_score": round(classifier_stage.score, 4) if classifier_stage is not None else None,
|
||||
"fusion_score": round(fusion_stage.score, 4) if fusion_stage is not None else None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def summarize_business_view(result: dict[str, object]) -> dict[str, object]:
|
||||
classifier_result = dict(result.get("classifier_result") or {})
|
||||
route_result = dict(result.get("route_result") or {})
|
||||
predicted_intent = classifier_result.get("predicted_intent")
|
||||
raw_label = classifier_result.get("raw_label")
|
||||
effective_label = raw_label if raw_label in NON_BUSINESS_LABELS else predicted_intent
|
||||
if effective_label in NON_BUSINESS_LABELS:
|
||||
classifier_result["predicted_intent"] = None
|
||||
classifier_result["non_business_label"] = effective_label
|
||||
classifier_result["business_interpretation"] = "non_business_label_detected"
|
||||
route_result["selected_intent"] = None
|
||||
route_result["decision"] = "reject"
|
||||
route_result["decision_reason"] = "classifier detected a non-business label"
|
||||
route_result["unknown_detected"] = True
|
||||
else:
|
||||
classifier_result["non_business_label"] = None
|
||||
classifier_result["business_interpretation"] = "known_business_intent_or_uncertain"
|
||||
return {
|
||||
"text": result.get("text"),
|
||||
"config": result.get("config"),
|
||||
"classifier_result": classifier_result,
|
||||
"route_result": route_result,
|
||||
}
|
||||
|
||||
|
||||
def interactive_loop(
|
||||
*,
|
||||
model_path: Path,
|
||||
label_map_path: Path,
|
||||
threshold: float,
|
||||
top_k: int,
|
||||
warmup: bool,
|
||||
mode: str,
|
||||
) -> None:
|
||||
print("当前 BERT 测试已启动,输入一句话直接查看预测结果,输入 exit 退出。")
|
||||
print(f"model_path={model_path}")
|
||||
print(f"label_map_path={label_map_path}")
|
||||
print(f"threshold={threshold} top_k={top_k} warmup={warmup} mode={mode}")
|
||||
while True:
|
||||
try:
|
||||
text = input("\n请输入问题> ").strip()
|
||||
except EOFError:
|
||||
print()
|
||||
break
|
||||
if not text:
|
||||
continue
|
||||
if text.lower() in {"exit", "quit", "q"}:
|
||||
break
|
||||
result = predict_once(
|
||||
text,
|
||||
model_path=model_path,
|
||||
label_map_path=label_map_path,
|
||||
threshold=threshold,
|
||||
top_k=top_k,
|
||||
warmup=warmup,
|
||||
)
|
||||
if mode == "business":
|
||||
result = summarize_business_view(result)
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="当前项目 BERT 意图识别测试脚本")
|
||||
parser.add_argument("--text", type=str, default="", help="单次测试文本")
|
||||
parser.add_argument("--threshold", type=float, default=settings.classifier_bert_threshold, help="BERT 置信度阈值")
|
||||
parser.add_argument("--top-k", type=int, default=settings.classifier_top_k, help="返回候选数量")
|
||||
parser.add_argument("--model-path", type=str, default="", help="模型目录,默认取 .env 或 models/local_bert_intent")
|
||||
parser.add_argument("--label-map-path", type=str, default="", help="标签映射文件,默认取 .env 或 model_path/label_map.json")
|
||||
parser.add_argument("--warmup", action="store_true", help="先执行一次 warmup 再预测")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=("classifier", "business"),
|
||||
default="classifier",
|
||||
help="classifier 显示原始分类结果;business 会把非业务标签折叠成未命中业务意图",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = resolve_model_path(args.model_path)
|
||||
label_map_path = resolve_label_map_path(args.label_map_path, model_path)
|
||||
|
||||
if args.text.strip():
|
||||
result = predict_once(
|
||||
args.text.strip(),
|
||||
model_path=model_path,
|
||||
label_map_path=label_map_path,
|
||||
threshold=args.threshold,
|
||||
top_k=args.top_k,
|
||||
warmup=args.warmup,
|
||||
)
|
||||
if args.mode == "business":
|
||||
result = summarize_business_view(result)
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
return
|
||||
|
||||
interactive_loop(
|
||||
model_path=model_path,
|
||||
label_map_path=label_map_path,
|
||||
threshold=args.threshold,
|
||||
top_k=args.top_k,
|
||||
warmup=args.warmup,
|
||||
mode=args.mode,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
109
intelligent_cabin/archive/tests/test_chat_stream.py
Normal file
109
intelligent_cabin/archive/tests/test_chat_stream.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
os.environ["AGENT_CLASSIFIER_BACKEND"] = "mock"
|
||||
os.environ["AGENT_CLASSIFIER_WARMUP_ENABLED"] = "false"
|
||||
|
||||
from app.main import app
|
||||
from app.schemas.chat import ChatResponse
|
||||
|
||||
|
||||
def _fake_response() -> ChatResponse:
|
||||
return ChatResponse(
|
||||
session_id="sess_stream_1",
|
||||
reply_type="workflow_result",
|
||||
reply_text="好,空调已经打开了。",
|
||||
intent="cabin_ac_on",
|
||||
status="completed",
|
||||
trace_id="trace_stream_1",
|
||||
)
|
||||
|
||||
|
||||
class ChatStreamTests(unittest.TestCase):
|
||||
def test_chat_stream_returns_final_only_when_fast(self) -> None:
|
||||
client = TestClient(app)
|
||||
with patch("app.main.agent_service.handle_chat", return_value=_fake_response()):
|
||||
response = client.post(
|
||||
"/api/v1/agent/chat-stream",
|
||||
json={
|
||||
"session_id": "sess_stream_1",
|
||||
"user_id": "user_stream_1",
|
||||
"channel": "test",
|
||||
"input_text": "打开车窗",
|
||||
"input_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
|
||||
self.assertEqual(len(lines), 1)
|
||||
final_event = json.loads(lines[0])
|
||||
self.assertEqual(final_event.get("type"), "final")
|
||||
|
||||
def test_chat_stream_returns_ack_then_final_when_slow_request(self) -> None:
|
||||
client = TestClient(app)
|
||||
|
||||
def _slow_handle_chat(_request):
|
||||
time.sleep(1.2)
|
||||
return _fake_response()
|
||||
|
||||
with patch("app.main.agent_service.handle_chat", side_effect=_slow_handle_chat):
|
||||
response = client.post(
|
||||
"/api/v1/agent/chat-stream",
|
||||
json={
|
||||
"session_id": "sess_stream_1",
|
||||
"user_id": "user_stream_1",
|
||||
"channel": "test",
|
||||
"input_text": "打开车窗",
|
||||
"input_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
|
||||
self.assertGreaterEqual(len(lines), 2)
|
||||
|
||||
ack_event = json.loads(lines[0])
|
||||
final_event = json.loads(lines[-1])
|
||||
self.assertEqual(ack_event.get("type"), "ack")
|
||||
self.assertEqual(final_event.get("type"), "final")
|
||||
self.assertIn("data", final_event)
|
||||
self.assertIn("reply_text", final_event["data"])
|
||||
|
||||
def test_chat_stream_returns_ack_then_final_when_slow_social_request(self) -> None:
|
||||
client = TestClient(app)
|
||||
|
||||
def _slow_handle_chat(_request):
|
||||
time.sleep(1.2)
|
||||
return _fake_response()
|
||||
|
||||
with patch("app.main.agent_service.handle_chat", side_effect=_slow_handle_chat):
|
||||
response = client.post(
|
||||
"/api/v1/agent/chat-stream",
|
||||
json={
|
||||
"session_id": "sess_stream_1",
|
||||
"user_id": "user_stream_1",
|
||||
"channel": "test",
|
||||
"input_text": "今天天气如何",
|
||||
"input_type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
|
||||
self.assertGreaterEqual(len(lines), 2)
|
||||
ack_event = json.loads(lines[0])
|
||||
final_event = json.loads(lines[-1])
|
||||
self.assertEqual(ack_event.get("type"), "ack")
|
||||
self.assertEqual(final_event.get("type"), "final")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
90
intelligent_cabin/archive/tests/test_config_loader.py
Normal file
90
intelligent_cabin/archive/tests/test_config_loader.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.core.bootstrap import build_planner, load_runtime_bundle
|
||||
from app.core.config import settings
|
||||
from app.services.planner import CompositeWorkflowPlanner
|
||||
from app.services.config_loader import ConfigLoader
|
||||
from app.services.dialog_rules import DialogRuleEngine
|
||||
from app.services.response_policy import ResponsePolicy
|
||||
|
||||
|
||||
class ConfigLoaderTests(unittest.TestCase):
|
||||
def test_loader_reads_domain_actions_and_responses(self) -> None:
|
||||
bundle = ConfigLoader(
|
||||
domain_path="config/domain.yml",
|
||||
action_path="config/actions.yml",
|
||||
response_path="config/responses.yml",
|
||||
form_path="config/forms.yml",
|
||||
rule_path="config/rules.yml",
|
||||
dialog_act_path="config/dialog_acts.yml",
|
||||
workflow_path="config/workflows.yml",
|
||||
legacy_intent_path="app/data/intents.json",
|
||||
).load()
|
||||
|
||||
self.assertGreaterEqual(len(bundle.intent_registry.list()), 30)
|
||||
self.assertEqual(bundle.intent_registry.get("cabin_window_open").plugin_id, "plugin.cabin.window.open")
|
||||
self.assertEqual(bundle.intent_hints.get("cabin_window_open"), "打开车窗")
|
||||
self.assertEqual(bundle.response_templates.get("task_stopped"), "好的,已停止当前任务。")
|
||||
self.assertEqual(bundle.intent_registry.get("cabin_set_ac").required_slots, ["temperature"])
|
||||
self.assertTrue(bundle.dialog_rules.is_stop_request("先不要了"))
|
||||
self.assertEqual(bundle.dialog_rules.parse_confirmation_decision("确认"), True)
|
||||
self.assertEqual(bundle.dialog_act_engine.detect("你好"), "chitchat")
|
||||
self.assertGreaterEqual(len(bundle.workflow_templates.templates), 2)
|
||||
|
||||
def test_bootstrap_runtime_bundle_is_available(self) -> None:
|
||||
bundle = load_runtime_bundle()
|
||||
|
||||
self.assertGreaterEqual(len(bundle.intent_registry.list()), 30)
|
||||
self.assertIn("fallback", bundle.response_templates)
|
||||
self.assertEqual(bundle.dialog_act_engine.detect("确认"), "affirm")
|
||||
|
||||
def test_response_policy_can_be_driven_by_config_templates(self) -> None:
|
||||
policy = ResponsePolicy(
|
||||
templates={"reject": "这个能力暂未开通。"},
|
||||
intent_hints={"cabin_window_open": "开车窗"},
|
||||
)
|
||||
|
||||
self.assertEqual(policy.reject(), "这个能力暂未开通。")
|
||||
self.assertEqual(policy.clarify(["cabin_window_open"]), "请确认一下,你是想开车窗吗?")
|
||||
|
||||
def test_response_policy_formats_multi_step_summary_naturally(self) -> None:
|
||||
policy = ResponsePolicy()
|
||||
|
||||
summary = policy.workflow_summary(["好的,已打开空调。", "已将空调调到 20 度。"])
|
||||
|
||||
self.assertEqual(summary, "好,空调已经打开了,也调到 20 度了。")
|
||||
|
||||
def test_response_policy_formats_multi_step_summary_in_vehicle_style(self) -> None:
|
||||
policy = ResponsePolicy()
|
||||
|
||||
summary = policy.workflow_summary(["好的,已打开空调。", "好的,已关闭车窗。"])
|
||||
|
||||
self.assertEqual(summary, "好,空调已经打开了,车窗也帮你关上了。")
|
||||
|
||||
def test_build_planner_prefers_local_planners_before_cloud(self) -> None:
|
||||
with patch.object(settings, "planner_backend", "dashscope"):
|
||||
planner = build_planner()
|
||||
|
||||
self.assertIsInstance(planner, CompositeWorkflowPlanner)
|
||||
self.assertIsInstance(planner._planners[0], CompositeWorkflowPlanner)
|
||||
|
||||
def test_dialog_rule_engine_supports_configured_confirmation_and_stop(self) -> None:
|
||||
rules = DialogRuleEngine(
|
||||
stop_phrases=("先不用了",),
|
||||
positive_confirmation_tokens=("好,继续",),
|
||||
negative_confirmation_tokens=("取消吧",),
|
||||
confirmation_required_intents=("foo",),
|
||||
confirmation_required_risk_levels=("high",),
|
||||
)
|
||||
|
||||
self.assertTrue(rules.is_stop_request("先不用了"))
|
||||
self.assertTrue(rules.parse_confirmation_decision("好,继续"))
|
||||
self.assertFalse(rules.parse_confirmation_decision("取消吧"))
|
||||
self.assertTrue(rules.requires_confirmation("foo", "low"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,202 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from app.plugins.base import PluginRegistry
|
||||
from app.plugins.mock import MockPluginExecutor
|
||||
from app.schemas.chat import ChatRequest
|
||||
from app.schemas.debug import IntentCandidate, MatcherStageDebug, RoutingDebug
|
||||
from app.schemas.workflow import Workflow, WorkflowStep
|
||||
from app.services.agent_service import AgentService
|
||||
from app.services.intent_registry import IntentRegistry
|
||||
from app.services.planner import HeuristicWorkflowPlanner
|
||||
from app.services.response_policy import ResponsePolicy
|
||||
from app.services.rewrite_engine import ContextRewriteEngine
|
||||
from app.services.router import RouteMatchResult
|
||||
from app.services.session_store import InMemorySessionStore
|
||||
|
||||
|
||||
class _FailIfCalledPlanner:
|
||||
def plan(self, text, intents, context=None):
|
||||
_ = (intents, context)
|
||||
raise AssertionError(f"planner should not be called for single intent request: {text}")
|
||||
|
||||
|
||||
class _ScriptedRouter:
|
||||
def __init__(self, registry: IntentRegistry) -> None:
|
||||
self._registry = registry
|
||||
self._route_map = {
|
||||
"来点music": self._route_result("cabin_play_music", ["cabin_play_music"]),
|
||||
"打开车窗和空调": self._route_result("cabin_window_open", ["cabin_window_open", "cabin_ac_on"]),
|
||||
"关闭车窗": self._route_result("cabin_window_close", ["cabin_window_close", "cabin_window_open"]),
|
||||
}
|
||||
self._slot_map = {
|
||||
("播放黄昏", "cabin_play_music"): {"song": "黄昏"},
|
||||
("来一首黄昏", "cabin_play_music"): {"song": "黄昏"},
|
||||
("来点黄昏", "cabin_play_music"): {"song": "黄昏"},
|
||||
}
|
||||
|
||||
def route(self, text: str) -> RouteMatchResult:
|
||||
if text not in self._route_map:
|
||||
raise AssertionError(f"unexpected route request: {text}")
|
||||
return self._route_map[text]
|
||||
|
||||
def extract_slots(self, text: str, intent) -> dict[str, object]:
|
||||
return dict(self._slot_map.get((text, intent.intent_id), {}))
|
||||
|
||||
def _route_result(self, selected_intent: str, candidates: list[str]) -> RouteMatchResult:
|
||||
intent = self._registry.get(selected_intent)
|
||||
stage = MatcherStageDebug(
|
||||
stage="fusion",
|
||||
accepted=True,
|
||||
selected_intent=selected_intent,
|
||||
score=1.0,
|
||||
candidates=[
|
||||
IntentCandidate(intent_id=intent_id, score=max(0.5, 1.0 - index * 0.1))
|
||||
for index, intent_id in enumerate(candidates)
|
||||
],
|
||||
)
|
||||
return RouteMatchResult(
|
||||
intent=intent,
|
||||
debug=RoutingDebug(
|
||||
selected_intent=selected_intent,
|
||||
matched_stage="fusion",
|
||||
decision="execute",
|
||||
stages=[stage],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DialogContinuationAndMultiIntentTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.registry = IntentRegistry.from_json("app/data/intents.json")
|
||||
self.plugins = PluginRegistry()
|
||||
MockPluginExecutor().register(self.plugins)
|
||||
self.service = AgentService(
|
||||
intent_registry=self.registry,
|
||||
router=_ScriptedRouter(self.registry),
|
||||
plugins=self.plugins,
|
||||
session_store=InMemorySessionStore(),
|
||||
rewrite_engine=ContextRewriteEngine(),
|
||||
response_policy=ResponsePolicy(),
|
||||
planner=HeuristicWorkflowPlanner(),
|
||||
)
|
||||
|
||||
def test_music_followup_in_chat_continues_waiting_slot(self) -> None:
|
||||
first = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_music_followup",
|
||||
user_id="user_1",
|
||||
input_text="来点music",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(first.reply_type, "ask_slot")
|
||||
self.assertEqual(first.pending_slots, ["media_query"])
|
||||
|
||||
second = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_music_followup",
|
||||
user_id="user_1",
|
||||
input_text="黄昏",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(second.reply_type, "workflow_result")
|
||||
self.assertEqual(second.intent, "cabin_play_music")
|
||||
self.assertEqual(second.filled_slots.get("song"), "黄昏")
|
||||
self.assertIn("黄昏", second.reply_text)
|
||||
|
||||
def test_parallel_compound_request_enters_planner(self) -> None:
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_parallel_compound",
|
||||
user_id="user_1",
|
||||
input_text="打开车窗和空调",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "workflow_result")
|
||||
self.assertEqual(response.workflow.workflow_type, "sequence")
|
||||
step_intents = [step.intent_id for step in response.workflow.steps]
|
||||
self.assertEqual(step_intents, ["cabin_window_open", "cabin_ac_on"])
|
||||
self.assertIn("车窗", response.reply_text)
|
||||
self.assertIn("空调", response.reply_text)
|
||||
|
||||
def test_single_cabin_intent_does_not_enter_planner_from_top2_domain_candidates(self) -> None:
|
||||
service = AgentService(
|
||||
intent_registry=self.registry,
|
||||
router=_ScriptedRouter(self.registry),
|
||||
plugins=self.plugins,
|
||||
session_store=InMemorySessionStore(),
|
||||
rewrite_engine=ContextRewriteEngine(),
|
||||
response_policy=ResponsePolicy(),
|
||||
planner=_FailIfCalledPlanner(),
|
||||
)
|
||||
|
||||
response = service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_single_cabin_no_planner",
|
||||
user_id="user_1",
|
||||
input_text="关闭车窗",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "workflow_result")
|
||||
self.assertEqual(response.intent, "cabin_window_close")
|
||||
self.assertEqual(response.routing_debug.decision, "execute")
|
||||
self.assertFalse(any(stage.stage == "planner" for stage in response.routing_debug.stages))
|
||||
|
||||
def test_waiting_confirmation_can_continue_via_chat(self) -> None:
|
||||
session = self.service.session_store.get_or_create("sess_confirm_chat", "user_1")
|
||||
session.current_intent = "cs_cancel_order"
|
||||
session.status = "waiting_confirmation"
|
||||
session.pending_slots = ["confirmation"]
|
||||
session.slots = {"order_id": "A123456"}
|
||||
session.workflow = Workflow(
|
||||
workflow_id="wf_confirm_chat",
|
||||
workflow_type="conditional",
|
||||
domain="customer_service",
|
||||
intent_id="cs_cancel_order",
|
||||
status="waiting_confirmation",
|
||||
risk_level="high",
|
||||
slots={"order_id": "A123456"},
|
||||
steps=[
|
||||
WorkflowStep(
|
||||
step=1,
|
||||
step_id="step_cancel",
|
||||
intent_id="cs_cancel_order",
|
||||
plugin_id="plugin.order.cancel",
|
||||
action="cancel_order",
|
||||
status="waiting_confirmation",
|
||||
slots={"order_id": "A123456"},
|
||||
requires_confirmation=True,
|
||||
)
|
||||
],
|
||||
meta={
|
||||
"pending_confirmation": {
|
||||
"step_id": "step_cancel",
|
||||
"intent_id": "cs_cancel_order",
|
||||
"detail": "确认取消订单 A123456",
|
||||
},
|
||||
"step_results": {},
|
||||
"confirmed_steps": [],
|
||||
},
|
||||
).model_dump()
|
||||
self.service.session_store.save(session)
|
||||
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_confirm_chat",
|
||||
user_id="user_1",
|
||||
input_text="确认",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "workflow_result")
|
||||
self.assertEqual(response.status, "completed")
|
||||
self.assertIn("A123456", response.reply_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
149
intelligent_cabin/archive/tests/test_intent_coverage_and_stop.py
Normal file
149
intelligent_cabin/archive/tests/test_intent_coverage_and_stop.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from app.plugins.base import PluginRegistry
|
||||
from app.plugins.mock import MockPluginExecutor
|
||||
from app.services.classifier import MockIntentClassifier
|
||||
from app.services.agent_service import AgentService
|
||||
from app.services.intent_registry import IntentRegistry
|
||||
from app.services.response_policy import ResponsePolicy
|
||||
from app.services.rewrite_engine import ContextRewriteEngine
|
||||
from app.services.router import HeuristicSlotExtractor, IntentRouter, build_matcher_pipeline
|
||||
from app.services.session_store import InMemorySessionStore
|
||||
from app.schemas.chat import ChatRequest, FillSlotsRequest
|
||||
|
||||
|
||||
class _BertLikeMockClassifier(MockIntentClassifier):
|
||||
def predict(self, text, intents):
|
||||
result = super().predict(text, intents)
|
||||
result.model_name = "bert-local"
|
||||
result.backend_name = "bert-local"
|
||||
return result
|
||||
|
||||
|
||||
def _build_service() -> AgentService:
|
||||
registry = IntentRegistry.from_json("app/data/intents.json")
|
||||
plugins = PluginRegistry()
|
||||
MockPluginExecutor().register(plugins)
|
||||
router = IntentRouter(
|
||||
matcher=build_matcher_pipeline(
|
||||
registry,
|
||||
["classifier"],
|
||||
classifier=_BertLikeMockClassifier(threshold=0.0, top_k=3),
|
||||
),
|
||||
slot_extractor=HeuristicSlotExtractor(),
|
||||
)
|
||||
return AgentService(
|
||||
intent_registry=registry,
|
||||
router=router,
|
||||
plugins=plugins,
|
||||
session_store=InMemorySessionStore(),
|
||||
rewrite_engine=ContextRewriteEngine(),
|
||||
response_policy=ResponsePolicy(),
|
||||
planner=None,
|
||||
)
|
||||
|
||||
|
||||
class IntentCoverageAndStopTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.service = _build_service()
|
||||
self.registry = IntentRegistry.from_json("app/data/intents.json")
|
||||
|
||||
def test_intent_catalog_has_at_least_30_items(self) -> None:
|
||||
self.assertGreaterEqual(len(self.registry.list()), 30)
|
||||
|
||||
def test_close_ac_routes_to_power_off(self) -> None:
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_close_ac",
|
||||
user_id="user_1",
|
||||
input_text="关闭空调",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "workflow_result")
|
||||
self.assertEqual(response.intent, "cabin_ac_off")
|
||||
self.assertIn("已关闭空调", response.reply_text)
|
||||
|
||||
def test_open_window_is_covered(self) -> None:
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_window_open",
|
||||
user_id="user_1",
|
||||
input_text="打开车窗",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "workflow_result")
|
||||
self.assertEqual(response.intent, "cabin_window_open")
|
||||
self.assertIn("已打开车窗", response.reply_text)
|
||||
|
||||
def test_stop_current_task_while_waiting_for_slot(self) -> None:
|
||||
first = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_stop_task",
|
||||
user_id="user_1",
|
||||
input_text="空调调到",
|
||||
)
|
||||
)
|
||||
self.assertEqual(first.reply_type, "ask_slot")
|
||||
self.assertEqual(first.pending_slots, ["temperature"])
|
||||
|
||||
stopped = self.service.handle_fill_slots(
|
||||
FillSlotsRequest(
|
||||
session_id="sess_stop_task",
|
||||
user_id="user_1",
|
||||
input_text="不用了",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(stopped.reply_type, "text")
|
||||
self.assertEqual(stopped.status, "stopped")
|
||||
self.assertEqual(stopped.pending_slots, [])
|
||||
self.assertIn("已停止当前任务", stopped.reply_text)
|
||||
|
||||
def test_relative_ac_adjustment_uses_two_degree_step(self) -> None:
|
||||
session = self.service.session_store.get_or_create("sess_ac_lower", "user_1")
|
||||
session.current_intent = "cabin_set_ac"
|
||||
session.context_memory["last_temperature"] = 24
|
||||
self.service.session_store.save(session)
|
||||
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_ac_lower",
|
||||
user_id="user_1",
|
||||
input_text="调低一点",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "workflow_result")
|
||||
self.assertEqual(response.intent, "cabin_set_ac")
|
||||
self.assertEqual(response.filled_slots.get("temperature"), 22)
|
||||
self.assertIn("22", response.reply_text)
|
||||
|
||||
def test_relative_ac_adjustment_without_history_uses_default_baseline(self) -> None:
|
||||
session = self.service.session_store.get_or_create("sess_ac_lower_default", "user_1")
|
||||
session.current_intent = "cabin_ac_on"
|
||||
self.service.session_store.save(session)
|
||||
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_ac_lower_default",
|
||||
user_id="user_1",
|
||||
input_text="调低一点",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.reply_type, "workflow_result")
|
||||
self.assertEqual(response.intent, "cabin_set_ac")
|
||||
self.assertEqual(response.filled_slots.get("temperature"), 22)
|
||||
self.assertIn("22", response.reply_text)
|
||||
|
||||
def test_temperature_is_clamped_before_execution(self) -> None:
|
||||
self.assertEqual(self.service._normalize_temperature_value(-1), 16)
|
||||
self.assertEqual(self.service._normalize_temperature_value(40), 30)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.services.classifier import JointBertIntentClassifier
|
||||
from app.services.planner import HeuristicWorkflowPlanner
|
||||
from app.services.router import JointBertSlotExtractor
|
||||
|
||||
|
||||
class FakeJointNLU:
|
||||
def __init__(self) -> None:
|
||||
self._predictions = {
|
||||
"把空调调到22度": {
|
||||
"intent_id": "cabin_set_ac",
|
||||
"intent_score": 0.93,
|
||||
"candidates": [("cabin_set_ac", 0.93), ("cabin_ac_on", 0.04)],
|
||||
"slots": {"temperature": 22},
|
||||
},
|
||||
"导航去公司,然后把空调调到22度": {
|
||||
"intent_id": "cabin_nav_to",
|
||||
"intent_score": 0.88,
|
||||
"candidates": [("cabin_nav_to", 0.88), ("cabin_set_ac", 0.72)],
|
||||
"slots": {"destination": "公司"},
|
||||
},
|
||||
}
|
||||
self._slot_predictions = {
|
||||
("把空调调到22度", "cabin_set_ac"): {"temperature": 22},
|
||||
("导航去公司", "cabin_nav_to"): {"destination": "公司"},
|
||||
("把空调调到22度", "cabin_set_ac"): {"temperature": 22},
|
||||
}
|
||||
|
||||
def warmup(self, sample_text: str = "打开车窗") -> bool:
|
||||
_ = sample_text
|
||||
return True
|
||||
|
||||
def predict(self, text: str, intents: list[IntentDefinition]):
|
||||
from app.services.joint_nlu import JointCandidate, JointNluResult
|
||||
|
||||
raw = self._predictions[text]
|
||||
candidates = [JointCandidate(intent_id=intent_id, score=score) for intent_id, score in raw["candidates"]]
|
||||
return JointNluResult(
|
||||
intent_id=raw["intent_id"],
|
||||
intent_score=raw["intent_score"],
|
||||
candidates=candidates,
|
||||
slots=dict(raw["slots"]),
|
||||
)
|
||||
|
||||
def extract_slots(self, text: str, intent: IntentDefinition):
|
||||
return dict(self._slot_predictions.get((text, intent.intent_id), {}))
|
||||
|
||||
def extract_slots_by_intent_id(self, text: str, intent_id: str, required_slots=None):
|
||||
_ = required_slots
|
||||
return dict(self._slot_predictions.get((text, intent_id), {}))
|
||||
|
||||
|
||||
class JointNLUIntegrationTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.intents = [
|
||||
IntentDefinition(intent_id="cabin_set_ac", plugin_id="x", domain="cabin", required_slots=["temperature"]),
|
||||
IntentDefinition(intent_id="cabin_nav_to", plugin_id="x", domain="cabin", required_slots=["destination"]),
|
||||
IntentDefinition(intent_id="cabin_ac_on", plugin_id="x", domain="cabin"),
|
||||
]
|
||||
self.fake_nlu = FakeJointNLU()
|
||||
|
||||
def test_joint_classifier_uses_joint_nlu_intent_head(self) -> None:
|
||||
classifier = JointBertIntentClassifier(self.fake_nlu, threshold=0.3, top_k=2)
|
||||
|
||||
result = classifier.predict("把空调调到22度", self.intents)
|
||||
|
||||
self.assertIsNotNone(result.intent)
|
||||
self.assertEqual(result.intent.intent_id, "cabin_set_ac")
|
||||
self.assertEqual(result.raw_candidates[0]["intent_id"], "cabin_set_ac")
|
||||
|
||||
def test_joint_slot_extractor_uses_joint_nlu_slots(self) -> None:
|
||||
extractor = JointBertSlotExtractor(self.fake_nlu)
|
||||
|
||||
slots = extractor.extract("把空调调到22度", self.intents[0])
|
||||
|
||||
self.assertEqual(slots, {"temperature": 22})
|
||||
|
||||
def test_planner_prefers_joint_nlu_slots_for_each_clause(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner(joint_nlu=self.fake_nlu)
|
||||
|
||||
result = planner.plan("导航去公司,然后把空调调到22度", self.intents)
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.steps[0].slots, {"destination": "公司"})
|
||||
self.assertEqual(result.steps[1].slots, {"temperature": 22})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
144
intelligent_cabin/archive/tests/test_multi_intent_detector.py
Normal file
144
intelligent_cabin/archive/tests/test_multi_intent_detector.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.services.multi_intent_detector import BertMultiIntentDetector, JointBertMultiIntentDetector
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
def __call__(self, text, truncation=True, padding=False, return_tensors="pt"):
|
||||
_ = (text, truncation, padding, return_tensors)
|
||||
return {
|
||||
"input_ids": torch.tensor([[101, 102]], dtype=torch.long),
|
||||
"attention_mask": torch.tensor([[1, 1]], dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
class FakeModel:
|
||||
def __init__(self, logits: list[float], id2label: dict[int, str]) -> None:
|
||||
self.config = type("Config", (), {"id2label": id2label})()
|
||||
self._logits = torch.tensor([logits], dtype=torch.float32)
|
||||
|
||||
def eval(self) -> None:
|
||||
return None
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
_ = kwargs
|
||||
return type("Output", (), {"logits": self._logits})()
|
||||
|
||||
|
||||
class RuntimeBackedDetector(BertMultiIntentDetector):
|
||||
def __init__(
|
||||
self,
|
||||
logits: list[float],
|
||||
id2label: dict[int, str],
|
||||
threshold: float = 0.45,
|
||||
top_k: int = 8,
|
||||
max_labels: int = 4,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
model_path="unused",
|
||||
threshold=threshold,
|
||||
top_k=top_k,
|
||||
max_labels=max_labels,
|
||||
)
|
||||
self._runtime = (torch, FakeTokenizer(), FakeModel(logits, id2label))
|
||||
|
||||
def _get_runtime(self):
|
||||
return self._runtime
|
||||
|
||||
|
||||
class MultiIntentDetectorTests(unittest.TestCase):
|
||||
def test_detector_filters_blocked_and_unknown_labels(self) -> None:
|
||||
detector = RuntimeBackedDetector(
|
||||
logits=[2.4, 2.0, 3.2, 2.6],
|
||||
id2label={
|
||||
0: "cabin_window_open",
|
||||
1: "cabin_play_music",
|
||||
2: "__social__",
|
||||
3: "unknown_intent",
|
||||
},
|
||||
threshold=0.8,
|
||||
top_k=4,
|
||||
)
|
||||
intents = [
|
||||
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
|
||||
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
|
||||
]
|
||||
|
||||
result = detector.detect("打开车窗并播放音乐", intents)
|
||||
|
||||
self.assertTrue(result.detected)
|
||||
self.assertEqual(result.backend_name, "bert-multi-label")
|
||||
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
|
||||
|
||||
def test_detector_respects_threshold_and_max_labels(self) -> None:
|
||||
detector = RuntimeBackedDetector(
|
||||
logits=[2.8, 2.5, 2.2],
|
||||
id2label={
|
||||
0: "cabin_window_open",
|
||||
1: "cabin_play_music",
|
||||
2: "cabin_nav_to",
|
||||
},
|
||||
threshold=0.89,
|
||||
top_k=3,
|
||||
max_labels=2,
|
||||
)
|
||||
intents = [
|
||||
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
|
||||
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
|
||||
IntentDefinition(intent_id="cabin_nav_to", plugin_id="plugin.nav", domain="cabin"),
|
||||
]
|
||||
|
||||
result = detector.detect("开窗放歌去公司", intents)
|
||||
|
||||
self.assertTrue(result.detected)
|
||||
self.assertEqual(len(result.candidates), 2)
|
||||
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
|
||||
|
||||
def test_joint_bert_detector_wraps_shared_runtime(self) -> None:
|
||||
intents = [
|
||||
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
|
||||
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
|
||||
]
|
||||
|
||||
class FakeJointNlu:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict[str, object]] = []
|
||||
|
||||
def predict_multi_intents(self, text, known_intents, threshold=0.45, max_labels=4, top_k=8):
|
||||
self.calls.append(
|
||||
{
|
||||
"text": text,
|
||||
"threshold": threshold,
|
||||
"max_labels": max_labels,
|
||||
"top_k": top_k,
|
||||
"known_count": len(known_intents),
|
||||
}
|
||||
)
|
||||
return [
|
||||
type("Candidate", (), {"intent_id": "cabin_window_open", "score": 0.93})(),
|
||||
type("Candidate", (), {"intent_id": "cabin_play_music", "score": 0.88})(),
|
||||
]
|
||||
|
||||
def warmup(self, sample_text="") -> bool:
|
||||
_ = sample_text
|
||||
return True
|
||||
|
||||
fake_nlu = FakeJointNlu()
|
||||
detector = JointBertMultiIntentDetector(fake_nlu, threshold=0.5, top_k=6, max_labels=3)
|
||||
|
||||
result = detector.detect("打开车窗并播放音乐", intents)
|
||||
|
||||
self.assertTrue(result.detected)
|
||||
self.assertEqual(result.backend_name, "joint-bert-multi-label")
|
||||
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
|
||||
self.assertEqual(fake_nlu.calls[0]["threshold"], 0.5)
|
||||
self.assertEqual(fake_nlu.calls[0]["top_k"], 6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
195
intelligent_cabin/archive/tests/test_router_decisions.py
Normal file
195
intelligent_cabin/archive/tests/test_router_decisions.py
Normal file
@@ -0,0 +1,195 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from app.schemas.debug import IntentCandidate, MatcherStageDebug
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.services.intent_registry import IntentRegistry
|
||||
from app.services.router import IntentMatchResult, MultiStageIntentMatcher
|
||||
|
||||
|
||||
class _FakeMatcher:
|
||||
def __init__(self, stage_debug: MatcherStageDebug) -> None:
|
||||
self._stage_debug = stage_debug
|
||||
|
||||
def match(self, text: str) -> IntentMatchResult:
|
||||
_ = text
|
||||
return IntentMatchResult(intent=None, stage_debug=self._stage_debug)
|
||||
|
||||
|
||||
def _intent(intent_id: str) -> IntentDefinition:
|
||||
return IntentDefinition(
|
||||
intent_id=intent_id,
|
||||
plugin_id=f"mock.{intent_id}",
|
||||
domain="test",
|
||||
keywords=[],
|
||||
examples=[],
|
||||
)
|
||||
|
||||
|
||||
class RouterDecisionTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.registry = IntentRegistry([_intent("alpha"), _intent("beta"), _intent("gamma")])
|
||||
|
||||
def test_execute_when_bert_classifier_is_clear(self) -> None:
|
||||
matcher = MultiStageIntentMatcher(
|
||||
registry=self.registry,
|
||||
matchers=[
|
||||
_FakeMatcher(
|
||||
MatcherStageDebug(
|
||||
stage="classifier",
|
||||
accepted=True,
|
||||
selected_intent="alpha",
|
||||
score=0.92,
|
||||
reason="classifier selected best candidate",
|
||||
backend="joint-bert-local",
|
||||
candidates=[
|
||||
IntentCandidate(intent_id="alpha", score=0.92, reason="classifier", model_name="joint-bert-local"),
|
||||
IntentCandidate(intent_id="beta", score=0.21, reason="classifier", model_name="joint-bert-local"),
|
||||
],
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = matcher.match("alpha")
|
||||
|
||||
self.assertEqual(result.debug.decision, "execute")
|
||||
self.assertEqual(result.intent.intent_id if result.intent else None, "alpha")
|
||||
|
||||
def test_clarify_when_bert_top_candidates_are_too_close(self) -> None:
|
||||
matcher = MultiStageIntentMatcher(
|
||||
registry=self.registry,
|
||||
matchers=[
|
||||
_FakeMatcher(
|
||||
MatcherStageDebug(
|
||||
stage="classifier",
|
||||
accepted=True,
|
||||
selected_intent="alpha",
|
||||
score=0.22,
|
||||
reason="classifier selected best candidate",
|
||||
backend="bert-local",
|
||||
metadata={"threshold": 0.2},
|
||||
candidates=[
|
||||
IntentCandidate(intent_id="alpha", score=0.31, reason="classifier", model_name="bert-local"),
|
||||
IntentCandidate(intent_id="beta", score=0.28, reason="classifier", model_name="bert-local"),
|
||||
],
|
||||
)
|
||||
),
|
||||
],
|
||||
route_to_cloud_threshold=0.2,
|
||||
)
|
||||
|
||||
result = matcher.match("ambiguous request")
|
||||
|
||||
self.assertEqual(result.debug.decision, "clarify")
|
||||
self.assertIsNone(result.intent)
|
||||
self.assertEqual(result.debug.confidence_grade, "medium")
|
||||
|
||||
def test_route_to_cloud_when_bert_signal_is_weak_but_known(self) -> None:
|
||||
matcher = MultiStageIntentMatcher(
|
||||
registry=self.registry,
|
||||
matchers=[
|
||||
_FakeMatcher(
|
||||
MatcherStageDebug(
|
||||
stage="classifier",
|
||||
accepted=False,
|
||||
selected_intent="alpha",
|
||||
score=0.29,
|
||||
reason="classifier below execute threshold",
|
||||
backend="joint-bert-local",
|
||||
candidates=[
|
||||
IntentCandidate(intent_id="alpha", score=0.29, reason="classifier", model_name="joint-bert-local"),
|
||||
IntentCandidate(intent_id="beta", score=0.14, reason="classifier", model_name="joint-bert-local"),
|
||||
],
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = matcher.match("weak symbolic request")
|
||||
|
||||
self.assertEqual(result.debug.decision, "route_to_cloud")
|
||||
self.assertIsNone(result.intent)
|
||||
|
||||
def test_reject_when_no_branch_has_usable_signal(self) -> None:
|
||||
matcher = MultiStageIntentMatcher(
|
||||
registry=self.registry,
|
||||
matchers=[
|
||||
_FakeMatcher(
|
||||
MatcherStageDebug(
|
||||
stage="classifier",
|
||||
accepted=False,
|
||||
score=0.12,
|
||||
reason="classifier below threshold",
|
||||
backend="bert-local",
|
||||
metadata={"threshold": 0.2},
|
||||
candidates=[],
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = matcher.match("unknown request")
|
||||
|
||||
self.assertEqual(result.debug.decision, "reject")
|
||||
self.assertTrue(result.debug.unknown_detected)
|
||||
self.assertIsNone(result.intent)
|
||||
|
||||
def test_route_to_cloud_for_low_confidence_classifier_only_bert_signal(self) -> None:
|
||||
matcher = MultiStageIntentMatcher(
|
||||
registry=self.registry,
|
||||
matchers=[
|
||||
_FakeMatcher(
|
||||
MatcherStageDebug(
|
||||
stage="classifier",
|
||||
accepted=True,
|
||||
selected_intent="alpha",
|
||||
score=0.31,
|
||||
reason="classifier selected best candidate",
|
||||
backend="bert-local",
|
||||
metadata={"threshold": 0.0, "top_margin": 0.04},
|
||||
candidates=[
|
||||
IntentCandidate(intent_id="alpha", score=0.31, reason="classifier", model_name="bert-local"),
|
||||
IntentCandidate(intent_id="beta", score=0.27, reason="classifier", model_name="bert-local"),
|
||||
],
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = matcher.match("bert only weak request")
|
||||
|
||||
self.assertEqual(result.debug.decision, "route_to_cloud")
|
||||
self.assertIsNone(result.intent)
|
||||
|
||||
def test_execute_for_high_confidence_classifier_only_bert_signal(self) -> None:
|
||||
matcher = MultiStageIntentMatcher(
|
||||
registry=self.registry,
|
||||
matchers=[
|
||||
_FakeMatcher(
|
||||
MatcherStageDebug(
|
||||
stage="classifier",
|
||||
accepted=True,
|
||||
selected_intent="alpha",
|
||||
score=0.92,
|
||||
reason="classifier selected best candidate",
|
||||
backend="bert-local",
|
||||
metadata={"threshold": 0.0, "top_margin": 0.63},
|
||||
candidates=[
|
||||
IntentCandidate(intent_id="alpha", score=0.92, reason="classifier", model_name="bert-local"),
|
||||
IntentCandidate(intent_id="beta", score=0.29, reason="classifier", model_name="bert-local"),
|
||||
],
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = matcher.match("bert only strong request")
|
||||
|
||||
self.assertEqual(result.debug.decision, "execute")
|
||||
self.assertEqual(result.intent.intent_id if result.intent else None, "alpha")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
180
intelligent_cabin/archive/tests/test_social_chat.py
Normal file
180
intelligent_cabin/archive/tests/test_social_chat.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from app.plugins.base import PluginRegistry
|
||||
from app.schemas.chat import ChatRequest, FillSlotsRequest
|
||||
from app.schemas.debug import RoutingDebug
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.schemas.workflow import Workflow, WorkflowStep
|
||||
from app.services.agent_service import AgentService
|
||||
from app.services.intent_registry import IntentRegistry
|
||||
from app.services.session_store import InMemorySessionStore
|
||||
from app.services.social import SocialReplyResult, SocialRouter
|
||||
|
||||
|
||||
class _FailingRouter:
|
||||
def route(self, text: str): # pragma: no cover - should not be called in these tests
|
||||
raise AssertionError(f"router should not be called for social input: {text}")
|
||||
|
||||
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, object]:
|
||||
_ = (text, intent)
|
||||
return {}
|
||||
|
||||
|
||||
class _FakeSocialResponder:
|
||||
def reply(self, text: str, session) -> SocialReplyResult:
|
||||
_ = (text, session)
|
||||
normalized = text.strip()
|
||||
if "你好" in normalized:
|
||||
text = "你好呀,我在,想聊什么都可以。"
|
||||
elif "名字" in normalized or "你是谁" in normalized:
|
||||
text = "我是一名智能座舱助手,你可以直接叫我座舱助手。"
|
||||
elif "天气" in normalized:
|
||||
text = "是啊,今天确实挺舒服的。"
|
||||
else:
|
||||
text = "我在,咱们继续聊。"
|
||||
return SocialReplyResult(
|
||||
text=text,
|
||||
backend="fake-cloud",
|
||||
model_name="fake-social",
|
||||
)
|
||||
|
||||
|
||||
def _intent(intent_id: str, plugin_id: str) -> IntentDefinition:
|
||||
return IntentDefinition(
|
||||
intent_id=intent_id,
|
||||
plugin_id=plugin_id,
|
||||
domain="service" if intent_id.startswith("cs_") else "cabin",
|
||||
risk_level="high" if intent_id == "cs_cancel_order" else "low",
|
||||
required_slots=["order_id"] if intent_id == "cs_cancel_order" else [],
|
||||
ask_templates={"order_id": "请告诉我订单号。"} if intent_id == "cs_cancel_order" else {},
|
||||
keywords=[],
|
||||
examples=[],
|
||||
)
|
||||
|
||||
|
||||
class SocialChatTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.session_store = InMemorySessionStore()
|
||||
self.plugins = PluginRegistry()
|
||||
self.plugins.register(
|
||||
"mock.cancel_order",
|
||||
lambda slots: {"success": True, "message": f"已取消订单 {slots.get('order_id', '')}。"},
|
||||
)
|
||||
self.service = AgentService(
|
||||
intent_registry=IntentRegistry([_intent("cs_cancel_order", "mock.cancel_order")]),
|
||||
router=_FailingRouter(),
|
||||
plugins=self.plugins,
|
||||
session_store=self.session_store,
|
||||
social_router=SocialRouter(),
|
||||
social_responder=_FakeSocialResponder(),
|
||||
)
|
||||
|
||||
def test_greeting_social_reply_uses_social_responder(self) -> None:
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_social_hi",
|
||||
user_id="user_1",
|
||||
input_text="你好",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.decision, "open_social")
|
||||
self.assertEqual(response.status, "social")
|
||||
self.assertIn("你好呀", response.reply_text)
|
||||
|
||||
def test_capability_social_question_does_not_fall_into_business_intent(self) -> None:
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_social_name",
|
||||
user_id="user_1",
|
||||
input_text="你叫什么名字",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.decision, "open_social")
|
||||
self.assertEqual(response.status, "social")
|
||||
self.assertNotEqual(response.reply_type, "ask_slot")
|
||||
self.assertIn("智能座舱助手", response.reply_text)
|
||||
|
||||
def test_open_social_reply_uses_social_responder(self) -> None:
|
||||
response = self.service.handle_chat(
|
||||
ChatRequest(
|
||||
session_id="sess_social_open",
|
||||
user_id="user_1",
|
||||
input_text="今天天气真不错啊",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(response.decision, "open_social")
|
||||
self.assertEqual(response.status, "social")
|
||||
self.assertIn("挺舒服", response.reply_text)
|
||||
|
||||
def test_social_turn_does_not_break_waiting_confirmation(self) -> None:
|
||||
session = self.session_store.get_or_create("sess_confirm", "user_1")
|
||||
session.current_intent = "cs_cancel_order"
|
||||
session.status = "waiting_confirmation"
|
||||
session.pending_slots = ["confirmation"]
|
||||
session.slots = {"order_id": "A123456"}
|
||||
session.routing_debug = RoutingDebug(selected_intent="cs_cancel_order", decision="execute").model_dump()
|
||||
session.workflow = Workflow(
|
||||
workflow_id="wf_sess_confirm",
|
||||
workflow_type="conditional",
|
||||
domain="service",
|
||||
intent_id="cs_cancel_order",
|
||||
status="waiting_confirmation",
|
||||
risk_level="high",
|
||||
slots={"order_id": "A123456"},
|
||||
steps=[
|
||||
WorkflowStep(
|
||||
step=1,
|
||||
step_id="step_confirm",
|
||||
intent_id="cs_cancel_order",
|
||||
plugin_id="mock.cancel_order",
|
||||
action="cancel_order",
|
||||
status="waiting_confirmation",
|
||||
slots={"order_id": "A123456"},
|
||||
requires_confirmation=True,
|
||||
)
|
||||
],
|
||||
meta={
|
||||
"pending_confirmation": {
|
||||
"step_id": "step_confirm",
|
||||
"intent_id": "cs_cancel_order",
|
||||
"detail": "确认取消订单 A123456",
|
||||
},
|
||||
"confirmed_steps": [],
|
||||
"step_results": {},
|
||||
},
|
||||
).model_dump()
|
||||
self.session_store.save(session)
|
||||
|
||||
social_response = self.service.handle_fill_slots(
|
||||
FillSlotsRequest(
|
||||
session_id="sess_confirm",
|
||||
user_id="user_1",
|
||||
input_text="今天天气真不错啊",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(social_response.decision, "open_social")
|
||||
self.assertEqual(social_response.status, "waiting_confirmation")
|
||||
self.assertEqual(social_response.pending_slots, ["confirmation"])
|
||||
self.assertIn("回复“确认”或“取消”即可", social_response.reply_text)
|
||||
|
||||
confirm_response = self.service.handle_fill_slots(
|
||||
FillSlotsRequest(
|
||||
session_id="sess_confirm",
|
||||
user_id="user_1",
|
||||
input_text="确认",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(confirm_response.reply_type, "workflow_result")
|
||||
self.assertEqual(confirm_response.status, "completed")
|
||||
self.assertIn("已取消订单", confirm_response.reply_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
184
intelligent_cabin/archive/tests/test_workflow_templates.py
Normal file
184
intelligent_cabin/archive/tests/test_workflow_templates.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from app.schemas.configuration import WorkflowTemplatesConfig
|
||||
from app.services.classifier import ClassificationResult
|
||||
from app.services.multi_intent_detector import MultiIntentCandidate, MultiIntentDetectionResult
|
||||
from app.services.planner import CompositeWorkflowPlanner, HeuristicWorkflowPlanner, TemplateWorkflowPlanner
|
||||
from app.services.intent_registry import IntentRegistry
|
||||
|
||||
|
||||
class FakeClauseClassifier:
|
||||
def __init__(self, predictions: dict[str, list[dict[str, float | str]]]) -> None:
|
||||
self._predictions = predictions
|
||||
|
||||
def predict(self, text, intents):
|
||||
_ = intents
|
||||
return ClassificationResult(
|
||||
intent=None,
|
||||
score=0.0,
|
||||
model_name="fake-bert-clause",
|
||||
backend_name="fake-bert-clause",
|
||||
raw_candidates=self._predictions.get(text, []),
|
||||
)
|
||||
|
||||
|
||||
class FakeMultiIntentDetector:
|
||||
def __init__(self, predictions: dict[str, list[tuple[str, float]]]) -> None:
|
||||
self._predictions = predictions
|
||||
|
||||
def detect(self, text, intents):
|
||||
_ = intents
|
||||
candidates = [
|
||||
MultiIntentCandidate(intent_id=intent_id, score=score, label=intent_id)
|
||||
for intent_id, score in self._predictions.get(text, [])
|
||||
]
|
||||
return MultiIntentDetectionResult(
|
||||
detected=len(candidates) >= 2,
|
||||
candidates=candidates,
|
||||
reason="fake detector",
|
||||
backend_name="fake-multi",
|
||||
)
|
||||
|
||||
|
||||
class WorkflowTemplateTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.registry = IntentRegistry.from_json("app/data/intents.json")
|
||||
self.templates = WorkflowTemplatesConfig.model_validate_json(
|
||||
Path("config/workflows.yml").read_text(encoding="utf-8")
|
||||
)
|
||||
|
||||
def test_template_planner_matches_sequence_template(self) -> None:
|
||||
planner = TemplateWorkflowPlanner(self.templates)
|
||||
|
||||
result = planner.plan("打开车窗,然后把空调调到20度", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.backend, "local-template")
|
||||
self.assertEqual(result.workflow_type, "sequence")
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_set_ac"])
|
||||
self.assertEqual(result.steps[1].slots.get("temperature"), 20)
|
||||
|
||||
def test_template_planner_matches_conditional_template(self) -> None:
|
||||
planner = TemplateWorkflowPlanner(self.templates)
|
||||
|
||||
result = planner.plan("查一下订单A123456,如果还没发货就取消", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.workflow_type, "conditional")
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cs_query_order", "cs_cancel_order"])
|
||||
self.assertEqual(result.steps[1].depends_on, [1])
|
||||
self.assertTrue(result.steps[1].requires_confirmation)
|
||||
|
||||
def test_composite_planner_falls_back_to_heuristic_when_template_misses(self) -> None:
|
||||
planner = CompositeWorkflowPlanner([TemplateWorkflowPlanner(self.templates), HeuristicWorkflowPlanner()])
|
||||
|
||||
result = planner.plan("打开车窗,并且播放轻音乐", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertIn(result.backend, {"local-template", "local-heuristic"})
|
||||
self.assertEqual(result.workflow_type, "sequence")
|
||||
|
||||
def test_heuristic_planner_parses_ac_then_window_close_sequence(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner()
|
||||
|
||||
result = planner.plan("打开空调,再把窗户降下来", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.backend, "local-heuristic")
|
||||
self.assertEqual(result.workflow_type, "sequence")
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_ac_on", "cabin_window_close"])
|
||||
|
||||
def test_planner_metadata_contains_clause_analysis(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner()
|
||||
|
||||
result = planner.plan("打开空调,然后打开车窗", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertTrue(result.metadata.get("multi_intent_detected"))
|
||||
clause_analysis = result.metadata.get("clause_analysis", [])
|
||||
self.assertEqual(len(clause_analysis), 2)
|
||||
self.assertEqual(clause_analysis[0].get("selected_intent_id"), "cabin_ac_on")
|
||||
self.assertEqual(clause_analysis[1].get("selected_intent_id"), "cabin_window_open")
|
||||
|
||||
def test_heuristic_planner_supports_shared_action_parallel_objects(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner()
|
||||
|
||||
result = planner.plan("打开空调和车窗", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.workflow_type, "sequence")
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_ac_on", "cabin_window_open"])
|
||||
|
||||
def test_heuristic_planner_supports_parallel_objects_with_suffix_action(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner()
|
||||
|
||||
result = planner.plan("把车窗和天窗打开", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.workflow_type, "sequence")
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_sunroof_open"])
|
||||
|
||||
def test_heuristic_planner_supports_parallel_clause_with_bing_connector(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner()
|
||||
|
||||
result = planner.plan("打开车窗并播放轻音乐", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.workflow_type, "sequence")
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
|
||||
|
||||
def test_heuristic_planner_can_use_clause_classifier_to_rescue_semantic_clause(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner(
|
||||
clause_classifier=FakeClauseClassifier(
|
||||
{
|
||||
"车里太闷了": [
|
||||
{"label": "cabin_window_open", "intent_id": "cabin_window_open", "score": 0.83},
|
||||
],
|
||||
"来点轻音乐": [
|
||||
{"label": "cabin_play_music", "intent_id": "cabin_play_music", "score": 0.91},
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
result = planner.plan("车里太闷了,然后来点轻音乐", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.workflow_type, "sequence")
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
|
||||
clause_analysis = result.metadata.get("clause_analysis", [])
|
||||
self.assertGreater(clause_analysis[0].get("candidates", [])[0].get("model_score", 0.0), 0.8)
|
||||
|
||||
def test_heuristic_planner_can_use_multi_intent_detector_prior(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner(
|
||||
clause_classifier=FakeClauseClassifier(
|
||||
{
|
||||
"来点轻音乐": [
|
||||
{"label": "cabin_play_music", "intent_id": "cabin_play_music", "score": 0.91},
|
||||
],
|
||||
}
|
||||
),
|
||||
multi_intent_detector=FakeMultiIntentDetector(
|
||||
{
|
||||
"顺便开下车窗,再来点轻音乐": [
|
||||
("cabin_window_open", 0.87),
|
||||
("cabin_play_music", 0.82),
|
||||
]
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
result = planner.plan("顺便开下车窗,再来点轻音乐", self.registry.list())
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
|
||||
detector_meta = result.metadata.get("multi_intent_detector") or {}
|
||||
self.assertTrue(detector_meta.get("detected"))
|
||||
self.assertEqual(len(detector_meta.get("candidates", [])), 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user