Update project and configurations
This commit is contained in:
1
intelligent_cabin/app/services/__init__.py
Normal file
1
intelligent_cabin/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Application services for orchestration and session management."""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1305
intelligent_cabin/app/services/agent_service.py
Normal file
1305
intelligent_cabin/app/services/agent_service.py
Normal file
File diff suppressed because it is too large
Load Diff
600
intelligent_cabin/app/services/classifier.py
Normal file
600
intelligent_cabin/app/services/classifier.py
Normal file
@@ -0,0 +1,600 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
from typing import Any, Protocol
|
||||
from urllib import error, request
|
||||
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.services.joint_nlu import JointBertNLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassificationResult:
|
||||
intent: IntentDefinition | None
|
||||
score: float = 0.0
|
||||
model_name: str = "mock-classifier"
|
||||
candidates: list[tuple[IntentDefinition, float]] | None = None
|
||||
backend_name: str | None = None
|
||||
used_fallback: bool = False
|
||||
fallback_reason: str | None = None
|
||||
error_message: str | None = None
|
||||
raw_label: str | None = None
|
||||
raw_candidates: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class IntentClassifier(Protocol):
|
||||
def predict(self, text: str, intents: list[IntentDefinition]) -> ClassificationResult:
|
||||
...
|
||||
|
||||
|
||||
class MockIntentClassifier:
|
||||
"""A local classifier stub that mimics a BERT-style scoring interface."""
|
||||
|
||||
def __init__(self, threshold: float = 1.2, top_k: int = 3) -> None:
|
||||
self._threshold = threshold
|
||||
self._top_k = top_k
|
||||
|
||||
def predict(self, text: str, intents: list[IntentDefinition]) -> ClassificationResult:
|
||||
query_tokens = self._tokenize(text)
|
||||
if not query_tokens:
|
||||
return ClassificationResult(
|
||||
intent=None,
|
||||
score=0.0,
|
||||
backend_name="mock-classifier",
|
||||
raw_candidates=[],
|
||||
)
|
||||
|
||||
scored_intents: list[tuple[IntentDefinition, float]] = []
|
||||
for intent in intents:
|
||||
score = self._score_intent(query_tokens, intent)
|
||||
scored_intents.append((intent, score))
|
||||
|
||||
scored_intents.sort(key=lambda item: item[1], reverse=True)
|
||||
best_intent: IntentDefinition | None = scored_intents[0][0] if scored_intents else None
|
||||
best_score = scored_intents[0][1] if scored_intents else 0.0
|
||||
top_candidates = [(intent, score) for intent, score in scored_intents[: self._top_k] if score > 0]
|
||||
raw_candidates = [
|
||||
{
|
||||
"label": intent.intent_id,
|
||||
"intent_id": intent.intent_id,
|
||||
"score": score,
|
||||
}
|
||||
for intent, score in top_candidates
|
||||
]
|
||||
|
||||
if best_score < self._threshold:
|
||||
return ClassificationResult(
|
||||
intent=None,
|
||||
score=best_score,
|
||||
candidates=top_candidates,
|
||||
backend_name="mock-classifier",
|
||||
fallback_reason="below threshold",
|
||||
raw_candidates=raw_candidates,
|
||||
)
|
||||
return ClassificationResult(
|
||||
intent=best_intent,
|
||||
score=best_score,
|
||||
candidates=top_candidates,
|
||||
backend_name="mock-classifier",
|
||||
raw_label=best_intent.intent_id if best_intent is not None else None,
|
||||
raw_candidates=raw_candidates,
|
||||
)
|
||||
|
||||
def _score_intent(self, query_tokens: set[str], intent: IntentDefinition) -> float:
|
||||
score = 0.0
|
||||
for keyword in intent.keywords:
|
||||
keyword_tokens = self._tokenize(keyword)
|
||||
overlap = len(query_tokens & keyword_tokens)
|
||||
if overlap:
|
||||
score += overlap * 1.4
|
||||
|
||||
for example in intent.examples:
|
||||
example_tokens = self._tokenize(example)
|
||||
overlap = len(query_tokens & example_tokens)
|
||||
if not example_tokens:
|
||||
continue
|
||||
coverage = overlap / len(example_tokens)
|
||||
score = max(score, overlap + coverage)
|
||||
|
||||
if intent.domain == "customer_service" and any(token in query_tokens for token in {"订单", "物流", "快递"}):
|
||||
score += 0.2
|
||||
if intent.domain == "cabin" and any(token in query_tokens for token in {"导航", "空调", "音乐", "歌曲"}):
|
||||
score += 0.2
|
||||
return score
|
||||
|
||||
def _tokenize(self, text: str) -> set[str]:
|
||||
cleaned = re.sub(r"[,。!?、\s]+", " ", text.strip().lower())
|
||||
tokens = {token for token in cleaned.split(" ") if token}
|
||||
compact = cleaned.replace(" ", "")
|
||||
for size in (2, 3, 4):
|
||||
for index in range(0, max(len(compact) - size + 1, 0)):
|
||||
tokens.add(compact[index : index + size])
|
||||
return tokens
|
||||
|
||||
|
||||
class BertIntentClassifier:
|
||||
"""
|
||||
A pluggable local classifier interface for future BERT/Transformer models.
|
||||
|
||||
Expected model behavior:
|
||||
- input: user text
|
||||
- output labels: intent_id strings, or numeric indices mapped through a label file
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
threshold: float = 0.5,
|
||||
label_map_path: str | None = None,
|
||||
fallback: IntentClassifier | None = None,
|
||||
top_k: int = 3,
|
||||
) -> None:
|
||||
self._model_path = model_path
|
||||
self._threshold = threshold
|
||||
self._label_map_path = label_map_path
|
||||
self._fallback = fallback
|
||||
self._top_k = top_k
|
||||
self._pipeline = None
|
||||
self._label_map = self._load_label_map(label_map_path)
|
||||
self._warmed_up = False
|
||||
self._warmup_elapsed_ms: float | None = None
|
||||
self._warmup_error_message: str | None = None
|
||||
|
||||
def predict(self, text: str, intents: list[IntentDefinition]) -> ClassificationResult:
|
||||
pipeline = self._get_pipeline()
|
||||
if pipeline is None:
|
||||
return self._predict_with_fallback(
|
||||
text,
|
||||
intents,
|
||||
attempted_backend="bert-local",
|
||||
fallback_reason="bert model is unavailable",
|
||||
error_message=self._pipeline_error_message(),
|
||||
)
|
||||
|
||||
try:
|
||||
raw_output = pipeline(text, truncation=True, top_k=self._top_k)
|
||||
except Exception as exc:
|
||||
return self._predict_with_fallback(
|
||||
text,
|
||||
intents,
|
||||
attempted_backend="bert-local",
|
||||
fallback_reason="bert inference failed",
|
||||
error_message=str(exc),
|
||||
)
|
||||
|
||||
normalized = self._normalize_pipeline_candidates(raw_output)
|
||||
if not normalized:
|
||||
return self._predict_with_fallback(
|
||||
text,
|
||||
intents,
|
||||
attempted_backend="bert-local",
|
||||
fallback_reason="bert returned empty result",
|
||||
raw_candidates=[],
|
||||
)
|
||||
|
||||
resolved = self._resolve_candidates(normalized, intents)
|
||||
top_candidate = resolved["top_candidate"]
|
||||
if top_candidate["intent"] is None:
|
||||
return self._predict_with_fallback(
|
||||
text,
|
||||
intents,
|
||||
attempted_backend="bert-local",
|
||||
fallback_reason="bert label is not mapped to a known intent",
|
||||
score=top_candidate["score"],
|
||||
raw_label=top_candidate["label"],
|
||||
raw_candidates=normalized,
|
||||
)
|
||||
if top_candidate["score"] < self._threshold:
|
||||
return self._predict_with_fallback(
|
||||
text,
|
||||
intents,
|
||||
attempted_backend="bert-local",
|
||||
fallback_reason="bert score is below threshold",
|
||||
score=top_candidate["score"],
|
||||
raw_label=top_candidate["label"],
|
||||
raw_candidates=normalized,
|
||||
)
|
||||
|
||||
return ClassificationResult(
|
||||
intent=top_candidate["intent"],
|
||||
score=top_candidate["score"],
|
||||
model_name="bert-local",
|
||||
candidates=resolved["known_candidates"],
|
||||
backend_name="bert-local",
|
||||
raw_label=top_candidate["label"],
|
||||
raw_candidates=normalized,
|
||||
)
|
||||
|
||||
def _get_pipeline(self):
|
||||
if self._pipeline is not None:
|
||||
return self._pipeline
|
||||
if not self._model_path or not Path(self._model_path).exists():
|
||||
return None
|
||||
try:
|
||||
transformers = importlib.import_module("transformers")
|
||||
except ImportError:
|
||||
return None
|
||||
self._pipeline = transformers.pipeline(
|
||||
"text-classification",
|
||||
model=self._model_path,
|
||||
tokenizer=self._model_path,
|
||||
)
|
||||
return self._pipeline
|
||||
|
||||
def warmup(self, sample_text: str = "打开车窗") -> bool:
|
||||
if self._warmed_up:
|
||||
return True
|
||||
started_at = perf_counter()
|
||||
pipeline = self._get_pipeline()
|
||||
if pipeline is None:
|
||||
self._warmup_error_message = self._pipeline_error_message()
|
||||
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
||||
return False
|
||||
try:
|
||||
warmup_texts = [sample_text.strip() or "打开车窗", f"请帮我{sample_text.strip() or '打开车窗'}"]
|
||||
for text in dict.fromkeys(warmup_texts):
|
||||
pipeline(text, truncation=True, top_k=self._top_k)
|
||||
except Exception as exc:
|
||||
self._warmup_error_message = str(exc)
|
||||
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
||||
return False
|
||||
self._warmup_error_message = None
|
||||
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
||||
self._warmed_up = True
|
||||
return True
|
||||
|
||||
def _pipeline_error_message(self) -> str:
|
||||
if not self._model_path:
|
||||
return "AGENT_CLASSIFIER_MODEL_PATH is empty"
|
||||
if not Path(self._model_path).exists():
|
||||
return f"model path not found: {self._model_path}"
|
||||
try:
|
||||
importlib.import_module("transformers")
|
||||
except ImportError:
|
||||
return "transformers is not installed"
|
||||
return "pipeline init failed"
|
||||
|
||||
def _load_label_map(self, label_map_path: str | None) -> dict[str, str]:
|
||||
if not label_map_path:
|
||||
return {}
|
||||
path = Path(label_map_path)
|
||||
if not path.exists():
|
||||
return {}
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
return {str(key): str(value) for key, value in data.items()}
|
||||
|
||||
def _resolve_label(self, label: str) -> str | None:
|
||||
if label in self._label_map:
|
||||
return self._label_map[label]
|
||||
return label or None
|
||||
|
||||
def _normalize_pipeline_candidates(self, raw_output: Any) -> list[dict[str, Any]]:
|
||||
items = raw_output
|
||||
if isinstance(items, list) and items and isinstance(items[0], list):
|
||||
items = items[0]
|
||||
if not isinstance(items, list):
|
||||
items = [items]
|
||||
normalized: list[dict[str, Any]] = []
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
label = str(item.get("label", ""))
|
||||
score = float(item.get("score", 0.0))
|
||||
normalized.append(
|
||||
{
|
||||
"label": label,
|
||||
"intent_id": self._resolve_label(label),
|
||||
"score": score,
|
||||
}
|
||||
)
|
||||
return normalized[: self._top_k]
|
||||
|
||||
def _resolve_candidates(
|
||||
self,
|
||||
normalized_candidates: list[dict[str, Any]],
|
||||
intents: list[IntentDefinition],
|
||||
) -> dict[str, Any]:
|
||||
intent_map = {intent.intent_id: intent for intent in intents}
|
||||
known_candidates: list[tuple[IntentDefinition, float]] = []
|
||||
resolved_items: list[dict[str, Any]] = []
|
||||
for item in normalized_candidates:
|
||||
intent = intent_map.get(str(item.get("intent_id") or ""))
|
||||
if intent is not None:
|
||||
known_candidates.append((intent, float(item.get("score", 0.0))))
|
||||
resolved_items.append(
|
||||
{
|
||||
"intent": intent,
|
||||
"label": str(item.get("label", "")),
|
||||
"score": float(item.get("score", 0.0)),
|
||||
}
|
||||
)
|
||||
return {
|
||||
"known_candidates": known_candidates,
|
||||
"top_candidate": resolved_items[0],
|
||||
}
|
||||
|
||||
def _predict_with_fallback(
|
||||
self,
|
||||
text: str,
|
||||
intents: list[IntentDefinition],
|
||||
attempted_backend: str,
|
||||
fallback_reason: str,
|
||||
score: float = 0.0,
|
||||
raw_label: str | None = None,
|
||||
raw_candidates: list[dict[str, Any]] | None = None,
|
||||
error_message: str | None = None,
|
||||
) -> ClassificationResult:
|
||||
if self._fallback is None:
|
||||
return ClassificationResult(
|
||||
intent=None,
|
||||
score=score,
|
||||
model_name=attempted_backend,
|
||||
backend_name=attempted_backend,
|
||||
used_fallback=False,
|
||||
fallback_reason=fallback_reason,
|
||||
error_message=error_message,
|
||||
raw_label=raw_label,
|
||||
raw_candidates=raw_candidates or [],
|
||||
)
|
||||
fallback_result = self._fallback.predict(text, intents)
|
||||
return ClassificationResult(
|
||||
intent=fallback_result.intent,
|
||||
score=fallback_result.score,
|
||||
model_name=fallback_result.model_name,
|
||||
candidates=fallback_result.candidates,
|
||||
backend_name=attempted_backend,
|
||||
used_fallback=True,
|
||||
fallback_reason=fallback_reason,
|
||||
error_message=error_message,
|
||||
raw_label=raw_label,
|
||||
raw_candidates=raw_candidates or fallback_result.raw_candidates or [],
|
||||
)
|
||||
|
||||
|
||||
class JointBertIntentClassifier:
|
||||
def __init__(
|
||||
self,
|
||||
nlu: JointBertNLU,
|
||||
threshold: float = 0.35,
|
||||
top_k: int = 3,
|
||||
) -> None:
|
||||
self._nlu = nlu
|
||||
self._threshold = threshold
|
||||
self._top_k = top_k
|
||||
|
||||
def warmup(self, sample_text: str = "打开车窗") -> bool:
|
||||
return self._nlu.warmup(sample_text)
|
||||
|
||||
def predict(self, text: str, intents: list[IntentDefinition]) -> ClassificationResult:
|
||||
result = self._nlu.predict(text, intents)
|
||||
raw_candidates = [
|
||||
{
|
||||
"label": item.intent_id,
|
||||
"intent_id": item.intent_id,
|
||||
"score": item.score,
|
||||
}
|
||||
for item in result.candidates[: self._top_k]
|
||||
]
|
||||
known_candidates = [
|
||||
(intent, item.score)
|
||||
for item in result.candidates[: self._top_k]
|
||||
for intent in intents
|
||||
if intent.intent_id == item.intent_id
|
||||
]
|
||||
selected_intent = next((intent for intent in intents if intent.intent_id == result.intent_id), None)
|
||||
if selected_intent is None or result.intent_score < self._threshold:
|
||||
return ClassificationResult(
|
||||
intent=None,
|
||||
score=result.intent_score,
|
||||
model_name=result.model_name,
|
||||
candidates=known_candidates,
|
||||
backend_name=result.backend_name,
|
||||
fallback_reason=result.error_message or "joint bert score is below threshold or no intent selected",
|
||||
error_message=result.error_message,
|
||||
raw_label=result.intent_id,
|
||||
raw_candidates=raw_candidates,
|
||||
)
|
||||
return ClassificationResult(
|
||||
intent=selected_intent,
|
||||
score=result.intent_score,
|
||||
model_name=result.model_name,
|
||||
candidates=known_candidates,
|
||||
backend_name=result.backend_name,
|
||||
raw_label=result.intent_id,
|
||||
raw_candidates=raw_candidates,
|
||||
)
|
||||
|
||||
|
||||
class RemoteIntentClassifier:
|
||||
"""
|
||||
A remote classifier client.
|
||||
|
||||
Expected response payload:
|
||||
{
|
||||
"intent_id": "cs_query_order",
|
||||
"score": 0.98,
|
||||
"model_name": "bert-remote"
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
timeout_seconds: float = 3.0,
|
||||
threshold: float = 0.5,
|
||||
fallback: IntentClassifier | None = None,
|
||||
label_map_path: str | None = None,
|
||||
top_k: int = 3,
|
||||
) -> None:
|
||||
self._endpoint = endpoint
|
||||
self._timeout_seconds = timeout_seconds
|
||||
self._threshold = threshold
|
||||
self._fallback = fallback
|
||||
self._label_map = self._load_label_map(label_map_path)
|
||||
self._top_k = top_k
|
||||
|
||||
def predict(self, text: str, intents: list[IntentDefinition]) -> ClassificationResult:
|
||||
if not self._endpoint:
|
||||
return self._predict_with_fallback(
|
||||
text,
|
||||
intents,
|
||||
attempted_backend="remote-classifier",
|
||||
fallback_reason="remote endpoint is not configured",
|
||||
error_message="AGENT_CLASSIFIER_REMOTE_URL is empty",
|
||||
)
|
||||
|
||||
payload = json.dumps(
|
||||
{
|
||||
"text": text,
|
||||
"top_k": self._top_k,
|
||||
"labels": [intent.intent_id for intent in intents],
|
||||
}
|
||||
).encode("utf-8")
|
||||
req = request.Request(
|
||||
self._endpoint,
|
||||
data=payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with request.urlopen(req, timeout=self._timeout_seconds) as response:
|
||||
data = json.loads(response.read().decode("utf-8"))
|
||||
except (error.URLError, TimeoutError, ValueError) as exc:
|
||||
return self._predict_with_fallback(
|
||||
text,
|
||||
intents,
|
||||
attempted_backend="remote-classifier",
|
||||
fallback_reason="remote inference failed",
|
||||
error_message=str(exc),
|
||||
)
|
||||
|
||||
model_name = str(data.get("model_name", "remote-classifier"))
|
||||
normalized = self._normalize_remote_candidates(data)
|
||||
if not normalized:
|
||||
return self._predict_with_fallback(
|
||||
text,
|
||||
intents,
|
||||
attempted_backend=model_name,
|
||||
fallback_reason="remote response has no candidates",
|
||||
raw_candidates=[],
|
||||
)
|
||||
|
||||
intent_map = {intent.intent_id: intent for intent in intents}
|
||||
known_candidates = [
|
||||
(intent_map[item["intent_id"]], item["score"])
|
||||
for item in normalized
|
||||
if item["intent_id"] in intent_map
|
||||
]
|
||||
top_candidate = normalized[0]
|
||||
selected_intent = intent_map.get(top_candidate["intent_id"])
|
||||
if selected_intent is None:
|
||||
return self._predict_with_fallback(
|
||||
text,
|
||||
intents,
|
||||
attempted_backend=model_name,
|
||||
fallback_reason="remote top label is not mapped to a known intent",
|
||||
score=top_candidate["score"],
|
||||
raw_label=top_candidate["label"],
|
||||
raw_candidates=normalized,
|
||||
)
|
||||
if top_candidate["score"] < self._threshold:
|
||||
return self._predict_with_fallback(
|
||||
text,
|
||||
intents,
|
||||
attempted_backend=model_name,
|
||||
fallback_reason="remote score is below threshold",
|
||||
score=top_candidate["score"],
|
||||
raw_label=top_candidate["label"],
|
||||
raw_candidates=normalized,
|
||||
)
|
||||
return ClassificationResult(
|
||||
intent=selected_intent,
|
||||
score=top_candidate["score"],
|
||||
model_name=model_name,
|
||||
candidates=known_candidates,
|
||||
backend_name=model_name,
|
||||
raw_label=top_candidate["label"],
|
||||
raw_candidates=normalized,
|
||||
)
|
||||
|
||||
def _load_label_map(self, label_map_path: str | None) -> dict[str, str]:
|
||||
if not label_map_path:
|
||||
return {}
|
||||
path = Path(label_map_path)
|
||||
if not path.exists():
|
||||
return {}
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
return {str(key): str(value) for key, value in data.items()}
|
||||
|
||||
def _resolve_label(self, label: str) -> str | None:
|
||||
if label in self._label_map:
|
||||
return self._label_map[label]
|
||||
return label or None
|
||||
|
||||
def _normalize_remote_candidates(self, data: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
raw_candidates = data.get("candidates") or data.get("predictions") or []
|
||||
if not raw_candidates and data.get("intent_id"):
|
||||
raw_candidates = [
|
||||
{
|
||||
"intent_id": data.get("intent_id"),
|
||||
"label": data.get("label") or data.get("intent_id"),
|
||||
"score": data.get("score", 0.0),
|
||||
}
|
||||
]
|
||||
|
||||
normalized: list[dict[str, Any]] = []
|
||||
for item in raw_candidates:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
label = str(item.get("label") or item.get("intent_id") or "")
|
||||
intent_id = str(item.get("intent_id") or self._resolve_label(label) or "")
|
||||
normalized.append(
|
||||
{
|
||||
"label": label,
|
||||
"intent_id": intent_id,
|
||||
"score": float(item.get("score", 0.0)),
|
||||
}
|
||||
)
|
||||
return normalized[: self._top_k]
|
||||
|
||||
def _predict_with_fallback(
|
||||
self,
|
||||
text: str,
|
||||
intents: list[IntentDefinition],
|
||||
attempted_backend: str,
|
||||
fallback_reason: str,
|
||||
score: float = 0.0,
|
||||
raw_label: str | None = None,
|
||||
raw_candidates: list[dict[str, Any]] | None = None,
|
||||
error_message: str | None = None,
|
||||
) -> ClassificationResult:
|
||||
if self._fallback is None:
|
||||
return ClassificationResult(
|
||||
intent=None,
|
||||
score=score,
|
||||
model_name=attempted_backend,
|
||||
backend_name=attempted_backend,
|
||||
used_fallback=False,
|
||||
fallback_reason=fallback_reason,
|
||||
error_message=error_message,
|
||||
raw_label=raw_label,
|
||||
raw_candidates=raw_candidates or [],
|
||||
)
|
||||
fallback_result = self._fallback.predict(text, intents)
|
||||
return ClassificationResult(
|
||||
intent=fallback_result.intent,
|
||||
score=fallback_result.score,
|
||||
model_name=fallback_result.model_name,
|
||||
candidates=fallback_result.candidates,
|
||||
backend_name=attempted_backend,
|
||||
used_fallback=True,
|
||||
fallback_reason=fallback_reason,
|
||||
error_message=error_message,
|
||||
raw_label=raw_label,
|
||||
raw_candidates=raw_candidates or fallback_result.raw_candidates or [],
|
||||
)
|
||||
173
intelligent_cabin/app/services/config_loader.py
Normal file
173
intelligent_cabin/app/services/config_loader.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from app.schemas.configuration import (
|
||||
ActionsConfig,
|
||||
ContextRewriteConfig,
|
||||
DialogActsConfig,
|
||||
DialogRulesConfig,
|
||||
DomainConfig,
|
||||
FormsConfig,
|
||||
ResponsesConfig,
|
||||
WorkflowTemplatesConfig,
|
||||
)
|
||||
from app.services.dialog_act import DialogActEngine
|
||||
from app.services.dialog_rules import DialogRuleEngine
|
||||
from app.services.intent_registry import IntentRegistry
|
||||
from app.services.rewrite_engine import ContextRewriteEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeConfigBundle:
|
||||
intent_registry: IntentRegistry
|
||||
response_templates: dict[str, str]
|
||||
intent_hints: dict[str, str]
|
||||
dialog_rules: DialogRuleEngine
|
||||
dialog_act_engine: DialogActEngine
|
||||
workflow_templates: WorkflowTemplatesConfig
|
||||
rewrite_engine: ContextRewriteEngine = field(default_factory=ContextRewriteEngine)
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
def __init__(
|
||||
self,
|
||||
domain_path: str,
|
||||
action_path: str,
|
||||
response_path: str,
|
||||
form_path: str | None = None,
|
||||
rule_path: str | None = None,
|
||||
dialog_act_path: str | None = None,
|
||||
workflow_path: str | None = None,
|
||||
legacy_intent_path: str | None = None,
|
||||
context_rewrite_path: str | None = None,
|
||||
) -> None:
|
||||
self._domain_path = Path(domain_path)
|
||||
self._action_path = Path(action_path)
|
||||
self._response_path = Path(response_path)
|
||||
self._form_path = Path(form_path) if form_path else None
|
||||
self._rule_path = Path(rule_path) if rule_path else None
|
||||
self._dialog_act_path = Path(dialog_act_path) if dialog_act_path else None
|
||||
self._workflow_path = Path(workflow_path) if workflow_path else None
|
||||
self._legacy_intent_path = Path(legacy_intent_path) if legacy_intent_path else None
|
||||
self._context_rewrite_path = Path(context_rewrite_path) if context_rewrite_path else None
|
||||
|
||||
def load(self) -> RuntimeConfigBundle:
|
||||
if self._domain_path.exists() and self._action_path.exists():
|
||||
return self._load_from_config_files()
|
||||
if self._legacy_intent_path is not None and self._legacy_intent_path.exists():
|
||||
return RuntimeConfigBundle(
|
||||
intent_registry=IntentRegistry.from_json(str(self._legacy_intent_path)),
|
||||
response_templates=self._load_response_templates(),
|
||||
intent_hints={},
|
||||
dialog_rules=self._load_dialog_rules(),
|
||||
dialog_act_engine=self._load_dialog_act_engine(),
|
||||
workflow_templates=self._load_workflow_templates(),
|
||||
rewrite_engine=self._load_rewrite_engine(),
|
||||
)
|
||||
raise FileNotFoundError(
|
||||
"no runtime config found, expected config/*.yml or legacy intent json"
|
||||
)
|
||||
|
||||
def _load_from_config_files(self) -> RuntimeConfigBundle:
|
||||
domain = DomainConfig.model_validate(self._read_structured_file(self._domain_path))
|
||||
actions = ActionsConfig.model_validate(self._read_structured_file(self._action_path))
|
||||
forms = self._load_forms()
|
||||
action_map = {item.action_id: item for item in actions.actions}
|
||||
form_map = {item.intent_id: item for item in forms.forms}
|
||||
intents = []
|
||||
for item in domain.intents:
|
||||
form = form_map.get(item.intent_id)
|
||||
if form is not None:
|
||||
item = item.model_copy(
|
||||
update={
|
||||
"required_slots": form.required_slots,
|
||||
"ask_templates": form.ask_templates,
|
||||
}
|
||||
)
|
||||
intents.append(item.to_intent_definition(action_map))
|
||||
intent_hints = {
|
||||
item.intent_id: item.label.strip()
|
||||
for item in domain.intents
|
||||
if item.label and item.label.strip()
|
||||
}
|
||||
return RuntimeConfigBundle(
|
||||
intent_registry=IntentRegistry(intents),
|
||||
response_templates=self._load_response_templates(),
|
||||
intent_hints=intent_hints,
|
||||
dialog_rules=self._load_dialog_rules(),
|
||||
dialog_act_engine=self._load_dialog_act_engine(),
|
||||
workflow_templates=self._load_workflow_templates(),
|
||||
rewrite_engine=self._load_rewrite_engine(),
|
||||
)
|
||||
|
||||
def _load_response_templates(self) -> dict[str, str]:
|
||||
if not self._response_path.exists():
|
||||
return {}
|
||||
raw = self._read_structured_file(self._response_path)
|
||||
parsed = ResponsesConfig.model_validate(raw)
|
||||
return parsed.templates
|
||||
|
||||
def _load_forms(self) -> FormsConfig:
|
||||
if self._form_path is None or not self._form_path.exists():
|
||||
return FormsConfig()
|
||||
raw = self._read_structured_file(self._form_path)
|
||||
return FormsConfig.model_validate(raw)
|
||||
|
||||
def _load_dialog_rules(self) -> DialogRuleEngine:
|
||||
if self._rule_path is None or not self._rule_path.exists():
|
||||
return DialogRuleEngine()
|
||||
raw = self._read_structured_file(self._rule_path)
|
||||
parsed = DialogRulesConfig.model_validate(raw)
|
||||
return DialogRuleEngine(
|
||||
stop_phrases=tuple(parsed.stop.phrases) or DialogRuleEngine.stop_phrases,
|
||||
positive_confirmation_tokens=tuple(parsed.confirmation.positive_tokens)
|
||||
or DialogRuleEngine.positive_confirmation_tokens,
|
||||
negative_confirmation_tokens=tuple(parsed.confirmation.negative_tokens)
|
||||
or DialogRuleEngine.negative_confirmation_tokens,
|
||||
confirmation_required_intents=tuple(parsed.confirmation.required_intents)
|
||||
or DialogRuleEngine.confirmation_required_intents,
|
||||
confirmation_required_risk_levels=tuple(parsed.confirmation.required_risk_levels)
|
||||
or DialogRuleEngine.confirmation_required_risk_levels,
|
||||
metadata={"source": str(self._rule_path)},
|
||||
)
|
||||
|
||||
def _load_dialog_act_engine(self) -> DialogActEngine:
|
||||
if self._dialog_act_path is None or not self._dialog_act_path.exists():
|
||||
return DialogActEngine()
|
||||
raw = self._read_structured_file(self._dialog_act_path)
|
||||
parsed = DialogActsConfig.model_validate(raw)
|
||||
return DialogActEngine(
|
||||
patterns={
|
||||
item.act_id: tuple(item.phrases)
|
||||
for item in parsed.acts
|
||||
},
|
||||
numeric_patterns={
|
||||
item.act_id: tuple(item.numeric_patterns)
|
||||
for item in parsed.acts
|
||||
if item.numeric_patterns
|
||||
},
|
||||
)
|
||||
|
||||
def _load_rewrite_engine(self) -> ContextRewriteEngine:
|
||||
if self._context_rewrite_path is None or not self._context_rewrite_path.exists():
|
||||
return ContextRewriteEngine()
|
||||
raw = self._read_structured_file(self._context_rewrite_path)
|
||||
config = ContextRewriteConfig.model_validate(raw)
|
||||
return ContextRewriteEngine(config=config)
|
||||
|
||||
def _load_workflow_templates(self) -> WorkflowTemplatesConfig:
|
||||
if self._workflow_path is None or not self._workflow_path.exists():
|
||||
return WorkflowTemplatesConfig()
|
||||
raw = self._read_structured_file(self._workflow_path)
|
||||
return WorkflowTemplatesConfig.model_validate(raw)
|
||||
|
||||
def _read_structured_file(self, path: Path) -> dict[str, Any]:
|
||||
if path.suffix.lower() == ".json":
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
return yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||
51
intelligent_cabin/app/services/dialog_act.py
Normal file
51
intelligent_cabin/app/services/dialog_act.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DialogActEngine:
|
||||
"""
|
||||
基于配置词表的对话行为检测器。
|
||||
- patterns : act_id → 触发词组 tuple,逐词包含匹配
|
||||
- numeric_patterns : act_id → 正则 tuple,全文正则匹配(用于数字类 inform)
|
||||
|
||||
词表和正则均从 config/dialog_acts.yml 加载,不同设备部署时修改配置文件即可,无需改代码。
|
||||
"""
|
||||
|
||||
patterns: dict[str, tuple[str, ...]] = field(
|
||||
default_factory=lambda: {
|
||||
"affirm": ("确认", "好的", "继续", "可以", "确定"),
|
||||
"deny": ("不要", "不行", "否", "不"),
|
||||
"cancel": ("取消", "算了", "不用了", "停止"),
|
||||
"modify": ("改成", "换成", "再低一点", "再高一点", "调大", "调小"),
|
||||
"chitchat": ("你好", "谢谢", "再见", "天气", "真不错"),
|
||||
"request": ("帮我", "打开", "关闭", "导航", "播放", "查询", "查"),
|
||||
"inform": (),
|
||||
}
|
||||
)
|
||||
# act_id → 正则表达式 tuple(全文 search,任意命中即触发)
|
||||
numeric_patterns: dict[str, tuple[str, ...]] = field(
|
||||
default_factory=lambda: {
|
||||
"inform": (r"\d+",),
|
||||
}
|
||||
)
|
||||
|
||||
def detect(self, text: str) -> str:
|
||||
normalized = re.sub(r"\s+", "", text.strip().lower())
|
||||
if not normalized:
|
||||
return "unknown"
|
||||
|
||||
# 1. 词表包含匹配(保持原有优先级顺序)
|
||||
for act_id, phrases in self.patterns.items():
|
||||
if any(phrase and phrase in normalized for phrase in phrases):
|
||||
return act_id
|
||||
|
||||
# 2. 正则匹配(主要用于 inform 的数字检测)
|
||||
for act_id, regexes in self.numeric_patterns.items():
|
||||
for pattern in regexes:
|
||||
if re.search(pattern, normalized):
|
||||
return act_id
|
||||
|
||||
return "unknown"
|
||||
62
intelligent_cabin/app/services/dialog_rules.py
Normal file
62
intelligent_cabin/app/services/dialog_rules.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DialogRuleEngine:
|
||||
stop_phrases: tuple[str, ...] = (
|
||||
"不用了",
|
||||
"算了",
|
||||
"先不要了",
|
||||
"先这样吧",
|
||||
"停一下",
|
||||
"停止",
|
||||
"停止当前任务",
|
||||
"结束这次操作",
|
||||
"别弄了",
|
||||
"不需要了",
|
||||
)
|
||||
positive_confirmation_tokens: tuple[str, ...] = (
|
||||
"确认",
|
||||
"好的",
|
||||
"是",
|
||||
"继续",
|
||||
"可以",
|
||||
"确定",
|
||||
"yes",
|
||||
"ok",
|
||||
)
|
||||
negative_confirmation_tokens: tuple[str, ...] = (
|
||||
"取消",
|
||||
"不用",
|
||||
"不要",
|
||||
"否",
|
||||
"no",
|
||||
"算了",
|
||||
"停止",
|
||||
)
|
||||
confirmation_required_intents: tuple[str, ...] = ("cs_cancel_order",)
|
||||
confirmation_required_risk_levels: tuple[str, ...] = ("high",)
|
||||
metadata: dict[str, object] = field(default_factory=dict)
|
||||
|
||||
def is_stop_request(self, text: str) -> bool:
|
||||
normalized = text.strip().lower().replace(" ", "")
|
||||
if not normalized:
|
||||
return False
|
||||
return any(phrase in normalized for phrase in self.stop_phrases)
|
||||
|
||||
def parse_confirmation_decision(self, text: str) -> bool | None:
|
||||
normalized = text.strip().lower()
|
||||
if not normalized:
|
||||
return None
|
||||
if any(token == normalized or token in normalized for token in self.negative_confirmation_tokens):
|
||||
return False
|
||||
if any(token == normalized or token in normalized for token in self.positive_confirmation_tokens):
|
||||
return True
|
||||
return None
|
||||
|
||||
def requires_confirmation(self, intent_id: str, risk_level: str) -> bool:
|
||||
if intent_id in self.confirmation_required_intents:
|
||||
return True
|
||||
return risk_level in self.confirmation_required_risk_levels
|
||||
29
intelligent_cabin/app/services/intent_registry.py
Normal file
29
intelligent_cabin/app/services/intent_registry.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from app.schemas.intent import IntentDefinition
|
||||
|
||||
|
||||
class IntentRegistry:
|
||||
def __init__(self, intents: list[IntentDefinition]) -> None:
|
||||
self._intents = {intent.intent_id: intent for intent in intents}
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, file_path: str) -> "IntentRegistry":
|
||||
data = json.loads(Path(file_path).read_text(encoding="utf-8"))
|
||||
intents = [IntentDefinition.model_validate(item) for item in data]
|
||||
return cls(intents)
|
||||
|
||||
def get(self, intent_id: str) -> IntentDefinition:
|
||||
return self._intents[intent_id]
|
||||
|
||||
def list(self) -> list[IntentDefinition]:
|
||||
return list(self._intents.values())
|
||||
|
||||
def match(self, text: str) -> IntentDefinition | None:
|
||||
for intent in self._intents.values():
|
||||
if any(keyword in text for keyword in intent.keywords):
|
||||
return intent
|
||||
return None
|
||||
430
intelligent_cabin/app/services/joint_nlu.py
Normal file
430
intelligent_cabin/app/services/joint_nlu.py
Normal file
@@ -0,0 +1,430 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
||||
|
||||
from app.schemas.intent import IntentDefinition
|
||||
|
||||
|
||||
OPTIONAL_SLOT_NAMES_BY_INTENT: dict[str, set[str]] = {
|
||||
"cabin_play_music": {"song", "genre"},
|
||||
}
|
||||
|
||||
BLOCKED_INTENT_LABELS = {"__social__", "__out_of_scope__"}
|
||||
|
||||
|
||||
def allowed_slot_names(intent_id: str, required_slots: list[str] | None = None) -> set[str]:
|
||||
required = set(required_slots or [])
|
||||
return required | OPTIONAL_SLOT_NAMES_BY_INTENT.get(intent_id, set())
|
||||
|
||||
|
||||
@dataclass
|
||||
class JointSlot:
|
||||
slot_name: str
|
||||
value: str
|
||||
start: int
|
||||
end: int
|
||||
score: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class JointCandidate:
|
||||
intent_id: str
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class JointNluResult:
|
||||
intent_id: str | None = None
|
||||
intent_score: float = 0.0
|
||||
candidates: list[JointCandidate] = field(default_factory=list)
|
||||
multi_intent_candidates: list[JointCandidate] = field(default_factory=list)
|
||||
slots: dict[str, Any] = field(default_factory=dict)
|
||||
slot_items: list[JointSlot] = field(default_factory=list)
|
||||
model_name: str = "joint-bert-local"
|
||||
backend_name: str = "joint-bert-local"
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class JointBertForNLU(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
base_model_name: str,
|
||||
num_intents: int,
|
||||
num_slot_labels: int,
|
||||
encoder_config_path: str | Path | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if encoder_config_path is not None:
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_config_path, local_files_only=True)
|
||||
self.encoder = AutoModel.from_config(encoder_config)
|
||||
else:
|
||||
self.encoder = AutoModel.from_pretrained(base_model_name)
|
||||
hidden_size = int(self.encoder.config.hidden_size)
|
||||
dropout_prob = float(getattr(self.encoder.config, "hidden_dropout_prob", 0.1))
|
||||
self.dropout = torch.nn.Dropout(dropout_prob)
|
||||
self.intent_classifier = torch.nn.Linear(hidden_size, num_intents)
|
||||
self.slot_classifier = torch.nn.Linear(hidden_size, num_slot_labels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
token_type_ids: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
encoder_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
if token_type_ids is not None:
|
||||
encoder_kwargs["token_type_ids"] = token_type_ids
|
||||
outputs = self.encoder(**encoder_kwargs)
|
||||
sequence_output = self.dropout(outputs.last_hidden_state)
|
||||
pooled_output = self.dropout(sequence_output[:, 0])
|
||||
intent_logits = self.intent_classifier(pooled_output)
|
||||
slot_logits = self.slot_classifier(sequence_output)
|
||||
return intent_logits, slot_logits
|
||||
|
||||
|
||||
class JointBertNLU:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
intent_threshold: float | None = None,
|
||||
multi_intent_threshold: float | None = None,
|
||||
top_k: int = 5,
|
||||
max_multi_intents: int = 4,
|
||||
max_cache_size: int = 8,
|
||||
) -> None:
|
||||
self._model_path = Path(model_path)
|
||||
self._intent_threshold = intent_threshold
|
||||
self._multi_intent_threshold = multi_intent_threshold
|
||||
self._top_k = top_k
|
||||
self._max_multi_intents = max_multi_intents
|
||||
self._max_cache_size = max_cache_size
|
||||
self._runtime: tuple[AutoTokenizer, JointBertForNLU, dict[str, Any], torch.device] | None = None
|
||||
self._warmup_elapsed_ms: float | None = None
|
||||
self._warmup_error_message: str | None = None
|
||||
self._warmed_up = False
|
||||
self._cache: OrderedDict[str, dict[str, Any]] = OrderedDict()
|
||||
|
||||
def warmup(self, sample_text: str = "把空调调到22度") -> bool:
|
||||
started_at = perf_counter()
|
||||
try:
|
||||
self._predict_raw(sample_text)
|
||||
except Exception as exc:
|
||||
self._warmup_error_message = str(exc)
|
||||
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
||||
return False
|
||||
self._warmup_error_message = None
|
||||
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
||||
self._warmed_up = True
|
||||
return True
|
||||
|
||||
def predict(self, text: str, intents: list[IntentDefinition]) -> JointNluResult:
|
||||
try:
|
||||
raw_result = self._predict_raw(text)
|
||||
except Exception as exc:
|
||||
return JointNluResult(error_message=str(exc))
|
||||
candidates = self._filter_known_candidates(raw_result["candidates"], intents, limit=self._top_k)
|
||||
multi_candidates = self.predict_multi_intents(text, intents)
|
||||
top_candidate = candidates[0] if candidates else None
|
||||
if top_candidate is None or top_candidate.score < self._resolved_intent_threshold():
|
||||
return JointNluResult(
|
||||
intent_id=None,
|
||||
intent_score=top_candidate.score if top_candidate is not None else 0.0,
|
||||
candidates=candidates,
|
||||
multi_intent_candidates=multi_candidates,
|
||||
slots={},
|
||||
slot_items=[],
|
||||
)
|
||||
intent_def = next((intent for intent in intents if intent.intent_id == top_candidate.intent_id), None)
|
||||
if intent_def is None:
|
||||
return JointNluResult(
|
||||
intent_id=None,
|
||||
intent_score=top_candidate.score,
|
||||
candidates=candidates,
|
||||
multi_intent_candidates=multi_candidates,
|
||||
slots={},
|
||||
slot_items=[],
|
||||
)
|
||||
slot_items = self._filter_slot_items(raw_result["slot_items"], intent_def.intent_id, intent_def.required_slots)
|
||||
return JointNluResult(
|
||||
intent_id=top_candidate.intent_id,
|
||||
intent_score=top_candidate.score,
|
||||
candidates=candidates,
|
||||
multi_intent_candidates=multi_candidates,
|
||||
slots=self._slot_items_to_dict(slot_items),
|
||||
slot_items=slot_items,
|
||||
)
|
||||
|
||||
def predict_multi_intents(
|
||||
self,
|
||||
text: str,
|
||||
intents: list[IntentDefinition],
|
||||
threshold: float | None = None,
|
||||
max_labels: int | None = None,
|
||||
top_k: int | None = None,
|
||||
) -> list[JointCandidate]:
|
||||
try:
|
||||
raw_result = self._predict_raw(text)
|
||||
except Exception:
|
||||
return []
|
||||
threshold = self._multi_intent_threshold if threshold is None else threshold
|
||||
if threshold is None:
|
||||
threshold = self._resolved_multi_intent_threshold()
|
||||
max_labels = self._max_multi_intents if max_labels is None else max_labels
|
||||
ranked = self._filter_known_candidates(raw_result["candidates"], intents, limit=top_k or self._top_k)
|
||||
selected: list[JointCandidate] = []
|
||||
for item in ranked:
|
||||
if item.score < threshold:
|
||||
continue
|
||||
selected.append(item)
|
||||
if len(selected) >= max_labels:
|
||||
break
|
||||
return selected
|
||||
|
||||
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
|
||||
try:
|
||||
raw_result = self._predict_raw(text)
|
||||
except Exception:
|
||||
return {}
|
||||
slot_items = self._filter_slot_items(raw_result["slot_items"], intent.intent_id, intent.required_slots)
|
||||
return self._slot_items_to_dict(slot_items)
|
||||
|
||||
def extract_slots_by_intent_id(
|
||||
self,
|
||||
text: str,
|
||||
intent_id: str,
|
||||
required_slots: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
raw_result = self._predict_raw(text)
|
||||
except Exception:
|
||||
return {}
|
||||
slot_items = self._filter_slot_items(raw_result["slot_items"], intent_id, required_slots or [])
|
||||
return self._slot_items_to_dict(slot_items)
|
||||
|
||||
def _filter_known_candidates(
|
||||
self,
|
||||
candidates: list[JointCandidate],
|
||||
intents: list[IntentDefinition],
|
||||
limit: int | None = None,
|
||||
) -> list[JointCandidate]:
|
||||
known_intents = {intent.intent_id for intent in intents}
|
||||
filtered = [
|
||||
item
|
||||
for item in candidates
|
||||
if item.intent_id in known_intents and item.intent_id not in BLOCKED_INTENT_LABELS
|
||||
]
|
||||
return filtered[:limit] if limit is not None else filtered
|
||||
|
||||
def _slot_items_to_dict(self, slot_items: list[JointSlot]) -> dict[str, Any]:
|
||||
slots: dict[str, Any] = {}
|
||||
for item in slot_items:
|
||||
if item.slot_name == "temperature":
|
||||
digits = "".join(ch for ch in item.value if ch.isdigit())
|
||||
if digits:
|
||||
slots[item.slot_name] = int(digits)
|
||||
continue
|
||||
slots[item.slot_name] = item.value
|
||||
return slots
|
||||
|
||||
def _filter_slot_items(
|
||||
self,
|
||||
slot_items: list[JointSlot],
|
||||
intent_id: str,
|
||||
required_slots: list[str],
|
||||
) -> list[JointSlot]:
|
||||
allowed = allowed_slot_names(intent_id, required_slots)
|
||||
if not allowed:
|
||||
return []
|
||||
filtered = [item for item in slot_items if item.slot_name in allowed]
|
||||
deduped: list[JointSlot] = []
|
||||
seen: set[tuple[str, int, int]] = set()
|
||||
for item in filtered:
|
||||
key = (item.slot_name, item.start, item.end)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(item)
|
||||
return deduped
|
||||
|
||||
def _predict_raw(self, text: str) -> dict[str, Any]:
|
||||
normalized = (text or "").strip()
|
||||
if not normalized:
|
||||
return {"candidates": [], "slot_items": []}
|
||||
if normalized in self._cache:
|
||||
cached = self._cache.pop(normalized)
|
||||
self._cache[normalized] = cached
|
||||
return cached
|
||||
tokenizer, model, metadata, device = self._load_runtime()
|
||||
encoded = tokenizer(
|
||||
normalized,
|
||||
truncation=True,
|
||||
max_length=int(metadata.get("max_length", 64)),
|
||||
return_offsets_mapping=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
offset_mapping = encoded.pop("offset_mapping")[0].tolist()
|
||||
encoded = {key: value.to(device) for key, value in encoded.items()}
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
intent_logits, slot_logits = model(**encoded)
|
||||
slot_probs = torch.softmax(slot_logits, dim=-1)[0].detach().cpu()
|
||||
slot_ids = torch.argmax(slot_probs, dim=-1).tolist()
|
||||
intent_probs = self._intent_probabilities(intent_logits.detach().cpu()[0], metadata)
|
||||
intent_labels = metadata.get("intent_labels", [])
|
||||
slot_labels = metadata.get("slot_labels", [])
|
||||
candidates = [
|
||||
JointCandidate(intent_id=str(intent_labels[index]), score=float(score))
|
||||
for index, score in sorted(
|
||||
list(enumerate(intent_probs)),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
]
|
||||
slot_items = self._decode_slot_items(
|
||||
text=normalized,
|
||||
offset_mapping=offset_mapping,
|
||||
slot_ids=slot_ids,
|
||||
slot_probs=slot_probs,
|
||||
slot_labels=slot_labels,
|
||||
)
|
||||
result = {
|
||||
"candidates": candidates,
|
||||
"slot_items": slot_items,
|
||||
}
|
||||
self._cache[normalized] = result
|
||||
while len(self._cache) > self._max_cache_size:
|
||||
self._cache.popitem(last=False)
|
||||
return result
|
||||
|
||||
def _intent_probabilities(self, intent_logits: torch.Tensor, metadata: dict[str, Any]) -> list[float]:
|
||||
task_type = str(metadata.get("intent_task", "single_label")).strip() or "single_label"
|
||||
if task_type == "multi_label":
|
||||
return torch.sigmoid(intent_logits).tolist()
|
||||
return torch.softmax(intent_logits, dim=-1).tolist()
|
||||
|
||||
def _decode_slot_items(
|
||||
self,
|
||||
text: str,
|
||||
offset_mapping: list[list[int]],
|
||||
slot_ids: list[int],
|
||||
slot_probs: torch.Tensor,
|
||||
slot_labels: list[str],
|
||||
) -> list[JointSlot]:
|
||||
items: list[JointSlot] = []
|
||||
current_name: str | None = None
|
||||
current_start: int | None = None
|
||||
current_end: int | None = None
|
||||
current_scores: list[float] = []
|
||||
|
||||
def flush() -> None:
|
||||
nonlocal current_name, current_start, current_end, current_scores
|
||||
if current_name is None or current_start is None or current_end is None or current_start >= current_end:
|
||||
current_name = None
|
||||
current_start = None
|
||||
current_end = None
|
||||
current_scores = []
|
||||
return
|
||||
value = text[current_start:current_end].strip()
|
||||
if value:
|
||||
items.append(
|
||||
JointSlot(
|
||||
slot_name=current_name,
|
||||
value=value,
|
||||
start=current_start,
|
||||
end=current_end,
|
||||
score=round(sum(current_scores) / max(len(current_scores), 1), 4),
|
||||
)
|
||||
)
|
||||
current_name = None
|
||||
current_start = None
|
||||
current_end = None
|
||||
current_scores = []
|
||||
|
||||
for index, label_id in enumerate(slot_ids):
|
||||
if index >= len(offset_mapping):
|
||||
break
|
||||
start, end = offset_mapping[index]
|
||||
if end <= start:
|
||||
flush()
|
||||
continue
|
||||
label = str(slot_labels[label_id]) if label_id < len(slot_labels) else "O"
|
||||
token_score = float(slot_probs[index][label_id].item())
|
||||
if label == "O":
|
||||
flush()
|
||||
continue
|
||||
prefix, _, name = label.partition("-")
|
||||
if prefix == "B" or current_name != name:
|
||||
flush()
|
||||
current_name = name
|
||||
current_start = start
|
||||
current_end = end
|
||||
current_scores = [token_score]
|
||||
continue
|
||||
current_end = end
|
||||
current_scores.append(token_score)
|
||||
flush()
|
||||
return items
|
||||
|
||||
def _load_runtime(self) -> tuple[AutoTokenizer, JointBertForNLU, dict[str, Any], torch.device]:
|
||||
if self._runtime is not None:
|
||||
return self._runtime
|
||||
if not self._model_path.exists():
|
||||
raise FileNotFoundError(f"joint nlu model path not found: {self._model_path}")
|
||||
metadata_path = self._model_path / "joint_nlu_config.json"
|
||||
state_dict_path = self._model_path / "model_state.pt"
|
||||
if not metadata_path.exists():
|
||||
raise FileNotFoundError(f"joint nlu config missing: {metadata_path}")
|
||||
if not state_dict_path.exists():
|
||||
raise FileNotFoundError(f"joint nlu model state missing: {state_dict_path}")
|
||||
metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
|
||||
tokenizer = AutoTokenizer.from_pretrained(self._model_path)
|
||||
model = JointBertForNLU(
|
||||
base_model_name=str(metadata["base_model_name"]),
|
||||
num_intents=len(metadata["intent_labels"]),
|
||||
num_slot_labels=len(metadata["slot_labels"]),
|
||||
encoder_config_path=self._resolve_encoder_config_path(metadata),
|
||||
)
|
||||
state_dict = torch.load(state_dict_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict)
|
||||
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||
model.to(device)
|
||||
self._runtime = (tokenizer, model, metadata, device)
|
||||
return self._runtime
|
||||
|
||||
def _resolve_encoder_config_path(self, metadata: dict[str, Any]) -> Path | None:
|
||||
local_config = self._model_path / "config.json"
|
||||
if local_config.exists():
|
||||
return self._model_path
|
||||
|
||||
base_model_path = Path(str(metadata.get("base_model_name", "")))
|
||||
if base_model_path.exists() and (base_model_path / "config.json").exists():
|
||||
return base_model_path
|
||||
|
||||
for candidate_name in ("local_bert_intent", "local_bert_multi_intent"):
|
||||
candidate_path = self._model_path.parent / candidate_name
|
||||
if (candidate_path / "config.json").exists():
|
||||
return candidate_path
|
||||
return None
|
||||
|
||||
def _resolved_intent_threshold(self) -> float:
|
||||
if self._intent_threshold is not None:
|
||||
return self._intent_threshold
|
||||
metadata = self._runtime[2] if self._runtime is not None else {}
|
||||
return float(metadata.get("intent_threshold", 0.35))
|
||||
|
||||
def _resolved_multi_intent_threshold(self) -> float:
|
||||
if self._multi_intent_threshold is not None:
|
||||
return self._multi_intent_threshold
|
||||
metadata = self._runtime[2] if self._runtime is not None else {}
|
||||
return float(metadata.get("multi_intent_threshold", metadata.get("intent_threshold", 0.45)))
|
||||
238
intelligent_cabin/app/services/knowledge_llm.py
Normal file
238
intelligent_cabin/app/services/knowledge_llm.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
app/services/knowledge_llm.py
|
||||
|
||||
当 BERT NLU 未命中时,使用 LLM + knowledge_search function call 查询本地知识库。
|
||||
|
||||
流程:
|
||||
1. 构建 tools=[knowledge_search] 发给 LLM
|
||||
2. 若 LLM 返回 tool_calls → 执行 KnowledgeStore.search() → 拼结果再发一次 LLM
|
||||
3. LLM 生成最终回复 reply_text + knowledge_doc_id(可选)
|
||||
|
||||
返回 KnowledgeReplyResult,包含:
|
||||
- reply_text: 简短自然语言摘要
|
||||
- doc_id / doc_content: 命中的知识文档(供前端渲染 KnowledgeArtifact)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from urllib import error, request
|
||||
|
||||
from app.services.knowledge_store import KnowledgeDoc, KnowledgeStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeReplyResult:
|
||||
reply_text: str
|
||||
backend: str
|
||||
model_name: str
|
||||
doc_id: str | None = None
|
||||
doc_content: str | None = None # 原始 MD 内容,前端渲染用
|
||||
doc_title: str | None = None
|
||||
error_message: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ── LLM 工具定义(OpenAI function calling 格式,DashScope 兼容)────────────────
|
||||
_KNOWLEDGE_SEARCH_TOOL: dict[str, Any] = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "knowledge_search",
|
||||
"description": (
|
||||
"搜索本地设备知识库,获取焊管机/弯管机产线相关的故障排查、操作规程等知识。"
|
||||
"当用户问到设备故障、报警处理、操作方法、工艺参数时请调用此工具。"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "搜索关键词,如'虚焊报警'、'激光扫描仪操作'、'弯管模具调节'等",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_SYSTEM_PROMPT = """\
|
||||
你是焊管机产线智能助手,负责回答操作工人关于设备故障、工艺调节、仪器使用的问题。
|
||||
你有一个工具 knowledge_search 可以查询本地设备知识库,遇到设备类问题时请先调用它。
|
||||
回答时语言简洁、口语化,先给出结论,再说步骤,总长度不超过 100 字。
|
||||
如果工具返回了相关知识,请基于知识内容回答,不要编造。
|
||||
如果没有找到相关知识,诚实告知"暂未找到相关资料,建议联系技术支持"。
|
||||
"""
|
||||
|
||||
|
||||
class DashScopeKnowledgeLLM:
|
||||
"""使用 DashScope(OpenAI 兼容 API)+ function calling 的知识库问答器。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
knowledge_store: KnowledgeStore,
|
||||
timeout_seconds: float = 12.0,
|
||||
max_tool_rounds: int = 2,
|
||||
) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._api_key = api_key
|
||||
self._model_name = model_name
|
||||
self._store = knowledge_store
|
||||
self._timeout = timeout_seconds
|
||||
self._max_tool_rounds = max_tool_rounds
|
||||
|
||||
# ── 主入口 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def reply(self, user_text: str) -> KnowledgeReplyResult:
|
||||
"""完整 function-call 对话流(最多 max_tool_rounds 轮工具调用)。"""
|
||||
if not self._base_url or not self._api_key or not self._model_name:
|
||||
return self._local_fallback(user_text, "LLM not configured")
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": _SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_text},
|
||||
]
|
||||
|
||||
found_doc: KnowledgeDoc | None = None
|
||||
|
||||
for _round in range(self._max_tool_rounds):
|
||||
raw = self._chat(messages, tools=[_KNOWLEDGE_SEARCH_TOOL])
|
||||
if raw is None:
|
||||
return self._local_fallback(user_text, "LLM request failed")
|
||||
|
||||
choice = self._first_choice(raw)
|
||||
if choice is None:
|
||||
return self._local_fallback(user_text, "empty choices")
|
||||
|
||||
finish_reason = choice.get("finish_reason", "")
|
||||
message = choice.get("message", {})
|
||||
|
||||
# ── 工具调用分支 ─────────────────────────────────────────────────
|
||||
if finish_reason == "tool_calls" or message.get("tool_calls"):
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
messages.append({"role": "assistant", **message})
|
||||
|
||||
for tc in tool_calls:
|
||||
fn_name = tc.get("function", {}).get("name", "")
|
||||
fn_args_raw = tc.get("function", {}).get("arguments", "{}")
|
||||
tc_id = tc.get("id", "call_0")
|
||||
|
||||
if fn_name == "knowledge_search":
|
||||
try:
|
||||
fn_args = json.loads(fn_args_raw)
|
||||
except json.JSONDecodeError:
|
||||
fn_args = {"query": user_text}
|
||||
|
||||
query = fn_args.get("query", user_text)
|
||||
tool_result, found_doc = self._run_knowledge_search(query)
|
||||
else:
|
||||
tool_result = f"Unknown tool: {fn_name}"
|
||||
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc_id,
|
||||
"content": tool_result,
|
||||
})
|
||||
# 继续下一轮 LLM 调用
|
||||
continue
|
||||
|
||||
# ── 正常文本回复 ─────────────────────────────────────────────────
|
||||
content = self._extract_content(message)
|
||||
if not content:
|
||||
return self._local_fallback(user_text, "empty content")
|
||||
|
||||
return KnowledgeReplyResult(
|
||||
reply_text=content,
|
||||
backend="dashscope",
|
||||
model_name=self._model_name,
|
||||
doc_id=found_doc.doc_id if found_doc else None,
|
||||
doc_content=found_doc.content if found_doc else None,
|
||||
doc_title=found_doc.title if found_doc else None,
|
||||
)
|
||||
|
||||
# 超出工具调用轮数,直接本地兜底
|
||||
return self._local_fallback(user_text, "max tool rounds exceeded")
|
||||
|
||||
# ── 内部工具执行 ───────────────────────────────────────────────────────────
|
||||
|
||||
def _run_knowledge_search(self, query: str) -> tuple[str, KnowledgeDoc | None]:
|
||||
"""执行本地知识库搜索,返回 (tool_result_str, best_doc)。"""
|
||||
results = self._store.search(query, top_k=2)
|
||||
if not results:
|
||||
return "未找到相关知识文档。", None
|
||||
|
||||
best = results[0]
|
||||
# 给 LLM 的 tool result:文档标题 + 正文(截断到 800 字节)
|
||||
excerpt = best.doc.content[:800]
|
||||
tool_text = (
|
||||
f"[知识库检索结果]\n"
|
||||
f"文档:{best.doc.title}\n"
|
||||
f"命中关键词:{', '.join(best.matched_keywords)}\n\n"
|
||||
f"{excerpt}"
|
||||
)
|
||||
return tool_text, best.doc
|
||||
|
||||
# ── HTTP 调用 ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
payload: dict[str, Any] = {
|
||||
"model": self._model_name,
|
||||
"temperature": 0.3,
|
||||
"enable_thinking": False,
|
||||
"max_tokens": 300,
|
||||
"messages": messages,
|
||||
}
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = "auto"
|
||||
|
||||
req = request.Request(
|
||||
self._endpoint(),
|
||||
data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with request.urlopen(req, timeout=self._timeout) as resp:
|
||||
return json.loads(resp.read().decode("utf-8"))
|
||||
except (error.URLError, TimeoutError, ValueError, OSError):
|
||||
return None
|
||||
|
||||
def _endpoint(self) -> str:
|
||||
if self._base_url.endswith("/chat/completions"):
|
||||
return self._base_url
|
||||
return f"{self._base_url}/chat/completions"
|
||||
|
||||
def _first_choice(self, payload: dict[str, Any]) -> dict[str, Any] | None:
|
||||
choices = payload.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
return None
|
||||
return choices[0]
|
||||
|
||||
def _extract_content(self, message: dict[str, Any]) -> str:
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, list):
|
||||
return "".join(
|
||||
str(item.get("text", "")).strip()
|
||||
for item in content
|
||||
if isinstance(item, dict) and item.get("type") == "text"
|
||||
).strip()
|
||||
return str(content).strip()
|
||||
|
||||
def _local_fallback(self, _user_text: str, reason: str) -> KnowledgeReplyResult:
|
||||
return KnowledgeReplyResult(
|
||||
reply_text="暂未找到相关资料,建议联系技术支持或查阅设备手册。",
|
||||
backend="local-fallback",
|
||||
model_name="knowledge-fallback",
|
||||
error_message=reason,
|
||||
)
|
||||
152
intelligent_cabin/app/services/knowledge_store.py
Normal file
152
intelligent_cabin/app/services/knowledge_store.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
app/services/knowledge_store.py
|
||||
|
||||
本地 Markdown 知识库加载与关键词检索。
|
||||
- 所有 .md 文件存放在 config/knowledge/ 目录
|
||||
- 基于关键词打分,支持多文档排序返回
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeDoc:
|
||||
"""一篇知识文档的元数据与正文。"""
|
||||
|
||||
doc_id: str # 文件名(不含扩展名)
|
||||
title: str # MD 首行 # 标题,无则用文件名
|
||||
content: str # 完整原始 Markdown 内容
|
||||
keywords: list[str] = field(default_factory=list) # 从正文抽取的高频词
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
doc: KnowledgeDoc
|
||||
score: float
|
||||
matched_keywords: list[str]
|
||||
|
||||
|
||||
class KnowledgeStore:
|
||||
"""从 config/knowledge/*.md 加载知识库,提供关键词检索。"""
|
||||
|
||||
def __init__(self, knowledge_dir: str | Path) -> None:
|
||||
self._dir = Path(knowledge_dir)
|
||||
self._docs: dict[str, KnowledgeDoc] = {}
|
||||
self._load()
|
||||
|
||||
# ── 公开 API ───────────────────────────────────────────────────────────────
|
||||
|
||||
def search(self, query: str, top_k: int = 3) -> list[SearchResult]:
|
||||
"""根据 query 检索最相关的知识文档,返回最多 top_k 条。"""
|
||||
query_tokens = self._tokenize(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
results: list[SearchResult] = []
|
||||
for doc in self._docs.values():
|
||||
score, matched = self._score(doc, query_tokens)
|
||||
if score > 0:
|
||||
results.append(SearchResult(doc=doc, score=score, matched_keywords=matched))
|
||||
|
||||
results.sort(key=lambda r: r.score, reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
def get(self, doc_id: str) -> KnowledgeDoc | None:
|
||||
return self._docs.get(doc_id)
|
||||
|
||||
def all_doc_ids(self) -> list[str]:
|
||||
return list(self._docs.keys())
|
||||
|
||||
def reload(self) -> None:
|
||||
"""热重载知识库(添加新 MD 文件后调用)。"""
|
||||
self._docs.clear()
|
||||
self._load()
|
||||
|
||||
# ── 内部逻辑 ───────────────────────────────────────────────────────────────
|
||||
|
||||
def _load(self) -> None:
|
||||
if not self._dir.exists():
|
||||
return
|
||||
for md_path in sorted(self._dir.glob("*.md")):
|
||||
doc = self._parse_md(md_path)
|
||||
self._docs[doc.doc_id] = doc
|
||||
|
||||
def _parse_md(self, path: Path) -> KnowledgeDoc:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
doc_id = path.stem
|
||||
|
||||
# 提取第一个 # 标题作为文档标题
|
||||
title_match = re.search(r"^#+\s+(.+)", content, re.MULTILINE)
|
||||
title = title_match.group(1).strip() if title_match else doc_id
|
||||
|
||||
# 提取关键词:去标点后的中文词段(2~6字)
|
||||
keywords = self._extract_keywords(content)
|
||||
return KnowledgeDoc(doc_id=doc_id, title=title, content=content, keywords=keywords)
|
||||
|
||||
def _extract_keywords(self, content: str) -> list[str]:
|
||||
"""提取 MD 正文中的中文词段作为候选关键词。"""
|
||||
# 去掉 Markdown 语法符号
|
||||
text = re.sub(r"[#`*_>|~\[\]()!]", " ", content)
|
||||
text = re.sub(r"https?://\S+", " ", text)
|
||||
# 中文词段(2-6 个汉字)
|
||||
words = re.findall(r"[\u4e00-\u9fff]{2,6}", text)
|
||||
# 去重,保留顺序
|
||||
seen: set[str] = set()
|
||||
unique: list[str] = []
|
||||
for w in words:
|
||||
if w not in seen:
|
||||
seen.add(w)
|
||||
unique.append(w)
|
||||
return unique
|
||||
|
||||
def _tokenize(self, text: str) -> list[str]:
|
||||
"""将 query 分割成候选检索词。
|
||||
|
||||
策略:
|
||||
1. 提取所有连续中文字段(2字以上)作为候选
|
||||
2. 在连续中文字段上做滑动窗口(2-5字),覆盖子串匹配
|
||||
避免整句 '虚焊报警怎么办' 作为单一 token 无法匹配 '虚焊报警'
|
||||
"""
|
||||
# 提取所有连续中文片段
|
||||
chinese_chunks = re.findall(r"[\u4e00-\u9fff]+", text)
|
||||
tokens: list[str] = []
|
||||
for chunk in chinese_chunks:
|
||||
# 滑动窗口:长度 2 到 min(5, len(chunk))
|
||||
for size in range(2, min(6, len(chunk) + 1)):
|
||||
for start in range(len(chunk) - size + 1):
|
||||
tokens.append(chunk[start : start + size])
|
||||
# 整体 chunk 也加入(用于长词精确匹配)
|
||||
if len(chunk) > 1:
|
||||
tokens.append(chunk)
|
||||
# 去重保序
|
||||
seen: set[str] = set()
|
||||
unique: list[str] = []
|
||||
for t in tokens:
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
unique.append(t)
|
||||
return unique
|
||||
|
||||
def _score(self, doc: KnowledgeDoc, query_tokens: list[str]) -> tuple[float, list[str]]:
|
||||
"""给文档打分:命中 title 得 3 分,命中 content 得 1 分(上限 5)。"""
|
||||
score = 0.0
|
||||
matched: list[str] = []
|
||||
seen: set[str] = set()
|
||||
content_lower = doc.content.lower()
|
||||
title_lower = doc.title.lower()
|
||||
|
||||
for token in query_tokens:
|
||||
token_lower = token.lower()
|
||||
if token_lower in seen:
|
||||
continue
|
||||
in_title = token_lower in title_lower
|
||||
in_content = token_lower in content_lower
|
||||
if in_title or in_content:
|
||||
seen.add(token_lower)
|
||||
matched.append(token)
|
||||
score += 3.0 if in_title else 1.0
|
||||
|
||||
return min(score, 15.0), matched # 上限 15,避免极端高分
|
||||
219
intelligent_cabin/app/services/multi_intent_detector.py
Normal file
219
intelligent_cabin/app/services/multi_intent_detector.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
from typing import Any, Protocol
|
||||
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.services.joint_nlu import JointBertNLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiIntentCandidate:
|
||||
intent_id: str
|
||||
score: float
|
||||
label: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiIntentDetectionResult:
|
||||
detected: bool = False
|
||||
candidates: list[MultiIntentCandidate] = field(default_factory=list)
|
||||
reason: str | None = None
|
||||
backend_name: str = "none"
|
||||
error_message: str | None = None
|
||||
raw_scores: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class MultiIntentDetector(Protocol):
|
||||
def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult:
|
||||
...
|
||||
|
||||
|
||||
class JointBertMultiIntentDetector:
|
||||
"""
|
||||
A multi-intent detector backed by the same Joint BERT runtime as single-intent and slot extraction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nlu: JointBertNLU,
|
||||
threshold: float | None = None,
|
||||
top_k: int = 8,
|
||||
max_labels: int = 4,
|
||||
) -> None:
|
||||
self._nlu = nlu
|
||||
self._threshold = threshold
|
||||
self._top_k = top_k
|
||||
self._max_labels = max_labels
|
||||
|
||||
def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult:
|
||||
candidates = self._nlu.predict_multi_intents(
|
||||
text,
|
||||
intents,
|
||||
threshold=self._threshold,
|
||||
max_labels=self._max_labels,
|
||||
top_k=self._top_k,
|
||||
)
|
||||
return MultiIntentDetectionResult(
|
||||
detected=len(candidates) >= 2,
|
||||
candidates=[
|
||||
MultiIntentCandidate(intent_id=item.intent_id, score=item.score, label=item.intent_id)
|
||||
for item in candidates
|
||||
],
|
||||
reason=f"joint bert multi-label candidates={len(candidates)} threshold={self._threshold}",
|
||||
backend_name="joint-bert-multi-label",
|
||||
raw_scores=[
|
||||
{"intent_id": item.intent_id, "label": item.intent_id, "score": float(item.score)}
|
||||
for item in candidates
|
||||
],
|
||||
)
|
||||
|
||||
def warmup(self, sample_text: str = "打开空调并打开车窗") -> bool:
|
||||
return self._nlu.warmup(sample_text)
|
||||
|
||||
|
||||
class BertMultiIntentDetector:
|
||||
"""
|
||||
A stage-2 multi-intent detector backed by a dedicated multi-label BERT head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
threshold: float = 0.45,
|
||||
top_k: int = 8,
|
||||
max_labels: int = 4,
|
||||
blocked_labels: set[str] | None = None,
|
||||
) -> None:
|
||||
self._model_path = model_path
|
||||
self._threshold = threshold
|
||||
self._top_k = top_k
|
||||
self._max_labels = max_labels
|
||||
self._blocked_labels = blocked_labels or {"__social__", "__out_of_scope__"}
|
||||
self._tokenizer = None
|
||||
self._model = None
|
||||
self._torch = None
|
||||
self._error_message: str | None = None
|
||||
self._warmed_up = False
|
||||
self._warmup_elapsed_ms: float | None = None
|
||||
|
||||
def detect(self, text: str, intents: list[IntentDefinition]) -> MultiIntentDetectionResult:
|
||||
runtime = self._get_runtime()
|
||||
if runtime is None:
|
||||
return MultiIntentDetectionResult(
|
||||
detected=False,
|
||||
reason="multi-label detector unavailable",
|
||||
backend_name="bert-multi-label",
|
||||
error_message=self._error_message,
|
||||
)
|
||||
torch, tokenizer, model = runtime
|
||||
intent_map = {intent.intent_id: intent for intent in intents}
|
||||
try:
|
||||
encoded = tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
padding=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
logits = model(**encoded).logits.squeeze(0)
|
||||
probs = torch.sigmoid(logits).detach().cpu().tolist()
|
||||
except Exception as exc:
|
||||
return MultiIntentDetectionResult(
|
||||
detected=False,
|
||||
reason="multi-label forward failed",
|
||||
backend_name="bert-multi-label",
|
||||
error_message=str(exc),
|
||||
)
|
||||
|
||||
id2label = getattr(model.config, "id2label", {}) or {}
|
||||
raw_scores: list[dict[str, Any]] = []
|
||||
for index, score in enumerate(probs):
|
||||
label = str(id2label.get(index, f"LABEL_{index}"))
|
||||
raw_scores.append(
|
||||
{
|
||||
"label": label,
|
||||
"intent_id": label,
|
||||
"score": float(score),
|
||||
}
|
||||
)
|
||||
raw_scores.sort(key=lambda item: item["score"], reverse=True)
|
||||
raw_top = raw_scores[: self._top_k]
|
||||
|
||||
candidates: list[MultiIntentCandidate] = []
|
||||
for item in raw_top:
|
||||
intent_id = str(item.get("intent_id") or "")
|
||||
if intent_id in self._blocked_labels:
|
||||
continue
|
||||
if intent_id not in intent_map:
|
||||
continue
|
||||
score = float(item.get("score", 0.0))
|
||||
if score < self._threshold:
|
||||
continue
|
||||
candidates.append(
|
||||
MultiIntentCandidate(
|
||||
intent_id=intent_id,
|
||||
score=score,
|
||||
label=str(item.get("label") or intent_id),
|
||||
)
|
||||
)
|
||||
if len(candidates) >= self._max_labels:
|
||||
break
|
||||
|
||||
return MultiIntentDetectionResult(
|
||||
detected=len(candidates) >= 2,
|
||||
candidates=candidates,
|
||||
reason=f"bert multi-label candidates={len(candidates)} threshold={self._threshold}",
|
||||
backend_name="bert-multi-label",
|
||||
raw_scores=raw_top,
|
||||
)
|
||||
|
||||
def warmup(self, sample_text: str = "打开空调并打开车窗") -> bool:
|
||||
if self._warmed_up:
|
||||
return True
|
||||
started_at = perf_counter()
|
||||
runtime = self._get_runtime()
|
||||
if runtime is None:
|
||||
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
||||
return False
|
||||
torch, tokenizer, model = runtime
|
||||
try:
|
||||
encoded = tokenizer(sample_text, truncation=True, padding=False, return_tensors="pt")
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
_ = model(**encoded).logits
|
||||
except Exception as exc:
|
||||
self._error_message = str(exc)
|
||||
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
||||
return False
|
||||
self._warmed_up = True
|
||||
self._warmup_elapsed_ms = round((perf_counter() - started_at) * 1000, 3)
|
||||
return True
|
||||
|
||||
def _get_runtime(self):
|
||||
if self._tokenizer is not None and self._model is not None and self._torch is not None:
|
||||
return self._torch, self._tokenizer, self._model
|
||||
if not self._model_path or not Path(self._model_path).exists():
|
||||
self._error_message = "multi-intent model path is empty or missing"
|
||||
return None
|
||||
try:
|
||||
transformers = importlib.import_module("transformers")
|
||||
torch = importlib.import_module("torch")
|
||||
except ImportError as exc:
|
||||
self._error_message = str(exc)
|
||||
return None
|
||||
try:
|
||||
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self._model_path)
|
||||
self._model = transformers.AutoModelForSequenceClassification.from_pretrained(self._model_path)
|
||||
self._torch = torch
|
||||
except Exception as exc:
|
||||
self._error_message = str(exc)
|
||||
return None
|
||||
return self._torch, self._tokenizer, self._model
|
||||
|
||||
|
||||
SigmoidBertMultiIntentDetector = BertMultiIntentDetector
|
||||
1347
intelligent_cabin/app/services/planner.py
Normal file
1347
intelligent_cabin/app/services/planner.py
Normal file
File diff suppressed because it is too large
Load Diff
299
intelligent_cabin/app/services/response_policy.py
Normal file
299
intelligent_cabin/app/services/response_policy.py
Normal file
@@ -0,0 +1,299 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.schemas.intent import IntentDefinition
|
||||
|
||||
|
||||
class ResponsePolicy:
|
||||
_DEFAULT_INTENT_HINTS = {
|
||||
"cabin_nav_to": "导航",
|
||||
"cabin_nav_cancel": "结束导航",
|
||||
"cabin_set_ac": "调空调",
|
||||
"cabin_ac_on": "打开空调",
|
||||
"cabin_ac_off": "关闭空调",
|
||||
"cabin_fan_up": "调大风量",
|
||||
"cabin_fan_down": "调小风量",
|
||||
"cabin_defog_front_on": "打开前挡除雾",
|
||||
"cabin_defog_rear_on": "打开后挡除雾",
|
||||
"cabin_window_open": "打开车窗",
|
||||
"cabin_window_close": "关闭车窗",
|
||||
"cabin_sunroof_open": "打开天窗",
|
||||
"cabin_sunroof_close": "关闭天窗",
|
||||
"cabin_trunk_open": "打开后备箱",
|
||||
"cabin_trunk_close": "关闭后备箱",
|
||||
"cabin_lock_doors": "锁车门",
|
||||
"cabin_unlock_doors": "解锁车门",
|
||||
"cabin_play_music": "播放音乐",
|
||||
"cabin_pause_music": "暂停音乐",
|
||||
"cabin_next_track": "下一首",
|
||||
"cabin_previous_track": "上一首",
|
||||
"cabin_volume_up": "调大音量",
|
||||
"cabin_volume_down": "调小音量",
|
||||
"cabin_volume_mute": "静音",
|
||||
"cabin_lights_on": "打开车灯",
|
||||
"cabin_lights_off": "关闭车灯",
|
||||
"cabin_reading_light_on": "打开阅读灯",
|
||||
"cabin_reading_light_off": "关闭阅读灯",
|
||||
"cabin_seat_heat_on": "打开座椅加热",
|
||||
"cabin_seat_heat_off": "关闭座椅加热",
|
||||
"cabin_seat_vent_on": "打开座椅通风",
|
||||
"cabin_seat_vent_off": "关闭座椅通风",
|
||||
"cabin_mirror_fold": "折叠后视镜",
|
||||
"cabin_mirror_unfold": "展开后视镜",
|
||||
"cabin_wiper_on": "打开雨刷",
|
||||
"cabin_wiper_off": "关闭雨刷",
|
||||
"cabin_screen_brightness_up": "调亮屏幕",
|
||||
"cabin_screen_brightness_down": "调暗屏幕",
|
||||
"cabin_answer_call": "接听电话",
|
||||
"cabin_hang_up_call": "挂断电话",
|
||||
"cs_query_order": "查订单",
|
||||
"cs_query_logistics": "查物流",
|
||||
"cs_cancel_order": "取消订单",
|
||||
"cs_transfer_human": "转人工",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
templates: dict[str, str] | None = None,
|
||||
intent_hints: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self._templates = templates or {}
|
||||
self._intent_hints = {**self._DEFAULT_INTENT_HINTS, **(intent_hints or {})}
|
||||
|
||||
def ask_for_slot(self, intent: IntentDefinition, slot_name: str, default_template: str) -> str:
|
||||
if slot_name == "order_id":
|
||||
if intent.intent_id == "cs_cancel_order":
|
||||
return self._template("ask_cancel_order_id", "请告诉我订单号。")
|
||||
return self._template("ask_order_id", "请提供订单号。")
|
||||
if slot_name == "destination":
|
||||
return self._template("ask_destination", "请告诉我要去哪里。")
|
||||
if slot_name == "temperature":
|
||||
return self._template("ask_temperature", "请告诉我要设置多少度。")
|
||||
if slot_name == "media_query":
|
||||
return self._template("ask_media_query", "想听什么风格或者具体的歌名?")
|
||||
return default_template.strip() or "请补充一个关键信息。"
|
||||
|
||||
def workflow_result(self, intent: IntentDefinition, plugin_result: dict[str, Any]) -> str:
|
||||
if not plugin_result.get("success", True):
|
||||
return self._template("workflow_failed", "这次没处理成功,请稍后再试。")
|
||||
message = str(plugin_result.get("message") or "").strip()
|
||||
if not message:
|
||||
return self.ack(intent)
|
||||
if len(message) > 42:
|
||||
return message[:39].rstrip(",。;; ") + "..."
|
||||
return message
|
||||
|
||||
def workflow_summary(self, messages: list[str]) -> str:
|
||||
cleaned = [item.strip() for item in messages if item and item.strip()]
|
||||
if not cleaned:
|
||||
return self._template("workflow_summary_empty", "好的,已经处理完成。")
|
||||
if len(cleaned) == 1:
|
||||
return cleaned[0]
|
||||
natural_clauses: list[str] = []
|
||||
previous_subject: str | None = None
|
||||
for index, item in enumerate(cleaned[:3]):
|
||||
clause, subject = self._vehicle_style_clause(item, index=index, previous_subject=previous_subject)
|
||||
natural_clauses.append(clause)
|
||||
previous_subject = subject or previous_subject
|
||||
summary = f"好,{','.join(natural_clauses)}。"
|
||||
if len(cleaned) > 3:
|
||||
summary = summary.rstrip("。") + ",其余步骤也已完成。"
|
||||
if len(summary) > 70:
|
||||
return summary[:67].rstrip(",。;; ") + "..."
|
||||
return summary
|
||||
|
||||
def ask_for_confirmation(self, intent: IntentDefinition, detail: str | None = None) -> str:
|
||||
if intent.intent_id == "cs_cancel_order":
|
||||
if detail:
|
||||
return f"即将取消订单,{detail}。请回复“确认”或“取消”。"
|
||||
return "即将取消订单。请回复“确认”或“取消”。"
|
||||
if detail:
|
||||
return f"{detail}。请回复“确认”或“取消”。"
|
||||
return "请确认是否继续执行。回复“确认”或“取消”。"
|
||||
|
||||
def confirm_retry(self) -> str:
|
||||
return self._template("confirm_retry", "我需要一个明确确认。请回复“确认”继续,或回复“取消”终止。")
|
||||
|
||||
def confirm_cancelled(self) -> str:
|
||||
return self._template("confirm_cancelled", "好的,已取消这一步。")
|
||||
|
||||
def step_skipped(self, intent: IntentDefinition, reason: str | None = None) -> str:
|
||||
if intent.intent_id == "cs_cancel_order":
|
||||
base = "订单取消步骤未执行。"
|
||||
else:
|
||||
base = "这一步已跳过。"
|
||||
if reason:
|
||||
return f"{base}{reason}"
|
||||
return base
|
||||
|
||||
def ack(self, intent: IntentDefinition | None = None) -> str:
|
||||
if intent is None:
|
||||
return self._template("ack_default", "收到,马上处理。")
|
||||
if intent.domain == "cabin":
|
||||
return self._template("ack_cabin", "好的,马上处理。")
|
||||
return self._template("ack_service", "收到,我来处理。")
|
||||
|
||||
def reject(self) -> str:
|
||||
return self._template("reject", "这个我暂时做不了,但我可以帮你查询、控制或转人工。")
|
||||
|
||||
def short_social(self, social_kind: str) -> str:
|
||||
if social_kind == "greeting":
|
||||
return self._template("short_social_greeting", "你好,我在。")
|
||||
if social_kind == "thanks":
|
||||
return self._template("short_social_thanks", "不客气。")
|
||||
if social_kind == "goodbye":
|
||||
return self._template("short_social_goodbye", "好的,有需要再叫我。")
|
||||
if social_kind == "capability":
|
||||
return self._template(
|
||||
"short_social_capability",
|
||||
"我可以帮你查订单、查物流、取消订单、导航、调空调、播放音乐或转人工。",
|
||||
)
|
||||
return self._template("short_social_default", "我在。")
|
||||
|
||||
def open_social_fallback(self) -> str:
|
||||
return self._template("open_social_fallback", "可以和你聊两句,你也可以继续告诉我想处理什么。")
|
||||
|
||||
def with_pending_hint(self, text: str, pending_hint: str | None = None) -> str:
|
||||
base = text.strip() or self.open_social_fallback()
|
||||
hint = (pending_hint or "").strip()
|
||||
if not hint:
|
||||
return base
|
||||
return f"{base} {hint}"
|
||||
|
||||
def pending_task_hint(self, status: str, pending_slots: list[str], current_intent: str | None = None) -> str | None:
|
||||
if status == "waiting_confirmation":
|
||||
return self._template("pending_confirmation_hint", "当前这一步还在等你确认,回复“确认”或“取消”即可。")
|
||||
if status == "waiting_slot" and pending_slots:
|
||||
if pending_slots[0] == "order_id":
|
||||
return self._template("pending_slot_order_id", "当前还缺订单号,你继续告诉我订单号就行。")
|
||||
if pending_slots[0] == "temperature":
|
||||
return self._template("pending_slot_temperature", "当前还缺温度,你继续告诉我要设置多少度就行。")
|
||||
if pending_slots[0] == "destination":
|
||||
return self._template("pending_slot_destination", "当前还缺目的地,你继续告诉我要去哪里就行。")
|
||||
if pending_slots[0] == "media_query":
|
||||
return self._template("pending_slot_media_query", "当前还缺歌名或风格,你直接说歌名、歌手或风格就行。")
|
||||
return self._template("pending_slot_default", "当前还缺一个关键信息,你继续补充就行。")
|
||||
if status == "running" and current_intent:
|
||||
return self._template("pending_running", "当前任务还在继续,你也可以直接继续下一个指令。")
|
||||
return None
|
||||
|
||||
def task_stopped(self) -> str:
|
||||
return self._template("task_stopped", "好的,已停止当前任务。")
|
||||
|
||||
def clarify(self, candidate_intents: list[str]) -> str:
|
||||
options = [
|
||||
self._intent_hints.get(intent_id, intent_id)
|
||||
for intent_id in candidate_intents
|
||||
if intent_id
|
||||
]
|
||||
deduped: list[str] = []
|
||||
for item in options:
|
||||
if item not in deduped:
|
||||
deduped.append(item)
|
||||
if not deduped:
|
||||
return "我理解得还不够确定,你是想查询、控制,还是转人工?"
|
||||
if len(deduped) == 1:
|
||||
return f"请确认一下,你是想{deduped[0]}吗?"
|
||||
if len(deduped) == 2:
|
||||
return f"请确认一下,你是想{deduped[0]}还是{deduped[1]}?"
|
||||
return f"请确认一下,你是想{deduped[0]}、{deduped[1]},还是{deduped[2]}?"
|
||||
|
||||
def fallback(self) -> str:
|
||||
return self._template("fallback", "我还没完全听懂,你可以换个简短说法,或告诉我是查询、控制还是转人工。")
|
||||
|
||||
def _template(self, key: str, default: str) -> str:
|
||||
value = str(self._templates.get(key, default)).strip()
|
||||
return value or default
|
||||
|
||||
def _naturalize_workflow_message(self, text: str) -> str:
|
||||
normalized = text.strip().rstrip("。;; ")
|
||||
normalized = re.sub(r"^好的[,,\s]*", "", normalized)
|
||||
normalized = re.sub(r"^收到[,,\s]*", "", normalized)
|
||||
if normalized.startswith("已将"):
|
||||
normalized = normalized[2:]
|
||||
elif normalized.startswith("已经将"):
|
||||
normalized = normalized[3:]
|
||||
elif normalized.startswith("已经"):
|
||||
normalized = normalized[2:]
|
||||
elif normalized.startswith("已"):
|
||||
normalized = normalized[1:]
|
||||
normalized = normalized.strip(",, ")
|
||||
if not normalized:
|
||||
return "已经处理好了"
|
||||
if normalized.endswith("了"):
|
||||
return normalized
|
||||
return f"{normalized}了"
|
||||
|
||||
def _vehicle_style_clause(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
index: int,
|
||||
previous_subject: str | None = None,
|
||||
) -> tuple[str, str | None]:
|
||||
normalized = self._naturalize_workflow_message(text)
|
||||
|
||||
match = re.match(r"^(打开|关闭)(.+)了$", normalized)
|
||||
if match:
|
||||
action, subject = match.groups()
|
||||
subject = subject.strip()
|
||||
if action == "打开":
|
||||
if previous_subject and previous_subject == subject:
|
||||
return "也打开了", subject
|
||||
if index > 0:
|
||||
return f"{subject}也打开了", subject
|
||||
return f"{subject}已经打开了", subject
|
||||
if previous_subject and previous_subject == subject:
|
||||
return "也帮你关上了", subject
|
||||
if index > 0:
|
||||
return f"{subject}也帮你关上了", subject
|
||||
return f"{subject}已经关上了", subject
|
||||
|
||||
match = re.match(r"^(锁定|解锁)(.+)了$", normalized)
|
||||
if match:
|
||||
action, subject = match.groups()
|
||||
subject = subject.strip()
|
||||
action_text = "锁好了" if action == "锁定" else "解锁了"
|
||||
if previous_subject and previous_subject == subject:
|
||||
return f"也{action_text}", subject
|
||||
if index > 0:
|
||||
return f"{subject}也{action_text}", subject
|
||||
return f"{subject}已经{action_text}", subject
|
||||
|
||||
match = re.match(r"^(.+)调到\s*(.+)度了$", normalized)
|
||||
if match:
|
||||
subject, value = match.groups()
|
||||
subject = subject.strip()
|
||||
value = value.strip()
|
||||
if previous_subject and previous_subject == subject:
|
||||
return f"也调到 {value} 度了", subject
|
||||
if index > 0:
|
||||
return f"{subject}也调到 {value} 度了", subject
|
||||
return f"{subject}调到 {value} 度了", subject
|
||||
|
||||
match = re.match(r"^(调大|调小)(.+)了$", normalized)
|
||||
if match:
|
||||
action, subject = match.groups()
|
||||
subject = subject.strip()
|
||||
if previous_subject and previous_subject == subject:
|
||||
return f"也{action}了", subject
|
||||
if index > 0:
|
||||
return f"{subject}也{action}了", subject
|
||||
return f"{subject}已经{action}了", subject
|
||||
|
||||
if normalized.startswith("正在播放 "):
|
||||
target = normalized[len("正在播放 ") :].strip()
|
||||
if index > 0:
|
||||
return f"也开始播放 {target} 了", "播放"
|
||||
return f"开始播放 {target} 了", "播放"
|
||||
|
||||
if normalized.startswith("订单 ") and normalized.endswith(" 已取消"):
|
||||
order_text = normalized[:-4].strip()
|
||||
return f"{order_text}已经取消了", "订单"
|
||||
|
||||
if normalized.startswith("订单 ") and "当前" in normalized:
|
||||
return normalized, "订单"
|
||||
|
||||
return normalized, None
|
||||
108
intelligent_cabin/app/services/rewrite_engine.py
Normal file
108
intelligent_cabin/app/services/rewrite_engine.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.schemas.configuration import ContextRewriteConfig, ParamContextDefinition
|
||||
from app.services.session_store import SessionState
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewriteResult:
|
||||
original_text: str
|
||||
rewritten_text: str
|
||||
applied: bool = False
|
||||
reason: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ContextRewriteEngine:
|
||||
"""
|
||||
将短句 follow-up(如"再快一点"、"电压高一点")改写为完整命令(如"速度设为 85 mm/min"),
|
||||
使其能复用本地快路径而不必每轮重做完整规划。
|
||||
|
||||
改写规则完全由外部配置文件(context_rewrite.yml)驱动,不硬编码业务参数,
|
||||
适用于不同设备(线切割 / 激光切割 / 注塑机等)的部署切换。
|
||||
"""
|
||||
|
||||
def __init__(self, config: ContextRewriteConfig | None = None) -> None:
|
||||
self._config = config or ContextRewriteConfig()
|
||||
# 构建反向索引:intent_id → ParamContextDefinition
|
||||
self._intent_index: dict[str, ParamContextDefinition] = {}
|
||||
for ctx in self._config.param_contexts:
|
||||
for intent_id in ctx.intent_ids:
|
||||
self._intent_index[intent_id] = ctx
|
||||
|
||||
# ------------------------------------------------------------------ public
|
||||
|
||||
def rewrite(self, text: str, session: SessionState) -> RewriteResult:
|
||||
normalized = text.strip()
|
||||
if not normalized:
|
||||
return RewriteResult(original_text=text, rewritten_text=text)
|
||||
|
||||
current_intent = session.current_intent
|
||||
if current_intent and current_intent in self._intent_index:
|
||||
ctx = self._intent_index[current_intent]
|
||||
result = self._rewrite_param_adjustment(normalized, session, ctx)
|
||||
if result.applied:
|
||||
return result
|
||||
|
||||
return RewriteResult(original_text=text, rewritten_text=text)
|
||||
|
||||
# ----------------------------------------------------------------- private
|
||||
|
||||
def _rewrite_param_adjustment(
|
||||
self,
|
||||
text: str,
|
||||
session: SessionState,
|
||||
ctx: ParamContextDefinition,
|
||||
) -> RewriteResult:
|
||||
direction: str | None = None
|
||||
if any(phrase and phrase in text for phrase in ctx.up_phrases):
|
||||
direction = "up"
|
||||
elif any(phrase and phrase in text for phrase in ctx.down_phrases):
|
||||
direction = "down"
|
||||
|
||||
if direction is None:
|
||||
return RewriteResult(original_text=text, rewritten_text=text)
|
||||
|
||||
previous_value = self._last_slot_value(session, ctx.slot_name)
|
||||
base_value = previous_value if previous_value is not None else ctx.default_value
|
||||
delta = ctx.step if direction == "up" else -ctx.step
|
||||
|
||||
if isinstance(ctx.min_value, float) or isinstance(ctx.max_value, float) or isinstance(ctx.step, float):
|
||||
next_value: int | float = max(float(ctx.min_value), min(float(ctx.max_value), float(base_value) + float(delta)))
|
||||
else:
|
||||
next_value = max(int(ctx.min_value), min(int(ctx.max_value), int(base_value) + int(delta)))
|
||||
|
||||
rewritten = ctx.rewrite_template.format(value=next_value)
|
||||
return RewriteResult(
|
||||
original_text=text,
|
||||
rewritten_text=rewritten,
|
||||
applied=True,
|
||||
reason=f"normalize relative {ctx.slot_name} adjustment into an explicit target",
|
||||
metadata={
|
||||
"cache_hit": True,
|
||||
"rewrite_type": "param_adjustment",
|
||||
"slot_name": ctx.slot_name,
|
||||
"direction": direction,
|
||||
"previous_value": previous_value,
|
||||
"base_value": base_value,
|
||||
"next_value": next_value,
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _last_slot_value(session: SessionState, slot_name: str) -> int | float | None:
|
||||
raw = session.context_memory.get(f"last_{slot_name}", session.slots.get(slot_name))
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, (int, float)):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
return int(raw) if raw.isdigit() else float(raw)
|
||||
except ValueError:
|
||||
pass
|
||||
return None
|
||||
407
intelligent_cabin/app/services/router.py
Normal file
407
intelligent_cabin/app/services/router.py
Normal file
@@ -0,0 +1,407 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from time import perf_counter
|
||||
from typing import Any, Protocol
|
||||
|
||||
from app.schemas.debug import IntentCandidate, MatcherStageDebug, RoutingDebug
|
||||
from app.schemas.intent import IntentDefinition
|
||||
from app.services.classifier import IntentClassifier
|
||||
from app.services.intent_registry import IntentRegistry
|
||||
from app.services.joint_nlu import JointBertNLU
|
||||
|
||||
|
||||
@dataclass
|
||||
class IntentMatchResult:
|
||||
intent: IntentDefinition | None
|
||||
stage_debug: MatcherStageDebug
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouteMatchResult:
|
||||
intent: IntentDefinition | None
|
||||
debug: RoutingDebug
|
||||
|
||||
|
||||
class IntentMatcher(Protocol):
|
||||
def match(self, text: str) -> IntentMatchResult:
|
||||
...
|
||||
|
||||
|
||||
class SlotExtractor(Protocol):
|
||||
def extract(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
|
||||
...
|
||||
|
||||
|
||||
class ClassifierIntentMatcher:
|
||||
def __init__(self, registry: IntentRegistry, classifier: IntentClassifier) -> None:
|
||||
self._registry = registry
|
||||
self._classifier = classifier
|
||||
|
||||
def match(self, text: str) -> IntentMatchResult:
|
||||
result = self._classifier.predict(text, self._registry.list())
|
||||
ranked_classifier_scores = sorted(
|
||||
[float(score) for _, score in (result.candidates or [])],
|
||||
reverse=True,
|
||||
)
|
||||
classifier_top_margin = (
|
||||
ranked_classifier_scores[0] - ranked_classifier_scores[1]
|
||||
if len(ranked_classifier_scores) >= 2
|
||||
else ranked_classifier_scores[0] if ranked_classifier_scores else 0.0
|
||||
)
|
||||
candidates = [
|
||||
IntentCandidate(
|
||||
intent_id=intent.intent_id,
|
||||
score=score,
|
||||
reason="classifier candidate",
|
||||
model_name=result.backend_name or result.model_name,
|
||||
raw_label=next(
|
||||
(
|
||||
item.get("label")
|
||||
for item in (result.raw_candidates or [])
|
||||
if item.get("intent_id") == intent.intent_id and float(item.get("score", 0.0)) == score
|
||||
),
|
||||
intent.intent_id,
|
||||
),
|
||||
)
|
||||
for intent, score in (result.candidates or [])
|
||||
]
|
||||
metadata: dict[str, Any] = {
|
||||
"decision_model": result.model_name,
|
||||
"threshold": getattr(self._classifier, "_threshold", None),
|
||||
"raw_candidates": result.raw_candidates or [],
|
||||
"top_margin": round(classifier_top_margin, 4),
|
||||
}
|
||||
if result.fallback_reason:
|
||||
metadata["fallback_reason"] = result.fallback_reason
|
||||
if result.intent is None:
|
||||
return IntentMatchResult(
|
||||
intent=None,
|
||||
stage_debug=MatcherStageDebug(
|
||||
stage="classifier",
|
||||
accepted=False,
|
||||
score=result.score,
|
||||
reason=result.fallback_reason or "classifier below threshold or no intent selected",
|
||||
model_name=result.model_name,
|
||||
backend=result.backend_name or result.model_name,
|
||||
fallback_used=result.used_fallback,
|
||||
raw_label=result.raw_label,
|
||||
error_message=result.error_message,
|
||||
metadata=metadata,
|
||||
candidates=candidates,
|
||||
),
|
||||
)
|
||||
return IntentMatchResult(
|
||||
intent=result.intent,
|
||||
stage_debug=MatcherStageDebug(
|
||||
stage="classifier",
|
||||
accepted=True,
|
||||
selected_intent=result.intent.intent_id,
|
||||
score=result.score,
|
||||
reason=(
|
||||
"bert classifier selected best candidate"
|
||||
if not result.used_fallback
|
||||
else f"fallback selected best candidate: {result.fallback_reason}"
|
||||
),
|
||||
model_name=result.model_name,
|
||||
backend=result.backend_name or result.model_name,
|
||||
fallback_used=result.used_fallback,
|
||||
raw_label=result.raw_label,
|
||||
error_message=result.error_message,
|
||||
metadata=metadata,
|
||||
candidates=candidates,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class MultiStageIntentMatcher:
|
||||
def __init__(
|
||||
self,
|
||||
registry: IntentRegistry,
|
||||
matchers: list[IntentMatcher],
|
||||
route_to_cloud_threshold: float = 0.75,
|
||||
clarify_margin_threshold: float = 0.12,
|
||||
classifier_execute_score_threshold: float = 0.55,
|
||||
classifier_execute_margin_threshold: float = 0.18,
|
||||
) -> None:
|
||||
self._registry = registry
|
||||
self._matchers = matchers
|
||||
self._route_to_cloud_threshold = route_to_cloud_threshold
|
||||
self._clarify_margin_threshold = clarify_margin_threshold
|
||||
self._classifier_execute_score_threshold = classifier_execute_score_threshold
|
||||
self._classifier_execute_margin_threshold = classifier_execute_margin_threshold
|
||||
|
||||
def match(self, text: str) -> RouteMatchResult:
|
||||
stage_traces: list[MatcherStageDebug] = []
|
||||
match_started_at = perf_counter()
|
||||
for matcher in self._matchers:
|
||||
stage_started_at = perf_counter()
|
||||
result = matcher.match(text)
|
||||
result.stage_debug.elapsed_ms = round((perf_counter() - stage_started_at) * 1000, 3)
|
||||
stage_traces.append(result.stage_debug)
|
||||
fusion_started_at = perf_counter()
|
||||
fusion_stage = self._build_fusion_stage(stage_traces)
|
||||
fusion_stage.elapsed_ms = round((perf_counter() - fusion_started_at) * 1000, 3)
|
||||
stage_traces.append(fusion_stage)
|
||||
total_match_latency_ms = round((perf_counter() - match_started_at) * 1000, 3)
|
||||
decision = str(fusion_stage.metadata.get("decision", "reject"))
|
||||
confidence_grade = str(fusion_stage.metadata.get("grade", "low"))
|
||||
unknown_detected = bool(fusion_stage.metadata.get("unknown_detected", False))
|
||||
decision_reason = fusion_stage.reason
|
||||
if decision == "execute" and fusion_stage.selected_intent is not None:
|
||||
intent = self._registry.get(fusion_stage.selected_intent)
|
||||
return RouteMatchResult(
|
||||
intent=intent,
|
||||
debug=RoutingDebug(
|
||||
selected_intent=fusion_stage.selected_intent,
|
||||
matched_stage=fusion_stage.stage,
|
||||
decision=decision,
|
||||
decision_reason=decision_reason,
|
||||
confidence_grade=confidence_grade,
|
||||
total_match_latency_ms=total_match_latency_ms,
|
||||
unknown_detected=unknown_detected,
|
||||
stages=stage_traces,
|
||||
),
|
||||
)
|
||||
return RouteMatchResult(
|
||||
intent=None,
|
||||
debug=RoutingDebug(
|
||||
selected_intent=None,
|
||||
matched_stage=fusion_stage.stage,
|
||||
decision=decision,
|
||||
decision_reason=decision_reason,
|
||||
confidence_grade=confidence_grade,
|
||||
total_match_latency_ms=total_match_latency_ms,
|
||||
unknown_detected=unknown_detected,
|
||||
stages=stage_traces,
|
||||
),
|
||||
)
|
||||
|
||||
def _build_fusion_stage(self, stage_traces: list[MatcherStageDebug]) -> MatcherStageDebug:
|
||||
classifier_stage = next((stage for stage in stage_traces if stage.stage == "classifier"), None)
|
||||
if classifier_stage is None:
|
||||
return MatcherStageDebug(
|
||||
stage="fusion",
|
||||
accepted=False,
|
||||
reason="classifier stage is missing",
|
||||
model_name="fusion-router",
|
||||
backend="bert-first-fusion",
|
||||
metadata={
|
||||
"grade": "low",
|
||||
"decision": "reject",
|
||||
"unknown_detected": True,
|
||||
"ranked_intents": [],
|
||||
},
|
||||
candidates=[],
|
||||
)
|
||||
ranked_candidates = list(classifier_stage.candidates or [])
|
||||
if not ranked_candidates and classifier_stage.selected_intent is not None:
|
||||
ranked_candidates = [
|
||||
IntentCandidate(
|
||||
intent_id=classifier_stage.selected_intent,
|
||||
score=classifier_stage.score,
|
||||
reason="classifier selected intent",
|
||||
model_name=classifier_stage.model_name,
|
||||
)
|
||||
]
|
||||
if not ranked_candidates:
|
||||
return MatcherStageDebug(
|
||||
stage="fusion",
|
||||
accepted=False,
|
||||
reason="classifier did not produce a usable candidate",
|
||||
model_name="fusion-router",
|
||||
backend="bert-first-fusion",
|
||||
metadata={
|
||||
"grade": "low",
|
||||
"decision": "reject",
|
||||
"unknown_detected": True,
|
||||
"ranked_intents": [],
|
||||
},
|
||||
candidates=[],
|
||||
)
|
||||
selected_candidate = ranked_candidates[0]
|
||||
selected_intent = selected_candidate.intent_id
|
||||
top_score = float(selected_candidate.score)
|
||||
second_score = float(ranked_candidates[1].score) if len(ranked_candidates) > 1 else 0.0
|
||||
top_margin = top_score - second_score
|
||||
grade = self._fusion_grade(top_score)
|
||||
classifier_backend = str(classifier_stage.backend or classifier_stage.model_name or "")
|
||||
classifier_signal = top_score
|
||||
classifier_margin = top_margin
|
||||
bert_classifier_confident = (
|
||||
"bert" in classifier_backend
|
||||
and classifier_signal >= self._classifier_execute_score_threshold
|
||||
and classifier_margin >= self._classifier_execute_margin_threshold
|
||||
)
|
||||
ambiguous = (
|
||||
classifier_signal >= self._route_to_cloud_threshold
|
||||
and classifier_signal < self._classifier_execute_score_threshold
|
||||
and top_margin < self._clarify_margin_threshold
|
||||
and len(ranked_candidates) > 1
|
||||
)
|
||||
accepted = bert_classifier_confident and not ambiguous
|
||||
possible_known_intent = classifier_signal >= self._route_to_cloud_threshold or classifier_signal >= 0.24
|
||||
unknown_detected = not accepted and not possible_known_intent
|
||||
if accepted:
|
||||
decision = "execute"
|
||||
reason = f"bert classifier is confident enough to execute (grade={grade})"
|
||||
elif ambiguous:
|
||||
decision = "clarify"
|
||||
reason = "bert top candidates are too close and require a short clarification"
|
||||
elif possible_known_intent:
|
||||
decision = "route_to_cloud"
|
||||
reason = "bert signal is not stable enough, routing to cloud planner"
|
||||
else:
|
||||
decision = "reject"
|
||||
reason = "bert signal is below local capability threshold"
|
||||
metadata = {
|
||||
"grade": grade,
|
||||
"classifier_signal": round(classifier_signal, 4),
|
||||
"classifier_margin": round(classifier_margin, 4),
|
||||
"classifier_backend": classifier_backend or None,
|
||||
"bert_classifier_confident": bert_classifier_confident,
|
||||
"top_margin": round(top_margin, 4),
|
||||
"route_to_cloud_threshold": self._route_to_cloud_threshold,
|
||||
"clarify_margin_threshold": self._clarify_margin_threshold,
|
||||
"classifier_execute_score_threshold": self._classifier_execute_score_threshold,
|
||||
"classifier_execute_margin_threshold": self._classifier_execute_margin_threshold,
|
||||
"decision": decision,
|
||||
"unknown_detected": unknown_detected,
|
||||
"ranked_intents": [
|
||||
{"intent_id": item.intent_id, "score": round(float(item.score), 4)}
|
||||
for item in ranked_candidates[:5]
|
||||
],
|
||||
}
|
||||
return MatcherStageDebug(
|
||||
stage="fusion",
|
||||
accepted=accepted or ambiguous,
|
||||
selected_intent=selected_intent if decision in {"execute", "clarify", "route_to_cloud"} else None,
|
||||
score=top_score,
|
||||
reason=reason,
|
||||
model_name="fusion-router",
|
||||
backend="bert-first-fusion",
|
||||
metadata=metadata,
|
||||
candidates=ranked_candidates[:5],
|
||||
)
|
||||
|
||||
def _fusion_grade(self, score: float) -> str:
|
||||
if score >= self._classifier_execute_score_threshold:
|
||||
return "high"
|
||||
if score >= self._route_to_cloud_threshold:
|
||||
return "medium"
|
||||
return "low"
|
||||
|
||||
|
||||
class HeuristicSlotExtractor:
|
||||
def extract(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
|
||||
slots: dict[str, Any] = {}
|
||||
order_id_match = re.search(r"\b[A-Za-z]\d{5,}\b", text)
|
||||
if order_id_match:
|
||||
slots["order_id"] = order_id_match.group(0)
|
||||
|
||||
temperature_match = re.search(r"(\d{2})\s*度", text)
|
||||
if temperature_match:
|
||||
slots["temperature"] = int(temperature_match.group(1))
|
||||
|
||||
if intent.intent_id == "cabin_nav_to":
|
||||
destination = self._extract_destination(text)
|
||||
if destination:
|
||||
slots["destination"] = destination
|
||||
|
||||
if intent.intent_id == "cabin_play_music":
|
||||
if "播放" in text:
|
||||
music_target = text.split("播放", maxsplit=1)[-1].strip(" ,。")
|
||||
if music_target:
|
||||
slots["song"] = music_target
|
||||
elif "音乐" in text:
|
||||
slots["genre"] = "轻音乐"
|
||||
|
||||
return slots
|
||||
|
||||
def _extract_destination(self, text: str) -> str | None:
|
||||
patterns = [
|
||||
r"导航去(?P<destination>.+)",
|
||||
r"导航到(?P<destination>.+)",
|
||||
r"去(?P<destination>.+)",
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
destination = match.group("destination").strip(" ,。")
|
||||
if destination:
|
||||
return destination
|
||||
return None
|
||||
|
||||
|
||||
class JointBertSlotExtractor:
|
||||
def __init__(self, nlu: JointBertNLU) -> None:
|
||||
self._nlu = nlu
|
||||
|
||||
def extract(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
|
||||
return self._nlu.extract_slots(text, intent)
|
||||
|
||||
|
||||
class Router(Protocol):
|
||||
def route(self, text: str) -> RouteMatchResult:
|
||||
...
|
||||
|
||||
def match_intent(self, text: str) -> IntentDefinition | None:
|
||||
...
|
||||
|
||||
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
|
||||
...
|
||||
|
||||
|
||||
class IntentRouter:
|
||||
def __init__(self, matcher: IntentMatcher, slot_extractor: SlotExtractor) -> None:
|
||||
self._matcher = matcher
|
||||
self._slot_extractor = slot_extractor
|
||||
|
||||
def route(self, text: str) -> RouteMatchResult:
|
||||
return self._matcher.match(text)
|
||||
|
||||
def match_intent(self, text: str) -> IntentDefinition | None:
|
||||
return self.route(text).intent
|
||||
|
||||
def extract_slots(self, text: str, intent: IntentDefinition) -> dict[str, Any]:
|
||||
return self._slot_extractor.extract(text, intent)
|
||||
|
||||
|
||||
def build_matcher_pipeline(
|
||||
registry: IntentRegistry,
|
||||
stages: list[str],
|
||||
classifier: IntentClassifier | None = None,
|
||||
route_to_cloud_threshold: float = 0.75,
|
||||
clarify_margin_threshold: float = 0.12,
|
||||
classifier_execute_score_threshold: float = 0.55,
|
||||
classifier_execute_margin_threshold: float = 0.18,
|
||||
) -> MultiStageIntentMatcher:
|
||||
normalized_stages = [stage.strip() for stage in stages if stage.strip()]
|
||||
if not normalized_stages:
|
||||
normalized_stages = ["classifier"]
|
||||
if normalized_stages != ["classifier"]:
|
||||
raise ValueError("Only classifier matcher pipeline is supported in bert-first mode")
|
||||
matcher = ClassifierIntentMatcher(registry, classifier) if classifier is not None else NullIntentMatcher()
|
||||
return MultiStageIntentMatcher(
|
||||
registry,
|
||||
[matcher],
|
||||
route_to_cloud_threshold=route_to_cloud_threshold,
|
||||
clarify_margin_threshold=clarify_margin_threshold,
|
||||
classifier_execute_score_threshold=classifier_execute_score_threshold,
|
||||
classifier_execute_margin_threshold=classifier_execute_margin_threshold,
|
||||
)
|
||||
|
||||
|
||||
class NullIntentMatcher:
|
||||
def match(self, text: str) -> IntentMatchResult:
|
||||
_ = text
|
||||
return IntentMatchResult(
|
||||
intent=None,
|
||||
stage_debug=MatcherStageDebug(
|
||||
stage="null",
|
||||
accepted=False,
|
||||
reason="matcher unavailable",
|
||||
candidates=[],
|
||||
),
|
||||
)
|
||||
125
intelligent_cabin/app/services/session_store.py
Normal file
125
intelligent_cabin/app/services/session_store.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionState:
|
||||
session_id: str
|
||||
user_id: str
|
||||
channel: str
|
||||
status: str = "idle"
|
||||
current_intent: str | None = None
|
||||
pending_slots: list[str] = field(default_factory=list)
|
||||
slots: dict[str, Any] = field(default_factory=dict)
|
||||
workflow: dict[str, Any] | None = None
|
||||
routing_debug: dict[str, Any] | None = None
|
||||
last_user_text: str | None = None
|
||||
last_agent_text: str | None = None
|
||||
context_memory: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"user_id": self.user_id,
|
||||
"channel": self.channel,
|
||||
"status": self.status,
|
||||
"current_intent": self.current_intent,
|
||||
"pending_slots": self.pending_slots,
|
||||
"slots": self.slots,
|
||||
"workflow": self.workflow,
|
||||
"routing_debug": self.routing_debug,
|
||||
"last_user_text": self.last_user_text,
|
||||
"last_agent_text": self.last_agent_text,
|
||||
"context_memory": self.context_memory,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SessionState":
|
||||
return cls(
|
||||
session_id=data["session_id"],
|
||||
user_id=data["user_id"],
|
||||
channel=data.get("channel", "app"),
|
||||
status=data.get("status", "idle"),
|
||||
current_intent=data.get("current_intent"),
|
||||
pending_slots=list(data.get("pending_slots", [])),
|
||||
slots=dict(data.get("slots", {})),
|
||||
workflow=data.get("workflow"),
|
||||
routing_debug=data.get("routing_debug"),
|
||||
last_user_text=data.get("last_user_text"),
|
||||
last_agent_text=data.get("last_agent_text"),
|
||||
context_memory=dict(data.get("context_memory", {})),
|
||||
)
|
||||
|
||||
|
||||
class SessionStore(Protocol):
|
||||
def get_or_create(self, session_id: str, user_id: str, channel: str = "app") -> SessionState:
|
||||
...
|
||||
|
||||
def get(self, session_id: str) -> SessionState | None:
|
||||
...
|
||||
|
||||
def save(self, session: SessionState) -> SessionState:
|
||||
...
|
||||
|
||||
|
||||
class InMemorySessionStore:
|
||||
def __init__(self) -> None:
|
||||
self._sessions: dict[str, SessionState] = {}
|
||||
|
||||
def get_or_create(self, session_id: str, user_id: str, channel: str = "app") -> SessionState:
|
||||
session = self._sessions.get(session_id)
|
||||
if session is None:
|
||||
session = SessionState(session_id=session_id, user_id=user_id, channel=channel)
|
||||
self._sessions[session_id] = session
|
||||
return session
|
||||
|
||||
def get(self, session_id: str) -> SessionState | None:
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def save(self, session: SessionState) -> SessionState:
|
||||
self._sessions[session.session_id] = session
|
||||
return session
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str,
|
||||
key_prefix: str = "agent:session",
|
||||
ttl_seconds: int = 86400,
|
||||
) -> None:
|
||||
redis_module = importlib.import_module("redis")
|
||||
self._client = redis_module.Redis.from_url(redis_url, decode_responses=True)
|
||||
self._key_prefix = key_prefix
|
||||
self._ttl_seconds = ttl_seconds
|
||||
|
||||
def get_or_create(self, session_id: str, user_id: str, channel: str = "app") -> SessionState:
|
||||
session = self.get(session_id)
|
||||
if session is not None:
|
||||
return session
|
||||
|
||||
session = SessionState(session_id=session_id, user_id=user_id, channel=channel)
|
||||
self.save(session)
|
||||
return session
|
||||
|
||||
def get(self, session_id: str) -> SessionState | None:
|
||||
payload = self._client.get(self._build_key(session_id))
|
||||
if payload is None:
|
||||
return None
|
||||
return SessionState.from_dict(json.loads(payload))
|
||||
|
||||
def save(self, session: SessionState) -> SessionState:
|
||||
self._client.set(
|
||||
self._build_key(session.session_id),
|
||||
json.dumps(session.to_dict(), ensure_ascii=False),
|
||||
ex=self._ttl_seconds,
|
||||
)
|
||||
return session
|
||||
|
||||
def _build_key(self, session_id: str) -> str:
|
||||
return f"{self._key_prefix}:{session_id}"
|
||||
250
intelligent_cabin/app/services/social.py
Normal file
250
intelligent_cabin/app/services/social.py
Normal file
@@ -0,0 +1,250 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Protocol
|
||||
from urllib import error, request
|
||||
|
||||
from app.services.session_store import SessionState
|
||||
|
||||
|
||||
SocialCategory = Literal["none", "open_social"]
|
||||
ShortSocialKind = Literal["greeting", "thanks", "goodbye", "capability"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SocialRouteResult:
|
||||
category: SocialCategory
|
||||
reason: str
|
||||
short_kind: ShortSocialKind | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SocialReplyResult:
|
||||
text: str
|
||||
backend: str
|
||||
model_name: str
|
||||
error_message: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class SocialResponder(Protocol):
|
||||
def reply(self, text: str, session: SessionState) -> SocialReplyResult:
|
||||
...
|
||||
|
||||
|
||||
class SocialRouter:
|
||||
_SHORT_SOCIAL_PATTERNS: dict[ShortSocialKind, tuple[str, ...]] = {
|
||||
"greeting": ("你好", "您好", "嗨", "哈喽", "hi", "hello", "在吗", "在不在"),
|
||||
"thanks": ("谢谢", "谢啦", "多谢", "thanks", "thank you", "辛苦了"),
|
||||
"goodbye": ("再见", "拜拜", "回头见", "bye", "goodbye"),
|
||||
"capability": (
|
||||
"你是谁",
|
||||
"你叫什么",
|
||||
"你叫什么名字",
|
||||
"你叫啥",
|
||||
"怎么称呼你",
|
||||
"介绍一下你自己",
|
||||
"你能做什么",
|
||||
"你会什么",
|
||||
"你可以做什么",
|
||||
),
|
||||
}
|
||||
_OPEN_SOCIAL_PATTERNS: tuple[str, ...] = (
|
||||
"天气",
|
||||
"讲个笑话",
|
||||
"笑话",
|
||||
"无聊",
|
||||
"有点累",
|
||||
"有点困",
|
||||
"有点烦",
|
||||
"开心",
|
||||
"不开心",
|
||||
"真不错",
|
||||
"真好",
|
||||
"聊聊天",
|
||||
"你觉得",
|
||||
"你怎么看",
|
||||
"你说呢",
|
||||
)
|
||||
_OPEN_SOCIAL_REGEXES: tuple[re.Pattern[str], ...] = (
|
||||
re.compile(r"今天.*(不错|真好|挺好|真舒服)"),
|
||||
re.compile(r"(好|真)热啊"),
|
||||
re.compile(r"(好|真)冷啊"),
|
||||
re.compile(r"我今天.*(累|困|烦|开心|难过)"),
|
||||
re.compile(r".*(怎么样|如何|咋样)[??]?$"),
|
||||
)
|
||||
_CAPABILITY_REGEXES: tuple[re.Pattern[str], ...] = (
|
||||
re.compile(r"你.*(叫.*名字|叫什么|叫啥)[??]?$"),
|
||||
re.compile(r"(怎么称呼你|介绍一下你自己)[??]?$"),
|
||||
re.compile(r"你.*(能做什么|会什么|可以做什么)[??]?$"),
|
||||
)
|
||||
_TASK_KEYWORDS: tuple[str, ...] = (
|
||||
"订单",
|
||||
"物流",
|
||||
"取消",
|
||||
"转人工",
|
||||
"导航",
|
||||
"去",
|
||||
"到",
|
||||
"空调",
|
||||
"温度",
|
||||
"调到",
|
||||
"播放",
|
||||
"音乐",
|
||||
"歌曲",
|
||||
"车窗",
|
||||
"座椅",
|
||||
"后视镜",
|
||||
"灯光",
|
||||
"除雾",
|
||||
"确认",
|
||||
"不用",
|
||||
)
|
||||
|
||||
def route(self, text: str, session: SessionState) -> SocialRouteResult:
|
||||
normalized = self._normalize(text)
|
||||
if not normalized:
|
||||
return SocialRouteResult(category="none", reason="empty text")
|
||||
if self._looks_like_task(normalized):
|
||||
return SocialRouteResult(category="none", reason="contains task keywords")
|
||||
for short_kind, patterns in self._SHORT_SOCIAL_PATTERNS.items():
|
||||
if any(pattern in normalized for pattern in patterns):
|
||||
return SocialRouteResult(
|
||||
category="open_social",
|
||||
short_kind=short_kind,
|
||||
reason=f"matched social pattern routed to llm: {short_kind}",
|
||||
)
|
||||
if any(regex.search(normalized) for regex in self._CAPABILITY_REGEXES):
|
||||
return SocialRouteResult(
|
||||
category="open_social",
|
||||
short_kind="capability",
|
||||
reason="matched capability social regex routed to llm",
|
||||
)
|
||||
if any(pattern in normalized for pattern in self._OPEN_SOCIAL_PATTERNS):
|
||||
return SocialRouteResult(category="open_social", reason="matched open social phrase")
|
||||
if any(regex.search(normalized) for regex in self._OPEN_SOCIAL_REGEXES):
|
||||
return SocialRouteResult(category="open_social", reason="matched open social regex")
|
||||
if session.context_memory.get("last_dialog_mode") == "open_social" and len(normalized) <= 14:
|
||||
return SocialRouteResult(category="open_social", reason="follow-up to previous open social turn")
|
||||
return SocialRouteResult(category="none", reason="no social pattern matched")
|
||||
|
||||
def _normalize(self, text: str) -> str:
|
||||
return re.sub(r"\s+", "", text.strip().lower())
|
||||
|
||||
def _looks_like_task(self, normalized: str) -> bool:
|
||||
if any(keyword in normalized for keyword in self._TASK_KEYWORDS):
|
||||
return True
|
||||
return bool(re.match(r"^(查|帮我查|打开|关闭|设置|调|导航|播放|取消|转)(.+)", normalized))
|
||||
|
||||
|
||||
class DashScopeSocialResponder:
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
timeout_seconds: float = 6.0,
|
||||
) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._api_key = api_key
|
||||
self._model_name = model_name
|
||||
self._timeout_seconds = timeout_seconds
|
||||
|
||||
def reply(self, text: str, session: SessionState) -> SocialReplyResult:
|
||||
if not self._base_url or not self._api_key or not self._model_name:
|
||||
return SocialReplyResult(
|
||||
text="可以和你聊两句,你也可以继续让我处理查询或控制。",
|
||||
backend="local-fallback",
|
||||
model_name="social-fallback",
|
||||
error_message="social responder is not configured",
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": self._model_name,
|
||||
"temperature": 0.6,
|
||||
"enable_thinking": False,
|
||||
"max_tokens": 120,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"你是智能座舱助手,负责处理所有闲聊、问候、身份问答、能力介绍和开放聊天。"
|
||||
"请用自然、口语化、简短的中文回答,优先 1-3 句,总长度尽量不超过 50 个字。"
|
||||
"如果用户在打招呼、问你是谁、问你叫什么名字、问你会什么,请直接自然回答,不要像固定菜单。"
|
||||
"可以结合用户上下文自然接话,但不要过度展开。"
|
||||
"不要编造已经执行了任何车辆或客服动作。"
|
||||
"不要输出 JSON,不要长篇解释。"
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"text": text,
|
||||
"context": {
|
||||
"last_user_text": session.last_user_text,
|
||||
"last_agent_text": session.last_agent_text,
|
||||
"current_intent": session.current_intent,
|
||||
"status": session.status,
|
||||
},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
],
|
||||
}
|
||||
req = request.Request(
|
||||
self._endpoint(),
|
||||
data=json.dumps(payload).encode("utf-8"),
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with request.urlopen(req, timeout=self._timeout_seconds) as response:
|
||||
data = json.loads(response.read().decode("utf-8"))
|
||||
except (error.URLError, TimeoutError, ValueError) as exc:
|
||||
return SocialReplyResult(
|
||||
text="是啊,听起来今天状态不错。",
|
||||
backend="local-fallback",
|
||||
model_name="social-fallback",
|
||||
error_message=str(exc),
|
||||
)
|
||||
|
||||
content = self._extract_content(data)
|
||||
if not content:
|
||||
return SocialReplyResult(
|
||||
text="可以和你聊两句,你也可以继续说说看。",
|
||||
backend="local-fallback",
|
||||
model_name="social-fallback",
|
||||
error_message="empty social response",
|
||||
)
|
||||
return SocialReplyResult(
|
||||
text=content,
|
||||
backend="dashscope",
|
||||
model_name=self._model_name,
|
||||
)
|
||||
|
||||
def _endpoint(self) -> str:
|
||||
if self._base_url.endswith("/chat/completions"):
|
||||
return self._base_url
|
||||
return f"{self._base_url}/chat/completions"
|
||||
|
||||
def _extract_content(self, payload: dict[str, Any]) -> str:
|
||||
choices = payload.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
return ""
|
||||
message = choices[0].get("message", {})
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
str(item.get("text", "")).strip()
|
||||
for item in content
|
||||
if isinstance(item, dict) and item.get("type") == "text"
|
||||
]
|
||||
return "".join(parts).strip()
|
||||
return str(content).strip()
|
||||
Reference in New Issue
Block a user