Update project and configurations
This commit is contained in:
219
intelligent_cabin/app/services/multi_intent_detector.py
Normal file
219
intelligent_cabin/app/services/multi_intent_detector.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
from typing import Any, Protocol
|
||||
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.services.joint_nlu import JointBertNLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiIntentCandidate:
|
||||
intent_id: str
|
||||
score: float
|
||||
label: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiIntentDetectionResult:
|
||||
detected: bool = False
|
||||
candidates: list[MultiIntentCandidate] = field(default_factory=list)
|
||||
reason: str | None = None
|
||||
backend_name: str = "none"
|
||||
error_message: str | None = None
|
||||
raw_scores: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class MultiIntentDetector(Protocol):
|
||||
def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult:
|
||||
...
|
||||
|
||||
|
||||
class JointBertMultiIntentDetector:
|
||||
"""
|
||||
A multi-intent detector backed by the same Joint BERT runtime as single-intent and slot extraction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nlu: JointBertNLU,
|
||||
threshold: float | None = None,
|
||||
top_k: int = 8,
|
||||
max_labels: int = 4,
|
||||
) -> None:
|
||||
self._nlu = nlu
|
||||
self._threshold = threshold
|
||||
self._top_k = top_k
|
||||
self._max_labels = max_labels
|
||||
|
||||
def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult:
|
||||
candidates = self._nlu.predict_multi_intents(
|
||||
text,
|
||||
intents,
|
||||
threshold=self._threshold,
|
||||
max_labels=self._max_labels,
|
||||
top_k=self._top_k,
|
||||
)
|
||||
return MultiIntentDetectionResult(
|
||||
detected=len(candidates) >= 2,
|
||||
candidates=[
|
||||
MultiIntentCandidate(intent_id=item.intent_id, score=item.score, label=item.intent_id)
|
||||
for item in candidates
|
||||
],
|
||||
reason=f"joint bert multi-label candidates={len(candidates)} threshold={self._threshold}",
|
||||
backend_name="joint-bert-multi-label",
|
||||
raw_scores=[
|
||||
{"intent_id": item.intent_id, "label": item.intent_id, "score": float(item.score)}
|
||||
for item in candidates
|
||||
],
|
||||
)
|
||||
|
||||
def warmup(self, sample_text: str = "打开空调并打开车窗") -> bool:
|
||||
return self._nlu.warmup(sample_text)
|
||||
|
||||
|
||||
class BertMultiIntentDetector:
|
||||
"""
|
||||
A stage-2 multi-intent detector backed by a dedicated multi-label BERT head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
threshold: float = 0.45,
|
||||
top_k: int = 8,
|
||||
max_labels: int = 4,
|
||||
blocked_labels: set[str] | None = None,
|
||||
) -> None:
|
||||
self._model_path = model_path
|
||||
self._threshold = threshold
|
||||
self._top_k = top_k
|
||||
self._max_labels = max_labels
|
||||
self._blocked_labels = blocked_labels or {"__social__", "__out_of_scope__"}
|
||||
self._tokenizer = None
|
||||
self._model = None
|
||||
self._torch = None
|
||||
self._error_message: str | None = None
|
||||
self._warmed_up = False
|
||||
self._warmup_elapsed_ms: float | None = None
|
||||
|
||||
def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult:
|
||||
runtime = self._get_runtime()
|
||||
if runtime is None:
|
||||
return MultiIntentDetectionResult(
|
||||
detected=False,
|
||||
reason="multi-label detector unavailable",
|
||||
backend_name="bert-multi-label",
|
||||
error_message=self._error_message,
|
||||
)
|
||||
torch, tokenizer, model = runtime
|
||||
intent_map = {intent.intent_id: intent for intent in intents}
|
||||
try:
|
||||
encoded = tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
padding=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
logits = model(**encoded).logits.squeeze(0)
|
||||
probs = torch.sigmoid(logits).detach().cpu().tolist()
|
||||
except Exception as exc:
|
||||
return MultiIntentDetectionResult(
|
||||
detected=False,
|
||||
reason="multi-label forward failed",
|
||||
backend_name="bert-multi-label",
|
||||
error_message=str(exc),
|
||||
)
|
||||
|
||||
id2label = getattr(model.config, "id2label", {}) or {}
|
||||
raw_scores: list[dict[str, Any]] = []
|
||||
for index, score in enumerate(probs):
|
||||
label = str(id2label.get(index, f"LABEL_{index}"))
|
||||
raw_scores.append(
|
||||
{
|
||||
"label": label,
|
||||
"intent_id": label,
|
||||
"score": float(score),
|
||||
}
|
||||
)
|
||||
raw_scores.sort(key=lambda item: item["score"], reverse=True)
|
||||
raw_top = raw_scores[: self._top_k]
|
||||
|
||||
candidates: list[MultiIntentCandidate] = []
|
||||
for item in raw_top:
|
||||
intent_id = str(item.get("intent_id") or "")
|
||||
if intent_id in self._blocked_labels:
|
||||
continue
|
||||
if intent_id not in intent_map:
|
||||
continue
|
||||
score = float(item.get("score", 0.0))
|
||||
if score < self._threshold:
|
||||
continue
|
||||
candidates.append(
|
||||
MultiIntentCandidate(
|
||||
intent_id=intent_id,
|
||||
score=score,
|
||||
label=str(item.get("label") or intent_id),
|
||||
)
|
||||
)
|
||||
if len(candidates) >= self._max_labels:
|
||||
break
|
||||
|
||||
return MultiIntentDetectionResult(
|
||||
detected=len(candidates) >= 2,
|
||||
candidates=candidates,
|
||||
reason=f"bert multi-label candidates={len(candidates)} threshold={self._threshold}",
|
||||
backend_name="bert-multi-label",
|
||||
raw_scores=raw_top,
|
||||
)
|
||||
|
||||
def warmup(self, sample_text: str = "打开空调并打开车窗") -> bool:
|
||||
if self._warmed_up:
|
||||
return True
|
||||
started_at = perf_counter()
|
||||
runtime = self._get_runtime()
|
||||
if runtime is None:
|
||||
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
||||
return False
|
||||
torch, tokenizer, model = runtime
|
||||
try:
|
||||
encoded = tokenizer(sample_text, truncation=True, padding=False, return_tensors="pt")
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
_ = model(**encoded).logits
|
||||
except Exception as exc:
|
||||
self._error_message = str(exc)
|
||||
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
||||
return False
|
||||
self._warmed_up = True
|
||||
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
||||
return True
|
||||
|
||||
def _get_runtime(self):
|
||||
if self._tokenizer is not None and self._model is not None and self._torch is not None:
|
||||
return self._torch, self._tokenizer, self._model
|
||||
if not self._model_path or not Path(self._model_path).exists():
|
||||
self._error_message = "multi-intent model path is empty or missing"
|
||||
return None
|
||||
try:
|
||||
transformers = importlib.import_module("transformers")
|
||||
torch = importlib.import_module("torch")
|
||||
except ImportError as exc:
|
||||
self._error_message = str(exc)
|
||||
return None
|
||||
try:
|
||||
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self._model_path)
|
||||
self._model = transformers.AutoModelForSequenceClassification.from_pretrained(self._model_path)
|
||||
self._torch = torch
|
||||
except Exception as exc:
|
||||
self._error_message = str(exc)
|
||||
return None
|
||||
return self._torch, self._tokenizer, self._model
|
||||
|
||||
|
||||
SigmoidBertMultiIntentDetector = BertMultiIntentDetector
|
||||
Reference in New Issue
Block a user