94 lines
3.7 KiB
Python
94 lines
3.7 KiB
Python
from __future__ import annotations
|
||
|
||
import unittest
|
||
|
||
from app.schemas.intent import IntentDefinition
|
||
from app.services.classifier import JointBertIntentClassifier
|
||
from app.services.planner import HeuristicWorkflowPlanner
|
||
from app.services.router import JointBertSlotExtractor
|
||
|
||
|
||
class FakeJointNLU:
|
||
def __init__(self) -> None:
|
||
self._predictions = {
|
||
"把空调调到22度": {
|
||
"intent_id": "cabin_set_ac",
|
||
"intent_score": 0.93,
|
||
"candidates": [("cabin_set_ac", 0.93), ("cabin_ac_on", 0.04)],
|
||
"slots": {"temperature": 22},
|
||
},
|
||
"导航去公司,然后把空调调到22度": {
|
||
"intent_id": "cabin_nav_to",
|
||
"intent_score": 0.88,
|
||
"candidates": [("cabin_nav_to", 0.88), ("cabin_set_ac", 0.72)],
|
||
"slots": {"destination": "公司"},
|
||
},
|
||
}
|
||
self._slot_predictions = {
|
||
("把空调调到22度", "cabin_set_ac"): {"temperature": 22},
|
||
("导航去公司", "cabin_nav_to"): {"destination": "公司"},
|
||
("把空调调到22度", "cabin_set_ac"): {"temperature": 22},
|
||
}
|
||
|
||
def warmup(self, sample_text: str = "打开车窗") -> bool:
|
||
_ = sample_text
|
||
return True
|
||
|
||
def predict(self, text: str, intents: list[IntentDefinition]):
|
||
from app.services.joint_nlu import JointCandidate, JointNluResult
|
||
|
||
raw = self._predictions[text]
|
||
candidates = [JointCandidate(intent_id=intent_id, score=score) for intent_id, score in raw["candidates"]]
|
||
return JointNluResult(
|
||
intent_id=raw["intent_id"],
|
||
intent_score=raw["intent_score"],
|
||
candidates=candidates,
|
||
slots=dict(raw["slots"]),
|
||
)
|
||
|
||
def extract_slots(self, text: str, intent: IntentDefinition):
|
||
return dict(self._slot_predictions.get((text, intent.intent_id), {}))
|
||
|
||
def extract_slots_by_intent_id(self, text: str, intent_id: str, required_slots=None):
|
||
_ = required_slots
|
||
return dict(self._slot_predictions.get((text, intent_id), {}))
|
||
|
||
|
||
class JointNLUIntegrationTests(unittest.TestCase):
|
||
def setUp(self) -> None:
|
||
self.intents = [
|
||
IntentDefinition(intent_id="cabin_set_ac", plugin_id="x", domain="cabin", required_slots=["temperature"]),
|
||
IntentDefinition(intent_id="cabin_nav_to", plugin_id="x", domain="cabin", required_slots=["destination"]),
|
||
IntentDefinition(intent_id="cabin_ac_on", plugin_id="x", domain="cabin"),
|
||
]
|
||
self.fake_nlu = FakeJointNLU()
|
||
|
||
def test_joint_classifier_uses_joint_nlu_intent_head(self) -> None:
|
||
classifier = JointBertIntentClassifier(self.fake_nlu, threshold=0.3, top_k=2)
|
||
|
||
result = classifier.predict("把空调调到22度", self.intents)
|
||
|
||
self.assertIsNotNone(result.intent)
|
||
self.assertEqual(result.intent.intent_id, "cabin_set_ac")
|
||
self.assertEqual(result.raw_candidates[0]["intent_id"], "cabin_set_ac")
|
||
|
||
def test_joint_slot_extractor_uses_joint_nlu_slots(self) -> None:
|
||
extractor = JointBertSlotExtractor(self.fake_nlu)
|
||
|
||
slots = extractor.extract("把空调调到22度", self.intents[0])
|
||
|
||
self.assertEqual(slots, {"temperature": 22})
|
||
|
||
def test_planner_prefers_joint_nlu_slots_for_each_clause(self) -> None:
|
||
planner = HeuristicWorkflowPlanner(joint_nlu=self.fake_nlu)
|
||
|
||
result = planner.plan("导航去公司,然后把空调调到22度", self.intents)
|
||
|
||
self.assertTrue(result.accepted)
|
||
self.assertEqual(result.steps[0].slots, {"destination": "公司"})
|
||
self.assertEqual(result.steps[1].slots, {"temperature": 22})
|
||
|
||
|
||
if __name__ == "__main__":
|
||
unittest.main()
|