Initial upload for secondary development
This commit is contained in:
0
chatlog_fastAPI/routers/__init__.py
Normal file
0
chatlog_fastAPI/routers/__init__.py
Normal file
116
chatlog_fastAPI/routers/ai.py
Normal file
116
chatlog_fastAPI/routers/ai.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Literal
|
||||
import aiosqlite, json, logging
|
||||
import httpx
|
||||
from database import get_db
|
||||
from config import settings
|
||||
from services.ai_client import get_openai_client
|
||||
from services.runtime_settings import get_ai_settings
|
||||
from services.media_parser import parse_media
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["ai"])
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_ai_client():
|
||||
return await get_openai_client()
|
||||
|
||||
|
||||
class SummarizeRequest(BaseModel):
|
||||
context: str # 已组装好的对话文本(含媒体描述)
|
||||
room_name: Optional[str] = ""
|
||||
messages: Optional[list] = None # 兼容旧调用,忽略
|
||||
|
||||
|
||||
class ParseRequest(BaseModel):
|
||||
type: Literal["voice", "image", "video"]
|
||||
key: str # voice: ServerID string; image/video: md5
|
||||
|
||||
|
||||
@router.post("/ai/parse")
|
||||
async def ai_parse(body: ParseRequest):
|
||||
"""
|
||||
通过 FastAPI 代理 AI 媒体解析:
|
||||
- voice: 从 chatlog 下载音频 → DashScope Paraformer ASR 转文字
|
||||
- image/video: 从 chatlog 下载媒体 → base64 → 视觉模型描述
|
||||
"""
|
||||
try:
|
||||
return await parse_media(body.type, body.key)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error(f"[ai/parse] 媒体解析失败: {e}", exc_info=True)
|
||||
raise HTTPException(500, f"媒体解析失败: {e}")
|
||||
|
||||
|
||||
@router.post("/ai/summarize/stream")
|
||||
async def summarize_stream(body: SummarizeRequest):
|
||||
"""
|
||||
接收前端已处理好的对话上下文,调用 AI 模型流式输出总结。
|
||||
前端负责先把媒体(图片/语音/视频)解析成文字再拼进 context。
|
||||
"""
|
||||
_ai = await get_ai_settings()
|
||||
if not _ai.get("ai_api_key"):
|
||||
async def err_gen():
|
||||
yield 'data: {"error": "AI 服务未配置,请在「设置」页面填入 AI API Key"}\n\n'
|
||||
return StreamingResponse(err_gen(), media_type="text/event-stream")
|
||||
if not _ai.get("ai_model"):
|
||||
async def err_gen():
|
||||
yield 'data: {"error": "知识总结模型未配置,请在「设置」页面填入模型名称(如 qwen-max)"}\n\n'
|
||||
return StreamingResponse(err_gen(), media_type="text/event-stream")
|
||||
|
||||
context = body.context.strip()
|
||||
if not context:
|
||||
async def err_gen():
|
||||
yield 'data: {"error": "对话内容为空"}\n\n'
|
||||
return StreamingResponse(err_gen(), media_type="text/event-stream")
|
||||
|
||||
room = body.room_name or "会话"
|
||||
|
||||
system_prompt = (
|
||||
"你是一位专业的对话分析助手。"
|
||||
"请根据提供的聊天记录(可能包含图片描述、语音转文字、视频描述等多媒体内容)"
|
||||
"生成一份结构清晰的 Markdown 总结。"
|
||||
"总结应包含:主要话题、关键信息点、媒体内容要点、待办事项(如有)。"
|
||||
"只输出 Markdown 格式内容,不要有任何额外说明。"
|
||||
)
|
||||
user_prompt = (
|
||||
f"群聊:{room}\n\n"
|
||||
f"以下是聊天记录(含多媒体内容描述):\n\n"
|
||||
f"{context[:12000]}\n\n" # 限制 token 数
|
||||
f"请生成总结:"
|
||||
)
|
||||
|
||||
async def generate():
|
||||
try:
|
||||
_client, _ai = await _get_ai_client()
|
||||
stream = await _client.chat.completions.create(
|
||||
model=_ai["ai_model"],
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
stream=True,
|
||||
temperature=0.3,
|
||||
)
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta.content if chunk.choices else None
|
||||
if delta:
|
||||
yield f"data: {json.dumps({'delta': delta}, ensure_ascii=False)}\n\n"
|
||||
yield 'data: {"done": true}\n\n'
|
||||
except Exception as e:
|
||||
log.error(f"[summarize/stream] LLM 调用失败: {e}", exc_info=True)
|
||||
yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}")
|
||||
async def get_task(task_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT * FROM ai_tasks WHERE id=?", (task_id,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
return dict(row)
|
||||
93
chatlog_fastAPI/routers/chatlog_proxy.py
Normal file
93
chatlog_fastAPI/routers/chatlog_proxy.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
|
||||
from config import settings
|
||||
|
||||
router = APIRouter(tags=["chatlog-proxy"])
|
||||
|
||||
HOP_BY_HOP_HEADERS = {
|
||||
"connection",
|
||||
"keep-alive",
|
||||
"proxy-authenticate",
|
||||
"proxy-authorization",
|
||||
"te",
|
||||
"trailers",
|
||||
"transfer-encoding",
|
||||
"upgrade",
|
||||
}
|
||||
|
||||
|
||||
def _copy_headers(headers: httpx.Headers) -> dict[str, str]:
|
||||
copied: dict[str, str] = {}
|
||||
for key, value in headers.items():
|
||||
if key.lower() not in HOP_BY_HOP_HEADERS:
|
||||
copied[key] = value
|
||||
return copied
|
||||
|
||||
|
||||
async def _proxy_chatlog(request: Request, upstream_path: str) -> Response:
|
||||
query = request.url.query
|
||||
target = f"{settings.chatlog_base_url}{upstream_path}"
|
||||
if query:
|
||||
target = f"{target}?{query}"
|
||||
|
||||
body = await request.body()
|
||||
headers = {
|
||||
key: value
|
||||
for key, value in request.headers.items()
|
||||
if key.lower() not in HOP_BY_HOP_HEADERS and key.lower() != "host"
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=None, trust_env=False, follow_redirects=True) as client:
|
||||
upstream = await client.request(
|
||||
request.method,
|
||||
target,
|
||||
content=body if body else None,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
response_headers = _copy_headers(upstream.headers)
|
||||
return StreamingResponse(
|
||||
iter([upstream.content]),
|
||||
status_code=upstream.status_code,
|
||||
media_type=upstream.headers.get("content-type"),
|
||||
headers=response_headers,
|
||||
)
|
||||
|
||||
|
||||
@router.api_route("/api/v1/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"])
|
||||
async def proxy_api_v1(path: str, request: Request):
|
||||
return await _proxy_chatlog(request, f"/api/v1/{path}")
|
||||
|
||||
|
||||
async def _proxy_media(kind: str, path: str, request: Request):
|
||||
safe_path = "/".join(quote(part, safe="") for part in path.split("/"))
|
||||
return await _proxy_chatlog(request, f"/{kind}/{safe_path}")
|
||||
|
||||
|
||||
@router.api_route("/image/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_image(path: str, request: Request):
|
||||
return await _proxy_media("image", path, request)
|
||||
|
||||
|
||||
@router.api_route("/voice/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_voice(path: str, request: Request):
|
||||
return await _proxy_media("voice", path, request)
|
||||
|
||||
|
||||
@router.api_route("/video/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_video(path: str, request: Request):
|
||||
return await _proxy_media("video", path, request)
|
||||
|
||||
|
||||
@router.api_route("/file/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_file(path: str, request: Request):
|
||||
return await _proxy_media("file", path, request)
|
||||
|
||||
|
||||
@router.api_route("/data/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_data(path: str, request: Request):
|
||||
return await _proxy_media("data", path, request)
|
||||
190
chatlog_fastAPI/routers/files.py
Normal file
190
chatlog_fastAPI/routers/files.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
from config import settings
|
||||
from services.chatlog_client import chatlog_client
|
||||
|
||||
router = APIRouter(prefix="/api/files", tags=["files"])
|
||||
|
||||
|
||||
OFFICE_MEDIA_TYPES = {
|
||||
".xls": "application/vnd.ms-excel",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
".ppt": "application/vnd.ms-powerpoint",
|
||||
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
".doc": "application/msword",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".pdf": "application/pdf",
|
||||
".dwg": "application/acad",
|
||||
}
|
||||
|
||||
|
||||
def _connect_hardlink_db(hardlink_db: Path) -> sqlite3.Connection:
|
||||
"""
|
||||
chatlog may keep hardlink.db open. Copying a tiny snapshot avoids transient
|
||||
"unable to open database file" errors on Windows while keeping reads safe.
|
||||
"""
|
||||
tmp = Path(tempfile.gettempdir()) / f"chatlab_hardlink_{os.getpid()}_{hardlink_db.stat().st_mtime_ns}.db"
|
||||
if not tmp.exists() or tmp.stat().st_size != hardlink_db.stat().st_size:
|
||||
shutil.copy2(hardlink_db, tmp)
|
||||
con = sqlite3.connect(tmp)
|
||||
con.row_factory = sqlite3.Row
|
||||
return con
|
||||
|
||||
|
||||
def _safe_download_name(name: str, fallback: str) -> str:
|
||||
name = (name or fallback).replace("\r", "").replace("\n", "").strip()
|
||||
return name or fallback
|
||||
|
||||
|
||||
def _content_disposition(filename: str) -> str:
|
||||
quoted = quote(filename)
|
||||
ascii_fallback = re.sub(r"[^A-Za-z0-9._-]+", "_", filename) or "download"
|
||||
return f"attachment; filename=\"{ascii_fallback}\"; filename*=UTF-8''{quoted}"
|
||||
|
||||
|
||||
def _guess_media_type(filename: str, fallback: str = "") -> str:
|
||||
ext = Path(filename or "").suffix.lower()
|
||||
return OFFICE_MEDIA_TYPES.get(ext) or mimetypes.guess_type(filename)[0] or fallback or "application/octet-stream"
|
||||
|
||||
|
||||
async def _proxy_chatlog_file(md5: str, filename: str = ""):
|
||||
url = f"{settings.chatlog_base_url}/file/{quote(md5, safe='')}"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30, trust_env=False, follow_redirects=True) as client:
|
||||
resp = await client.get(url)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if resp.status_code != 200 or resp.content == b'"media not found"':
|
||||
return None
|
||||
|
||||
headers = {
|
||||
"Content-Length": str(len(resp.content)),
|
||||
"X-ChatLab-File-Source": "chatlog",
|
||||
}
|
||||
if filename:
|
||||
headers["Content-Disposition"] = _content_disposition(filename)
|
||||
media_type = _guess_media_type(filename, resp.headers.get("content-type") or "")
|
||||
return StreamingResponse(iter([resp.content]), media_type=media_type, headers=headers)
|
||||
|
||||
|
||||
def _xwechat_roots_from_hardlink_db(hardlink_db: Path) -> list[Path]:
|
||||
roots: list[Path] = []
|
||||
try:
|
||||
con = _connect_hardlink_db(hardlink_db)
|
||||
row = con.execute("SELECT ValueStdStr FROM db_info WHERE Key='uuid'").fetchone()
|
||||
raw = row["ValueStdStr"] if row else ""
|
||||
except Exception:
|
||||
raw = ""
|
||||
|
||||
if raw:
|
||||
m = re.search(r"([A-Za-z]:\\[^|]+?xwechat_files)", raw)
|
||||
if m:
|
||||
roots.append(Path(m.group(1)))
|
||||
|
||||
roots.extend([
|
||||
Path.home() / "xwechat_files",
|
||||
Path.home() / "Documents" / "WeChat Files",
|
||||
])
|
||||
uniq: list[Path] = []
|
||||
seen = set()
|
||||
for root in roots:
|
||||
s = str(root).lower()
|
||||
if s not in seen:
|
||||
uniq.append(root)
|
||||
seen.add(s)
|
||||
return uniq
|
||||
|
||||
|
||||
def _find_local_file(hardlink_db: Path, md5: str, requested_name: str = "") -> Path | None:
|
||||
try:
|
||||
con = _connect_hardlink_db(hardlink_db)
|
||||
row = con.execute(
|
||||
"""
|
||||
SELECT md5, file_name, file_size, dir1, dir2
|
||||
FROM file_hardlink_info_v4
|
||||
WHERE md5=?
|
||||
ORDER BY _rowid_ DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(md5,),
|
||||
).fetchone()
|
||||
except Exception:
|
||||
row = None
|
||||
if not row:
|
||||
return None
|
||||
|
||||
names = [requested_name, row["file_name"]]
|
||||
names = [n for n in names if n]
|
||||
size = int(row["file_size"] or 0)
|
||||
roots = _xwechat_roots_from_hardlink_db(hardlink_db)
|
||||
|
||||
for root in roots:
|
||||
if not root.exists():
|
||||
continue
|
||||
for name in names:
|
||||
for candidate in root.rglob(name):
|
||||
try:
|
||||
if candidate.is_file() and (not size or candidate.stat().st_size == size):
|
||||
return candidate
|
||||
except Exception:
|
||||
continue
|
||||
if size:
|
||||
# Fallback by size in the common file store. This is intentionally limited
|
||||
# to msg/file to avoid scanning unrelated huge trees for every request.
|
||||
for file_root in root.glob("*/msg/file"):
|
||||
if not file_root.exists():
|
||||
continue
|
||||
for candidate in file_root.rglob("*"):
|
||||
try:
|
||||
if candidate.is_file() and candidate.stat().st_size == size:
|
||||
if not names or candidate.name in names:
|
||||
return candidate
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/{md5}")
|
||||
async def get_file(md5: str, filename: str = Query("")):
|
||||
md5 = md5.strip()
|
||||
if not re.fullmatch(r"[0-9a-fA-F]{8,64}", md5):
|
||||
raise HTTPException(400, "文件 md5 不合法")
|
||||
|
||||
filename = _safe_download_name(filename, md5)
|
||||
proxied = await _proxy_chatlog_file(md5, filename)
|
||||
if proxied:
|
||||
return proxied
|
||||
|
||||
db_paths = await chatlog_client.get_db_paths()
|
||||
hardlink_paths = db_paths.get("media") or []
|
||||
for raw_path in hardlink_paths:
|
||||
hardlink_db = Path(raw_path)
|
||||
if not hardlink_db.exists():
|
||||
continue
|
||||
local_file = _find_local_file(hardlink_db, md5, filename)
|
||||
if local_file:
|
||||
media_type = _guess_media_type(filename or local_file.name)
|
||||
return FileResponse(
|
||||
path=str(local_file),
|
||||
filename=filename or local_file.name,
|
||||
media_type=media_type,
|
||||
headers={
|
||||
"Content-Disposition": _content_disposition(filename or local_file.name),
|
||||
"Content-Length": str(local_file.stat().st_size),
|
||||
"X-ChatLab-File-Source": "local-hardlink",
|
||||
},
|
||||
)
|
||||
|
||||
raise HTTPException(404, "原文件未找到,可能未解密或已清理")
|
||||
138
chatlog_fastAPI/routers/groups.py
Normal file
138
chatlog_fastAPI/routers/groups.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import aiosqlite, json
|
||||
from database import get_db
|
||||
|
||||
router = APIRouter(prefix="/api/groups", tags=["groups"])
|
||||
|
||||
class GroupCreate(BaseModel):
|
||||
talker: str
|
||||
name: Optional[str] = ""
|
||||
poll_interval: int = 300
|
||||
|
||||
class GroupPatch(BaseModel):
|
||||
analysis_prompt: Optional[str] = None
|
||||
|
||||
class InitParams(BaseModel):
|
||||
start_time: int # unix 秒
|
||||
end_time: int # unix 秒
|
||||
|
||||
@router.post("")
|
||||
async def create_group(body: GroupCreate, db: aiosqlite.Connection = Depends(get_db)):
|
||||
try:
|
||||
await db.execute(
|
||||
"INSERT INTO groups (talker, name, poll_interval) VALUES (?, ?, ?)",
|
||||
(body.talker, body.name, body.poll_interval)
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT * FROM groups WHERE talker=?", (body.talker,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
return dict(row)
|
||||
except aiosqlite.IntegrityError:
|
||||
raise HTTPException(409, "talker already exists")
|
||||
|
||||
@router.get("")
|
||||
async def list_groups(db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT * FROM groups") as cur:
|
||||
rows = await cur.fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
@router.patch("/{group_id}")
|
||||
async def patch_group(group_id: int, body: GroupPatch, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT id FROM groups WHERE id=?", (group_id,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "group not found")
|
||||
updates = body.model_dump(exclude_none=True)
|
||||
if "analysis_prompt" in updates:
|
||||
await db.execute(
|
||||
"UPDATE groups SET analysis_prompt=? WHERE id=?",
|
||||
(updates["analysis_prompt"], group_id),
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT * FROM groups WHERE id=?", (group_id,)) as cur:
|
||||
return dict(await cur.fetchone())
|
||||
|
||||
@router.post("/{group_id}/init")
|
||||
async def trigger_init(
|
||||
group_id: int,
|
||||
body: InitParams,
|
||||
db: aiosqlite.Connection = Depends(get_db),
|
||||
):
|
||||
"""对指定时间区间内全部消息做一次性 AI 分类。串行执行。"""
|
||||
async with db.execute("SELECT * FROM groups WHERE id=?", (group_id,)) as cur:
|
||||
group = await cur.fetchone()
|
||||
if not group:
|
||||
raise HTTPException(404, "group not found")
|
||||
|
||||
from services.topic_engine import get_classifying_group
|
||||
busy = get_classifying_group()
|
||||
if busy is not None:
|
||||
raise HTTPException(409, f"已有群正在分析(group_id={busy}),请等待完成后再试")
|
||||
|
||||
if body.end_time <= body.start_time:
|
||||
raise HTTPException(400, "end_time 必须大于 start_time")
|
||||
|
||||
await db.execute(
|
||||
"INSERT INTO ai_tasks (group_id, type, status, progress) VALUES (?, 'classify_window', 'running', ?)",
|
||||
(group_id, json.dumps({"processed": 0, "total": 0})),
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT last_insert_rowid()") as cur:
|
||||
task = await cur.fetchone()
|
||||
task_id = task[0]
|
||||
|
||||
from services.topic_engine import run_classify_window
|
||||
import asyncio
|
||||
asyncio.create_task(
|
||||
run_classify_window(group_id, task_id, dict(group), body.start_time, body.end_time)
|
||||
)
|
||||
return {"task_id": task_id}
|
||||
|
||||
@router.get("/{group_id}/task")
|
||||
async def get_task(group_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute(
|
||||
"SELECT * FROM ai_tasks WHERE group_id=? ORDER BY id DESC LIMIT 1", (group_id,)
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "no task found")
|
||||
return dict(row)
|
||||
|
||||
@router.delete("/{group_id}")
|
||||
async def delete_group(group_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT id FROM groups WHERE id=?", (group_id,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "group not found")
|
||||
# 级联删除关联数据(FTS → 知识库报告 → 话题消息 → 话题 → 任务 → 群组)
|
||||
# 注意:必须先清 knowledge_fts/knowledge_docs,否则 SQLite 复用 topic_id 时
|
||||
# 残留报告会"接"到新建话题上,造成跨群串报告。
|
||||
await db.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_fts WHERE doc_id IN (
|
||||
SELECT id FROM knowledge_docs WHERE topic_id IN (
|
||||
SELECT id FROM topics WHERE group_id=?
|
||||
)
|
||||
)
|
||||
""",
|
||||
(group_id,)
|
||||
)
|
||||
await db.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_docs WHERE topic_id IN (
|
||||
SELECT id FROM topics WHERE group_id=?
|
||||
)
|
||||
""",
|
||||
(group_id,)
|
||||
)
|
||||
await db.execute(
|
||||
"DELETE FROM topic_messages WHERE topic_id IN (SELECT id FROM topics WHERE group_id=?)",
|
||||
(group_id,)
|
||||
)
|
||||
await db.execute("DELETE FROM topics WHERE group_id=?", (group_id,))
|
||||
await db.execute("DELETE FROM ai_tasks WHERE group_id=?", (group_id,))
|
||||
await db.execute("DELETE FROM groups WHERE id=?", (group_id,))
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
67
chatlog_fastAPI/routers/knowledge.py
Normal file
67
chatlog_fastAPI/routers/knowledge.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import aiosqlite
|
||||
from database import get_db
|
||||
|
||||
router = APIRouter(prefix="/api/knowledge", tags=["knowledge"])
|
||||
|
||||
class KnowledgePatch(BaseModel):
|
||||
content: str
|
||||
|
||||
@router.get("")
|
||||
async def list_knowledge(
|
||||
keyword: Optional[str] = None,
|
||||
db: aiosqlite.Connection = Depends(get_db)
|
||||
):
|
||||
if keyword:
|
||||
# FTS5 查询前先用 jieba 分词,提高中文召回率
|
||||
from services.fts import build_match_query
|
||||
fts_query = build_match_query(keyword)
|
||||
if not fts_query:
|
||||
return []
|
||||
async with db.execute(
|
||||
"SELECT k.id, k.topic_id, k.created_at, k.updated_at, t.title, t.group_id, g.name as group_name "
|
||||
"FROM knowledge_docs k JOIN topics t ON k.topic_id=t.id "
|
||||
"LEFT JOIN groups g ON t.group_id=g.id "
|
||||
"WHERE k.id IN (SELECT doc_id FROM knowledge_fts WHERE knowledge_fts MATCH ?)",
|
||||
(fts_query,)
|
||||
) as cur:
|
||||
return [dict(r) for r in await cur.fetchall()]
|
||||
async with db.execute(
|
||||
"SELECT k.id, k.topic_id, k.created_at, k.updated_at, t.title, t.group_id, g.name as group_name "
|
||||
"FROM knowledge_docs k JOIN topics t ON k.topic_id=t.id "
|
||||
"LEFT JOIN groups g ON t.group_id=g.id "
|
||||
"ORDER BY g.name, k.updated_at DESC"
|
||||
) as cur:
|
||||
return [dict(r) for r in await cur.fetchall()]
|
||||
|
||||
@router.get("/{doc_id}")
|
||||
async def get_knowledge(doc_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT * FROM knowledge_docs WHERE id=?", (doc_id,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
return dict(row)
|
||||
|
||||
@router.patch("/{doc_id}")
|
||||
async def patch_knowledge(doc_id: int, body: KnowledgePatch, db: aiosqlite.Connection = Depends(get_db)):
|
||||
await db.execute(
|
||||
"UPDATE knowledge_docs SET content=?, updated_at=CURRENT_TIMESTAMP, curated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||||
(body.content, doc_id)
|
||||
)
|
||||
await db.commit()
|
||||
# update FTS
|
||||
async with db.execute("SELECT topic_id FROM knowledge_docs WHERE id=?", (doc_id,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
async with db.execute("SELECT title FROM topics WHERE id=?", (row["topic_id"],)) as cur:
|
||||
topic = await cur.fetchone()
|
||||
await db.execute("DELETE FROM knowledge_fts WHERE doc_id=?", (doc_id,))
|
||||
from services.fts import tokenize
|
||||
await db.execute(
|
||||
"INSERT INTO knowledge_fts (doc_id, title, content) VALUES (?, ?, ?)",
|
||||
(doc_id, tokenize(topic["title"]), tokenize(body.content))
|
||||
)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
144
chatlog_fastAPI/routers/search.py
Normal file
144
chatlog_fastAPI/routers/search.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from urllib.parse import quote
|
||||
from services.chatlog_client import MessageIndexNotReady, chatlog_client
|
||||
from services.message_formatter import extract_quote
|
||||
|
||||
router = APIRouter(prefix="/api/search", tags=["search"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def search(
|
||||
talker: str = Query(..., description="群/联系人 ID"),
|
||||
time: str = Query("", description="时间范围,如 2024-01-01,2024-01-31"),
|
||||
sender: str = Query("", description="发送者 ID,可选"),
|
||||
keyword: str = Query("", description="关键词,可选"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=500),
|
||||
):
|
||||
"""透传 chatlog /api/v1/chatlog,返回 {"total": N, "items": [...]}"""
|
||||
offset = (page - 1) * page_size
|
||||
try:
|
||||
data = await chatlog_client.get_messages(
|
||||
talker,
|
||||
time=time,
|
||||
sender=sender,
|
||||
keyword=keyword,
|
||||
limit=page_size,
|
||||
offset=offset,
|
||||
)
|
||||
except MessageIndexNotReady as e:
|
||||
raise HTTPException(status_code=503, detail=str(e)) from e
|
||||
for item in data.get("items", []) or []:
|
||||
contents = item.get("contents") or item.get("Contents") or {}
|
||||
if not isinstance(contents, dict):
|
||||
contents = {}
|
||||
try:
|
||||
is_file = int(item.get("type") or item.get("Type") or 0) == 49 and int(
|
||||
item.get("subType") or item.get("sub_type") or item.get("SubType") or 0
|
||||
) == 6
|
||||
except Exception:
|
||||
is_file = False
|
||||
file_md5 = str(contents.get("md5") or "") if is_file else ""
|
||||
item["is_file"] = is_file
|
||||
item["file_name"] = (
|
||||
contents.get("title") or contents.get("fileName") or contents.get("filename") or ""
|
||||
) if is_file else ""
|
||||
item["file_md5"] = file_md5
|
||||
item["quote"] = item.get("quote") or extract_quote(item)
|
||||
file_name = item["file_name"]
|
||||
item["file_url"] = f"/api/files/{quote(file_md5, safe='')}?filename={quote(file_name or file_md5, safe='')}" if file_md5 else ""
|
||||
return data
|
||||
|
||||
|
||||
@router.get("/chatrooms")
|
||||
async def chatrooms(
|
||||
keyword: str = Query("", description="关键词搜索"),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
):
|
||||
"""获取所有可用的微信群聊列表"""
|
||||
fetch_limit = min(2000, offset + limit)
|
||||
rooms_data = await chatlog_client.get_chatrooms(keyword=keyword, limit=fetch_limit, offset=0)
|
||||
if isinstance(rooms_data, list):
|
||||
room_items = rooms_data
|
||||
total = len(room_items)
|
||||
else:
|
||||
room_items = rooms_data.get("items") or rooms_data.get("data") or []
|
||||
total = rooms_data.get("total", len(room_items))
|
||||
|
||||
merged = []
|
||||
seen = set()
|
||||
|
||||
def get_room_id(item: dict) -> str:
|
||||
return str(item.get("name") or item.get("Name") or item.get("userName") or item.get("UserName") or "")
|
||||
|
||||
def add_room(item: dict):
|
||||
room_id = get_room_id(item)
|
||||
if not room_id or not room_id.endswith("@chatroom") or room_id in seen:
|
||||
return
|
||||
seen.add(room_id)
|
||||
merged.append(item)
|
||||
|
||||
for item in room_items:
|
||||
if isinstance(item, dict):
|
||||
add_room(item)
|
||||
|
||||
# Freshly imported phone records may exist in sessions/messages before
|
||||
# chatroom metadata is populated. Merge @chatroom sessions as fallback.
|
||||
try:
|
||||
session_items = await chatlog_client.get_sessions(keyword="", limit=2000)
|
||||
except Exception:
|
||||
session_items = []
|
||||
|
||||
lowered_keyword = (keyword or "").lower()
|
||||
for session in session_items:
|
||||
if not isinstance(session, dict):
|
||||
continue
|
||||
user_name = str(session.get("userName") or session.get("UserName") or "")
|
||||
if not user_name.endswith("@chatroom"):
|
||||
continue
|
||||
nick_name = session.get("nickName") or session.get("NickName") or ""
|
||||
remark = session.get("remark") or session.get("Remark") or ""
|
||||
if lowered_keyword:
|
||||
haystack = f"{user_name} {nick_name} {remark}".lower()
|
||||
if lowered_keyword not in haystack:
|
||||
continue
|
||||
add_room({
|
||||
"name": user_name,
|
||||
"nickName": nick_name,
|
||||
"remark": remark,
|
||||
"source": "session",
|
||||
})
|
||||
|
||||
return {"total": max(total, len(merged)), "items": merged[offset:offset + limit]}
|
||||
|
||||
@router.get("/avatar")
|
||||
async def avatar(wxid: str = Query(...)):
|
||||
url = await chatlog_client.get_avatar_url(wxid)
|
||||
return {"url": url}
|
||||
|
||||
|
||||
@router.get("/members")
|
||||
async def members(
|
||||
talker: str = Query(..., description="群 ID"),
|
||||
time: str = Query("", description="统计时间范围,可选"),
|
||||
):
|
||||
"""
|
||||
获取群成员列表(按发言量降序)
|
||||
返回 {"members": [...], "total": N}
|
||||
每个成员:userName, displayName, msgCount, lastSpeakTime
|
||||
"""
|
||||
return await chatlog_client.get_chatroom_members(talker, time=time)
|
||||
|
||||
|
||||
@router.get("/sessions")
|
||||
async def sessions(
|
||||
keyword: str = Query("", description="关键词搜索"),
|
||||
limit: int = Query(500, ge=1, le=2000),
|
||||
):
|
||||
"""
|
||||
获取所有会话列表,含最新一条消息预览和时间(来自微信原生 Session 表)。
|
||||
返回:[{ userName, nickName, remark, content, nTime, nOrder }]
|
||||
"""
|
||||
items = await chatlog_client.get_sessions(keyword=keyword, limit=limit)
|
||||
return items
|
||||
64
chatlog_fastAPI/routers/settings.py
Normal file
64
chatlog_fastAPI/routers/settings.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import aiosqlite
|
||||
from database import get_db
|
||||
|
||||
router = APIRouter(prefix="/api/settings", tags=["settings"])
|
||||
|
||||
EDITABLE_KEYS = [
|
||||
"ai_base_url", "ai_api_key", "ai_model", "summary_model",
|
||||
"vision_model", "voice_model", "topic_analysis_prompt",
|
||||
]
|
||||
|
||||
|
||||
def _mask_key(value: str) -> str:
|
||||
if not value or len(value) <= 8:
|
||||
return "*" * len(value) if value else ""
|
||||
return value[:3] + "*" * (len(value) - 7) + value[-4:]
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_settings(db: aiosqlite.Connection = Depends(get_db)):
|
||||
result = {}
|
||||
placeholders = ",".join("?" for _ in EDITABLE_KEYS)
|
||||
async with db.execute(
|
||||
f"SELECT key, value FROM app_settings WHERE key IN ({placeholders})",
|
||||
EDITABLE_KEYS,
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
for row in rows:
|
||||
k, v = row["key"], row["value"]
|
||||
result[k] = _mask_key(v) if k == "ai_api_key" else v
|
||||
for k in EDITABLE_KEYS:
|
||||
if k not in result:
|
||||
result[k] = ""
|
||||
return result
|
||||
|
||||
|
||||
class SettingsUpdate(BaseModel):
|
||||
ai_base_url: Optional[str] = None
|
||||
ai_api_key: Optional[str] = None
|
||||
ai_model: Optional[str] = None
|
||||
summary_model: Optional[str] = None
|
||||
vision_model: Optional[str] = None
|
||||
voice_model: Optional[str] = None
|
||||
topic_analysis_prompt: Optional[str] = None
|
||||
|
||||
|
||||
@router.put("")
|
||||
async def update_settings(body: SettingsUpdate, db: aiosqlite.Connection = Depends(get_db)):
|
||||
updates = body.model_dump(exclude_none=True)
|
||||
for k, v in updates.items():
|
||||
if k not in EDITABLE_KEYS:
|
||||
continue
|
||||
if k == "ai_api_key" and "*" in v:
|
||||
continue
|
||||
await db.execute(
|
||||
"INSERT INTO app_settings (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = ?",
|
||||
(k, v, v),
|
||||
)
|
||||
await db.commit()
|
||||
from services.runtime_settings import invalidate_cache
|
||||
invalidate_cache()
|
||||
return {"status": "ok"}
|
||||
40
chatlog_fastAPI/routers/sse.py
Normal file
40
chatlog_fastAPI/routers/sse.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import asyncio, json, logging
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from services.chatlog_client import chatlog_client
|
||||
from services.message_formatter import attach_quote
|
||||
|
||||
router = APIRouter(prefix="/api/sse", tags=["sse"])
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/chatlog")
|
||||
async def sse_chatlog(talker: str = Query(...)):
|
||||
async def generate():
|
||||
try:
|
||||
data = await chatlog_client.get_messages(talker, limit=1, offset=0)
|
||||
last_total = data.get("total", 0)
|
||||
except Exception:
|
||||
last_total = 0
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(2)
|
||||
try:
|
||||
data = await chatlog_client.get_messages(talker, limit=50, offset=last_total)
|
||||
msgs = data.get("messages") or data.get("items") or []
|
||||
new_total = data.get("total", last_total)
|
||||
for msg in msgs:
|
||||
attach_quote(msg)
|
||||
yield f"data: {json.dumps(msg, ensure_ascii=False)}\n\n"
|
||||
if new_total > last_total:
|
||||
last_total = new_total
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
log.warning(f"[sse] poll error: {e}")
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
435
chatlog_fastAPI/routers/topics.py
Normal file
435
chatlog_fastAPI/routers/topics.py
Normal file
@@ -0,0 +1,435 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import aiosqlite, json
|
||||
from datetime import datetime
|
||||
from urllib.parse import quote
|
||||
from database import get_db
|
||||
from services.message_formatter import extract_quote
|
||||
|
||||
router = APIRouter(prefix="/api/topics", tags=["topics"])
|
||||
CHATLOG_BATCH_SIZE = 80
|
||||
NAME_LOOKUP_PAGE_SIZE = 500
|
||||
NAME_LOOKUP_MAX_ITEMS = 5000
|
||||
STALE_SUMMARIZE_MINUTES = 15
|
||||
|
||||
class TopicCreate(BaseModel):
|
||||
group_id: int
|
||||
title: str
|
||||
|
||||
class TopicPatch(BaseModel):
|
||||
title: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
class MessageAdd(BaseModel):
|
||||
msg_seq: int
|
||||
talker: str
|
||||
|
||||
|
||||
async def _mark_stale_summarize_tasks(db: aiosqlite.Connection, group_id: int, topic_id: int) -> None:
|
||||
error = "AI 报告生成任务超过 15 分钟未完成,已自动标记为失败,可重新生成"
|
||||
stale_window = f"-{STALE_SUMMARIZE_MINUTES} minutes"
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE ai_tasks
|
||||
SET status='error', error=?, updated_at=CURRENT_TIMESTAMP
|
||||
WHERE group_id=?
|
||||
AND type='summarize'
|
||||
AND status='running'
|
||||
AND datetime(updated_at) <= datetime('now', ?)
|
||||
""",
|
||||
(error, group_id, stale_window),
|
||||
)
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE topics
|
||||
SET status='error', updated_at=CURRENT_TIMESTAMP
|
||||
WHERE id=?
|
||||
AND status='processing'
|
||||
AND datetime(updated_at) <= datetime('now', ?)
|
||||
""",
|
||||
(topic_id, stale_window),
|
||||
)
|
||||
|
||||
|
||||
def _normalize_chatlog_message(item: dict, fallback_seq: int = 0) -> dict:
|
||||
contents = item.get("contents") or item.get("Contents") or {}
|
||||
if not isinstance(contents, dict):
|
||||
contents = {}
|
||||
media_key = (
|
||||
contents.get("rawmd5")
|
||||
or contents.get("md5")
|
||||
or contents.get("path")
|
||||
or item.get("media_key")
|
||||
or item.get("mediaKey")
|
||||
or ""
|
||||
)
|
||||
voice_key = (
|
||||
str(contents.get("voice"))
|
||||
if contents.get("voice")
|
||||
else item.get("voice_key") or item.get("voiceKey") or ""
|
||||
)
|
||||
raw_type = item.get("type") or item.get("Type") or 1
|
||||
raw_sub_type = item.get("sub_type") or item.get("subType") or item.get("SubType") or 0
|
||||
try:
|
||||
is_file = int(raw_type) == 49 and int(raw_sub_type) == 6
|
||||
except Exception:
|
||||
is_file = False
|
||||
file_md5 = str(contents.get("md5") or item.get("file_md5") or item.get("fileMd5") or "") if is_file else ""
|
||||
file_name = (
|
||||
contents.get("title")
|
||||
or contents.get("fileName")
|
||||
or contents.get("filename")
|
||||
or item.get("file_name")
|
||||
or item.get("fileName")
|
||||
or ""
|
||||
) if is_file else ""
|
||||
file_url = f"/api/files/{quote(file_md5, safe='')}?filename={quote(file_name or file_md5, safe='')}" if file_md5 else ""
|
||||
return {
|
||||
"seq": item.get("seq") or item.get("Seq") or item.get("sort_seq") or fallback_seq or 0,
|
||||
"sender": item.get("sender") or item.get("Sender") or "",
|
||||
"sender_name": (
|
||||
item.get("sender_name")
|
||||
or item.get("senderName")
|
||||
or item.get("SenderName")
|
||||
or item.get("sender")
|
||||
or item.get("Sender")
|
||||
or ""
|
||||
),
|
||||
"create_time": item.get("create_time") or item.get("time") or item.get("CreateTime") or "",
|
||||
"content": item.get("content") or item.get("Content") or "",
|
||||
"type": raw_type,
|
||||
"sub_type": raw_sub_type,
|
||||
"contents": contents,
|
||||
"media_key": media_key,
|
||||
"voice_key": voice_key,
|
||||
"image_path": media_key,
|
||||
"voice_path": voice_key,
|
||||
"video_path": media_key,
|
||||
"file_path": media_key,
|
||||
"link_url": contents.get("url") or item.get("link_url") or "",
|
||||
"link_title": contents.get("title") or item.get("link_title") or "",
|
||||
"link_desc": contents.get("desc") or item.get("link_desc") or "",
|
||||
"link_thumb": contents.get("thumbUrl") or contents.get("thumb_url") or item.get("link_thumb") or "",
|
||||
"link_source": contents.get("sourceName") or contents.get("source_name") or item.get("link_source") or "",
|
||||
"quote": item.get("quote") or extract_quote(item),
|
||||
"is_file": is_file,
|
||||
"file_name": file_name,
|
||||
"file_md5": file_md5,
|
||||
"file_url": file_url,
|
||||
}
|
||||
|
||||
|
||||
def _message_from_snapshot(raw: str | None, fallback_seq: int) -> dict | None:
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
item = json.loads(raw)
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(item, dict):
|
||||
return None
|
||||
return _normalize_chatlog_message(item, fallback_seq)
|
||||
|
||||
|
||||
def _looks_like_raw_sender_id(value: str | None) -> bool:
|
||||
value = (value or "").strip()
|
||||
return (
|
||||
not value
|
||||
or value.startswith("wxid_")
|
||||
or value.startswith("gh_")
|
||||
or value.endswith("@chatroom")
|
||||
or value.startswith("chatroom_")
|
||||
)
|
||||
|
||||
|
||||
def _sender_display_name(item: dict, sender: str = "") -> str:
|
||||
for key in (
|
||||
"sender_name",
|
||||
"senderName",
|
||||
"SenderName",
|
||||
"accountName",
|
||||
"groupNickname",
|
||||
"displayName",
|
||||
"nickName",
|
||||
"remark",
|
||||
):
|
||||
value = str(item.get(key) or "").strip()
|
||||
if value and value != sender and not _looks_like_raw_sender_id(value):
|
||||
return value
|
||||
return ""
|
||||
|
||||
|
||||
def _message_date(value: str | None) -> str | None:
|
||||
value = (value or "").strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00")).strftime("%Y-%m-%d")
|
||||
except Exception:
|
||||
return value[:10] if len(value) >= 10 and value[4:5] == "-" else None
|
||||
|
||||
|
||||
async def _build_sender_name_map(talker: str, messages: list[dict]) -> dict[str, str]:
|
||||
names: dict[str, str] = {}
|
||||
|
||||
for msg in messages:
|
||||
sender = str(msg.get("sender") or "").strip()
|
||||
name = _sender_display_name(msg, sender)
|
||||
if sender and name:
|
||||
names[sender] = name
|
||||
|
||||
missing = {
|
||||
str(msg.get("sender") or "").strip()
|
||||
for msg in messages
|
||||
if str(msg.get("sender") or "").strip()
|
||||
and _looks_like_raw_sender_id(str(msg.get("sender_name") or "").strip())
|
||||
}
|
||||
missing = {sender for sender in missing if sender not in names}
|
||||
if not missing:
|
||||
return names
|
||||
|
||||
from services.chatlog_client import chatlog_client
|
||||
|
||||
dates = sorted({
|
||||
date
|
||||
for msg in messages
|
||||
for date in [_message_date(str(msg.get("create_time") or ""))]
|
||||
if date
|
||||
})
|
||||
if dates:
|
||||
time_range = f"{dates[0]},{dates[-1]}"
|
||||
offset = 0
|
||||
seen = 0
|
||||
while missing and seen < NAME_LOOKUP_MAX_ITEMS:
|
||||
try:
|
||||
data = await chatlog_client.get_messages(
|
||||
talker,
|
||||
time=time_range,
|
||||
limit=NAME_LOOKUP_PAGE_SIZE,
|
||||
offset=offset,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch sender names talker={talker}: {e}")
|
||||
break
|
||||
items = data.get("items", []) or []
|
||||
if not items:
|
||||
break
|
||||
for item in items:
|
||||
sender = str(item.get("sender") or item.get("Sender") or "").strip()
|
||||
if sender not in missing:
|
||||
continue
|
||||
name = _sender_display_name(item, sender)
|
||||
if name:
|
||||
names[sender] = name
|
||||
missing.discard(sender)
|
||||
seen += len(items)
|
||||
offset += len(items)
|
||||
total = int(data.get("total") or 0)
|
||||
if total and offset >= total:
|
||||
break
|
||||
if len(items) < NAME_LOOKUP_PAGE_SIZE:
|
||||
break
|
||||
|
||||
if missing:
|
||||
try:
|
||||
members = await chatlog_client.get_chatroom_members(talker)
|
||||
raw_members = members.get("members", []) if isinstance(members, dict) else []
|
||||
for member in raw_members:
|
||||
sender = str(member.get("userName") or member.get("UserName") or "").strip()
|
||||
if sender not in missing:
|
||||
continue
|
||||
name = _sender_display_name(
|
||||
{
|
||||
"displayName": member.get("displayName") or member.get("DisplayName"),
|
||||
"nickName": member.get("nickName") or member.get("NickName"),
|
||||
"remark": member.get("remark") or member.get("Remark"),
|
||||
},
|
||||
sender,
|
||||
)
|
||||
if name:
|
||||
names[sender] = name
|
||||
missing.discard(sender)
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch chatroom members talker={talker}: {e}")
|
||||
|
||||
return names
|
||||
|
||||
|
||||
async def _fill_sender_names(talker: str, messages: list[dict]) -> None:
|
||||
if not talker or not messages:
|
||||
return
|
||||
names = await _build_sender_name_map(talker, messages)
|
||||
if not names:
|
||||
return
|
||||
for msg in messages:
|
||||
sender = str(msg.get("sender") or "").strip()
|
||||
if not sender or sender not in names:
|
||||
continue
|
||||
current = str(msg.get("sender_name") or "").strip()
|
||||
if not current or current == sender or _looks_like_raw_sender_id(current):
|
||||
msg["sender_name"] = names[sender]
|
||||
msg["senderName"] = names[sender]
|
||||
|
||||
@router.get("")
|
||||
async def list_topics(
|
||||
group_id: Optional[int] = None,
|
||||
status: Optional[str] = None,
|
||||
keyword: Optional[str] = None,
|
||||
db: aiosqlite.Connection = Depends(get_db)
|
||||
):
|
||||
sql = "SELECT * FROM topics WHERE 1=1"
|
||||
params = []
|
||||
if group_id:
|
||||
sql += " AND group_id=?"; params.append(group_id)
|
||||
if status:
|
||||
sql += " AND status=?"; params.append(status)
|
||||
if keyword:
|
||||
sql += " AND title LIKE ?"; params.append(f"%{keyword}%")
|
||||
async with db.execute(sql, params) as cur:
|
||||
return [dict(r) for r in await cur.fetchall()]
|
||||
|
||||
@router.post("")
|
||||
async def create_topic(body: TopicCreate, db: aiosqlite.Connection = Depends(get_db)):
|
||||
await db.execute(
|
||||
"INSERT INTO topics (group_id, title, source) VALUES (?, ?, 'manual')",
|
||||
(body.group_id, body.title),
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT * FROM topics ORDER BY id DESC LIMIT 1") as cur:
|
||||
return dict(await cur.fetchone())
|
||||
|
||||
@router.get("/{topic_id}")
|
||||
async def get_topic(topic_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT * FROM topics WHERE id=?", (topic_id,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
|
||||
# 拿到该话题下关联的所有 seq
|
||||
async with db.execute("SELECT * FROM topic_messages WHERE topic_id=? ORDER BY msg_seq ASC", (topic_id,)) as cur:
|
||||
msg_rows = await cur.fetchall()
|
||||
|
||||
msgs = []
|
||||
if msg_rows:
|
||||
# 获取群聊 ID
|
||||
talker = msg_rows[0]["talker"]
|
||||
seq_list = [r["msg_seq"] for r in msg_rows]
|
||||
|
||||
# 分批调用 5030 接口获取真实消息内容,避免大批量 batch 触发 500。
|
||||
from services.chatlog_client import chatlog_client
|
||||
fetched_by_seq: dict[int, dict] = {}
|
||||
for i in range(0, len(seq_list), CHATLOG_BATCH_SIZE):
|
||||
chunk = seq_list[i: i + CHATLOG_BATCH_SIZE]
|
||||
try:
|
||||
msgs_data = await chatlog_client.get_messages_batch(talker, chunk)
|
||||
raw_items = msgs_data.get("items", []) or msgs_data.get("Items", [])
|
||||
for item in raw_items:
|
||||
normalized = _normalize_chatlog_message(item)
|
||||
seq = normalized.get("seq")
|
||||
if seq:
|
||||
fetched_by_seq[int(seq)] = normalized
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch real messages chunk topic={topic_id}: {e}")
|
||||
|
||||
for r in msg_rows:
|
||||
seq = int(r["msg_seq"])
|
||||
if seq in fetched_by_seq:
|
||||
msgs.append(fetched_by_seq[seq])
|
||||
continue
|
||||
snap = _message_from_snapshot(r["message_json"] if "message_json" in r.keys() else None, seq)
|
||||
if snap:
|
||||
msgs.append(snap)
|
||||
continue
|
||||
msgs.append({
|
||||
"seq": seq,
|
||||
"sender": "",
|
||||
"sender_name": "系统提示",
|
||||
"create_time": "",
|
||||
"content": f"原始消息无法从 chatlog 找回 (seq: {seq})",
|
||||
"type": 1,
|
||||
"sub_type": 0,
|
||||
})
|
||||
|
||||
# 获取知识文档
|
||||
if msg_rows:
|
||||
await _fill_sender_names(talker, msgs)
|
||||
|
||||
async with db.execute("SELECT id, topic_id, content, created_at, updated_at FROM knowledge_docs WHERE topic_id=?", (topic_id,)) as cur:
|
||||
doc = await cur.fetchone()
|
||||
|
||||
return {**dict(row), "messages": msgs, "knowledge_doc": dict(doc) if doc else None}
|
||||
|
||||
@router.patch("/{topic_id}")
|
||||
async def patch_topic(topic_id: int, body: TopicPatch, db: aiosqlite.Connection = Depends(get_db)):
|
||||
if body.title:
|
||||
await db.execute(
|
||||
"UPDATE topics SET title=?, source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||||
(body.title, topic_id),
|
||||
)
|
||||
if body.status:
|
||||
await db.execute(
|
||||
"UPDATE topics SET status=?, source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||||
(body.status, topic_id),
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT * FROM topics WHERE id=?", (topic_id,)) as cur:
|
||||
return dict(await cur.fetchone())
|
||||
|
||||
@router.delete("/{topic_id}")
|
||||
async def delete_topic(topic_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
# 先拿到 doc_id,用于清 FTS
|
||||
async with db.execute("SELECT id FROM knowledge_docs WHERE topic_id=?", (topic_id,)) as cur:
|
||||
doc_row = await cur.fetchone()
|
||||
if doc_row:
|
||||
await db.execute("DELETE FROM knowledge_fts WHERE doc_id=?", (doc_row["id"],))
|
||||
await db.execute("DELETE FROM topic_messages WHERE topic_id=?", (topic_id,))
|
||||
await db.execute("DELETE FROM knowledge_docs WHERE topic_id=?", (topic_id,))
|
||||
await db.execute("DELETE FROM topics WHERE id=?", (topic_id,))
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
@router.post("/{topic_id}/messages")
|
||||
async def add_message(topic_id: int, body: MessageAdd, db: aiosqlite.Connection = Depends(get_db)):
|
||||
await db.execute(
|
||||
"INSERT OR IGNORE INTO topic_messages (topic_id, msg_seq, talker, added_by) VALUES (?, ?, ?, 'user')",
|
||||
(topic_id, body.msg_seq, body.talker)
|
||||
)
|
||||
await db.execute(
|
||||
"UPDATE topics SET source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||||
(topic_id,),
|
||||
)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
@router.delete("/{topic_id}/messages/{seq}")
|
||||
async def remove_message(topic_id: int, seq: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
await db.execute("DELETE FROM topic_messages WHERE topic_id=? AND msg_seq=?", (topic_id, seq))
|
||||
await db.execute(
|
||||
"UPDATE topics SET source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||||
(topic_id,),
|
||||
)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
@router.post("/{topic_id}/summarize")
|
||||
async def summarize(topic_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT * FROM topics WHERE id=?", (topic_id,)) as cur:
|
||||
topic = await cur.fetchone()
|
||||
if not topic:
|
||||
raise HTTPException(404, "not found")
|
||||
topic_data = dict(topic)
|
||||
await _mark_stale_summarize_tasks(db, topic_data["group_id"], topic_id)
|
||||
# 创建 ai_tasks 记录以追踪进度
|
||||
await db.execute(
|
||||
"INSERT INTO ai_tasks (group_id, type, status, progress) VALUES (?, 'summarize', 'running', ?)",
|
||||
(topic_data["group_id"], json.dumps({"processed": 0, "total": 1}))
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT last_insert_rowid() AS id") as cur:
|
||||
task_row = await cur.fetchone()
|
||||
task_id = task_row["id"]
|
||||
from services.summary_engine import run_summarize
|
||||
import asyncio
|
||||
asyncio.create_task(run_summarize(topic_id, topic_data, task_id))
|
||||
return {"ok": True, "task_id": task_id}
|
||||
Reference in New Issue
Block a user