Files
ai-device/intelligent_cabin/archive/tests/test_chat_stream.py
2026-06-11 16:28:00 +08:00

110 lines
3.7 KiB
Python

from __future__ import annotations
import json
import os
import time
import unittest
from unittest.mock import patch
from fastapi.testclient import TestClient
os.environ["AGENT_CLASSIFIER_BACKEND"] = "mock"
os.environ["AGENT_CLASSIFIER_WARMUP_ENABLED"] = "false"
from app.main import app
from app.schemas.chat import ChatResponse
def _fake_response() -> ChatResponse:
return ChatResponse(
session_id="sess_stream_1",
reply_type="workflow_result",
reply_text="好,空调已经打开了。",
intent="cabin_ac_on",
status="completed",
trace_id="trace_stream_1",
)
class ChatStreamTests(unittest.TestCase):
def test_chat_stream_returns_final_only_when_fast(self) -> None:
client = TestClient(app)
with patch("app.main.agent_service.handle_chat", return_value=_fake_response()):
response = client.post(
"/api/v1/agent/chat-stream",
json={
"session_id": "sess_stream_1",
"user_id": "user_stream_1",
"channel": "test",
"input_text": "打开车窗",
"input_type": "text",
},
)
self.assertEqual(response.status_code, 200)
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
self.assertEqual(len(lines), 1)
final_event = json.loads(lines[0])
self.assertEqual(final_event.get("type"), "final")
def test_chat_stream_returns_ack_then_final_when_slow_request(self) -> None:
client = TestClient(app)
def _slow_handle_chat(_request):
time.sleep(1.2)
return _fake_response()
with patch("app.main.agent_service.handle_chat", side_effect=_slow_handle_chat):
response = client.post(
"/api/v1/agent/chat-stream",
json={
"session_id": "sess_stream_1",
"user_id": "user_stream_1",
"channel": "test",
"input_text": "打开车窗",
"input_type": "text",
},
)
self.assertEqual(response.status_code, 200)
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
self.assertGreaterEqual(len(lines), 2)
ack_event = json.loads(lines[0])
final_event = json.loads(lines[-1])
self.assertEqual(ack_event.get("type"), "ack")
self.assertEqual(final_event.get("type"), "final")
self.assertIn("data", final_event)
self.assertIn("reply_text", final_event["data"])
def test_chat_stream_returns_ack_then_final_when_slow_social_request(self) -> None:
client = TestClient(app)
def _slow_handle_chat(_request):
time.sleep(1.2)
return _fake_response()
with patch("app.main.agent_service.handle_chat", side_effect=_slow_handle_chat):
response = client.post(
"/api/v1/agent/chat-stream",
json={
"session_id": "sess_stream_1",
"user_id": "user_stream_1",
"channel": "test",
"input_text": "今天天气如何",
"input_type": "text",
},
)
self.assertEqual(response.status_code, 200)
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
self.assertGreaterEqual(len(lines), 2)
ack_event = json.loads(lines[0])
final_event = json.loads(lines[-1])
self.assertEqual(ack_event.get("type"), "ack")
self.assertEqual(final_event.get("type"), "final")
if __name__ == "__main__":
unittest.main()