408 lines
16 KiB
Python
408 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from dataclasses import dataclass
|
|
from time import perf_counter
|
|
from typing import Any, Protocol
|
|
|
|
from app.schemas.debug import IntentCandidate, MatcherStageDebug, RoutingDebug
|
|
from app.schemas.intent import IntentDefinition
|
|
from app.services.classifier import IntentClassifier
|
|
from app.services.intent_registry import IntentRegistry
|
|
from app.services.joint_nlu import JointBertNLU
|
|
|
|
|
|
@dataclass
|
|
class IntentMatchResult:
|
|
intent: IntentDefinition | None
|
|
stage_debug: MatcherStageDebug
|
|
|
|
|
|
@dataclass
|
|
class RouteMatchResult:
|
|
intent: IntentDefinition | None
|
|
debug: RoutingDebug
|
|
|
|
|
|
class IntentMatcher(Protocol):
|
|
def match(self, text: str) -> IntentMatchResult:
|
|
...
|
|
|
|
|
|
class SlotExtractor(Protocol):
|
|
def extract(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
|
|
...
|
|
|
|
|
|
class ClassifierIntentMatcher:
|
|
def __init__(self, registry: IntentRegistry, classifier: IntentClassifier) -> None:
|
|
self._registry = registry
|
|
self._classifier = classifier
|
|
|
|
def match(self, text: str) -> IntentMatchResult:
|
|
result = self._classifier.predict(text, self._registry.list())
|
|
ranked_classifier_scores = sorted(
|
|
[float(score) for _, score in (result.candidates or [])],
|
|
reverse=True,
|
|
)
|
|
classifier_top_margin = (
|
|
ranked_classifier_scores[0] - ranked_classifier_scores[1]
|
|
if len(ranked_classifier_scores) >= 2
|
|
else ranked_classifier_scores[0] if ranked_classifier_scores else 0.0
|
|
)
|
|
candidates = [
|
|
IntentCandidate(
|
|
intent_id=intent.intent_id,
|
|
score=score,
|
|
reason="classifier candidate",
|
|
model_name=result.backend_name or result.model_name,
|
|
raw_label=next(
|
|
(
|
|
item.get("label")
|
|
for item in (result.raw_candidates or [])
|
|
if item.get("intent_id") == intent.intent_id and float(item.get("score", 0.0)) == score
|
|
),
|
|
intent.intent_id,
|
|
),
|
|
)
|
|
for intent, score in (result.candidates or [])
|
|
]
|
|
metadata: dict[str, Any] = {
|
|
"decision_model": result.model_name,
|
|
"threshold": getattr(self._classifier, "_threshold", None),
|
|
"raw_candidates": result.raw_candidates or [],
|
|
"top_margin": round(classifier_top_margin, 4),
|
|
}
|
|
if result.fallback_reason:
|
|
metadata["fallback_reason"] = result.fallback_reason
|
|
if result.intent is None:
|
|
return IntentMatchResult(
|
|
intent=None,
|
|
stage_debug=MatcherStageDebug(
|
|
stage="classifier",
|
|
accepted=False,
|
|
score=result.score,
|
|
reason=result.fallback_reason or "classifier below threshold or no intent selected",
|
|
model_name=result.model_name,
|
|
backend=result.backend_name or result.model_name,
|
|
fallback_used=result.used_fallback,
|
|
raw_label=result.raw_label,
|
|
error_message=result.error_message,
|
|
metadata=metadata,
|
|
candidates=candidates,
|
|
),
|
|
)
|
|
return IntentMatchResult(
|
|
intent=result.intent,
|
|
stage_debug=MatcherStageDebug(
|
|
stage="classifier",
|
|
accepted=True,
|
|
selected_intent=result.intent.intent_id,
|
|
score=result.score,
|
|
reason=(
|
|
"bert classifier selected best candidate"
|
|
if not result.used_fallback
|
|
else f"fallback selected best candidate: {result.fallback_reason}"
|
|
),
|
|
model_name=result.model_name,
|
|
backend=result.backend_name or result.model_name,
|
|
fallback_used=result.used_fallback,
|
|
raw_label=result.raw_label,
|
|
error_message=result.error_message,
|
|
metadata=metadata,
|
|
candidates=candidates,
|
|
),
|
|
)
|
|
|
|
|
|
class MultiStageIntentMatcher:
|
|
def __init__(
|
|
self,
|
|
registry: IntentRegistry,
|
|
matchers: list[IntentMatcher],
|
|
route_to_cloud_threshold: float = 0.75,
|
|
clarify_margin_threshold: float = 0.12,
|
|
classifier_execute_score_threshold: float = 0.55,
|
|
classifier_execute_margin_threshold: float = 0.18,
|
|
) -> None:
|
|
self._registry = registry
|
|
self._matchers = matchers
|
|
self._route_to_cloud_threshold = route_to_cloud_threshold
|
|
self._clarify_margin_threshold = clarify_margin_threshold
|
|
self._classifier_execute_score_threshold = classifier_execute_score_threshold
|
|
self._classifier_execute_margin_threshold = classifier_execute_margin_threshold
|
|
|
|
def match(self, text: str) -> RouteMatchResult:
|
|
stage_traces: list[MatcherStageDebug] = []
|
|
match_started_at = perf_counter()
|
|
for matcher in self._matchers:
|
|
stage_started_at = perf_counter()
|
|
result = matcher.match(text)
|
|
result.stage_debug.elapsed_ms = round((perf_counter() - stage_started_at) * 1000, 3)
|
|
stage_traces.append(result.stage_debug)
|
|
fusion_started_at = perf_counter()
|
|
fusion_stage = self._build_fusion_stage(stage_traces)
|
|
fusion_stage.elapsed_ms = round((perf_counter() - fusion_started_at) * 1000, 3)
|
|
stage_traces.append(fusion_stage)
|
|
total_match_latency_ms = round((perf_counter() - match_started_at) * 1000, 3)
|
|
decision = str(fusion_stage.metadata.get("decision", "reject"))
|
|
confidence_grade = str(fusion_stage.metadata.get("grade", "low"))
|
|
unknown_detected = bool(fusion_stage.metadata.get("unknown_detected", False))
|
|
decision_reason = fusion_stage.reason
|
|
if decision == "execute" and fusion_stage.selected_intent is not None:
|
|
intent = self._registry.get(fusion_stage.selected_intent)
|
|
return RouteMatchResult(
|
|
intent=intent,
|
|
debug=RoutingDebug(
|
|
selected_intent=fusion_stage.selected_intent,
|
|
matched_stage=fusion_stage.stage,
|
|
decision=decision,
|
|
decision_reason=decision_reason,
|
|
confidence_grade=confidence_grade,
|
|
total_match_latency_ms=total_match_latency_ms,
|
|
unknown_detected=unknown_detected,
|
|
stages=stage_traces,
|
|
),
|
|
)
|
|
return RouteMatchResult(
|
|
intent=None,
|
|
debug=RoutingDebug(
|
|
selected_intent=None,
|
|
matched_stage=fusion_stage.stage,
|
|
decision=decision,
|
|
decision_reason=decision_reason,
|
|
confidence_grade=confidence_grade,
|
|
total_match_latency_ms=total_match_latency_ms,
|
|
unknown_detected=unknown_detected,
|
|
stages=stage_traces,
|
|
),
|
|
)
|
|
|
|
def _build_fusion_stage(self, stage_traces: list[MatcherStageDebug]) -> MatcherStageDebug:
|
|
classifier_stage = next((stage for stage in stage_traces if stage.stage == "classifier"), None)
|
|
if classifier_stage is None:
|
|
return MatcherStageDebug(
|
|
stage="fusion",
|
|
accepted=False,
|
|
reason="classifier stage is missing",
|
|
model_name="fusion-router",
|
|
backend="bert-first-fusion",
|
|
metadata={
|
|
"grade": "low",
|
|
"decision": "reject",
|
|
"unknown_detected": True,
|
|
"ranked_intents": [],
|
|
},
|
|
candidates=[],
|
|
)
|
|
ranked_candidates = list(classifier_stage.candidates or [])
|
|
if not ranked_candidates and classifier_stage.selected_intent is not None:
|
|
ranked_candidates = [
|
|
IntentCandidate(
|
|
intent_id=classifier_stage.selected_intent,
|
|
score=classifier_stage.score,
|
|
reason="classifier selected intent",
|
|
model_name=classifier_stage.model_name,
|
|
)
|
|
]
|
|
if not ranked_candidates:
|
|
return MatcherStageDebug(
|
|
stage="fusion",
|
|
accepted=False,
|
|
reason="classifier did not produce a usable candidate",
|
|
model_name="fusion-router",
|
|
backend="bert-first-fusion",
|
|
metadata={
|
|
"grade": "low",
|
|
"decision": "reject",
|
|
"unknown_detected": True,
|
|
"ranked_intents": [],
|
|
},
|
|
candidates=[],
|
|
)
|
|
selected_candidate = ranked_candidates[0]
|
|
selected_intent = selected_candidate.intent_id
|
|
top_score = float(selected_candidate.score)
|
|
second_score = float(ranked_candidates[1].score) if len(ranked_candidates) > 1 else 0.0
|
|
top_margin = top_score - second_score
|
|
grade = self._fusion_grade(top_score)
|
|
classifier_backend = str(classifier_stage.backend or classifier_stage.model_name or "")
|
|
classifier_signal = top_score
|
|
classifier_margin = top_margin
|
|
bert_classifier_confident = (
|
|
"bert" in classifier_backend
|
|
and classifier_signal >= self._classifier_execute_score_threshold
|
|
and classifier_margin >= self._classifier_execute_margin_threshold
|
|
)
|
|
ambiguous = (
|
|
classifier_signal >= self._route_to_cloud_threshold
|
|
and classifier_signal < self._classifier_execute_score_threshold
|
|
and top_margin < self._clarify_margin_threshold
|
|
and len(ranked_candidates) > 1
|
|
)
|
|
accepted = bert_classifier_confident and not ambiguous
|
|
possible_known_intent = classifier_signal >= self._route_to_cloud_threshold or classifier_signal >= 0.24
|
|
unknown_detected = not accepted and not possible_known_intent
|
|
if accepted:
|
|
decision = "execute"
|
|
reason = f"bert classifier is confident enough to execute (grade={grade})"
|
|
elif ambiguous:
|
|
decision = "clarify"
|
|
reason = "bert top candidates are too close and require a short clarification"
|
|
elif possible_known_intent:
|
|
decision = "route_to_cloud"
|
|
reason = "bert signal is not stable enough, routing to cloud planner"
|
|
else:
|
|
decision = "reject"
|
|
reason = "bert signal is below local capability threshold"
|
|
metadata = {
|
|
"grade": grade,
|
|
"classifier_signal": round(classifier_signal, 4),
|
|
"classifier_margin": round(classifier_margin, 4),
|
|
"classifier_backend": classifier_backend or None,
|
|
"bert_classifier_confident": bert_classifier_confident,
|
|
"top_margin": round(top_margin, 4),
|
|
"route_to_cloud_threshold": self._route_to_cloud_threshold,
|
|
"clarify_margin_threshold": self._clarify_margin_threshold,
|
|
"classifier_execute_score_threshold": self._classifier_execute_score_threshold,
|
|
"classifier_execute_margin_threshold": self._classifier_execute_margin_threshold,
|
|
"decision": decision,
|
|
"unknown_detected": unknown_detected,
|
|
"ranked_intents": [
|
|
{"intent_id": item.intent_id, "score": round(float(item.score), 4)}
|
|
for item in ranked_candidates[:5]
|
|
],
|
|
}
|
|
return MatcherStageDebug(
|
|
stage="fusion",
|
|
accepted=accepted or ambiguous,
|
|
selected_intent=selected_intent if decision in {"execute", "clarify", "route_to_cloud"} else None,
|
|
score=top_score,
|
|
reason=reason,
|
|
model_name="fusion-router",
|
|
backend="bert-first-fusion",
|
|
metadata=metadata,
|
|
candidates=ranked_candidates[:5],
|
|
)
|
|
|
|
def _fusion_grade(self, score: float) -> str:
|
|
if score >= self._classifier_execute_score_threshold:
|
|
return "high"
|
|
if score >= self._route_to_cloud_threshold:
|
|
return "medium"
|
|
return "low"
|
|
|
|
|
|
class HeuristicSlotExtractor:
|
|
def extract(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
|
|
slots: dict[str, Any] = {}
|
|
order_id_match = re.search(r"\b[A-Za-z]\d{5,}\b", text)
|
|
if order_id_match:
|
|
slots["order_id"] = order_id_match.group(0)
|
|
|
|
temperature_match = re.search(r"(\d{2})\s*度", text)
|
|
if temperature_match:
|
|
slots["temperature"] = int(temperature_match.group(1))
|
|
|
|
if intent.intent_id == "cabin_nav_to":
|
|
destination = self._extract_destination(text)
|
|
if destination:
|
|
slots["destination"] = destination
|
|
|
|
if intent.intent_id == "cabin_play_music":
|
|
if "播放" in text:
|
|
music_target = text.split("播放", maxsplit=1)[-1].strip(" ,。")
|
|
if music_target:
|
|
slots["song"] = music_target
|
|
elif "音乐" in text:
|
|
slots["genre"] = "轻音乐"
|
|
|
|
return slots
|
|
|
|
def _extract_destination(self, text: str) -> str | None:
|
|
patterns = [
|
|
r"导航去(?P<destination>.+)",
|
|
r"导航到(?P<destination>.+)",
|
|
r"去(?P<destination>.+)",
|
|
]
|
|
for pattern in patterns:
|
|
match = re.search(pattern, text)
|
|
if match:
|
|
destination = match.group("destination").strip(" ,。")
|
|
if destination:
|
|
return destination
|
|
return None
|
|
|
|
|
|
class JointBertSlotExtractor:
|
|
def __init__(self, nlu: JointBertNLU) -> None:
|
|
self._nlu = nlu
|
|
|
|
def extract(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
|
|
return self._nlu.extract_slots(text, intent)
|
|
|
|
|
|
class Router(Protocol):
|
|
def route(self, text: str) -> RouteMatchResult:
|
|
...
|
|
|
|
def match_intent(self, text: str) -> IntentDefinition | None:
|
|
...
|
|
|
|
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
|
|
...
|
|
|
|
|
|
class IntentRouter:
|
|
def __init__(self, matcher: IntentMatcher, slot_extractor: SlotExtractor) -> None:
|
|
self._matcher = matcher
|
|
self._slot_extractor = slot_extractor
|
|
|
|
def route(self, text: str) -> RouteMatchResult:
|
|
return self._matcher.match(text)
|
|
|
|
def match_intent(self, text: str) -> IntentDefinition | None:
|
|
return self.route(text).intent
|
|
|
|
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
|
|
return self._slot_extractor.extract(text, intent)
|
|
|
|
|
|
def build_matcher_pipeline(
|
|
registry: IntentRegistry,
|
|
stages: list[str],
|
|
classifier: IntentClassifier | None = None,
|
|
route_to_cloud_threshold: float = 0.75,
|
|
clarify_margin_threshold: float = 0.12,
|
|
classifier_execute_score_threshold: float = 0.55,
|
|
classifier_execute_margin_threshold: float = 0.18,
|
|
) -> MultiStageIntentMatcher:
|
|
normalized_stages = [stage.strip() for stage in stages if stage.strip()]
|
|
if not normalized_stages:
|
|
normalized_stages = ["classifier"]
|
|
if normalized_stages != ["classifier"]:
|
|
raise ValueError("Only classifier matcher pipeline is supported in bert-first mode")
|
|
matcher = ClassifierIntentMatcher(registry, classifier) if classifier is not None else NullIntentMatcher()
|
|
return MultiStageIntentMatcher(
|
|
registry,
|
|
[matcher],
|
|
route_to_cloud_threshold=route_to_cloud_threshold,
|
|
clarify_margin_threshold=clarify_margin_threshold,
|
|
classifier_execute_score_threshold=classifier_execute_score_threshold,
|
|
classifier_execute_margin_threshold=classifier_execute_margin_threshold,
|
|
)
|
|
|
|
|
|
class NullIntentMatcher:
|
|
def match(self, text: str) -> IntentMatchResult:
|
|
_ = text
|
|
return IntentMatchResult(
|
|
intent=None,
|
|
stage_debug=MatcherStageDebug(
|
|
stage="null",
|
|
accepted=False,
|
|
reason="matcher unavailable",
|
|
candidates=[],
|
|
),
|
|
)
|