181 lines
6.8 KiB
Python
181 lines
6.8 KiB
Python
from __future__ import annotations
|
|
|
|
import unittest
|
|
|
|
from app.plugins.base import PluginRegistry
|
|
from app.schemas.chat import ChatRequest, FillSlotsRequest
|
|
from app.schemas.debug import RoutingDebug
|
|
from app.schemas.intent import IntentDefinition
|
|
from app.schemas.workflow import Workflow, WorkflowStep
|
|
from app.services.agent_service import AgentService
|
|
from app.services.intent_registry import IntentRegistry
|
|
from app.services.session_store import InMemorySessionStore
|
|
from app.services.social import SocialReplyResult, SocialRouter
|
|
|
|
|
|
class _FailingRouter:
|
|
def route(self, text: str): # pragma: no cover - should not be called in these tests
|
|
raise AssertionError(f"router should not be called for social input: {text}")
|
|
|
|
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, object]:
|
|
_ = (text, intent)
|
|
return {}
|
|
|
|
|
|
class _FakeSocialResponder:
|
|
def reply(self, text: str, session) -> SocialReplyResult:
|
|
_ = (text, session)
|
|
normalized = text.strip()
|
|
if "你好" in normalized:
|
|
text = "你好呀,我在,想聊什么都可以。"
|
|
elif "名字" in normalized or "你是谁" in normalized:
|
|
text = "我是一名智能座舱助手,你可以直接叫我座舱助手。"
|
|
elif "天气" in normalized:
|
|
text = "是啊,今天确实挺舒服的。"
|
|
else:
|
|
text = "我在,咱们继续聊。"
|
|
return SocialReplyResult(
|
|
text=text,
|
|
backend="fake-cloud",
|
|
model_name="fake-social",
|
|
)
|
|
|
|
|
|
def _intent(intent_id: str, plugin_id: str) -> IntentDefinition:
|
|
return IntentDefinition(
|
|
intent_id=intent_id,
|
|
plugin_id=plugin_id,
|
|
domain="service" if intent_id.startswith("cs_") else "cabin",
|
|
risk_level="high" if intent_id == "cs_cancel_order" else "low",
|
|
required_slots=["order_id"] if intent_id == "cs_cancel_order" else [],
|
|
ask_templates={"order_id": "请告诉我订单号。"} if intent_id == "cs_cancel_order" else {},
|
|
keywords=[],
|
|
examples=[],
|
|
)
|
|
|
|
|
|
class SocialChatTests(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.session_store = InMemorySessionStore()
|
|
self.plugins = PluginRegistry()
|
|
self.plugins.register(
|
|
"mock.cancel_order",
|
|
lambda slots: {"success": True, "message": f"已取消订单 {slots.get('order_id', '')}。"},
|
|
)
|
|
self.service = AgentService(
|
|
intent_registry=IntentRegistry([_intent("cs_cancel_order", "mock.cancel_order")]),
|
|
router=_FailingRouter(),
|
|
plugins=self.plugins,
|
|
session_store=self.session_store,
|
|
social_router=SocialRouter(),
|
|
social_responder=_FakeSocialResponder(),
|
|
)
|
|
|
|
def test_greeting_social_reply_uses_social_responder(self) -> None:
|
|
response = self.service.handle_chat(
|
|
ChatRequest(
|
|
session_id="sess_social_hi",
|
|
user_id="user_1",
|
|
input_text="你好",
|
|
)
|
|
)
|
|
|
|
self.assertEqual(response.decision, "open_social")
|
|
self.assertEqual(response.status, "social")
|
|
self.assertIn("你好呀", response.reply_text)
|
|
|
|
def test_capability_social_question_does_not_fall_into_business_intent(self) -> None:
|
|
response = self.service.handle_chat(
|
|
ChatRequest(
|
|
session_id="sess_social_name",
|
|
user_id="user_1",
|
|
input_text="你叫什么名字",
|
|
)
|
|
)
|
|
|
|
self.assertEqual(response.decision, "open_social")
|
|
self.assertEqual(response.status, "social")
|
|
self.assertNotEqual(response.reply_type, "ask_slot")
|
|
self.assertIn("智能座舱助手", response.reply_text)
|
|
|
|
def test_open_social_reply_uses_social_responder(self) -> None:
|
|
response = self.service.handle_chat(
|
|
ChatRequest(
|
|
session_id="sess_social_open",
|
|
user_id="user_1",
|
|
input_text="今天天气真不错啊",
|
|
)
|
|
)
|
|
|
|
self.assertEqual(response.decision, "open_social")
|
|
self.assertEqual(response.status, "social")
|
|
self.assertIn("挺舒服", response.reply_text)
|
|
|
|
def test_social_turn_does_not_break_waiting_confirmation(self) -> None:
|
|
session = self.session_store.get_or_create("sess_confirm", "user_1")
|
|
session.current_intent = "cs_cancel_order"
|
|
session.status = "waiting_confirmation"
|
|
session.pending_slots = ["confirmation"]
|
|
session.slots = {"order_id": "A123456"}
|
|
session.routing_debug = RoutingDebug(selected_intent="cs_cancel_order", decision="execute").model_dump()
|
|
session.workflow = Workflow(
|
|
workflow_id="wf_sess_confirm",
|
|
workflow_type="conditional",
|
|
domain="service",
|
|
intent_id="cs_cancel_order",
|
|
status="waiting_confirmation",
|
|
risk_level="high",
|
|
slots={"order_id": "A123456"},
|
|
steps=[
|
|
WorkflowStep(
|
|
step=1,
|
|
step_id="step_confirm",
|
|
intent_id="cs_cancel_order",
|
|
plugin_id="mock.cancel_order",
|
|
action="cancel_order",
|
|
status="waiting_confirmation",
|
|
slots={"order_id": "A123456"},
|
|
requires_confirmation=True,
|
|
)
|
|
],
|
|
meta={
|
|
"pending_confirmation": {
|
|
"step_id": "step_confirm",
|
|
"intent_id": "cs_cancel_order",
|
|
"detail": "确认取消订单 A123456",
|
|
},
|
|
"confirmed_steps": [],
|
|
"step_results": {},
|
|
},
|
|
).model_dump()
|
|
self.session_store.save(session)
|
|
|
|
social_response = self.service.handle_fill_slots(
|
|
FillSlotsRequest(
|
|
session_id="sess_confirm",
|
|
user_id="user_1",
|
|
input_text="今天天气真不错啊",
|
|
)
|
|
)
|
|
|
|
self.assertEqual(social_response.decision, "open_social")
|
|
self.assertEqual(social_response.status, "waiting_confirmation")
|
|
self.assertEqual(social_response.pending_slots, ["confirmation"])
|
|
self.assertIn("回复“确认”或“取消”即可", social_response.reply_text)
|
|
|
|
confirm_response = self.service.handle_fill_slots(
|
|
FillSlotsRequest(
|
|
session_id="sess_confirm",
|
|
user_id="user_1",
|
|
input_text="确认",
|
|
)
|
|
)
|
|
|
|
self.assertEqual(confirm_response.reply_type, "workflow_result")
|
|
self.assertEqual(confirm_response.status, "completed")
|
|
self.assertIn("已取消订单", confirm_response.reply_text)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|