Files
ai-device/intelligent_cabin/archive/tests/test_multi_intent_detector.py
2026-06-11 16:28:00 +08:00

145 lines
5.2 KiB
Python

from __future__ import annotations
import unittest
import torch
from app.schemas.intent import IntentDefinition
from app.services.multi_intent_detector import BertMultiIntentDetector, JointBertMultiIntentDetector
class FakeTokenizer:
def __call__(self, text, truncation=True, padding=False, return_tensors="pt"):
_ = (text, truncation, padding, return_tensors)
return {
"input_ids": torch.tensor([[101, 102]], dtype=torch.long),
"attention_mask": torch.tensor([[1, 1]], dtype=torch.long),
}
class FakeModel:
def __init__(self, logits: list[float], id2label: dict[int, str]) -> None:
self.config = type("Config", (), {"id2label": id2label})()
self._logits = torch.tensor([logits], dtype=torch.float32)
def eval(self) -> None:
return None
def __call__(self, **kwargs):
_ = kwargs
return type("Output", (), {"logits": self._logits})()
class RuntimeBackedDetector(BertMultiIntentDetector):
def __init__(
self,
logits: list[float],
id2label: dict[int, str],
threshold: float = 0.45,
top_k: int = 8,
max_labels: int = 4,
) -> None:
super().__init__(
model_path="unused",
threshold=threshold,
top_k=top_k,
max_labels=max_labels,
)
self._runtime = (torch, FakeTokenizer(), FakeModel(logits, id2label))
def _get_runtime(self):
return self._runtime
class MultiIntentDetectorTests(unittest.TestCase):
def test_detector_filters_blocked_and_unknown_labels(self) -> None:
detector = RuntimeBackedDetector(
logits=[2.4, 2.0, 3.2, 2.6],
id2label={
0: "cabin_window_open",
1: "cabin_play_music",
2: "__social__",
3: "unknown_intent",
},
threshold=0.8,
top_k=4,
)
intents = [
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
]
result = detector.detect("打开车窗并播放音乐", intents)
self.assertTrue(result.detected)
self.assertEqual(result.backend_name, "bert-multi-label")
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
def test_detector_respects_threshold_and_max_labels(self) -> None:
detector = RuntimeBackedDetector(
logits=[2.8, 2.5, 2.2],
id2label={
0: "cabin_window_open",
1: "cabin_play_music",
2: "cabin_nav_to",
},
threshold=0.89,
top_k=3,
max_labels=2,
)
intents = [
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
IntentDefinition(intent_id="cabin_nav_to", plugin_id="plugin.nav", domain="cabin"),
]
result = detector.detect("开窗放歌去公司", intents)
self.assertTrue(result.detected)
self.assertEqual(len(result.candidates), 2)
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
def test_joint_bert_detector_wraps_shared_runtime(self) -> None:
intents = [
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
]
class FakeJointNlu:
def __init__(self) -> None:
self.calls: list[dict[str, object]] = []
def predict_multi_intents(self, text, known_intents, threshold=0.45, max_labels=4, top_k=8):
self.calls.append(
{
"text": text,
"threshold": threshold,
"max_labels": max_labels,
"top_k": top_k,
"known_count": len(known_intents),
}
)
return [
type("Candidate", (), {"intent_id": "cabin_window_open", "score": 0.93})(),
type("Candidate", (), {"intent_id": "cabin_play_music", "score": 0.88})(),
]
def warmup(self, sample_text="") -> bool:
_ = sample_text
return True
fake_nlu = FakeJointNlu()
detector = JointBertMultiIntentDetector(fake_nlu, threshold=0.5, top_k=6, max_labels=3)
result = detector.detect("打开车窗并播放音乐", intents)
self.assertTrue(result.detected)
self.assertEqual(result.backend_name, "joint-bert-multi-label")
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
self.assertEqual(fake_nlu.calls[0]["threshold"], 0.5)
self.assertEqual(fake_nlu.calls[0]["top_k"], 6)
if __name__ == "__main__":
unittest.main()