from __future__ import annotations import json import os import time import unittest from unittest.mock import patch from fastapi.testclient import TestClient os.environ["AGENT_CLASSIFIER_BACKEND"] = "mock" os.environ["AGENT_CLASSIFIER_WARMUP_ENABLED"] = "false" from app.main import app from app.schemas.chat import ChatResponse def _fake_response() -> ChatResponse: return ChatResponse( session_id="sess_stream_1", reply_type="workflow_result", reply_text="好,空调已经打开了。", intent="cabin_ac_on", status="completed", trace_id="trace_stream_1", ) class ChatStreamTests(unittest.TestCase): def test_chat_stream_returns_final_only_when_fast(self) -> None: client = TestClient(app) with patch("app.main.agent_service.handle_chat", return_value=_fake_response()): response = client.post( "/api/v1/agent/chat-stream", json={ "session_id": "sess_stream_1", "user_id": "user_stream_1", "channel": "test", "input_text": "打开车窗", "input_type": "text", }, ) self.assertEqual(response.status_code, 200) lines = [line.strip() for line in response.text.splitlines() if line.strip()] self.assertEqual(len(lines), 1) final_event = json.loads(lines[0]) self.assertEqual(final_event.get("type"), "final") def test_chat_stream_returns_ack_then_final_when_slow_request(self) -> None: client = TestClient(app) def _slow_handle_chat(_request): time.sleep(1.2) return _fake_response() with patch("app.main.agent_service.handle_chat", side_effect=_slow_handle_chat): response = client.post( "/api/v1/agent/chat-stream", json={ "session_id": "sess_stream_1", "user_id": "user_stream_1", "channel": "test", "input_text": "打开车窗", "input_type": "text", }, ) self.assertEqual(response.status_code, 200) lines = [line.strip() for line in response.text.splitlines() if line.strip()] self.assertGreaterEqual(len(lines), 2) ack_event = json.loads(lines[0]) final_event = json.loads(lines[-1]) self.assertEqual(ack_event.get("type"), "ack") self.assertEqual(final_event.get("type"), "final") self.assertIn("data", final_event) self.assertIn("reply_text", final_event["data"]) def test_chat_stream_returns_ack_then_final_when_slow_social_request(self) -> None: client = TestClient(app) def _slow_handle_chat(_request): time.sleep(1.2) return _fake_response() with patch("app.main.agent_service.handle_chat", side_effect=_slow_handle_chat): response = client.post( "/api/v1/agent/chat-stream", json={ "session_id": "sess_stream_1", "user_id": "user_stream_1", "channel": "test", "input_text": "今天天气如何", "input_type": "text", }, ) self.assertEqual(response.status_code, 200) lines = [line.strip() for line in response.text.splitlines() if line.strip()] self.assertGreaterEqual(len(lines), 2) ack_event = json.loads(lines[0]) final_event = json.loads(lines[-1]) self.assertEqual(ack_event.get("type"), "ack") self.assertEqual(final_event.get("type"), "final") if __name__ == "__main__": unittest.main()