153 lines
5.7 KiB
Python
153 lines
5.7 KiB
Python
"""
|
||
app/services/knowledge_store.py
|
||
|
||
本地 Markdown 知识库加载与关键词检索。
|
||
- 所有 .md 文件存放在 config/knowledge/ 目录
|
||
- 基于关键词打分,支持多文档排序返回
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import re
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
|
||
|
||
@dataclass
|
||
class KnowledgeDoc:
|
||
"""一篇知识文档的元数据与正文。"""
|
||
|
||
doc_id: str # 文件名(不含扩展名)
|
||
title: str # MD 首行 # 标题,无则用文件名
|
||
content: str # 完整原始 Markdown 内容
|
||
keywords: list[str] = field(default_factory=list) # 从正文抽取的高频词
|
||
|
||
|
||
@dataclass
|
||
class SearchResult:
|
||
doc: KnowledgeDoc
|
||
score: float
|
||
matched_keywords: list[str]
|
||
|
||
|
||
class KnowledgeStore:
|
||
"""从 config/knowledge/*.md 加载知识库,提供关键词检索。"""
|
||
|
||
def __init__(self, knowledge_dir: str | Path) -> None:
|
||
self._dir = Path(knowledge_dir)
|
||
self._docs: dict[str, KnowledgeDoc] = {}
|
||
self._load()
|
||
|
||
# ── 公开 API ───────────────────────────────────────────────────────────────
|
||
|
||
def search(self, query: str, top_k: int = 3) -> list[SearchResult]:
|
||
"""根据 query 检索最相关的知识文档,返回最多 top_k 条。"""
|
||
query_tokens = self._tokenize(query)
|
||
if not query_tokens:
|
||
return []
|
||
|
||
results: list[SearchResult] = []
|
||
for doc in self._docs.values():
|
||
score, matched = self._score(doc, query_tokens)
|
||
if score > 0:
|
||
results.append(SearchResult(doc=doc, score=score, matched_keywords=matched))
|
||
|
||
results.sort(key=lambda r: r.score, reverse=True)
|
||
return results[:top_k]
|
||
|
||
def get(self, doc_id: str) -> KnowledgeDoc | None:
|
||
return self._docs.get(doc_id)
|
||
|
||
def all_doc_ids(self) -> list[str]:
|
||
return list(self._docs.keys())
|
||
|
||
def reload(self) -> None:
|
||
"""热重载知识库(添加新 MD 文件后调用)。"""
|
||
self._docs.clear()
|
||
self._load()
|
||
|
||
# ── 内部逻辑 ───────────────────────────────────────────────────────────────
|
||
|
||
def _load(self) -> None:
|
||
if not self._dir.exists():
|
||
return
|
||
for md_path in sorted(self._dir.glob("*.md")):
|
||
doc = self._parse_md(md_path)
|
||
self._docs[doc.doc_id] = doc
|
||
|
||
def _parse_md(self, path: Path) -> KnowledgeDoc:
|
||
content = path.read_text(encoding="utf-8")
|
||
doc_id = path.stem
|
||
|
||
# 提取第一个 # 标题作为文档标题
|
||
title_match = re.search(r"^#+\s+(.+)", content, re.MULTILINE)
|
||
title = title_match.group(1).strip() if title_match else doc_id
|
||
|
||
# 提取关键词:去标点后的中文词段(2~6字)
|
||
keywords = self._extract_keywords(content)
|
||
return KnowledgeDoc(doc_id=doc_id, title=title, content=content, keywords=keywords)
|
||
|
||
def _extract_keywords(self, content: str) -> list[str]:
|
||
"""提取 MD 正文中的中文词段作为候选关键词。"""
|
||
# 去掉 Markdown 语法符号
|
||
text = re.sub(r"[#`*_>|~\[\]()!]", " ", content)
|
||
text = re.sub(r"https?://\S+", " ", text)
|
||
# 中文词段(2-6 个汉字)
|
||
words = re.findall(r"[\u4e00-\u9fff]{2,6}", text)
|
||
# 去重,保留顺序
|
||
seen: set[str] = set()
|
||
unique: list[str] = []
|
||
for w in words:
|
||
if w not in seen:
|
||
seen.add(w)
|
||
unique.append(w)
|
||
return unique
|
||
|
||
def _tokenize(self, text: str) -> list[str]:
|
||
"""将 query 分割成候选检索词。
|
||
|
||
策略:
|
||
1. 提取所有连续中文字段(2字以上)作为候选
|
||
2. 在连续中文字段上做滑动窗口(2-5字),覆盖子串匹配
|
||
避免整句 '虚焊报警怎么办' 作为单一 token 无法匹配 '虚焊报警'
|
||
"""
|
||
# 提取所有连续中文片段
|
||
chinese_chunks = re.findall(r"[\u4e00-\u9fff]+", text)
|
||
tokens: list[str] = []
|
||
for chunk in chinese_chunks:
|
||
# 滑动窗口:长度 2 到 min(5, len(chunk))
|
||
for size in range(2, min(6, len(chunk) + 1)):
|
||
for start in range(len(chunk) - size + 1):
|
||
tokens.append(chunk[start : start + size])
|
||
# 整体 chunk 也加入(用于长词精确匹配)
|
||
if len(chunk) > 1:
|
||
tokens.append(chunk)
|
||
# 去重保序
|
||
seen: set[str] = set()
|
||
unique: list[str] = []
|
||
for t in tokens:
|
||
if t not in seen:
|
||
seen.add(t)
|
||
unique.append(t)
|
||
return unique
|
||
|
||
def _score(self, doc: KnowledgeDoc, query_tokens: list[str]) -> tuple[float, list[str]]:
|
||
"""给文档打分:命中 title 得 3 分,命中 content 得 1 分(上限 5)。"""
|
||
score = 0.0
|
||
matched: list[str] = []
|
||
seen: set[str] = set()
|
||
content_lower = doc.content.lower()
|
||
title_lower = doc.title.lower()
|
||
|
||
for token in query_tokens:
|
||
token_lower = token.lower()
|
||
if token_lower in seen:
|
||
continue
|
||
in_title = token_lower in title_lower
|
||
in_content = token_lower in content_lower
|
||
if in_title or in_content:
|
||
seen.add(token_lower)
|
||
matched.append(token)
|
||
score += 3.0 if in_title else 1.0
|
||
|
||
return min(score, 15.0), matched # 上限 15,避免极端高分
|