139 lines
5.0 KiB
Python
139 lines
5.0 KiB
Python
from fastapi import APIRouter, Depends, HTTPException
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
import aiosqlite, json
|
||
from database import get_db
|
||
|
||
router = APIRouter(prefix="/api/groups", tags=["groups"])
|
||
|
||
class GroupCreate(BaseModel):
|
||
talker: str
|
||
name: Optional[str] = ""
|
||
poll_interval: int = 300
|
||
|
||
class GroupPatch(BaseModel):
|
||
analysis_prompt: Optional[str] = None
|
||
|
||
class InitParams(BaseModel):
|
||
start_time: int # unix 秒
|
||
end_time: int # unix 秒
|
||
|
||
@router.post("")
|
||
async def create_group(body: GroupCreate, db: aiosqlite.Connection = Depends(get_db)):
|
||
try:
|
||
await db.execute(
|
||
"INSERT INTO groups (talker, name, poll_interval) VALUES (?, ?, ?)",
|
||
(body.talker, body.name, body.poll_interval)
|
||
)
|
||
await db.commit()
|
||
async with db.execute("SELECT * FROM groups WHERE talker=?", (body.talker,)) as cur:
|
||
row = await cur.fetchone()
|
||
return dict(row)
|
||
except aiosqlite.IntegrityError:
|
||
raise HTTPException(409, "talker already exists")
|
||
|
||
@router.get("")
|
||
async def list_groups(db: aiosqlite.Connection = Depends(get_db)):
|
||
async with db.execute("SELECT * FROM groups") as cur:
|
||
rows = await cur.fetchall()
|
||
return [dict(r) for r in rows]
|
||
|
||
@router.patch("/{group_id}")
|
||
async def patch_group(group_id: int, body: GroupPatch, db: aiosqlite.Connection = Depends(get_db)):
|
||
async with db.execute("SELECT id FROM groups WHERE id=?", (group_id,)) as cur:
|
||
row = await cur.fetchone()
|
||
if not row:
|
||
raise HTTPException(404, "group not found")
|
||
updates = body.model_dump(exclude_none=True)
|
||
if "analysis_prompt" in updates:
|
||
await db.execute(
|
||
"UPDATE groups SET analysis_prompt=? WHERE id=?",
|
||
(updates["analysis_prompt"], group_id),
|
||
)
|
||
await db.commit()
|
||
async with db.execute("SELECT * FROM groups WHERE id=?", (group_id,)) as cur:
|
||
return dict(await cur.fetchone())
|
||
|
||
@router.post("/{group_id}/init")
|
||
async def trigger_init(
|
||
group_id: int,
|
||
body: InitParams,
|
||
db: aiosqlite.Connection = Depends(get_db),
|
||
):
|
||
"""对指定时间区间内全部消息做一次性 AI 分类。串行执行。"""
|
||
async with db.execute("SELECT * FROM groups WHERE id=?", (group_id,)) as cur:
|
||
group = await cur.fetchone()
|
||
if not group:
|
||
raise HTTPException(404, "group not found")
|
||
|
||
from services.topic_engine import get_classifying_group
|
||
busy = get_classifying_group()
|
||
if busy is not None:
|
||
raise HTTPException(409, f"已有群正在分析(group_id={busy}),请等待完成后再试")
|
||
|
||
if body.end_time <= body.start_time:
|
||
raise HTTPException(400, "end_time 必须大于 start_time")
|
||
|
||
await db.execute(
|
||
"INSERT INTO ai_tasks (group_id, type, status, progress) VALUES (?, 'classify_window', 'running', ?)",
|
||
(group_id, json.dumps({"processed": 0, "total": 0})),
|
||
)
|
||
await db.commit()
|
||
async with db.execute("SELECT last_insert_rowid()") as cur:
|
||
task = await cur.fetchone()
|
||
task_id = task[0]
|
||
|
||
from services.topic_engine import run_classify_window
|
||
import asyncio
|
||
asyncio.create_task(
|
||
run_classify_window(group_id, task_id, dict(group), body.start_time, body.end_time)
|
||
)
|
||
return {"task_id": task_id}
|
||
|
||
@router.get("/{group_id}/task")
|
||
async def get_task(group_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||
async with db.execute(
|
||
"SELECT * FROM ai_tasks WHERE group_id=? ORDER BY id DESC LIMIT 1", (group_id,)
|
||
) as cur:
|
||
row = await cur.fetchone()
|
||
if not row:
|
||
raise HTTPException(404, "no task found")
|
||
return dict(row)
|
||
|
||
@router.delete("/{group_id}")
|
||
async def delete_group(group_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||
async with db.execute("SELECT id FROM groups WHERE id=?", (group_id,)) as cur:
|
||
row = await cur.fetchone()
|
||
if not row:
|
||
raise HTTPException(404, "group not found")
|
||
# 级联删除关联数据(FTS → 知识库报告 → 话题消息 → 话题 → 任务 → 群组)
|
||
# 注意:必须先清 knowledge_fts/knowledge_docs,否则 SQLite 复用 topic_id 时
|
||
# 残留报告会"接"到新建话题上,造成跨群串报告。
|
||
await db.execute(
|
||
"""
|
||
DELETE FROM knowledge_fts WHERE doc_id IN (
|
||
SELECT id FROM knowledge_docs WHERE topic_id IN (
|
||
SELECT id FROM topics WHERE group_id=?
|
||
)
|
||
)
|
||
""",
|
||
(group_id,)
|
||
)
|
||
await db.execute(
|
||
"""
|
||
DELETE FROM knowledge_docs WHERE topic_id IN (
|
||
SELECT id FROM topics WHERE group_id=?
|
||
)
|
||
""",
|
||
(group_id,)
|
||
)
|
||
await db.execute(
|
||
"DELETE FROM topic_messages WHERE topic_id IN (SELECT id FROM topics WHERE group_id=?)",
|
||
(group_id,)
|
||
)
|
||
await db.execute("DELETE FROM topics WHERE group_id=?", (group_id,))
|
||
await db.execute("DELETE FROM ai_tasks WHERE group_id=?", (group_id,))
|
||
await db.execute("DELETE FROM groups WHERE id=?", (group_id,))
|
||
await db.commit()
|
||
return {"ok": True}
|