Files
ai-device/intelligent_cabin/archive/tests/test_joint_nlu_integration.py
2026-06-11 16:28:00 +08:00

94 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import unittest
from app.schemas.intent import IntentDefinition
from app.services.classifier import JointBertIntentClassifier
from app.services.planner import HeuristicWorkflowPlanner
from app.services.router import JointBertSlotExtractor
class FakeJointNLU:
def __init__(self) -> None:
self._predictions = {
"把空调调到22度": {
"intent_id": "cabin_set_ac",
"intent_score": 0.93,
"candidates": [("cabin_set_ac", 0.93), ("cabin_ac_on", 0.04)],
"slots": {"temperature": 22},
},
"导航去公司然后把空调调到22度": {
"intent_id": "cabin_nav_to",
"intent_score": 0.88,
"candidates": [("cabin_nav_to", 0.88), ("cabin_set_ac", 0.72)],
"slots": {"destination": "公司"},
},
}
self._slot_predictions = {
("把空调调到22度", "cabin_set_ac"): {"temperature": 22},
("导航去公司", "cabin_nav_to"): {"destination": "公司"},
("把空调调到22度", "cabin_set_ac"): {"temperature": 22},
}
def warmup(self, sample_text: str = "打开车窗") -> bool:
_ = sample_text
return True
def predict(self, text: str, intents: list[IntentDefinition]):
from app.services.joint_nlu import JointCandidate, JointNluResult
raw = self._predictions[text]
candidates = [JointCandidate(intent_id=intent_id, score=score) for intent_id, score in raw["candidates"]]
return JointNluResult(
intent_id=raw["intent_id"],
intent_score=raw["intent_score"],
candidates=candidates,
slots=dict(raw["slots"]),
)
def extract_slots(self, text: str, intent: IntentDefinition):
return dict(self._slot_predictions.get((text, intent.intent_id), {}))
def extract_slots_by_intent_id(self, text: str, intent_id: str, required_slots=None):
_ = required_slots
return dict(self._slot_predictions.get((text, intent_id), {}))
class JointNLUIntegrationTests(unittest.TestCase):
def setUp(self) -> None:
self.intents = [
IntentDefinition(intent_id="cabin_set_ac", plugin_id="x", domain="cabin", required_slots=["temperature"]),
IntentDefinition(intent_id="cabin_nav_to", plugin_id="x", domain="cabin", required_slots=["destination"]),
IntentDefinition(intent_id="cabin_ac_on", plugin_id="x", domain="cabin"),
]
self.fake_nlu = FakeJointNLU()
def test_joint_classifier_uses_joint_nlu_intent_head(self) -> None:
classifier = JointBertIntentClassifier(self.fake_nlu, threshold=0.3, top_k=2)
result = classifier.predict("把空调调到22度", self.intents)
self.assertIsNotNone(result.intent)
self.assertEqual(result.intent.intent_id, "cabin_set_ac")
self.assertEqual(result.raw_candidates[0]["intent_id"], "cabin_set_ac")
def test_joint_slot_extractor_uses_joint_nlu_slots(self) -> None:
extractor = JointBertSlotExtractor(self.fake_nlu)
slots = extractor.extract("把空调调到22度", self.intents[0])
self.assertEqual(slots, {"temperature": 22})
def test_planner_prefers_joint_nlu_slots_for_each_clause(self) -> None:
planner = HeuristicWorkflowPlanner(joint_nlu=self.fake_nlu)
result = planner.plan("导航去公司然后把空调调到22度", self.intents)
self.assertTrue(result.accepted)
self.assertEqual(result.steps[0].slots, {"destination": "公司"})
self.assertEqual(result.steps[1].slots, {"temperature": 22})
if __name__ == "__main__":
unittest.main()