251 lines
9.1 KiB
Python
251 lines
9.1 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import re
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Literal, Protocol
|
||
from urllib import error, request
|
||
|
||
from app.services.session_store import SessionState
|
||
|
||
|
||
SocialCategory = Literal["none", "open_social"]
|
||
ShortSocialKind = Literal["greeting", "thanks", "goodbye", "capability"]
|
||
|
||
|
||
@dataclass
|
||
class SocialRouteResult:
|
||
category: SocialCategory
|
||
reason: str
|
||
short_kind: ShortSocialKind | None = None
|
||
|
||
|
||
@dataclass
|
||
class SocialReplyResult:
|
||
text: str
|
||
backend: str
|
||
model_name: str
|
||
error_message: str | None = None
|
||
metadata: dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
class SocialResponder(Protocol):
|
||
def reply(self, text: str, session: SessionState) -> SocialReplyResult:
|
||
...
|
||
|
||
|
||
class SocialRouter:
|
||
_SHORT_SOCIAL_PATTERNS: dict[ShortSocialKind, tuple[str, ...]] = {
|
||
"greeting": ("你好", "您好", "嗨", "哈喽", "hi", "hello", "在吗", "在不在"),
|
||
"thanks": ("谢谢", "谢啦", "多谢", "thanks", "thank you", "辛苦了"),
|
||
"goodbye": ("再见", "拜拜", "回头见", "bye", "goodbye"),
|
||
"capability": (
|
||
"你是谁",
|
||
"你叫什么",
|
||
"你叫什么名字",
|
||
"你叫啥",
|
||
"怎么称呼你",
|
||
"介绍一下你自己",
|
||
"你能做什么",
|
||
"你会什么",
|
||
"你可以做什么",
|
||
),
|
||
}
|
||
_OPEN_SOCIAL_PATTERNS: tuple[str, ...] = (
|
||
"天气",
|
||
"讲个笑话",
|
||
"笑话",
|
||
"无聊",
|
||
"有点累",
|
||
"有点困",
|
||
"有点烦",
|
||
"开心",
|
||
"不开心",
|
||
"真不错",
|
||
"真好",
|
||
"聊聊天",
|
||
"你觉得",
|
||
"你怎么看",
|
||
"你说呢",
|
||
)
|
||
_OPEN_SOCIAL_REGEXES: tuple[re.Pattern[str], ...] = (
|
||
re.compile(r"今天.*(不错|真好|挺好|真舒服)"),
|
||
re.compile(r"(好|真)热啊"),
|
||
re.compile(r"(好|真)冷啊"),
|
||
re.compile(r"我今天.*(累|困|烦|开心|难过)"),
|
||
re.compile(r".*(怎么样|如何|咋样)[??]?$"),
|
||
)
|
||
_CAPABILITY_REGEXES: tuple[re.Pattern[str], ...] = (
|
||
re.compile(r"你.*(叫.*名字|叫什么|叫啥)[??]?$"),
|
||
re.compile(r"(怎么称呼你|介绍一下你自己)[??]?$"),
|
||
re.compile(r"你.*(能做什么|会什么|可以做什么)[??]?$"),
|
||
)
|
||
_TASK_KEYWORDS: tuple[str, ...] = (
|
||
"订单",
|
||
"物流",
|
||
"取消",
|
||
"转人工",
|
||
"导航",
|
||
"去",
|
||
"到",
|
||
"空调",
|
||
"温度",
|
||
"调到",
|
||
"播放",
|
||
"音乐",
|
||
"歌曲",
|
||
"车窗",
|
||
"座椅",
|
||
"后视镜",
|
||
"灯光",
|
||
"除雾",
|
||
"确认",
|
||
"不用",
|
||
)
|
||
|
||
def route(self, text: str, session: SessionState) -> SocialRouteResult:
|
||
normalized = self._normalize(text)
|
||
if not normalized:
|
||
return SocialRouteResult(category="none", reason="empty text")
|
||
if self._looks_like_task(normalized):
|
||
return SocialRouteResult(category="none", reason="contains task keywords")
|
||
for short_kind, patterns in self._SHORT_SOCIAL_PATTERNS.items():
|
||
if any(pattern in normalized for pattern in patterns):
|
||
return SocialRouteResult(
|
||
category="open_social",
|
||
short_kind=short_kind,
|
||
reason=f"matched social pattern routed to llm: {short_kind}",
|
||
)
|
||
if any(regex.search(normalized) for regex in self._CAPABILITY_REGEXES):
|
||
return SocialRouteResult(
|
||
category="open_social",
|
||
short_kind="capability",
|
||
reason="matched capability social regex routed to llm",
|
||
)
|
||
if any(pattern in normalized for pattern in self._OPEN_SOCIAL_PATTERNS):
|
||
return SocialRouteResult(category="open_social", reason="matched open social phrase")
|
||
if any(regex.search(normalized) for regex in self._OPEN_SOCIAL_REGEXES):
|
||
return SocialRouteResult(category="open_social", reason="matched open social regex")
|
||
if session.context_memory.get("last_dialog_mode") == "open_social" and len(normalized) <= 14:
|
||
return SocialRouteResult(category="open_social", reason="follow-up to previous open social turn")
|
||
return SocialRouteResult(category="none", reason="no social pattern matched")
|
||
|
||
def _normalize(self, text: str) -> str:
|
||
return re.sub(r"\s+", "", text.strip().lower())
|
||
|
||
def _looks_like_task(self, normalized: str) -> bool:
|
||
if any(keyword in normalized for keyword in self._TASK_KEYWORDS):
|
||
return True
|
||
return bool(re.match(r"^(查|帮我查|打开|关闭|设置|调|导航|播放|取消|转)(.+)", normalized))
|
||
|
||
|
||
class DashScopeSocialResponder:
|
||
def __init__(
|
||
self,
|
||
base_url: str,
|
||
api_key: str,
|
||
model_name: str,
|
||
timeout_seconds: float = 6.0,
|
||
) -> None:
|
||
self._base_url = base_url.rstrip("/")
|
||
self._api_key = api_key
|
||
self._model_name = model_name
|
||
self._timeout_seconds = timeout_seconds
|
||
|
||
def reply(self, text: str, session: SessionState) -> SocialReplyResult:
|
||
if not self._base_url or not self._api_key or not self._model_name:
|
||
return SocialReplyResult(
|
||
text="可以和你聊两句,你也可以继续让我处理查询或控制。",
|
||
backend="local-fallback",
|
||
model_name="social-fallback",
|
||
error_message="social responder is not configured",
|
||
)
|
||
|
||
payload = {
|
||
"model": self._model_name,
|
||
"temperature": 0.6,
|
||
"enable_thinking": False,
|
||
"max_tokens": 120,
|
||
"messages": [
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"你是智能座舱助手,负责处理所有闲聊、问候、身份问答、能力介绍和开放聊天。"
|
||
"请用自然、口语化、简短的中文回答,优先 1-3 句,总长度尽量不超过 50 个字。"
|
||
"如果用户在打招呼、问你是谁、问你叫什么名字、问你会什么,请直接自然回答,不要像固定菜单。"
|
||
"可以结合用户上下文自然接话,但不要过度展开。"
|
||
"不要编造已经执行了任何车辆或客服动作。"
|
||
"不要输出 JSON,不要长篇解释。"
|
||
),
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": json.dumps(
|
||
{
|
||
"text": text,
|
||
"context": {
|
||
"last_user_text": session.last_user_text,
|
||
"last_agent_text": session.last_agent_text,
|
||
"current_intent": session.current_intent,
|
||
"status": session.status,
|
||
},
|
||
},
|
||
ensure_ascii=False,
|
||
),
|
||
},
|
||
],
|
||
}
|
||
req = request.Request(
|
||
self._endpoint(),
|
||
data=json.dumps(payload).encode("utf-8"),
|
||
headers={
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {self._api_key}",
|
||
},
|
||
method="POST",
|
||
)
|
||
try:
|
||
with request.urlopen(req, timeout=self._timeout_seconds) as response:
|
||
data = json.loads(response.read().decode("utf-8"))
|
||
except (error.URLError, TimeoutError, ValueError) as exc:
|
||
return SocialReplyResult(
|
||
text="是啊,听起来今天状态不错。",
|
||
backend="local-fallback",
|
||
model_name="social-fallback",
|
||
error_message=str(exc),
|
||
)
|
||
|
||
content = self._extract_content(data)
|
||
if not content:
|
||
return SocialReplyResult(
|
||
text="可以和你聊两句,你也可以继续说说看。",
|
||
backend="local-fallback",
|
||
model_name="social-fallback",
|
||
error_message="empty social response",
|
||
)
|
||
return SocialReplyResult(
|
||
text=content,
|
||
backend="dashscope",
|
||
model_name=self._model_name,
|
||
)
|
||
|
||
def _endpoint(self) -> str:
|
||
if self._base_url.endswith("/chat/completions"):
|
||
return self._base_url
|
||
return f"{self._base_url}/chat/completions"
|
||
|
||
def _extract_content(self, payload: dict[str, Any]) -> str:
|
||
choices = payload.get("choices")
|
||
if not isinstance(choices, list) or not choices:
|
||
return ""
|
||
message = choices[0].get("message", {})
|
||
content = message.get("content", "")
|
||
if isinstance(content, list):
|
||
parts = [
|
||
str(item.get("text", "")).strip()
|
||
for item in content
|
||
if isinstance(item, dict) and item.get("type") == "text"
|
||
]
|
||
return "".join(parts).strip()
|
||
return str(content).strip()
|