Initial upload for secondary development
This commit is contained in:
138
chatlog_fastAPI/routers/groups.py
Normal file
138
chatlog_fastAPI/routers/groups.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import aiosqlite, json
|
||||
from database import get_db
|
||||
|
||||
router = APIRouter(prefix="/api/groups", tags=["groups"])
|
||||
|
||||
class GroupCreate(BaseModel):
|
||||
talker: str
|
||||
name: Optional[str] = ""
|
||||
poll_interval: int = 300
|
||||
|
||||
class GroupPatch(BaseModel):
|
||||
analysis_prompt: Optional[str] = None
|
||||
|
||||
class InitParams(BaseModel):
|
||||
start_time: int # unix 秒
|
||||
end_time: int # unix 秒
|
||||
|
||||
@router.post("")
|
||||
async def create_group(body: GroupCreate, db: aiosqlite.Connection = Depends(get_db)):
|
||||
try:
|
||||
await db.execute(
|
||||
"INSERT INTO groups (talker, name, poll_interval) VALUES (?, ?, ?)",
|
||||
(body.talker, body.name, body.poll_interval)
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT * FROM groups WHERE talker=?", (body.talker,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
return dict(row)
|
||||
except aiosqlite.IntegrityError:
|
||||
raise HTTPException(409, "talker already exists")
|
||||
|
||||
@router.get("")
|
||||
async def list_groups(db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT * FROM groups") as cur:
|
||||
rows = await cur.fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
@router.patch("/{group_id}")
|
||||
async def patch_group(group_id: int, body: GroupPatch, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT id FROM groups WHERE id=?", (group_id,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "group not found")
|
||||
updates = body.model_dump(exclude_none=True)
|
||||
if "analysis_prompt" in updates:
|
||||
await db.execute(
|
||||
"UPDATE groups SET analysis_prompt=? WHERE id=?",
|
||||
(updates["analysis_prompt"], group_id),
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT * FROM groups WHERE id=?", (group_id,)) as cur:
|
||||
return dict(await cur.fetchone())
|
||||
|
||||
@router.post("/{group_id}/init")
|
||||
async def trigger_init(
|
||||
group_id: int,
|
||||
body: InitParams,
|
||||
db: aiosqlite.Connection = Depends(get_db),
|
||||
):
|
||||
"""对指定时间区间内全部消息做一次性 AI 分类。串行执行。"""
|
||||
async with db.execute("SELECT * FROM groups WHERE id=?", (group_id,)) as cur:
|
||||
group = await cur.fetchone()
|
||||
if not group:
|
||||
raise HTTPException(404, "group not found")
|
||||
|
||||
from services.topic_engine import get_classifying_group
|
||||
busy = get_classifying_group()
|
||||
if busy is not None:
|
||||
raise HTTPException(409, f"已有群正在分析(group_id={busy}),请等待完成后再试")
|
||||
|
||||
if body.end_time <= body.start_time:
|
||||
raise HTTPException(400, "end_time 必须大于 start_time")
|
||||
|
||||
await db.execute(
|
||||
"INSERT INTO ai_tasks (group_id, type, status, progress) VALUES (?, 'classify_window', 'running', ?)",
|
||||
(group_id, json.dumps({"processed": 0, "total": 0})),
|
||||
)
|
||||
await db.commit()
|
||||
async with db.execute("SELECT last_insert_rowid()") as cur:
|
||||
task = await cur.fetchone()
|
||||
task_id = task[0]
|
||||
|
||||
from services.topic_engine import run_classify_window
|
||||
import asyncio
|
||||
asyncio.create_task(
|
||||
run_classify_window(group_id, task_id, dict(group), body.start_time, body.end_time)
|
||||
)
|
||||
return {"task_id": task_id}
|
||||
|
||||
@router.get("/{group_id}/task")
|
||||
async def get_task(group_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute(
|
||||
"SELECT * FROM ai_tasks WHERE group_id=? ORDER BY id DESC LIMIT 1", (group_id,)
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "no task found")
|
||||
return dict(row)
|
||||
|
||||
@router.delete("/{group_id}")
|
||||
async def delete_group(group_id: int, db: aiosqlite.Connection = Depends(get_db)):
|
||||
async with db.execute("SELECT id FROM groups WHERE id=?", (group_id,)) as cur:
|
||||
row = await cur.fetchone()
|
||||
if not row:
|
||||
raise HTTPException(404, "group not found")
|
||||
# 级联删除关联数据(FTS → 知识库报告 → 话题消息 → 话题 → 任务 → 群组)
|
||||
# 注意:必须先清 knowledge_fts/knowledge_docs,否则 SQLite 复用 topic_id 时
|
||||
# 残留报告会"接"到新建话题上,造成跨群串报告。
|
||||
await db.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_fts WHERE doc_id IN (
|
||||
SELECT id FROM knowledge_docs WHERE topic_id IN (
|
||||
SELECT id FROM topics WHERE group_id=?
|
||||
)
|
||||
)
|
||||
""",
|
||||
(group_id,)
|
||||
)
|
||||
await db.execute(
|
||||
"""
|
||||
DELETE FROM knowledge_docs WHERE topic_id IN (
|
||||
SELECT id FROM topics WHERE group_id=?
|
||||
)
|
||||
""",
|
||||
(group_id,)
|
||||
)
|
||||
await db.execute(
|
||||
"DELETE FROM topic_messages WHERE topic_id IN (SELECT id FROM topics WHERE group_id=?)",
|
||||
(group_id,)
|
||||
)
|
||||
await db.execute("DELETE FROM topics WHERE group_id=?", (group_id,))
|
||||
await db.execute("DELETE FROM ai_tasks WHERE group_id=?", (group_id,))
|
||||
await db.execute("DELETE FROM groups WHERE id=?", (group_id,))
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
Reference in New Issue
Block a user