185 lines
7.9 KiB
Python
185 lines
7.9 KiB
Python
from __future__ import annotations
|
||
|
||
import unittest
|
||
from pathlib import Path
|
||
|
||
from app.schemas.configuration import WorkflowTemplatesConfig
|
||
from app.services.classifier import ClassificationResult
|
||
from app.services.multi_intent_detector import MultiIntentCandidate, MultiIntentDetectionResult
|
||
from app.services.planner import CompositeWorkflowPlanner, HeuristicWorkflowPlanner, TemplateWorkflowPlanner
|
||
from app.services.intent_registry import IntentRegistry
|
||
|
||
|
||
class FakeClauseClassifier:
|
||
def __init__(self, predictions: dict[str, list[dict[str, float | str]]]) -> None:
|
||
self._predictions = predictions
|
||
|
||
def predict(self, text, intents):
|
||
_ = intents
|
||
return ClassificationResult(
|
||
intent=None,
|
||
score=0.0,
|
||
model_name="fake-bert-clause",
|
||
backend_name="fake-bert-clause",
|
||
raw_candidates=self._predictions.get(text, []),
|
||
)
|
||
|
||
|
||
class FakeMultiIntentDetector:
|
||
def __init__(self, predictions: dict[str, list[tuple[str, float]]]) -> None:
|
||
self._predictions = predictions
|
||
|
||
def detect(self, text, intents):
|
||
_ = intents
|
||
candidates = [
|
||
MultiIntentCandidate(intent_id=intent_id, score=score, label=intent_id)
|
||
for intent_id, score in self._predictions.get(text, [])
|
||
]
|
||
return MultiIntentDetectionResult(
|
||
detected=len(candidates) >= 2,
|
||
candidates=candidates,
|
||
reason="fake detector",
|
||
backend_name="fake-multi",
|
||
)
|
||
|
||
|
||
class WorkflowTemplateTests(unittest.TestCase):
|
||
def setUp(self) -> None:
|
||
self.registry = IntentRegistry.from_json("app/data/intents.json")
|
||
self.templates = WorkflowTemplatesConfig.model_validate_json(
|
||
Path("config/workflows.yml").read_text(encoding="utf-8")
|
||
)
|
||
|
||
def test_template_planner_matches_sequence_template(self) -> None:
|
||
planner = TemplateWorkflowPlanner(self.templates)
|
||
|
||
result = planner.plan("打开车窗,然后把空调调到20度", self.registry.list())
|
||
|
||
self.assertTrue(result.accepted)
|
||
self.assertEqual(result.backend, "local-template")
|
||
self.assertEqual(result.workflow_type, "sequence")
|
||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_set_ac"])
|
||
self.assertEqual(result.steps[1].slots.get("temperature"), 20)
|
||
|
||
def test_template_planner_matches_conditional_template(self) -> None:
|
||
planner = TemplateWorkflowPlanner(self.templates)
|
||
|
||
result = planner.plan("查一下订单A123456,如果还没发货就取消", self.registry.list())
|
||
|
||
self.assertTrue(result.accepted)
|
||
self.assertEqual(result.workflow_type, "conditional")
|
||
self.assertEqual([step.intent_id for step in result.steps], ["cs_query_order", "cs_cancel_order"])
|
||
self.assertEqual(result.steps[1].depends_on, [1])
|
||
self.assertTrue(result.steps[1].requires_confirmation)
|
||
|
||
def test_composite_planner_falls_back_to_heuristic_when_template_misses(self) -> None:
|
||
planner = CompositeWorkflowPlanner([TemplateWorkflowPlanner(self.templates), HeuristicWorkflowPlanner()])
|
||
|
||
result = planner.plan("打开车窗,并且播放轻音乐", self.registry.list())
|
||
|
||
self.assertTrue(result.accepted)
|
||
self.assertIn(result.backend, {"local-template", "local-heuristic"})
|
||
self.assertEqual(result.workflow_type, "sequence")
|
||
|
||
def test_heuristic_planner_parses_ac_then_window_close_sequence(self) -> None:
|
||
planner = HeuristicWorkflowPlanner()
|
||
|
||
result = planner.plan("打开空调,再把窗户降下来", self.registry.list())
|
||
|
||
self.assertTrue(result.accepted)
|
||
self.assertEqual(result.backend, "local-heuristic")
|
||
self.assertEqual(result.workflow_type, "sequence")
|
||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_ac_on", "cabin_window_close"])
|
||
|
||
def test_planner_metadata_contains_clause_analysis(self) -> None:
|
||
planner = HeuristicWorkflowPlanner()
|
||
|
||
result = planner.plan("打开空调,然后打开车窗", self.registry.list())
|
||
|
||
self.assertTrue(result.accepted)
|
||
self.assertTrue(result.metadata.get("multi_intent_detected"))
|
||
clause_analysis = result.metadata.get("clause_analysis", [])
|
||
self.assertEqual(len(clause_analysis), 2)
|
||
self.assertEqual(clause_analysis[0].get("selected_intent_id"), "cabin_ac_on")
|
||
self.assertEqual(clause_analysis[1].get("selected_intent_id"), "cabin_window_open")
|
||
|
||
def test_heuristic_planner_supports_shared_action_parallel_objects(self) -> None:
|
||
planner = HeuristicWorkflowPlanner()
|
||
|
||
result = planner.plan("打开空调和车窗", self.registry.list())
|
||
|
||
self.assertTrue(result.accepted)
|
||
self.assertEqual(result.workflow_type, "sequence")
|
||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_ac_on", "cabin_window_open"])
|
||
|
||
def test_heuristic_planner_supports_parallel_objects_with_suffix_action(self) -> None:
|
||
planner = HeuristicWorkflowPlanner()
|
||
|
||
result = planner.plan("把车窗和天窗打开", self.registry.list())
|
||
|
||
self.assertTrue(result.accepted)
|
||
self.assertEqual(result.workflow_type, "sequence")
|
||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_sunroof_open"])
|
||
|
||
def test_heuristic_planner_supports_parallel_clause_with_bing_connector(self) -> None:
|
||
planner = HeuristicWorkflowPlanner()
|
||
|
||
result = planner.plan("打开车窗并播放轻音乐", self.registry.list())
|
||
|
||
self.assertTrue(result.accepted)
|
||
self.assertEqual(result.workflow_type, "sequence")
|
||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
|
||
|
||
def test_heuristic_planner_can_use_clause_classifier_to_rescue_semantic_clause(self) -> None:
|
||
planner = HeuristicWorkflowPlanner(
|
||
clause_classifier=FakeClauseClassifier(
|
||
{
|
||
"车里太闷了": [
|
||
{"label": "cabin_window_open", "intent_id": "cabin_window_open", "score": 0.83},
|
||
],
|
||
"来点轻音乐": [
|
||
{"label": "cabin_play_music", "intent_id": "cabin_play_music", "score": 0.91},
|
||
],
|
||
}
|
||
)
|
||
)
|
||
|
||
result = planner.plan("车里太闷了,然后来点轻音乐", self.registry.list())
|
||
|
||
self.assertTrue(result.accepted)
|
||
self.assertEqual(result.workflow_type, "sequence")
|
||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
|
||
clause_analysis = result.metadata.get("clause_analysis", [])
|
||
self.assertGreater(clause_analysis[0].get("candidates", [])[0].get("model_score", 0.0), 0.8)
|
||
|
||
def test_heuristic_planner_can_use_multi_intent_detector_prior(self) -> None:
|
||
planner = HeuristicWorkflowPlanner(
|
||
clause_classifier=FakeClauseClassifier(
|
||
{
|
||
"来点轻音乐": [
|
||
{"label": "cabin_play_music", "intent_id": "cabin_play_music", "score": 0.91},
|
||
],
|
||
}
|
||
),
|
||
multi_intent_detector=FakeMultiIntentDetector(
|
||
{
|
||
"顺便开下车窗,再来点轻音乐": [
|
||
("cabin_window_open", 0.87),
|
||
("cabin_play_music", 0.82),
|
||
]
|
||
}
|
||
),
|
||
)
|
||
|
||
result = planner.plan("顺便开下车窗,再来点轻音乐", self.registry.list())
|
||
|
||
self.assertTrue(result.accepted)
|
||
self.assertEqual([step.intent_id for step in result.steps], ["cabin_window_open", "cabin_play_music"])
|
||
detector_meta = result.metadata.get("multi_intent_detector") or {}
|
||
self.assertTrue(detector_meta.get("detected"))
|
||
self.assertEqual(len(detector_meta.get("candidates", [])), 2)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
unittest.main()
|