152 lines
5.3 KiB
Python
152 lines
5.3 KiB
Python
import json
|
||
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
||
from pathlib import Path
|
||
from uuid import uuid4
|
||
|
||
from fastapi import FastAPI, HTTPException
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import FileResponse, StreamingResponse
|
||
|
||
|
||
from app.core.config import settings
|
||
from app.core.bootstrap import build_agent_service_with_runtime, build_intent_registry
|
||
from app.schemas.chat import ChatRequest, ChatResponse, FillSlotsRequest
|
||
from app.schemas.demo import DemoRuntimeConfig, DemoRuntimeUpdateRequest
|
||
|
||
app = FastAPI(title=settings.app_name)
|
||
|
||
# CORS:允许 Canvas 前端跨域调用
|
||
# 生产环境请将 allow_origins 替换为实际前端域名
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["http://localhost:3000", "http://127.0.0.1:3000"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
demo_html_path = Path(__file__).parent / "static" / "demo.html"
|
||
chat_stream_executor = ThreadPoolExecutor(max_workers=8)
|
||
|
||
runtime_config = DemoRuntimeConfig(
|
||
matcher_pipeline=settings.matcher_pipeline,
|
||
classifier_backend=settings.classifier_backend,
|
||
session_backend=settings.session_backend,
|
||
slot_extractor_backend=settings.slot_extractor_backend,
|
||
planner_backend=settings.planner_backend,
|
||
planner_model_name=settings.planner_model_name,
|
||
)
|
||
agent_service = build_agent_service_with_runtime(
|
||
matcher_pipeline=runtime_config.matcher_pipeline,
|
||
classifier_backend=runtime_config.classifier_backend,
|
||
session_backend=runtime_config.session_backend,
|
||
)
|
||
intent_registry = build_intent_registry()
|
||
|
||
|
||
@app.get("/health")
|
||
def health() -> dict[str, str]:
|
||
return {"status": "ok", "env": settings.app_env}
|
||
|
||
|
||
@app.get("/")
|
||
@app.get("/demo")
|
||
def demo() -> FileResponse:
|
||
return FileResponse(demo_html_path)
|
||
|
||
|
||
@app.post("/api/v1/agent/chat", response_model=ChatResponse)
|
||
def chat(request: ChatRequest) -> ChatResponse:
|
||
return agent_service.handle_chat(request)
|
||
|
||
|
||
@app.post("/api/v1/agent/chat-stream")
|
||
def chat_stream(request: ChatRequest) -> StreamingResponse:
|
||
def stream():
|
||
future = chat_stream_executor.submit(agent_service.handle_chat, request)
|
||
try:
|
||
response = future.result(timeout=1.0)
|
||
except TimeoutError:
|
||
ack = {
|
||
"type": "ack",
|
||
"reply_text": "好的,正在处理中,请稍等一下。",
|
||
"status": "processing",
|
||
"trace_id": uuid4().hex,
|
||
}
|
||
yield json.dumps(ack, ensure_ascii=False) + "\n"
|
||
try:
|
||
response = future.result()
|
||
except Exception as exc: # pragma: no cover - stream error fallback
|
||
payload = {
|
||
"type": "error",
|
||
"message": str(exc),
|
||
}
|
||
yield json.dumps(payload, ensure_ascii=False) + "\n"
|
||
return
|
||
except Exception as exc: # pragma: no cover - stream error fallback
|
||
payload = {
|
||
"type": "error",
|
||
"message": str(exc),
|
||
}
|
||
yield json.dumps(payload, ensure_ascii=False) + "\n"
|
||
return
|
||
|
||
try:
|
||
payload = {
|
||
"type": "final",
|
||
"data": response.model_dump(mode="json"),
|
||
}
|
||
yield json.dumps(payload, ensure_ascii=False) + "\n"
|
||
except Exception as exc: # pragma: no cover - stream error fallback
|
||
payload = {
|
||
"type": "error",
|
||
"message": str(exc),
|
||
}
|
||
yield json.dumps(payload, ensure_ascii=False) + "\n"
|
||
|
||
return StreamingResponse(stream(), media_type="application/x-ndjson")
|
||
|
||
|
||
@app.post("/api/v1/agent/fill-slots", response_model=ChatResponse)
|
||
def fill_slots(request: FillSlotsRequest) -> ChatResponse:
|
||
return agent_service.handle_fill_slots(request)
|
||
|
||
|
||
@app.get("/api/v1/intents")
|
||
def list_intents() -> list[dict[str, object]]:
|
||
return [intent.model_dump() for intent in intent_registry.list()]
|
||
|
||
|
||
@app.get("/api/v1/demo/runtime", response_model=DemoRuntimeConfig)
|
||
def get_demo_runtime() -> DemoRuntimeConfig:
|
||
return runtime_config
|
||
|
||
|
||
@app.post("/api/v1/demo/runtime", response_model=DemoRuntimeConfig)
|
||
def update_demo_runtime(request: DemoRuntimeUpdateRequest) -> DemoRuntimeConfig:
|
||
global agent_service, runtime_config
|
||
|
||
matcher_stages = [stage.strip() for stage in request.matcher_pipeline.split(",") if stage.strip()]
|
||
if matcher_stages != ["classifier"]:
|
||
raise HTTPException(status_code=400, detail="Only classifier matcher pipeline is supported in bert-first mode")
|
||
|
||
try:
|
||
next_service = build_agent_service_with_runtime(
|
||
matcher_pipeline=request.matcher_pipeline,
|
||
classifier_backend=request.classifier_backend,
|
||
session_backend=request.session_backend,
|
||
)
|
||
except Exception as exc:
|
||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||
|
||
agent_service = next_service
|
||
runtime_config = DemoRuntimeConfig(
|
||
matcher_pipeline=request.matcher_pipeline,
|
||
classifier_backend=request.classifier_backend,
|
||
session_backend=request.session_backend,
|
||
slot_extractor_backend=settings.slot_extractor_backend,
|
||
planner_backend=settings.planner_backend,
|
||
planner_model_name=settings.planner_model_name,
|
||
)
|
||
return runtime_config
|