Update project and configurations
This commit is contained in:
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.services.classifier import JointBertIntentClassifier
|
||||
from app.services.planner import HeuristicWorkflowPlanner
|
||||
from app.services.router import JointBertSlotExtractor
|
||||
|
||||
|
||||
class FakeJointNLU:
|
||||
def __init__(self) -> None:
|
||||
self._predictions = {
|
||||
"把空调调到22度": {
|
||||
"intent_id": "cabin_set_ac",
|
||||
"intent_score": 0.93,
|
||||
"candidates": [("cabin_set_ac", 0.93), ("cabin_ac_on", 0.04)],
|
||||
"slots": {"temperature": 22},
|
||||
},
|
||||
"导航去公司,然后把空调调到22度": {
|
||||
"intent_id": "cabin_nav_to",
|
||||
"intent_score": 0.88,
|
||||
"candidates": [("cabin_nav_to", 0.88), ("cabin_set_ac", 0.72)],
|
||||
"slots": {"destination": "公司"},
|
||||
},
|
||||
}
|
||||
self._slot_predictions = {
|
||||
("把空调调到22度", "cabin_set_ac"): {"temperature": 22},
|
||||
("导航去公司", "cabin_nav_to"): {"destination": "公司"},
|
||||
("把空调调到22度", "cabin_set_ac"): {"temperature": 22},
|
||||
}
|
||||
|
||||
def warmup(self, sample_text: str = "打开车窗") -> bool:
|
||||
_ = sample_text
|
||||
return True
|
||||
|
||||
def predict(self, text: str, intents: list[IntentDefinition]):
|
||||
from app.services.joint_nlu import JointCandidate, JointNluResult
|
||||
|
||||
raw = self._predictions[text]
|
||||
candidates = [JointCandidate(intent_id=intent_id, score=score) for intent_id, score in raw["candidates"]]
|
||||
return JointNluResult(
|
||||
intent_id=raw["intent_id"],
|
||||
intent_score=raw["intent_score"],
|
||||
candidates=candidates,
|
||||
slots=dict(raw["slots"]),
|
||||
)
|
||||
|
||||
def extract_slots(self, text: str, intent: IntentDefinition):
|
||||
return dict(self._slot_predictions.get((text, intent.intent_id), {}))
|
||||
|
||||
def extract_slots_by_intent_id(self, text: str, intent_id: str, required_slots=None):
|
||||
_ = required_slots
|
||||
return dict(self._slot_predictions.get((text, intent_id), {}))
|
||||
|
||||
|
||||
class JointNLUIntegrationTests(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.intents = [
|
||||
IntentDefinition(intent_id="cabin_set_ac", plugin_id="x", domain="cabin", required_slots=["temperature"]),
|
||||
IntentDefinition(intent_id="cabin_nav_to", plugin_id="x", domain="cabin", required_slots=["destination"]),
|
||||
IntentDefinition(intent_id="cabin_ac_on", plugin_id="x", domain="cabin"),
|
||||
]
|
||||
self.fake_nlu = FakeJointNLU()
|
||||
|
||||
def test_joint_classifier_uses_joint_nlu_intent_head(self) -> None:
|
||||
classifier = JointBertIntentClassifier(self.fake_nlu, threshold=0.3, top_k=2)
|
||||
|
||||
result = classifier.predict("把空调调到22度", self.intents)
|
||||
|
||||
self.assertIsNotNone(result.intent)
|
||||
self.assertEqual(result.intent.intent_id, "cabin_set_ac")
|
||||
self.assertEqual(result.raw_candidates[0]["intent_id"], "cabin_set_ac")
|
||||
|
||||
def test_joint_slot_extractor_uses_joint_nlu_slots(self) -> None:
|
||||
extractor = JointBertSlotExtractor(self.fake_nlu)
|
||||
|
||||
slots = extractor.extract("把空调调到22度", self.intents[0])
|
||||
|
||||
self.assertEqual(slots, {"temperature": 22})
|
||||
|
||||
def test_planner_prefers_joint_nlu_slots_for_each_clause(self) -> None:
|
||||
planner = HeuristicWorkflowPlanner(joint_nlu=self.fake_nlu)
|
||||
|
||||
result = planner.plan("导航去公司,然后把空调调到22度", self.intents)
|
||||
|
||||
self.assertTrue(result.accepted)
|
||||
self.assertEqual(result.steps[0].slots, {"destination": "公司"})
|
||||
self.assertEqual(result.steps[1].slots, {"temperature": 22})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user