Files
2026-06-11 16:28:00 +08:00

251 lines
9.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import re
from dataclasses import dataclass, field
from typing import Any, Literal, Protocol
from urllib import error, request
from app.services.session_store import SessionState
SocialCategory = Literal["none", "open_social"]
ShortSocialKind = Literal["greeting", "thanks", "goodbye", "capability"]
@dataclass
class SocialRouteResult:
category: SocialCategory
reason: str
short_kind: ShortSocialKind | None = None
@dataclass
class SocialReplyResult:
text: str
backend: str
model_name: str
error_message: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
class SocialResponder(Protocol):
def reply(self, text: str, session: SessionState) -> SocialReplyResult:
...
class SocialRouter:
_SHORT_SOCIAL_PATTERNS: dict[ShortSocialKind, tuple[str, ...]] = {
"greeting": ("你好", "您好", "", "哈喽", "hi", "hello", "在吗", "在不在"),
"thanks": ("谢谢", "谢啦", "多谢", "thanks", "thank you", "辛苦了"),
"goodbye": ("再见", "拜拜", "回头见", "bye", "goodbye"),
"capability": (
"你是谁",
"你叫什么",
"你叫什么名字",
"你叫啥",
"怎么称呼你",
"介绍一下你自己",
"你能做什么",
"你会什么",
"你可以做什么",
),
}
_OPEN_SOCIAL_PATTERNS: tuple[str, ...] = (
"天气",
"讲个笑话",
"笑话",
"无聊",
"有点累",
"有点困",
"有点烦",
"开心",
"不开心",
"真不错",
"真好",
"聊聊天",
"你觉得",
"你怎么看",
"你说呢",
)
_OPEN_SOCIAL_REGEXES: tuple[re.Pattern[str], ...] = (
re.compile(r"今天.*(不错|真好|挺好|真舒服)"),
re.compile(r"(好|真)热啊"),
re.compile(r"(好|真)冷啊"),
re.compile(r"我今天.*(累|困|烦|开心|难过)"),
re.compile(r".*(怎么样|如何|咋样)[?]?$"),
)
_CAPABILITY_REGEXES: tuple[re.Pattern[str], ...] = (
re.compile(r"你.*(叫.*名字|叫什么|叫啥)[?]?$"),
re.compile(r"(怎么称呼你|介绍一下你自己)[?]?$"),
re.compile(r"你.*(能做什么|会什么|可以做什么)[?]?$"),
)
_TASK_KEYWORDS: tuple[str, ...] = (
"订单",
"物流",
"取消",
"转人工",
"导航",
"",
"",
"空调",
"温度",
"调到",
"播放",
"音乐",
"歌曲",
"车窗",
"座椅",
"后视镜",
"灯光",
"除雾",
"确认",
"不用",
)
def route(self, text: str, session: SessionState) -> SocialRouteResult:
normalized = self._normalize(text)
if not normalized:
return SocialRouteResult(category="none", reason="empty text")
if self._looks_like_task(normalized):
return SocialRouteResult(category="none", reason="contains task keywords")
for short_kind, patterns in self._SHORT_SOCIAL_PATTERNS.items():
if any(pattern in normalized for pattern in patterns):
return SocialRouteResult(
category="open_social",
short_kind=short_kind,
reason=f"matched social pattern routed to llm: {short_kind}",
)
if any(regex.search(normalized) for regex in self._CAPABILITY_REGEXES):
return SocialRouteResult(
category="open_social",
short_kind="capability",
reason="matched capability social regex routed to llm",
)
if any(pattern in normalized for pattern in self._OPEN_SOCIAL_PATTERNS):
return SocialRouteResult(category="open_social", reason="matched open social phrase")
if any(regex.search(normalized) for regex in self._OPEN_SOCIAL_REGEXES):
return SocialRouteResult(category="open_social", reason="matched open social regex")
if session.context_memory.get("last_dialog_mode") == "open_social" and len(normalized) <= 14:
return SocialRouteResult(category="open_social", reason="follow-up to previous open social turn")
return SocialRouteResult(category="none", reason="no social pattern matched")
def _normalize(self, text: str) -> str:
return re.sub(r"\s+", "", text.strip().lower())
def _looks_like_task(self, normalized: str) -> bool:
if any(keyword in normalized for keyword in self._TASK_KEYWORDS):
return True
return bool(re.match(r"^(查|帮我查|打开|关闭|设置|调|导航|播放|取消|转)(.+)", normalized))
class DashScopeSocialResponder:
def __init__(
self,
base_url: str,
api_key: str,
model_name: str,
timeout_seconds: float = 6.0,
) -> None:
self._base_url = base_url.rstrip("/")
self._api_key = api_key
self._model_name = model_name
self._timeout_seconds = timeout_seconds
def reply(self, text: str, session: SessionState) -> SocialReplyResult:
if not self._base_url or not self._api_key or not self._model_name:
return SocialReplyResult(
text="可以和你聊两句,你也可以继续让我处理查询或控制。",
backend="local-fallback",
model_name="social-fallback",
error_message="social responder is not configured",
)
payload = {
"model": self._model_name,
"temperature": 0.6,
"enable_thinking": False,
"max_tokens": 120,
"messages": [
{
"role": "system",
"content": (
"你是智能座舱助手,负责处理所有闲聊、问候、身份问答、能力介绍和开放聊天。"
"请用自然、口语化、简短的中文回答,优先 1-3 句,总长度尽量不超过 50 个字。"
"如果用户在打招呼、问你是谁、问你叫什么名字、问你会什么,请直接自然回答,不要像固定菜单。"
"可以结合用户上下文自然接话,但不要过度展开。"
"不要编造已经执行了任何车辆或客服动作。"
"不要输出 JSON不要长篇解释。"
),
},
{
"role": "user",
"content": json.dumps(
{
"text": text,
"context": {
"last_user_text": session.last_user_text,
"last_agent_text": session.last_agent_text,
"current_intent": session.current_intent,
"status": session.status,
},
},
ensure_ascii=False,
),
},
],
}
req = request.Request(
self._endpoint(),
data=json.dumps(payload).encode("utf-8"),
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self._api_key}",
},
method="POST",
)
try:
with request.urlopen(req, timeout=self._timeout_seconds) as response:
data = json.loads(response.read().decode("utf-8"))
except (error.URLError, TimeoutError, ValueError) as exc:
return SocialReplyResult(
text="是啊,听起来今天状态不错。",
backend="local-fallback",
model_name="social-fallback",
error_message=str(exc),
)
content = self._extract_content(data)
if not content:
return SocialReplyResult(
text="可以和你聊两句,你也可以继续说说看。",
backend="local-fallback",
model_name="social-fallback",
error_message="empty social response",
)
return SocialReplyResult(
text=content,
backend="dashscope",
model_name=self._model_name,
)
def _endpoint(self) -> str:
if self._base_url.endswith("/chat/completions"):
return self._base_url
return f"{self._base_url}/chat/completions"
def _extract_content(self, payload: dict[str, Any]) -> str:
choices = payload.get("choices")
if not isinstance(choices, list) or not choices:
return ""
message = choices[0].get("message", {})
content = message.get("content", "")
if isinstance(content, list):
parts = [
str(item.get("text", "")).strip()
for item in content
if isinstance(item, dict) and item.get("type") == "text"
]
return "".join(parts).strip()
return str(content).strip()