from __future__ import annotations import importlib import json import re from dataclasses import dataclass from pathlib import Path from time import perf_counter from typing import Any, Protocol from urllib import error, request from app.schemas.intent import IntentDefinition from app.services.joint_nlu import JointBertNLU @dataclass class ClassificationResult: intent: IntentDefinition | None score: float = 0.0 model_name: str = "mock-classifier" candidates: list[tuple[IntentDefinition, float]] | None = None backend_name: str | None = None used_fallback: bool = False fallback_reason: str | None = None error_message: str | None = None raw_label: str | None = None raw_candidates: list[dict[str, Any]] | None = None class IntentClassifier(Protocol): def predict(self, text: str, intents: list[IntentDefinition]) -> ClassificationResult: ... class MockIntentClassifier: """A local classifier stub that mimics a BERT-style scoring interface.""" def __init__(self, threshold: float = 1.2, top_k: int = 3) -> None: self._threshold = threshold self._top_k = top_k def predict(self, text: str, intents: list[IntentDefinition]) -> ClassificationResult: query_tokens = self._tokenize(text) if not query_tokens: return ClassificationResult( intent=None, score=0.0, backend_name="mock-classifier", raw_candidates=[], ) scored_intents: list[tuple[IntentDefinition, float]] = [] for intent in intents: score = self._score_intent(query_tokens, intent) scored_intents.append((intent, score)) scored_intents.sort(key=lambda item: item[1], reverse=True) best_intent: IntentDefinition | None = scored_intents[0][0] if scored_intents else None best_score = scored_intents[0][1] if scored_intents else 0.0 top_candidates = [(intent, score) for intent, score in scored_intents[: self._top_k] if score > 0] raw_candidates = [ { "label": intent.intent_id, "intent_id": intent.intent_id, "score": score, } for intent, score in top_candidates ] if best_score < self._threshold: return ClassificationResult( intent=None, score=best_score, candidates=top_candidates, backend_name="mock-classifier", fallback_reason="below threshold", raw_candidates=raw_candidates, ) return ClassificationResult( intent=best_intent, score=best_score, candidates=top_candidates, backend_name="mock-classifier", raw_label=best_intent.intent_id if best_intent is not None else None, raw_candidates=raw_candidates, ) def _score_intent(self, query_tokens: set[str], intent: IntentDefinition) -> float: score = 0.0 for keyword in intent.keywords: keyword_tokens = self._tokenize(keyword) overlap = len(query_tokens & keyword_tokens) if overlap: score += overlap * 1.4 for example in intent.examples: example_tokens = self._tokenize(example) overlap = len(query_tokens & example_tokens) if not example_tokens: continue coverage = overlap / len(example_tokens) score = max(score, overlap + coverage) if intent.domain == "customer_service" and any(token in query_tokens for token in {"订单", "物流", "快递"}): score += 0.2 if intent.domain == "cabin" and any(token in query_tokens for token in {"导航", "空调", "音乐", "歌曲"}): score += 0.2 return score def _tokenize(self, text: str) -> set[str]: cleaned = re.sub(r"[,。!?、\s]+", " ", text.strip().lower()) tokens = {token for token in cleaned.split(" ") if token} compact = cleaned.replace(" ", "") for size in (2, 3, 4): for index in range(0, max(len(compact) - size + 1, 0)): tokens.add(compact[index : index + size]) return tokens class BertIntentClassifier: """ A pluggable local classifier interface for future BERT/Transformer models. Expected model behavior: - input: user text - output labels: intent_id strings, or numeric indices mapped through a label file """ def __init__( self, model_path: str, threshold: float = 0.5, label_map_path: str | None = None, fallback: IntentClassifier | None = None, top_k: int = 3, ) -> None: self._model_path = model_path self._threshold = threshold self._label_map_path = label_map_path self._fallback = fallback self._top_k = top_k self._pipeline = None self._label_map = self._load_label_map(label_map_path) self._warmed_up = False self._warmup_elapsed_ms: float | None = None self._warmup_error_message: str | None = None def predict(self, text: str, intents: list[IntentDefinition]) -> ClassificationResult: pipeline = self._get_pipeline() if pipeline is None: return self._predict_with_fallback( text, intents, attempted_backend="bert-local", fallback_reason="bert model is unavailable", error_message=self._pipeline_error_message(), ) try: raw_output = pipeline(text, truncation=True, top_k=self._top_k) except Exception as exc: return self._predict_with_fallback( text, intents, attempted_backend="bert-local", fallback_reason="bert inference failed", error_message=str(exc), ) normalized = self._normalize_pipeline_candidates(raw_output) if not normalized: return self._predict_with_fallback( text, intents, attempted_backend="bert-local", fallback_reason="bert returned empty result", raw_candidates=[], ) resolved = self._resolve_candidates(normalized, intents) top_candidate = resolved["top_candidate"] if top_candidate["intent"] is None: return self._predict_with_fallback( text, intents, attempted_backend="bert-local", fallback_reason="bert label is not mapped to a known intent", score=top_candidate["score"], raw_label=top_candidate["label"], raw_candidates=normalized, ) if top_candidate["score"] < self._threshold: return self._predict_with_fallback( text, intents, attempted_backend="bert-local", fallback_reason="bert score is below threshold", score=top_candidate["score"], raw_label=top_candidate["label"], raw_candidates=normalized, ) return ClassificationResult( intent=top_candidate["intent"], score=top_candidate["score"], model_name="bert-local", candidates=resolved["known_candidates"], backend_name="bert-local", raw_label=top_candidate["label"], raw_candidates=normalized, ) def _get_pipeline(self): if self._pipeline is not None: return self._pipeline if not self._model_path or not Path(self._model_path).exists(): return None try: transformers = importlib.import_module("transformers") except ImportError: return None self._pipeline = transformers.pipeline( "text-classification", model=self._model_path, tokenizer=self._model_path, ) return self._pipeline def warmup(self, sample_text: str = "打开车窗") -> bool: if self._warmed_up: return True started_at = perf_counter() pipeline = self._get_pipeline() if pipeline is None: self._warmup_error_message = self._pipeline_error_message() self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3) return False try: warmup_texts = [sample_text.strip() or "打开车窗", f"请帮我{sample_text.strip() or '打开车窗'}"] for text in dict.fromkeys(warmup_texts): pipeline(text, truncation=True, top_k=self._top_k) except Exception as exc: self._warmup_error_message = str(exc) self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3) return False self._warmup_error_message = None self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3) self._warmed_up = True return True def _pipeline_error_message(self) -> str: if not self._model_path: return "AGENT_CLASSIFIER_MODEL_PATH is empty" if not Path(self._model_path).exists(): return f"model path not found: {self._model_path}" try: importlib.import_module("transformers") except ImportError: return "transformers is not installed" return "pipeline init failed" def _load_label_map(self, label_map_path: str | None) -> dict[str, str]: if not label_map_path: return {} path = Path(label_map_path) if not path.exists(): return {} data = json.loads(path.read_text(encoding="utf-8")) return {str(key): str(value) for key, value in data.items()} def _resolve_label(self, label: str) -> str | None: if label in self._label_map: return self._label_map[label] return label or None def _normalize_pipeline_candidates(self, raw_output: Any) -> list[dict[str, Any]]: items = raw_output if isinstance(items, list) and items and isinstance(items[0], list): items = items[0] if not isinstance(items, list): items = [items] normalized: list[dict[str, Any]] = [] for item in items: if not isinstance(item, dict): continue label = str(item.get("label", "")) score = float(item.get("score", 0.0)) normalized.append( { "label": label, "intent_id": self._resolve_label(label), "score": score, } ) return normalized[: self._top_k] def _resolve_candidates( self, normalized_candidates: list[dict[str, Any]], intents: list[IntentDefinition], ) -> dict[str, Any]: intent_map = {intent.intent_id: intent for intent in intents} known_candidates: list[tuple[IntentDefinition, float]] = [] resolved_items: list[dict[str, Any]] = [] for item in normalized_candidates: intent = intent_map.get(str(item.get("intent_id") or "")) if intent is not None: known_candidates.append((intent, float(item.get("score", 0.0)))) resolved_items.append( { "intent": intent, "label": str(item.get("label", "")), "score": float(item.get("score", 0.0)), } ) return { "known_candidates": known_candidates, "top_candidate": resolved_items[0], } def _predict_with_fallback( self, text: str, intents: list[IntentDefinition], attempted_backend: str, fallback_reason: str, score: float = 0.0, raw_label: str | None = None, raw_candidates: list[dict[str, Any]] | None = None, error_message: str | None = None, ) -> ClassificationResult: if self._fallback is None: return ClassificationResult( intent=None, score=score, model_name=attempted_backend, backend_name=attempted_backend, used_fallback=False, fallback_reason=fallback_reason, error_message=error_message, raw_label=raw_label, raw_candidates=raw_candidates or [], ) fallback_result = self._fallback.predict(text, intents) return ClassificationResult( intent=fallback_result.intent, score=fallback_result.score, model_name=fallback_result.model_name, candidates=fallback_result.candidates, backend_name=attempted_backend, used_fallback=True, fallback_reason=fallback_reason, error_message=error_message, raw_label=raw_label, raw_candidates=raw_candidates or fallback_result.raw_candidates or [], ) class JointBertIntentClassifier: def __init__( self, nlu: JointBertNLU, threshold: float = 0.35, top_k: int = 3, ) -> None: self._nlu = nlu self._threshold = threshold self._top_k = top_k def warmup(self, sample_text: str = "打开车窗") -> bool: return self._nlu.warmup(sample_text) def predict(self, text: str, intents: list[IntentDefinition]) -> ClassificationResult: result = self._nlu.predict(text, intents) raw_candidates = [ { "label": item.intent_id, "intent_id": item.intent_id, "score": item.score, } for item in result.candidates[: self._top_k] ] known_candidates = [ (intent, item.score) for item in result.candidates[: self._top_k] for intent in intents if intent.intent_id == item.intent_id ] selected_intent = next((intent for intent in intents if intent.intent_id == result.intent_id), None) if selected_intent is None or result.intent_score < self._threshold: return ClassificationResult( intent=None, score=result.intent_score, model_name=result.model_name, candidates=known_candidates, backend_name=result.backend_name, fallback_reason=result.error_message or "joint bert score is below threshold or no intent selected", error_message=result.error_message, raw_label=result.intent_id, raw_candidates=raw_candidates, ) return ClassificationResult( intent=selected_intent, score=result.intent_score, model_name=result.model_name, candidates=known_candidates, backend_name=result.backend_name, raw_label=result.intent_id, raw_candidates=raw_candidates, ) class RemoteIntentClassifier: """ A remote classifier client. Expected response payload: { "intent_id": "cs_query_order", "score": 0.98, "model_name": "bert-remote" } """ def __init__( self, endpoint: str, timeout_seconds: float = 3.0, threshold: float = 0.5, fallback: IntentClassifier | None = None, label_map_path: str | None = None, top_k: int = 3, ) -> None: self._endpoint = endpoint self._timeout_seconds = timeout_seconds self._threshold = threshold self._fallback = fallback self._label_map = self._load_label_map(label_map_path) self._top_k = top_k def predict(self, text: str, intents: list[IntentDefinition]) -> ClassificationResult: if not self._endpoint: return self._predict_with_fallback( text, intents, attempted_backend="remote-classifier", fallback_reason="remote endpoint is not configured", error_message="AGENT_CLASSIFIER_REMOTE_URL is empty", ) payload = json.dumps( { "text": text, "top_k": self._top_k, "labels": [intent.intent_id for intent in intents], } ).encode("utf-8") req = request.Request( self._endpoint, data=payload, headers={"Content-Type": "application/json"}, method="POST", ) try: with request.urlopen(req, timeout=self._timeout_seconds) as response: data = json.loads(response.read().decode("utf-8")) except (error.URLError, TimeoutError, ValueError) as exc: return self._predict_with_fallback( text, intents, attempted_backend="remote-classifier", fallback_reason="remote inference failed", error_message=str(exc), ) model_name = str(data.get("model_name", "remote-classifier")) normalized = self._normalize_remote_candidates(data) if not normalized: return self._predict_with_fallback( text, intents, attempted_backend=model_name, fallback_reason="remote response has no candidates", raw_candidates=[], ) intent_map = {intent.intent_id: intent for intent in intents} known_candidates = [ (intent_map[item["intent_id"]], item["score"]) for item in normalized if item["intent_id"] in intent_map ] top_candidate = normalized[0] selected_intent = intent_map.get(top_candidate["intent_id"]) if selected_intent is None: return self._predict_with_fallback( text, intents, attempted_backend=model_name, fallback_reason="remote top label is not mapped to a known intent", score=top_candidate["score"], raw_label=top_candidate["label"], raw_candidates=normalized, ) if top_candidate["score"] < self._threshold: return self._predict_with_fallback( text, intents, attempted_backend=model_name, fallback_reason="remote score is below threshold", score=top_candidate["score"], raw_label=top_candidate["label"], raw_candidates=normalized, ) return ClassificationResult( intent=selected_intent, score=top_candidate["score"], model_name=model_name, candidates=known_candidates, backend_name=model_name, raw_label=top_candidate["label"], raw_candidates=normalized, ) def _load_label_map(self, label_map_path: str | None) -> dict[str, str]: if not label_map_path: return {} path = Path(label_map_path) if not path.exists(): return {} data = json.loads(path.read_text(encoding="utf-8")) return {str(key): str(value) for key, value in data.items()} def _resolve_label(self, label: str) -> str | None: if label in self._label_map: return self._label_map[label] return label or None def _normalize_remote_candidates(self, data: dict[str, Any]) -> list[dict[str, Any]]: raw_candidates = data.get("candidates") or data.get("predictions") or [] if not raw_candidates and data.get("intent_id"): raw_candidates = [ { "intent_id": data.get("intent_id"), "label": data.get("label") or data.get("intent_id"), "score": data.get("score", 0.0), } ] normalized: list[dict[str, Any]] = [] for item in raw_candidates: if not isinstance(item, dict): continue label = str(item.get("label") or item.get("intent_id") or "") intent_id = str(item.get("intent_id") or self._resolve_label(label) or "") normalized.append( { "label": label, "intent_id": intent_id, "score": float(item.get("score", 0.0)), } ) return normalized[: self._top_k] def _predict_with_fallback( self, text: str, intents: list[IntentDefinition], attempted_backend: str, fallback_reason: str, score: float = 0.0, raw_label: str | None = None, raw_candidates: list[dict[str, Any]] | None = None, error_message: str | None = None, ) -> ClassificationResult: if self._fallback is None: return ClassificationResult( intent=None, score=score, model_name=attempted_backend, backend_name=attempted_backend, used_fallback=False, fallback_reason=fallback_reason, error_message=error_message, raw_label=raw_label, raw_candidates=raw_candidates or [], ) fallback_result = self._fallback.predict(text, intents) return ClassificationResult( intent=fallback_result.intent, score=fallback_result.score, model_name=fallback_result.model_name, candidates=fallback_result.candidates, backend_name=attempted_backend, used_fallback=True, fallback_reason=fallback_reason, error_message=error_message, raw_label=raw_label, raw_candidates=raw_candidates or fallback_result.raw_candidates or [], )