152 lines
4.8 KiB
Python
152 lines
4.8 KiB
Python
from fastapi import FastAPI, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse, JSONResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
from contextlib import asynccontextmanager
|
|
import asyncio
|
|
import logging
|
|
from pathlib import Path
|
|
import httpx
|
|
from database import get_active_db_path, get_current_wxid, init_db, reset_wxid_cache, update_db_path
|
|
from scheduler import start_scheduler
|
|
from config import settings
|
|
from routers import search, groups, topics, knowledge, ai, sse, files, chatlog_proxy
|
|
from routers import settings as settings_router
|
|
from services.chatlog_context import get_chatlog_context, update_chatlog_context
|
|
from services.media_resolver import diagnose_media
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class ChatlogContextRequest(BaseModel):
|
|
account: str = ""
|
|
workDir: str = ""
|
|
dataDir: str = ""
|
|
platform: str = "windows"
|
|
version: int = 4
|
|
chatlogExe: str = ""
|
|
chatlogVersion: str = ""
|
|
|
|
async def _account_watch_loop():
|
|
"""每 60 秒检测一次当前微信账号,如账号切换则自动切换数据库。"""
|
|
while True:
|
|
await asyncio.sleep(60)
|
|
try:
|
|
await update_db_path()
|
|
except Exception as e:
|
|
log.warning(f"[account_watch] update_db_path error: {e}")
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
await init_db()
|
|
await start_scheduler()
|
|
# 启动后台账号监控任务
|
|
task = asyncio.create_task(_account_watch_loop())
|
|
yield
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
@app.exception_handler(RuntimeError)
|
|
async def runtime_error_handler(request: Request, exc: RuntimeError):
|
|
return JSONResponse(status_code=500, content={"detail": str(exc)})
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.cors_origins,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
app.include_router(search.router)
|
|
app.include_router(groups.router)
|
|
app.include_router(topics.router)
|
|
app.include_router(knowledge.router)
|
|
app.include_router(ai.router)
|
|
app.include_router(sse.router)
|
|
app.include_router(files.router)
|
|
app.include_router(settings_router.router)
|
|
app.include_router(chatlog_proxy.router)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
chatlog_ok = False
|
|
chatlog_error = ""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=3.0, trust_env=False) as client:
|
|
resp = await client.get(f"{settings.chatlog_base_url}/api/v1/session", params={"limit": 1, "format": "json"})
|
|
chatlog_ok = resp.status_code == 200
|
|
if not chatlog_ok:
|
|
chatlog_error = f"HTTP {resp.status_code}"
|
|
except Exception as e:
|
|
chatlog_error = str(e)
|
|
|
|
wxid = await get_current_wxid() if chatlog_ok else "default"
|
|
return {
|
|
"ok": True,
|
|
"chatlog_ok": chatlog_ok,
|
|
"chatlog_error": chatlog_error,
|
|
"wxid": wxid,
|
|
"db_path": get_active_db_path(),
|
|
"data_dir": settings.data_dir,
|
|
}
|
|
|
|
|
|
@app.post("/api/system/refresh-account")
|
|
async def refresh_account():
|
|
reset_wxid_cache()
|
|
db_path = await update_db_path(force=True)
|
|
wxid = await get_current_wxid()
|
|
return {"ok": True, "wxid": wxid, "db_path": db_path}
|
|
|
|
|
|
@app.post("/api/system/chatlog-context")
|
|
async def set_chatlog_context(body: ChatlogContextRequest):
|
|
return {"ok": True, "context": update_chatlog_context(body.model_dump())}
|
|
|
|
|
|
@app.get("/api/system/chatlog-context")
|
|
async def read_chatlog_context():
|
|
return {"ok": True, "context": get_chatlog_context()}
|
|
|
|
|
|
@app.get("/api/system/media-diagnostics")
|
|
async def media_diagnostics(kind: str = "voice", key: str = ""):
|
|
return await diagnose_media(kind, key)
|
|
|
|
|
|
static_dir = Path(settings.static_dir)
|
|
if static_dir.exists():
|
|
assets_dir = static_dir / "assets"
|
|
if assets_dir.exists():
|
|
app.mount("/assets", StaticFiles(directory=str(assets_dir)), name="assets")
|
|
for static_name in ("favicon.svg", "icons.svg"):
|
|
static_file = static_dir / static_name
|
|
|
|
if static_file.exists():
|
|
@app.get(f"/{static_name}", include_in_schema=False)
|
|
async def _serve_static_file(name=static_name):
|
|
return FileResponse(static_dir / name)
|
|
|
|
@app.get("/", include_in_schema=False)
|
|
async def spa_index():
|
|
return FileResponse(static_dir / "index.html")
|
|
|
|
@app.get("/{full_path:path}", include_in_schema=False)
|
|
async def spa_fallback(full_path: str):
|
|
path = static_dir / full_path
|
|
if path.exists() and path.is_file():
|
|
return FileResponse(path)
|
|
return FileResponse(static_dir / "index.html")
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
# 为了在使用 PyInstaller 打包时也能正常运行
|
|
uvicorn.run(app, host="127.0.0.1", port=8000, reload=False)
|