Update project and configurations
This commit is contained in:
430
intelligent_cabin/app/services/joint_nlu.py
Normal file
430
intelligent_cabin/app/services/joint_nlu.py
Normal file
@@ -0,0 +1,430 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
||||
|
||||
from app.schemas.intent import IntentDefinition
|
||||
|
||||
|
||||
OPTIONAL_SLOT_NAMES_BY_INTENT: dict[str, set[str]] = {
|
||||
"cabin_play_music": {"song", "genre"},
|
||||
}
|
||||
|
||||
BLOCKED_INTENT_LABELS = {"__social__", "__out_of_scope__"}
|
||||
|
||||
|
||||
def allowed_slot_names(intent_id: str, required_slots: list[str] | None = None) -> set[str]:
|
||||
required = set(required_slots or [])
|
||||
return required | OPTIONAL_SLOT_NAMES_BY_INTENT.get(intent_id, set())
|
||||
|
||||
|
||||
@dataclass
|
||||
class JointSlot:
|
||||
slot_name: str
|
||||
value: str
|
||||
start: int
|
||||
end: int
|
||||
score: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class JointCandidate:
|
||||
intent_id: str
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class JointNluResult:
|
||||
intent_id: str | None = None
|
||||
intent_score: float = 0.0
|
||||
candidates: list[JointCandidate] = field(default_factory=list)
|
||||
multi_intent_candidates: list[JointCandidate] = field(default_factory=list)
|
||||
slots: dict[str, Any] = field(default_factory=dict)
|
||||
slot_items: list[JointSlot] = field(default_factory=list)
|
||||
model_name: str = "joint-bert-local"
|
||||
backend_name: str = "joint-bert-local"
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class JointBertForNLU(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
base_model_name: str,
|
||||
num_intents: int,
|
||||
num_slot_labels: int,
|
||||
encoder_config_path: str | Path | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if encoder_config_path is not None:
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_config_path, local_files_only=True)
|
||||
self.encoder = AutoModel.from_config(encoder_config)
|
||||
else:
|
||||
self.encoder = AutoModel.from_pretrained(base_model_name)
|
||||
hidden_size = int(self.encoder.config.hidden_size)
|
||||
dropout_prob = float(getattr(self.encoder.config, "hidden_dropout_prob", 0.1))
|
||||
self.dropout = torch.nn.Dropout(dropout_prob)
|
||||
self.intent_classifier = torch.nn.Linear(hidden_size, num_intents)
|
||||
self.slot_classifier = torch.nn.Linear(hidden_size, num_slot_labels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
token_type_ids: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
encoder_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
if token_type_ids is not None:
|
||||
encoder_kwargs["token_type_ids"] = token_type_ids
|
||||
outputs = self.encoder(**encoder_kwargs)
|
||||
sequence_output = self.dropout(outputs.last_hidden_state)
|
||||
pooled_output = self.dropout(sequence_output[:, 0])
|
||||
intent_logits = self.intent_classifier(pooled_output)
|
||||
slot_logits = self.slot_classifier(sequence_output)
|
||||
return intent_logits, slot_logits
|
||||
|
||||
|
||||
class JointBertNLU:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
intent_threshold: float | None = None,
|
||||
multi_intent_threshold: float | None = None,
|
||||
top_k: int = 5,
|
||||
max_multi_intents: int = 4,
|
||||
max_cache_size: int = 8,
|
||||
) -> None:
|
||||
self._model_path = Path(model_path)
|
||||
self._intent_threshold = intent_threshold
|
||||
self._multi_intent_threshold = multi_intent_threshold
|
||||
self._top_k = top_k
|
||||
self._max_multi_intents = max_multi_intents
|
||||
self._max_cache_size = max_cache_size
|
||||
self._runtime: tuple[AutoTokenizer, JointBertForNLU, dict[str, Any], torch.device] | None = None
|
||||
self._warmup_elapsed_ms: float | None = None
|
||||
self._warmup_error_message: str | None = None
|
||||
self._warmed_up = False
|
||||
self._cache: OrderedDict[str, dict[str, Any]] = OrderedDict()
|
||||
|
||||
def warmup(self, sample_text: str = "把空调调到22度") -> bool:
|
||||
started_at = perf_counter()
|
||||
try:
|
||||
self._predict_raw(sample_text)
|
||||
except Exception as exc:
|
||||
self._warmup_error_message = str(exc)
|
||||
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
||||
return False
|
||||
self._warmup_error_message = None
|
||||
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
||||
self._warmed_up = True
|
||||
return True
|
||||
|
||||
def predict(self, text: str, intents: list[IntentDefinition]) -> JointNluResult:
|
||||
try:
|
||||
raw_result = self._predict_raw(text)
|
||||
except Exception as exc:
|
||||
return JointNluResult(error_message=str(exc))
|
||||
candidates = self._filter_known_candidates(raw_result["candidates"], intents, limit=self._top_k)
|
||||
multi_candidates = self.predict_multi_intents(text, intents)
|
||||
top_candidate = candidates[0] if candidates else None
|
||||
if top_candidate is None or top_candidate.score < self._resolved_intent_threshold():
|
||||
return JointNluResult(
|
||||
intent_id=None,
|
||||
intent_score=top_candidate.score if top_candidate is not None else 0.0,
|
||||
candidates=candidates,
|
||||
multi_intent_candidates=multi_candidates,
|
||||
slots={},
|
||||
slot_items=[],
|
||||
)
|
||||
intent_def = next((intent for intent in intents if intent.intent_id == top_candidate.intent_id), None)
|
||||
if intent_def is None:
|
||||
return JointNluResult(
|
||||
intent_id=None,
|
||||
intent_score=top_candidate.score,
|
||||
candidates=candidates,
|
||||
multi_intent_candidates=multi_candidates,
|
||||
slots={},
|
||||
slot_items=[],
|
||||
)
|
||||
slot_items = self._filter_slot_items(raw_result["slot_items"], intent_def.intent_id, intent_def.required_slots)
|
||||
return JointNluResult(
|
||||
intent_id=top_candidate.intent_id,
|
||||
intent_score=top_candidate.score,
|
||||
candidates=candidates,
|
||||
multi_intent_candidates=multi_candidates,
|
||||
slots=self._slot_items_to_dict(slot_items),
|
||||
slot_items=slot_items,
|
||||
)
|
||||
|
||||
def predict_multi_intents(
|
||||
self,
|
||||
text: str,
|
||||
intents: list[IntentDefinition],
|
||||
threshold: float | None = None,
|
||||
max_labels: int | None = None,
|
||||
top_k: int | None = None,
|
||||
) -> list[JointCandidate]:
|
||||
try:
|
||||
raw_result = self._predict_raw(text)
|
||||
except Exception:
|
||||
return []
|
||||
threshold = self._multi_intent_threshold if threshold is None else threshold
|
||||
if threshold is None:
|
||||
threshold = self._resolved_multi_intent_threshold()
|
||||
max_labels = self._max_multi_intents if max_labels is None else max_labels
|
||||
ranked = self._filter_known_candidates(raw_result["candidates"], intents, limit=top_k or self._top_k)
|
||||
selected: list[JointCandidate] = []
|
||||
for item in ranked:
|
||||
if item.score < threshold:
|
||||
continue
|
||||
selected.append(item)
|
||||
if len(selected) >= max_labels:
|
||||
break
|
||||
return selected
|
||||
|
||||
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
|
||||
try:
|
||||
raw_result = self._predict_raw(text)
|
||||
except Exception:
|
||||
return {}
|
||||
slot_items = self._filter_slot_items(raw_result["slot_items"], intent.intent_id, intent.required_slots)
|
||||
return self._slot_items_to_dict(slot_items)
|
||||
|
||||
def extract_slots_by_intent_id(
|
||||
self,
|
||||
text: str,
|
||||
intent_id: str,
|
||||
required_slots: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
raw_result = self._predict_raw(text)
|
||||
except Exception:
|
||||
return {}
|
||||
slot_items = self._filter_slot_items(raw_result["slot_items"], intent_id, required_slots or [])
|
||||
return self._slot_items_to_dict(slot_items)
|
||||
|
||||
def _filter_known_candidates(
|
||||
self,
|
||||
candidates: list[JointCandidate],
|
||||
intents: list[IntentDefinition],
|
||||
limit: int | None = None,
|
||||
) -> list[JointCandidate]:
|
||||
known_intents = {intent.intent_id for intent in intents}
|
||||
filtered = [
|
||||
item
|
||||
for item in candidates
|
||||
if item.intent_id in known_intents and item.intent_id not in BLOCKED_INTENT_LABELS
|
||||
]
|
||||
return filtered[:limit] if limit is not None else filtered
|
||||
|
||||
def _slot_items_to_dict(self, slot_items: list[JointSlot]) -> dict[str, Any]:
|
||||
slots: dict[str, Any] = {}
|
||||
for item in slot_items:
|
||||
if item.slot_name == "temperature":
|
||||
digits = "".join(ch for ch in item.value if ch.isdigit())
|
||||
if digits:
|
||||
slots[item.slot_name] = int(digits)
|
||||
continue
|
||||
slots[item.slot_name] = item.value
|
||||
return slots
|
||||
|
||||
def _filter_slot_items(
|
||||
self,
|
||||
slot_items: list[JointSlot],
|
||||
intent_id: str,
|
||||
required_slots: list[str],
|
||||
) -> list[JointSlot]:
|
||||
allowed = allowed_slot_names(intent_id, required_slots)
|
||||
if not allowed:
|
||||
return []
|
||||
filtered = [item for item in slot_items if item.slot_name in allowed]
|
||||
deduped: list[JointSlot] = []
|
||||
seen: set[tuple[str, int, int]] = set()
|
||||
for item in filtered:
|
||||
key = (item.slot_name, item.start, item.end)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(item)
|
||||
return deduped
|
||||
|
||||
def _predict_raw(self, text: str) -> dict[str, Any]:
|
||||
normalized = (text or "").strip()
|
||||
if not normalized:
|
||||
return {"candidates": [], "slot_items": []}
|
||||
if normalized in self._cache:
|
||||
cached = self._cache.pop(normalized)
|
||||
self._cache[normalized] = cached
|
||||
return cached
|
||||
tokenizer, model, metadata, device = self._load_runtime()
|
||||
encoded = tokenizer(
|
||||
normalized,
|
||||
truncation=True,
|
||||
max_length=int(metadata.get("max_length", 64)),
|
||||
return_offsets_mapping=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
offset_mapping = encoded.pop("offset_mapping")[0].tolist()
|
||||
encoded = {key: value.to(device) for key, value in encoded.items()}
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
intent_logits, slot_logits = model(**encoded)
|
||||
slot_probs = torch.softmax(slot_logits, dim=-1)[0].detach().cpu()
|
||||
slot_ids = torch.argmax(slot_probs, dim=-1).tolist()
|
||||
intent_probs = self._intent_probabilities(intent_logits.detach().cpu()[0], metadata)
|
||||
intent_labels = metadata.get("intent_labels", [])
|
||||
slot_labels = metadata.get("slot_labels", [])
|
||||
candidates = [
|
||||
JointCandidate(intent_id=str(intent_labels[index]), score=float(score))
|
||||
for index, score in sorted(
|
||||
list(enumerate(intent_probs)),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
]
|
||||
slot_items = self._decode_slot_items(
|
||||
text=normalized,
|
||||
offset_mapping=offset_mapping,
|
||||
slot_ids=slot_ids,
|
||||
slot_probs=slot_probs,
|
||||
slot_labels=slot_labels,
|
||||
)
|
||||
result = {
|
||||
"candidates": candidates,
|
||||
"slot_items": slot_items,
|
||||
}
|
||||
self._cache[normalized] = result
|
||||
while len(self._cache) > self._max_cache_size:
|
||||
self._cache.popitem(last=False)
|
||||
return result
|
||||
|
||||
def _intent_probabilities(self, intent_logits: torch.Tensor, metadata: dict[str, Any]) -> list[float]:
|
||||
task_type = str(metadata.get("intent_task", "single_label")).strip() or "single_label"
|
||||
if task_type == "multi_label":
|
||||
return torch.sigmoid(intent_logits).tolist()
|
||||
return torch.softmax(intent_logits, dim=-1).tolist()
|
||||
|
||||
def _decode_slot_items(
|
||||
self,
|
||||
text: str,
|
||||
offset_mapping: list[list[int]],
|
||||
slot_ids: list[int],
|
||||
slot_probs: torch.Tensor,
|
||||
slot_labels: list[str],
|
||||
) -> list[JointSlot]:
|
||||
items: list[JointSlot] = []
|
||||
current_name: str | None = None
|
||||
current_start: int | None = None
|
||||
current_end: int | None = None
|
||||
current_scores: list[float] = []
|
||||
|
||||
def flush() -> None:
|
||||
nonlocal current_name, current_start, current_end, current_scores
|
||||
if current_name is None or current_start is None or current_end is None or current_start >= current_end:
|
||||
current_name = None
|
||||
current_start = None
|
||||
current_end = None
|
||||
current_scores = []
|
||||
return
|
||||
value = text[current_start:current_end].strip()
|
||||
if value:
|
||||
items.append(
|
||||
JointSlot(
|
||||
slot_name=current_name,
|
||||
value=value,
|
||||
start=current_start,
|
||||
end=current_end,
|
||||
score=round(sum(current_scores) / max(len(current_scores), 1), 4),
|
||||
)
|
||||
)
|
||||
current_name = None
|
||||
current_start = None
|
||||
current_end = None
|
||||
current_scores = []
|
||||
|
||||
for index, label_id in enumerate(slot_ids):
|
||||
if index >= len(offset_mapping):
|
||||
break
|
||||
start, end = offset_mapping[index]
|
||||
if end <= start:
|
||||
flush()
|
||||
continue
|
||||
label = str(slot_labels[label_id]) if label_id < len(slot_labels) else "O"
|
||||
token_score = float(slot_probs[index][label_id].item())
|
||||
if label == "O":
|
||||
flush()
|
||||
continue
|
||||
prefix, _, name = label.partition("-")
|
||||
if prefix == "B" or current_name != name:
|
||||
flush()
|
||||
current_name = name
|
||||
current_start = start
|
||||
current_end = end
|
||||
current_scores = [token_score]
|
||||
continue
|
||||
current_end = end
|
||||
current_scores.append(token_score)
|
||||
flush()
|
||||
return items
|
||||
|
||||
def _load_runtime(self) -> tuple[AutoTokenizer, JointBertForNLU, dict[str, Any], torch.device]:
|
||||
if self._runtime is not None:
|
||||
return self._runtime
|
||||
if not self._model_path.exists():
|
||||
raise FileNotFoundError(f"joint nlu model path not found: {self._model_path}")
|
||||
metadata_path = self._model_path / "joint_nlu_config.json"
|
||||
state_dict_path = self._model_path / "model_state.pt"
|
||||
if not metadata_path.exists():
|
||||
raise FileNotFoundError(f"joint nlu config missing: {metadata_path}")
|
||||
if not state_dict_path.exists():
|
||||
raise FileNotFoundError(f"joint nlu model state missing: {state_dict_path}")
|
||||
metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
|
||||
tokenizer = AutoTokenizer.from_pretrained(self._model_path)
|
||||
model = JointBertForNLU(
|
||||
base_model_name=str(metadata["base_model_name"]),
|
||||
num_intents=len(metadata["intent_labels"]),
|
||||
num_slot_labels=len(metadata["slot_labels"]),
|
||||
encoder_config_path=self._resolve_encoder_config_path(metadata),
|
||||
)
|
||||
state_dict = torch.load(state_dict_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict)
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
model.to(device)
|
||||
self._runtime = (tokenizer, model, metadata, device)
|
||||
return self._runtime
|
||||
|
||||
def _resolve_encoder_config_path(self, metadata: dict[str, Any]) -> Path | None:
|
||||
local_config = self._model_path / "config.json"
|
||||
if local_config.exists():
|
||||
return self._model_path
|
||||
|
||||
base_model_path = Path(str(metadata.get("base_model_name", "")))
|
||||
if base_model_path.exists() and (base_model_path / "config.json").exists():
|
||||
return base_model_path
|
||||
|
||||
for candidate_name in ("local_bert_intent", "local_bert_multi_intent"):
|
||||
candidate_path = self._model_path.parent / candidate_name
|
||||
if (candidate_path / "config.json").exists():
|
||||
return candidate_path
|
||||
return None
|
||||
|
||||
def _resolved_intent_threshold(self) -> float:
|
||||
if self._intent_threshold is not None:
|
||||
return self._intent_threshold
|
||||
metadata = self._runtime[2] if self._runtime is not None else {}
|
||||
return float(metadata.get("intent_threshold", 0.35))
|
||||
|
||||
def _resolved_multi_intent_threshold(self) -> float:
|
||||
if self._multi_intent_threshold is not None:
|
||||
return self._multi_intent_threshold
|
||||
metadata = self._runtime[2] if self._runtime is not None else {}
|
||||
return float(metadata.get("multi_intent_threshold", metadata.get("intent_threshold", 0.45)))
|
||||
Reference in New Issue
Block a user