Update project and configurations
This commit is contained in:
144
intelligent_cabin/archive/tests/test_multi_intent_detector.py
Normal file
144
intelligent_cabin/archive/tests/test_multi_intent_detector.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.services.multi_intent_detector import BertMultiIntentDetector, JointBertMultiIntentDetector
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
def __call__(self, text, truncation=True, padding=False, return_tensors="pt"):
|
||||
_ = (text, truncation, padding, return_tensors)
|
||||
return {
|
||||
"input_ids": torch.tensor([[101, 102]], dtype=torch.long),
|
||||
"attention_mask": torch.tensor([[1, 1]], dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
class FakeModel:
|
||||
def __init__(self, logits: list[float], id2label: dict[int, str]) -> None:
|
||||
self.config = type("Config", (), {"id2label": id2label})()
|
||||
self._logits = torch.tensor([logits], dtype=torch.float32)
|
||||
|
||||
def eval(self) -> None:
|
||||
return None
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
_ = kwargs
|
||||
return type("Output", (), {"logits": self._logits})()
|
||||
|
||||
|
||||
class RuntimeBackedDetector(BertMultiIntentDetector):
|
||||
def __init__(
|
||||
self,
|
||||
logits: list[float],
|
||||
id2label: dict[int, str],
|
||||
threshold: float = 0.45,
|
||||
top_k: int = 8,
|
||||
max_labels: int = 4,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
model_path="unused",
|
||||
threshold=threshold,
|
||||
top_k=top_k,
|
||||
max_labels=max_labels,
|
||||
)
|
||||
self._runtime = (torch, FakeTokenizer(), FakeModel(logits, id2label))
|
||||
|
||||
def _get_runtime(self):
|
||||
return self._runtime
|
||||
|
||||
|
||||
class MultiIntentDetectorTests(unittest.TestCase):
|
||||
def test_detector_filters_blocked_and_unknown_labels(self) -> None:
|
||||
detector = RuntimeBackedDetector(
|
||||
logits=[2.4, 2.0, 3.2, 2.6],
|
||||
id2label={
|
||||
0: "cabin_window_open",
|
||||
1: "cabin_play_music",
|
||||
2: "__social__",
|
||||
3: "unknown_intent",
|
||||
},
|
||||
threshold=0.8,
|
||||
top_k=4,
|
||||
)
|
||||
intents = [
|
||||
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
|
||||
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
|
||||
]
|
||||
|
||||
result = detector.detect("打开车窗并播放音乐", intents)
|
||||
|
||||
self.assertTrue(result.detected)
|
||||
self.assertEqual(result.backend_name, "bert-multi-label")
|
||||
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
|
||||
|
||||
def test_detector_respects_threshold_and_max_labels(self) -> None:
|
||||
detector = RuntimeBackedDetector(
|
||||
logits=[2.8, 2.5, 2.2],
|
||||
id2label={
|
||||
0: "cabin_window_open",
|
||||
1: "cabin_play_music",
|
||||
2: "cabin_nav_to",
|
||||
},
|
||||
threshold=0.89,
|
||||
top_k=3,
|
||||
max_labels=2,
|
||||
)
|
||||
intents = [
|
||||
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
|
||||
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
|
||||
IntentDefinition(intent_id="cabin_nav_to", plugin_id="plugin.nav", domain="cabin"),
|
||||
]
|
||||
|
||||
result = detector.detect("开窗放歌去公司", intents)
|
||||
|
||||
self.assertTrue(result.detected)
|
||||
self.assertEqual(len(result.candidates), 2)
|
||||
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
|
||||
|
||||
def test_joint_bert_detector_wraps_shared_runtime(self) -> None:
|
||||
intents = [
|
||||
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
|
||||
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
|
||||
]
|
||||
|
||||
class FakeJointNlu:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict[str, object]] = []
|
||||
|
||||
def predict_multi_intents(self, text, known_intents, threshold=0.45, max_labels=4, top_k=8):
|
||||
self.calls.append(
|
||||
{
|
||||
"text": text,
|
||||
"threshold": threshold,
|
||||
"max_labels": max_labels,
|
||||
"top_k": top_k,
|
||||
"known_count": len(known_intents),
|
||||
}
|
||||
)
|
||||
return [
|
||||
type("Candidate", (), {"intent_id": "cabin_window_open", "score": 0.93})(),
|
||||
type("Candidate", (), {"intent_id": "cabin_play_music", "score": 0.88})(),
|
||||
]
|
||||
|
||||
def warmup(self, sample_text="") -> bool:
|
||||
_ = sample_text
|
||||
return True
|
||||
|
||||
fake_nlu = FakeJointNlu()
|
||||
detector = JointBertMultiIntentDetector(fake_nlu, threshold=0.5, top_k=6, max_labels=3)
|
||||
|
||||
result = detector.detect("打开车窗并播放音乐", intents)
|
||||
|
||||
self.assertTrue(result.detected)
|
||||
self.assertEqual(result.backend_name, "joint-bert-multi-label")
|
||||
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
|
||||
self.assertEqual(fake_nlu.calls[0]["threshold"], 0.5)
|
||||
self.assertEqual(fake_nlu.calls[0]["top_k"], 6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user