150 lines
5.3 KiB
Python
150 lines
5.3 KiB
Python
from __future__ import annotations
|
|
|
|
import unittest
|
|
|
|
from app.plugins.base import PluginRegistry
|
|
from app.plugins.mock import MockPluginExecutor
|
|
from app.services.classifier import MockIntentClassifier
|
|
from app.services.agent_service import AgentService
|
|
from app.services.intent_registry import IntentRegistry
|
|
from app.services.response_policy import ResponsePolicy
|
|
from app.services.rewrite_engine import ContextRewriteEngine
|
|
from app.services.router import HeuristicSlotExtractor, IntentRouter, build_matcher_pipeline
|
|
from app.services.session_store import InMemorySessionStore
|
|
from app.schemas.chat import ChatRequest, FillSlotsRequest
|
|
|
|
|
|
class _BertLikeMockClassifier(MockIntentClassifier):
|
|
def predict(self, text, intents):
|
|
result = super().predict(text, intents)
|
|
result.model_name = "bert-local"
|
|
result.backend_name = "bert-local"
|
|
return result
|
|
|
|
|
|
def _build_service() -> AgentService:
|
|
registry = IntentRegistry.from_json("app/data/intents.json")
|
|
plugins = PluginRegistry()
|
|
MockPluginExecutor().register(plugins)
|
|
router = IntentRouter(
|
|
matcher=build_matcher_pipeline(
|
|
registry,
|
|
["classifier"],
|
|
classifier=_BertLikeMockClassifier(threshold=0.0, top_k=3),
|
|
),
|
|
slot_extractor=HeuristicSlotExtractor(),
|
|
)
|
|
return AgentService(
|
|
intent_registry=registry,
|
|
router=router,
|
|
plugins=plugins,
|
|
session_store=InMemorySessionStore(),
|
|
rewrite_engine=ContextRewriteEngine(),
|
|
response_policy=ResponsePolicy(),
|
|
planner=None,
|
|
)
|
|
|
|
|
|
class IntentCoverageAndStopTests(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.service = _build_service()
|
|
self.registry = IntentRegistry.from_json("app/data/intents.json")
|
|
|
|
def test_intent_catalog_has_at_least_30_items(self) -> None:
|
|
self.assertGreaterEqual(len(self.registry.list()), 30)
|
|
|
|
def test_close_ac_routes_to_power_off(self) -> None:
|
|
response = self.service.handle_chat(
|
|
ChatRequest(
|
|
session_id="sess_close_ac",
|
|
user_id="user_1",
|
|
input_text="关闭空调",
|
|
)
|
|
)
|
|
|
|
self.assertEqual(response.reply_type, "workflow_result")
|
|
self.assertEqual(response.intent, "cabin_ac_off")
|
|
self.assertIn("已关闭空调", response.reply_text)
|
|
|
|
def test_open_window_is_covered(self) -> None:
|
|
response = self.service.handle_chat(
|
|
ChatRequest(
|
|
session_id="sess_window_open",
|
|
user_id="user_1",
|
|
input_text="打开车窗",
|
|
)
|
|
)
|
|
|
|
self.assertEqual(response.reply_type, "workflow_result")
|
|
self.assertEqual(response.intent, "cabin_window_open")
|
|
self.assertIn("已打开车窗", response.reply_text)
|
|
|
|
def test_stop_current_task_while_waiting_for_slot(self) -> None:
|
|
first = self.service.handle_chat(
|
|
ChatRequest(
|
|
session_id="sess_stop_task",
|
|
user_id="user_1",
|
|
input_text="空调调到",
|
|
)
|
|
)
|
|
self.assertEqual(first.reply_type, "ask_slot")
|
|
self.assertEqual(first.pending_slots, ["temperature"])
|
|
|
|
stopped = self.service.handle_fill_slots(
|
|
FillSlotsRequest(
|
|
session_id="sess_stop_task",
|
|
user_id="user_1",
|
|
input_text="不用了",
|
|
)
|
|
)
|
|
|
|
self.assertEqual(stopped.reply_type, "text")
|
|
self.assertEqual(stopped.status, "stopped")
|
|
self.assertEqual(stopped.pending_slots, [])
|
|
self.assertIn("已停止当前任务", stopped.reply_text)
|
|
|
|
def test_relative_ac_adjustment_uses_two_degree_step(self) -> None:
|
|
session = self.service.session_store.get_or_create("sess_ac_lower", "user_1")
|
|
session.current_intent = "cabin_set_ac"
|
|
session.context_memory["last_temperature"] = 24
|
|
self.service.session_store.save(session)
|
|
|
|
response = self.service.handle_chat(
|
|
ChatRequest(
|
|
session_id="sess_ac_lower",
|
|
user_id="user_1",
|
|
input_text="调低一点",
|
|
)
|
|
)
|
|
|
|
self.assertEqual(response.reply_type, "workflow_result")
|
|
self.assertEqual(response.intent, "cabin_set_ac")
|
|
self.assertEqual(response.filled_slots.get("temperature"), 22)
|
|
self.assertIn("22", response.reply_text)
|
|
|
|
def test_relative_ac_adjustment_without_history_uses_default_baseline(self) -> None:
|
|
session = self.service.session_store.get_or_create("sess_ac_lower_default", "user_1")
|
|
session.current_intent = "cabin_ac_on"
|
|
self.service.session_store.save(session)
|
|
|
|
response = self.service.handle_chat(
|
|
ChatRequest(
|
|
session_id="sess_ac_lower_default",
|
|
user_id="user_1",
|
|
input_text="调低一点",
|
|
)
|
|
)
|
|
|
|
self.assertEqual(response.reply_type, "workflow_result")
|
|
self.assertEqual(response.intent, "cabin_set_ac")
|
|
self.assertEqual(response.filled_slots.get("temperature"), 22)
|
|
self.assertIn("22", response.reply_text)
|
|
|
|
def test_temperature_is_clamped_before_execution(self) -> None:
|
|
self.assertEqual(self.service._normalize_temperature_value(-1), 16)
|
|
self.assertEqual(self.service._normalize_temperature_value(40), 30)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|