import json from concurrent.futures import ThreadPoolExecutor, TimeoutError from pathlib import Path from uuid import uuid4 from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, StreamingResponse from app.core.config import settings from app.core.bootstrap import build_agent_service_with_runtime, build_intent_registry from app.schemas.chat import ChatRequest, ChatResponse, FillSlotsRequest from app.schemas.demo import DemoRuntimeConfig, DemoRuntimeUpdateRequest app = FastAPI(title=settings.app_name) # CORS:允许 Canvas 前端跨域调用 # 生产环境请将 allow_origins 替换为实际前端域名 app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000", "http://127.0.0.1:3000"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) demo_html_path = Path(__file__).parent / "static" / "demo.html" chat_stream_executor = ThreadPoolExecutor(max_workers=8) runtime_config = DemoRuntimeConfig( matcher_pipeline=settings.matcher_pipeline, classifier_backend=settings.classifier_backend, session_backend=settings.session_backend, slot_extractor_backend=settings.slot_extractor_backend, planner_backend=settings.planner_backend, planner_model_name=settings.planner_model_name, ) agent_service = build_agent_service_with_runtime( matcher_pipeline=runtime_config.matcher_pipeline, classifier_backend=runtime_config.classifier_backend, session_backend=runtime_config.session_backend, ) intent_registry = build_intent_registry() @app.get("/health") def health() -> dict[str, str]: return {"status": "ok", "env": settings.app_env} @app.get("/") @app.get("/demo") def demo() -> FileResponse: return FileResponse(demo_html_path) @app.post("/api/v1/agent/chat", response_model=ChatResponse) def chat(request: ChatRequest) -> ChatResponse: return agent_service.handle_chat(request) @app.post("/api/v1/agent/chat-stream") def chat_stream(request: ChatRequest) -> StreamingResponse: def stream(): future = chat_stream_executor.submit(agent_service.handle_chat, request) try: response = future.result(timeout=1.0) except TimeoutError: ack = { "type": "ack", "reply_text": "好的,正在处理中,请稍等一下。", "status": "processing", "trace_id": uuid4().hex, } yield json.dumps(ack, ensure_ascii=False) + "\n" try: response = future.result() except Exception as exc: # pragma: no cover - stream error fallback payload = { "type": "error", "message": str(exc), } yield json.dumps(payload, ensure_ascii=False) + "\n" return except Exception as exc: # pragma: no cover - stream error fallback payload = { "type": "error", "message": str(exc), } yield json.dumps(payload, ensure_ascii=False) + "\n" return try: payload = { "type": "final", "data": response.model_dump(mode="json"), } yield json.dumps(payload, ensure_ascii=False) + "\n" except Exception as exc: # pragma: no cover - stream error fallback payload = { "type": "error", "message": str(exc), } yield json.dumps(payload, ensure_ascii=False) + "\n" return StreamingResponse(stream(), media_type="application/x-ndjson") @app.post("/api/v1/agent/fill-slots", response_model=ChatResponse) def fill_slots(request: FillSlotsRequest) -> ChatResponse: return agent_service.handle_fill_slots(request) @app.get("/api/v1/intents") def list_intents() -> list[dict[str, object]]: return [intent.model_dump() for intent in intent_registry.list()] @app.get("/api/v1/demo/runtime", response_model=DemoRuntimeConfig) def get_demo_runtime() -> DemoRuntimeConfig: return runtime_config @app.post("/api/v1/demo/runtime", response_model=DemoRuntimeConfig) def update_demo_runtime(request: DemoRuntimeUpdateRequest) -> DemoRuntimeConfig: global agent_service, runtime_config matcher_stages = [stage.strip() for stage in request.matcher_pipeline.split(",") if stage.strip()] if matcher_stages != ["classifier"]: raise HTTPException(status_code=400, detail="Only classifier matcher pipeline is supported in bert-first mode") try: next_service = build_agent_service_with_runtime( matcher_pipeline=request.matcher_pipeline, classifier_backend=request.classifier_backend, session_backend=request.session_backend, ) except Exception as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc agent_service = next_service runtime_config = DemoRuntimeConfig( matcher_pipeline=request.matcher_pipeline, classifier_backend=request.classifier_backend, session_backend=request.session_backend, slot_extractor_backend=settings.slot_extractor_backend, planner_backend=settings.planner_backend, planner_model_name=settings.planner_model_name, ) return runtime_config