Files
ai-device/intelligent_cabin/app/services/knowledge_llm.py
2026-06-11 16:28:00 +08:00

239 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
app/services/knowledge_llm.py
当 BERT NLU 未命中时,使用 LLM + knowledge_search function call 查询本地知识库。
流程:
1. 构建 tools=[knowledge_search] 发给 LLM
2. 若 LLM 返回 tool_calls → 执行 KnowledgeStore.search() → 拼结果再发一次 LLM
3. LLM 生成最终回复 reply_text + knowledge_doc_id可选
返回 KnowledgeReplyResult包含
- reply_text: 简短自然语言摘要
- doc_id / doc_content: 命中的知识文档(供前端渲染 KnowledgeArtifact
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Any
from urllib import error, request
from app.services.knowledge_store import KnowledgeDoc, KnowledgeStore
@dataclass
class KnowledgeReplyResult:
reply_text: str
backend: str
model_name: str
doc_id: str | None = None
doc_content: str | None = None # 原始 MD 内容,前端渲染用
doc_title: str | None = None
error_message: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
# ── LLM 工具定义OpenAI function calling 格式DashScope 兼容)────────────────
_KNOWLEDGE_SEARCH_TOOL: dict[str, Any] = {
"type": "function",
"function": {
"name": "knowledge_search",
"description": (
"搜索本地设备知识库,获取焊管机/弯管机产线相关的故障排查、操作规程等知识。"
"当用户问到设备故障、报警处理、操作方法、工艺参数时请调用此工具。"
),
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索关键词,如'虚焊报警''激光扫描仪操作''弯管模具调节'",
}
},
"required": ["query"],
},
},
}
_SYSTEM_PROMPT = """\
你是焊管机产线智能助手,负责回答操作工人关于设备故障、工艺调节、仪器使用的问题。
你有一个工具 knowledge_search 可以查询本地设备知识库,遇到设备类问题时请先调用它。
回答时语言简洁、口语化,先给出结论,再说步骤,总长度不超过 100 字。
如果工具返回了相关知识,请基于知识内容回答,不要编造。
如果没有找到相关知识,诚实告知"暂未找到相关资料,建议联系技术支持"
"""
class DashScopeKnowledgeLLM:
"""使用 DashScopeOpenAI 兼容 API+ function calling 的知识库问答器。"""
def __init__(
self,
base_url: str,
api_key: str,
model_name: str,
knowledge_store: KnowledgeStore,
timeout_seconds: float = 12.0,
max_tool_rounds: int = 2,
) -> None:
self._base_url = base_url.rstrip("/")
self._api_key = api_key
self._model_name = model_name
self._store = knowledge_store
self._timeout = timeout_seconds
self._max_tool_rounds = max_tool_rounds
# ── 主入口 ─────────────────────────────────────────────────────────────────
def reply(self, user_text: str) -> KnowledgeReplyResult:
"""完整 function-call 对话流(最多 max_tool_rounds 轮工具调用)。"""
if not self._base_url or not self._api_key or not self._model_name:
return self._local_fallback(user_text, "LLM not configured")
messages: list[dict[str, Any]] = [
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": user_text},
]
found_doc: KnowledgeDoc | None = None
for _round in range(self._max_tool_rounds):
raw = self._chat(messages, tools=[_KNOWLEDGE_SEARCH_TOOL])
if raw is None:
return self._local_fallback(user_text, "LLM request failed")
choice = self._first_choice(raw)
if choice is None:
return self._local_fallback(user_text, "empty choices")
finish_reason = choice.get("finish_reason", "")
message = choice.get("message", {})
# ── 工具调用分支 ─────────────────────────────────────────────────
if finish_reason == "tool_calls" or message.get("tool_calls"):
tool_calls = message.get("tool_calls", [])
messages.append({"role": "assistant", **message})
for tc in tool_calls:
fn_name = tc.get("function", {}).get("name", "")
fn_args_raw = tc.get("function", {}).get("arguments", "{}")
tc_id = tc.get("id", "call_0")
if fn_name == "knowledge_search":
try:
fn_args = json.loads(fn_args_raw)
except json.JSONDecodeError:
fn_args = {"query": user_text}
query = fn_args.get("query", user_text)
tool_result, found_doc = self._run_knowledge_search(query)
else:
tool_result = f"Unknown tool: {fn_name}"
messages.append({
"role": "tool",
"tool_call_id": tc_id,
"content": tool_result,
})
# 继续下一轮 LLM 调用
continue
# ── 正常文本回复 ─────────────────────────────────────────────────
content = self._extract_content(message)
if not content:
return self._local_fallback(user_text, "empty content")
return KnowledgeReplyResult(
reply_text=content,
backend="dashscope",
model_name=self._model_name,
doc_id=found_doc.doc_id if found_doc else None,
doc_content=found_doc.content if found_doc else None,
doc_title=found_doc.title if found_doc else None,
)
# 超出工具调用轮数,直接本地兜底
return self._local_fallback(user_text, "max tool rounds exceeded")
# ── 内部工具执行 ───────────────────────────────────────────────────────────
def _run_knowledge_search(self, query: str) -> tuple[str, KnowledgeDoc | None]:
"""执行本地知识库搜索,返回 (tool_result_str, best_doc)。"""
results = self._store.search(query, top_k=2)
if not results:
return "未找到相关知识文档。", None
best = results[0]
# 给 LLM 的 tool result文档标题 + 正文(截断到 800 字节)
excerpt = best.doc.content[:800]
tool_text = (
f"[知识库检索结果]\n"
f"文档:{best.doc.title}\n"
f"命中关键词:{', '.join(best.matched_keywords)}\n\n"
f"{excerpt}"
)
return tool_text, best.doc
# ── HTTP 调用 ──────────────────────────────────────────────────────────────
def _chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
payload: dict[str, Any] = {
"model": self._model_name,
"temperature": 0.3,
"enable_thinking": False,
"max_tokens": 300,
"messages": messages,
}
if tools:
payload["tools"] = tools
payload["tool_choice"] = "auto"
req = request.Request(
self._endpoint(),
data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self._api_key}",
},
method="POST",
)
try:
with request.urlopen(req, timeout=self._timeout) as resp:
return json.loads(resp.read().decode("utf-8"))
except (error.URLError, TimeoutError, ValueError, OSError):
return None
def _endpoint(self) -> str:
if self._base_url.endswith("/chat/completions"):
return self._base_url
return f"{self._base_url}/chat/completions"
def _first_choice(self, payload: dict[str, Any]) -> dict[str, Any] | None:
choices = payload.get("choices")
if not isinstance(choices, list) or not choices:
return None
return choices[0]
def _extract_content(self, message: dict[str, Any]) -> str:
content = message.get("content", "")
if isinstance(content, list):
return "".join(
str(item.get("text", "")).strip()
for item in content
if isinstance(item, dict) and item.get("type") == "text"
).strip()
return str(content).strip()
def _local_fallback(self, _user_text: str, reason: str) -> KnowledgeReplyResult:
return KnowledgeReplyResult(
reply_text="暂未找到相关资料,建议联系技术支持或查阅设备手册。",
backend="local-fallback",
model_name="knowledge-fallback",
error_message=reason,
)