Files
ai-device/intelligent_cabin/app/main.py
2026-06-11 16:28:00 +08:00

152 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from pathlib import Path
from uuid import uuid4
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from app.core.config import settings
from app.core.bootstrap import build_agent_service_with_runtime, build_intent_registry
from app.schemas.chat import ChatRequest, ChatResponse, FillSlotsRequest
from app.schemas.demo import DemoRuntimeConfig, DemoRuntimeUpdateRequest
app = FastAPI(title=settings.app_name)
# CORS允许 Canvas 前端跨域调用
# 生产环境请将 allow_origins 替换为实际前端域名
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://127.0.0.1:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
demo_html_path = Path(__file__).parent / "static" / "demo.html"
chat_stream_executor = ThreadPoolExecutor(max_workers=8)
runtime_config = DemoRuntimeConfig(
matcher_pipeline=settings.matcher_pipeline,
classifier_backend=settings.classifier_backend,
session_backend=settings.session_backend,
slot_extractor_backend=settings.slot_extractor_backend,
planner_backend=settings.planner_backend,
planner_model_name=settings.planner_model_name,
)
agent_service = build_agent_service_with_runtime(
matcher_pipeline=runtime_config.matcher_pipeline,
classifier_backend=runtime_config.classifier_backend,
session_backend=runtime_config.session_backend,
)
intent_registry = build_intent_registry()
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok", "env": settings.app_env}
@app.get("/")
@app.get("/demo")
def demo() -> FileResponse:
return FileResponse(demo_html_path)
@app.post("/api/v1/agent/chat", response_model=ChatResponse)
def chat(request: ChatRequest) -> ChatResponse:
return agent_service.handle_chat(request)
@app.post("/api/v1/agent/chat-stream")
def chat_stream(request: ChatRequest) -> StreamingResponse:
def stream():
future = chat_stream_executor.submit(agent_service.handle_chat, request)
try:
response = future.result(timeout=1.0)
except TimeoutError:
ack = {
"type": "ack",
"reply_text": "好的,正在处理中,请稍等一下。",
"status": "processing",
"trace_id": uuid4().hex,
}
yield json.dumps(ack, ensure_ascii=False) + "\n"
try:
response = future.result()
except Exception as exc: # pragma: no cover - stream error fallback
payload = {
"type": "error",
"message": str(exc),
}
yield json.dumps(payload, ensure_ascii=False) + "\n"
return
except Exception as exc: # pragma: no cover - stream error fallback
payload = {
"type": "error",
"message": str(exc),
}
yield json.dumps(payload, ensure_ascii=False) + "\n"
return
try:
payload = {
"type": "final",
"data": response.model_dump(mode="json"),
}
yield json.dumps(payload, ensure_ascii=False) + "\n"
except Exception as exc: # pragma: no cover - stream error fallback
payload = {
"type": "error",
"message": str(exc),
}
yield json.dumps(payload, ensure_ascii=False) + "\n"
return StreamingResponse(stream(), media_type="application/x-ndjson")
@app.post("/api/v1/agent/fill-slots", response_model=ChatResponse)
def fill_slots(request: FillSlotsRequest) -> ChatResponse:
return agent_service.handle_fill_slots(request)
@app.get("/api/v1/intents")
def list_intents() -> list[dict[str, object]]:
return [intent.model_dump() for intent in intent_registry.list()]
@app.get("/api/v1/demo/runtime", response_model=DemoRuntimeConfig)
def get_demo_runtime() -> DemoRuntimeConfig:
return runtime_config
@app.post("/api/v1/demo/runtime", response_model=DemoRuntimeConfig)
def update_demo_runtime(request: DemoRuntimeUpdateRequest) -> DemoRuntimeConfig:
global agent_service, runtime_config
matcher_stages = [stage.strip() for stage in request.matcher_pipeline.split(",") if stage.strip()]
if matcher_stages != ["classifier"]:
raise HTTPException(status_code=400, detail="Only classifier matcher pipeline is supported in bert-first mode")
try:
next_service = build_agent_service_with_runtime(
matcher_pipeline=request.matcher_pipeline,
classifier_backend=request.classifier_backend,
session_backend=request.session_backend,
)
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
agent_service = next_service
runtime_config = DemoRuntimeConfig(
matcher_pipeline=request.matcher_pipeline,
classifier_backend=request.classifier_backend,
session_backend=request.session_backend,
slot_extractor_backend=settings.slot_extractor_backend,
planner_backend=settings.planner_backend,
planner_model_name=settings.planner_model_name,
)
return runtime_config