Files
2026-06-11 16:28:00 +08:00

431 lines
17 KiB
Python

from __future__ import annotations
import json
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from time import perf_counter
from typing import Any
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
from app.schemas.intent import IntentDefinition
OPTIONAL_SLOT_NAMES_BY_INTENT: dict[str, set[str]] = {
"cabin_play_music": {"song", "genre"},
}
BLOCKED_INTENT_LABELS = {"__social__", "__out_of_scope__"}
def allowed_slot_names(intent_id: str, required_slots: list[str] | None = None) -> set[str]:
required = set(required_slots or [])
return required | OPTIONAL_SLOT_NAMES_BY_INTENT.get(intent_id, set())
@dataclass
class JointSlot:
slot_name: str
value: str
start: int
end: int
score: float = 0.0
@dataclass
class JointCandidate:
intent_id: str
score: float
@dataclass
class JointNluResult:
intent_id: str | None = None
intent_score: float = 0.0
candidates: list[JointCandidate] = field(default_factory=list)
multi_intent_candidates: list[JointCandidate] = field(default_factory=list)
slots: dict[str, Any] = field(default_factory=dict)
slot_items: list[JointSlot] = field(default_factory=list)
model_name: str = "joint-bert-local"
backend_name: str = "joint-bert-local"
error_message: str | None = None
class JointBertForNLU(torch.nn.Module):
def __init__(
self,
base_model_name: str,
num_intents: int,
num_slot_labels: int,
encoder_config_path: str | Path | None = None,
) -> None:
super().__init__()
if encoder_config_path is not None:
encoder_config = AutoConfig.from_pretrained(encoder_config_path, local_files_only=True)
self.encoder = AutoModel.from_config(encoder_config)
else:
self.encoder = AutoModel.from_pretrained(base_model_name)
hidden_size = int(self.encoder.config.hidden_size)
dropout_prob = float(getattr(self.encoder.config, "hidden_dropout_prob", 0.1))
self.dropout = torch.nn.Dropout(dropout_prob)
self.intent_classifier = torch.nn.Linear(hidden_size, num_intents)
self.slot_classifier = torch.nn.Linear(hidden_size, num_slot_labels)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
encoder_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
if token_type_ids is not None:
encoder_kwargs["token_type_ids"] = token_type_ids
outputs = self.encoder(**encoder_kwargs)
sequence_output = self.dropout(outputs.last_hidden_state)
pooled_output = self.dropout(sequence_output[:, 0])
intent_logits = self.intent_classifier(pooled_output)
slot_logits = self.slot_classifier(sequence_output)
return intent_logits, slot_logits
class JointBertNLU:
def __init__(
self,
model_path: str,
intent_threshold: float | None = None,
multi_intent_threshold: float | None = None,
top_k: int = 5,
max_multi_intents: int = 4,
max_cache_size: int = 8,
) -> None:
self._model_path = Path(model_path)
self._intent_threshold = intent_threshold
self._multi_intent_threshold = multi_intent_threshold
self._top_k = top_k
self._max_multi_intents = max_multi_intents
self._max_cache_size = max_cache_size
self._runtime: tuple[AutoTokenizer, JointBertForNLU, dict[str, Any], torch.device] | None = None
self._warmup_elapsed_ms: float | None = None
self._warmup_error_message: str | None = None
self._warmed_up = False
self._cache: OrderedDict[str, dict[str, Any]] = OrderedDict()
def warmup(self, sample_text: str = "把空调调到22度") -> bool:
started_at = perf_counter()
try:
self._predict_raw(sample_text)
except Exception as exc:
self._warmup_error_message = str(exc)
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
return False
self._warmup_error_message = None
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
self._warmed_up = True
return True
def predict(self, text: str, intents: list[IntentDefinition]) -> JointNluResult:
try:
raw_result = self._predict_raw(text)
except Exception as exc:
return JointNluResult(error_message=str(exc))
candidates = self._filter_known_candidates(raw_result["candidates"], intents, limit=self._top_k)
multi_candidates = self.predict_multi_intents(text, intents)
top_candidate = candidates[0] if candidates else None
if top_candidate is None or top_candidate.score < self._resolved_intent_threshold():
return JointNluResult(
intent_id=None,
intent_score=top_candidate.score if top_candidate is not None else 0.0,
candidates=candidates,
multi_intent_candidates=multi_candidates,
slots={},
slot_items=[],
)
intent_def = next((intent for intent in intents if intent.intent_id == top_candidate.intent_id), None)
if intent_def is None:
return JointNluResult(
intent_id=None,
intent_score=top_candidate.score,
candidates=candidates,
multi_intent_candidates=multi_candidates,
slots={},
slot_items=[],
)
slot_items = self._filter_slot_items(raw_result["slot_items"], intent_def.intent_id, intent_def.required_slots)
return JointNluResult(
intent_id=top_candidate.intent_id,
intent_score=top_candidate.score,
candidates=candidates,
multi_intent_candidates=multi_candidates,
slots=self._slot_items_to_dict(slot_items),
slot_items=slot_items,
)
def predict_multi_intents(
self,
text: str,
intents: list[IntentDefinition],
threshold: float | None = None,
max_labels: int | None = None,
top_k: int | None = None,
) -> list[JointCandidate]:
try:
raw_result = self._predict_raw(text)
except Exception:
return []
threshold = self._multi_intent_threshold if threshold is None else threshold
if threshold is None:
threshold = self._resolved_multi_intent_threshold()
max_labels = self._max_multi_intents if max_labels is None else max_labels
ranked = self._filter_known_candidates(raw_result["candidates"], intents, limit=top_k or self._top_k)
selected: list[JointCandidate] = []
for item in ranked:
if item.score < threshold:
continue
selected.append(item)
if len(selected) >= max_labels:
break
return selected
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
try:
raw_result = self._predict_raw(text)
except Exception:
return {}
slot_items = self._filter_slot_items(raw_result["slot_items"], intent.intent_id, intent.required_slots)
return self._slot_items_to_dict(slot_items)
def extract_slots_by_intent_id(
self,
text: str,
intent_id: str,
required_slots: list[str] | None = None,
) -> dict[str, Any]:
try:
raw_result = self._predict_raw(text)
except Exception:
return {}
slot_items = self._filter_slot_items(raw_result["slot_items"], intent_id, required_slots or [])
return self._slot_items_to_dict(slot_items)
def _filter_known_candidates(
self,
candidates: list[JointCandidate],
intents: list[IntentDefinition],
limit: int | None = None,
) -> list[JointCandidate]:
known_intents = {intent.intent_id for intent in intents}
filtered = [
item
for item in candidates
if item.intent_id in known_intents and item.intent_id not in BLOCKED_INTENT_LABELS
]
return filtered[:limit] if limit is not None else filtered
def _slot_items_to_dict(self, slot_items: list[JointSlot]) -> dict[str, Any]:
slots: dict[str, Any] = {}
for item in slot_items:
if item.slot_name == "temperature":
digits = "".join(ch for ch in item.value if ch.isdigit())
if digits:
slots[item.slot_name] = int(digits)
continue
slots[item.slot_name] = item.value
return slots
def _filter_slot_items(
self,
slot_items: list[JointSlot],
intent_id: str,
required_slots: list[str],
) -> list[JointSlot]:
allowed = allowed_slot_names(intent_id, required_slots)
if not allowed:
return []
filtered = [item for item in slot_items if item.slot_name in allowed]
deduped: list[JointSlot] = []
seen: set[tuple[str, int, int]] = set()
for item in filtered:
key = (item.slot_name, item.start, item.end)
if key in seen:
continue
seen.add(key)
deduped.append(item)
return deduped
def _predict_raw(self, text: str) -> dict[str, Any]:
normalized = (text or "").strip()
if not normalized:
return {"candidates": [], "slot_items": []}
if normalized in self._cache:
cached = self._cache.pop(normalized)
self._cache[normalized] = cached
return cached
tokenizer, model, metadata, device = self._load_runtime()
encoded = tokenizer(
normalized,
truncation=True,
max_length=int(metadata.get("max_length", 64)),
return_offsets_mapping=True,
return_tensors="pt",
)
offset_mapping = encoded.pop("offset_mapping")[0].tolist()
encoded = {key: value.to(device) for key, value in encoded.items()}
model.eval()
with torch.no_grad():
intent_logits, slot_logits = model(**encoded)
slot_probs = torch.softmax(slot_logits, dim=-1)[0].detach().cpu()
slot_ids = torch.argmax(slot_probs, dim=-1).tolist()
intent_probs = self._intent_probabilities(intent_logits.detach().cpu()[0], metadata)
intent_labels = metadata.get("intent_labels", [])
slot_labels = metadata.get("slot_labels", [])
candidates = [
JointCandidate(intent_id=str(intent_labels[index]), score=float(score))
for index, score in sorted(
list(enumerate(intent_probs)),
key=lambda item: item[1],
reverse=True,
)
]
slot_items = self._decode_slot_items(
text=normalized,
offset_mapping=offset_mapping,
slot_ids=slot_ids,
slot_probs=slot_probs,
slot_labels=slot_labels,
)
result = {
"candidates": candidates,
"slot_items": slot_items,
}
self._cache[normalized] = result
while len(self._cache) > self._max_cache_size:
self._cache.popitem(last=False)
return result
def _intent_probabilities(self, intent_logits: torch.Tensor, metadata: dict[str, Any]) -> list[float]:
task_type = str(metadata.get("intent_task", "single_label")).strip() or "single_label"
if task_type == "multi_label":
return torch.sigmoid(intent_logits).tolist()
return torch.softmax(intent_logits, dim=-1).tolist()
def _decode_slot_items(
self,
text: str,
offset_mapping: list[list[int]],
slot_ids: list[int],
slot_probs: torch.Tensor,
slot_labels: list[str],
) -> list[JointSlot]:
items: list[JointSlot] = []
current_name: str | None = None
current_start: int | None = None
current_end: int | None = None
current_scores: list[float] = []
def flush() -> None:
nonlocal current_name, current_start, current_end, current_scores
if current_name is None or current_start is None or current_end is None or current_start >= current_end:
current_name = None
current_start = None
current_end = None
current_scores = []
return
value = text[current_start:current_end].strip()
if value:
items.append(
JointSlot(
slot_name=current_name,
value=value,
start=current_start,
end=current_end,
score=round(sum(current_scores) / max(len(current_scores), 1), 4),
)
)
current_name = None
current_start = None
current_end = None
current_scores = []
for index, label_id in enumerate(slot_ids):
if index >= len(offset_mapping):
break
start, end = offset_mapping[index]
if end <= start:
flush()
continue
label = str(slot_labels[label_id]) if label_id < len(slot_labels) else "O"
token_score = float(slot_probs[index][label_id].item())
if label == "O":
flush()
continue
prefix, _, name = label.partition("-")
if prefix == "B" or current_name != name:
flush()
current_name = name
current_start = start
current_end = end
current_scores = [token_score]
continue
current_end = end
current_scores.append(token_score)
flush()
return items
def _load_runtime(self) -> tuple[AutoTokenizer, JointBertForNLU, dict[str, Any], torch.device]:
if self._runtime is not None:
return self._runtime
if not self._model_path.exists():
raise FileNotFoundError(f"joint nlu model path not found: {self._model_path}")
metadata_path = self._model_path / "joint_nlu_config.json"
state_dict_path = self._model_path / "model_state.pt"
if not metadata_path.exists():
raise FileNotFoundError(f"joint nlu config missing: {metadata_path}")
if not state_dict_path.exists():
raise FileNotFoundError(f"joint nlu model state missing: {state_dict_path}")
metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
tokenizer = AutoTokenizer.from_pretrained(self._model_path)
model = JointBertForNLU(
base_model_name=str(metadata["base_model_name"]),
num_intents=len(metadata["intent_labels"]),
num_slot_labels=len(metadata["slot_labels"]),
encoder_config_path=self._resolve_encoder_config_path(metadata),
)
state_dict = torch.load(state_dict_path, map_location="cpu")
model.load_state_dict(state_dict)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
self._runtime = (tokenizer, model, metadata, device)
return self._runtime
def _resolve_encoder_config_path(self, metadata: dict[str, Any]) -> Path | None:
local_config = self._model_path / "config.json"
if local_config.exists():
return self._model_path
base_model_path = Path(str(metadata.get("base_model_name", "")))
if base_model_path.exists() and (base_model_path / "config.json").exists():
return base_model_path
for candidate_name in ("local_bert_intent", "local_bert_multi_intent"):
candidate_path = self._model_path.parent / candidate_name
if (candidate_path / "config.json").exists():
return candidate_path
return None
def _resolved_intent_threshold(self) -> float:
if self._intent_threshold is not None:
return self._intent_threshold
metadata = self._runtime[2] if self._runtime is not None else {}
return float(metadata.get("intent_threshold", 0.35))
def _resolved_multi_intent_threshold(self) -> float:
if self._multi_intent_threshold is not None:
return self._multi_intent_threshold
metadata = self._runtime[2] if self._runtime is not None else {}
return float(metadata.get("multi_intent_threshold", metadata.get("intent_threshold", 0.45)))