Initial upload for secondary development
This commit is contained in:
56
chatlog_fastAPI/ChatLabBackend.spec
Normal file
56
chatlog_fastAPI/ChatLabBackend.spec
Normal file
@@ -0,0 +1,56 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
|
||||
from PyInstaller.utils.hooks import collect_data_files, collect_submodules
|
||||
|
||||
datas = []
|
||||
datas += collect_data_files("jieba")
|
||||
|
||||
hiddenimports = []
|
||||
hiddenimports += collect_submodules("uvicorn")
|
||||
hiddenimports += collect_submodules("fastapi")
|
||||
hiddenimports += collect_submodules("pydantic_settings")
|
||||
hiddenimports += collect_submodules("aiosqlite")
|
||||
hiddenimports += collect_submodules("apscheduler")
|
||||
|
||||
a = Analysis(
|
||||
["run_backend.py"],
|
||||
pathex=[],
|
||||
binaries=[],
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
noarchive=False,
|
||||
optimize=0,
|
||||
)
|
||||
pyz = PYZ(a.pure)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
[],
|
||||
exclude_binaries=True,
|
||||
name="ChatLabBackend",
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
|
||||
coll = COLLECT(
|
||||
exe,
|
||||
a.binaries,
|
||||
a.datas,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
name="ChatLabBackend",
|
||||
)
|
||||
55
chatlog_fastAPI/config.py
Normal file
55
chatlog_fastAPI/config.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import List
|
||||
|
||||
|
||||
def _default_data_dir() -> str:
|
||||
configured = os.environ.get("CHATLAB_DATA_DIR")
|
||||
if configured:
|
||||
return str(Path(configured).expanduser())
|
||||
appdata = os.environ.get("APPDATA")
|
||||
if appdata:
|
||||
return str(Path(appdata) / "ChatLab")
|
||||
return str(Path.home() / ".chatlab")
|
||||
|
||||
|
||||
def _default_static_dir() -> str:
|
||||
configured = os.environ.get("CHATLAB_STATIC_DIR")
|
||||
if configured:
|
||||
return str(Path(configured).expanduser())
|
||||
return str((Path(__file__).resolve().parents[1] / "chatlab-web" / "frontend" / "dist"))
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
chatlog_base_url: str = "http://127.0.0.1:5030"
|
||||
ai_base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
ai_api_key: str = ""
|
||||
ai_model: str = "" # 不设默认值,必须由用户在设置页配置
|
||||
summary_model: str = "" # 不设默认值,必须由用户在设置页配置
|
||||
voice_model: str = "" # 不设默认值,必须由用户在设置页配置
|
||||
vision_model: str = "" # 不设默认值,必须由用户在设置页配置
|
||||
data_dir: str = _default_data_dir()
|
||||
static_dir: str = _default_static_dir()
|
||||
db_path: str = str(Path(_default_data_dir()) / "data" / "knowledge.db")
|
||||
cors_origins: List[str] = [
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:5173",
|
||||
"http://localhost:3000",
|
||||
]
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
settings = Settings()
|
||||
|
||||
try:
|
||||
Path(settings.data_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(settings.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
except PermissionError:
|
||||
fallback_dir = Path(tempfile.gettempdir()) / "ChatLab"
|
||||
fallback_dir.mkdir(parents=True, exist_ok=True)
|
||||
settings.data_dir = str(fallback_dir)
|
||||
settings.db_path = str(fallback_dir / "data" / "knowledge.db")
|
||||
Path(settings.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
301
chatlog_fastAPI/database.py
Normal file
301
chatlog_fastAPI/database.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import aiosqlite
|
||||
import asyncio
|
||||
import httpx
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from config import settings
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_data_db_dir = Path(settings.db_path).resolve().parent
|
||||
_data_db_dir.mkdir(parents=True, exist_ok=True)
|
||||
_current_db_path = str(Path(settings.db_path).resolve())
|
||||
_initialized_dbs = set()
|
||||
|
||||
_resolved_wxid: str | None = None
|
||||
_wxid_last_resolved: float = 0.0
|
||||
_WXID_TTL = 60.0 # 60 秒后强制重新检测,确保账号切换能被感知
|
||||
STALE_SUMMARIZE_ERROR = "AI 报告生成任务超过 15 分钟未完成,已自动标记为失败,可重新生成"
|
||||
|
||||
def _db_path_for_wxid(wxid: str) -> str:
|
||||
if wxid and wxid != "default":
|
||||
safe = "".join(c for c in wxid if c.isalnum() or c in ("_", "-"))
|
||||
return str((_data_db_dir / f"knowledge_{safe}.db").resolve())
|
||||
return str(Path(settings.db_path).resolve())
|
||||
|
||||
|
||||
def reset_wxid_cache():
|
||||
global _resolved_wxid, _wxid_last_resolved
|
||||
_resolved_wxid = None
|
||||
_wxid_last_resolved = 0.0
|
||||
|
||||
|
||||
async def get_current_wxid(force: bool = False):
|
||||
global _resolved_wxid, _wxid_last_resolved
|
||||
now = time.time()
|
||||
# 已有有效缓存且未超时,直接返回
|
||||
if (
|
||||
not force
|
||||
and _resolved_wxid
|
||||
and _resolved_wxid != "default"
|
||||
and (now - _wxid_last_resolved) < _WXID_TTL
|
||||
):
|
||||
return _resolved_wxid
|
||||
# 重新解析当前 wxid
|
||||
base = settings.chatlog_base_url
|
||||
async with httpx.AsyncClient(trust_env=False, timeout=10.0) as client:
|
||||
try:
|
||||
r = await client.get(f"{base}/api/v1/chatlog", params={"talker": "filehelper", "limit": 100, "time": "1970-01-01,2099-12-31", "format": "json"})
|
||||
if r.status_code == 200:
|
||||
data = r.json()
|
||||
for msg in data.get("items", []):
|
||||
if msg.get("isSelf"):
|
||||
_resolved_wxid = msg.get("sender")
|
||||
_wxid_last_resolved = time.time()
|
||||
return _resolved_wxid
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
r = await client.get(f"{base}/api/v1/chatroom", params={"limit": 10, "format": "json"})
|
||||
if r.status_code == 200:
|
||||
rooms = r.json().get("items", [])
|
||||
for room in rooms:
|
||||
room_id = room.get("name")
|
||||
r2 = await client.get(f"{base}/api/v1/chatlog", params={"talker": room_id, "limit": 50, "time": "1970-01-01,2099-12-31", "format": "json"})
|
||||
if r2.status_code == 200:
|
||||
data2 = r2.json()
|
||||
for msg in data2.get("items", []):
|
||||
if msg.get("isSelf"):
|
||||
_resolved_wxid = msg.get("sender")
|
||||
_wxid_last_resolved = time.time()
|
||||
return _resolved_wxid
|
||||
except Exception:
|
||||
pass
|
||||
if force:
|
||||
reset_wxid_cache()
|
||||
return "default"
|
||||
|
||||
async def update_db_path(force: bool = False):
|
||||
global _current_db_path
|
||||
wxid = await get_current_wxid(force=force)
|
||||
new_path = _db_path_for_wxid(wxid)
|
||||
if new_path != _current_db_path:
|
||||
log.info(f"Switching database to {new_path}")
|
||||
_current_db_path = new_path
|
||||
await init_db(new_path)
|
||||
return _current_db_path
|
||||
|
||||
def get_active_db_path():
|
||||
return _current_db_path
|
||||
|
||||
async def get_db():
|
||||
path = get_active_db_path()
|
||||
if path not in _initialized_dbs:
|
||||
await init_db(path)
|
||||
async with aiosqlite.connect(path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
yield db
|
||||
|
||||
async def init_db(path=None):
|
||||
if path is None:
|
||||
path = get_active_db_path()
|
||||
async with aiosqlite.connect(path) as db:
|
||||
await db.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS groups (
|
||||
id INTEGER PRIMARY KEY,
|
||||
talker TEXT UNIQUE NOT NULL,
|
||||
name TEXT,
|
||||
analysis_prompt TEXT DEFAULT '',
|
||||
cursor_seq INTEGER DEFAULT 0,
|
||||
initialized INTEGER DEFAULT 0,
|
||||
poll_interval INTEGER DEFAULT 300,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS topics (
|
||||
id INTEGER PRIMARY KEY,
|
||||
group_id INTEGER REFERENCES groups(id),
|
||||
title TEXT NOT NULL,
|
||||
source TEXT DEFAULT 'manual',
|
||||
status TEXT DEFAULT 'pending',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS topic_messages (
|
||||
topic_id INTEGER REFERENCES topics(id),
|
||||
msg_seq INTEGER,
|
||||
talker TEXT,
|
||||
added_by TEXT DEFAULT 'ai',
|
||||
message_json TEXT,
|
||||
PRIMARY KEY (topic_id, msg_seq)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS knowledge_docs (
|
||||
id INTEGER PRIMARY KEY,
|
||||
topic_id INTEGER UNIQUE REFERENCES topics(id),
|
||||
content TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
curated_at DATETIME
|
||||
);
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS knowledge_fts USING fts5(
|
||||
doc_id UNINDEXED,
|
||||
title,
|
||||
content
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS ai_tasks (
|
||||
id INTEGER PRIMARY KEY,
|
||||
group_id INTEGER REFERENCES groups(id),
|
||||
type TEXT,
|
||||
status TEXT,
|
||||
progress TEXT,
|
||||
error TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS app_settings (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
""")
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE ai_tasks
|
||||
SET status='error', error=?, updated_at=CURRENT_TIMESTAMP
|
||||
WHERE type='summarize'
|
||||
AND status='running'
|
||||
AND datetime(updated_at) <= datetime('now', '-15 minutes')
|
||||
""",
|
||||
(STALE_SUMMARIZE_ERROR,),
|
||||
)
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE topics
|
||||
SET status='error', updated_at=CURRENT_TIMESTAMP
|
||||
WHERE status='processing'
|
||||
AND datetime(updated_at) <= datetime('now', '-15 minutes')
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async with db.execute("PRAGMA table_info(topic_messages)") as cur:
|
||||
topic_message_cols = {row[1] for row in await cur.fetchall()}
|
||||
if "message_json" not in topic_message_cols:
|
||||
await db.execute("ALTER TABLE topic_messages ADD COLUMN message_json TEXT")
|
||||
await db.commit()
|
||||
log.info(f"[init_db] added topic_messages.message_json in {path}")
|
||||
|
||||
async with db.execute("PRAGMA table_info(groups)") as cur:
|
||||
group_cols = {row[1] for row in await cur.fetchall()}
|
||||
if "analysis_prompt" not in group_cols:
|
||||
await db.execute("ALTER TABLE groups ADD COLUMN analysis_prompt TEXT DEFAULT ''")
|
||||
await db.commit()
|
||||
log.info(f"[init_db] added groups.analysis_prompt in {path}")
|
||||
|
||||
async with db.execute("PRAGMA table_info(topics)") as cur:
|
||||
topic_cols = {row[1] for row in await cur.fetchall()}
|
||||
if "source" not in topic_cols:
|
||||
await db.execute("ALTER TABLE topics ADD COLUMN source TEXT DEFAULT 'manual'")
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE topics
|
||||
SET source = CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM topic_messages tm
|
||||
WHERE tm.topic_id = topics.id AND tm.added_by = 'user'
|
||||
) THEN 'manual'
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM topic_messages tm
|
||||
WHERE tm.topic_id = topics.id AND COALESCE(tm.added_by, 'ai') = 'ai'
|
||||
) THEN 'ai'
|
||||
ELSE 'manual'
|
||||
END
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
log.info(f"[init_db] added topics.source in {path}")
|
||||
|
||||
async with db.execute("PRAGMA table_info(knowledge_docs)") as cur:
|
||||
knowledge_cols = {row[1] for row in await cur.fetchall()}
|
||||
if "curated_at" not in knowledge_cols:
|
||||
await db.execute("ALTER TABLE knowledge_docs ADD COLUMN curated_at DATETIME")
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE knowledge_docs
|
||||
SET curated_at = updated_at
|
||||
WHERE updated_at IS NOT NULL
|
||||
AND created_at IS NOT NULL
|
||||
AND updated_at > created_at
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
log.info(f"[init_db] added knowledge_docs.curated_at in {path}")
|
||||
|
||||
# 迁移 topics 表到 AUTOINCREMENT,防止 SQLite rowid 复用导致旧 knowledge_docs
|
||||
# 被新建话题"接"上(跨群串报告的根因)。每次启动检测一次,已迁移则跳过。
|
||||
async with db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='sqlite_sequence'"
|
||||
) as cur:
|
||||
has_seq_tbl = await cur.fetchone() is not None
|
||||
needs_migrate = True
|
||||
if has_seq_tbl:
|
||||
async with db.execute(
|
||||
"SELECT 1 FROM sqlite_sequence WHERE name='topics'"
|
||||
) as cur:
|
||||
if await cur.fetchone():
|
||||
needs_migrate = False
|
||||
if needs_migrate:
|
||||
await db.executescript("""
|
||||
BEGIN;
|
||||
CREATE TABLE topics_new (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
group_id INTEGER REFERENCES groups(id),
|
||||
title TEXT NOT NULL,
|
||||
source TEXT DEFAULT 'manual',
|
||||
status TEXT DEFAULT 'pending',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
INSERT INTO topics_new (id, group_id, title, source, status, created_at, updated_at)
|
||||
SELECT id, group_id, title, COALESCE(source, 'manual'), status, created_at, updated_at FROM topics;
|
||||
DROP TABLE topics;
|
||||
ALTER TABLE topics_new RENAME TO topics;
|
||||
COMMIT;
|
||||
""")
|
||||
await db.execute(
|
||||
"INSERT OR REPLACE INTO sqlite_sequence(name, seq) "
|
||||
"SELECT 'topics', COALESCE(MAX(id), 0) FROM topics"
|
||||
)
|
||||
await db.commit()
|
||||
log.info(f"[init_db] migrated topics table to AUTOINCREMENT in {path}")
|
||||
|
||||
# 孤儿数据清理:删除 topic_id 不存在于 topics 的 knowledge_docs 及其 FTS。
|
||||
# 历史上删群时遗漏过这两张表,需要每次启动幂等修复。
|
||||
await db.execute("""
|
||||
DELETE FROM knowledge_fts WHERE doc_id IN (
|
||||
SELECT id FROM knowledge_docs
|
||||
WHERE topic_id NOT IN (SELECT id FROM topics)
|
||||
)
|
||||
""")
|
||||
await db.execute("""
|
||||
DELETE FROM knowledge_docs
|
||||
WHERE topic_id NOT IN (SELECT id FROM topics)
|
||||
""")
|
||||
# 错绑数据清理:doc 创建时间早于其指向的 topic 创建时间,说明 doc 是历史残留、
|
||||
# topic 是后建的(rowid 复用),doc 应清掉。合法 doc 必然在 topic 之后生成。
|
||||
await db.execute("""
|
||||
DELETE FROM knowledge_fts WHERE doc_id IN (
|
||||
SELECT d.id FROM knowledge_docs d
|
||||
JOIN topics t ON t.id = d.topic_id
|
||||
WHERE d.created_at < t.created_at
|
||||
)
|
||||
""")
|
||||
await db.execute("""
|
||||
DELETE FROM knowledge_docs WHERE id IN (
|
||||
SELECT d.id FROM knowledge_docs d
|
||||
JOIN topics t ON t.id = d.topic_id
|
||||
WHERE d.created_at < t.created_at
|
||||
)
|
||||
""")
|
||||
await db.commit()
|
||||
_initialized_dbs.add(path)
|
||||
151
chatlog_fastAPI/main.py
Normal file
151
chatlog_fastAPI/main.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import httpx
|
||||
from database import get_active_db_path, get_current_wxid, init_db, reset_wxid_cache, update_db_path
|
||||
from scheduler import start_scheduler
|
||||
from config import settings
|
||||
from routers import search, groups, topics, knowledge, ai, sse, files, chatlog_proxy
|
||||
from routers import settings as settings_router
|
||||
from services.chatlog_context import get_chatlog_context, update_chatlog_context
|
||||
from services.media_resolver import diagnose_media
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatlogContextRequest(BaseModel):
|
||||
account: str = ""
|
||||
workDir: str = ""
|
||||
dataDir: str = ""
|
||||
platform: str = "windows"
|
||||
version: int = 4
|
||||
chatlogExe: str = ""
|
||||
chatlogVersion: str = ""
|
||||
|
||||
async def _account_watch_loop():
|
||||
"""每 60 秒检测一次当前微信账号,如账号切换则自动切换数据库。"""
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
try:
|
||||
await update_db_path()
|
||||
except Exception as e:
|
||||
log.warning(f"[account_watch] update_db_path error: {e}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await init_db()
|
||||
await start_scheduler()
|
||||
# 启动后台账号监控任务
|
||||
task = asyncio.create_task(_account_watch_loop())
|
||||
yield
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@app.exception_handler(RuntimeError)
|
||||
async def runtime_error_handler(request: Request, exc: RuntimeError):
|
||||
return JSONResponse(status_code=500, content={"detail": str(exc)})
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(search.router)
|
||||
app.include_router(groups.router)
|
||||
app.include_router(topics.router)
|
||||
app.include_router(knowledge.router)
|
||||
app.include_router(ai.router)
|
||||
app.include_router(sse.router)
|
||||
app.include_router(files.router)
|
||||
app.include_router(settings_router.router)
|
||||
app.include_router(chatlog_proxy.router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
chatlog_ok = False
|
||||
chatlog_error = ""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=3.0, trust_env=False) as client:
|
||||
resp = await client.get(f"{settings.chatlog_base_url}/api/v1/session", params={"limit": 1, "format": "json"})
|
||||
chatlog_ok = resp.status_code == 200
|
||||
if not chatlog_ok:
|
||||
chatlog_error = f"HTTP {resp.status_code}"
|
||||
except Exception as e:
|
||||
chatlog_error = str(e)
|
||||
|
||||
wxid = await get_current_wxid() if chatlog_ok else "default"
|
||||
return {
|
||||
"ok": True,
|
||||
"chatlog_ok": chatlog_ok,
|
||||
"chatlog_error": chatlog_error,
|
||||
"wxid": wxid,
|
||||
"db_path": get_active_db_path(),
|
||||
"data_dir": settings.data_dir,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/system/refresh-account")
|
||||
async def refresh_account():
|
||||
reset_wxid_cache()
|
||||
db_path = await update_db_path(force=True)
|
||||
wxid = await get_current_wxid()
|
||||
return {"ok": True, "wxid": wxid, "db_path": db_path}
|
||||
|
||||
|
||||
@app.post("/api/system/chatlog-context")
|
||||
async def set_chatlog_context(body: ChatlogContextRequest):
|
||||
return {"ok": True, "context": update_chatlog_context(body.model_dump())}
|
||||
|
||||
|
||||
@app.get("/api/system/chatlog-context")
|
||||
async def read_chatlog_context():
|
||||
return {"ok": True, "context": get_chatlog_context()}
|
||||
|
||||
|
||||
@app.get("/api/system/media-diagnostics")
|
||||
async def media_diagnostics(kind: str = "voice", key: str = ""):
|
||||
return await diagnose_media(kind, key)
|
||||
|
||||
|
||||
static_dir = Path(settings.static_dir)
|
||||
if static_dir.exists():
|
||||
assets_dir = static_dir / "assets"
|
||||
if assets_dir.exists():
|
||||
app.mount("/assets", StaticFiles(directory=str(assets_dir)), name="assets")
|
||||
for static_name in ("favicon.svg", "icons.svg"):
|
||||
static_file = static_dir / static_name
|
||||
|
||||
if static_file.exists():
|
||||
@app.get(f"/{static_name}", include_in_schema=False)
|
||||
async def _serve_static_file(name=static_name):
|
||||
return FileResponse(static_dir / name)
|
||||
|
||||
@app.get("/", include_in_schema=False)
|
||||
async def spa_index():
|
||||
return FileResponse(static_dir / "index.html")
|
||||
|
||||
@app.get("/{full_path:path}", include_in_schema=False)
|
||||
async def spa_fallback(full_path: str):
|
||||
path = static_dir / full_path
|
||||
if path.exists() and path.is_file():
|
||||
return FileResponse(path)
|
||||
return FileResponse(static_dir / "index.html")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
# 为了在使用 PyInstaller 打包时也能正常运行
|
||||
uvicorn.run(app, host="127.0.0.1", port=8000, reload=False)
|
||||
8
chatlog_fastAPI/requirements.txt
Normal file
8
chatlog_fastAPI/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
httpx>=0.27.0,<0.28.0
|
||||
openai>=1.56.1,<3.0.0
|
||||
apscheduler
|
||||
jieba
|
||||
aiosqlite
|
||||
pydantic-settings
|
||||
0
chatlog_fastAPI/routers/__init__.py
Normal file
0
chatlog_fastAPI/routers/__init__.py
Normal file
116
chatlog_fastAPI/routers/ai.py
Normal file
116
chatlog_fastAPI/routers/ai.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Literal
|
||||
import aiosqlite, json, logging
|
||||
import httpx
|
||||
from database import get_db
|
||||
from config import settings
|
||||
from services.ai_client import get_openai_client
|
||||
from services.runtime_settings import get_ai_settings
|
||||
from services.media_parser import parse_media
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["ai"])
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_ai_client():
|
||||
return await get_openai_client()
|
||||
|
||||
|
||||
class SummarizeRequest(BaseModel):
|
||||
context: str # 已组装好的对话文本(含媒体描述)
|
||||
room_name: Optional[str] = ""
|
||||
messages: Optional[list] = None # 兼容旧调用,忽略
|
||||
|
||||
|
||||
class ParseRequest(BaseModel):
|
||||
type: Literal["voice", "image", "video"]
|
||||
key: str # voice: ServerID string; image/video: md5
|
||||
|
||||
|
||||
@router.post("/ai/parse")
|
||||
async def ai_parse(body: ParseRequest):
|
||||
"""
|
||||
通过 FastAPI 代理 AI 媒体解析:
|
||||
- voice: 从 chatlog 下载音频 → DashScope Paraformer ASR 转文字
|
||||
- image/video: 从 chatlog 下载媒体 → base64 → 视觉模型描述
|
||||
"""
|
||||
try:
|
||||
return await parse_media(body.type, body.key)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error(f"[ai/parse] 媒体解析失败: {e}", exc_info=True)
|
||||
raise HTTPException(500, f"媒体解析失败: {e}")
|
||||
|
||||
|
||||
@router.post("/ai/summarize/stream")
|
||||
async def summarize_stream(body: SummarizeRequest):
|
||||
"""
|
||||
接收前端已处理好的对话上下文,调用 AI 模型流式输出总结。
|
||||
前端负责先把媒体(图片/语音/视频)解析成文字再拼进 context。
|
||||
"""
|
||||
_ai = await get_ai_settings()
|
||||
if not _ai.get("ai_api_key"):
|
||||
async def err_gen():
|
||||
yield 'data: {"error": "AI 服务未配置,请在「设置」页面填入 AI API Key"}\n\n'
|
||||
return StreamingResponse(err_gen(), media_type="text/event-stream")
|
||||
if not _ai.get("ai_model"):
|
||||
async def err_gen():
|
||||
yield 'data: {"error": "知识总结模型未配置,请在「设置」页面填入模型名称(如 qwen-max)"}\n\n'
|
||||
return StreamingResponse(err_gen(), media_type="text/event-stream")
|
||||
|
||||
context = body.context.strip()
|
||||
if not context:
|
||||
async def err_gen():
|
||||
yield 'data: {"error": "对话内容为空"}\n\n'
|
||||
return StreamingResponse(err_gen(), media_type="text/event-stream")
|
||||
|
||||
room = body.room_name or "会话"
|
||||
|
||||
system_prompt = (
|
||||
"你是一位专业的对话分析助手。"
|
||||
"请根据提供的聊天记录(可能包含图片描述、语音转文字、视频描述等多媒体内容)"
|
||||
"生成一份结构清晰的 Markdown 总结。"
|
||||
"总结应包含:主要话题、关键信息点、媒体内容要点、待办事项(如有)。"
|
||||
"只输出 Markdown 格式内容,不要有任何额外说明。"
|
||||
)
|
||||
user_prompt = (
|
||||
f"群聊:{room}\n\n"
|
||||
f"以下是聊天记录(含多媒体内容描述):\n\n"
|
||||
f"{context[:12000]}\n\n" # 限制 token 数
|
||||
f"请生成总结:"
|
||||
)
|
||||
|
||||
async def generate():
|
||||
try:
|
||||
_client, _ai = await _get_ai_client()
|
||||
stream = await _client.chat.completions.create(
|
||||
model=_ai["ai_model"],
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
stream=True,
|
||||
temperature=0.3,
|
||||
)
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta.content if chunk.choices else None
|
||||
if delta:
|
||||
yield f"data: {json.dumps({'delta': delta}, ensure_ascii=False)}\n\n"
|
||||
yield 'data: {"done": true}\n\n'
|
||||
except Exception as e:
|
||||
log.error(f"[summarize/stream] LLM 调用失败: {e}", exc_info=True)
|
||||
yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}")
|
||||
async def get_task(task_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT * FROM ai_tasks WHERE id=?", (task_id,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
return dict(row)
|
||||
93
chatlog_fastAPI/routers/chatlog_proxy.py
Normal file
93
chatlog_fastAPI/routers/chatlog_proxy.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
|
||||
from config import settings
|
||||
|
||||
router = APIRouter(tags=["chatlog-proxy"])
|
||||
|
||||
HOP_BY_HOP_HEADERS = {
|
||||
"connection",
|
||||
"keep-alive",
|
||||
"proxy-authenticate",
|
||||
"proxy-authorization",
|
||||
"te",
|
||||
"trailers",
|
||||
"transfer-encoding",
|
||||
"upgrade",
|
||||
}
|
||||
|
||||
|
||||
def _copy_headers(headers: httpx.Headers) -> dict[str, str]:
|
||||
copied: dict[str, str] = {}
|
||||
for key, value in headers.items():
|
||||
if key.lower() not in HOP_BY_HOP_HEADERS:
|
||||
copied[key] = value
|
||||
return copied
|
||||
|
||||
|
||||
async def _proxy_chatlog(request: Request, upstream_path: str) -> Response:
|
||||
query = request.url.query
|
||||
target = f"{settings.chatlog_base_url}{upstream_path}"
|
||||
if query:
|
||||
target = f"{target}?{query}"
|
||||
|
||||
body = await request.body()
|
||||
headers = {
|
||||
key: value
|
||||
for key, value in request.headers.items()
|
||||
if key.lower() not in HOP_BY_HOP_HEADERS and key.lower() != "host"
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=None, trust_env=False, follow_redirects=True) as client:
|
||||
upstream = await client.request(
|
||||
request.method,
|
||||
target,
|
||||
content=body if body else None,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
response_headers = _copy_headers(upstream.headers)
|
||||
return StreamingResponse(
|
||||
iter([upstream.content]),
|
||||
status_code=upstream.status_code,
|
||||
media_type=upstream.headers.get("content-type"),
|
||||
headers=response_headers,
|
||||
)
|
||||
|
||||
|
||||
@router.api_route("/api/v1/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"])
|
||||
async def proxy_api_v1(path: str, request: Request):
|
||||
return await _proxy_chatlog(request, f"/api/v1/{path}")
|
||||
|
||||
|
||||
async def _proxy_media(kind: str, path: str, request: Request):
|
||||
safe_path = "/".join(quote(part, safe="") for part in path.split("/"))
|
||||
return await _proxy_chatlog(request, f"/{kind}/{safe_path}")
|
||||
|
||||
|
||||
@router.api_route("/image/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_image(path: str, request: Request):
|
||||
return await _proxy_media("image", path, request)
|
||||
|
||||
|
||||
@router.api_route("/voice/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_voice(path: str, request: Request):
|
||||
return await _proxy_media("voice", path, request)
|
||||
|
||||
|
||||
@router.api_route("/video/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_video(path: str, request: Request):
|
||||
return await _proxy_media("video", path, request)
|
||||
|
||||
|
||||
@router.api_route("/file/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_file(path: str, request: Request):
|
||||
return await _proxy_media("file", path, request)
|
||||
|
||||
|
||||
@router.api_route("/data/{path:path}", methods=["GET", "POST", "OPTIONS"])
|
||||
async def proxy_data(path: str, request: Request):
|
||||
return await _proxy_media("data", path, request)
|
||||
190
chatlog_fastAPI/routers/files.py
Normal file
190
chatlog_fastAPI/routers/files.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
from config import settings
|
||||
from services.chatlog_client import chatlog_client
|
||||
|
||||
router = APIRouter(prefix="/api/files", tags=["files"])
|
||||
|
||||
|
||||
OFFICE_MEDIA_TYPES = {
|
||||
".xls": "application/vnd.ms-excel",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
".ppt": "application/vnd.ms-powerpoint",
|
||||
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
".doc": "application/msword",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".pdf": "application/pdf",
|
||||
".dwg": "application/acad",
|
||||
}
|
||||
|
||||
|
||||
def _connect_hardlink_db(hardlink_db: Path) -> sqlite3.Connection:
|
||||
"""
|
||||
chatlog may keep hardlink.db open. Copying a tiny snapshot avoids transient
|
||||
"unable to open database file" errors on Windows while keeping reads safe.
|
||||
"""
|
||||
tmp = Path(tempfile.gettempdir()) / f"chatlab_hardlink_{os.getpid()}_{hardlink_db.stat().st_mtime_ns}.db"
|
||||
if not tmp.exists() or tmp.stat().st_size != hardlink_db.stat().st_size:
|
||||
shutil.copy2(hardlink_db, tmp)
|
||||
con = sqlite3.connect(tmp)
|
||||
con.row_factory = sqlite3.Row
|
||||
return con
|
||||
|
||||
|
||||
def _safe_download_name(name: str, fallback: str) -> str:
|
||||
name = (name or fallback).replace("\r", "").replace("\n", "").strip()
|
||||
return name or fallback
|
||||
|
||||
|
||||
def _content_disposition(filename: str) -> str:
|
||||
quoted = quote(filename)
|
||||
ascii_fallback = re.sub(r"[^A-Za-z0-9._-]+", "_", filename) or "download"
|
||||
return f"attachment; filename=\"{ascii_fallback}\"; filename*=UTF-8''{quoted}"
|
||||
|
||||
|
||||
def _guess_media_type(filename: str, fallback: str = "") -> str:
|
||||
ext = Path(filename or "").suffix.lower()
|
||||
return OFFICE_MEDIA_TYPES.get(ext) or mimetypes.guess_type(filename)[0] or fallback or "application/octet-stream"
|
||||
|
||||
|
||||
async def _proxy_chatlog_file(md5: str, filename: str = ""):
|
||||
url = f"{settings.chatlog_base_url}/file/{quote(md5, safe='')}"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30, trust_env=False, follow_redirects=True) as client:
|
||||
resp = await client.get(url)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if resp.status_code != 200 or resp.content == b'"media not found"':
|
||||
return None
|
||||
|
||||
headers = {
|
||||
"Content-Length": str(len(resp.content)),
|
||||
"X-ChatLab-File-Source": "chatlog",
|
||||
}
|
||||
if filename:
|
||||
headers["Content-Disposition"] = _content_disposition(filename)
|
||||
media_type = _guess_media_type(filename, resp.headers.get("content-type") or "")
|
||||
return StreamingResponse(iter([resp.content]), media_type=media_type, headers=headers)
|
||||
|
||||
|
||||
def _xwechat_roots_from_hardlink_db(hardlink_db: Path) -> list[Path]:
|
||||
roots: list[Path] = []
|
||||
try:
|
||||
con = _connect_hardlink_db(hardlink_db)
|
||||
row = con.execute("SELECT ValueStdStr FROM db_info WHERE Key='uuid'").fetchone()
|
||||
raw = row["ValueStdStr"] if row else ""
|
||||
except Exception:
|
||||
raw = ""
|
||||
|
||||
if raw:
|
||||
m = re.search(r"([A-Za-z]:\\[^|]+?xwechat_files)", raw)
|
||||
if m:
|
||||
roots.append(Path(m.group(1)))
|
||||
|
||||
roots.extend([
|
||||
Path.home() / "xwechat_files",
|
||||
Path.home() / "Documents" / "WeChat Files",
|
||||
])
|
||||
uniq: list[Path] = []
|
||||
seen = set()
|
||||
for root in roots:
|
||||
s = str(root).lower()
|
||||
if s not in seen:
|
||||
uniq.append(root)
|
||||
seen.add(s)
|
||||
return uniq
|
||||
|
||||
|
||||
def _find_local_file(hardlink_db: Path, md5: str, requested_name: str = "") -> Path | None:
|
||||
try:
|
||||
con = _connect_hardlink_db(hardlink_db)
|
||||
row = con.execute(
|
||||
"""
|
||||
SELECT md5, file_name, file_size, dir1, dir2
|
||||
FROM file_hardlink_info_v4
|
||||
WHERE md5=?
|
||||
ORDER BY _rowid_ DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(md5,),
|
||||
).fetchone()
|
||||
except Exception:
|
||||
row = None
|
||||
if not row:
|
||||
return None
|
||||
|
||||
names = [requested_name, row["file_name"]]
|
||||
names = [n for n in names if n]
|
||||
size = int(row["file_size"] or 0)
|
||||
roots = _xwechat_roots_from_hardlink_db(hardlink_db)
|
||||
|
||||
for root in roots:
|
||||
if not root.exists():
|
||||
continue
|
||||
for name in names:
|
||||
for candidate in root.rglob(name):
|
||||
try:
|
||||
if candidate.is_file() and (not size or candidate.stat().st_size == size):
|
||||
return candidate
|
||||
except Exception:
|
||||
continue
|
||||
if size:
|
||||
# Fallback by size in the common file store. This is intentionally limited
|
||||
# to msg/file to avoid scanning unrelated huge trees for every request.
|
||||
for file_root in root.glob("*/msg/file"):
|
||||
if not file_root.exists():
|
||||
continue
|
||||
for candidate in file_root.rglob("*"):
|
||||
try:
|
||||
if candidate.is_file() and candidate.stat().st_size == size:
|
||||
if not names or candidate.name in names:
|
||||
return candidate
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/{md5}")
|
||||
async def get_file(md5: str, filename: str = Query("")):
|
||||
md5 = md5.strip()
|
||||
if not re.fullmatch(r"[0-9a-fA-F]{8,64}", md5):
|
||||
raise HTTPException(400, "文件 md5 不合法")
|
||||
|
||||
filename = _safe_download_name(filename, md5)
|
||||
proxied = await _proxy_chatlog_file(md5, filename)
|
||||
if proxied:
|
||||
return proxied
|
||||
|
||||
db_paths = await chatlog_client.get_db_paths()
|
||||
hardlink_paths = db_paths.get("media") or []
|
||||
for raw_path in hardlink_paths:
|
||||
hardlink_db = Path(raw_path)
|
||||
if not hardlink_db.exists():
|
||||
continue
|
||||
local_file = _find_local_file(hardlink_db, md5, filename)
|
||||
if local_file:
|
||||
media_type = _guess_media_type(filename or local_file.name)
|
||||
return FileResponse(
|
||||
path=str(local_file),
|
||||
filename=filename or local_file.name,
|
||||
media_type=media_type,
|
||||
headers={
|
||||
"Content-Disposition": _content_disposition(filename or local_file.name),
|
||||
"Content-Length": str(local_file.stat().st_size),
|
||||
"X-ChatLab-File-Source": "local-hardlink",
|
||||
},
|
||||
)
|
||||
|
||||
raise HTTPException(404, "原文件未找到,可能未解密或已清理")
|
||||
138
chatlog_fastAPI/routers/groups.py
Normal file
138
chatlog_fastAPI/routers/groups.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import aiosqlite, json
|
||||
from database import get_db
|
||||
|
||||
router = APIRouter(prefix="/api/groups", tags=["groups"])
|
||||
|
||||
class GroupCreate(BaseModel):
|
||||
talker: str
|
||||
name: Optional[str] = ""
|
||||
poll_interval: int = 300
|
||||
|
||||
class GroupPatch(BaseModel):
|
||||
analysis_prompt: Optional[str] = None
|
||||
|
||||
class InitParams(BaseModel):
|
||||
start_time: int # unix 秒
|
||||
end_time: int # unix 秒
|
||||
|
||||
@router.post("")
|
||||
async def create_group(body: GroupCreate, db: aiosqlite.Connection = Depends(get_db)):
|
||||
try:
|
||||
await db.execute(
|
||||
"INSERT INTO groups (talker, name, poll_interval) VALUES (?, ?, ?)",
|
||||
(body.talker, body.name, body.poll_interval)
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT * FROM groups WHERE talker=?", (body.talker,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
return dict(row)
|
||||
except aiosqlite.IntegrityError:
|
||||
raise HTTPException(409, "talker already exists")
|
||||
|
||||
@router.get("")
|
||||
async def list_groups(db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT * FROM groups") as cur:
|
||||
rows = await cur.fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
@router.patch("/{group_id}")
|
||||
async def patch_group(group_id: int, body: GroupPatch, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT id FROM groups WHERE id=?", (group_id,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "group not found")
|
||||
updates = body.model_dump(exclude_none=True)
|
||||
if "analysis_prompt" in updates:
|
||||
await db.execute(
|
||||
"UPDATE groups SET analysis_prompt=? WHERE id=?",
|
||||
(updates["analysis_prompt"], group_id),
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT * FROM groups WHERE id=?", (group_id,)) as cur:
|
||||
return dict(await cur.fetchone())
|
||||
|
||||
@router.post("/{group_id}/init")
|
||||
async def trigger_init(
|
||||
group_id: int,
|
||||
body: InitParams,
|
||||
db: aiosqlite.Connection = Depends(get_db),
|
||||
):
|
||||
"""对指定时间区间内全部消息做一次性 AI 分类。串行执行。"""
|
||||
async with db.execute("SELECT * FROM groups WHERE id=?", (group_id,)) as cur:
|
||||
group = await cur.fetchone()
|
||||
if not group:
|
||||
raise HTTPException(404, "group not found")
|
||||
|
||||
from services.topic_engine import get_classifying_group
|
||||
busy = get_classifying_group()
|
||||
if busy is not None:
|
||||
raise HTTPException(409, f"已有群正在分析(group_id={busy}),请等待完成后再试")
|
||||
|
||||
if body.end_time <= body.start_time:
|
||||
raise HTTPException(400, "end_time 必须大于 start_time")
|
||||
|
||||
await db.execute(
|
||||
"INSERT INTO ai_tasks (group_id, type, status, progress) VALUES (?, 'classify_window', 'running', ?)",
|
||||
(group_id, json.dumps({"processed": 0, "total": 0})),
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT last_insert_rowid()") as cur:
|
||||
task = await cur.fetchone()
|
||||
task_id = task[0]
|
||||
|
||||
from services.topic_engine import run_classify_window
|
||||
import asyncio
|
||||
asyncio.create_task(
|
||||
run_classify_window(group_id, task_id, dict(group), body.start_time, body.end_time)
|
||||
)
|
||||
return {"task_id": task_id}
|
||||
|
||||
@router.get("/{group_id}/task")
|
||||
async def get_task(group_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute(
|
||||
"SELECT * FROM ai_tasks WHERE group_id=? ORDER BY id DESC LIMIT 1", (group_id,)
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "no task found")
|
||||
return dict(row)
|
||||
|
||||
@router.delete("/{group_id}")
|
||||
async def delete_group(group_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT id FROM groups WHERE id=?", (group_id,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "group not found")
|
||||
# 级联删除关联数据(FTS → 知识库报告 → 话题消息 → 话题 → 任务 → 群组)
|
||||
# 注意:必须先清 knowledge_fts/knowledge_docs,否则 SQLite 复用 topic_id 时
|
||||
# 残留报告会"接"到新建话题上,造成跨群串报告。
|
||||
await db.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_fts WHERE doc_id IN (
|
||||
SELECT id FROM knowledge_docs WHERE topic_id IN (
|
||||
SELECT id FROM topics WHERE group_id=?
|
||||
)
|
||||
)
|
||||
""",
|
||||
(group_id,)
|
||||
)
|
||||
await db.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_docs WHERE topic_id IN (
|
||||
SELECT id FROM topics WHERE group_id=?
|
||||
)
|
||||
""",
|
||||
(group_id,)
|
||||
)
|
||||
await db.execute(
|
||||
"DELETE FROM topic_messages WHERE topic_id IN (SELECT id FROM topics WHERE group_id=?)",
|
||||
(group_id,)
|
||||
)
|
||||
await db.execute("DELETE FROM topics WHERE group_id=?", (group_id,))
|
||||
await db.execute("DELETE FROM ai_tasks WHERE group_id=?", (group_id,))
|
||||
await db.execute("DELETE FROM groups WHERE id=?", (group_id,))
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
67
chatlog_fastAPI/routers/knowledge.py
Normal file
67
chatlog_fastAPI/routers/knowledge.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import aiosqlite
|
||||
from database import get_db
|
||||
|
||||
router = APIRouter(prefix="/api/knowledge", tags=["knowledge"])
|
||||
|
||||
class KnowledgePatch(BaseModel):
|
||||
content: str
|
||||
|
||||
@router.get("")
|
||||
async def list_knowledge(
|
||||
keyword: Optional[str] = None,
|
||||
db: aiosqlite.Connection = Depends(get_db)
|
||||
):
|
||||
if keyword:
|
||||
# FTS5 查询前先用 jieba 分词,提高中文召回率
|
||||
from services.fts import build_match_query
|
||||
fts_query = build_match_query(keyword)
|
||||
if not fts_query:
|
||||
return []
|
||||
async with db.execute(
|
||||
"SELECT k.id, k.topic_id, k.created_at, k.updated_at, t.title, t.group_id, g.name as group_name "
|
||||
"FROM knowledge_docs k JOIN topics t ON k.topic_id=t.id "
|
||||
"LEFT JOIN groups g ON t.group_id=g.id "
|
||||
"WHERE k.id IN (SELECT doc_id FROM knowledge_fts WHERE knowledge_fts MATCH ?)",
|
||||
(fts_query,)
|
||||
) as cur:
|
||||
return [dict(r) for r in await cur.fetchall()]
|
||||
async with db.execute(
|
||||
"SELECT k.id, k.topic_id, k.created_at, k.updated_at, t.title, t.group_id, g.name as group_name "
|
||||
"FROM knowledge_docs k JOIN topics t ON k.topic_id=t.id "
|
||||
"LEFT JOIN groups g ON t.group_id=g.id "
|
||||
"ORDER BY g.name, k.updated_at DESC"
|
||||
) as cur:
|
||||
return [dict(r) for r in await cur.fetchall()]
|
||||
|
||||
@router.get("/{doc_id}")
|
||||
async def get_knowledge(doc_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT * FROM knowledge_docs WHERE id=?", (doc_id,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
return dict(row)
|
||||
|
||||
@router.patch("/{doc_id}")
|
||||
async def patch_knowledge(doc_id: int, body: KnowledgePatch, db: aiosqlite.Connection = Depends(get_db)):
|
||||
await db.execute(
|
||||
"UPDATE knowledge_docs SET content=?, updated_at=CURRENT_TIMESTAMP, curated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||||
(body.content, doc_id)
|
||||
)
|
||||
await db.commit()
|
||||
# update FTS
|
||||
async with db.execute("SELECT topic_id FROM knowledge_docs WHERE id=?", (doc_id,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
async with db.execute("SELECT title FROM topics WHERE id=?", (row["topic_id"],)) as cur:
|
||||
topic = await cur.fetchone()
|
||||
await db.execute("DELETE FROM knowledge_fts WHERE doc_id=?", (doc_id,))
|
||||
from services.fts import tokenize
|
||||
await db.execute(
|
||||
"INSERT INTO knowledge_fts (doc_id, title, content) VALUES (?, ?, ?)",
|
||||
(doc_id, tokenize(topic["title"]), tokenize(body.content))
|
||||
)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
144
chatlog_fastAPI/routers/search.py
Normal file
144
chatlog_fastAPI/routers/search.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from urllib.parse import quote
|
||||
from services.chatlog_client import MessageIndexNotReady, chatlog_client
|
||||
from services.message_formatter import extract_quote
|
||||
|
||||
router = APIRouter(prefix="/api/search", tags=["search"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def search(
|
||||
talker: str = Query(..., description="群/联系人 ID"),
|
||||
time: str = Query("", description="时间范围,如 2024-01-01,2024-01-31"),
|
||||
sender: str = Query("", description="发送者 ID,可选"),
|
||||
keyword: str = Query("", description="关键词,可选"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=500),
|
||||
):
|
||||
"""透传 chatlog /api/v1/chatlog,返回 {"total": N, "items": [...]}"""
|
||||
offset = (page - 1) * page_size
|
||||
try:
|
||||
data = await chatlog_client.get_messages(
|
||||
talker,
|
||||
time=time,
|
||||
sender=sender,
|
||||
keyword=keyword,
|
||||
limit=page_size,
|
||||
offset=offset,
|
||||
)
|
||||
except MessageIndexNotReady as e:
|
||||
raise HTTPException(status_code=503, detail=str(e)) from e
|
||||
for item in data.get("items", []) or []:
|
||||
contents = item.get("contents") or item.get("Contents") or {}
|
||||
if not isinstance(contents, dict):
|
||||
contents = {}
|
||||
try:
|
||||
is_file = int(item.get("type") or item.get("Type") or 0) == 49 and int(
|
||||
item.get("subType") or item.get("sub_type") or item.get("SubType") or 0
|
||||
) == 6
|
||||
except Exception:
|
||||
is_file = False
|
||||
file_md5 = str(contents.get("md5") or "") if is_file else ""
|
||||
item["is_file"] = is_file
|
||||
item["file_name"] = (
|
||||
contents.get("title") or contents.get("fileName") or contents.get("filename") or ""
|
||||
) if is_file else ""
|
||||
item["file_md5"] = file_md5
|
||||
item["quote"] = item.get("quote") or extract_quote(item)
|
||||
file_name = item["file_name"]
|
||||
item["file_url"] = f"/api/files/{quote(file_md5, safe='')}?filename={quote(file_name or file_md5, safe='')}" if file_md5 else ""
|
||||
return data
|
||||
|
||||
|
||||
@router.get("/chatrooms")
|
||||
async def chatrooms(
|
||||
keyword: str = Query("", description="关键词搜索"),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
):
|
||||
"""获取所有可用的微信群聊列表"""
|
||||
fetch_limit = min(2000, offset + limit)
|
||||
rooms_data = await chatlog_client.get_chatrooms(keyword=keyword, limit=fetch_limit, offset=0)
|
||||
if isinstance(rooms_data, list):
|
||||
room_items = rooms_data
|
||||
total = len(room_items)
|
||||
else:
|
||||
room_items = rooms_data.get("items") or rooms_data.get("data") or []
|
||||
total = rooms_data.get("total", len(room_items))
|
||||
|
||||
merged = []
|
||||
seen = set()
|
||||
|
||||
def get_room_id(item: dict) -> str:
|
||||
return str(item.get("name") or item.get("Name") or item.get("userName") or item.get("UserName") or "")
|
||||
|
||||
def add_room(item: dict):
|
||||
room_id = get_room_id(item)
|
||||
if not room_id or not room_id.endswith("@chatroom") or room_id in seen:
|
||||
return
|
||||
seen.add(room_id)
|
||||
merged.append(item)
|
||||
|
||||
for item in room_items:
|
||||
if isinstance(item, dict):
|
||||
add_room(item)
|
||||
|
||||
# Freshly imported phone records may exist in sessions/messages before
|
||||
# chatroom metadata is populated. Merge @chatroom sessions as fallback.
|
||||
try:
|
||||
session_items = await chatlog_client.get_sessions(keyword="", limit=2000)
|
||||
except Exception:
|
||||
session_items = []
|
||||
|
||||
lowered_keyword = (keyword or "").lower()
|
||||
for session in session_items:
|
||||
if not isinstance(session, dict):
|
||||
continue
|
||||
user_name = str(session.get("userName") or session.get("UserName") or "")
|
||||
if not user_name.endswith("@chatroom"):
|
||||
continue
|
||||
nick_name = session.get("nickName") or session.get("NickName") or ""
|
||||
remark = session.get("remark") or session.get("Remark") or ""
|
||||
if lowered_keyword:
|
||||
haystack = f"{user_name} {nick_name} {remark}".lower()
|
||||
if lowered_keyword not in haystack:
|
||||
continue
|
||||
add_room({
|
||||
"name": user_name,
|
||||
"nickName": nick_name,
|
||||
"remark": remark,
|
||||
"source": "session",
|
||||
})
|
||||
|
||||
return {"total": max(total, len(merged)), "items": merged[offset:offset + limit]}
|
||||
|
||||
@router.get("/avatar")
|
||||
async def avatar(wxid: str = Query(...)):
|
||||
url = await chatlog_client.get_avatar_url(wxid)
|
||||
return {"url": url}
|
||||
|
||||
|
||||
@router.get("/members")
|
||||
async def members(
|
||||
talker: str = Query(..., description="群 ID"),
|
||||
time: str = Query("", description="统计时间范围,可选"),
|
||||
):
|
||||
"""
|
||||
获取群成员列表(按发言量降序)
|
||||
返回 {"members": [...], "total": N}
|
||||
每个成员:userName, displayName, msgCount, lastSpeakTime
|
||||
"""
|
||||
return await chatlog_client.get_chatroom_members(talker, time=time)
|
||||
|
||||
|
||||
@router.get("/sessions")
|
||||
async def sessions(
|
||||
keyword: str = Query("", description="关键词搜索"),
|
||||
limit: int = Query(500, ge=1, le=2000),
|
||||
):
|
||||
"""
|
||||
获取所有会话列表,含最新一条消息预览和时间(来自微信原生 Session 表)。
|
||||
返回:[{ userName, nickName, remark, content, nTime, nOrder }]
|
||||
"""
|
||||
items = await chatlog_client.get_sessions(keyword=keyword, limit=limit)
|
||||
return items
|
||||
64
chatlog_fastAPI/routers/settings.py
Normal file
64
chatlog_fastAPI/routers/settings.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import aiosqlite
|
||||
from database import get_db
|
||||
|
||||
router = APIRouter(prefix="/api/settings", tags=["settings"])
|
||||
|
||||
EDITABLE_KEYS = [
|
||||
"ai_base_url", "ai_api_key", "ai_model", "summary_model",
|
||||
"vision_model", "voice_model", "topic_analysis_prompt",
|
||||
]
|
||||
|
||||
|
||||
def _mask_key(value: str) -> str:
|
||||
if not value or len(value) <= 8:
|
||||
return "*" * len(value) if value else ""
|
||||
return value[:3] + "*" * (len(value) - 7) + value[-4:]
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_settings(db: aiosqlite.Connection = Depends(get_db)):
|
||||
result = {}
|
||||
placeholders = ",".join("?" for _ in EDITABLE_KEYS)
|
||||
async with db.execute(
|
||||
f"SELECT key, value FROM app_settings WHERE key IN ({placeholders})",
|
||||
EDITABLE_KEYS,
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
for row in rows:
|
||||
k, v = row["key"], row["value"]
|
||||
result[k] = _mask_key(v) if k == "ai_api_key" else v
|
||||
for k in EDITABLE_KEYS:
|
||||
if k not in result:
|
||||
result[k] = ""
|
||||
return result
|
||||
|
||||
|
||||
class SettingsUpdate(BaseModel):
|
||||
ai_base_url: Optional[str] = None
|
||||
ai_api_key: Optional[str] = None
|
||||
ai_model: Optional[str] = None
|
||||
summary_model: Optional[str] = None
|
||||
vision_model: Optional[str] = None
|
||||
voice_model: Optional[str] = None
|
||||
topic_analysis_prompt: Optional[str] = None
|
||||
|
||||
|
||||
@router.put("")
|
||||
async def update_settings(body: SettingsUpdate, db: aiosqlite.Connection = Depends(get_db)):
|
||||
updates = body.model_dump(exclude_none=True)
|
||||
for k, v in updates.items():
|
||||
if k not in EDITABLE_KEYS:
|
||||
continue
|
||||
if k == "ai_api_key" and "*" in v:
|
||||
continue
|
||||
await db.execute(
|
||||
"INSERT INTO app_settings (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = ?",
|
||||
(k, v, v),
|
||||
)
|
||||
await db.commit()
|
||||
from services.runtime_settings import invalidate_cache
|
||||
invalidate_cache()
|
||||
return {"status": "ok"}
|
||||
40
chatlog_fastAPI/routers/sse.py
Normal file
40
chatlog_fastAPI/routers/sse.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import asyncio, json, logging
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from services.chatlog_client import chatlog_client
|
||||
from services.message_formatter import attach_quote
|
||||
|
||||
router = APIRouter(prefix="/api/sse", tags=["sse"])
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/chatlog")
|
||||
async def sse_chatlog(talker: str = Query(...)):
|
||||
async def generate():
|
||||
try:
|
||||
data = await chatlog_client.get_messages(talker, limit=1, offset=0)
|
||||
last_total = data.get("total", 0)
|
||||
except Exception:
|
||||
last_total = 0
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(2)
|
||||
try:
|
||||
data = await chatlog_client.get_messages(talker, limit=50, offset=last_total)
|
||||
msgs = data.get("messages") or data.get("items") or []
|
||||
new_total = data.get("total", last_total)
|
||||
for msg in msgs:
|
||||
attach_quote(msg)
|
||||
yield f"data: {json.dumps(msg, ensure_ascii=False)}\n\n"
|
||||
if new_total > last_total:
|
||||
last_total = new_total
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
log.warning(f"[sse] poll error: {e}")
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
435
chatlog_fastAPI/routers/topics.py
Normal file
435
chatlog_fastAPI/routers/topics.py
Normal file
@@ -0,0 +1,435 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import aiosqlite, json
|
||||
from datetime import datetime
|
||||
from urllib.parse import quote
|
||||
from database import get_db
|
||||
from services.message_formatter import extract_quote
|
||||
|
||||
router = APIRouter(prefix="/api/topics", tags=["topics"])
|
||||
CHATLOG_BATCH_SIZE = 80
|
||||
NAME_LOOKUP_PAGE_SIZE = 500
|
||||
NAME_LOOKUP_MAX_ITEMS = 5000
|
||||
STALE_SUMMARIZE_MINUTES = 15
|
||||
|
||||
class TopicCreate(BaseModel):
|
||||
group_id: int
|
||||
title: str
|
||||
|
||||
class TopicPatch(BaseModel):
|
||||
title: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
class MessageAdd(BaseModel):
|
||||
msg_seq: int
|
||||
talker: str
|
||||
|
||||
|
||||
async def _mark_stale_summarize_tasks(db: aiosqlite.Connection, group_id: int, topic_id: int) -> None:
|
||||
error = "AI 报告生成任务超过 15 分钟未完成,已自动标记为失败,可重新生成"
|
||||
stale_window = f"-{STALE_SUMMARIZE_MINUTES} minutes"
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE ai_tasks
|
||||
SET status='error', error=?, updated_at=CURRENT_TIMESTAMP
|
||||
WHERE group_id=?
|
||||
AND type='summarize'
|
||||
AND status='running'
|
||||
AND datetime(updated_at) <= datetime('now', ?)
|
||||
""",
|
||||
(error, group_id, stale_window),
|
||||
)
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE topics
|
||||
SET status='error', updated_at=CURRENT_TIMESTAMP
|
||||
WHERE id=?
|
||||
AND status='processing'
|
||||
AND datetime(updated_at) <= datetime('now', ?)
|
||||
""",
|
||||
(topic_id, stale_window),
|
||||
)
|
||||
|
||||
|
||||
def _normalize_chatlog_message(item: dict, fallback_seq: int = 0) -> dict:
|
||||
contents = item.get("contents") or item.get("Contents") or {}
|
||||
if not isinstance(contents, dict):
|
||||
contents = {}
|
||||
media_key = (
|
||||
contents.get("rawmd5")
|
||||
or contents.get("md5")
|
||||
or contents.get("path")
|
||||
or item.get("media_key")
|
||||
or item.get("mediaKey")
|
||||
or ""
|
||||
)
|
||||
voice_key = (
|
||||
str(contents.get("voice"))
|
||||
if contents.get("voice")
|
||||
else item.get("voice_key") or item.get("voiceKey") or ""
|
||||
)
|
||||
raw_type = item.get("type") or item.get("Type") or 1
|
||||
raw_sub_type = item.get("sub_type") or item.get("subType") or item.get("SubType") or 0
|
||||
try:
|
||||
is_file = int(raw_type) == 49 and int(raw_sub_type) == 6
|
||||
except Exception:
|
||||
is_file = False
|
||||
file_md5 = str(contents.get("md5") or item.get("file_md5") or item.get("fileMd5") or "") if is_file else ""
|
||||
file_name = (
|
||||
contents.get("title")
|
||||
or contents.get("fileName")
|
||||
or contents.get("filename")
|
||||
or item.get("file_name")
|
||||
or item.get("fileName")
|
||||
or ""
|
||||
) if is_file else ""
|
||||
file_url = f"/api/files/{quote(file_md5, safe='')}?filename={quote(file_name or file_md5, safe='')}" if file_md5 else ""
|
||||
return {
|
||||
"seq": item.get("seq") or item.get("Seq") or item.get("sort_seq") or fallback_seq or 0,
|
||||
"sender": item.get("sender") or item.get("Sender") or "",
|
||||
"sender_name": (
|
||||
item.get("sender_name")
|
||||
or item.get("senderName")
|
||||
or item.get("SenderName")
|
||||
or item.get("sender")
|
||||
or item.get("Sender")
|
||||
or ""
|
||||
),
|
||||
"create_time": item.get("create_time") or item.get("time") or item.get("CreateTime") or "",
|
||||
"content": item.get("content") or item.get("Content") or "",
|
||||
"type": raw_type,
|
||||
"sub_type": raw_sub_type,
|
||||
"contents": contents,
|
||||
"media_key": media_key,
|
||||
"voice_key": voice_key,
|
||||
"image_path": media_key,
|
||||
"voice_path": voice_key,
|
||||
"video_path": media_key,
|
||||
"file_path": media_key,
|
||||
"link_url": contents.get("url") or item.get("link_url") or "",
|
||||
"link_title": contents.get("title") or item.get("link_title") or "",
|
||||
"link_desc": contents.get("desc") or item.get("link_desc") or "",
|
||||
"link_thumb": contents.get("thumbUrl") or contents.get("thumb_url") or item.get("link_thumb") or "",
|
||||
"link_source": contents.get("sourceName") or contents.get("source_name") or item.get("link_source") or "",
|
||||
"quote": item.get("quote") or extract_quote(item),
|
||||
"is_file": is_file,
|
||||
"file_name": file_name,
|
||||
"file_md5": file_md5,
|
||||
"file_url": file_url,
|
||||
}
|
||||
|
||||
|
||||
def _message_from_snapshot(raw: str | None, fallback_seq: int) -> dict | None:
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
item = json.loads(raw)
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(item, dict):
|
||||
return None
|
||||
return _normalize_chatlog_message(item, fallback_seq)
|
||||
|
||||
|
||||
def _looks_like_raw_sender_id(value: str | None) -> bool:
|
||||
value = (value or "").strip()
|
||||
return (
|
||||
not value
|
||||
or value.startswith("wxid_")
|
||||
or value.startswith("gh_")
|
||||
or value.endswith("@chatroom")
|
||||
or value.startswith("chatroom_")
|
||||
)
|
||||
|
||||
|
||||
def _sender_display_name(item: dict, sender: str = "") -> str:
|
||||
for key in (
|
||||
"sender_name",
|
||||
"senderName",
|
||||
"SenderName",
|
||||
"accountName",
|
||||
"groupNickname",
|
||||
"displayName",
|
||||
"nickName",
|
||||
"remark",
|
||||
):
|
||||
value = str(item.get(key) or "").strip()
|
||||
if value and value != sender and not _looks_like_raw_sender_id(value):
|
||||
return value
|
||||
return ""
|
||||
|
||||
|
||||
def _message_date(value: str | None) -> str | None:
|
||||
value = (value or "").strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00")).strftime("%Y-%m-%d")
|
||||
except Exception:
|
||||
return value[:10] if len(value) >= 10 and value[4:5] == "-" else None
|
||||
|
||||
|
||||
async def _build_sender_name_map(talker: str, messages: list[dict]) -> dict[str, str]:
|
||||
names: dict[str, str] = {}
|
||||
|
||||
for msg in messages:
|
||||
sender = str(msg.get("sender") or "").strip()
|
||||
name = _sender_display_name(msg, sender)
|
||||
if sender and name:
|
||||
names[sender] = name
|
||||
|
||||
missing = {
|
||||
str(msg.get("sender") or "").strip()
|
||||
for msg in messages
|
||||
if str(msg.get("sender") or "").strip()
|
||||
and _looks_like_raw_sender_id(str(msg.get("sender_name") or "").strip())
|
||||
}
|
||||
missing = {sender for sender in missing if sender not in names}
|
||||
if not missing:
|
||||
return names
|
||||
|
||||
from services.chatlog_client import chatlog_client
|
||||
|
||||
dates = sorted({
|
||||
date
|
||||
for msg in messages
|
||||
for date in [_message_date(str(msg.get("create_time") or ""))]
|
||||
if date
|
||||
})
|
||||
if dates:
|
||||
time_range = f"{dates[0]},{dates[-1]}"
|
||||
offset = 0
|
||||
seen = 0
|
||||
while missing and seen < NAME_LOOKUP_MAX_ITEMS:
|
||||
try:
|
||||
data = await chatlog_client.get_messages(
|
||||
talker,
|
||||
time=time_range,
|
||||
limit=NAME_LOOKUP_PAGE_SIZE,
|
||||
offset=offset,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch sender names talker={talker}: {e}")
|
||||
break
|
||||
items = data.get("items", []) or []
|
||||
if not items:
|
||||
break
|
||||
for item in items:
|
||||
sender = str(item.get("sender") or item.get("Sender") or "").strip()
|
||||
if sender not in missing:
|
||||
continue
|
||||
name = _sender_display_name(item, sender)
|
||||
if name:
|
||||
names[sender] = name
|
||||
missing.discard(sender)
|
||||
seen += len(items)
|
||||
offset += len(items)
|
||||
total = int(data.get("total") or 0)
|
||||
if total and offset >= total:
|
||||
break
|
||||
if len(items) < NAME_LOOKUP_PAGE_SIZE:
|
||||
break
|
||||
|
||||
if missing:
|
||||
try:
|
||||
members = await chatlog_client.get_chatroom_members(talker)
|
||||
raw_members = members.get("members", []) if isinstance(members, dict) else []
|
||||
for member in raw_members:
|
||||
sender = str(member.get("userName") or member.get("UserName") or "").strip()
|
||||
if sender not in missing:
|
||||
continue
|
||||
name = _sender_display_name(
|
||||
{
|
||||
"displayName": member.get("displayName") or member.get("DisplayName"),
|
||||
"nickName": member.get("nickName") or member.get("NickName"),
|
||||
"remark": member.get("remark") or member.get("Remark"),
|
||||
},
|
||||
sender,
|
||||
)
|
||||
if name:
|
||||
names[sender] = name
|
||||
missing.discard(sender)
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch chatroom members talker={talker}: {e}")
|
||||
|
||||
return names
|
||||
|
||||
|
||||
async def _fill_sender_names(talker: str, messages: list[dict]) -> None:
|
||||
if not talker or not messages:
|
||||
return
|
||||
names = await _build_sender_name_map(talker, messages)
|
||||
if not names:
|
||||
return
|
||||
for msg in messages:
|
||||
sender = str(msg.get("sender") or "").strip()
|
||||
if not sender or sender not in names:
|
||||
continue
|
||||
current = str(msg.get("sender_name") or "").strip()
|
||||
if not current or current == sender or _looks_like_raw_sender_id(current):
|
||||
msg["sender_name"] = names[sender]
|
||||
msg["senderName"] = names[sender]
|
||||
|
||||
@router.get("")
|
||||
async def list_topics(
|
||||
group_id: Optional[int] = None,
|
||||
status: Optional[str] = None,
|
||||
keyword: Optional[str] = None,
|
||||
db: aiosqlite.Connection = Depends(get_db)
|
||||
):
|
||||
sql = "SELECT * FROM topics WHERE 1=1"
|
||||
params = []
|
||||
if group_id:
|
||||
sql += " AND group_id=?"; params.append(group_id)
|
||||
if status:
|
||||
sql += " AND status=?"; params.append(status)
|
||||
if keyword:
|
||||
sql += " AND title LIKE ?"; params.append(f"%{keyword}%")
|
||||
async with db.execute(sql, params) as cur:
|
||||
return [dict(r) for r in await cur.fetchall()]
|
||||
|
||||
@router.post("")
|
||||
async def create_topic(body: TopicCreate, db: aiosqlite.Connection = Depends(get_db)):
|
||||
await db.execute(
|
||||
"INSERT INTO topics (group_id, title, source) VALUES (?, ?, 'manual')",
|
||||
(body.group_id, body.title),
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT * FROM topics ORDER BY id DESC LIMIT 1") as cur:
|
||||
return dict(await cur.fetchone())
|
||||
|
||||
@router.get("/{topic_id}")
|
||||
async def get_topic(topic_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT * FROM topics WHERE id=?", (topic_id,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "not found")
|
||||
|
||||
# 拿到该话题下关联的所有 seq
|
||||
async with db.execute("SELECT * FROM topic_messages WHERE topic_id=? ORDER BY msg_seq ASC", (topic_id,)) as cur:
|
||||
msg_rows = await cur.fetchall()
|
||||
|
||||
msgs = []
|
||||
if msg_rows:
|
||||
# 获取群聊 ID
|
||||
talker = msg_rows[0]["talker"]
|
||||
seq_list = [r["msg_seq"] for r in msg_rows]
|
||||
|
||||
# 分批调用 5030 接口获取真实消息内容,避免大批量 batch 触发 500。
|
||||
from services.chatlog_client import chatlog_client
|
||||
fetched_by_seq: dict[int, dict] = {}
|
||||
for i in range(0, len(seq_list), CHATLOG_BATCH_SIZE):
|
||||
chunk = seq_list[i: i + CHATLOG_BATCH_SIZE]
|
||||
try:
|
||||
msgs_data = await chatlog_client.get_messages_batch(talker, chunk)
|
||||
raw_items = msgs_data.get("items", []) or msgs_data.get("Items", [])
|
||||
for item in raw_items:
|
||||
normalized = _normalize_chatlog_message(item)
|
||||
seq = normalized.get("seq")
|
||||
if seq:
|
||||
fetched_by_seq[int(seq)] = normalized
|
||||
except Exception as e:
|
||||
print(f"Failed to fetch real messages chunk topic={topic_id}: {e}")
|
||||
|
||||
for r in msg_rows:
|
||||
seq = int(r["msg_seq"])
|
||||
if seq in fetched_by_seq:
|
||||
msgs.append(fetched_by_seq[seq])
|
||||
continue
|
||||
snap = _message_from_snapshot(r["message_json"] if "message_json" in r.keys() else None, seq)
|
||||
if snap:
|
||||
msgs.append(snap)
|
||||
continue
|
||||
msgs.append({
|
||||
"seq": seq,
|
||||
"sender": "",
|
||||
"sender_name": "系统提示",
|
||||
"create_time": "",
|
||||
"content": f"原始消息无法从 chatlog 找回 (seq: {seq})",
|
||||
"type": 1,
|
||||
"sub_type": 0,
|
||||
})
|
||||
|
||||
# 获取知识文档
|
||||
if msg_rows:
|
||||
await _fill_sender_names(talker, msgs)
|
||||
|
||||
async with db.execute("SELECT id, topic_id, content, created_at, updated_at FROM knowledge_docs WHERE topic_id=?", (topic_id,)) as cur:
|
||||
doc = await cur.fetchone()
|
||||
|
||||
return {**dict(row), "messages": msgs, "knowledge_doc": dict(doc) if doc else None}
|
||||
|
||||
@router.patch("/{topic_id}")
|
||||
async def patch_topic(topic_id: int, body: TopicPatch, db: aiosqlite.Connection = Depends(get_db)):
|
||||
if body.title:
|
||||
await db.execute(
|
||||
"UPDATE topics SET title=?, source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||||
(body.title, topic_id),
|
||||
)
|
||||
if body.status:
|
||||
await db.execute(
|
||||
"UPDATE topics SET status=?, source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||||
(body.status, topic_id),
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT * FROM topics WHERE id=?", (topic_id,)) as cur:
|
||||
return dict(await cur.fetchone())
|
||||
|
||||
@router.delete("/{topic_id}")
|
||||
async def delete_topic(topic_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
# 先拿到 doc_id,用于清 FTS
|
||||
async with db.execute("SELECT id FROM knowledge_docs WHERE topic_id=?", (topic_id,)) as cur:
|
||||
doc_row = await cur.fetchone()
|
||||
if doc_row:
|
||||
await db.execute("DELETE FROM knowledge_fts WHERE doc_id=?", (doc_row["id"],))
|
||||
await db.execute("DELETE FROM topic_messages WHERE topic_id=?", (topic_id,))
|
||||
await db.execute("DELETE FROM knowledge_docs WHERE topic_id=?", (topic_id,))
|
||||
await db.execute("DELETE FROM topics WHERE id=?", (topic_id,))
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
@router.post("/{topic_id}/messages")
|
||||
async def add_message(topic_id: int, body: MessageAdd, db: aiosqlite.Connection = Depends(get_db)):
|
||||
await db.execute(
|
||||
"INSERT OR IGNORE INTO topic_messages (topic_id, msg_seq, talker, added_by) VALUES (?, ?, ?, 'user')",
|
||||
(topic_id, body.msg_seq, body.talker)
|
||||
)
|
||||
await db.execute(
|
||||
"UPDATE topics SET source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||||
(topic_id,),
|
||||
)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
@router.delete("/{topic_id}/messages/{seq}")
|
||||
async def remove_message(topic_id: int, seq: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
await db.execute("DELETE FROM topic_messages WHERE topic_id=? AND msg_seq=?", (topic_id, seq))
|
||||
await db.execute(
|
||||
"UPDATE topics SET source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||||
(topic_id,),
|
||||
)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
@router.post("/{topic_id}/summarize")
|
||||
async def summarize(topic_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT * FROM topics WHERE id=?", (topic_id,)) as cur:
|
||||
topic = await cur.fetchone()
|
||||
if not topic:
|
||||
raise HTTPException(404, "not found")
|
||||
topic_data = dict(topic)
|
||||
await _mark_stale_summarize_tasks(db, topic_data["group_id"], topic_id)
|
||||
# 创建 ai_tasks 记录以追踪进度
|
||||
await db.execute(
|
||||
"INSERT INTO ai_tasks (group_id, type, status, progress) VALUES (?, 'summarize', 'running', ?)",
|
||||
(topic_data["group_id"], json.dumps({"processed": 0, "total": 1}))
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT last_insert_rowid() AS id") as cur:
|
||||
task_row = await cur.fetchone()
|
||||
task_id = task_row["id"]
|
||||
from services.summary_engine import run_summarize
|
||||
import asyncio
|
||||
asyncio.create_task(run_summarize(topic_id, topic_data, task_id))
|
||||
return {"ok": True, "task_id": task_id}
|
||||
13
chatlog_fastAPI/run_backend.py
Normal file
13
chatlog_fastAPI/run_backend.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from main import app
|
||||
|
||||
|
||||
def main():
|
||||
port = int(os.environ.get("CHATLAB_BACKEND_PORT", "8000"))
|
||||
uvicorn.run(app, host="127.0.0.1", port=port, reload=False, log_level="info")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
49
chatlog_fastAPI/scheduler.py
Normal file
49
chatlog_fastAPI/scheduler.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
APScheduler — 仅保留 wxid/数据库切换检测。
|
||||
(不再运行任何 AI 分类轮询:AI 分析改为用户手动按时间窗口触发)
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from database import update_db_path
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
scheduler = AsyncIOScheduler(timezone="Asia/Shanghai")
|
||||
_sync_failures = 0
|
||||
|
||||
|
||||
def register_poll_job(group_id: int, poll_interval: int):
|
||||
"""已废弃。保留空函数避免其他模块旧引用炸。"""
|
||||
log.debug(f"[scheduler] register_poll_job called (no-op now): group={group_id}")
|
||||
|
||||
|
||||
def _reschedule_sync(seconds: int):
|
||||
if scheduler.get_job("sync_jobs"):
|
||||
scheduler.remove_job("sync_jobs")
|
||||
scheduler.add_job(_sync_jobs, "interval", seconds=seconds, id="sync_jobs")
|
||||
|
||||
|
||||
async def _sync_jobs():
|
||||
"""定期触发 wxid 重新检测,让账号切换能自动切换数据库。"""
|
||||
global _sync_failures
|
||||
try:
|
||||
await update_db_path()
|
||||
if _sync_failures > 0:
|
||||
_sync_failures = 0
|
||||
_reschedule_sync(10)
|
||||
except Exception as e:
|
||||
_sync_failures += 1
|
||||
log.error(f"[scheduler] sync error (consecutive={_sync_failures}): {e}")
|
||||
if _sync_failures == 3:
|
||||
_reschedule_sync(60)
|
||||
log.warning("[scheduler] sync backoff to 60s after 3 failures")
|
||||
|
||||
|
||||
async def start_scheduler():
|
||||
scheduler.add_job(_sync_jobs, "interval", seconds=10, id="sync_jobs")
|
||||
scheduler.start()
|
||||
# Do not block FastAPI startup on chatlog. Electron starts the backend
|
||||
# before chatlog, so the first account sync must happen in the background.
|
||||
scheduler.add_job(_sync_jobs, "date", id="sync_jobs_initial")
|
||||
log.info("[scheduler] started (db-path watcher only, no poll jobs)")
|
||||
0
chatlog_fastAPI/services/__init__.py
Normal file
0
chatlog_fastAPI/services/__init__.py
Normal file
31
chatlog_fastAPI/services/ai_client.py
Normal file
31
chatlog_fastAPI/services/ai_client.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from services.runtime_settings import get_ai_settings
|
||||
|
||||
_client_cache: dict[tuple[str, str], AsyncOpenAI] = {}
|
||||
_http_client_cache: dict[tuple[str, str], httpx.AsyncClient] = {}
|
||||
|
||||
|
||||
async def get_openai_client() -> tuple[AsyncOpenAI, dict]:
|
||||
settings = await get_ai_settings()
|
||||
cache_key = (
|
||||
settings.get("ai_base_url") or "",
|
||||
settings.get("ai_api_key") or "",
|
||||
)
|
||||
|
||||
if cache_key not in _client_cache:
|
||||
for http_client in _http_client_cache.values():
|
||||
await http_client.aclose()
|
||||
_client_cache.clear()
|
||||
_http_client_cache.clear()
|
||||
|
||||
http_client = httpx.AsyncClient(timeout=httpx.Timeout(600.0, connect=30.0))
|
||||
_http_client_cache[cache_key] = http_client
|
||||
_client_cache[cache_key] = AsyncOpenAI(
|
||||
api_key=settings.get("ai_api_key") or "missing",
|
||||
base_url=settings.get("ai_base_url"),
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
return _client_cache[cache_key], settings
|
||||
203
chatlog_fastAPI/services/chatlog_client.py
Normal file
203
chatlog_fastAPI/services/chatlog_client.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import httpx
|
||||
import asyncio
|
||||
from typing import List
|
||||
from config import settings
|
||||
|
||||
|
||||
class ChatlogHTTPError(RuntimeError):
|
||||
def __init__(self, status_code: int, method: str, path: str, detail: str):
|
||||
self.status_code = status_code
|
||||
self.method = method
|
||||
self.path = path
|
||||
self.detail = detail
|
||||
super().__init__(f"chatlog HTTP {status_code}: {method} {path} body={detail!r}")
|
||||
|
||||
|
||||
class MessageIndexNotReady(RuntimeError):
|
||||
"""Raised when chatlog has sessions but its message time index is not usable yet."""
|
||||
|
||||
|
||||
class ChatlogClient:
|
||||
def __init__(self):
|
||||
self.base = settings.chatlog_base_url
|
||||
self._contact_db_file = None
|
||||
|
||||
async def _get(self, path: str, params: dict, timeout: float = 30.0) -> dict:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout, trust_env=False) as client:
|
||||
r = await client.get(f"{self.base}{path}", params=params)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except httpx.TimeoutException:
|
||||
raise RuntimeError(f"chatlog timeout: GET {path}")
|
||||
except httpx.HTTPStatusError as e:
|
||||
detail = self._response_detail(e.response)
|
||||
raise ChatlogHTTPError(e.response.status_code, "GET", path, detail)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"chatlog request failed: {e}")
|
||||
|
||||
async def _post(self, path: str, body: dict, timeout: float = 30.0) -> dict:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout, trust_env=False) as client:
|
||||
r = await client.post(f"{self.base}{path}", json=body)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except httpx.TimeoutException:
|
||||
raise RuntimeError(f"chatlog timeout: POST {path}")
|
||||
except httpx.HTTPStatusError as e:
|
||||
detail = self._response_detail(e.response)
|
||||
raise ChatlogHTTPError(e.response.status_code, "POST", path, detail)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"chatlog request failed: {e}")
|
||||
|
||||
def _response_detail(self, response: httpx.Response) -> str:
|
||||
try:
|
||||
body = response.json()
|
||||
if isinstance(body, dict):
|
||||
return str(body.get("error") or body.get("detail") or body)
|
||||
return str(body)
|
||||
except Exception:
|
||||
return response.text
|
||||
|
||||
async def get_messages(
|
||||
self,
|
||||
talker: str,
|
||||
time: str = "",
|
||||
sender: str = "",
|
||||
keyword: str = "",
|
||||
min_seq: int = 0,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> dict:
|
||||
params: dict = {
|
||||
"talker": talker,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"format": "json",
|
||||
}
|
||||
if time:
|
||||
params["time"] = time
|
||||
else:
|
||||
params["time"] = "1970-01-01,2099-12-31"
|
||||
if sender:
|
||||
params["sender"] = sender
|
||||
if keyword:
|
||||
params["keyword"] = keyword
|
||||
if min_seq > 0:
|
||||
params["min_seq"] = min_seq
|
||||
|
||||
try:
|
||||
data = await self._get("/api/v1/chatlog", params)
|
||||
except ChatlogHTTPError as e:
|
||||
detail = e.detail.lower()
|
||||
if e.status_code == 404 and "time range not found" in detail:
|
||||
await asyncio.sleep(0.2)
|
||||
try:
|
||||
data = await self._get("/api/v1/chatlog", params)
|
||||
except ChatlogHTTPError as retry_error:
|
||||
if (
|
||||
retry_error.status_code == 404
|
||||
and "time range not found" in retry_error.detail.lower()
|
||||
):
|
||||
raise MessageIndexNotReady(
|
||||
"自动解密仍在处理消息库,请稍后刷新聊天记录;如果长时间为空,请在微信里打开该聊天并翻看历史消息。"
|
||||
) from retry_error
|
||||
raise
|
||||
elif e.status_code == 404 and "not found" in detail:
|
||||
# chatlog sometimes reports a valid date window as missing while it is warming/querying.
|
||||
await asyncio.sleep(0.2)
|
||||
try:
|
||||
data = await self._get("/api/v1/chatlog", params)
|
||||
except ChatlogHTTPError as retry_error:
|
||||
retry_detail = retry_error.detail.lower()
|
||||
if (
|
||||
retry_error.status_code == 404
|
||||
and "time range not found" in retry_detail
|
||||
):
|
||||
raise MessageIndexNotReady(
|
||||
"自动解密仍在处理消息库,请稍后刷新聊天记录;如果长时间为空,请在微信里打开该聊天并翻看历史消息。"
|
||||
) from retry_error
|
||||
if retry_error.status_code == 404 and "not found" in retry_detail:
|
||||
return {"total": 0, "items": []}
|
||||
raise
|
||||
else:
|
||||
raise
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
return {"total": len(data), "items": data}
|
||||
|
||||
async def get_message(self, talker: str, seq: int) -> dict | None:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0, trust_env=False) as client:
|
||||
r = await client.get(
|
||||
f"{self.base}/api/v1/chatlog/message",
|
||||
params={"talker": talker, "seq": seq},
|
||||
)
|
||||
if r.status_code == 404:
|
||||
return None
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
except httpx.TimeoutException:
|
||||
raise RuntimeError("chatlog timeout: get_message")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"chatlog request failed: {e}")
|
||||
|
||||
async def get_messages_batch(self, talker: str, seqs: List[int]) -> dict:
|
||||
return await self._post("/api/v1/chatlog/batch", {"talker": talker, "seqs": seqs})
|
||||
|
||||
async def get_chatrooms(self, keyword: str = "", limit: int = 100, offset: int = 0) -> dict:
|
||||
params: dict = {"limit": limit, "offset": offset, "format": "json"}
|
||||
if keyword:
|
||||
params["keyword"] = keyword
|
||||
return await self._get("/api/v1/chatroom", params, timeout=10.0)
|
||||
|
||||
async def get_contacts(self, keyword: str = "", limit: int = 100, offset: int = 0) -> dict:
|
||||
params: dict = {"limit": limit, "offset": offset, "format": "json"}
|
||||
if keyword:
|
||||
params["keyword"] = keyword
|
||||
return await self._get("/api/v1/contact", params, timeout=10.0)
|
||||
|
||||
async def get_chatroom_members(self, talker: str, time: str = "") -> dict:
|
||||
params: dict = {"talker": talker}
|
||||
if time:
|
||||
params["time"] = time
|
||||
return await self._get("/api/v1/chatroom/members", params)
|
||||
|
||||
async def get_sessions(self, keyword: str = "", limit: int = 500) -> list:
|
||||
params: dict = {"limit": limit, "format": "json"}
|
||||
if keyword:
|
||||
params["keyword"] = keyword
|
||||
data = await self._get("/api/v1/session", params, timeout=15.0)
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
return data.get("items", data.get("data", []))
|
||||
|
||||
|
||||
async def get_avatar_url(self, wxid: str) -> str:
|
||||
if self._contact_db_file is None:
|
||||
try:
|
||||
db_list = await self._get("/api/v1/db", {})
|
||||
self._contact_db_file = (db_list.get("contact") or [""])[0]
|
||||
except Exception:
|
||||
self._contact_db_file = ""
|
||||
if not self._contact_db_file:
|
||||
return ""
|
||||
safe_wxid = wxid.replace("'", "''")
|
||||
sql = f"SELECT small_head_url, big_head_url FROM contact WHERE username='{safe_wxid}' LIMIT 1"
|
||||
params = {"group": "contact", "file": self._contact_db_file, "sql": sql}
|
||||
try:
|
||||
rows = await self._get("/api/v1/db/query", params, timeout=5.0)
|
||||
if rows:
|
||||
url = rows[0].get("small_head_url") or rows[0].get("big_head_url") or ""
|
||||
if url:
|
||||
return url
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
async def get_db_paths(self) -> dict:
|
||||
data = await self._get("/api/v1/db", {}, timeout=10.0)
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
chatlog_client = ChatlogClient()
|
||||
35
chatlog_fastAPI/services/chatlog_context.py
Normal file
35
chatlog_fastAPI/services/chatlog_context.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatlogContext:
|
||||
account: str = ""
|
||||
work_dir: str = ""
|
||||
data_dir: str = ""
|
||||
platform: str = "windows"
|
||||
version: int = 4
|
||||
chatlog_exe: str = ""
|
||||
chatlog_version: str = ""
|
||||
|
||||
|
||||
_context = ChatlogContext()
|
||||
|
||||
|
||||
def update_chatlog_context(payload: dict) -> dict:
|
||||
global _context
|
||||
_context = ChatlogContext(
|
||||
account=str(payload.get("account") or ""),
|
||||
work_dir=str(payload.get("workDir") or payload.get("work_dir") or ""),
|
||||
data_dir=str(payload.get("dataDir") or payload.get("data_dir") or ""),
|
||||
platform=str(payload.get("platform") or "windows"),
|
||||
version=int(payload.get("version") or 4),
|
||||
chatlog_exe=str(payload.get("chatlogExe") or payload.get("chatlog_exe") or ""),
|
||||
chatlog_version=str(payload.get("chatlogVersion") or payload.get("chatlog_version") or ""),
|
||||
)
|
||||
return get_chatlog_context()
|
||||
|
||||
|
||||
def get_chatlog_context() -> dict:
|
||||
return asdict(_context)
|
||||
25
chatlog_fastAPI/services/fts.py
Normal file
25
chatlog_fastAPI/services/fts.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import jieba
|
||||
import re
|
||||
|
||||
def tokenize(text: str) -> str:
|
||||
return " ".join(jieba.cut(text))
|
||||
|
||||
|
||||
def build_match_query(text: str, limit: int = 12) -> str:
|
||||
"""Build a safe FTS5 MATCH query from user/model text."""
|
||||
terms: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for token in tokenize(text or "").split():
|
||||
token = token.strip()
|
||||
if not token or not re.search(r"\w", token, flags=re.UNICODE):
|
||||
continue
|
||||
upper = token.upper()
|
||||
if upper in {"AND", "OR", "NOT", "NEAR"}:
|
||||
continue
|
||||
if token in seen:
|
||||
continue
|
||||
seen.add(token)
|
||||
terms.append('"' + token.replace('"', '""') + '"')
|
||||
if len(terms) >= limit:
|
||||
break
|
||||
return " OR ".join(terms)
|
||||
142
chatlog_fastAPI/services/media_parser.py
Normal file
142
chatlog_fastAPI/services/media_parser.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import base64
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from services.ai_client import get_openai_client
|
||||
from services.media_resolver import resolve_media
|
||||
from services.runtime_settings import get_ai_settings
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_ai_client():
|
||||
return await get_openai_client()
|
||||
|
||||
|
||||
async def parse_media(kind: str, key: str) -> dict:
|
||||
"""
|
||||
Parse one chatlog media object into text.
|
||||
|
||||
kind: voice, image, or video.
|
||||
key: chatlog media key.
|
||||
"""
|
||||
if kind not in {"voice", "image", "video"}:
|
||||
raise HTTPException(400, "不支持的媒体类型")
|
||||
if not key:
|
||||
raise HTTPException(400, "媒体 key 不能为空")
|
||||
|
||||
ai = await get_ai_settings()
|
||||
if not ai.get("ai_api_key"):
|
||||
raise HTTPException(503, "AI 服务未配置,请在设置页填写 AI API Key")
|
||||
if kind == "voice" and not ai.get("voice_model"):
|
||||
raise HTTPException(503, "语音模型未配置,请在设置页填写语音模型名称,例如 paraformer-v2")
|
||||
if kind in ("image", "video") and not ai.get("vision_model"):
|
||||
raise HTTPException(503, "视觉模型未配置,请在设置页填写视觉模型名称,例如 qwen-vl-plus")
|
||||
|
||||
media = await resolve_media(kind, key)
|
||||
if kind == "voice":
|
||||
return {"text": await _parse_voice(media.bytes, media.content_type)}
|
||||
return {"text": await _parse_visual(kind, media.bytes, media.content_type)}
|
||||
|
||||
|
||||
async def _parse_voice(media_bytes: bytes, content_type: str) -> str:
|
||||
b64_audio = base64.b64encode(media_bytes).decode()
|
||||
audio_ct = content_type.lower()
|
||||
if "silk" in audio_ct or "x-silk" in audio_ct:
|
||||
audio_mime = "audio/silk"
|
||||
elif "amr" in audio_ct:
|
||||
audio_mime = "audio/amr"
|
||||
elif "ogg" in audio_ct or "opus" in audio_ct:
|
||||
audio_mime = "audio/ogg"
|
||||
elif "wav" in audio_ct:
|
||||
audio_mime = "audio/wav"
|
||||
else:
|
||||
audio_mime = "audio/mpeg"
|
||||
|
||||
data_uri = f"data:{audio_mime};base64,{b64_audio}"
|
||||
_, ai = await _get_ai_client()
|
||||
asr_headers = {
|
||||
"Authorization": f"Bearer {ai['ai_api_key']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=60) as http:
|
||||
submit = await http.post(
|
||||
"https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription",
|
||||
headers={**asr_headers, "X-DashScope-Async": "enable"},
|
||||
json={
|
||||
"model": ai["voice_model"],
|
||||
"input": {"file_urls": [data_uri]},
|
||||
"parameters": {"language_hints": ["zh", "en"]},
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
submit_data = submit.json()
|
||||
if submit.status_code not in (200, 201):
|
||||
raise HTTPException(500, f"提交识别任务失败: {submit_data.get('message', submit_data)}")
|
||||
|
||||
task_id = submit_data.get("output", {}).get("task_id")
|
||||
if not task_id:
|
||||
raise HTTPException(500, f"未获取到 task_id: {submit_data}")
|
||||
|
||||
for _ in range(30):
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(1)
|
||||
poll = await http.get(
|
||||
f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}",
|
||||
headers=asr_headers,
|
||||
timeout=10,
|
||||
)
|
||||
poll_data = poll.json()
|
||||
status = poll_data.get("output", {}).get("task_status", "")
|
||||
if status == "SUCCEEDED":
|
||||
results = poll_data.get("output", {}).get("results", [])
|
||||
log.info("[media_parser] ASR SUCCEEDED results: %s", results)
|
||||
if not results:
|
||||
return "(识别结果为空)"
|
||||
trans_url = results[0].get("transcription_url", "")
|
||||
if trans_url:
|
||||
trans_resp = await http.get(trans_url, timeout=10)
|
||||
trans_data = trans_resp.json()
|
||||
log.info("[media_parser] transcription_url content: %s", str(trans_data)[:500])
|
||||
transcripts = trans_data.get("transcripts", [])
|
||||
text = transcripts[0].get("text", "") if transcripts else ""
|
||||
else:
|
||||
text = results[0].get("transcription", "")
|
||||
return text or "(识别结果为空)"
|
||||
if status in ("FAILED", "CANCELLED"):
|
||||
raise HTTPException(500, f"识别任务失败: {poll_data.get('output', {}).get('message', status)}")
|
||||
|
||||
raise HTTPException(500, "语音识别超时(30秒)")
|
||||
|
||||
|
||||
async def _parse_visual(kind: str, media_bytes: bytes, content_type: str) -> str:
|
||||
b64 = base64.b64encode(media_bytes).decode()
|
||||
ct = content_type.lower()
|
||||
if "png" in ct:
|
||||
mime = "image/png"
|
||||
elif "webp" in ct:
|
||||
mime = "image/webp"
|
||||
else:
|
||||
mime = "image/jpeg"
|
||||
data_url = f"data:{mime};base64,{b64}"
|
||||
prompt = "请用中文简洁描述这张图片的内容。" if kind == "image" else "请用中文简洁描述这个视频截图的内容。"
|
||||
|
||||
client, ai = await _get_ai_client()
|
||||
resp_ai = await client.chat.completions.create(
|
||||
model=ai["vision_model"],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": data_url}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=300,
|
||||
)
|
||||
return resp_ai.choices[0].message.content or ""
|
||||
174
chatlog_fastAPI/services/media_resolver.py
Normal file
174
chatlog_fastAPI/services/media_resolver.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from config import settings
|
||||
from services.chatlog_context import get_chatlog_context
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedMedia:
|
||||
bytes: bytes
|
||||
content_type: str
|
||||
url: str
|
||||
|
||||
|
||||
def _media_url(kind: str, key: str, thumb: bool = False) -> str:
|
||||
url = f"{settings.chatlog_base_url}/{kind}/{key}"
|
||||
if thumb:
|
||||
url += "?thumb=1"
|
||||
return url
|
||||
|
||||
|
||||
def _read_voice_resource_status(key: str) -> dict:
|
||||
ctx = get_chatlog_context()
|
||||
work_dir = ctx.get("work_dir") or ""
|
||||
if not work_dir:
|
||||
return {"checked": False, "reason": "missing_work_dir"}
|
||||
|
||||
db_path = Path(work_dir) / "db_storage" / "message" / "message_resource.db"
|
||||
if not db_path.exists():
|
||||
return {"checked": False, "reason": "message_resource_db_missing", "path": str(db_path)}
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(f"file:{db_path.as_posix()}?mode=ro", uri=True)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
info = conn.execute(
|
||||
"SELECT * FROM MessageResourceInfo WHERE message_svr_id=?",
|
||||
(int(key),),
|
||||
).fetchone()
|
||||
if not info:
|
||||
return {
|
||||
"checked": True,
|
||||
"found": False,
|
||||
"path": str(db_path),
|
||||
"message": "当前已解密资源库里没有这条语音的媒体资源记录",
|
||||
}
|
||||
details = conn.execute(
|
||||
"SELECT type,size,status,data_index FROM MessageResourceDetail WHERE message_id=?",
|
||||
(info["message_id"],),
|
||||
).fetchall()
|
||||
return {
|
||||
"checked": True,
|
||||
"found": True,
|
||||
"path": str(db_path),
|
||||
"message_id": info["message_id"],
|
||||
"resources": [dict(row) for row in details],
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
except Exception as exc:
|
||||
return {"checked": False, "reason": "resource_db_read_failed", "error": str(exc), "path": str(db_path)}
|
||||
|
||||
|
||||
def _download_failure_message(kind: str, key: str, status_code: int | None, body: str = "") -> str:
|
||||
if kind == "voice":
|
||||
base = "底层语音文件未读取成功"
|
||||
if status_code:
|
||||
base += f"(chatlog /voice 返回 HTTP {status_code})"
|
||||
return (
|
||||
f"{base}。请先确认已安装新版程序并重新识别当前微信账号;"
|
||||
"如果仍失败,说明当前 chatlog 版本还不能解析该 WeChat 4.x 语音资源。"
|
||||
)
|
||||
if status_code:
|
||||
return f"从 chatlog 下载媒体失败: HTTP {status_code}"
|
||||
return f"从 chatlog 下载媒体失败: {body or 'unknown error'}"
|
||||
|
||||
|
||||
async def diagnose_media(kind: str, key: str) -> dict:
|
||||
if kind not in {"voice", "image", "video"}:
|
||||
raise HTTPException(400, "不支持的媒体类型")
|
||||
if not key:
|
||||
raise HTTPException(400, "媒体 key 不能为空")
|
||||
|
||||
url = _media_url(kind, key, thumb=kind in {"image", "video"})
|
||||
result = {
|
||||
"ok": False,
|
||||
"kind": kind,
|
||||
"key": key,
|
||||
"url": url,
|
||||
"chatlog_base_url": settings.chatlog_base_url,
|
||||
"chatlog_context": get_chatlog_context(),
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=20, trust_env=False, follow_redirects=True) as client:
|
||||
try:
|
||||
resp = await client.get(url)
|
||||
content_type = resp.headers.get("content-type", "")
|
||||
result.update(
|
||||
{
|
||||
"status_code": resp.status_code,
|
||||
"content_type": content_type,
|
||||
"content_length": len(resp.content or b""),
|
||||
"ok": resp.status_code < 400 and bool(resp.content),
|
||||
}
|
||||
)
|
||||
if resp.status_code >= 400:
|
||||
result["error"] = _download_failure_message(kind, key, resp.status_code, resp.text[:500])
|
||||
result["response_preview"] = resp.text[:500]
|
||||
elif not resp.content:
|
||||
result["error"] = "chatlog 返回了空媒体文件"
|
||||
except Exception as exc:
|
||||
result.update({"error": f"无法连接 chatlog 媒体接口: {exc}", "exception": str(exc)})
|
||||
|
||||
if kind == "voice":
|
||||
result["resource_db"] = _read_voice_resource_status(key)
|
||||
return result
|
||||
|
||||
|
||||
async def resolve_media(kind: str, key: str) -> ResolvedMedia:
|
||||
if kind not in {"voice", "image", "video"}:
|
||||
raise HTTPException(400, "不支持的媒体类型")
|
||||
if not key:
|
||||
raise HTTPException(400, "媒体 key 不能为空")
|
||||
|
||||
url = _media_url(kind, key, thumb=kind in {"image", "video"})
|
||||
async with httpx.AsyncClient(timeout=60, trust_env=False, follow_redirects=True) as client:
|
||||
try:
|
||||
resp = await client.get(url)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
diagnostics = await diagnose_media(kind, key)
|
||||
log.warning("[media_resolver] media download failed: %s", diagnostics)
|
||||
raise HTTPException(
|
||||
502,
|
||||
{
|
||||
"message": _download_failure_message(kind, key, exc.response.status_code, exc.response.text[:500]),
|
||||
"diagnostics": diagnostics,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
diagnostics = await diagnose_media(kind, key)
|
||||
log.warning("[media_resolver] media download exception: %s", diagnostics)
|
||||
raise HTTPException(
|
||||
502,
|
||||
{
|
||||
"message": _download_failure_message(kind, key, None, str(exc)),
|
||||
"diagnostics": diagnostics,
|
||||
},
|
||||
)
|
||||
|
||||
if not resp.content:
|
||||
diagnostics = await diagnose_media(kind, key)
|
||||
raise HTTPException(
|
||||
502,
|
||||
{
|
||||
"message": "chatlog 返回了空媒体文件",
|
||||
"diagnostics": diagnostics,
|
||||
},
|
||||
)
|
||||
|
||||
return ResolvedMedia(
|
||||
bytes=resp.content,
|
||||
content_type=resp.headers.get("content-type", "application/octet-stream"),
|
||||
url=url,
|
||||
)
|
||||
253
chatlog_fastAPI/services/message_formatter.py
Normal file
253
chatlog_fastAPI/services/message_formatter.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import html
|
||||
import json
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Any
|
||||
|
||||
|
||||
QUOTE_CONTENT_LIMIT = 600
|
||||
|
||||
|
||||
def extract_contents(item: dict) -> dict:
|
||||
contents = item.get("contents") or item.get("Contents") or {}
|
||||
return contents if isinstance(contents, dict) else {}
|
||||
|
||||
|
||||
def clean_message_text(value: Any) -> str:
|
||||
text = html.unescape(str(value or "")).strip()
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
if len(text) > QUOTE_CONTENT_LIMIT:
|
||||
text = text[:QUOTE_CONTENT_LIMIT] + "..."
|
||||
return text
|
||||
|
||||
|
||||
def _local_name(tag: str) -> str:
|
||||
return tag.rsplit("}", 1)[-1]
|
||||
|
||||
|
||||
def _safe_int(value: Any) -> int | None:
|
||||
if value in (None, ""):
|
||||
return None
|
||||
try:
|
||||
return int(str(value).strip())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _first(data: dict, *keys: str) -> Any:
|
||||
for key in keys:
|
||||
value = data.get(key)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _has_quote_indicator(data: dict) -> bool:
|
||||
keys = {str(key) for key in data.keys()}
|
||||
indicators = {
|
||||
"quote",
|
||||
"refermsg",
|
||||
"referMsg",
|
||||
"refer",
|
||||
"recordInfo",
|
||||
"recordinfo",
|
||||
"fromusr",
|
||||
"fromUser",
|
||||
"chatusr",
|
||||
"chatUser",
|
||||
"displayname",
|
||||
"displayName",
|
||||
"referContent",
|
||||
"svrid",
|
||||
"newmsgid",
|
||||
"newMsgId",
|
||||
}
|
||||
return bool(keys & indicators)
|
||||
|
||||
|
||||
def _decode_json(value: str) -> Any:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _xml_node_text(node: ET.Element, names: set[str]) -> str:
|
||||
for child in node.iter():
|
||||
if _local_name(child.tag) in names:
|
||||
text = "".join(child.itertext()).strip()
|
||||
if text:
|
||||
return text
|
||||
return ""
|
||||
|
||||
|
||||
def _quote_from_xml(value: str) -> dict | None:
|
||||
text = html.unescape(value or "").strip()
|
||||
if "<" not in text or ">" not in text:
|
||||
return None
|
||||
try:
|
||||
root = ET.fromstring(text)
|
||||
except Exception:
|
||||
try:
|
||||
root = ET.fromstring(f"<root>{text}</root>")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
refer_node = None
|
||||
for node in root.iter():
|
||||
if _local_name(node.tag).lower() == "refermsg":
|
||||
refer_node = node
|
||||
break
|
||||
if refer_node is None:
|
||||
return None
|
||||
|
||||
content = _xml_node_text(refer_node, {"content", "title", "desc"})
|
||||
sender_name = _xml_node_text(refer_node, {"displayname", "nickname", "fromnickname"})
|
||||
sender = _xml_node_text(refer_node, {"fromusr", "chatusr", "sender"})
|
||||
msg_type = _safe_int(_xml_node_text(refer_node, {"type"}))
|
||||
seq = _safe_int(_xml_node_text(refer_node, {"seq", "msgid", "newmsgid", "svrid"}))
|
||||
|
||||
return _normalize_quote(
|
||||
{
|
||||
"sender": sender,
|
||||
"sender_name": sender_name,
|
||||
"content": content,
|
||||
"type": msg_type,
|
||||
"seq": seq,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _find_quote_payload(value: Any, allow_plain_text: bool = False) -> dict | None:
|
||||
if value in (None, ""):
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return None
|
||||
decoded = _decode_json(text) if text[:1] in ("{", "[") else None
|
||||
if decoded is not None:
|
||||
return _find_quote_payload(decoded, allow_plain_text=allow_plain_text)
|
||||
xml_quote = _quote_from_xml(text)
|
||||
if xml_quote:
|
||||
return xml_quote
|
||||
if allow_plain_text:
|
||||
return _normalize_quote({"content": text})
|
||||
return None
|
||||
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
quote = _find_quote_payload(item, allow_plain_text=allow_plain_text)
|
||||
if quote:
|
||||
return quote
|
||||
return None
|
||||
|
||||
if not isinstance(value, dict):
|
||||
return None
|
||||
|
||||
for key in ("quote", "refermsg", "referMsg", "refer", "recordInfo", "recordinfo"):
|
||||
if key in value:
|
||||
quote = _find_quote_payload(value.get(key), allow_plain_text=True)
|
||||
if quote:
|
||||
return quote
|
||||
|
||||
quote = _normalize_quote(value) if allow_plain_text or _has_quote_indicator(value) else None
|
||||
if quote:
|
||||
return quote
|
||||
|
||||
for nested in value.values():
|
||||
quote = _find_quote_payload(nested, allow_plain_text=False)
|
||||
if quote:
|
||||
return quote
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_quote(data: dict) -> dict | None:
|
||||
content = clean_message_text(
|
||||
_first(
|
||||
data,
|
||||
"content",
|
||||
"Content",
|
||||
"text",
|
||||
"title",
|
||||
"desc",
|
||||
"digest",
|
||||
"displayContent",
|
||||
"referContent",
|
||||
)
|
||||
)
|
||||
if not content:
|
||||
return None
|
||||
|
||||
sender = clean_message_text(
|
||||
_first(data, "sender", "Sender", "fromusr", "fromUser", "chatusr", "chatUser", "from")
|
||||
)
|
||||
sender_name = clean_message_text(
|
||||
_first(data, "sender_name", "senderName", "SenderName", "displayname", "displayName", "nickname", "nickName")
|
||||
)
|
||||
msg_type = _safe_int(_first(data, "type", "Type", "msgType", "subType"))
|
||||
seq = _safe_int(_first(data, "seq", "Seq", "sort_seq", "msgid", "msgId", "newmsgid", "newMsgId", "svrid"))
|
||||
|
||||
return {
|
||||
"sender": sender,
|
||||
"sender_name": sender_name,
|
||||
"content": content,
|
||||
"type": msg_type,
|
||||
"seq": seq,
|
||||
}
|
||||
|
||||
|
||||
def extract_quote(item: dict | None) -> dict | None:
|
||||
if not isinstance(item, dict):
|
||||
return None
|
||||
|
||||
contents = extract_contents(item)
|
||||
explicit_sources = (
|
||||
item.get("quote"),
|
||||
item.get("Quote"),
|
||||
item.get("refer"),
|
||||
item.get("recordInfo"),
|
||||
contents.get("quote"),
|
||||
contents.get("refer"),
|
||||
contents.get("refermsg"),
|
||||
contents.get("referMsg"),
|
||||
contents.get("recordInfo"),
|
||||
contents.get("recordinfo"),
|
||||
)
|
||||
for source in explicit_sources:
|
||||
quote = _find_quote_payload(source, allow_plain_text=True)
|
||||
if quote:
|
||||
return quote
|
||||
|
||||
for source in (
|
||||
contents.get("appmsg"),
|
||||
item.get("content"),
|
||||
item.get("Content"),
|
||||
):
|
||||
quote = _find_quote_payload(source, allow_plain_text=False)
|
||||
if quote:
|
||||
return quote
|
||||
return None
|
||||
|
||||
|
||||
def attach_quote(item: dict) -> dict:
|
||||
item["quote"] = extract_quote(item)
|
||||
return item
|
||||
|
||||
|
||||
def quote_to_text(quote: dict | None) -> str:
|
||||
if not quote:
|
||||
return ""
|
||||
sender = quote.get("sender_name") or quote.get("sender") or "未知"
|
||||
seq = quote.get("seq")
|
||||
seq_text = f" seq={seq}" if seq else ""
|
||||
return f"[引用消息{seq_text}] {sender}: {quote.get('content') or ''}".strip()
|
||||
|
||||
|
||||
def append_quote_text(base_text: str, item: dict) -> str:
|
||||
parts = [base_text.strip()] if base_text and base_text.strip() else []
|
||||
quote_text = quote_to_text(extract_quote(item))
|
||||
if quote_text:
|
||||
parts.append(quote_text)
|
||||
return ";".join(parts)
|
||||
139
chatlog_fastAPI/services/report_learning.py
Normal file
139
chatlog_fastAPI/services/report_learning.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import re
|
||||
import aiosqlite
|
||||
|
||||
from services.fts import build_match_query
|
||||
|
||||
MAX_EXAMPLES = 3
|
||||
MAX_EXAMPLE_CHARS = 1800
|
||||
MAX_CONTEXT_CHARS = 5200
|
||||
|
||||
|
||||
def _compact(text: str, limit: int = MAX_EXAMPLE_CHARS) -> str:
|
||||
text = re.sub(r"\n{3,}", "\n\n", (text or "").strip())
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
return text[:limit].rstrip() + "\n..."
|
||||
|
||||
|
||||
def _format_examples(rows: list[aiosqlite.Row], purpose: str) -> str:
|
||||
if not rows:
|
||||
return ""
|
||||
heading = {
|
||||
"topic": "历史人工修订报告参考(用于学习话题命名和分类口径)",
|
||||
"summary": "历史人工修订报告参考(只学习结构、措辞和关注点,不得照抄历史事实)",
|
||||
}.get(purpose, "历史人工修订报告参考")
|
||||
parts = [heading]
|
||||
total = len(parts[0])
|
||||
for idx, row in enumerate(rows, 1):
|
||||
block = (
|
||||
f"\n\n--- 示例 {idx} ---\n"
|
||||
f"群聊:{row['group_name'] or row['talker'] or row['group_id']}\n"
|
||||
f"话题标题:{row['title']}\n"
|
||||
f"报告内容:\n{_compact(row['content'])}"
|
||||
)
|
||||
if total + len(block) > MAX_CONTEXT_CHARS:
|
||||
break
|
||||
parts.append(block)
|
||||
total += len(block)
|
||||
return "".join(parts).strip()
|
||||
|
||||
|
||||
async def build_report_learning_context(
|
||||
db: aiosqlite.Connection,
|
||||
*,
|
||||
group_id: int | None,
|
||||
query: str = "",
|
||||
exclude_topic_id: int | None = None,
|
||||
purpose: str = "summary",
|
||||
limit: int = MAX_EXAMPLES,
|
||||
) -> str:
|
||||
params: list[object] = []
|
||||
exclude_sql = ""
|
||||
if exclude_topic_id is not None:
|
||||
exclude_sql = " AND t.id<>?"
|
||||
params.append(exclude_topic_id)
|
||||
|
||||
selected: list[aiosqlite.Row] = []
|
||||
seen_doc_ids: set[int] = set()
|
||||
|
||||
if group_id is not None:
|
||||
async with db.execute(
|
||||
f"""
|
||||
SELECT k.id, k.content, k.updated_at, t.id AS topic_id, t.title, t.group_id,
|
||||
g.name AS group_name, g.talker
|
||||
FROM knowledge_docs k
|
||||
JOIN topics t ON t.id = k.topic_id
|
||||
LEFT JOIN groups g ON g.id = t.group_id
|
||||
WHERE k.curated_at IS NOT NULL
|
||||
AND t.group_id=?
|
||||
{exclude_sql}
|
||||
ORDER BY k.curated_at DESC, k.updated_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
[group_id, *params, limit],
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
for row in rows:
|
||||
selected.append(row)
|
||||
seen_doc_ids.add(int(row["id"]))
|
||||
|
||||
if len(selected) < limit:
|
||||
remaining = limit - len(selected)
|
||||
fts_query = build_match_query(query or "")
|
||||
if fts_query:
|
||||
async with db.execute(
|
||||
f"""
|
||||
SELECT k.id, k.content, k.updated_at, t.id AS topic_id, t.title, t.group_id,
|
||||
g.name AS group_name, g.talker
|
||||
FROM knowledge_docs k
|
||||
JOIN topics t ON t.id = k.topic_id
|
||||
LEFT JOIN groups g ON g.id = t.group_id
|
||||
WHERE k.curated_at IS NOT NULL
|
||||
AND k.id IN (SELECT doc_id FROM knowledge_fts WHERE knowledge_fts MATCH ?)
|
||||
{exclude_sql}
|
||||
ORDER BY CASE WHEN t.group_id=? THEN 0 ELSE 1 END,
|
||||
k.curated_at DESC,
|
||||
k.updated_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
[fts_query, *params, group_id or -1, remaining * 3],
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
for row in rows:
|
||||
doc_id = int(row["id"])
|
||||
if doc_id in seen_doc_ids:
|
||||
continue
|
||||
selected.append(row)
|
||||
seen_doc_ids.add(doc_id)
|
||||
if len(selected) >= limit:
|
||||
break
|
||||
|
||||
if len(selected) < limit:
|
||||
remaining = limit - len(selected)
|
||||
async with db.execute(
|
||||
f"""
|
||||
SELECT k.id, k.content, k.updated_at, t.id AS topic_id, t.title, t.group_id,
|
||||
g.name AS group_name, g.talker
|
||||
FROM knowledge_docs k
|
||||
JOIN topics t ON t.id = k.topic_id
|
||||
LEFT JOIN groups g ON g.id = t.group_id
|
||||
WHERE k.curated_at IS NOT NULL
|
||||
{exclude_sql}
|
||||
ORDER BY CASE WHEN t.group_id=? THEN 0 ELSE 1 END,
|
||||
k.curated_at DESC,
|
||||
k.updated_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
[*params, group_id or -1, remaining * 3],
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
for row in rows:
|
||||
doc_id = int(row["id"])
|
||||
if doc_id in seen_doc_ids:
|
||||
continue
|
||||
selected.append(row)
|
||||
seen_doc_ids.add(doc_id)
|
||||
if len(selected) >= limit:
|
||||
break
|
||||
|
||||
return _format_examples(selected[:limit], purpose)
|
||||
45
chatlog_fastAPI/services/runtime_settings.py
Normal file
45
chatlog_fastAPI/services/runtime_settings.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import logging
|
||||
import aiosqlite
|
||||
from config import settings as default_settings
|
||||
from database import get_active_db_path
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_cache: dict | None = None
|
||||
|
||||
|
||||
def invalidate_cache():
|
||||
global _cache
|
||||
_cache = None
|
||||
|
||||
|
||||
async def get_ai_settings() -> dict:
|
||||
global _cache
|
||||
if _cache is not None:
|
||||
return _cache
|
||||
|
||||
# ai_base_url 保留默认值(阿里云兼容 OpenAI 格式地址),其余字段必须由用户在设置页配置
|
||||
result = {
|
||||
"ai_base_url": default_settings.ai_base_url,
|
||||
"ai_api_key": "",
|
||||
"ai_model": "",
|
||||
"summary_model": "",
|
||||
"vision_model": "",
|
||||
"voice_model": "",
|
||||
"topic_analysis_prompt": "",
|
||||
}
|
||||
|
||||
try:
|
||||
path = get_active_db_path()
|
||||
async with aiosqlite.connect(path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute("SELECT key, value FROM app_settings") as cur:
|
||||
rows = await cur.fetchall()
|
||||
for row in rows:
|
||||
if row["key"] in result and row["value"]:
|
||||
result[row["key"]] = row["value"]
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to read runtime settings: {e}")
|
||||
|
||||
_cache = result
|
||||
return result
|
||||
476
chatlog_fastAPI/services/summary_engine.py
Normal file
476
chatlog_fastAPI/services/summary_engine.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""
|
||||
售后报告生成引擎
|
||||
- 从 topic_messages 拿到所有 msg_seq
|
||||
- 通过 chatlog batch 接口批量拉回消息原文
|
||||
- 用配置的总结模型生成 Markdown 售后事件报告
|
||||
- 写入 knowledge_docs + knowledge_fts(jieba 分词)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import aiosqlite
|
||||
from urllib.parse import quote
|
||||
|
||||
from database import get_active_db_path
|
||||
from services.ai_client import get_openai_client
|
||||
from services.fts import tokenize
|
||||
from services.message_formatter import append_quote_text, extract_contents, extract_quote
|
||||
from services.report_learning import build_report_learning_context
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
CHATLOG_BATCH_SIZE = 80
|
||||
SUMMARY_LLM_TIMEOUT_SECONDS = 300
|
||||
|
||||
|
||||
async def _get_client():
|
||||
return await get_openai_client()
|
||||
|
||||
|
||||
def _message_line(item: dict, fallback_seq: int = 0) -> tuple[int, str] | None:
|
||||
if not item:
|
||||
return None
|
||||
seq = item.get("seq") or item.get("Seq") or item.get("sort_seq") or fallback_seq or 0
|
||||
time_str = item.get("create_time") or item.get("time") or item.get("CreateTime") or ""
|
||||
sender = (
|
||||
item.get("sender_name")
|
||||
or item.get("senderName")
|
||||
or item.get("SenderName")
|
||||
or item.get("sender")
|
||||
or item.get("Sender")
|
||||
or ""
|
||||
)
|
||||
content = _message_text(item)
|
||||
if not content:
|
||||
return None
|
||||
return int(seq), f"[{time_str}] {sender}: {content}"
|
||||
|
||||
|
||||
def _message_meta(item: dict, fallback_seq: int = 0) -> dict:
|
||||
return {
|
||||
"seq": int(item.get("seq") or item.get("Seq") or item.get("sort_seq") or fallback_seq or 0),
|
||||
"time": item.get("create_time") or item.get("time") or item.get("CreateTime") or "",
|
||||
"sender": (
|
||||
item.get("sender_name")
|
||||
or item.get("senderName")
|
||||
or item.get("SenderName")
|
||||
or item.get("sender")
|
||||
or item.get("Sender")
|
||||
or ""
|
||||
),
|
||||
"type": item.get("type") or item.get("Type") or 1,
|
||||
}
|
||||
|
||||
|
||||
def _extract_contents(item: dict) -> dict:
|
||||
return extract_contents(item)
|
||||
|
||||
|
||||
def _message_text(item: dict) -> str:
|
||||
content = item.get("content") or item.get("Content") or ""
|
||||
contents = _extract_contents(item)
|
||||
if isinstance(content, str) and content.lstrip().startswith("<") and extract_quote(item):
|
||||
content = ""
|
||||
|
||||
link_title = contents.get("title") or item.get("link_title") or ""
|
||||
link_desc = contents.get("desc") or item.get("link_desc") or ""
|
||||
link_source = contents.get("sourceName") or contents.get("source_name") or item.get("link_source") or ""
|
||||
link_url = contents.get("url") or item.get("link_url") or ""
|
||||
|
||||
if link_title:
|
||||
parts = [f"[链接卡片] {link_title}"]
|
||||
if link_desc:
|
||||
parts.append(link_desc)
|
||||
if link_source:
|
||||
parts.append(f"来源:{link_source}")
|
||||
if link_url:
|
||||
parts.append(f"URL:{link_url}")
|
||||
if content and content not in parts:
|
||||
parts.append(content)
|
||||
return append_quote_text(";".join(parts), item)
|
||||
|
||||
return append_quote_text(content, item)
|
||||
|
||||
|
||||
def _extract_image_key(item: dict) -> str:
|
||||
contents = _extract_contents(item)
|
||||
key = (
|
||||
contents.get("rawmd5")
|
||||
or contents.get("md5")
|
||||
or contents.get("path")
|
||||
or item.get("media_key")
|
||||
or item.get("mediaKey")
|
||||
or item.get("image_path")
|
||||
or ""
|
||||
)
|
||||
return str(key).replace("\\", "/")
|
||||
|
||||
|
||||
def _is_image_message(item: dict) -> bool:
|
||||
try:
|
||||
return int(item.get("type") or item.get("Type") or 0) == 3
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _media_path(kind: str, key: str) -> str:
|
||||
return f"/{kind}/" + "/".join(quote(part) for part in key.split("/"))
|
||||
|
||||
|
||||
def _image_url(key: str) -> str:
|
||||
return f"{_media_path('image', key)}?thumb=1"
|
||||
|
||||
|
||||
def _collect_image_evidence(messages: list[dict]) -> tuple[list[dict], list[dict]]:
|
||||
images: list[dict] = []
|
||||
failures: list[dict] = []
|
||||
|
||||
for item in messages:
|
||||
if not _is_image_message(item):
|
||||
continue
|
||||
meta = _message_meta(item)
|
||||
key = _extract_image_key(item)
|
||||
if not key:
|
||||
failures.append({**meta, "url": "", "reason": "图片无法展示,缺少图片文件标识"})
|
||||
continue
|
||||
|
||||
url = _image_url(key)
|
||||
images.append({**meta, "key": key, "url": url})
|
||||
|
||||
return images, failures
|
||||
|
||||
|
||||
def _image_evidence_context(images: list[dict], failures: list[dict]) -> str:
|
||||
lines: list[str] = []
|
||||
if images:
|
||||
lines.append("系统将作为原始材料插入报告的现场图片:")
|
||||
for img in images:
|
||||
lines.append(f"- [{img['time']}] {img['sender']} seq={img['seq']} url={img['url']}")
|
||||
if failures:
|
||||
lines.append("无法展示的图片清单:")
|
||||
for img in failures:
|
||||
link = f",查看图片:{img['url']}" if img.get("url") else ""
|
||||
lines.append(f"- [{img['time']}] {img['sender']} seq={img['seq']}:{img['reason']}{link}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _image_success_markdown(images: list[dict]) -> str:
|
||||
if not images:
|
||||
return ""
|
||||
blocks = ["### 现场图片"]
|
||||
for img in images:
|
||||
alt = f"现场图片 - {img['time']} {img['sender']}".strip()
|
||||
blocks.extend(
|
||||
[
|
||||
f"",
|
||||
f"来源:{img['time']} {img['sender']} seq={img['seq']}",
|
||||
"",
|
||||
]
|
||||
)
|
||||
return "\n".join(blocks).strip()
|
||||
|
||||
|
||||
def _image_failure_markdown(failures: list[dict]) -> str:
|
||||
if not failures:
|
||||
return ""
|
||||
lines = ["## 图片展示提示"]
|
||||
for img in failures:
|
||||
link = f",查看图片:{img['url']}" if img.get("url") else ""
|
||||
lines.append(f"- [{img['time']}] {img['sender']} seq={img['seq']}:{img['reason']}{link}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _insert_after_heading(content: str, heading: str, addition: str) -> str:
|
||||
if not addition:
|
||||
return content
|
||||
lines = content.splitlines()
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip() == heading:
|
||||
return "\n".join(lines[: i + 1] + ["", addition, ""] + lines[i + 1 :]).strip()
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith("# "):
|
||||
return "\n".join(lines[: i + 1] + ["", heading, "", addition, ""] + lines[i + 1 :]).strip()
|
||||
return f"{heading}\n\n{addition}\n\n{content}".strip()
|
||||
|
||||
|
||||
def _merge_image_sections(content: str, successes: list[dict], failures: list[dict]) -> str:
|
||||
result = _insert_after_heading(content, "## 关键聊天依据", _image_success_markdown(successes))
|
||||
failure_md = _image_failure_markdown(failures)
|
||||
if failure_md:
|
||||
result = f"{result.rstrip()}\n\n{failure_md}"
|
||||
return result.strip()
|
||||
|
||||
|
||||
def _line_from_snapshot(raw: str | None, fallback_seq: int) -> str | None:
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
item = json.loads(raw)
|
||||
except Exception:
|
||||
return None
|
||||
line = _message_line(item, fallback_seq)
|
||||
return line[1] if line else None
|
||||
|
||||
MARKDOWN_TEMPLATE = """\
|
||||
# {title}
|
||||
|
||||
请按聊天记录中的实际内容生成一份【具体售后问题点】报告,不要照抄固定字段,也不要输出占位文案。
|
||||
|
||||
必须围绕以下结构组织,按内容决定是否保留章节,不要输出空章节:
|
||||
## 问题摘要
|
||||
## 关键聊天依据
|
||||
## 当前处理状态
|
||||
## 是否解决
|
||||
## AI 建议/解决方法
|
||||
|
||||
输出规则:
|
||||
- 只写聊天记录中能直接识别或合理归纳的信息。
|
||||
- 没有识别到的客户、门店、联系人、合同、订单、物流、日期、价格、原因等信息直接省略。
|
||||
- 不要写“未从聊天记录中识别”“待补充”“未知”“无”等占位内容。
|
||||
- “是否解决”只能从聊天记录判断,取值限定为:已解决、未解决、处理中、待确认。
|
||||
- 如果聊天内容不足以形成明确售后问题点,仍然按当前话题内容整理,但用更保守的“待确认”结论。
|
||||
- “AI 建议/解决方法”必须放在文档下方,并附注:注:此方法由 AI 生成,仅供参考,请以人工复核和现场实际情况为准。
|
||||
- 只输出 Markdown 报告,不要输出这些规则本身。
|
||||
"""
|
||||
|
||||
|
||||
async def _mark_summarize_failed(topic_id: int, task_id: int | None, error: str):
|
||||
path = get_active_db_path()
|
||||
message = error or "AI 报告生成失败"
|
||||
try:
|
||||
async with aiosqlite.connect(path) as db:
|
||||
await db.execute(
|
||||
"UPDATE topics SET status = 'error', updated_at = CURRENT_TIMESTAMP WHERE id = ?",
|
||||
(topic_id,),
|
||||
)
|
||||
if task_id is not None:
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE ai_tasks
|
||||
SET status='error', progress=?, error=?, updated_at=CURRENT_TIMESTAMP
|
||||
WHERE id=?
|
||||
""",
|
||||
(json.dumps({"processed": 0, "total": 1}), message, task_id),
|
||||
)
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
log.warning(f"[summarize] 标记失败状态失败 topic={topic_id} task={task_id}: {exc}")
|
||||
|
||||
|
||||
async def _run_summarize_impl(topic_id: int, topic: dict, task_id: int | None = None):
|
||||
"""
|
||||
为指定话题生成/更新 Markdown 售后事件报告。
|
||||
由 POST /api/topics/{id}/summarize(手动触发)调用。
|
||||
task_id: 若提供,则更新 ai_tasks 表的状态和进度。
|
||||
"""
|
||||
path = get_active_db_path()
|
||||
|
||||
async def _update_task(status: str, processed: int = 0, total: int = 1, error: str = ""):
|
||||
"""辅助函数:更新 ai_tasks 状态和进度"""
|
||||
if task_id is None:
|
||||
return
|
||||
try:
|
||||
async with aiosqlite.connect(path) as _db:
|
||||
_db.row_factory = aiosqlite.Row
|
||||
await _db.execute(
|
||||
"""
|
||||
UPDATE ai_tasks
|
||||
SET status=?, progress=?, error=?, updated_at=CURRENT_TIMESTAMP
|
||||
WHERE id=?
|
||||
""",
|
||||
(status, json.dumps({"processed": processed, "total": total}), error or None, task_id)
|
||||
)
|
||||
await _db.commit()
|
||||
except Exception as e:
|
||||
log.warning(f"[summarize] 更新 task {task_id} 失败: {e}")
|
||||
path = get_active_db_path()
|
||||
async with aiosqlite.connect(path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
# 将话题状态置为 processing
|
||||
await db.execute("UPDATE topics SET status = 'processing', updated_at = CURRENT_TIMESTAMP WHERE id = ?", (topic_id,))
|
||||
await db.commit()
|
||||
await _update_task("running", 0, 1)
|
||||
|
||||
# 1. 拿到该话题的所有消息 seq 和群 talker
|
||||
async with db.execute(
|
||||
"""
|
||||
SELECT tm.msg_seq, tm.talker, tm.message_json
|
||||
FROM topic_messages tm
|
||||
WHERE tm.topic_id = ?
|
||||
ORDER BY tm.msg_seq
|
||||
""",
|
||||
(topic_id,),
|
||||
) as cur:
|
||||
msg_rows = await cur.fetchall()
|
||||
|
||||
if not msg_rows:
|
||||
log.warning(f"[summarize] topic={topic_id} 没有消息,跳过")
|
||||
error = "该话题没有关联消息,无法生成 AI 报告"
|
||||
await db.execute("UPDATE topics SET status = 'error', updated_at = CURRENT_TIMESTAMP WHERE id = ?", (topic_id,))
|
||||
await db.commit()
|
||||
await _update_task("error", 0, 1, error)
|
||||
return
|
||||
|
||||
seqs = [r["msg_seq"] for r in msg_rows]
|
||||
# talker 在 topic_messages 里存的是群 ID(chatlog 叫 talker)
|
||||
group_talker = msg_rows[0]["talker"]
|
||||
|
||||
# 2. 批量从 chatlog 拉取消息原文(最多 100 条/批)
|
||||
from services.chatlog_client import chatlog_client
|
||||
messages_text: list[str] = []
|
||||
message_items: dict[int, dict] = {}
|
||||
|
||||
fetched_lines: dict[int, str] = {}
|
||||
for i in range(0, len(seqs), CHATLOG_BATCH_SIZE):
|
||||
chunk_seqs = seqs[i: i + CHATLOG_BATCH_SIZE]
|
||||
try:
|
||||
result = await chatlog_client.get_messages_batch(group_talker, chunk_seqs)
|
||||
for m in result.get("items", []):
|
||||
meta = _message_meta(m)
|
||||
if meta["seq"]:
|
||||
message_items[meta["seq"]] = m
|
||||
line = _message_line(m)
|
||||
if line:
|
||||
fetched_lines[line[0]] = line[1]
|
||||
except Exception as e:
|
||||
log.error(f"[summarize] batch 拉取失败 topic={topic_id}: {e}")
|
||||
|
||||
for r in msg_rows:
|
||||
seq = int(r["msg_seq"])
|
||||
if seq in fetched_lines:
|
||||
messages_text.append(fetched_lines[seq])
|
||||
continue
|
||||
snap_raw = r["message_json"] if "message_json" in r.keys() else None
|
||||
if seq not in message_items and snap_raw:
|
||||
try:
|
||||
snap_item = json.loads(snap_raw)
|
||||
if isinstance(snap_item, dict):
|
||||
message_items[seq] = snap_item
|
||||
except Exception:
|
||||
pass
|
||||
snap_line = _line_from_snapshot(snap_raw, seq)
|
||||
if snap_line:
|
||||
messages_text.append(snap_line)
|
||||
|
||||
image_successes, image_failures = _collect_image_evidence(
|
||||
[message_items[seq] for seq in seqs if seq in message_items]
|
||||
)
|
||||
|
||||
if not messages_text and not image_successes and not image_failures:
|
||||
log.warning(f"[summarize] topic={topic_id} 从 chatlog 获取到 0 条有效消息")
|
||||
error = "未能从 chatlog 获取到有效消息,无法生成 AI 报告"
|
||||
await db.execute("UPDATE topics SET status = 'error', updated_at = CURRENT_TIMESTAMP WHERE id = ?", (topic_id,))
|
||||
await db.commit()
|
||||
await _update_task("error", 0, 1, error)
|
||||
return
|
||||
|
||||
chat_text = "\n".join(messages_text) if messages_text else "无文字消息,仅有图片或媒体证据。"
|
||||
image_context = _image_evidence_context(image_successes, image_failures)
|
||||
learning_context = await build_report_learning_context(
|
||||
db,
|
||||
group_id=topic.get("group_id"),
|
||||
query=f"{topic.get('title', '')}\n{chat_text[:2000]}",
|
||||
exclude_topic_id=topic_id,
|
||||
purpose="summary",
|
||||
)
|
||||
|
||||
# 3. 构建 Prompt
|
||||
template_filled = MARKDOWN_TEMPLATE.format(title=topic["title"])
|
||||
prompt = (
|
||||
f"售后问题点话题:{topic['title']}\n\n"
|
||||
f"以下是该售后问题点关联的完整微信群聊天记录(按时间顺序):\n\n"
|
||||
f"{chat_text}\n\n"
|
||||
f"以下是系统将插入报告的现场图片信息(如有):\n\n{image_context or '无现场图片。'}\n\n"
|
||||
"请根据上述聊天记录输出一份 Markdown 报告。\n"
|
||||
"报告要求:\n"
|
||||
"1. 保持售后问题点口径,优先提炼问题现象、涉及产品/部件、现场材料、处理过程和处理结果。\n"
|
||||
"2. 只能使用聊天记录中能直接识别或合理归纳的信息,不要编造客户、合同、订单、物流、日期、价格、原因或处理结果。\n"
|
||||
"3. 不要输出空字段、空项目、空章节、空表格;某个章节没有有效内容时整段省略。\n"
|
||||
"4. 「是否解决」必须写在文档中,并使用:已解决 / 未解决 / 处理中 / 待确认。\n"
|
||||
"5. 「AI 建议/解决方法」必须写在文档中,且在段末附上固定注释:注:此方法由 AI 生成,仅供参考,请以人工复核和现场实际情况为准。\n"
|
||||
"6. 如果聊天内容不足以形成明确售后问题点,也不要编造结论;只按聊天中已有事实给出保守的待确认判断。\n"
|
||||
"7. 图片会由系统作为「现场图片」原始材料插入「关键聊天依据」;你不要猜测图片内容,也不要自行输出图片 Markdown 或图片说明。\n"
|
||||
"8. 如果聊天文字中有人描述图片内容,可以引用这些文字;但不要根据图片本身编造故障细节。\n"
|
||||
"9. 聊天记录中的「[引用消息]」属于当前回复的上下文证据,可以用于理解被回复的问题和处理过程。\n"
|
||||
"10. 只输出 Markdown 报告,不要输出模板说明或额外解释。\n\n"
|
||||
f"以下是本企业报告库中人工修订过的历史报告示例(如有)。请只学习它们的栏目结构、措辞风格、问题关注点和结论表达方式;不得复制历史事实、客户名、设备状态或处理结果到当前报告:\n\n{learning_context or '暂无可学习的人工修订报告。'}\n\n"
|
||||
f"{template_filled}"
|
||||
)
|
||||
|
||||
# 4. 调用 LLM
|
||||
try:
|
||||
_client, _ai = await _get_client()
|
||||
async with asyncio.timeout(SUMMARY_LLM_TIMEOUT_SECONDS):
|
||||
resp = await _client.chat.completions.create(
|
||||
model=_ai["summary_model"],
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"你是资深售后运营与设备服务工程师,负责根据微信群聊天记录整理具体售后问题点报告。"
|
||||
"你必须忠实依据聊天记录,只输出已识别到的有效信息,缺失信息直接省略,不得编造。"
|
||||
"你要在文档中明确给出是否解决结论,并给出 AI 建议/解决方法和免责声明。只输出 Markdown 报告,不要有任何额外说明。"
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.2,
|
||||
)
|
||||
content = resp.choices[0].message.content.strip()
|
||||
content = _merge_image_sections(content, image_successes, image_failures)
|
||||
except TimeoutError:
|
||||
error = "AI 报告生成超时,请检查模型/API或稍后重试"
|
||||
log.error(f"[summarize] LLM 调用超时 topic={topic_id}")
|
||||
await db.execute("UPDATE topics SET status = 'error', updated_at = CURRENT_TIMESTAMP WHERE id = ?", (topic_id,))
|
||||
await db.commit()
|
||||
await _update_task("error", 0, 1, error)
|
||||
return
|
||||
except Exception as e:
|
||||
log.error(f"[summarize] LLM 调用失败 topic={topic_id}: {e}", exc_info=True)
|
||||
await db.execute("UPDATE topics SET status = 'error', updated_at = CURRENT_TIMESTAMP WHERE id = ?", (topic_id,))
|
||||
await db.commit()
|
||||
await _update_task("error", 0, 1, str(e) or "LLM 调用失败")
|
||||
return
|
||||
|
||||
# 5. 写入 knowledge_docs
|
||||
async with db.execute(
|
||||
"SELECT id FROM knowledge_docs WHERE topic_id = ?", (topic_id,)
|
||||
) as cur:
|
||||
existing = await cur.fetchone()
|
||||
|
||||
if existing:
|
||||
doc_id = existing["id"]
|
||||
await db.execute(
|
||||
"UPDATE knowledge_docs SET content = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?",
|
||||
(content, doc_id),
|
||||
)
|
||||
else:
|
||||
await db.execute(
|
||||
"INSERT INTO knowledge_docs (topic_id, content) VALUES (?, ?)",
|
||||
(topic_id, content),
|
||||
)
|
||||
async with db.execute("SELECT last_insert_rowid() AS id") as cur:
|
||||
doc_id = (await cur.fetchone())["id"]
|
||||
|
||||
# 6. 更新 FTS(先删后插)
|
||||
await db.execute("DELETE FROM knowledge_fts WHERE doc_id = ?", (doc_id,))
|
||||
await db.execute(
|
||||
"INSERT INTO knowledge_fts (doc_id, title, content) VALUES (?, ?, ?)",
|
||||
(doc_id, tokenize(topic["title"]), tokenize(content)),
|
||||
)
|
||||
|
||||
await db.execute("UPDATE topics SET status = 'completed', updated_at = CURRENT_TIMESTAMP WHERE id = ?", (topic_id,))
|
||||
await db.commit()
|
||||
await _update_task("done", 1, 1)
|
||||
log.info(f"[summarize] topic={topic_id} doc={doc_id} 生成完成({len(content)} 字符)")
|
||||
|
||||
|
||||
async def run_summarize(topic_id: int, topic: dict, task_id: int | None = None):
|
||||
try:
|
||||
await _run_summarize_impl(topic_id, topic, task_id)
|
||||
except Exception as e:
|
||||
error = str(e) or e.__class__.__name__
|
||||
log.error(f"[summarize] 未捕获异常 topic={topic_id}: {error}", exc_info=True)
|
||||
await _mark_summarize_failed(topic_id, task_id, error)
|
||||
1094
chatlog_fastAPI/services/topic_engine.py
Normal file
1094
chatlog_fastAPI/services/topic_engine.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user