Update project and configurations

This commit is contained in:
Zou-Seay
2026-06-11 16:28:00 +08:00
parent 12d3922091
commit a29a91867d
237 changed files with 164880 additions and 90 deletions

View File

@@ -0,0 +1,219 @@
from __future__ import annotations
import importlib
from dataclasses import dataclass, field
from pathlib import Path
from time import perf_counter
from typing import Any, Protocol
from app.schemas.intent import IntentDefinition
from app.services.joint_nlu import JointBertNLU
@dataclass
class MultiIntentCandidate:
intent_id: str
score: float
label: str | None = None
@dataclass
class MultiIntentDetectionResult:
detected: bool = False
candidates: list[MultiIntentCandidate] = field(default_factory=list)
reason: str | None = None
backend_name: str = "none"
error_message: str | None = None
raw_scores: list[dict[str, Any]] = field(default_factory=list)
class MultiIntentDetector(Protocol):
def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult:
...
class JointBertMultiIntentDetector:
"""
A multi-intent detector backed by the same Joint BERT runtime as single-intent and slot extraction.
"""
def __init__(
self,
nlu: JointBertNLU,
threshold: float | None = None,
top_k: int = 8,
max_labels: int = 4,
) -> None:
self._nlu = nlu
self._threshold = threshold
self._top_k = top_k
self._max_labels = max_labels
def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult:
candidates = self._nlu.predict_multi_intents(
text,
intents,
threshold=self._threshold,
max_labels=self._max_labels,
top_k=self._top_k,
)
return MultiIntentDetectionResult(
detected=len(candidates) >= 2,
candidates=[
MultiIntentCandidate(intent_id=item.intent_id, score=item.score, label=item.intent_id)
for item in candidates
],
reason=f"joint bert multi-label candidates={len(candidates)} threshold={self._threshold}",
backend_name="joint-bert-multi-label",
raw_scores=[
{"intent_id": item.intent_id, "label": item.intent_id, "score": float(item.score)}
for item in candidates
],
)
def warmup(self, sample_text: str = "打开空调并打开车窗") -> bool:
return self._nlu.warmup(sample_text)
class BertMultiIntentDetector:
"""
A stage-2 multi-intent detector backed by a dedicated multi-label BERT head.
"""
def __init__(
self,
model_path: str,
threshold: float = 0.45,
top_k: int = 8,
max_labels: int = 4,
blocked_labels: set[str] | None = None,
) -> None:
self._model_path = model_path
self._threshold = threshold
self._top_k = top_k
self._max_labels = max_labels
self._blocked_labels = blocked_labels or {"__social__", "__out_of_scope__"}
self._tokenizer = None
self._model = None
self._torch = None
self._error_message: str | None = None
self._warmed_up = False
self._warmup_elapsed_ms: float | None = None
def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult:
runtime = self._get_runtime()
if runtime is None:
return MultiIntentDetectionResult(
detected=False,
reason="multi-label detector unavailable",
backend_name="bert-multi-label",
error_message=self._error_message,
)
torch, tokenizer, model = runtime
intent_map = {intent.intent_id: intent for intent in intents}
try:
encoded = tokenizer(
text,
truncation=True,
padding=False,
return_tensors="pt",
)
model.eval()
with torch.no_grad():
logits = model(**encoded).logits.squeeze(0)
probs = torch.sigmoid(logits).detach().cpu().tolist()
except Exception as exc:
return MultiIntentDetectionResult(
detected=False,
reason="multi-label forward failed",
backend_name="bert-multi-label",
error_message=str(exc),
)
id2label = getattr(model.config, "id2label", {}) or {}
raw_scores: list[dict[str, Any]] = []
for index, score in enumerate(probs):
label = str(id2label.get(index, f"LABEL_{index}"))
raw_scores.append(
{
"label": label,
"intent_id": label,
"score": float(score),
}
)
raw_scores.sort(key=lambda item: item["score"], reverse=True)
raw_top = raw_scores[: self._top_k]
candidates: list[MultiIntentCandidate] = []
for item in raw_top:
intent_id = str(item.get("intent_id") or "")
if intent_id in self._blocked_labels:
continue
if intent_id not in intent_map:
continue
score = float(item.get("score", 0.0))
if score < self._threshold:
continue
candidates.append(
MultiIntentCandidate(
intent_id=intent_id,
score=score,
label=str(item.get("label") or intent_id),
)
)
if len(candidates) >= self._max_labels:
break
return MultiIntentDetectionResult(
detected=len(candidates) >= 2,
candidates=candidates,
reason=f"bert multi-label candidates={len(candidates)} threshold={self._threshold}",
backend_name="bert-multi-label",
raw_scores=raw_top,
)
def warmup(self, sample_text: str = "打开空调并打开车窗") -> bool:
if self._warmed_up:
return True
started_at = perf_counter()
runtime = self._get_runtime()
if runtime is None:
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
return False
torch, tokenizer, model = runtime
try:
encoded = tokenizer(sample_text, truncation=True, padding=False, return_tensors="pt")
model.eval()
with torch.no_grad():
_ = model(**encoded).logits
except Exception as exc:
self._error_message = str(exc)
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
return False
self._warmed_up = True
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
return True
def _get_runtime(self):
if self._tokenizer is not None and self._model is not None and self._torch is not None:
return self._torch, self._tokenizer, self._model
if not self._model_path or not Path(self._model_path).exists():
self._error_message = "multi-intent model path is empty or missing"
return None
try:
transformers = importlib.import_module("transformers")
torch = importlib.import_module("torch")
except ImportError as exc:
self._error_message = str(exc)
return None
try:
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self._model_path)
self._model = transformers.AutoModelForSequenceClassification.from_pretrained(self._model_path)
self._torch = torch
except Exception as exc:
self._error_message = str(exc)
return None
return self._torch, self._tokenizer, self._model
SigmoidBertMultiIntentDetector = BertMultiIntentDetector