Update project and configurations
This commit is contained in:
151
intelligent_cabin/app/main.py
Normal file
151
intelligent_cabin/app/main.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.bootstrap import build_agent_service_with_runtime, build_intent_registry
|
||||
from app.schemas.chat import ChatRequest, ChatResponse, FillSlotsRequest
|
||||
from app.schemas.demo import DemoRuntimeConfig, DemoRuntimeUpdateRequest
|
||||
|
||||
app = FastAPI(title=settings.app_name)
|
||||
|
||||
# CORS:允许 Canvas 前端跨域调用
|
||||
# 生产环境请将 allow_origins 替换为实际前端域名
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:3000", "http://127.0.0.1:3000"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
demo_html_path = Path(__file__).parent / "static" / "demo.html"
|
||||
chat_stream_executor = ThreadPoolExecutor(max_workers=8)
|
||||
|
||||
runtime_config = DemoRuntimeConfig(
|
||||
matcher_pipeline=settings.matcher_pipeline,
|
||||
classifier_backend=settings.classifier_backend,
|
||||
session_backend=settings.session_backend,
|
||||
slot_extractor_backend=settings.slot_extractor_backend,
|
||||
planner_backend=settings.planner_backend,
|
||||
planner_model_name=settings.planner_model_name,
|
||||
)
|
||||
agent_service = build_agent_service_with_runtime(
|
||||
matcher_pipeline=runtime_config.matcher_pipeline,
|
||||
classifier_backend=runtime_config.classifier_backend,
|
||||
session_backend=runtime_config.session_backend,
|
||||
)
|
||||
intent_registry = build_intent_registry()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict[str, str]:
|
||||
return {"status": "ok", "env": settings.app_env}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
@app.get("/demo")
|
||||
def demo() -> FileResponse:
|
||||
return FileResponse(demo_html_path)
|
||||
|
||||
|
||||
@app.post("/api/v1/agent/chat", response_model=ChatResponse)
|
||||
def chat(request: ChatRequest) -> ChatResponse:
|
||||
return agent_service.handle_chat(request)
|
||||
|
||||
|
||||
@app.post("/api/v1/agent/chat-stream")
|
||||
def chat_stream(request: ChatRequest) -> StreamingResponse:
|
||||
def stream():
|
||||
future = chat_stream_executor.submit(agent_service.handle_chat, request)
|
||||
try:
|
||||
response = future.result(timeout=1.0)
|
||||
except TimeoutError:
|
||||
ack = {
|
||||
"type": "ack",
|
||||
"reply_text": "好的,正在处理中,请稍等一下。",
|
||||
"status": "processing",
|
||||
"trace_id": uuid4().hex,
|
||||
}
|
||||
yield json.dumps(ack, ensure_ascii=False) + "\n"
|
||||
try:
|
||||
response = future.result()
|
||||
except Exception as exc: # pragma: no cover - stream error fallback
|
||||
payload = {
|
||||
"type": "error",
|
||||
"message": str(exc),
|
||||
}
|
||||
yield json.dumps(payload, ensure_ascii=False) + "\n"
|
||||
return
|
||||
except Exception as exc: # pragma: no cover - stream error fallback
|
||||
payload = {
|
||||
"type": "error",
|
||||
"message": str(exc),
|
||||
}
|
||||
yield json.dumps(payload, ensure_ascii=False) + "\n"
|
||||
return
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"type": "final",
|
||||
"data": response.model_dump(mode="json"),
|
||||
}
|
||||
yield json.dumps(payload, ensure_ascii=False) + "\n"
|
||||
except Exception as exc: # pragma: no cover - stream error fallback
|
||||
payload = {
|
||||
"type": "error",
|
||||
"message": str(exc),
|
||||
}
|
||||
yield json.dumps(payload, ensure_ascii=False) + "\n"
|
||||
|
||||
return StreamingResponse(stream(), media_type="application/x-ndjson")
|
||||
|
||||
|
||||
@app.post("/api/v1/agent/fill-slots", response_model=ChatResponse)
|
||||
def fill_slots(request: FillSlotsRequest) -> ChatResponse:
|
||||
return agent_service.handle_fill_slots(request)
|
||||
|
||||
|
||||
@app.get("/api/v1/intents")
|
||||
def list_intents() -> list[dict[str, object]]:
|
||||
return [intent.model_dump() for intent in intent_registry.list()]
|
||||
|
||||
|
||||
@app.get("/api/v1/demo/runtime", response_model=DemoRuntimeConfig)
|
||||
def get_demo_runtime() -> DemoRuntimeConfig:
|
||||
return runtime_config
|
||||
|
||||
|
||||
@app.post("/api/v1/demo/runtime", response_model=DemoRuntimeConfig)
|
||||
def update_demo_runtime(request: DemoRuntimeUpdateRequest) -> DemoRuntimeConfig:
|
||||
global agent_service, runtime_config
|
||||
|
||||
matcher_stages = [stage.strip() for stage in request.matcher_pipeline.split(",") if stage.strip()]
|
||||
if matcher_stages != ["classifier"]:
|
||||
raise HTTPException(status_code=400, detail="Only classifier matcher pipeline is supported in bert-first mode")
|
||||
|
||||
try:
|
||||
next_service = build_agent_service_with_runtime(
|
||||
matcher_pipeline=request.matcher_pipeline,
|
||||
classifier_backend=request.classifier_backend,
|
||||
session_backend=request.session_backend,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
agent_service = next_service
|
||||
runtime_config = DemoRuntimeConfig(
|
||||
matcher_pipeline=request.matcher_pipeline,
|
||||
classifier_backend=request.classifier_backend,
|
||||
session_backend=request.session_backend,
|
||||
slot_extractor_backend=settings.slot_extractor_backend,
|
||||
planner_backend=settings.planner_backend,
|
||||
planner_model_name=settings.planner_model_name,
|
||||
)
|
||||
return runtime_config
|
||||
Reference in New Issue
Block a user