145 lines
5.2 KiB
Python
145 lines
5.2 KiB
Python
from __future__ import annotations
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from app.schemas.intent import IntentDefinition
|
|
from app.services.multi_intent_detector import BertMultiIntentDetector, JointBertMultiIntentDetector
|
|
|
|
|
|
class FakeTokenizer:
|
|
def __call__(self, text, truncation=True, padding=False, return_tensors="pt"):
|
|
_ = (text, truncation, padding, return_tensors)
|
|
return {
|
|
"input_ids": torch.tensor([[101, 102]], dtype=torch.long),
|
|
"attention_mask": torch.tensor([[1, 1]], dtype=torch.long),
|
|
}
|
|
|
|
|
|
class FakeModel:
|
|
def __init__(self, logits: list[float], id2label: dict[int, str]) -> None:
|
|
self.config = type("Config", (), {"id2label": id2label})()
|
|
self._logits = torch.tensor([logits], dtype=torch.float32)
|
|
|
|
def eval(self) -> None:
|
|
return None
|
|
|
|
def __call__(self, **kwargs):
|
|
_ = kwargs
|
|
return type("Output", (), {"logits": self._logits})()
|
|
|
|
|
|
class RuntimeBackedDetector(BertMultiIntentDetector):
|
|
def __init__(
|
|
self,
|
|
logits: list[float],
|
|
id2label: dict[int, str],
|
|
threshold: float = 0.45,
|
|
top_k: int = 8,
|
|
max_labels: int = 4,
|
|
) -> None:
|
|
super().__init__(
|
|
model_path="unused",
|
|
threshold=threshold,
|
|
top_k=top_k,
|
|
max_labels=max_labels,
|
|
)
|
|
self._runtime = (torch, FakeTokenizer(), FakeModel(logits, id2label))
|
|
|
|
def _get_runtime(self):
|
|
return self._runtime
|
|
|
|
|
|
class MultiIntentDetectorTests(unittest.TestCase):
|
|
def test_detector_filters_blocked_and_unknown_labels(self) -> None:
|
|
detector = RuntimeBackedDetector(
|
|
logits=[2.4, 2.0, 3.2, 2.6],
|
|
id2label={
|
|
0: "cabin_window_open",
|
|
1: "cabin_play_music",
|
|
2: "__social__",
|
|
3: "unknown_intent",
|
|
},
|
|
threshold=0.8,
|
|
top_k=4,
|
|
)
|
|
intents = [
|
|
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
|
|
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
|
|
]
|
|
|
|
result = detector.detect("打开车窗并播放音乐", intents)
|
|
|
|
self.assertTrue(result.detected)
|
|
self.assertEqual(result.backend_name, "bert-multi-label")
|
|
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
|
|
|
|
def test_detector_respects_threshold_and_max_labels(self) -> None:
|
|
detector = RuntimeBackedDetector(
|
|
logits=[2.8, 2.5, 2.2],
|
|
id2label={
|
|
0: "cabin_window_open",
|
|
1: "cabin_play_music",
|
|
2: "cabin_nav_to",
|
|
},
|
|
threshold=0.89,
|
|
top_k=3,
|
|
max_labels=2,
|
|
)
|
|
intents = [
|
|
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
|
|
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
|
|
IntentDefinition(intent_id="cabin_nav_to", plugin_id="plugin.nav", domain="cabin"),
|
|
]
|
|
|
|
result = detector.detect("开窗放歌去公司", intents)
|
|
|
|
self.assertTrue(result.detected)
|
|
self.assertEqual(len(result.candidates), 2)
|
|
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
|
|
|
|
def test_joint_bert_detector_wraps_shared_runtime(self) -> None:
|
|
intents = [
|
|
IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"),
|
|
IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"),
|
|
]
|
|
|
|
class FakeJointNlu:
|
|
def __init__(self) -> None:
|
|
self.calls: list[dict[str, object]] = []
|
|
|
|
def predict_multi_intents(self, text, known_intents, threshold=0.45, max_labels=4, top_k=8):
|
|
self.calls.append(
|
|
{
|
|
"text": text,
|
|
"threshold": threshold,
|
|
"max_labels": max_labels,
|
|
"top_k": top_k,
|
|
"known_count": len(known_intents),
|
|
}
|
|
)
|
|
return [
|
|
type("Candidate", (), {"intent_id": "cabin_window_open", "score": 0.93})(),
|
|
type("Candidate", (), {"intent_id": "cabin_play_music", "score": 0.88})(),
|
|
]
|
|
|
|
def warmup(self, sample_text="") -> bool:
|
|
_ = sample_text
|
|
return True
|
|
|
|
fake_nlu = FakeJointNlu()
|
|
detector = JointBertMultiIntentDetector(fake_nlu, threshold=0.5, top_k=6, max_labels=3)
|
|
|
|
result = detector.detect("打开车窗并播放音乐", intents)
|
|
|
|
self.assertTrue(result.detected)
|
|
self.assertEqual(result.backend_name, "joint-bert-multi-label")
|
|
self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"])
|
|
self.assertEqual(fake_nlu.calls[0]["threshold"], 0.5)
|
|
self.assertEqual(fake_nlu.calls[0]["top_k"], 6)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|