from __future__ import annotations import unittest import torch from app.schemas.intent import IntentDefinition from app.services.multi_intent_detector import BertMultiIntentDetector, JointBertMultiIntentDetector class FakeTokenizer: def __call__(self, text, truncation=True, padding=False, return_tensors="pt"): _ = (text, truncation, padding, return_tensors) return { "input_ids": torch.tensor([[101, 102]], dtype=torch.long), "attention_mask": torch.tensor([[1, 1]], dtype=torch.long), } class FakeModel: def __init__(self, logits: list[float], id2label: dict[int, str]) -> None: self.config = type("Config", (), {"id2label": id2label})() self._logits = torch.tensor([logits], dtype=torch.float32) def eval(self) -> None: return None def __call__(self, **kwargs): _ = kwargs return type("Output", (), {"logits": self._logits})() class RuntimeBackedDetector(BertMultiIntentDetector): def __init__( self, logits: list[float], id2label: dict[int, str], threshold: float = 0.45, top_k: int = 8, max_labels: int = 4, ) -> None: super().__init__( model_path="unused", threshold=threshold, top_k=top_k, max_labels=max_labels, ) self._runtime = (torch, FakeTokenizer(), FakeModel(logits, id2label)) def _get_runtime(self): return self._runtime class MultiIntentDetectorTests(unittest.TestCase): def test_detector_filters_blocked_and_unknown_labels(self) -> None: detector = RuntimeBackedDetector( logits=[2.4, 2.0, 3.2, 2.6], id2label={ 0: "cabin_window_open", 1: "cabin_play_music", 2: "__social__", 3: "unknown_intent", }, threshold=0.8, top_k=4, ) intents = [ IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"), IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"), ] result = detector.detect("打开车窗并播放音乐", intents) self.assertTrue(result.detected) self.assertEqual(result.backend_name, "bert-multi-label") self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"]) def test_detector_respects_threshold_and_max_labels(self) -> None: detector = RuntimeBackedDetector( logits=[2.8, 2.5, 2.2], id2label={ 0: "cabin_window_open", 1: "cabin_play_music", 2: "cabin_nav_to", }, threshold=0.89, top_k=3, max_labels=2, ) intents = [ IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"), IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"), IntentDefinition(intent_id="cabin_nav_to", plugin_id="plugin.nav", domain="cabin"), ] result = detector.detect("开窗放歌去公司", intents) self.assertTrue(result.detected) self.assertEqual(len(result.candidates), 2) self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"]) def test_joint_bert_detector_wraps_shared_runtime(self) -> None: intents = [ IntentDefinition(intent_id="cabin_window_open", plugin_id="plugin.window", domain="cabin"), IntentDefinition(intent_id="cabin_play_music", plugin_id="plugin.music", domain="cabin"), ] class FakeJointNlu: def __init__(self) -> None: self.calls: list[dict[str, object]] = [] def predict_multi_intents(self, text, known_intents, threshold=0.45, max_labels=4, top_k=8): self.calls.append( { "text": text, "threshold": threshold, "max_labels": max_labels, "top_k": top_k, "known_count": len(known_intents), } ) return [ type("Candidate", (), {"intent_id": "cabin_window_open", "score": 0.93})(), type("Candidate", (), {"intent_id": "cabin_play_music", "score": 0.88})(), ] def warmup(self, sample_text="") -> bool: _ = sample_text return True fake_nlu = FakeJointNlu() detector = JointBertMultiIntentDetector(fake_nlu, threshold=0.5, top_k=6, max_labels=3) result = detector.detect("打开车窗并播放音乐", intents) self.assertTrue(result.detected) self.assertEqual(result.backend_name, "joint-bert-multi-label") self.assertEqual([item.intent_id for item in result.candidates], ["cabin_window_open", "cabin_play_music"]) self.assertEqual(fake_nlu.calls[0]["threshold"], 0.5) self.assertEqual(fake_nlu.calls[0]["top_k"], 6) if __name__ == "__main__": unittest.main()