Update project and configurations

This commit is contained in:
Zou-Seay
2026-06-11 16:28:00 +08:00
parent 12d3922091
commit a29a91867d
237 changed files with 164880 additions and 90 deletions

View File

@@ -0,0 +1,132 @@
from __future__ import annotations
import unittest
from app.plugins.base import PluginRegistry
from app.schemas.chat import ChatRequest
from app.schemas.debug import IntentCandidate, MatcherStageDebug, RoutingDebug
from app.schemas.intent import IntentDefinition
from app.services.agent_service import AgentService
from app.services.intent_registry import IntentRegistry
from app.services.planner import PlanningResult
from app.services.session_store import InMemorySessionStore
class _RouteToCloudRouter:
def route(self, text: str):
_ = text
return type(
"RouteResult",
(),
{
"intent": None,
"debug": RoutingDebug(
selected_intent="cabin_nav_to",
matched_stage="fusion",
decision="route_to_cloud",
decision_reason="local signal is not stable enough, routing to cloud planner",
confidence_grade="low",
stages=[
MatcherStageDebug(
stage="fusion",
accepted=False,
selected_intent="cabin_nav_to",
score=0.88,
reason="route to cloud",
candidates=[
IntentCandidate(intent_id="cabin_nav_to", score=0.88, reason="fusion", model_name="fusion"),
IntentCandidate(intent_id="cabin_play_music", score=0.75, reason="fusion", model_name="fusion"),
],
)
],
),
},
)()
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, object]:
_ = (text, intent)
return {}
class _PlannerRejects:
def plan(self, text: str, intents: list[IntentDefinition], context: dict[str, object] | None = None) -> PlanningResult:
_ = (text, intents, context)
return PlanningResult(
accepted=False,
workflow_type="single",
model_name="qwen3.5-plus",
backend="dashscope",
reason="cloud planner could not produce a stable executable step",
)
class _PlannerOutOfScope:
def plan(self, text: str, intents: list[IntentDefinition], context: dict[str, object] | None = None) -> PlanningResult:
_ = (text, intents, context)
return PlanningResult(
accepted=False,
workflow_type="single",
model_name="qwen3.5-plus",
backend="dashscope",
reason="The provided intent catalog only contains cabin and service actions. There is no matching intent for ordering food via a third-party app action.",
)
def _intent(intent_id: str) -> IntentDefinition:
return IntentDefinition(
intent_id=intent_id,
plugin_id=f"mock.{intent_id}",
domain="cabin",
keywords=[],
examples=[],
)
class AgentCloudRouteTests(unittest.TestCase):
def test_route_to_cloud_returns_explicit_clarify_feedback_when_planner_does_not_accept(self) -> None:
service = AgentService(
intent_registry=IntentRegistry([_intent("cabin_nav_to"), _intent("cabin_play_music")]),
router=_RouteToCloudRouter(),
plugins=PluginRegistry(),
session_store=InMemorySessionStore(),
planner=_PlannerRejects(),
)
response = service.handle_chat(
ChatRequest(
session_id="sess_cloud_route",
user_id="user_1",
input_text="带我过去",
)
)
self.assertEqual(response.decision, "route_to_cloud")
self.assertEqual(response.reply_type, "clarify")
self.assertEqual(response.status, "route_to_cloud")
self.assertIn("请确认一下", response.reply_text)
def test_route_to_cloud_rejects_when_planner_marks_request_out_of_scope(self) -> None:
service = AgentService(
intent_registry=IntentRegistry([_intent("cabin_nav_to"), _intent("cabin_play_music")]),
router=_RouteToCloudRouter(),
plugins=PluginRegistry(),
session_store=InMemorySessionStore(),
planner=_PlannerOutOfScope(),
)
response = service.handle_chat(
ChatRequest(
session_id="sess_cloud_route_reject",
user_id="user_1",
input_text="去美团叫个外卖",
)
)
self.assertEqual(response.reply_type, "reject")
self.assertEqual(response.decision, "reject")
self.assertEqual(response.status, "rejected")
self.assertIn("做不了", response.reply_text)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,235 @@
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from app.core.config import settings
from app.core.bootstrap import build_intent_registry
from app.services.classifier import BertIntentClassifier
from app.services.router import build_matcher_pipeline
DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "local_bert_intent"
NON_BUSINESS_LABELS = {"__social__", "__out_of_scope__"}
def resolve_model_path(model_path: str) -> Path:
configured = (model_path or settings.classifier_model_path).strip()
if configured:
return Path(configured)
return DEFAULT_MODEL_DIR
def resolve_label_map_path(label_map_path: str, model_path: Path) -> Path:
configured = (label_map_path or settings.classifier_label_map_path).strip()
if configured:
return Path(configured)
return model_path / "label_map.json"
def build_classifier(
*,
model_path: Path,
label_map_path: Path,
threshold: float,
top_k: int,
) -> BertIntentClassifier:
return BertIntentClassifier(
model_path=str(model_path),
threshold=threshold,
label_map_path=str(label_map_path),
fallback=None,
top_k=top_k,
)
def predict_once(
text: str,
*,
model_path: Path,
label_map_path: Path,
threshold: float,
top_k: int,
warmup: bool,
) -> dict[str, object]:
registry = build_intent_registry()
classifier = build_classifier(
model_path=model_path,
label_map_path=label_map_path,
threshold=threshold,
top_k=top_k,
)
warmup_ok = None
if warmup:
warmup_ok = classifier.warmup(settings.classifier_warmup_text)
if not warmup_ok:
error_message = getattr(classifier, "_warmup_error_message", None) or "BERT warmup failed"
raise RuntimeError(error_message)
intents = registry.list()
result = classifier.predict(text, intents)
matcher = build_matcher_pipeline(
registry,
["classifier"],
classifier=classifier,
route_to_cloud_threshold=settings.local_route_to_cloud_threshold,
clarify_margin_threshold=settings.local_clarify_margin_threshold,
classifier_execute_score_threshold=settings.local_classifier_execute_score_threshold,
classifier_execute_margin_threshold=settings.local_classifier_execute_margin_threshold,
)
route_result = matcher.match(text)
fusion_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "fusion"), None)
classifier_stage = next((stage for stage in reversed(route_result.debug.stages) if stage.stage == "classifier"), None)
return {
"text": text,
"config": {
"model_path": str(model_path),
"label_map_path": str(label_map_path),
"threshold": threshold,
"top_k": top_k,
"warmup_requested": warmup,
"warmup_ok": warmup_ok,
"warmup_elapsed_ms": getattr(classifier, "_warmup_elapsed_ms", None),
"warmup_error_message": getattr(classifier, "_warmup_error_message", None),
},
"classifier_result": {
"predicted_intent": result.intent.intent_id if result.intent is not None else None,
"score": round(result.score, 4),
"model_name": result.model_name,
"backend_name": result.backend_name,
"used_fallback": result.used_fallback,
"fallback_reason": result.fallback_reason,
"error_message": result.error_message,
"raw_label": result.raw_label,
"raw_candidates": result.raw_candidates or [],
"known_candidates": [
{"intent_id": intent.intent_id, "score": round(score, 4)}
for intent, score in (result.candidates or [])
],
},
"route_result": {
"decision": route_result.debug.decision,
"decision_reason": route_result.debug.decision_reason,
"matched_stage": route_result.debug.matched_stage,
"selected_intent": route_result.debug.selected_intent,
"confidence_grade": route_result.debug.confidence_grade,
"unknown_detected": route_result.debug.unknown_detected,
"classifier_score": round(classifier_stage.score, 4) if classifier_stage is not None else None,
"fusion_score": round(fusion_stage.score, 4) if fusion_stage is not None else None,
},
}
def summarize_business_view(result: dict[str, object]) -> dict[str, object]:
classifier_result = dict(result.get("classifier_result") or {})
route_result = dict(result.get("route_result") or {})
predicted_intent = classifier_result.get("predicted_intent")
raw_label = classifier_result.get("raw_label")
effective_label = raw_label if raw_label in NON_BUSINESS_LABELS else predicted_intent
if effective_label in NON_BUSINESS_LABELS:
classifier_result["predicted_intent"] = None
classifier_result["non_business_label"] = effective_label
classifier_result["business_interpretation"] = "non_business_label_detected"
route_result["selected_intent"] = None
route_result["decision"] = "reject"
route_result["decision_reason"] = "classifier detected a non-business label"
route_result["unknown_detected"] = True
else:
classifier_result["non_business_label"] = None
classifier_result["business_interpretation"] = "known_business_intent_or_uncertain"
return {
"text": result.get("text"),
"config": result.get("config"),
"classifier_result": classifier_result,
"route_result": route_result,
}
def interactive_loop(
*,
model_path: Path,
label_map_path: Path,
threshold: float,
top_k: int,
warmup: bool,
mode: str,
) -> None:
print("当前 BERT 测试已启动,输入一句话直接查看预测结果,输入 exit 退出。")
print(f"model_path={model_path}")
print(f"label_map_path={label_map_path}")
print(f"threshold={threshold} top_k={top_k} warmup={warmup} mode={mode}")
while True:
try:
text = input("\n请输入问题> ").strip()
except EOFError:
print()
break
if not text:
continue
if text.lower() in {"exit", "quit", "q"}:
break
result = predict_once(
text,
model_path=model_path,
label_map_path=label_map_path,
threshold=threshold,
top_k=top_k,
warmup=warmup,
)
if mode == "business":
result = summarize_business_view(result)
print(json.dumps(result, ensure_ascii=False, indent=2))
def main() -> None:
parser = argparse.ArgumentParser(description="当前项目 BERT 意图识别测试脚本")
parser.add_argument("--text", type=str, default="", help="单次测试文本")
parser.add_argument("--threshold", type=float, default=settings.classifier_bert_threshold, help="BERT 置信度阈值")
parser.add_argument("--top-k", type=int, default=settings.classifier_top_k, help="返回候选数量")
parser.add_argument("--model-path", type=str, default="", help="模型目录,默认取 .env 或 models/local_bert_intent")
parser.add_argument("--label-map-path", type=str, default="", help="标签映射文件,默认取 .env 或 model_path/label_map.json")
parser.add_argument("--warmup", action="store_true", help="先执行一次 warmup 再预测")
parser.add_argument(
"--mode",
choices=("classifier", "business"),
default="classifier",
help="classifier 显示原始分类结果business 会把非业务标签折叠成未命中业务意图",
)
args = parser.parse_args()
model_path = resolve_model_path(args.model_path)
label_map_path = resolve_label_map_path(args.label_map_path, model_path)
if args.text.strip():
result = predict_once(
args.text.strip(),
model_path=model_path,
label_map_path=label_map_path,
threshold=args.threshold,
top_k=args.top_k,
warmup=args.warmup,
)
if args.mode == "business":
result = summarize_business_view(result)
print(json.dumps(result, ensure_ascii=False, indent=2))
return
interactive_loop(
model_path=model_path,
label_map_path=label_map_path,
threshold=args.threshold,
top_k=args.top_k,
warmup=args.warmup,
mode=args.mode,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,109 @@
from __future__ import annotations
import json
import os
import time
import unittest
from unittest.mock import patch
from fastapi.testclient import TestClient
os.environ["AGENT_CLASSIFIER_BACKEND"] = "mock"
os.environ["AGENT_CLASSIFIER_WARMUP_ENABLED"] = "false"
from app.main import app
from app.schemas.chat import ChatResponse
def _fake_response() -> ChatResponse:
return ChatResponse(
session_id="sess_stream_1",
reply_type="workflow_result",
reply_text="好,空调已经打开了。",
intent="cabin_ac_on",
status="completed",
trace_id="trace_stream_1",
)
class ChatStreamTests(unittest.TestCase):
def test_chat_stream_returns_final_only_when_fast(self) -> None:
client = TestClient(app)
with patch("app.main.agent_service.handle_chat", return_value=_fake_response()):
response = client.post(
"/api/v1/agent/chat-stream",
json={
"session_id": "sess_stream_1",
"user_id": "user_stream_1",
"channel": "test",
"input_text": "打开车窗",
"input_type": "text",
},
)
self.assertEqual(response.status_code, 200)
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
self.assertEqual(len(lines), 1)
final_event = json.loads(lines[0])
self.assertEqual(final_event.get("type"), "final")
def test_chat_stream_returns_ack_then_final_when_slow_request(self) -> None:
client = TestClient(app)
def _slow_handle_chat(_request):
time.sleep(1.2)
return _fake_response()
with patch("app.main.agent_service.handle_chat", side_effect=_slow_handle_chat):
response = client.post(
"/api/v1/agent/chat-stream",
json={
"session_id": "sess_stream_1",
"user_id": "user_stream_1",
"channel": "test",
"input_text": "打开车窗",
"input_type": "text",
},
)
self.assertEqual(response.status_code, 200)
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
self.assertGreaterEqual(len(lines), 2)
ack_event = json.loads(lines[0])
final_event = json.loads(lines[-1])
self.assertEqual(ack_event.get("type"), "ack")
self.assertEqual(final_event.get("type"), "final")
self.assertIn("data", final_event)
self.assertIn("reply_text", final_event["data"])
def test_chat_stream_returns_ack_then_final_when_slow_social_request(self) -> None:
client = TestClient(app)
def _slow_handle_chat(_request):
time.sleep(1.2)
return _fake_response()
with patch("app.main.agent_service.handle_chat", side_effect=_slow_handle_chat):
response = client.post(
"/api/v1/agent/chat-stream",
json={
"session_id": "sess_stream_1",
"user_id": "user_stream_1",
"channel": "test",
"input_text": "今天天气如何",
"input_type": "text",
},
)
self.assertEqual(response.status_code, 200)
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
self.assertGreaterEqual(len(lines), 2)
ack_event = json.loads(lines[0])
final_event = json.loads(lines[-1])
self.assertEqual(ack_event.get("type"), "ack")
self.assertEqual(final_event.get("type"), "final")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,90 @@
from __future__ import annotations
import unittest
from unittest.mock import patch
from app.core.bootstrap import build_planner, load_runtime_bundle
from app.core.config import settings
from app.services.planner import CompositeWorkflowPlanner
from app.services.config_loader import ConfigLoader
from app.services.dialog_rules import DialogRuleEngine
from app.services.response_policy import ResponsePolicy
class ConfigLoaderTests(unittest.TestCase):
def test_loader_reads_domain_actions_and_responses(self) -> None:
bundle = ConfigLoader(
domain_path="config/domain.yml",
action_path="config/actions.yml",
response_path="config/responses.yml",
form_path="config/forms.yml",
rule_path="config/rules.yml",
dialog_act_path="config/dialog_acts.yml",
workflow_path="config/workflows.yml",
legacy_intent_path="app/data/intents.json",
).load()
self.assertGreaterEqual(len(bundle.intent_registry.list()), 30)
self.assertEqual(bundle.intent_registry.get("cabin_window_open").plugin_id, "plugin.cabin.window.open")
self.assertEqual(bundle.intent_hints.get("cabin_window_open"), "打开车窗")
self.assertEqual(bundle.response_templates.get("task_stopped"), "好的,已停止当前任务。")
self.assertEqual(bundle.intent_registry.get("cabin_set_ac").required_slots, ["temperature"])
self.assertTrue(bundle.dialog_rules.is_stop_request("先不要了"))
self.assertEqual(bundle.dialog_rules.parse_confirmation_decision("确认"), True)
self.assertEqual(bundle.dialog_act_engine.detect("你好"), "chitchat")
self.assertGreaterEqual(len(bundle.workflow_templates.templates), 2)
def test_bootstrap_runtime_bundle_is_available(self) -> None:
bundle = load_runtime_bundle()
self.assertGreaterEqual(len(bundle.intent_registry.list()), 30)
self.assertIn("fallback", bundle.response_templates)
self.assertEqual(bundle.dialog_act_engine.detect("确认"), "affirm")
def test_response_policy_can_be_driven_by_config_templates(self) -> None:
policy = ResponsePolicy(
templates={"reject": "这个能力暂未开通。"},
intent_hints={"cabin_window_open": "开车窗"},
)
self.assertEqual(policy.reject(), "这个能力暂未开通。")
self.assertEqual(policy.clarify(["cabin_window_open"]), "请确认一下,你是想开车窗吗?")
def test_response_policy_formats_multi_step_summary_naturally(self) -> None:
policy = ResponsePolicy()
summary = policy.workflow_summary(["好的,已打开空调。", "已将空调调到 20 度。"])
self.assertEqual(summary, "好,空调已经打开了,也调到 20 度了。")
def test_response_policy_formats_multi_step_summary_in_vehicle_style(self) -> None:
policy = ResponsePolicy()
summary = policy.workflow_summary(["好的,已打开空调。", "好的,已关闭车窗。"])
self.assertEqual(summary, "好,空调已经打开了,车窗也帮你关上了。")
def test_build_planner_prefers_local_planners_before_cloud(self) -> None:
with patch.object(settings, "planner_backend", "dashscope"):
planner = build_planner()
self.assertIsInstance(planner, CompositeWorkflowPlanner)
self.assertIsInstance(planner._planners[0], CompositeWorkflowPlanner)
def test_dialog_rule_engine_supports_configured_confirmation_and_stop(self) -> None:
rules = DialogRuleEngine(
stop_phrases=("先不用了",),
positive_confirmation_tokens=("好,继续",),
negative_confirmation_tokens=("取消吧",),
confirmation_required_intents=("foo",),
confirmation_required_risk_levels=("high",),
)
self.assertTrue(rules.is_stop_request("先不用了"))
self.assertTrue(rules.parse_confirmation_decision("好,继续"))
self.assertFalse(rules.parse_confirmation_decision("取消吧"))
self.assertTrue(rules.requires_confirmation("foo", "low"))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,202 @@
from __future__ import annotations
import unittest
from app.plugins.base import PluginRegistry
from app.plugins.mock import MockPluginExecutor
from app.schemas.chat import ChatRequest
from app.schemas.debug import IntentCandidate, MatcherStageDebug, RoutingDebug
from app.schemas.workflow import Workflow, WorkflowStep
from app.services.agent_service import AgentService
from app.services.intent_registry import IntentRegistry
from app.services.planner import HeuristicWorkflowPlanner
from app.services.response_policy import ResponsePolicy
from app.services.rewrite_engine import ContextRewriteEngine
from app.services.router import RouteMatchResult
from app.services.session_store import InMemorySessionStore
class _FailIfCalledPlanner:
def plan(self, text, intents, context=None):
_ = (intents, context)
raise AssertionError(f"planner should not be called for single intent request: {text}")
class _ScriptedRouter:
def __init__(self, registry: IntentRegistry) -> None:
self._registry = registry
self._route_map = {
"来点music": self._route_result("cabin_play_music", ["cabin_play_music"]),
"打开车窗和空调": self._route_result("cabin_window_open", ["cabin_window_open", "cabin_ac_on"]),
"关闭车窗": self._route_result("cabin_window_close", ["cabin_window_close", "cabin_window_open"]),
}
self._slot_map = {
("播放黄昏", "cabin_play_music"): {"song": "黄昏"},
("来一首黄昏", "cabin_play_music"): {"song": "黄昏"},
("来点黄昏", "cabin_play_music"): {"song": "黄昏"},
}
def route(self, text: str) -> RouteMatchResult:
if text not in self._route_map:
raise AssertionError(f"unexpected route request: {text}")
return self._route_map[text]
def extract_slots(self, text: str, intent) -> dict[str, object]:
return dict(self._slot_map.get((text, intent.intent_id), {}))
def _route_result(self, selected_intent: str, candidates: list[str]) -> RouteMatchResult:
intent = self._registry.get(selected_intent)
stage = MatcherStageDebug(
stage="fusion",
accepted=True,
selected_intent=selected_intent,
score=1.0,
candidates=[
IntentCandidate(intent_id=intent_id, score=max(0.5, 1.0 - index * 0.1))
for index, intent_id in enumerate(candidates)
],
)
return RouteMatchResult(
intent=intent,
debug=RoutingDebug(
selected_intent=selected_intent,
matched_stage="fusion",
decision="execute",
stages=[stage],
),
)
class DialogContinuationAndMultiIntentTests(unittest.TestCase):
def setUp(self) -> None:
self.registry = IntentRegistry.from_json("app/data/intents.json")
self.plugins = PluginRegistry()
MockPluginExecutor().register(self.plugins)
self.service = AgentService(
intent_registry=self.registry,
router=_ScriptedRouter(self.registry),
plugins=self.plugins,
session_store=InMemorySessionStore(),
rewrite_engine=ContextRewriteEngine(),
response_policy=ResponsePolicy(),
planner=HeuristicWorkflowPlanner(),
)
def test_music_followup_in_chat_continues_waiting_slot(self) -> None:
first = self.service.handle_chat(
ChatRequest(
session_id="sess_music_followup",
user_id="user_1",
input_text="来点music",
)
)
self.assertEqual(first.reply_type, "ask_slot")
self.assertEqual(first.pending_slots, ["media_query"])
second = self.service.handle_chat(
ChatRequest(
session_id="sess_music_followup",
user_id="user_1",
input_text="黄昏",
)
)
self.assertEqual(second.reply_type, "workflow_result")
self.assertEqual(second.intent, "cabin_play_music")
self.assertEqual(second.filled_slots.get("song"), "黄昏")
self.assertIn("黄昏", second.reply_text)
def test_parallel_compound_request_enters_planner(self) -> None:
response = self.service.handle_chat(
ChatRequest(
session_id="sess_parallel_compound",
user_id="user_1",
input_text="打开车窗和空调",
)
)
self.assertEqual(response.reply_type, "workflow_result")
self.assertEqual(response.workflow.workflow_type, "sequence")
step_intents = [step.intent_id for step in response.workflow.steps]
self.assertEqual(step_intents, ["cabin_window_open", "cabin_ac_on"])
self.assertIn("车窗", response.reply_text)
self.assertIn("空调", response.reply_text)
def test_single_cabin_intent_does_not_enter_planner_from_top2_domain_candidates(self) -> None:
service = AgentService(
intent_registry=self.registry,
router=_ScriptedRouter(self.registry),
plugins=self.plugins,
session_store=InMemorySessionStore(),
rewrite_engine=ContextRewriteEngine(),
response_policy=ResponsePolicy(),
planner=_FailIfCalledPlanner(),
)
response = service.handle_chat(
ChatRequest(
session_id="sess_single_cabin_no_planner",
user_id="user_1",
input_text="关闭车窗",
)
)
self.assertEqual(response.reply_type, "workflow_result")
self.assertEqual(response.intent, "cabin_window_close")
self.assertEqual(response.routing_debug.decision, "execute")
self.assertFalse(any(stage.stage == "planner" for stage in response.routing_debug.stages))
def test_waiting_confirmation_can_continue_via_chat(self) -> None:
session = self.service.session_store.get_or_create("sess_confirm_chat", "user_1")
session.current_intent = "cs_cancel_order"
session.status = "waiting_confirmation"
session.pending_slots = ["confirmation"]
session.slots = {"order_id": "A123456"}
session.workflow = Workflow(
workflow_id="wf_confirm_chat",
workflow_type="conditional",
domain="customer_service",
intent_id="cs_cancel_order",
status="waiting_confirmation",
risk_level="high",
slots={"order_id": "A123456"},
steps=[
WorkflowStep(
step=1,
step_id="step_cancel",
intent_id="cs_cancel_order",
plugin_id="plugin.order.cancel",
action="cancel_order",
status="waiting_confirmation",
slots={"order_id": "A123456"},
requires_confirmation=True,
)
],
meta={
"pending_confirmation": {
"step_id": "step_cancel",
"intent_id": "cs_cancel_order",
"detail": "确认取消订单 A123456",
},
"step_results": {},
"confirmed_steps": [],
},
).model_dump()
self.service.session_store.save(session)
response = self.service.handle_chat(
ChatRequest(
session_id="sess_confirm_chat",
user_id="user_1",
input_text="确认",
)
)
self.assertEqual(response.reply_type, "workflow_result")
self.assertEqual(response.status, "completed")
self.assertIn("A123456", response.reply_text)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,149 @@
from __future__ import annotations
import unittest
from app.plugins.base import PluginRegistry
from app.plugins.mock import MockPluginExecutor
from app.services.classifier import MockIntentClassifier
from app.services.agent_service import AgentService
from app.services.intent_registry import IntentRegistry
from app.services.response_policy import ResponsePolicy
from app.services.rewrite_engine import ContextRewriteEngine
from app.services.router import HeuristicSlotExtractor, IntentRouter, build_matcher_pipeline
from app.services.session_store import InMemorySessionStore
from app.schemas.chat import ChatRequest, FillSlotsRequest
class _BertLikeMockClassifier(MockIntentClassifier):
def predict(self, text, intents):
result = super().predict(text, intents)
result.model_name = "bert-local"
result.backend_name = "bert-local"
return result
def _build_service() -> AgentService:
registry = IntentRegistry.from_json("app/data/intents.json")
plugins = PluginRegistry()
MockPluginExecutor().register(plugins)
router = IntentRouter(
matcher=build_matcher_pipeline(
registry,
["classifier"],
classifier=_BertLikeMockClassifier(threshold=0.0, top_k=3),
),
slot_extractor=HeuristicSlotExtractor(),
)
return AgentService(
intent_registry=registry,
router=router,
plugins=plugins,
session_store=InMemorySessionStore(),
rewrite_engine=ContextRewriteEngine(),
response_policy=ResponsePolicy(),
planner=None,
)
class IntentCoverageAndStopTests(unittest.TestCase):
def setUp(self) -> None:
self.service = _build_service()
self.registry = IntentRegistry.from_json("app/data/intents.json")
def test_intent_catalog_has_at_least_30_items(self) -> None:
self.assertGreaterEqual(len(self.registry.list()), 30)
def test_close_ac_routes_to_power_off(self) -> None:
response = self.service.handle_chat(
ChatRequest(
session_id="sess_close_ac",
user_id="user_1",
input_text="关闭空调",
)
)
self.assertEqual(response.reply_type, "workflow_result")
self.assertEqual(response.intent, "cabin_ac_off")
self.assertIn("已关闭空调", response.reply_text)
def test_open_window_is_covered(self) -> None:
response = self.service.handle_chat(
ChatRequest(
session_id="sess_window_open",
user_id="user_1",
input_text="打开车窗",
)
)
self.assertEqual(response.reply_type, "workflow_result")
self.assertEqual(response.intent, "cabin_window_open")
self.assertIn("已打开车窗", response.reply_text)
def test_stop_current_task_while_waiting_for_slot(self) -> None:
first = self.service.handle_chat(
ChatRequest(
session_id="sess_stop_task",
user_id="user_1",
input_text="空调调到",
)
)
self.assertEqual(first.reply_type, "ask_slot")
self.assertEqual(first.pending_slots, ["temperature"])
stopped = self.service.handle_fill_slots(
FillSlotsRequest(
session_id="sess_stop_task",
user_id="user_1",
input_text="不用了",
)
)
self.assertEqual(stopped.reply_type, "text")
self.assertEqual(stopped.status, "stopped")
self.assertEqual(stopped.pending_slots, [])
self.assertIn("已停止当前任务", stopped.reply_text)
def test_relative_ac_adjustment_uses_two_degree_step(self) -> None:
session = self.service.session_store.get_or_create("sess_ac_lower", "user_1")
session.current_intent = "cabin_set_ac"
session.context_memory["last_temperature"] = 24
self.service.session_store.save(session)
response = self.service.handle_chat(
ChatRequest(
session_id="sess_ac_lower",
user_id="user_1",
input_text="调低一点",
)
)
self.assertEqual(response.reply_type, "workflow_result")
self.assertEqual(response.intent, "cabin_set_ac")
self.assertEqual(response.filled_slots.get("temperature"), 22)
self.assertIn("22", response.reply_text)
def test_relative_ac_adjustment_without_history_uses_default_baseline(self) -> None:
session = self.service.session_store.get_or_create("sess_ac_lower_default", "user_1")
session.current_intent = "cabin_ac_on"
self.service.session_store.save(session)
response = self.service.handle_chat(
ChatRequest(
session_id="sess_ac_lower_default",
user_id="user_1",
input_text="调低一点",
)
)
self.assertEqual(response.reply_type, "workflow_result")
self.assertEqual(response.intent, "cabin_set_ac")
self.assertEqual(response.filled_slots.get("temperature"), 22)
self.assertIn("22", response.reply_text)
def test_temperature_is_clamped_before_execution(self) -> None:
self.assertEqual(self.service._normalize_temperature_value(-1), 16)
self.assertEqual(self.service._normalize_temperature_value(40), 30)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,93 @@
from __future__ import annotations
import unittest
from app.schemas.intent import IntentDefinition
from app.services.classifier import JointBertIntentClassifier
from app.services.planner import HeuristicWorkflowPlanner
from app.services.router import JointBertSlotExtractor
class FakeJointNLU:
def __init__(self) -> None:
self._predictions = {
"把空调调到22度": {
"intent_id": "cabin_set_ac",
"intent_score": 0.93,
"candidates": [("cabin_set_ac", 0.93), ("cabin_ac_on", 0.04)],
"slots": {"temperature": 22},
},
"导航去公司然后把空调调到22度": {
"intent_id": "cabin_nav_to",
"intent_score": 0.88,
"candidates": [("cabin_nav_to", 0.88), ("cabin_set_ac", 0.72)],
"slots": {"destination": "公司"},
},
}
self._slot_predictions = {
("把空调调到22度", "cabin_set_ac"): {"temperature": 22},
("导航去公司", "cabin_nav_to"): {"destination": "公司"},
("把空调调到22度", "cabin_set_ac"): {"temperature": 22},
}
def warmup(self, sample_text: str = "打开车窗") -> bool:
_ = sample_text
return True
def predict(self, text: str, intents: list[IntentDefinition]):
from app.services.joint_nlu import JointCandidate, JointNluResult
raw = self._predictions[text]
candidates = [JointCandidate(intent_id=intent_id, score=score) for intent_id, score in raw["candidates"]]
return JointNluResult(
intent_id=raw["intent_id"],
intent_score=raw["intent_score"],
candidates=candidates,
slots=dict(raw["slots"]),
)
def extract_slots(self, text: str, intent: IntentDefinition):
return dict(self._slot_predictions.get((text, intent.intent_id), {}))
def extract_slots_by_intent_id(self, text: str, intent_id: str, required_slots=None):
_ = required_slots
return dict(self._slot_predictions.get((text, intent_id), {}))
class JointNLUIntegrationTests(unittest.TestCase):
def setUp(self) -> None:
self.intents = [
IntentDefinition(intent_id="cabin_set_ac", plugin_id="x", domain="cabin", required_slots=["temperature"]),
IntentDefinition(intent_id="cabin_nav_to", plugin_id="x", domain="cabin", required_slots=["destination"]),
IntentDefinition(intent_id="cabin_ac_on", plugin_id="x", domain="cabin"),
]
self.fake_nlu = FakeJointNLU()
def test_joint_classifier_uses_joint_nlu_intent_head(self) -> None:
classifier = JointBertIntentClassifier(self.fake_nlu, threshold=0.3, top_k=2)
result = classifier.predict("把空调调到22度", self.intents)
self.assertIsNotNone(result.intent)
self.assertEqual(result.intent.intent_id, "cabin_set_ac")
self.assertEqual(result.raw_candidates[0]["intent_id"], "cabin_set_ac")
def test_joint_slot_extractor_uses_joint_nlu_slots(self) -> None:
extractor = JointBertSlotExtractor(self.fake_nlu)
slots = extractor.extract("把空调调到22度", self.intents[0])
self.assertEqual(slots, {"temperature": 22})
def test_planner_prefers_joint_nlu_slots_for_each_clause(self) -> None:
planner = HeuristicWorkflowPlanner(joint_nlu=self.fake_nlu)
result = planner.plan("导航去公司然后把空调调到22度", self.intents)
self.assertTrue(result.accepted)
self.assertEqual(result.steps[0].slots, {"destination": "公司"})
self.assertEqual(result.steps[1].slots, {"temperature": 22})
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,144 @@
from __future__ import annotations
import unittest
import torch
from app.schemas.intent import IntentDefinition
from app.services.multi_intent_detector import BertMultiIntentDetector, JointBertMultiIntentDetector
class FakeTokenizer:
def __call__(self, text, truncation=True, padding=False, return_tensors="pt"):
_ = (text, truncation, padding, return_tensors)
return {
"input_ids": torch.tensor([[101, 102]], dtype=torch.long),
"attention_mask": torch.tensor([[1, 1]], dtype=torch.long),
}
class FakeModel:
def __init__(self, logits: list[float], id2label: dict[int, str]) -> None:
self.config = type("Config", (), {"id2label": id2label})()
self._logits = torch.tensor([logits], dtype=torch.float32)
def eval(self) -> None:
return None
def __call__(self, **kwargs):
_ = kwargs
return type("Output", (), {"logits": self._logits})()
class RuntimeBackedDetector(BertMultiIntentDetector):
def __init__(
self,
logits: list[float],
id2label: dict[int, str],
threshold: float = 0.45,
top_k: int = 8,
max_labels: int = 4,
) -> None:
super().__init__(
model_path="unused",
threshold=threshold,
top_k=top_k,
max_labels=max_labels,
)
self._runtime = (torch, FakeTokenizer(), FakeModel(logits, id2label))
def _get_runtime(self):
return self._runtime
class MultiIntentDetectorTests(unittest.TestCase):
def test_detector_filters_blocked_and_unknown_labels(self) -> None:
detector = RuntimeBackedDetector(
logits=[2.4, 2.0, 3.2, 2.6],
id2label={
0: "cabin_window_open",
1: "cabin_play_music",
2: "__social__",
3: "unknown_intent",
},
threshold=0.8,
top_k=4,
)
intents = [
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
]
result = detector.detect("打开车窗并播放音乐", intents)
self.assertTrue(result.detected)
self.assertEqual(result.backend_name, "bert-multi-label")
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
def test_detector_respects_threshold_and_max_labels(self) -> None:
detector = RuntimeBackedDetector(
logits=[2.8, 2.5, 2.2],
id2label={
0: "cabin_window_open",
1: "cabin_play_music",
2: "cabin_nav_to",
},
threshold=0.89,
top_k=3,
max_labels=2,
)
intents = [
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
IntentDefinition(intent_id="cabin_nav_to", plugin_id="plugin.nav", domain="cabin"),
]
result = detector.detect("开窗放歌去公司", intents)
self.assertTrue(result.detected)
self.assertEqual(len(result.candidates), 2)
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
def test_joint_bert_detector_wraps_shared_runtime(self) -> None:
intents = [
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
]
class FakeJointNlu:
def __init__(self) -> None:
self.calls: list[dict[str, object]] = []
def predict_multi_intents(self, text, known_intents, threshold=0.45, max_labels=4, top_k=8):
self.calls.append(
{
"text": text,
"threshold": threshold,
"max_labels": max_labels,
"top_k": top_k,
"known_count": len(known_intents),
}
)
return [
type("Candidate", (), {"intent_id": "cabin_window_open", "score": 0.93})(),
type("Candidate", (), {"intent_id": "cabin_play_music", "score": 0.88})(),
]
def warmup(self, sample_text="") -> bool:
_ = sample_text
return True
fake_nlu = FakeJointNlu()
detector = JointBertMultiIntentDetector(fake_nlu, threshold=0.5, top_k=6, max_labels=3)
result = detector.detect("打开车窗并播放音乐", intents)
self.assertTrue(result.detected)
self.assertEqual(result.backend_name, "joint-bert-multi-label")
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
self.assertEqual(fake_nlu.calls[0]["threshold"], 0.5)
self.assertEqual(fake_nlu.calls[0]["top_k"], 6)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,195 @@
from __future__ import annotations
import unittest
from app.schemas.debug import IntentCandidate, MatcherStageDebug
from app.schemas.intent import IntentDefinition
from app.services.intent_registry import IntentRegistry
from app.services.router import IntentMatchResult, MultiStageIntentMatcher
class _FakeMatcher:
def __init__(self, stage_debug: MatcherStageDebug) -> None:
self._stage_debug = stage_debug
def match(self, text: str) -> IntentMatchResult:
_ = text
return IntentMatchResult(intent=None, stage_debug=self._stage_debug)
def _intent(intent_id: str) -> IntentDefinition:
return IntentDefinition(
intent_id=intent_id,
plugin_id=f"mock.{intent_id}",
domain="test",
keywords=[],
examples=[],
)
class RouterDecisionTests(unittest.TestCase):
def setUp(self) -> None:
self.registry = IntentRegistry([_intent("alpha"), _intent("beta"), _intent("gamma")])
def test_execute_when_bert_classifier_is_clear(self) -> None:
matcher = MultiStageIntentMatcher(
registry=self.registry,
matchers=[
_FakeMatcher(
MatcherStageDebug(
stage="classifier",
accepted=True,
selected_intent="alpha",
score=0.92,
reason="classifier selected best candidate",
backend="joint-bert-local",
candidates=[
IntentCandidate(intent_id="alpha", score=0.92, reason="classifier", model_name="joint-bert-local"),
IntentCandidate(intent_id="beta", score=0.21, reason="classifier", model_name="joint-bert-local"),
],
)
),
],
)
result = matcher.match("alpha")
self.assertEqual(result.debug.decision, "execute")
self.assertEqual(result.intent.intent_id if result.intent else None, "alpha")
def test_clarify_when_bert_top_candidates_are_too_close(self) -> None:
matcher = MultiStageIntentMatcher(
registry=self.registry,
matchers=[
_FakeMatcher(
MatcherStageDebug(
stage="classifier",
accepted=True,
selected_intent="alpha",
score=0.22,
reason="classifier selected best candidate",
backend="bert-local",
metadata={"threshold": 0.2},
candidates=[
IntentCandidate(intent_id="alpha", score=0.31, reason="classifier", model_name="bert-local"),
IntentCandidate(intent_id="beta", score=0.28, reason="classifier", model_name="bert-local"),
],
)
),
],
route_to_cloud_threshold=0.2,
)
result = matcher.match("ambiguous request")
self.assertEqual(result.debug.decision, "clarify")
self.assertIsNone(result.intent)
self.assertEqual(result.debug.confidence_grade, "medium")
def test_route_to_cloud_when_bert_signal_is_weak_but_known(self) -> None:
matcher = MultiStageIntentMatcher(
registry=self.registry,
matchers=[
_FakeMatcher(
MatcherStageDebug(
stage="classifier",
accepted=False,
selected_intent="alpha",
score=0.29,
reason="classifier below execute threshold",
backend="joint-bert-local",
candidates=[
IntentCandidate(intent_id="alpha", score=0.29, reason="classifier", model_name="joint-bert-local"),
IntentCandidate(intent_id="beta", score=0.14, reason="classifier", model_name="joint-bert-local"),
],
)
),
],
)
result = matcher.match("weak symbolic request")
self.assertEqual(result.debug.decision, "route_to_cloud")
self.assertIsNone(result.intent)
def test_reject_when_no_branch_has_usable_signal(self) -> None:
matcher = MultiStageIntentMatcher(
registry=self.registry,
matchers=[
_FakeMatcher(
MatcherStageDebug(
stage="classifier",
accepted=False,
score=0.12,
reason="classifier below threshold",
backend="bert-local",
metadata={"threshold": 0.2},
candidates=[],
)
),
],
)
result = matcher.match("unknown request")
self.assertEqual(result.debug.decision, "reject")
self.assertTrue(result.debug.unknown_detected)
self.assertIsNone(result.intent)
def test_route_to_cloud_for_low_confidence_classifier_only_bert_signal(self) -> None:
matcher = MultiStageIntentMatcher(
registry=self.registry,
matchers=[
_FakeMatcher(
MatcherStageDebug(
stage="classifier",
accepted=True,
selected_intent="alpha",
score=0.31,
reason="classifier selected best candidate",
backend="bert-local",
metadata={"threshold": 0.0, "top_margin": 0.04},
candidates=[
IntentCandidate(intent_id="alpha", score=0.31, reason="classifier", model_name="bert-local"),
IntentCandidate(intent_id="beta", score=0.27, reason="classifier", model_name="bert-local"),
],
)
),
],
)
result = matcher.match("bert only weak request")
self.assertEqual(result.debug.decision, "route_to_cloud")
self.assertIsNone(result.intent)
def test_execute_for_high_confidence_classifier_only_bert_signal(self) -> None:
matcher = MultiStageIntentMatcher(
registry=self.registry,
matchers=[
_FakeMatcher(
MatcherStageDebug(
stage="classifier",
accepted=True,
selected_intent="alpha",
score=0.92,
reason="classifier selected best candidate",
backend="bert-local",
metadata={"threshold": 0.0, "top_margin": 0.63},
candidates=[
IntentCandidate(intent_id="alpha", score=0.92, reason="classifier", model_name="bert-local"),
IntentCandidate(intent_id="beta", score=0.29, reason="classifier", model_name="bert-local"),
],
)
),
],
)
result = matcher.match("bert only strong request")
self.assertEqual(result.debug.decision, "execute")
self.assertEqual(result.intent.intent_id if result.intent else None, "alpha")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,180 @@
from __future__ import annotations
import unittest
from app.plugins.base import PluginRegistry
from app.schemas.chat import ChatRequest, FillSlotsRequest
from app.schemas.debug import RoutingDebug
from app.schemas.intent import IntentDefinition
from app.schemas.workflow import Workflow, WorkflowStep
from app.services.agent_service import AgentService
from app.services.intent_registry import IntentRegistry
from app.services.session_store import InMemorySessionStore
from app.services.social import SocialReplyResult, SocialRouter
class _FailingRouter:
def route(self, text: str): # pragma: no cover - should not be called in these tests
raise AssertionError(f"router should not be called for social input: {text}")
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, object]:
_ = (text, intent)
return {}
class _FakeSocialResponder:
def reply(self, text: str, session) -> SocialReplyResult:
_ = (text, session)
normalized = text.strip()
if "你好" in normalized:
text = "你好呀,我在,想聊什么都可以。"
elif "名字" in normalized or "你是谁" in normalized:
text = "我是一名智能座舱助手,你可以直接叫我座舱助手。"
elif "天气" in normalized:
text = "是啊,今天确实挺舒服的。"
else:
text = "我在,咱们继续聊。"
return SocialReplyResult(
text=text,
backend="fake-cloud",
model_name="fake-social",
)
def _intent(intent_id: str, plugin_id: str) -> IntentDefinition:
return IntentDefinition(
intent_id=intent_id,
plugin_id=plugin_id,
domain="service" if intent_id.startswith("cs_") else "cabin",
risk_level="high" if intent_id == "cs_cancel_order" else "low",
required_slots=["order_id"] if intent_id == "cs_cancel_order" else [],
ask_templates={"order_id": "请告诉我订单号。"} if intent_id == "cs_cancel_order" else {},
keywords=[],
examples=[],
)
class SocialChatTests(unittest.TestCase):
def setUp(self) -> None:
self.session_store = InMemorySessionStore()
self.plugins = PluginRegistry()
self.plugins.register(
"mock.cancel_order",
lambda slots: {"success": True, "message": f"已取消订单 {slots.get('order_id', '')}"},
)
self.service = AgentService(
intent_registry=IntentRegistry([_intent("cs_cancel_order", "mock.cancel_order")]),
router=_FailingRouter(),
plugins=self.plugins,
session_store=self.session_store,
social_router=SocialRouter(),
social_responder=_FakeSocialResponder(),
)
def test_greeting_social_reply_uses_social_responder(self) -> None:
response = self.service.handle_chat(
ChatRequest(
session_id="sess_social_hi",
user_id="user_1",
input_text="你好",
)
)
self.assertEqual(response.decision, "open_social")
self.assertEqual(response.status, "social")
self.assertIn("你好呀", response.reply_text)
def test_capability_social_question_does_not_fall_into_business_intent(self) -> None:
response = self.service.handle_chat(
ChatRequest(
session_id="sess_social_name",
user_id="user_1",
input_text="你叫什么名字",
)
)
self.assertEqual(response.decision, "open_social")
self.assertEqual(response.status, "social")
self.assertNotEqual(response.reply_type, "ask_slot")
self.assertIn("智能座舱助手", response.reply_text)
def test_open_social_reply_uses_social_responder(self) -> None:
response = self.service.handle_chat(
ChatRequest(
session_id="sess_social_open",
user_id="user_1",
input_text="今天天气真不错啊",
)
)
self.assertEqual(response.decision, "open_social")
self.assertEqual(response.status, "social")
self.assertIn("挺舒服", response.reply_text)
def test_social_turn_does_not_break_waiting_confirmation(self) -> None:
session = self.session_store.get_or_create("sess_confirm", "user_1")
session.current_intent = "cs_cancel_order"
session.status = "waiting_confirmation"
session.pending_slots = ["confirmation"]
session.slots = {"order_id": "A123456"}
session.routing_debug = RoutingDebug(selected_intent="cs_cancel_order", decision="execute").model_dump()
session.workflow = Workflow(
workflow_id="wf_sess_confirm",
workflow_type="conditional",
domain="service",
intent_id="cs_cancel_order",
status="waiting_confirmation",
risk_level="high",
slots={"order_id": "A123456"},
steps=[
WorkflowStep(
step=1,
step_id="step_confirm",
intent_id="cs_cancel_order",
plugin_id="mock.cancel_order",
action="cancel_order",
status="waiting_confirmation",
slots={"order_id": "A123456"},
requires_confirmation=True,
)
],
meta={
"pending_confirmation": {
"step_id": "step_confirm",
"intent_id": "cs_cancel_order",
"detail": "确认取消订单 A123456",
},
"confirmed_steps": [],
"step_results": {},
},
).model_dump()
self.session_store.save(session)
social_response = self.service.handle_fill_slots(
FillSlotsRequest(
session_id="sess_confirm",
user_id="user_1",
input_text="今天天气真不错啊",
)
)
self.assertEqual(social_response.decision, "open_social")
self.assertEqual(social_response.status, "waiting_confirmation")
self.assertEqual(social_response.pending_slots, ["confirmation"])
self.assertIn("回复“确认”或“取消”即可", social_response.reply_text)
confirm_response = self.service.handle_fill_slots(
FillSlotsRequest(
session_id="sess_confirm",
user_id="user_1",
input_text="确认",
)
)
self.assertEqual(confirm_response.reply_type, "workflow_result")
self.assertEqual(confirm_response.status, "completed")
self.assertIn("已取消订单", confirm_response.reply_text)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,184 @@
from __future__ import annotations
import unittest
from pathlib import Path
from app.schemas.configuration import WorkflowTemplatesConfig
from app.services.classifier import ClassificationResult
from app.services.multi_intent_detector import MultiIntentCandidate, MultiIntentDetectionResult
from app.services.planner import CompositeWorkflowPlanner, HeuristicWorkflowPlanner, TemplateWorkflowPlanner
from app.services.intent_registry import IntentRegistry
class FakeClauseClassifier:
def __init__(self, predictions: dict[str, list[dict[str, float | str]]]) -> None:
self._predictions = predictions
def predict(self, text, intents):
_ = intents
return ClassificationResult(
intent=None,
score=0.0,
model_name="fake-bert-clause",
backend_name="fake-bert-clause",
raw_candidates=self._predictions.get(text, []),
)
class FakeMultiIntentDetector:
def __init__(self, predictions: dict[str, list[tuple[str, float]]]) -> None:
self._predictions = predictions
def detect(self, text, intents):
_ = intents
candidates = [
MultiIntentCandidate(intent_id=intent_id, score=score, label=intent_id)
for intent_id, score in self._predictions.get(text, [])
]
return MultiIntentDetectionResult(
detected=len(candidates) >= 2,
candidates=candidates,
reason="fake detector",
backend_name="fake-multi",
)
class WorkflowTemplateTests(unittest.TestCase):
def setUp(self) -> None:
self.registry = IntentRegistry.from_json("app/data/intents.json")
self.templates = WorkflowTemplatesConfig.model_validate_json(
Path("config/workflows.yml").read_text(encoding="utf-8")
)
def test_template_planner_matches_sequence_template(self) -> None:
planner = TemplateWorkflowPlanner(self.templates)
result = planner.plan("打开车窗然后把空调调到20度", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual(result.backend, "local-template")
self.assertEqual(result.workflow_type, "sequence")
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_set_ac"])
self.assertEqual(result.steps[1].slots.get("temperature"), 20)
def test_template_planner_matches_conditional_template(self) -> None:
planner = TemplateWorkflowPlanner(self.templates)
result = planner.plan("查一下订单A123456如果还没发货就取消", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual(result.workflow_type, "conditional")
self.assertEqual([step.intent_id for step in result.steps], ["cs_query_order", "cs_cancel_order"])
self.assertEqual(result.steps[1].depends_on, [1])
self.assertTrue(result.steps[1].requires_confirmation)
def test_composite_planner_falls_back_to_heuristic_when_template_misses(self) -> None:
planner = CompositeWorkflowPlanner([TemplateWorkflowPlanner(self.templates), HeuristicWorkflowPlanner()])
result = planner.plan("打开车窗,并且播放轻音乐", self.registry.list())
self.assertTrue(result.accepted)
self.assertIn(result.backend, {"local-template", "local-heuristic"})
self.assertEqual(result.workflow_type, "sequence")
def test_heuristic_planner_parses_ac_then_window_close_sequence(self) -> None:
planner = HeuristicWorkflowPlanner()
result = planner.plan("打开空调,再把窗户降下来", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual(result.backend, "local-heuristic")
self.assertEqual(result.workflow_type, "sequence")
self.assertEqual([step.intent_id for step in result.steps], ["cabin_ac_on", "cabin_window_close"])
def test_planner_metadata_contains_clause_analysis(self) -> None:
planner = HeuristicWorkflowPlanner()
result = planner.plan("打开空调,然后打开车窗", self.registry.list())
self.assertTrue(result.accepted)
self.assertTrue(result.metadata.get("multi_intent_detected"))
clause_analysis = result.metadata.get("clause_analysis", [])
self.assertEqual(len(clause_analysis), 2)
self.assertEqual(clause_analysis[0].get("selected_intent_id"), "cabin_ac_on")
self.assertEqual(clause_analysis[1].get("selected_intent_id"), "cabin_window_open")
def test_heuristic_planner_supports_shared_action_parallel_objects(self) -> None:
planner = HeuristicWorkflowPlanner()
result = planner.plan("打开空调和车窗", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual(result.workflow_type, "sequence")
self.assertEqual([step.intent_id for step in result.steps], ["cabin_ac_on", "cabin_window_open"])
def test_heuristic_planner_supports_parallel_objects_with_suffix_action(self) -> None:
planner = HeuristicWorkflowPlanner()
result = planner.plan("把车窗和天窗打开", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual(result.workflow_type, "sequence")
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_sunroof_open"])
def test_heuristic_planner_supports_parallel_clause_with_bing_connector(self) -> None:
planner = HeuristicWorkflowPlanner()
result = planner.plan("打开车窗并播放轻音乐", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual(result.workflow_type, "sequence")
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
def test_heuristic_planner_can_use_clause_classifier_to_rescue_semantic_clause(self) -> None:
planner = HeuristicWorkflowPlanner(
clause_classifier=FakeClauseClassifier(
{
"车里太闷了": [
{"label": "cabin_window_open", "intent_id": "cabin_window_open", "score": 0.83},
],
"来点轻音乐": [
{"label": "cabin_play_music", "intent_id": "cabin_play_music", "score": 0.91},
],
}
)
)
result = planner.plan("车里太闷了,然后来点轻音乐", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual(result.workflow_type, "sequence")
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
clause_analysis = result.metadata.get("clause_analysis", [])
self.assertGreater(clause_analysis[0].get("candidates", [])[0].get("model_score", 0.0), 0.8)
def test_heuristic_planner_can_use_multi_intent_detector_prior(self) -> None:
planner = HeuristicWorkflowPlanner(
clause_classifier=FakeClauseClassifier(
{
"来点轻音乐": [
{"label": "cabin_play_music", "intent_id": "cabin_play_music", "score": 0.91},
],
}
),
multi_intent_detector=FakeMultiIntentDetector(
{
"顺便开下车窗,再来点轻音乐": [
("cabin_window_open", 0.87),
("cabin_play_music", 0.82),
]
}
),
)
result = planner.plan("顺便开下车窗,再来点轻音乐", self.registry.list())
self.assertTrue(result.accepted)
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
detector_meta = result.metadata.get("multi_intent_detector") or {}
self.assertTrue(detector_meta.get("detected"))
self.assertEqual(len(detector_meta.get("candidates", [])), 2)
if __name__ == "__main__":
unittest.main()