from __future__ import annotations import json from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path from time import perf_counter from typing import Any import torch from transformers import AutoConfig, AutoModel, AutoTokenizer from app.schemas.intent import IntentDefinition OPTIONAL_SLOT_NAMES_BY_INTENT: dict[str, set[str]] = { "cabin_play_music": {"song", "genre"}, } BLOCKED_INTENT_LABELS = {"__social__", "__out_of_scope__"} def allowed_slot_names(intent_id: str, required_slots: list[str] | None = None) -> set[str]: required = set(required_slots or []) return required | OPTIONAL_SLOT_NAMES_BY_INTENT.get(intent_id, set()) @dataclass class JointSlot: slot_name: str value: str start: int end: int score: float = 0.0 @dataclass class JointCandidate: intent_id: str score: float @dataclass class JointNluResult: intent_id: str | None = None intent_score: float = 0.0 candidates: list[JointCandidate] = field(default_factory=list) multi_intent_candidates: list[JointCandidate] = field(default_factory=list) slots: dict[str, Any] = field(default_factory=dict) slot_items: list[JointSlot] = field(default_factory=list) model_name: str = "joint-bert-local" backend_name: str = "joint-bert-local" error_message: str | None = None class JointBertForNLU(torch.nn.Module): def __init__( self, base_model_name: str, num_intents: int, num_slot_labels: int, encoder_config_path: str | Path | None = None, ) -> None: super().__init__() if encoder_config_path is not None: encoder_config = AutoConfig.from_pretrained(encoder_config_path, local_files_only=True) self.encoder = AutoModel.from_config(encoder_config) else: self.encoder = AutoModel.from_pretrained(base_model_name) hidden_size = int(self.encoder.config.hidden_size) dropout_prob = float(getattr(self.encoder.config, "hidden_dropout_prob", 0.1)) self.dropout = torch.nn.Dropout(dropout_prob) self.intent_classifier = torch.nn.Linear(hidden_size, num_intents) self.slot_classifier = torch.nn.Linear(hidden_size, num_slot_labels) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: encoder_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, } if token_type_ids is not None: encoder_kwargs["token_type_ids"] = token_type_ids outputs = self.encoder(**encoder_kwargs) sequence_output = self.dropout(outputs.last_hidden_state) pooled_output = self.dropout(sequence_output[:, 0]) intent_logits = self.intent_classifier(pooled_output) slot_logits = self.slot_classifier(sequence_output) return intent_logits, slot_logits class JointBertNLU: def __init__( self, model_path: str, intent_threshold: float | None = None, multi_intent_threshold: float | None = None, top_k: int = 5, max_multi_intents: int = 4, max_cache_size: int = 8, ) -> None: self._model_path = Path(model_path) self._intent_threshold = intent_threshold self._multi_intent_threshold = multi_intent_threshold self._top_k = top_k self._max_multi_intents = max_multi_intents self._max_cache_size = max_cache_size self._runtime: tuple[AutoTokenizer, JointBertForNLU, dict[str, Any], torch.device] | None = None self._warmup_elapsed_ms: float | None = None self._warmup_error_message: str | None = None self._warmed_up = False self._cache: OrderedDict[str, dict[str, Any]] = OrderedDict() def warmup(self, sample_text: str = "把空调调到22度") -> bool: started_at = perf_counter() try: self._predict_raw(sample_text) except Exception as exc: self._warmup_error_message = str(exc) self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3) return False self._warmup_error_message = None self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3) self._warmed_up = True return True def predict(self, text: str, intents: list[IntentDefinition]) -> JointNluResult: try: raw_result = self._predict_raw(text) except Exception as exc: return JointNluResult(error_message=str(exc)) candidates = self._filter_known_candidates(raw_result["candidates"], intents, limit=self._top_k) multi_candidates = self.predict_multi_intents(text, intents) top_candidate = candidates[0] if candidates else None if top_candidate is None or top_candidate.score < self._resolved_intent_threshold(): return JointNluResult( intent_id=None, intent_score=top_candidate.score if top_candidate is not None else 0.0, candidates=candidates, multi_intent_candidates=multi_candidates, slots={}, slot_items=[], ) intent_def = next((intent for intent in intents if intent.intent_id == top_candidate.intent_id), None) if intent_def is None: return JointNluResult( intent_id=None, intent_score=top_candidate.score, candidates=candidates, multi_intent_candidates=multi_candidates, slots={}, slot_items=[], ) slot_items = self._filter_slot_items(raw_result["slot_items"], intent_def.intent_id, intent_def.required_slots) return JointNluResult( intent_id=top_candidate.intent_id, intent_score=top_candidate.score, candidates=candidates, multi_intent_candidates=multi_candidates, slots=self._slot_items_to_dict(slot_items), slot_items=slot_items, ) def predict_multi_intents( self, text: str, intents: list[IntentDefinition], threshold: float | None = None, max_labels: int | None = None, top_k: int | None = None, ) -> list[JointCandidate]: try: raw_result = self._predict_raw(text) except Exception: return [] threshold = self._multi_intent_threshold if threshold is None else threshold if threshold is None: threshold = self._resolved_multi_intent_threshold() max_labels = self._max_multi_intents if max_labels is None else max_labels ranked = self._filter_known_candidates(raw_result["candidates"], intents, limit=top_k or self._top_k) selected: list[JointCandidate] = [] for item in ranked: if item.score < threshold: continue selected.append(item) if len(selected) >= max_labels: break return selected def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, Any]: try: raw_result = self._predict_raw(text) except Exception: return {} slot_items = self._filter_slot_items(raw_result["slot_items"], intent.intent_id, intent.required_slots) return self._slot_items_to_dict(slot_items) def extract_slots_by_intent_id( self, text: str, intent_id: str, required_slots: list[str] | None = None, ) -> dict[str, Any]: try: raw_result = self._predict_raw(text) except Exception: return {} slot_items = self._filter_slot_items(raw_result["slot_items"], intent_id, required_slots or []) return self._slot_items_to_dict(slot_items) def _filter_known_candidates( self, candidates: list[JointCandidate], intents: list[IntentDefinition], limit: int | None = None, ) -> list[JointCandidate]: known_intents = {intent.intent_id for intent in intents} filtered = [ item for item in candidates if item.intent_id in known_intents and item.intent_id not in BLOCKED_INTENT_LABELS ] return filtered[:limit] if limit is not None else filtered def _slot_items_to_dict(self, slot_items: list[JointSlot]) -> dict[str, Any]: slots: dict[str, Any] = {} for item in slot_items: if item.slot_name == "temperature": digits = "".join(ch for ch in item.value if ch.isdigit()) if digits: slots[item.slot_name] = int(digits) continue slots[item.slot_name] = item.value return slots def _filter_slot_items( self, slot_items: list[JointSlot], intent_id: str, required_slots: list[str], ) -> list[JointSlot]: allowed = allowed_slot_names(intent_id, required_slots) if not allowed: return [] filtered = [item for item in slot_items if item.slot_name in allowed] deduped: list[JointSlot] = [] seen: set[tuple[str, int, int]] = set() for item in filtered: key = (item.slot_name, item.start, item.end) if key in seen: continue seen.add(key) deduped.append(item) return deduped def _predict_raw(self, text: str) -> dict[str, Any]: normalized = (text or "").strip() if not normalized: return {"candidates": [], "slot_items": []} if normalized in self._cache: cached = self._cache.pop(normalized) self._cache[normalized] = cached return cached tokenizer, model, metadata, device = self._load_runtime() encoded = tokenizer( normalized, truncation=True, max_length=int(metadata.get("max_length", 64)), return_offsets_mapping=True, return_tensors="pt", ) offset_mapping = encoded.pop("offset_mapping")[0].tolist() encoded = {key: value.to(device) for key, value in encoded.items()} model.eval() with torch.no_grad(): intent_logits, slot_logits = model(**encoded) slot_probs = torch.softmax(slot_logits, dim=-1)[0].detach().cpu() slot_ids = torch.argmax(slot_probs, dim=-1).tolist() intent_probs = self._intent_probabilities(intent_logits.detach().cpu()[0], metadata) intent_labels = metadata.get("intent_labels", []) slot_labels = metadata.get("slot_labels", []) candidates = [ JointCandidate(intent_id=str(intent_labels[index]), score=float(score)) for index, score in sorted( list(enumerate(intent_probs)), key=lambda item: item[1], reverse=True, ) ] slot_items = self._decode_slot_items( text=normalized, offset_mapping=offset_mapping, slot_ids=slot_ids, slot_probs=slot_probs, slot_labels=slot_labels, ) result = { "candidates": candidates, "slot_items": slot_items, } self._cache[normalized] = result while len(self._cache) > self._max_cache_size: self._cache.popitem(last=False) return result def _intent_probabilities(self, intent_logits: torch.Tensor, metadata: dict[str, Any]) -> list[float]: task_type = str(metadata.get("intent_task", "single_label")).strip() or "single_label" if task_type == "multi_label": return torch.sigmoid(intent_logits).tolist() return torch.softmax(intent_logits, dim=-1).tolist() def _decode_slot_items( self, text: str, offset_mapping: list[list[int]], slot_ids: list[int], slot_probs: torch.Tensor, slot_labels: list[str], ) -> list[JointSlot]: items: list[JointSlot] = [] current_name: str | None = None current_start: int | None = None current_end: int | None = None current_scores: list[float] = [] def flush() -> None: nonlocal current_name, current_start, current_end, current_scores if current_name is None or current_start is None or current_end is None or current_start >= current_end: current_name = None current_start = None current_end = None current_scores = [] return value = text[current_start:current_end].strip() if value: items.append( JointSlot( slot_name=current_name, value=value, start=current_start, end=current_end, score=round(sum(current_scores) / max(len(current_scores), 1), 4), ) ) current_name = None current_start = None current_end = None current_scores = [] for index, label_id in enumerate(slot_ids): if index >= len(offset_mapping): break start, end = offset_mapping[index] if end <= start: flush() continue label = str(slot_labels[label_id]) if label_id < len(slot_labels) else "O" token_score = float(slot_probs[index][label_id].item()) if label == "O": flush() continue prefix, _, name = label.partition("-") if prefix == "B" or current_name != name: flush() current_name = name current_start = start current_end = end current_scores = [token_score] continue current_end = end current_scores.append(token_score) flush() return items def _load_runtime(self) -> tuple[AutoTokenizer, JointBertForNLU, dict[str, Any], torch.device]: if self._runtime is not None: return self._runtime if not self._model_path.exists(): raise FileNotFoundError(f"joint nlu model path not found: {self._model_path}") metadata_path = self._model_path / "joint_nlu_config.json" state_dict_path = self._model_path / "model_state.pt" if not metadata_path.exists(): raise FileNotFoundError(f"joint nlu config missing: {metadata_path}") if not state_dict_path.exists(): raise FileNotFoundError(f"joint nlu model state missing: {state_dict_path}") metadata = json.loads(metadata_path.read_text(encoding="utf-8")) tokenizer = AutoTokenizer.from_pretrained(self._model_path) model = JointBertForNLU( base_model_name=str(metadata["base_model_name"]), num_intents=len(metadata["intent_labels"]), num_slot_labels=len(metadata["slot_labels"]), encoder_config_path=self._resolve_encoder_config_path(metadata), ) state_dict = torch.load(state_dict_path, map_location="cpu") model.load_state_dict(state_dict) device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") model.to(device) self._runtime = (tokenizer, model, metadata, device) return self._runtime def _resolve_encoder_config_path(self, metadata: dict[str, Any]) -> Path | None: local_config = self._model_path / "config.json" if local_config.exists(): return self._model_path base_model_path = Path(str(metadata.get("base_model_name", ""))) if base_model_path.exists() and (base_model_path / "config.json").exists(): return base_model_path for candidate_name in ("local_bert_intent", "local_bert_multi_intent"): candidate_path = self._model_path.parent / candidate_name if (candidate_path / "config.json").exists(): return candidate_path return None def _resolved_intent_threshold(self) -> float: if self._intent_threshold is not None: return self._intent_threshold metadata = self._runtime[2] if self._runtime is not None else {} return float(metadata.get("intent_threshold", 0.35)) def _resolved_multi_intent_threshold(self) -> float: if self._multi_intent_threshold is not None: return self._multi_intent_threshold metadata = self._runtime[2] if self._runtime is not None else {} return float(metadata.get("multi_intent_threshold", metadata.get("intent_threshold", 0.45)))