Files
ai-device/intelligent_cabin/app/services/router.py
2026-06-11 16:28:00 +08:00

408 lines
16 KiB
Python

from __future__ import annotations
import re
from dataclasses import dataclass
from time import perf_counter
from typing import Any, Protocol
from app.schemas.debug import IntentCandidate, MatcherStageDebug, RoutingDebug
from app.schemas.intent import IntentDefinition
from app.services.classifier import IntentClassifier
from app.services.intent_registry import IntentRegistry
from app.services.joint_nlu import JointBertNLU
@dataclass
class IntentMatchResult:
intent: IntentDefinition | None
stage_debug: MatcherStageDebug
@dataclass
class RouteMatchResult:
intent: IntentDefinition | None
debug: RoutingDebug
class IntentMatcher(Protocol):
def match(self, text: str) -> IntentMatchResult:
...
class SlotExtractor(Protocol):
def extract(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
...
class ClassifierIntentMatcher:
def __init__(self, registry: IntentRegistry, classifier: IntentClassifier) -> None:
self._registry = registry
self._classifier = classifier
def match(self, text: str) -> IntentMatchResult:
result = self._classifier.predict(text, self._registry.list())
ranked_classifier_scores = sorted(
[float(score) for _, score in (result.candidates or [])],
reverse=True,
)
classifier_top_margin = (
ranked_classifier_scores[0] - ranked_classifier_scores[1]
if len(ranked_classifier_scores) >= 2
else ranked_classifier_scores[0] if ranked_classifier_scores else 0.0
)
candidates = [
IntentCandidate(
intent_id=intent.intent_id,
score=score,
reason="classifier candidate",
model_name=result.backend_name or result.model_name,
raw_label=next(
(
item.get("label")
for item in (result.raw_candidates or [])
if item.get("intent_id") == intent.intent_id and float(item.get("score", 0.0)) == score
),
intent.intent_id,
),
)
for intent, score in (result.candidates or [])
]
metadata: dict[str, Any] = {
"decision_model": result.model_name,
"threshold": getattr(self._classifier, "_threshold", None),
"raw_candidates": result.raw_candidates or [],
"top_margin": round(classifier_top_margin, 4),
}
if result.fallback_reason:
metadata["fallback_reason"] = result.fallback_reason
if result.intent is None:
return IntentMatchResult(
intent=None,
stage_debug=MatcherStageDebug(
stage="classifier",
accepted=False,
score=result.score,
reason=result.fallback_reason or "classifier below threshold or no intent selected",
model_name=result.model_name,
backend=result.backend_name or result.model_name,
fallback_used=result.used_fallback,
raw_label=result.raw_label,
error_message=result.error_message,
metadata=metadata,
candidates=candidates,
),
)
return IntentMatchResult(
intent=result.intent,
stage_debug=MatcherStageDebug(
stage="classifier",
accepted=True,
selected_intent=result.intent.intent_id,
score=result.score,
reason=(
"bert classifier selected best candidate"
if not result.used_fallback
else f"fallback selected best candidate: {result.fallback_reason}"
),
model_name=result.model_name,
backend=result.backend_name or result.model_name,
fallback_used=result.used_fallback,
raw_label=result.raw_label,
error_message=result.error_message,
metadata=metadata,
candidates=candidates,
),
)
class MultiStageIntentMatcher:
def __init__(
self,
registry: IntentRegistry,
matchers: list[IntentMatcher],
route_to_cloud_threshold: float = 0.75,
clarify_margin_threshold: float = 0.12,
classifier_execute_score_threshold: float = 0.55,
classifier_execute_margin_threshold: float = 0.18,
) -> None:
self._registry = registry
self._matchers = matchers
self._route_to_cloud_threshold = route_to_cloud_threshold
self._clarify_margin_threshold = clarify_margin_threshold
self._classifier_execute_score_threshold = classifier_execute_score_threshold
self._classifier_execute_margin_threshold = classifier_execute_margin_threshold
def match(self, text: str) -> RouteMatchResult:
stage_traces: list[MatcherStageDebug] = []
match_started_at = perf_counter()
for matcher in self._matchers:
stage_started_at = perf_counter()
result = matcher.match(text)
result.stage_debug.elapsed_ms = round((perf_counter() - stage_started_at) * 1000, 3)
stage_traces.append(result.stage_debug)
fusion_started_at = perf_counter()
fusion_stage = self._build_fusion_stage(stage_traces)
fusion_stage.elapsed_ms = round((perf_counter() - fusion_started_at) * 1000, 3)
stage_traces.append(fusion_stage)
total_match_latency_ms = round((perf_counter() - match_started_at) * 1000, 3)
decision = str(fusion_stage.metadata.get("decision", "reject"))
confidence_grade = str(fusion_stage.metadata.get("grade", "low"))
unknown_detected = bool(fusion_stage.metadata.get("unknown_detected", False))
decision_reason = fusion_stage.reason
if decision == "execute" and fusion_stage.selected_intent is not None:
intent = self._registry.get(fusion_stage.selected_intent)
return RouteMatchResult(
intent=intent,
debug=RoutingDebug(
selected_intent=fusion_stage.selected_intent,
matched_stage=fusion_stage.stage,
decision=decision,
decision_reason=decision_reason,
confidence_grade=confidence_grade,
total_match_latency_ms=total_match_latency_ms,
unknown_detected=unknown_detected,
stages=stage_traces,
),
)
return RouteMatchResult(
intent=None,
debug=RoutingDebug(
selected_intent=None,
matched_stage=fusion_stage.stage,
decision=decision,
decision_reason=decision_reason,
confidence_grade=confidence_grade,
total_match_latency_ms=total_match_latency_ms,
unknown_detected=unknown_detected,
stages=stage_traces,
),
)
def _build_fusion_stage(self, stage_traces: list[MatcherStageDebug]) -> MatcherStageDebug:
classifier_stage = next((stage for stage in stage_traces if stage.stage == "classifier"), None)
if classifier_stage is None:
return MatcherStageDebug(
stage="fusion",
accepted=False,
reason="classifier stage is missing",
model_name="fusion-router",
backend="bert-first-fusion",
metadata={
"grade": "low",
"decision": "reject",
"unknown_detected": True,
"ranked_intents": [],
},
candidates=[],
)
ranked_candidates = list(classifier_stage.candidates or [])
if not ranked_candidates and classifier_stage.selected_intent is not None:
ranked_candidates = [
IntentCandidate(
intent_id=classifier_stage.selected_intent,
score=classifier_stage.score,
reason="classifier selected intent",
model_name=classifier_stage.model_name,
)
]
if not ranked_candidates:
return MatcherStageDebug(
stage="fusion",
accepted=False,
reason="classifier did not produce a usable candidate",
model_name="fusion-router",
backend="bert-first-fusion",
metadata={
"grade": "low",
"decision": "reject",
"unknown_detected": True,
"ranked_intents": [],
},
candidates=[],
)
selected_candidate = ranked_candidates[0]
selected_intent = selected_candidate.intent_id
top_score = float(selected_candidate.score)
second_score = float(ranked_candidates[1].score) if len(ranked_candidates) > 1 else 0.0
top_margin = top_score - second_score
grade = self._fusion_grade(top_score)
classifier_backend = str(classifier_stage.backend or classifier_stage.model_name or "")
classifier_signal = top_score
classifier_margin = top_margin
bert_classifier_confident = (
"bert" in classifier_backend
and classifier_signal >= self._classifier_execute_score_threshold
and classifier_margin >= self._classifier_execute_margin_threshold
)
ambiguous = (
classifier_signal >= self._route_to_cloud_threshold
and classifier_signal < self._classifier_execute_score_threshold
and top_margin < self._clarify_margin_threshold
and len(ranked_candidates) > 1
)
accepted = bert_classifier_confident and not ambiguous
possible_known_intent = classifier_signal >= self._route_to_cloud_threshold or classifier_signal >= 0.24
unknown_detected = not accepted and not possible_known_intent
if accepted:
decision = "execute"
reason = f"bert classifier is confident enough to execute (grade={grade})"
elif ambiguous:
decision = "clarify"
reason = "bert top candidates are too close and require a short clarification"
elif possible_known_intent:
decision = "route_to_cloud"
reason = "bert signal is not stable enough, routing to cloud planner"
else:
decision = "reject"
reason = "bert signal is below local capability threshold"
metadata = {
"grade": grade,
"classifier_signal": round(classifier_signal, 4),
"classifier_margin": round(classifier_margin, 4),
"classifier_backend": classifier_backend or None,
"bert_classifier_confident": bert_classifier_confident,
"top_margin": round(top_margin, 4),
"route_to_cloud_threshold": self._route_to_cloud_threshold,
"clarify_margin_threshold": self._clarify_margin_threshold,
"classifier_execute_score_threshold": self._classifier_execute_score_threshold,
"classifier_execute_margin_threshold": self._classifier_execute_margin_threshold,
"decision": decision,
"unknown_detected": unknown_detected,
"ranked_intents": [
{"intent_id": item.intent_id, "score": round(float(item.score), 4)}
for item in ranked_candidates[:5]
],
}
return MatcherStageDebug(
stage="fusion",
accepted=accepted or ambiguous,
selected_intent=selected_intent if decision in {"execute", "clarify", "route_to_cloud"} else None,
score=top_score,
reason=reason,
model_name="fusion-router",
backend="bert-first-fusion",
metadata=metadata,
candidates=ranked_candidates[:5],
)
def _fusion_grade(self, score: float) -> str:
if score >= self._classifier_execute_score_threshold:
return "high"
if score >= self._route_to_cloud_threshold:
return "medium"
return "low"
class HeuristicSlotExtractor:
def extract(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
slots: dict[str, Any] = {}
order_id_match = re.search(r"\b[A-Za-z]\d{5,}\b", text)
if order_id_match:
slots["order_id"] = order_id_match.group(0)
temperature_match = re.search(r"(\d{2})\s*度", text)
if temperature_match:
slots["temperature"] = int(temperature_match.group(1))
if intent.intent_id == "cabin_nav_to":
destination = self._extract_destination(text)
if destination:
slots["destination"] = destination
if intent.intent_id == "cabin_play_music":
if "播放" in text:
music_target = text.split("播放", maxsplit=1)[-1].strip(" ,。")
if music_target:
slots["song"] = music_target
elif "音乐" in text:
slots["genre"] = "轻音乐"
return slots
def _extract_destination(self, text: str) -> str | None:
patterns = [
r"导航去(?P<destination>.+)",
r"导航到(?P<destination>.+)",
r"去(?P<destination>.+)",
]
for pattern in patterns:
match = re.search(pattern, text)
if match:
destination = match.group("destination").strip(" ,。")
if destination:
return destination
return None
class JointBertSlotExtractor:
def __init__(self, nlu: JointBertNLU) -> None:
self._nlu = nlu
def extract(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
return self._nlu.extract_slots(text, intent)
class Router(Protocol):
def route(self, text: str) -> RouteMatchResult:
...
def match_intent(self, text: str) -> IntentDefinition | None:
...
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
...
class IntentRouter:
def __init__(self, matcher: IntentMatcher, slot_extractor: SlotExtractor) -> None:
self._matcher = matcher
self._slot_extractor = slot_extractor
def route(self, text: str) -> RouteMatchResult:
return self._matcher.match(text)
def match_intent(self, text: str) -> IntentDefinition | None:
return self.route(text).intent
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
return self._slot_extractor.extract(text, intent)
def build_matcher_pipeline(
registry: IntentRegistry,
stages: list[str],
classifier: IntentClassifier | None = None,
route_to_cloud_threshold: float = 0.75,
clarify_margin_threshold: float = 0.12,
classifier_execute_score_threshold: float = 0.55,
classifier_execute_margin_threshold: float = 0.18,
) -> MultiStageIntentMatcher:
normalized_stages = [stage.strip() for stage in stages if stage.strip()]
if not normalized_stages:
normalized_stages = ["classifier"]
if normalized_stages != ["classifier"]:
raise ValueError("Only classifier matcher pipeline is supported in bert-first mode")
matcher = ClassifierIntentMatcher(registry, classifier) if classifier is not None else NullIntentMatcher()
return MultiStageIntentMatcher(
registry,
[matcher],
route_to_cloud_threshold=route_to_cloud_threshold,
clarify_margin_threshold=clarify_margin_threshold,
classifier_execute_score_threshold=classifier_execute_score_threshold,
classifier_execute_margin_threshold=classifier_execute_margin_threshold,
)
class NullIntentMatcher:
def match(self, text: str) -> IntentMatchResult:
_ = text
return IntentMatchResult(
intent=None,
stage_debug=MatcherStageDebug(
stage="null",
accepted=False,
reason="matcher unavailable",
candidates=[],
),
)