Files
get_wechat/chatlog_fastAPI/routers/topics.py

436 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from typing import Optional
import aiosqlite, json
from datetime import datetime
from urllib.parse import quote
from database import get_db
from services.message_formatter import extract_quote
router = APIRouter(prefix="/api/topics", tags=["topics"])
CHATLOG_BATCH_SIZE = 80
NAME_LOOKUP_PAGE_SIZE = 500
NAME_LOOKUP_MAX_ITEMS = 5000
STALE_SUMMARIZE_MINUTES = 15
class TopicCreate(BaseModel):
group_id: int
title: str
class TopicPatch(BaseModel):
title: Optional[str] = None
status: Optional[str] = None
class MessageAdd(BaseModel):
msg_seq: int
talker: str
async def _mark_stale_summarize_tasks(db: aiosqlite.Connection, group_id: int, topic_id: int) -> None:
error = "AI 报告生成任务超过 15 分钟未完成,已自动标记为失败,可重新生成"
stale_window = f"-{STALE_SUMMARIZE_MINUTES} minutes"
await db.execute(
"""
UPDATE ai_tasks
SET status='error', error=?, updated_at=CURRENT_TIMESTAMP
WHERE group_id=?
AND type='summarize'
AND status='running'
AND datetime(updated_at) <= datetime('now', ?)
""",
(error, group_id, stale_window),
)
await db.execute(
"""
UPDATE topics
SET status='error', updated_at=CURRENT_TIMESTAMP
WHERE id=?
AND status='processing'
AND datetime(updated_at) <= datetime('now', ?)
""",
(topic_id, stale_window),
)
def _normalize_chatlog_message(item: dict, fallback_seq: int = 0) -> dict:
contents = item.get("contents") or item.get("Contents") or {}
if not isinstance(contents, dict):
contents = {}
media_key = (
contents.get("rawmd5")
or contents.get("md5")
or contents.get("path")
or item.get("media_key")
or item.get("mediaKey")
or ""
)
voice_key = (
str(contents.get("voice"))
if contents.get("voice")
else item.get("voice_key") or item.get("voiceKey") or ""
)
raw_type = item.get("type") or item.get("Type") or 1
raw_sub_type = item.get("sub_type") or item.get("subType") or item.get("SubType") or 0
try:
is_file = int(raw_type) == 49 and int(raw_sub_type) == 6
except Exception:
is_file = False
file_md5 = str(contents.get("md5") or item.get("file_md5") or item.get("fileMd5") or "") if is_file else ""
file_name = (
contents.get("title")
or contents.get("fileName")
or contents.get("filename")
or item.get("file_name")
or item.get("fileName")
or ""
) if is_file else ""
file_url = f"/api/files/{quote(file_md5, safe='')}?filename={quote(file_name or file_md5, safe='')}" if file_md5 else ""
return {
"seq": item.get("seq") or item.get("Seq") or item.get("sort_seq") or fallback_seq or 0,
"sender": item.get("sender") or item.get("Sender") or "",
"sender_name": (
item.get("sender_name")
or item.get("senderName")
or item.get("SenderName")
or item.get("sender")
or item.get("Sender")
or ""
),
"create_time": item.get("create_time") or item.get("time") or item.get("CreateTime") or "",
"content": item.get("content") or item.get("Content") or "",
"type": raw_type,
"sub_type": raw_sub_type,
"contents": contents,
"media_key": media_key,
"voice_key": voice_key,
"image_path": media_key,
"voice_path": voice_key,
"video_path": media_key,
"file_path": media_key,
"link_url": contents.get("url") or item.get("link_url") or "",
"link_title": contents.get("title") or item.get("link_title") or "",
"link_desc": contents.get("desc") or item.get("link_desc") or "",
"link_thumb": contents.get("thumbUrl") or contents.get("thumb_url") or item.get("link_thumb") or "",
"link_source": contents.get("sourceName") or contents.get("source_name") or item.get("link_source") or "",
"quote": item.get("quote") or extract_quote(item),
"is_file": is_file,
"file_name": file_name,
"file_md5": file_md5,
"file_url": file_url,
}
def _message_from_snapshot(raw: str | None, fallback_seq: int) -> dict | None:
if not raw:
return None
try:
item = json.loads(raw)
except Exception:
return None
if not isinstance(item, dict):
return None
return _normalize_chatlog_message(item, fallback_seq)
def _looks_like_raw_sender_id(value: str | None) -> bool:
value = (value or "").strip()
return (
not value
or value.startswith("wxid_")
or value.startswith("gh_")
or value.endswith("@chatroom")
or value.startswith("chatroom_")
)
def _sender_display_name(item: dict, sender: str = "") -> str:
for key in (
"sender_name",
"senderName",
"SenderName",
"accountName",
"groupNickname",
"displayName",
"nickName",
"remark",
):
value = str(item.get(key) or "").strip()
if value and value != sender and not _looks_like_raw_sender_id(value):
return value
return ""
def _message_date(value: str | None) -> str | None:
value = (value or "").strip()
if not value:
return None
try:
return datetime.fromisoformat(value.replace("Z", "+00:00")).strftime("%Y-%m-%d")
except Exception:
return value[:10] if len(value) >= 10 and value[4:5] == "-" else None
async def _build_sender_name_map(talker: str, messages: list[dict]) -> dict[str, str]:
names: dict[str, str] = {}
for msg in messages:
sender = str(msg.get("sender") or "").strip()
name = _sender_display_name(msg, sender)
if sender and name:
names[sender] = name
missing = {
str(msg.get("sender") or "").strip()
for msg in messages
if str(msg.get("sender") or "").strip()
and _looks_like_raw_sender_id(str(msg.get("sender_name") or "").strip())
}
missing = {sender for sender in missing if sender not in names}
if not missing:
return names
from services.chatlog_client import chatlog_client
dates = sorted({
date
for msg in messages
for date in [_message_date(str(msg.get("create_time") or ""))]
if date
})
if dates:
time_range = f"{dates[0]},{dates[-1]}"
offset = 0
seen = 0
while missing and seen < NAME_LOOKUP_MAX_ITEMS:
try:
data = await chatlog_client.get_messages(
talker,
time=time_range,
limit=NAME_LOOKUP_PAGE_SIZE,
offset=offset,
)
except Exception as e:
print(f"Failed to fetch sender names talker={talker}: {e}")
break
items = data.get("items", []) or []
if not items:
break
for item in items:
sender = str(item.get("sender") or item.get("Sender") or "").strip()
if sender not in missing:
continue
name = _sender_display_name(item, sender)
if name:
names[sender] = name
missing.discard(sender)
seen += len(items)
offset += len(items)
total = int(data.get("total") or 0)
if total and offset >= total:
break
if len(items) < NAME_LOOKUP_PAGE_SIZE:
break
if missing:
try:
members = await chatlog_client.get_chatroom_members(talker)
raw_members = members.get("members", []) if isinstance(members, dict) else []
for member in raw_members:
sender = str(member.get("userName") or member.get("UserName") or "").strip()
if sender not in missing:
continue
name = _sender_display_name(
{
"displayName": member.get("displayName") or member.get("DisplayName"),
"nickName": member.get("nickName") or member.get("NickName"),
"remark": member.get("remark") or member.get("Remark"),
},
sender,
)
if name:
names[sender] = name
missing.discard(sender)
except Exception as e:
print(f"Failed to fetch chatroom members talker={talker}: {e}")
return names
async def _fill_sender_names(talker: str, messages: list[dict]) -> None:
if not talker or not messages:
return
names = await _build_sender_name_map(talker, messages)
if not names:
return
for msg in messages:
sender = str(msg.get("sender") or "").strip()
if not sender or sender not in names:
continue
current = str(msg.get("sender_name") or "").strip()
if not current or current == sender or _looks_like_raw_sender_id(current):
msg["sender_name"] = names[sender]
msg["senderName"] = names[sender]
@router.get("")
async def list_topics(
group_id: Optional[int] = None,
status: Optional[str] = None,
keyword: Optional[str] = None,
db: aiosqlite.Connection = Depends(get_db)
):
sql = "SELECT * FROM topics WHERE 1=1"
params = []
if group_id:
sql += " AND group_id=?"; params.append(group_id)
if status:
sql += " AND status=?"; params.append(status)
if keyword:
sql += " AND title LIKE ?"; params.append(f"%{keyword}%")
async with db.execute(sql, params) as cur:
return [dict(r) for r in await cur.fetchall()]
@router.post("")
async def create_topic(body: TopicCreate, db: aiosqlite.Connection = Depends(get_db)):
await db.execute(
"INSERT INTO topics (group_id, title, source) VALUES (?, ?, 'manual')",
(body.group_id, body.title),
)
await db.commit()
async with db.execute("SELECT * FROM topics ORDER BY id DESC LIMIT 1") as cur:
return dict(await cur.fetchone())
@router.get("/{topic_id}")
async def get_topic(topic_id: int, db: aiosqlite.Connection = Depends(get_db)):
async with db.execute("SELECT * FROM topics WHERE id=?", (topic_id,)) as cur:
row = await cur.fetchone()
if not row:
raise HTTPException(404, "not found")
# 拿到该话题下关联的所有 seq
async with db.execute("SELECT * FROM topic_messages WHERE topic_id=? ORDER BY msg_seq ASC", (topic_id,)) as cur:
msg_rows = await cur.fetchall()
msgs = []
if msg_rows:
# 获取群聊 ID
talker = msg_rows[0]["talker"]
seq_list = [r["msg_seq"] for r in msg_rows]
# 分批调用 5030 接口获取真实消息内容,避免大批量 batch 触发 500。
from services.chatlog_client import chatlog_client
fetched_by_seq: dict[int, dict] = {}
for i in range(0, len(seq_list), CHATLOG_BATCH_SIZE):
chunk = seq_list[i: i + CHATLOG_BATCH_SIZE]
try:
msgs_data = await chatlog_client.get_messages_batch(talker, chunk)
raw_items = msgs_data.get("items", []) or msgs_data.get("Items", [])
for item in raw_items:
normalized = _normalize_chatlog_message(item)
seq = normalized.get("seq")
if seq:
fetched_by_seq[int(seq)] = normalized
except Exception as e:
print(f"Failed to fetch real messages chunk topic={topic_id}: {e}")
for r in msg_rows:
seq = int(r["msg_seq"])
if seq in fetched_by_seq:
msgs.append(fetched_by_seq[seq])
continue
snap = _message_from_snapshot(r["message_json"] if "message_json" in r.keys() else None, seq)
if snap:
msgs.append(snap)
continue
msgs.append({
"seq": seq,
"sender": "",
"sender_name": "系统提示",
"create_time": "",
"content": f"原始消息无法从 chatlog 找回 (seq: {seq})",
"type": 1,
"sub_type": 0,
})
# 获取知识文档
if msg_rows:
await _fill_sender_names(talker, msgs)
async with db.execute("SELECT id, topic_id, content, created_at, updated_at FROM knowledge_docs WHERE topic_id=?", (topic_id,)) as cur:
doc = await cur.fetchone()
return {**dict(row), "messages": msgs, "knowledge_doc": dict(doc) if doc else None}
@router.patch("/{topic_id}")
async def patch_topic(topic_id: int, body: TopicPatch, db: aiosqlite.Connection = Depends(get_db)):
if body.title:
await db.execute(
"UPDATE topics SET title=?, source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
(body.title, topic_id),
)
if body.status:
await db.execute(
"UPDATE topics SET status=?, source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
(body.status, topic_id),
)
await db.commit()
async with db.execute("SELECT * FROM topics WHERE id=?", (topic_id,)) as cur:
return dict(await cur.fetchone())
@router.delete("/{topic_id}")
async def delete_topic(topic_id: int, db: aiosqlite.Connection = Depends(get_db)):
# 先拿到 doc_id用于清 FTS
async with db.execute("SELECT id FROM knowledge_docs WHERE topic_id=?", (topic_id,)) as cur:
doc_row = await cur.fetchone()
if doc_row:
await db.execute("DELETE FROM knowledge_fts WHERE doc_id=?", (doc_row["id"],))
await db.execute("DELETE FROM topic_messages WHERE topic_id=?", (topic_id,))
await db.execute("DELETE FROM knowledge_docs WHERE topic_id=?", (topic_id,))
await db.execute("DELETE FROM topics WHERE id=?", (topic_id,))
await db.commit()
return {"ok": True}
@router.post("/{topic_id}/messages")
async def add_message(topic_id: int, body: MessageAdd, db: aiosqlite.Connection = Depends(get_db)):
await db.execute(
"INSERT OR IGNORE INTO topic_messages (topic_id, msg_seq, talker, added_by) VALUES (?, ?, ?, 'user')",
(topic_id, body.msg_seq, body.talker)
)
await db.execute(
"UPDATE topics SET source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
(topic_id,),
)
await db.commit()
return {"ok": True}
@router.delete("/{topic_id}/messages/{seq}")
async def remove_message(topic_id: int, seq: int, db: aiosqlite.Connection = Depends(get_db)):
await db.execute("DELETE FROM topic_messages WHERE topic_id=? AND msg_seq=?", (topic_id, seq))
await db.execute(
"UPDATE topics SET source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
(topic_id,),
)
await db.commit()
return {"ok": True}
@router.post("/{topic_id}/summarize")
async def summarize(topic_id: int, db: aiosqlite.Connection = Depends(get_db)):
async with db.execute("SELECT * FROM topics WHERE id=?", (topic_id,)) as cur:
topic = await cur.fetchone()
if not topic:
raise HTTPException(404, "not found")
topic_data = dict(topic)
await _mark_stale_summarize_tasks(db, topic_data["group_id"], topic_id)
# 创建 ai_tasks 记录以追踪进度
await db.execute(
"INSERT INTO ai_tasks (group_id, type, status, progress) VALUES (?, 'summarize', 'running', ?)",
(topic_data["group_id"], json.dumps({"processed": 0, "total": 1}))
)
await db.commit()
async with db.execute("SELECT last_insert_rowid() AS id") as cur:
task_row = await cur.fetchone()
task_id = task_row["id"]
from services.summary_engine import run_summarize
import asyncio
asyncio.create_task(run_summarize(topic_id, topic_data, task_id))
return {"ok": True, "task_id": task_id}