431 lines
17 KiB
Python
431 lines
17 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from collections import OrderedDict
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from time import perf_counter
|
|
from typing import Any
|
|
|
|
import torch
|
|
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
|
|
|
from app.schemas.intent import IntentDefinition
|
|
|
|
|
|
OPTIONAL_SLOT_NAMES_BY_INTENT: dict[str, set[str]] = {
|
|
"cabin_play_music": {"song", "genre"},
|
|
}
|
|
|
|
BLOCKED_INTENT_LABELS = {"__social__", "__out_of_scope__"}
|
|
|
|
|
|
def allowed_slot_names(intent_id: str, required_slots: list[str] | None = None) -> set[str]:
|
|
required = set(required_slots or [])
|
|
return required | OPTIONAL_SLOT_NAMES_BY_INTENT.get(intent_id, set())
|
|
|
|
|
|
@dataclass
|
|
class JointSlot:
|
|
slot_name: str
|
|
value: str
|
|
start: int
|
|
end: int
|
|
score: float = 0.0
|
|
|
|
|
|
@dataclass
|
|
class JointCandidate:
|
|
intent_id: str
|
|
score: float
|
|
|
|
|
|
@dataclass
|
|
class JointNluResult:
|
|
intent_id: str | None = None
|
|
intent_score: float = 0.0
|
|
candidates: list[JointCandidate] = field(default_factory=list)
|
|
multi_intent_candidates: list[JointCandidate] = field(default_factory=list)
|
|
slots: dict[str, Any] = field(default_factory=dict)
|
|
slot_items: list[JointSlot] = field(default_factory=list)
|
|
model_name: str = "joint-bert-local"
|
|
backend_name: str = "joint-bert-local"
|
|
error_message: str | None = None
|
|
|
|
|
|
class JointBertForNLU(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
base_model_name: str,
|
|
num_intents: int,
|
|
num_slot_labels: int,
|
|
encoder_config_path: str | Path | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
if encoder_config_path is not None:
|
|
encoder_config = AutoConfig.from_pretrained(encoder_config_path, local_files_only=True)
|
|
self.encoder = AutoModel.from_config(encoder_config)
|
|
else:
|
|
self.encoder = AutoModel.from_pretrained(base_model_name)
|
|
hidden_size = int(self.encoder.config.hidden_size)
|
|
dropout_prob = float(getattr(self.encoder.config, "hidden_dropout_prob", 0.1))
|
|
self.dropout = torch.nn.Dropout(dropout_prob)
|
|
self.intent_classifier = torch.nn.Linear(hidden_size, num_intents)
|
|
self.slot_classifier = torch.nn.Linear(hidden_size, num_slot_labels)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
token_type_ids: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
encoder_kwargs = {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
}
|
|
if token_type_ids is not None:
|
|
encoder_kwargs["token_type_ids"] = token_type_ids
|
|
outputs = self.encoder(**encoder_kwargs)
|
|
sequence_output = self.dropout(outputs.last_hidden_state)
|
|
pooled_output = self.dropout(sequence_output[:, 0])
|
|
intent_logits = self.intent_classifier(pooled_output)
|
|
slot_logits = self.slot_classifier(sequence_output)
|
|
return intent_logits, slot_logits
|
|
|
|
|
|
class JointBertNLU:
|
|
def __init__(
|
|
self,
|
|
model_path: str,
|
|
intent_threshold: float | None = None,
|
|
multi_intent_threshold: float | None = None,
|
|
top_k: int = 5,
|
|
max_multi_intents: int = 4,
|
|
max_cache_size: int = 8,
|
|
) -> None:
|
|
self._model_path = Path(model_path)
|
|
self._intent_threshold = intent_threshold
|
|
self._multi_intent_threshold = multi_intent_threshold
|
|
self._top_k = top_k
|
|
self._max_multi_intents = max_multi_intents
|
|
self._max_cache_size = max_cache_size
|
|
self._runtime: tuple[AutoTokenizer, JointBertForNLU, dict[str, Any], torch.device] | None = None
|
|
self._warmup_elapsed_ms: float | None = None
|
|
self._warmup_error_message: str | None = None
|
|
self._warmed_up = False
|
|
self._cache: OrderedDict[str, dict[str, Any]] = OrderedDict()
|
|
|
|
def warmup(self, sample_text: str = "把空调调到22度") -> bool:
|
|
started_at = perf_counter()
|
|
try:
|
|
self._predict_raw(sample_text)
|
|
except Exception as exc:
|
|
self._warmup_error_message = str(exc)
|
|
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
|
return False
|
|
self._warmup_error_message = None
|
|
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
|
self._warmed_up = True
|
|
return True
|
|
|
|
def predict(self, text: str, intents: list[IntentDefinition]) -> JointNluResult:
|
|
try:
|
|
raw_result = self._predict_raw(text)
|
|
except Exception as exc:
|
|
return JointNluResult(error_message=str(exc))
|
|
candidates = self._filter_known_candidates(raw_result["candidates"], intents, limit=self._top_k)
|
|
multi_candidates = self.predict_multi_intents(text, intents)
|
|
top_candidate = candidates[0] if candidates else None
|
|
if top_candidate is None or top_candidate.score < self._resolved_intent_threshold():
|
|
return JointNluResult(
|
|
intent_id=None,
|
|
intent_score=top_candidate.score if top_candidate is not None else 0.0,
|
|
candidates=candidates,
|
|
multi_intent_candidates=multi_candidates,
|
|
slots={},
|
|
slot_items=[],
|
|
)
|
|
intent_def = next((intent for intent in intents if intent.intent_id == top_candidate.intent_id), None)
|
|
if intent_def is None:
|
|
return JointNluResult(
|
|
intent_id=None,
|
|
intent_score=top_candidate.score,
|
|
candidates=candidates,
|
|
multi_intent_candidates=multi_candidates,
|
|
slots={},
|
|
slot_items=[],
|
|
)
|
|
slot_items = self._filter_slot_items(raw_result["slot_items"], intent_def.intent_id, intent_def.required_slots)
|
|
return JointNluResult(
|
|
intent_id=top_candidate.intent_id,
|
|
intent_score=top_candidate.score,
|
|
candidates=candidates,
|
|
multi_intent_candidates=multi_candidates,
|
|
slots=self._slot_items_to_dict(slot_items),
|
|
slot_items=slot_items,
|
|
)
|
|
|
|
def predict_multi_intents(
|
|
self,
|
|
text: str,
|
|
intents: list[IntentDefinition],
|
|
threshold: float | None = None,
|
|
max_labels: int | None = None,
|
|
top_k: int | None = None,
|
|
) -> list[JointCandidate]:
|
|
try:
|
|
raw_result = self._predict_raw(text)
|
|
except Exception:
|
|
return []
|
|
threshold = self._multi_intent_threshold if threshold is None else threshold
|
|
if threshold is None:
|
|
threshold = self._resolved_multi_intent_threshold()
|
|
max_labels = self._max_multi_intents if max_labels is None else max_labels
|
|
ranked = self._filter_known_candidates(raw_result["candidates"], intents, limit=top_k or self._top_k)
|
|
selected: list[JointCandidate] = []
|
|
for item in ranked:
|
|
if item.score < threshold:
|
|
continue
|
|
selected.append(item)
|
|
if len(selected) >= max_labels:
|
|
break
|
|
return selected
|
|
|
|
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
|
|
try:
|
|
raw_result = self._predict_raw(text)
|
|
except Exception:
|
|
return {}
|
|
slot_items = self._filter_slot_items(raw_result["slot_items"], intent.intent_id, intent.required_slots)
|
|
return self._slot_items_to_dict(slot_items)
|
|
|
|
def extract_slots_by_intent_id(
|
|
self,
|
|
text: str,
|
|
intent_id: str,
|
|
required_slots: list[str] | None = None,
|
|
) -> dict[str, Any]:
|
|
try:
|
|
raw_result = self._predict_raw(text)
|
|
except Exception:
|
|
return {}
|
|
slot_items = self._filter_slot_items(raw_result["slot_items"], intent_id, required_slots or [])
|
|
return self._slot_items_to_dict(slot_items)
|
|
|
|
def _filter_known_candidates(
|
|
self,
|
|
candidates: list[JointCandidate],
|
|
intents: list[IntentDefinition],
|
|
limit: int | None = None,
|
|
) -> list[JointCandidate]:
|
|
known_intents = {intent.intent_id for intent in intents}
|
|
filtered = [
|
|
item
|
|
for item in candidates
|
|
if item.intent_id in known_intents and item.intent_id not in BLOCKED_INTENT_LABELS
|
|
]
|
|
return filtered[:limit] if limit is not None else filtered
|
|
|
|
def _slot_items_to_dict(self, slot_items: list[JointSlot]) -> dict[str, Any]:
|
|
slots: dict[str, Any] = {}
|
|
for item in slot_items:
|
|
if item.slot_name == "temperature":
|
|
digits = "".join(ch for ch in item.value if ch.isdigit())
|
|
if digits:
|
|
slots[item.slot_name] = int(digits)
|
|
continue
|
|
slots[item.slot_name] = item.value
|
|
return slots
|
|
|
|
def _filter_slot_items(
|
|
self,
|
|
slot_items: list[JointSlot],
|
|
intent_id: str,
|
|
required_slots: list[str],
|
|
) -> list[JointSlot]:
|
|
allowed = allowed_slot_names(intent_id, required_slots)
|
|
if not allowed:
|
|
return []
|
|
filtered = [item for item in slot_items if item.slot_name in allowed]
|
|
deduped: list[JointSlot] = []
|
|
seen: set[tuple[str, int, int]] = set()
|
|
for item in filtered:
|
|
key = (item.slot_name, item.start, item.end)
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
deduped.append(item)
|
|
return deduped
|
|
|
|
def _predict_raw(self, text: str) -> dict[str, Any]:
|
|
normalized = (text or "").strip()
|
|
if not normalized:
|
|
return {"candidates": [], "slot_items": []}
|
|
if normalized in self._cache:
|
|
cached = self._cache.pop(normalized)
|
|
self._cache[normalized] = cached
|
|
return cached
|
|
tokenizer, model, metadata, device = self._load_runtime()
|
|
encoded = tokenizer(
|
|
normalized,
|
|
truncation=True,
|
|
max_length=int(metadata.get("max_length", 64)),
|
|
return_offsets_mapping=True,
|
|
return_tensors="pt",
|
|
)
|
|
offset_mapping = encoded.pop("offset_mapping")[0].tolist()
|
|
encoded = {key: value.to(device) for key, value in encoded.items()}
|
|
model.eval()
|
|
with torch.no_grad():
|
|
intent_logits, slot_logits = model(**encoded)
|
|
slot_probs = torch.softmax(slot_logits, dim=-1)[0].detach().cpu()
|
|
slot_ids = torch.argmax(slot_probs, dim=-1).tolist()
|
|
intent_probs = self._intent_probabilities(intent_logits.detach().cpu()[0], metadata)
|
|
intent_labels = metadata.get("intent_labels", [])
|
|
slot_labels = metadata.get("slot_labels", [])
|
|
candidates = [
|
|
JointCandidate(intent_id=str(intent_labels[index]), score=float(score))
|
|
for index, score in sorted(
|
|
list(enumerate(intent_probs)),
|
|
key=lambda item: item[1],
|
|
reverse=True,
|
|
)
|
|
]
|
|
slot_items = self._decode_slot_items(
|
|
text=normalized,
|
|
offset_mapping=offset_mapping,
|
|
slot_ids=slot_ids,
|
|
slot_probs=slot_probs,
|
|
slot_labels=slot_labels,
|
|
)
|
|
result = {
|
|
"candidates": candidates,
|
|
"slot_items": slot_items,
|
|
}
|
|
self._cache[normalized] = result
|
|
while len(self._cache) > self._max_cache_size:
|
|
self._cache.popitem(last=False)
|
|
return result
|
|
|
|
def _intent_probabilities(self, intent_logits: torch.Tensor, metadata: dict[str, Any]) -> list[float]:
|
|
task_type = str(metadata.get("intent_task", "single_label")).strip() or "single_label"
|
|
if task_type == "multi_label":
|
|
return torch.sigmoid(intent_logits).tolist()
|
|
return torch.softmax(intent_logits, dim=-1).tolist()
|
|
|
|
def _decode_slot_items(
|
|
self,
|
|
text: str,
|
|
offset_mapping: list[list[int]],
|
|
slot_ids: list[int],
|
|
slot_probs: torch.Tensor,
|
|
slot_labels: list[str],
|
|
) -> list[JointSlot]:
|
|
items: list[JointSlot] = []
|
|
current_name: str | None = None
|
|
current_start: int | None = None
|
|
current_end: int | None = None
|
|
current_scores: list[float] = []
|
|
|
|
def flush() -> None:
|
|
nonlocal current_name, current_start, current_end, current_scores
|
|
if current_name is None or current_start is None or current_end is None or current_start >= current_end:
|
|
current_name = None
|
|
current_start = None
|
|
current_end = None
|
|
current_scores = []
|
|
return
|
|
value = text[current_start:current_end].strip()
|
|
if value:
|
|
items.append(
|
|
JointSlot(
|
|
slot_name=current_name,
|
|
value=value,
|
|
start=current_start,
|
|
end=current_end,
|
|
score=round(sum(current_scores) / max(len(current_scores), 1), 4),
|
|
)
|
|
)
|
|
current_name = None
|
|
current_start = None
|
|
current_end = None
|
|
current_scores = []
|
|
|
|
for index, label_id in enumerate(slot_ids):
|
|
if index >= len(offset_mapping):
|
|
break
|
|
start, end = offset_mapping[index]
|
|
if end <= start:
|
|
flush()
|
|
continue
|
|
label = str(slot_labels[label_id]) if label_id < len(slot_labels) else "O"
|
|
token_score = float(slot_probs[index][label_id].item())
|
|
if label == "O":
|
|
flush()
|
|
continue
|
|
prefix, _, name = label.partition("-")
|
|
if prefix == "B" or current_name != name:
|
|
flush()
|
|
current_name = name
|
|
current_start = start
|
|
current_end = end
|
|
current_scores = [token_score]
|
|
continue
|
|
current_end = end
|
|
current_scores.append(token_score)
|
|
flush()
|
|
return items
|
|
|
|
def _load_runtime(self) -> tuple[AutoTokenizer, JointBertForNLU, dict[str, Any], torch.device]:
|
|
if self._runtime is not None:
|
|
return self._runtime
|
|
if not self._model_path.exists():
|
|
raise FileNotFoundError(f"joint nlu model path not found: {self._model_path}")
|
|
metadata_path = self._model_path / "joint_nlu_config.json"
|
|
state_dict_path = self._model_path / "model_state.pt"
|
|
if not metadata_path.exists():
|
|
raise FileNotFoundError(f"joint nlu config missing: {metadata_path}")
|
|
if not state_dict_path.exists():
|
|
raise FileNotFoundError(f"joint nlu model state missing: {state_dict_path}")
|
|
metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
|
|
tokenizer = AutoTokenizer.from_pretrained(self._model_path)
|
|
model = JointBertForNLU(
|
|
base_model_name=str(metadata["base_model_name"]),
|
|
num_intents=len(metadata["intent_labels"]),
|
|
num_slot_labels=len(metadata["slot_labels"]),
|
|
encoder_config_path=self._resolve_encoder_config_path(metadata),
|
|
)
|
|
state_dict = torch.load(state_dict_path, map_location="cpu")
|
|
model.load_state_dict(state_dict)
|
|
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
|
model.to(device)
|
|
self._runtime = (tokenizer, model, metadata, device)
|
|
return self._runtime
|
|
|
|
def _resolve_encoder_config_path(self, metadata: dict[str, Any]) -> Path | None:
|
|
local_config = self._model_path / "config.json"
|
|
if local_config.exists():
|
|
return self._model_path
|
|
|
|
base_model_path = Path(str(metadata.get("base_model_name", "")))
|
|
if base_model_path.exists() and (base_model_path / "config.json").exists():
|
|
return base_model_path
|
|
|
|
for candidate_name in ("local_bert_intent", "local_bert_multi_intent"):
|
|
candidate_path = self._model_path.parent / candidate_name
|
|
if (candidate_path / "config.json").exists():
|
|
return candidate_path
|
|
return None
|
|
|
|
def _resolved_intent_threshold(self) -> float:
|
|
if self._intent_threshold is not None:
|
|
return self._intent_threshold
|
|
metadata = self._runtime[2] if self._runtime is not None else {}
|
|
return float(metadata.get("intent_threshold", 0.35))
|
|
|
|
def _resolved_multi_intent_threshold(self) -> float:
|
|
if self._multi_intent_threshold is not None:
|
|
return self._multi_intent_threshold
|
|
metadata = self._runtime[2] if self._runtime is not None else {}
|
|
return float(metadata.get("multi_intent_threshold", metadata.get("intent_threshold", 0.45)))
|