436 lines
16 KiB
Python
436 lines
16 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, Query
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
import aiosqlite, json
|
||
from datetime import datetime
|
||
from urllib.parse import quote
|
||
from database import get_db
|
||
from services.message_formatter import extract_quote
|
||
|
||
router = APIRouter(prefix="/api/topics", tags=["topics"])
|
||
CHATLOG_BATCH_SIZE = 80
|
||
NAME_LOOKUP_PAGE_SIZE = 500
|
||
NAME_LOOKUP_MAX_ITEMS = 5000
|
||
STALE_SUMMARIZE_MINUTES = 15
|
||
|
||
class TopicCreate(BaseModel):
|
||
group_id: int
|
||
title: str
|
||
|
||
class TopicPatch(BaseModel):
|
||
title: Optional[str] = None
|
||
status: Optional[str] = None
|
||
|
||
class MessageAdd(BaseModel):
|
||
msg_seq: int
|
||
talker: str
|
||
|
||
|
||
async def _mark_stale_summarize_tasks(db: aiosqlite.Connection, group_id: int, topic_id: int) -> None:
|
||
error = "AI 报告生成任务超过 15 分钟未完成,已自动标记为失败,可重新生成"
|
||
stale_window = f"-{STALE_SUMMARIZE_MINUTES} minutes"
|
||
await db.execute(
|
||
"""
|
||
UPDATE ai_tasks
|
||
SET status='error', error=?, updated_at=CURRENT_TIMESTAMP
|
||
WHERE group_id=?
|
||
AND type='summarize'
|
||
AND status='running'
|
||
AND datetime(updated_at) <= datetime('now', ?)
|
||
""",
|
||
(error, group_id, stale_window),
|
||
)
|
||
await db.execute(
|
||
"""
|
||
UPDATE topics
|
||
SET status='error', updated_at=CURRENT_TIMESTAMP
|
||
WHERE id=?
|
||
AND status='processing'
|
||
AND datetime(updated_at) <= datetime('now', ?)
|
||
""",
|
||
(topic_id, stale_window),
|
||
)
|
||
|
||
|
||
def _normalize_chatlog_message(item: dict, fallback_seq: int = 0) -> dict:
|
||
contents = item.get("contents") or item.get("Contents") or {}
|
||
if not isinstance(contents, dict):
|
||
contents = {}
|
||
media_key = (
|
||
contents.get("rawmd5")
|
||
or contents.get("md5")
|
||
or contents.get("path")
|
||
or item.get("media_key")
|
||
or item.get("mediaKey")
|
||
or ""
|
||
)
|
||
voice_key = (
|
||
str(contents.get("voice"))
|
||
if contents.get("voice")
|
||
else item.get("voice_key") or item.get("voiceKey") or ""
|
||
)
|
||
raw_type = item.get("type") or item.get("Type") or 1
|
||
raw_sub_type = item.get("sub_type") or item.get("subType") or item.get("SubType") or 0
|
||
try:
|
||
is_file = int(raw_type) == 49 and int(raw_sub_type) == 6
|
||
except Exception:
|
||
is_file = False
|
||
file_md5 = str(contents.get("md5") or item.get("file_md5") or item.get("fileMd5") or "") if is_file else ""
|
||
file_name = (
|
||
contents.get("title")
|
||
or contents.get("fileName")
|
||
or contents.get("filename")
|
||
or item.get("file_name")
|
||
or item.get("fileName")
|
||
or ""
|
||
) if is_file else ""
|
||
file_url = f"/api/files/{quote(file_md5, safe='')}?filename={quote(file_name or file_md5, safe='')}" if file_md5 else ""
|
||
return {
|
||
"seq": item.get("seq") or item.get("Seq") or item.get("sort_seq") or fallback_seq or 0,
|
||
"sender": item.get("sender") or item.get("Sender") or "",
|
||
"sender_name": (
|
||
item.get("sender_name")
|
||
or item.get("senderName")
|
||
or item.get("SenderName")
|
||
or item.get("sender")
|
||
or item.get("Sender")
|
||
or ""
|
||
),
|
||
"create_time": item.get("create_time") or item.get("time") or item.get("CreateTime") or "",
|
||
"content": item.get("content") or item.get("Content") or "",
|
||
"type": raw_type,
|
||
"sub_type": raw_sub_type,
|
||
"contents": contents,
|
||
"media_key": media_key,
|
||
"voice_key": voice_key,
|
||
"image_path": media_key,
|
||
"voice_path": voice_key,
|
||
"video_path": media_key,
|
||
"file_path": media_key,
|
||
"link_url": contents.get("url") or item.get("link_url") or "",
|
||
"link_title": contents.get("title") or item.get("link_title") or "",
|
||
"link_desc": contents.get("desc") or item.get("link_desc") or "",
|
||
"link_thumb": contents.get("thumbUrl") or contents.get("thumb_url") or item.get("link_thumb") or "",
|
||
"link_source": contents.get("sourceName") or contents.get("source_name") or item.get("link_source") or "",
|
||
"quote": item.get("quote") or extract_quote(item),
|
||
"is_file": is_file,
|
||
"file_name": file_name,
|
||
"file_md5": file_md5,
|
||
"file_url": file_url,
|
||
}
|
||
|
||
|
||
def _message_from_snapshot(raw: str | None, fallback_seq: int) -> dict | None:
|
||
if not raw:
|
||
return None
|
||
try:
|
||
item = json.loads(raw)
|
||
except Exception:
|
||
return None
|
||
if not isinstance(item, dict):
|
||
return None
|
||
return _normalize_chatlog_message(item, fallback_seq)
|
||
|
||
|
||
def _looks_like_raw_sender_id(value: str | None) -> bool:
|
||
value = (value or "").strip()
|
||
return (
|
||
not value
|
||
or value.startswith("wxid_")
|
||
or value.startswith("gh_")
|
||
or value.endswith("@chatroom")
|
||
or value.startswith("chatroom_")
|
||
)
|
||
|
||
|
||
def _sender_display_name(item: dict, sender: str = "") -> str:
|
||
for key in (
|
||
"sender_name",
|
||
"senderName",
|
||
"SenderName",
|
||
"accountName",
|
||
"groupNickname",
|
||
"displayName",
|
||
"nickName",
|
||
"remark",
|
||
):
|
||
value = str(item.get(key) or "").strip()
|
||
if value and value != sender and not _looks_like_raw_sender_id(value):
|
||
return value
|
||
return ""
|
||
|
||
|
||
def _message_date(value: str | None) -> str | None:
|
||
value = (value or "").strip()
|
||
if not value:
|
||
return None
|
||
try:
|
||
return datetime.fromisoformat(value.replace("Z", "+00:00")).strftime("%Y-%m-%d")
|
||
except Exception:
|
||
return value[:10] if len(value) >= 10 and value[4:5] == "-" else None
|
||
|
||
|
||
async def _build_sender_name_map(talker: str, messages: list[dict]) -> dict[str, str]:
|
||
names: dict[str, str] = {}
|
||
|
||
for msg in messages:
|
||
sender = str(msg.get("sender") or "").strip()
|
||
name = _sender_display_name(msg, sender)
|
||
if sender and name:
|
||
names[sender] = name
|
||
|
||
missing = {
|
||
str(msg.get("sender") or "").strip()
|
||
for msg in messages
|
||
if str(msg.get("sender") or "").strip()
|
||
and _looks_like_raw_sender_id(str(msg.get("sender_name") or "").strip())
|
||
}
|
||
missing = {sender for sender in missing if sender not in names}
|
||
if not missing:
|
||
return names
|
||
|
||
from services.chatlog_client import chatlog_client
|
||
|
||
dates = sorted({
|
||
date
|
||
for msg in messages
|
||
for date in [_message_date(str(msg.get("create_time") or ""))]
|
||
if date
|
||
})
|
||
if dates:
|
||
time_range = f"{dates[0]},{dates[-1]}"
|
||
offset = 0
|
||
seen = 0
|
||
while missing and seen < NAME_LOOKUP_MAX_ITEMS:
|
||
try:
|
||
data = await chatlog_client.get_messages(
|
||
talker,
|
||
time=time_range,
|
||
limit=NAME_LOOKUP_PAGE_SIZE,
|
||
offset=offset,
|
||
)
|
||
except Exception as e:
|
||
print(f"Failed to fetch sender names talker={talker}: {e}")
|
||
break
|
||
items = data.get("items", []) or []
|
||
if not items:
|
||
break
|
||
for item in items:
|
||
sender = str(item.get("sender") or item.get("Sender") or "").strip()
|
||
if sender not in missing:
|
||
continue
|
||
name = _sender_display_name(item, sender)
|
||
if name:
|
||
names[sender] = name
|
||
missing.discard(sender)
|
||
seen += len(items)
|
||
offset += len(items)
|
||
total = int(data.get("total") or 0)
|
||
if total and offset >= total:
|
||
break
|
||
if len(items) < NAME_LOOKUP_PAGE_SIZE:
|
||
break
|
||
|
||
if missing:
|
||
try:
|
||
members = await chatlog_client.get_chatroom_members(talker)
|
||
raw_members = members.get("members", []) if isinstance(members, dict) else []
|
||
for member in raw_members:
|
||
sender = str(member.get("userName") or member.get("UserName") or "").strip()
|
||
if sender not in missing:
|
||
continue
|
||
name = _sender_display_name(
|
||
{
|
||
"displayName": member.get("displayName") or member.get("DisplayName"),
|
||
"nickName": member.get("nickName") or member.get("NickName"),
|
||
"remark": member.get("remark") or member.get("Remark"),
|
||
},
|
||
sender,
|
||
)
|
||
if name:
|
||
names[sender] = name
|
||
missing.discard(sender)
|
||
except Exception as e:
|
||
print(f"Failed to fetch chatroom members talker={talker}: {e}")
|
||
|
||
return names
|
||
|
||
|
||
async def _fill_sender_names(talker: str, messages: list[dict]) -> None:
|
||
if not talker or not messages:
|
||
return
|
||
names = await _build_sender_name_map(talker, messages)
|
||
if not names:
|
||
return
|
||
for msg in messages:
|
||
sender = str(msg.get("sender") or "").strip()
|
||
if not sender or sender not in names:
|
||
continue
|
||
current = str(msg.get("sender_name") or "").strip()
|
||
if not current or current == sender or _looks_like_raw_sender_id(current):
|
||
msg["sender_name"] = names[sender]
|
||
msg["senderName"] = names[sender]
|
||
|
||
@router.get("")
|
||
async def list_topics(
|
||
group_id: Optional[int] = None,
|
||
status: Optional[str] = None,
|
||
keyword: Optional[str] = None,
|
||
db: aiosqlite.Connection = Depends(get_db)
|
||
):
|
||
sql = "SELECT * FROM topics WHERE 1=1"
|
||
params = []
|
||
if group_id:
|
||
sql += " AND group_id=?"; params.append(group_id)
|
||
if status:
|
||
sql += " AND status=?"; params.append(status)
|
||
if keyword:
|
||
sql += " AND title LIKE ?"; params.append(f"%{keyword}%")
|
||
async with db.execute(sql, params) as cur:
|
||
return [dict(r) for r in await cur.fetchall()]
|
||
|
||
@router.post("")
|
||
async def create_topic(body: TopicCreate, db: aiosqlite.Connection = Depends(get_db)):
|
||
await db.execute(
|
||
"INSERT INTO topics (group_id, title, source) VALUES (?, ?, 'manual')",
|
||
(body.group_id, body.title),
|
||
)
|
||
await db.commit()
|
||
async with db.execute("SELECT * FROM topics ORDER BY id DESC LIMIT 1") as cur:
|
||
return dict(await cur.fetchone())
|
||
|
||
@router.get("/{topic_id}")
|
||
async def get_topic(topic_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||
async with db.execute("SELECT * FROM topics WHERE id=?", (topic_id,)) as cur:
|
||
row = await cur.fetchone()
|
||
if not row:
|
||
raise HTTPException(404, "not found")
|
||
|
||
# 拿到该话题下关联的所有 seq
|
||
async with db.execute("SELECT * FROM topic_messages WHERE topic_id=? ORDER BY msg_seq ASC", (topic_id,)) as cur:
|
||
msg_rows = await cur.fetchall()
|
||
|
||
msgs = []
|
||
if msg_rows:
|
||
# 获取群聊 ID
|
||
talker = msg_rows[0]["talker"]
|
||
seq_list = [r["msg_seq"] for r in msg_rows]
|
||
|
||
# 分批调用 5030 接口获取真实消息内容,避免大批量 batch 触发 500。
|
||
from services.chatlog_client import chatlog_client
|
||
fetched_by_seq: dict[int, dict] = {}
|
||
for i in range(0, len(seq_list), CHATLOG_BATCH_SIZE):
|
||
chunk = seq_list[i: i + CHATLOG_BATCH_SIZE]
|
||
try:
|
||
msgs_data = await chatlog_client.get_messages_batch(talker, chunk)
|
||
raw_items = msgs_data.get("items", []) or msgs_data.get("Items", [])
|
||
for item in raw_items:
|
||
normalized = _normalize_chatlog_message(item)
|
||
seq = normalized.get("seq")
|
||
if seq:
|
||
fetched_by_seq[int(seq)] = normalized
|
||
except Exception as e:
|
||
print(f"Failed to fetch real messages chunk topic={topic_id}: {e}")
|
||
|
||
for r in msg_rows:
|
||
seq = int(r["msg_seq"])
|
||
if seq in fetched_by_seq:
|
||
msgs.append(fetched_by_seq[seq])
|
||
continue
|
||
snap = _message_from_snapshot(r["message_json"] if "message_json" in r.keys() else None, seq)
|
||
if snap:
|
||
msgs.append(snap)
|
||
continue
|
||
msgs.append({
|
||
"seq": seq,
|
||
"sender": "",
|
||
"sender_name": "系统提示",
|
||
"create_time": "",
|
||
"content": f"原始消息无法从 chatlog 找回 (seq: {seq})",
|
||
"type": 1,
|
||
"sub_type": 0,
|
||
})
|
||
|
||
# 获取知识文档
|
||
if msg_rows:
|
||
await _fill_sender_names(talker, msgs)
|
||
|
||
async with db.execute("SELECT id, topic_id, content, created_at, updated_at FROM knowledge_docs WHERE topic_id=?", (topic_id,)) as cur:
|
||
doc = await cur.fetchone()
|
||
|
||
return {**dict(row), "messages": msgs, "knowledge_doc": dict(doc) if doc else None}
|
||
|
||
@router.patch("/{topic_id}")
|
||
async def patch_topic(topic_id: int, body: TopicPatch, db: aiosqlite.Connection = Depends(get_db)):
|
||
if body.title:
|
||
await db.execute(
|
||
"UPDATE topics SET title=?, source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||
(body.title, topic_id),
|
||
)
|
||
if body.status:
|
||
await db.execute(
|
||
"UPDATE topics SET status=?, source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||
(body.status, topic_id),
|
||
)
|
||
await db.commit()
|
||
async with db.execute("SELECT * FROM topics WHERE id=?", (topic_id,)) as cur:
|
||
return dict(await cur.fetchone())
|
||
|
||
@router.delete("/{topic_id}")
|
||
async def delete_topic(topic_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||
# 先拿到 doc_id,用于清 FTS
|
||
async with db.execute("SELECT id FROM knowledge_docs WHERE topic_id=?", (topic_id,)) as cur:
|
||
doc_row = await cur.fetchone()
|
||
if doc_row:
|
||
await db.execute("DELETE FROM knowledge_fts WHERE doc_id=?", (doc_row["id"],))
|
||
await db.execute("DELETE FROM topic_messages WHERE topic_id=?", (topic_id,))
|
||
await db.execute("DELETE FROM knowledge_docs WHERE topic_id=?", (topic_id,))
|
||
await db.execute("DELETE FROM topics WHERE id=?", (topic_id,))
|
||
await db.commit()
|
||
return {"ok": True}
|
||
|
||
@router.post("/{topic_id}/messages")
|
||
async def add_message(topic_id: int, body: MessageAdd, db: aiosqlite.Connection = Depends(get_db)):
|
||
await db.execute(
|
||
"INSERT OR IGNORE INTO topic_messages (topic_id, msg_seq, talker, added_by) VALUES (?, ?, ?, 'user')",
|
||
(topic_id, body.msg_seq, body.talker)
|
||
)
|
||
await db.execute(
|
||
"UPDATE topics SET source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||
(topic_id,),
|
||
)
|
||
await db.commit()
|
||
return {"ok": True}
|
||
|
||
@router.delete("/{topic_id}/messages/{seq}")
|
||
async def remove_message(topic_id: int, seq: int, db: aiosqlite.Connection = Depends(get_db)):
|
||
await db.execute("DELETE FROM topic_messages WHERE topic_id=? AND msg_seq=?", (topic_id, seq))
|
||
await db.execute(
|
||
"UPDATE topics SET source='manual', updated_at=CURRENT_TIMESTAMP WHERE id=?",
|
||
(topic_id,),
|
||
)
|
||
await db.commit()
|
||
return {"ok": True}
|
||
|
||
@router.post("/{topic_id}/summarize")
|
||
async def summarize(topic_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||
async with db.execute("SELECT * FROM topics WHERE id=?", (topic_id,)) as cur:
|
||
topic = await cur.fetchone()
|
||
if not topic:
|
||
raise HTTPException(404, "not found")
|
||
topic_data = dict(topic)
|
||
await _mark_stale_summarize_tasks(db, topic_data["group_id"], topic_id)
|
||
# 创建 ai_tasks 记录以追踪进度
|
||
await db.execute(
|
||
"INSERT INTO ai_tasks (group_id, type, status, progress) VALUES (?, 'summarize', 'running', ?)",
|
||
(topic_data["group_id"], json.dumps({"processed": 0, "total": 1}))
|
||
)
|
||
await db.commit()
|
||
async with db.execute("SELECT last_insert_rowid() AS id") as cur:
|
||
task_row = await cur.fetchone()
|
||
task_id = task_row["id"]
|
||
from services.summary_engine import run_summarize
|
||
import asyncio
|
||
asyncio.create_task(run_summarize(topic_id, topic_data, task_id))
|
||
return {"ok": True, "task_id": task_id}
|