Files
ai-device/intelligent_cabin/app/services/knowledge_store.py
2026-06-11 16:28:00 +08:00

153 lines
5.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
app/services/knowledge_store.py
本地 Markdown 知识库加载与关键词检索。
- 所有 .md 文件存放在 config/knowledge/ 目录
- 基于关键词打分,支持多文档排序返回
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class KnowledgeDoc:
"""一篇知识文档的元数据与正文。"""
doc_id: str # 文件名(不含扩展名)
title: str # MD 首行 # 标题,无则用文件名
content: str # 完整原始 Markdown 内容
keywords: list[str] = field(default_factory=list) # 从正文抽取的高频词
@dataclass
class SearchResult:
doc: KnowledgeDoc
score: float
matched_keywords: list[str]
class KnowledgeStore:
"""从 config/knowledge/*.md 加载知识库,提供关键词检索。"""
def __init__(self, knowledge_dir: str | Path) -> None:
self._dir = Path(knowledge_dir)
self._docs: dict[str, KnowledgeDoc] = {}
self._load()
# ── 公开 API ───────────────────────────────────────────────────────────────
def search(self, query: str, top_k: int = 3) -> list[SearchResult]:
"""根据 query 检索最相关的知识文档,返回最多 top_k 条。"""
query_tokens = self._tokenize(query)
if not query_tokens:
return []
results: list[SearchResult] = []
for doc in self._docs.values():
score, matched = self._score(doc, query_tokens)
if score > 0:
results.append(SearchResult(doc=doc, score=score, matched_keywords=matched))
results.sort(key=lambda r: r.score, reverse=True)
return results[:top_k]
def get(self, doc_id: str) -> KnowledgeDoc | None:
return self._docs.get(doc_id)
def all_doc_ids(self) -> list[str]:
return list(self._docs.keys())
def reload(self) -> None:
"""热重载知识库(添加新 MD 文件后调用)。"""
self._docs.clear()
self._load()
# ── 内部逻辑 ───────────────────────────────────────────────────────────────
def _load(self) -> None:
if not self._dir.exists():
return
for md_path in sorted(self._dir.glob("*.md")):
doc = self._parse_md(md_path)
self._docs[doc.doc_id] = doc
def _parse_md(self, path: Path) -> KnowledgeDoc:
content = path.read_text(encoding="utf-8")
doc_id = path.stem
# 提取第一个 # 标题作为文档标题
title_match = re.search(r"^#+\s+(.+)", content, re.MULTILINE)
title = title_match.group(1).strip() if title_match else doc_id
# 提取关键词去标点后的中文词段2~6字
keywords = self._extract_keywords(content)
return KnowledgeDoc(doc_id=doc_id, title=title, content=content, keywords=keywords)
def _extract_keywords(self, content: str) -> list[str]:
"""提取 MD 正文中的中文词段作为候选关键词。"""
# 去掉 Markdown 语法符号
text = re.sub(r"[#`*_>|~\[\]()!]", " ", content)
text = re.sub(r"https?://\S+", " ", text)
# 中文词段2-6 个汉字)
words = re.findall(r"[\u4e00-\u9fff]{2,6}", text)
# 去重,保留顺序
seen: set[str] = set()
unique: list[str] = []
for w in words:
if w not in seen:
seen.add(w)
unique.append(w)
return unique
def _tokenize(self, text: str) -> list[str]:
"""将 query 分割成候选检索词。
策略:
1. 提取所有连续中文字段2字以上作为候选
2. 在连续中文字段上做滑动窗口2-5字覆盖子串匹配
避免整句 '虚焊报警怎么办' 作为单一 token 无法匹配 '虚焊报警'
"""
# 提取所有连续中文片段
chinese_chunks = re.findall(r"[\u4e00-\u9fff]+", text)
tokens: list[str] = []
for chunk in chinese_chunks:
# 滑动窗口:长度 2 到 min(5, len(chunk))
for size in range(2, min(6, len(chunk) + 1)):
for start in range(len(chunk) - size + 1):
tokens.append(chunk[start : start + size])
# 整体 chunk 也加入(用于长词精确匹配)
if len(chunk) > 1:
tokens.append(chunk)
# 去重保序
seen: set[str] = set()
unique: list[str] = []
for t in tokens:
if t not in seen:
seen.add(t)
unique.append(t)
return unique
def _score(self, doc: KnowledgeDoc, query_tokens: list[str]) -> tuple[float, list[str]]:
"""给文档打分:命中 title 得 3 分,命中 content 得 1 分(上限 5"""
score = 0.0
matched: list[str] = []
seen: set[str] = set()
content_lower = doc.content.lower()
title_lower = doc.title.lower()
for token in query_tokens:
token_lower = token.lower()
if token_lower in seen:
continue
in_title = token_lower in title_lower
in_content = token_lower in content_lower
if in_title or in_content:
seen.add(token_lower)
matched.append(token)
score += 3.0 if in_title else 1.0
return min(score, 15.0), matched # 上限 15避免极端高分