220 lines
7.6 KiB
Python
220 lines
7.6 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from time import perf_counter
|
|
from typing import Any, Protocol
|
|
|
|
from app.schemas.intent import IntentDefinition
|
|
from app.services.joint_nlu import JointBertNLU
|
|
|
|
|
|
@dataclass
|
|
class MultiIntentCandidate:
|
|
intent_id: str
|
|
score: float
|
|
label: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class MultiIntentDetectionResult:
|
|
detected: bool = False
|
|
candidates: list[MultiIntentCandidate] = field(default_factory=list)
|
|
reason: str | None = None
|
|
backend_name: str = "none"
|
|
error_message: str | None = None
|
|
raw_scores: list[dict[str, Any]] = field(default_factory=list)
|
|
|
|
|
|
class MultiIntentDetector(Protocol):
|
|
def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult:
|
|
...
|
|
|
|
|
|
class JointBertMultiIntentDetector:
|
|
"""
|
|
A multi-intent detector backed by the same Joint BERT runtime as single-intent and slot extraction.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
nlu: JointBertNLU,
|
|
threshold: float | None = None,
|
|
top_k: int = 8,
|
|
max_labels: int = 4,
|
|
) -> None:
|
|
self._nlu = nlu
|
|
self._threshold = threshold
|
|
self._top_k = top_k
|
|
self._max_labels = max_labels
|
|
|
|
def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult:
|
|
candidates = self._nlu.predict_multi_intents(
|
|
text,
|
|
intents,
|
|
threshold=self._threshold,
|
|
max_labels=self._max_labels,
|
|
top_k=self._top_k,
|
|
)
|
|
return MultiIntentDetectionResult(
|
|
detected=len(candidates) >= 2,
|
|
candidates=[
|
|
MultiIntentCandidate(intent_id=item.intent_id, score=item.score, label=item.intent_id)
|
|
for item in candidates
|
|
],
|
|
reason=f"joint bert multi-label candidates={len(candidates)} threshold={self._threshold}",
|
|
backend_name="joint-bert-multi-label",
|
|
raw_scores=[
|
|
{"intent_id": item.intent_id, "label": item.intent_id, "score": float(item.score)}
|
|
for item in candidates
|
|
],
|
|
)
|
|
|
|
def warmup(self, sample_text: str = "打开空调并打开车窗") -> bool:
|
|
return self._nlu.warmup(sample_text)
|
|
|
|
|
|
class BertMultiIntentDetector:
|
|
"""
|
|
A stage-2 multi-intent detector backed by a dedicated multi-label BERT head.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: str,
|
|
threshold: float = 0.45,
|
|
top_k: int = 8,
|
|
max_labels: int = 4,
|
|
blocked_labels: set[str] | None = None,
|
|
) -> None:
|
|
self._model_path = model_path
|
|
self._threshold = threshold
|
|
self._top_k = top_k
|
|
self._max_labels = max_labels
|
|
self._blocked_labels = blocked_labels or {"__social__", "__out_of_scope__"}
|
|
self._tokenizer = None
|
|
self._model = None
|
|
self._torch = None
|
|
self._error_message: str | None = None
|
|
self._warmed_up = False
|
|
self._warmup_elapsed_ms: float | None = None
|
|
|
|
def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult:
|
|
runtime = self._get_runtime()
|
|
if runtime is None:
|
|
return MultiIntentDetectionResult(
|
|
detected=False,
|
|
reason="multi-label detector unavailable",
|
|
backend_name="bert-multi-label",
|
|
error_message=self._error_message,
|
|
)
|
|
torch, tokenizer, model = runtime
|
|
intent_map = {intent.intent_id: intent for intent in intents}
|
|
try:
|
|
encoded = tokenizer(
|
|
text,
|
|
truncation=True,
|
|
padding=False,
|
|
return_tensors="pt",
|
|
)
|
|
model.eval()
|
|
with torch.no_grad():
|
|
logits = model(**encoded).logits.squeeze(0)
|
|
probs = torch.sigmoid(logits).detach().cpu().tolist()
|
|
except Exception as exc:
|
|
return MultiIntentDetectionResult(
|
|
detected=False,
|
|
reason="multi-label forward failed",
|
|
backend_name="bert-multi-label",
|
|
error_message=str(exc),
|
|
)
|
|
|
|
id2label = getattr(model.config, "id2label", {}) or {}
|
|
raw_scores: list[dict[str, Any]] = []
|
|
for index, score in enumerate(probs):
|
|
label = str(id2label.get(index, f"LABEL_{index}"))
|
|
raw_scores.append(
|
|
{
|
|
"label": label,
|
|
"intent_id": label,
|
|
"score": float(score),
|
|
}
|
|
)
|
|
raw_scores.sort(key=lambda item: item["score"], reverse=True)
|
|
raw_top = raw_scores[: self._top_k]
|
|
|
|
candidates: list[MultiIntentCandidate] = []
|
|
for item in raw_top:
|
|
intent_id = str(item.get("intent_id") or "")
|
|
if intent_id in self._blocked_labels:
|
|
continue
|
|
if intent_id not in intent_map:
|
|
continue
|
|
score = float(item.get("score", 0.0))
|
|
if score < self._threshold:
|
|
continue
|
|
candidates.append(
|
|
MultiIntentCandidate(
|
|
intent_id=intent_id,
|
|
score=score,
|
|
label=str(item.get("label") or intent_id),
|
|
)
|
|
)
|
|
if len(candidates) >= self._max_labels:
|
|
break
|
|
|
|
return MultiIntentDetectionResult(
|
|
detected=len(candidates) >= 2,
|
|
candidates=candidates,
|
|
reason=f"bert multi-label candidates={len(candidates)} threshold={self._threshold}",
|
|
backend_name="bert-multi-label",
|
|
raw_scores=raw_top,
|
|
)
|
|
|
|
def warmup(self, sample_text: str = "打开空调并打开车窗") -> bool:
|
|
if self._warmed_up:
|
|
return True
|
|
started_at = perf_counter()
|
|
runtime = self._get_runtime()
|
|
if runtime is None:
|
|
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
|
return False
|
|
torch, tokenizer, model = runtime
|
|
try:
|
|
encoded = tokenizer(sample_text, truncation=True, padding=False, return_tensors="pt")
|
|
model.eval()
|
|
with torch.no_grad():
|
|
_ = model(**encoded).logits
|
|
except Exception as exc:
|
|
self._error_message = str(exc)
|
|
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
|
return False
|
|
self._warmed_up = True
|
|
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
|
return True
|
|
|
|
def _get_runtime(self):
|
|
if self._tokenizer is not None and self._model is not None and self._torch is not None:
|
|
return self._torch, self._tokenizer, self._model
|
|
if not self._model_path or not Path(self._model_path).exists():
|
|
self._error_message = "multi-intent model path is empty or missing"
|
|
return None
|
|
try:
|
|
transformers = importlib.import_module("transformers")
|
|
torch = importlib.import_module("torch")
|
|
except ImportError as exc:
|
|
self._error_message = str(exc)
|
|
return None
|
|
try:
|
|
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self._model_path)
|
|
self._model = transformers.AutoModelForSequenceClassification.from_pretrained(self._model_path)
|
|
self._torch = torch
|
|
except Exception as exc:
|
|
self._error_message = str(exc)
|
|
return None
|
|
return self._torch, self._tokenizer, self._model
|
|
|
|
|
|
SigmoidBertMultiIntentDetector = BertMultiIntentDetector
|