from __future__ import annotations import re from dataclasses import dataclass from time import perf_counter from typing import Any, Protocol from app.schemas.debug import IntentCandidate, MatcherStageDebug, RoutingDebug from app.schemas.intent import IntentDefinition from app.services.classifier import IntentClassifier from app.services.intent_registry import IntentRegistry from app.services.joint_nlu import JointBertNLU @dataclass class IntentMatchResult: intent: IntentDefinition | None stage_debug: MatcherStageDebug @dataclass class RouteMatchResult: intent: IntentDefinition | None debug: RoutingDebug class IntentMatcher(Protocol): def match(self, text: str) -> IntentMatchResult: ... class SlotExtractor(Protocol): def extract(self, text: str, intent: IntentDefinition) -> dict[str, Any]: ... class ClassifierIntentMatcher: def __init__(self, registry: IntentRegistry, classifier: IntentClassifier) -> None: self._registry = registry self._classifier = classifier def match(self, text: str) -> IntentMatchResult: result = self._classifier.predict(text, self._registry.list()) ranked_classifier_scores = sorted( [float(score) for _, score in (result.candidates or [])], reverse=True, ) classifier_top_margin = ( ranked_classifier_scores[0] - ranked_classifier_scores[1] if len(ranked_classifier_scores) >= 2 else ranked_classifier_scores[0] if ranked_classifier_scores else 0.0 ) candidates = [ IntentCandidate( intent_id=intent.intent_id, score=score, reason="classifier candidate", model_name=result.backend_name or result.model_name, raw_label=next( ( item.get("label") for item in (result.raw_candidates or []) if item.get("intent_id") == intent.intent_id and float(item.get("score", 0.0)) == score ), intent.intent_id, ), ) for intent, score in (result.candidates or []) ] metadata: dict[str, Any] = { "decision_model": result.model_name, "threshold": getattr(self._classifier, "_threshold", None), "raw_candidates": result.raw_candidates or [], "top_margin": round(classifier_top_margin, 4), } if result.fallback_reason: metadata["fallback_reason"] = result.fallback_reason if result.intent is None: return IntentMatchResult( intent=None, stage_debug=MatcherStageDebug( stage="classifier", accepted=False, score=result.score, reason=result.fallback_reason or "classifier below threshold or no intent selected", model_name=result.model_name, backend=result.backend_name or result.model_name, fallback_used=result.used_fallback, raw_label=result.raw_label, error_message=result.error_message, metadata=metadata, candidates=candidates, ), ) return IntentMatchResult( intent=result.intent, stage_debug=MatcherStageDebug( stage="classifier", accepted=True, selected_intent=result.intent.intent_id, score=result.score, reason=( "bert classifier selected best candidate" if not result.used_fallback else f"fallback selected best candidate: {result.fallback_reason}" ), model_name=result.model_name, backend=result.backend_name or result.model_name, fallback_used=result.used_fallback, raw_label=result.raw_label, error_message=result.error_message, metadata=metadata, candidates=candidates, ), ) class MultiStageIntentMatcher: def __init__( self, registry: IntentRegistry, matchers: list[IntentMatcher], route_to_cloud_threshold: float = 0.75, clarify_margin_threshold: float = 0.12, classifier_execute_score_threshold: float = 0.55, classifier_execute_margin_threshold: float = 0.18, ) -> None: self._registry = registry self._matchers = matchers self._route_to_cloud_threshold = route_to_cloud_threshold self._clarify_margin_threshold = clarify_margin_threshold self._classifier_execute_score_threshold = classifier_execute_score_threshold self._classifier_execute_margin_threshold = classifier_execute_margin_threshold def match(self, text: str) -> RouteMatchResult: stage_traces: list[MatcherStageDebug] = [] match_started_at = perf_counter() for matcher in self._matchers: stage_started_at = perf_counter() result = matcher.match(text) result.stage_debug.elapsed_ms = round((perf_counter() - stage_started_at) * 1000, 3) stage_traces.append(result.stage_debug) fusion_started_at = perf_counter() fusion_stage = self._build_fusion_stage(stage_traces) fusion_stage.elapsed_ms = round((perf_counter() - fusion_started_at) * 1000, 3) stage_traces.append(fusion_stage) total_match_latency_ms = round((perf_counter() - match_started_at) * 1000, 3) decision = str(fusion_stage.metadata.get("decision", "reject")) confidence_grade = str(fusion_stage.metadata.get("grade", "low")) unknown_detected = bool(fusion_stage.metadata.get("unknown_detected", False)) decision_reason = fusion_stage.reason if decision == "execute" and fusion_stage.selected_intent is not None: intent = self._registry.get(fusion_stage.selected_intent) return RouteMatchResult( intent=intent, debug=RoutingDebug( selected_intent=fusion_stage.selected_intent, matched_stage=fusion_stage.stage, decision=decision, decision_reason=decision_reason, confidence_grade=confidence_grade, total_match_latency_ms=total_match_latency_ms, unknown_detected=unknown_detected, stages=stage_traces, ), ) return RouteMatchResult( intent=None, debug=RoutingDebug( selected_intent=None, matched_stage=fusion_stage.stage, decision=decision, decision_reason=decision_reason, confidence_grade=confidence_grade, total_match_latency_ms=total_match_latency_ms, unknown_detected=unknown_detected, stages=stage_traces, ), ) def _build_fusion_stage(self, stage_traces: list[MatcherStageDebug]) -> MatcherStageDebug: classifier_stage = next((stage for stage in stage_traces if stage.stage == "classifier"), None) if classifier_stage is None: return MatcherStageDebug( stage="fusion", accepted=False, reason="classifier stage is missing", model_name="fusion-router", backend="bert-first-fusion", metadata={ "grade": "low", "decision": "reject", "unknown_detected": True, "ranked_intents": [], }, candidates=[], ) ranked_candidates = list(classifier_stage.candidates or []) if not ranked_candidates and classifier_stage.selected_intent is not None: ranked_candidates = [ IntentCandidate( intent_id=classifier_stage.selected_intent, score=classifier_stage.score, reason="classifier selected intent", model_name=classifier_stage.model_name, ) ] if not ranked_candidates: return MatcherStageDebug( stage="fusion", accepted=False, reason="classifier did not produce a usable candidate", model_name="fusion-router", backend="bert-first-fusion", metadata={ "grade": "low", "decision": "reject", "unknown_detected": True, "ranked_intents": [], }, candidates=[], ) selected_candidate = ranked_candidates[0] selected_intent = selected_candidate.intent_id top_score = float(selected_candidate.score) second_score = float(ranked_candidates[1].score) if len(ranked_candidates) > 1 else 0.0 top_margin = top_score - second_score grade = self._fusion_grade(top_score) classifier_backend = str(classifier_stage.backend or classifier_stage.model_name or "") classifier_signal = top_score classifier_margin = top_margin bert_classifier_confident = ( "bert" in classifier_backend and classifier_signal >= self._classifier_execute_score_threshold and classifier_margin >= self._classifier_execute_margin_threshold ) ambiguous = ( classifier_signal >= self._route_to_cloud_threshold and classifier_signal < self._classifier_execute_score_threshold and top_margin < self._clarify_margin_threshold and len(ranked_candidates) > 1 ) accepted = bert_classifier_confident and not ambiguous possible_known_intent = classifier_signal >= self._route_to_cloud_threshold or classifier_signal >= 0.24 unknown_detected = not accepted and not possible_known_intent if accepted: decision = "execute" reason = f"bert classifier is confident enough to execute (grade={grade})" elif ambiguous: decision = "clarify" reason = "bert top candidates are too close and require a short clarification" elif possible_known_intent: decision = "route_to_cloud" reason = "bert signal is not stable enough, routing to cloud planner" else: decision = "reject" reason = "bert signal is below local capability threshold" metadata = { "grade": grade, "classifier_signal": round(classifier_signal, 4), "classifier_margin": round(classifier_margin, 4), "classifier_backend": classifier_backend or None, "bert_classifier_confident": bert_classifier_confident, "top_margin": round(top_margin, 4), "route_to_cloud_threshold": self._route_to_cloud_threshold, "clarify_margin_threshold": self._clarify_margin_threshold, "classifier_execute_score_threshold": self._classifier_execute_score_threshold, "classifier_execute_margin_threshold": self._classifier_execute_margin_threshold, "decision": decision, "unknown_detected": unknown_detected, "ranked_intents": [ {"intent_id": item.intent_id, "score": round(float(item.score), 4)} for item in ranked_candidates[:5] ], } return MatcherStageDebug( stage="fusion", accepted=accepted or ambiguous, selected_intent=selected_intent if decision in {"execute", "clarify", "route_to_cloud"} else None, score=top_score, reason=reason, model_name="fusion-router", backend="bert-first-fusion", metadata=metadata, candidates=ranked_candidates[:5], ) def _fusion_grade(self, score: float) -> str: if score >= self._classifier_execute_score_threshold: return "high" if score >= self._route_to_cloud_threshold: return "medium" return "low" class HeuristicSlotExtractor: def extract(self, text: str, intent: IntentDefinition) -> dict[str, Any]: slots: dict[str, Any] = {} order_id_match = re.search(r"\b[A-Za-z]\d{5,}\b", text) if order_id_match: slots["order_id"] = order_id_match.group(0) temperature_match = re.search(r"(\d{2})\s*度", text) if temperature_match: slots["temperature"] = int(temperature_match.group(1)) if intent.intent_id == "cabin_nav_to": destination = self._extract_destination(text) if destination: slots["destination"] = destination if intent.intent_id == "cabin_play_music": if "播放" in text: music_target = text.split("播放", maxsplit=1)[-1].strip(" ,。") if music_target: slots["song"] = music_target elif "音乐" in text: slots["genre"] = "轻音乐" return slots def _extract_destination(self, text: str) -> str | None: patterns = [ r"导航去(?P.+)", r"导航到(?P.+)", r"去(?P.+)", ] for pattern in patterns: match = re.search(pattern, text) if match: destination = match.group("destination").strip(" ,。") if destination: return destination return None class JointBertSlotExtractor: def __init__(self, nlu: JointBertNLU) -> None: self._nlu = nlu def extract(self, text: str, intent: IntentDefinition) -> dict[str, Any]: return self._nlu.extract_slots(text, intent) class Router(Protocol): def route(self, text: str) -> RouteMatchResult: ... def match_intent(self, text: str) -> IntentDefinition | None: ... def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, Any]: ... class IntentRouter: def __init__(self, matcher: IntentMatcher, slot_extractor: SlotExtractor) -> None: self._matcher = matcher self._slot_extractor = slot_extractor def route(self, text: str) -> RouteMatchResult: return self._matcher.match(text) def match_intent(self, text: str) -> IntentDefinition | None: return self.route(text).intent def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, Any]: return self._slot_extractor.extract(text, intent) def build_matcher_pipeline( registry: IntentRegistry, stages: list[str], classifier: IntentClassifier | None = None, route_to_cloud_threshold: float = 0.75, clarify_margin_threshold: float = 0.12, classifier_execute_score_threshold: float = 0.55, classifier_execute_margin_threshold: float = 0.18, ) -> MultiStageIntentMatcher: normalized_stages = [stage.strip() for stage in stages if stage.strip()] if not normalized_stages: normalized_stages = ["classifier"] if normalized_stages != ["classifier"]: raise ValueError("Only classifier matcher pipeline is supported in bert-first mode") matcher = ClassifierIntentMatcher(registry, classifier) if classifier is not None else NullIntentMatcher() return MultiStageIntentMatcher( registry, [matcher], route_to_cloud_threshold=route_to_cloud_threshold, clarify_margin_threshold=clarify_margin_threshold, classifier_execute_score_threshold=classifier_execute_score_threshold, classifier_execute_margin_threshold=classifier_execute_margin_threshold, ) class NullIntentMatcher: def match(self, text: str) -> IntentMatchResult: _ = text return IntentMatchResult( intent=None, stage_debug=MatcherStageDebug( stage="null", accepted=False, reason="matcher unavailable", candidates=[], ), )