from __future__ import annotations import importlib from dataclasses import dataclass, field from pathlib import Path from time import perf_counter from typing import Any, Protocol from app.schemas.intent import IntentDefinition from app.services.joint_nlu import JointBertNLU @dataclass class MultiIntentCandidate: intent_id: str score: float label: str | None = None @dataclass class MultiIntentDetectionResult: detected: bool = False candidates: list[MultiIntentCandidate] = field(default_factory=list) reason: str | None = None backend_name: str = "none" error_message: str | None = None raw_scores: list[dict[str, Any]] = field(default_factory=list) class MultiIntentDetector(Protocol): def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult: ... class JointBertMultiIntentDetector: """ A multi-intent detector backed by the same Joint BERT runtime as single-intent and slot extraction. """ def __init__( self, nlu: JointBertNLU, threshold: float | None = None, top_k: int = 8, max_labels: int = 4, ) -> None: self._nlu = nlu self._threshold = threshold self._top_k = top_k self._max_labels = max_labels def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult: candidates = self._nlu.predict_multi_intents( text, intents, threshold=self._threshold, max_labels=self._max_labels, top_k=self._top_k, ) return MultiIntentDetectionResult( detected=len(candidates) >= 2, candidates=[ MultiIntentCandidate(intent_id=item.intent_id, score=item.score, label=item.intent_id) for item in candidates ], reason=f"joint bert multi-label candidates={len(candidates)} threshold={self._threshold}", backend_name="joint-bert-multi-label", raw_scores=[ {"intent_id": item.intent_id, "label": item.intent_id, "score": float(item.score)} for item in candidates ], ) def warmup(self, sample_text: str = "打开空调并打开车窗") -> bool: return self._nlu.warmup(sample_text) class BertMultiIntentDetector: """ A stage-2 multi-intent detector backed by a dedicated multi-label BERT head. """ def __init__( self, model_path: str, threshold: float = 0.45, top_k: int = 8, max_labels: int = 4, blocked_labels: set[str] | None = None, ) -> None: self._model_path = model_path self._threshold = threshold self._top_k = top_k self._max_labels = max_labels self._blocked_labels = blocked_labels or {"__social__", "__out_of_scope__"} self._tokenizer = None self._model = None self._torch = None self._error_message: str | None = None self._warmed_up = False self._warmup_elapsed_ms: float | None = None def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult: runtime = self._get_runtime() if runtime is None: return MultiIntentDetectionResult( detected=False, reason="multi-label detector unavailable", backend_name="bert-multi-label", error_message=self._error_message, ) torch, tokenizer, model = runtime intent_map = {intent.intent_id: intent for intent in intents} try: encoded = tokenizer( text, truncation=True, padding=False, return_tensors="pt", ) model.eval() with torch.no_grad(): logits = model(**encoded).logits.squeeze(0) probs = torch.sigmoid(logits).detach().cpu().tolist() except Exception as exc: return MultiIntentDetectionResult( detected=False, reason="multi-label forward failed", backend_name="bert-multi-label", error_message=str(exc), ) id2label = getattr(model.config, "id2label", {}) or {} raw_scores: list[dict[str, Any]] = [] for index, score in enumerate(probs): label = str(id2label.get(index, f"LABEL_{index}")) raw_scores.append( { "label": label, "intent_id": label, "score": float(score), } ) raw_scores.sort(key=lambda item: item["score"], reverse=True) raw_top = raw_scores[: self._top_k] candidates: list[MultiIntentCandidate] = [] for item in raw_top: intent_id = str(item.get("intent_id") or "") if intent_id in self._blocked_labels: continue if intent_id not in intent_map: continue score = float(item.get("score", 0.0)) if score < self._threshold: continue candidates.append( MultiIntentCandidate( intent_id=intent_id, score=score, label=str(item.get("label") or intent_id), ) ) if len(candidates) >= self._max_labels: break return MultiIntentDetectionResult( detected=len(candidates) >= 2, candidates=candidates, reason=f"bert multi-label candidates={len(candidates)} threshold={self._threshold}", backend_name="bert-multi-label", raw_scores=raw_top, ) def warmup(self, sample_text: str = "打开空调并打开车窗") -> bool: if self._warmed_up: return True started_at = perf_counter() runtime = self._get_runtime() if runtime is None: self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3) return False torch, tokenizer, model = runtime try: encoded = tokenizer(sample_text, truncation=True, padding=False, return_tensors="pt") model.eval() with torch.no_grad(): _ = model(**encoded).logits except Exception as exc: self._error_message = str(exc) self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3) return False self._warmed_up = True self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3) return True def _get_runtime(self): if self._tokenizer is not None and self._model is not None and self._torch is not None: return self._torch, self._tokenizer, self._model if not self._model_path or not Path(self._model_path).exists(): self._error_message = "multi-intent model path is empty or missing" return None try: transformers = importlib.import_module("transformers") torch = importlib.import_module("torch") except ImportError as exc: self._error_message = str(exc) return None try: self._tokenizer = transformers.AutoTokenizer.from_pretrained(self._model_path) self._model = transformers.AutoModelForSequenceClassification.from_pretrained(self._model_path) self._torch = torch except Exception as exc: self._error_message = str(exc) return None return self._torch, self._tokenizer, self._model SigmoidBertMultiIntentDetector = BertMultiIntentDetector