110 lines
3.7 KiB
Python
110 lines
3.7 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
from fastapi.testclient import TestClient
|
|
|
|
os.environ["AGENT_CLASSIFIER_BACKEND"] = "mock"
|
|
os.environ["AGENT_CLASSIFIER_WARMUP_ENABLED"] = "false"
|
|
|
|
from app.main import app
|
|
from app.schemas.chat import ChatResponse
|
|
|
|
|
|
def _fake_response() -> ChatResponse:
|
|
return ChatResponse(
|
|
session_id="sess_stream_1",
|
|
reply_type="workflow_result",
|
|
reply_text="好,空调已经打开了。",
|
|
intent="cabin_ac_on",
|
|
status="completed",
|
|
trace_id="trace_stream_1",
|
|
)
|
|
|
|
|
|
class ChatStreamTests(unittest.TestCase):
|
|
def test_chat_stream_returns_final_only_when_fast(self) -> None:
|
|
client = TestClient(app)
|
|
with patch("app.main.agent_service.handle_chat", return_value=_fake_response()):
|
|
response = client.post(
|
|
"/api/v1/agent/chat-stream",
|
|
json={
|
|
"session_id": "sess_stream_1",
|
|
"user_id": "user_stream_1",
|
|
"channel": "test",
|
|
"input_text": "打开车窗",
|
|
"input_type": "text",
|
|
},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
|
|
self.assertEqual(len(lines), 1)
|
|
final_event = json.loads(lines[0])
|
|
self.assertEqual(final_event.get("type"), "final")
|
|
|
|
def test_chat_stream_returns_ack_then_final_when_slow_request(self) -> None:
|
|
client = TestClient(app)
|
|
|
|
def _slow_handle_chat(_request):
|
|
time.sleep(1.2)
|
|
return _fake_response()
|
|
|
|
with patch("app.main.agent_service.handle_chat", side_effect=_slow_handle_chat):
|
|
response = client.post(
|
|
"/api/v1/agent/chat-stream",
|
|
json={
|
|
"session_id": "sess_stream_1",
|
|
"user_id": "user_stream_1",
|
|
"channel": "test",
|
|
"input_text": "打开车窗",
|
|
"input_type": "text",
|
|
},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
|
|
self.assertGreaterEqual(len(lines), 2)
|
|
|
|
ack_event = json.loads(lines[0])
|
|
final_event = json.loads(lines[-1])
|
|
self.assertEqual(ack_event.get("type"), "ack")
|
|
self.assertEqual(final_event.get("type"), "final")
|
|
self.assertIn("data", final_event)
|
|
self.assertIn("reply_text", final_event["data"])
|
|
|
|
def test_chat_stream_returns_ack_then_final_when_slow_social_request(self) -> None:
|
|
client = TestClient(app)
|
|
|
|
def _slow_handle_chat(_request):
|
|
time.sleep(1.2)
|
|
return _fake_response()
|
|
|
|
with patch("app.main.agent_service.handle_chat", side_effect=_slow_handle_chat):
|
|
response = client.post(
|
|
"/api/v1/agent/chat-stream",
|
|
json={
|
|
"session_id": "sess_stream_1",
|
|
"user_id": "user_stream_1",
|
|
"channel": "test",
|
|
"input_text": "今天天气如何",
|
|
"input_type": "text",
|
|
},
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
lines = [line.strip() for line in response.text.splitlines() if line.strip()]
|
|
self.assertGreaterEqual(len(lines), 2)
|
|
ack_event = json.loads(lines[0])
|
|
final_event = json.loads(lines[-1])
|
|
self.assertEqual(ack_event.get("type"), "ack")
|
|
self.assertEqual(final_event.get("type"), "final")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|