239 lines
9.6 KiB
Python
239 lines
9.6 KiB
Python
"""
|
||
app/services/knowledge_llm.py
|
||
|
||
当 BERT NLU 未命中时,使用 LLM + knowledge_search function call 查询本地知识库。
|
||
|
||
流程:
|
||
1. 构建 tools=[knowledge_search] 发给 LLM
|
||
2. 若 LLM 返回 tool_calls → 执行 KnowledgeStore.search() → 拼结果再发一次 LLM
|
||
3. LLM 生成最终回复 reply_text + knowledge_doc_id(可选)
|
||
|
||
返回 KnowledgeReplyResult,包含:
|
||
- reply_text: 简短自然语言摘要
|
||
- doc_id / doc_content: 命中的知识文档(供前端渲染 KnowledgeArtifact)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from dataclasses import dataclass, field
|
||
from typing import Any
|
||
from urllib import error, request
|
||
|
||
from app.services.knowledge_store import KnowledgeDoc, KnowledgeStore
|
||
|
||
|
||
@dataclass
|
||
class KnowledgeReplyResult:
|
||
reply_text: str
|
||
backend: str
|
||
model_name: str
|
||
doc_id: str | None = None
|
||
doc_content: str | None = None # 原始 MD 内容,前端渲染用
|
||
doc_title: str | None = None
|
||
error_message: str | None = None
|
||
metadata: dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
# ── LLM 工具定义(OpenAI function calling 格式,DashScope 兼容)────────────────
|
||
_KNOWLEDGE_SEARCH_TOOL: dict[str, Any] = {
|
||
"type": "function",
|
||
"function": {
|
||
"name": "knowledge_search",
|
||
"description": (
|
||
"搜索本地设备知识库,获取焊管机/弯管机产线相关的故障排查、操作规程等知识。"
|
||
"当用户问到设备故障、报警处理、操作方法、工艺参数时请调用此工具。"
|
||
),
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {
|
||
"type": "string",
|
||
"description": "搜索关键词,如'虚焊报警'、'激光扫描仪操作'、'弯管模具调节'等",
|
||
}
|
||
},
|
||
"required": ["query"],
|
||
},
|
||
},
|
||
}
|
||
|
||
_SYSTEM_PROMPT = """\
|
||
你是焊管机产线智能助手,负责回答操作工人关于设备故障、工艺调节、仪器使用的问题。
|
||
你有一个工具 knowledge_search 可以查询本地设备知识库,遇到设备类问题时请先调用它。
|
||
回答时语言简洁、口语化,先给出结论,再说步骤,总长度不超过 100 字。
|
||
如果工具返回了相关知识,请基于知识内容回答,不要编造。
|
||
如果没有找到相关知识,诚实告知"暂未找到相关资料,建议联系技术支持"。
|
||
"""
|
||
|
||
|
||
class DashScopeKnowledgeLLM:
|
||
"""使用 DashScope(OpenAI 兼容 API)+ function calling 的知识库问答器。"""
|
||
|
||
def __init__(
|
||
self,
|
||
base_url: str,
|
||
api_key: str,
|
||
model_name: str,
|
||
knowledge_store: KnowledgeStore,
|
||
timeout_seconds: float = 12.0,
|
||
max_tool_rounds: int = 2,
|
||
) -> None:
|
||
self._base_url = base_url.rstrip("/")
|
||
self._api_key = api_key
|
||
self._model_name = model_name
|
||
self._store = knowledge_store
|
||
self._timeout = timeout_seconds
|
||
self._max_tool_rounds = max_tool_rounds
|
||
|
||
# ── 主入口 ─────────────────────────────────────────────────────────────────
|
||
|
||
def reply(self, user_text: str) -> KnowledgeReplyResult:
|
||
"""完整 function-call 对话流(最多 max_tool_rounds 轮工具调用)。"""
|
||
if not self._base_url or not self._api_key or not self._model_name:
|
||
return self._local_fallback(user_text, "LLM not configured")
|
||
|
||
messages: list[dict[str, Any]] = [
|
||
{"role": "system", "content": _SYSTEM_PROMPT},
|
||
{"role": "user", "content": user_text},
|
||
]
|
||
|
||
found_doc: KnowledgeDoc | None = None
|
||
|
||
for _round in range(self._max_tool_rounds):
|
||
raw = self._chat(messages, tools=[_KNOWLEDGE_SEARCH_TOOL])
|
||
if raw is None:
|
||
return self._local_fallback(user_text, "LLM request failed")
|
||
|
||
choice = self._first_choice(raw)
|
||
if choice is None:
|
||
return self._local_fallback(user_text, "empty choices")
|
||
|
||
finish_reason = choice.get("finish_reason", "")
|
||
message = choice.get("message", {})
|
||
|
||
# ── 工具调用分支 ─────────────────────────────────────────────────
|
||
if finish_reason == "tool_calls" or message.get("tool_calls"):
|
||
tool_calls = message.get("tool_calls", [])
|
||
messages.append({"role": "assistant", **message})
|
||
|
||
for tc in tool_calls:
|
||
fn_name = tc.get("function", {}).get("name", "")
|
||
fn_args_raw = tc.get("function", {}).get("arguments", "{}")
|
||
tc_id = tc.get("id", "call_0")
|
||
|
||
if fn_name == "knowledge_search":
|
||
try:
|
||
fn_args = json.loads(fn_args_raw)
|
||
except json.JSONDecodeError:
|
||
fn_args = {"query": user_text}
|
||
|
||
query = fn_args.get("query", user_text)
|
||
tool_result, found_doc = self._run_knowledge_search(query)
|
||
else:
|
||
tool_result = f"Unknown tool: {fn_name}"
|
||
|
||
messages.append({
|
||
"role": "tool",
|
||
"tool_call_id": tc_id,
|
||
"content": tool_result,
|
||
})
|
||
# 继续下一轮 LLM 调用
|
||
continue
|
||
|
||
# ── 正常文本回复 ─────────────────────────────────────────────────
|
||
content = self._extract_content(message)
|
||
if not content:
|
||
return self._local_fallback(user_text, "empty content")
|
||
|
||
return KnowledgeReplyResult(
|
||
reply_text=content,
|
||
backend="dashscope",
|
||
model_name=self._model_name,
|
||
doc_id=found_doc.doc_id if found_doc else None,
|
||
doc_content=found_doc.content if found_doc else None,
|
||
doc_title=found_doc.title if found_doc else None,
|
||
)
|
||
|
||
# 超出工具调用轮数,直接本地兜底
|
||
return self._local_fallback(user_text, "max tool rounds exceeded")
|
||
|
||
# ── 内部工具执行 ───────────────────────────────────────────────────────────
|
||
|
||
def _run_knowledge_search(self, query: str) -> tuple[str, KnowledgeDoc | None]:
|
||
"""执行本地知识库搜索,返回 (tool_result_str, best_doc)。"""
|
||
results = self._store.search(query, top_k=2)
|
||
if not results:
|
||
return "未找到相关知识文档。", None
|
||
|
||
best = results[0]
|
||
# 给 LLM 的 tool result:文档标题 + 正文(截断到 800 字节)
|
||
excerpt = best.doc.content[:800]
|
||
tool_text = (
|
||
f"[知识库检索结果]\n"
|
||
f"文档:{best.doc.title}\n"
|
||
f"命中关键词:{', '.join(best.matched_keywords)}\n\n"
|
||
f"{excerpt}"
|
||
)
|
||
return tool_text, best.doc
|
||
|
||
# ── HTTP 调用 ──────────────────────────────────────────────────────────────
|
||
|
||
def _chat(
|
||
self,
|
||
messages: list[dict[str, Any]],
|
||
tools: list[dict[str, Any]] | None = None,
|
||
) -> dict[str, Any] | None:
|
||
payload: dict[str, Any] = {
|
||
"model": self._model_name,
|
||
"temperature": 0.3,
|
||
"enable_thinking": False,
|
||
"max_tokens": 300,
|
||
"messages": messages,
|
||
}
|
||
if tools:
|
||
payload["tools"] = tools
|
||
payload["tool_choice"] = "auto"
|
||
|
||
req = request.Request(
|
||
self._endpoint(),
|
||
data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
|
||
headers={
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {self._api_key}",
|
||
},
|
||
method="POST",
|
||
)
|
||
try:
|
||
with request.urlopen(req, timeout=self._timeout) as resp:
|
||
return json.loads(resp.read().decode("utf-8"))
|
||
except (error.URLError, TimeoutError, ValueError, OSError):
|
||
return None
|
||
|
||
def _endpoint(self) -> str:
|
||
if self._base_url.endswith("/chat/completions"):
|
||
return self._base_url
|
||
return f"{self._base_url}/chat/completions"
|
||
|
||
def _first_choice(self, payload: dict[str, Any]) -> dict[str, Any] | None:
|
||
choices = payload.get("choices")
|
||
if not isinstance(choices, list) or not choices:
|
||
return None
|
||
return choices[0]
|
||
|
||
def _extract_content(self, message: dict[str, Any]) -> str:
|
||
content = message.get("content", "")
|
||
if isinstance(content, list):
|
||
return "".join(
|
||
str(item.get("text", "")).strip()
|
||
for item in content
|
||
if isinstance(item, dict) and item.get("type") == "text"
|
||
).strip()
|
||
return str(content).strip()
|
||
|
||
def _local_fallback(self, _user_text: str, reason: str) -> KnowledgeReplyResult:
|
||
return KnowledgeReplyResult(
|
||
reply_text="暂未找到相关资料,建议联系技术支持或查阅设备手册。",
|
||
backend="local-fallback",
|
||
model_name="knowledge-fallback",
|
||
error_message=reason,
|
||
)
|