Files
yuanzhipeng 384a1fbcb2 feat(create_mcp): 添加MCP工具输出Schema定义
添加了固定的输出Schema定义,包含code、message和data字段,
用于规范MCP工具的返回格式,提高API响应的一致性。

- 定义了标准的输出Schema结构
- 包含响应状态码、消息和数据字段
- code和message为必需字段
2026-01-04 10:14:23 +08:00

359 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from pathlib import Path
from typing import Any
import asyncio
import logging
# 支持直接运行和模块导入两种方式
try:
from .utils import load_json, generate_tool_name, generate_input_schema
from .utils import get_skill_by_id, DataSourceAPIClient, process_skill_response, test_sql_with_schema
from .utils import get_database_id, get_skill_id, get_env_config
from .utils.logger_config import logger_config
except ImportError:
from utils import load_json, generate_tool_name, generate_input_schema
from utils import get_skill_by_id, DataSourceAPIClient, process_skill_response, test_sql_with_schema
from utils import get_database_id, get_skill_id, get_env_config
from utils.logger_config import logger_config
from mcp.server.models import InitializationOptions
from mcp.server import NotificationOptions, Server
import mcp.types as types
# 初始化 MCP 专用日志器
mcp_logger = logger_config.setup_mcp_logging()
# ========== 数据源配置 ==========
# 数据源类型常量
DATA_SOURCE_API = "api" # 仅使用API数据
DATA_SOURCE_LOCAL = "local" # 仅使用本地JSON数据
DATA_SOURCE_BOTH = "both" # 合并本地和API数据
# 默认数据源(可修改)
DEFAULT_DATA_SOURCE = DATA_SOURCE_LOCAL
# ================================
def get_queries():
"""
获取业务查询配置
Returns:
list: 包含所有业务查询配置的列表
"""
try:
# 获取当前文件所在目录
current_dir = Path(__file__).parent
# 构建 businessQueries.json 的路径
json_path = current_dir / "businessQueries.json"
mcp_logger.debug(f"正在读取业务查询配置文件: {json_path}")
# 使用 load_json 方法读取 JSON 文件
queries = load_json(json_path)
mcp_logger.info(f"成功加载 {len(queries)} 个业务查询配置")
return queries
except Exception as e:
mcp_logger.error(f"加载业务查询配置失败: {e}", exc_info=True)
raise
def generate_tool_schema_from_query(query: dict) -> types.Tool:
"""
根据查询配置生成 MCP 工具模式
Args:
query: 单个查询配置字典
Returns:
types.Tool: MCP 工具对象
"""
try:
# 获取参数定义并生成 inputSchema
parameters = query.get('parameters', {})
input_schema = generate_input_schema(parameters)
# 生成工具名称(格式: tool_拼音_id
# tool_name = generate_tool_name(query['businessName'], query['id'])
tool_name = query['businessName']
# 构建工具描述,包含业务名称和业务描述
description = f"{query['businessName']}: {query['businessDescription']}"
mcp_logger.debug(f"生成工具模式: {tool_name} - {query['businessName']}")
return types.Tool(
name=tool_name,
description=description,
inputSchema=input_schema
)
except Exception as e:
mcp_logger.error(f"生成工具模式失败: {query.get('id', 'unknown')}, 错误: {e}", exc_info=True)
raise
# 创建 MCP 服务器实例
server = Server("lzwcai-mcp-sqlexecutor")
# 缓存查询配置,避免重复加载
_queries_cache = None
async def get_queries_cache(source: str = None):
"""
获取或初始化查询配置缓存
Args:
source: 数据源类型(默认使用 DEFAULT_DATA_SOURCE
- "api": 仅使用API数据
- "local": 仅使用本地JSON数据
- "both": 合并本地和API数据
Returns:
查询配置列表
"""
global _queries_cache
if _queries_cache is None:
source = source or DEFAULT_DATA_SOURCE
mcp_logger.info(f"初始化查询配置(数据源: {source}...")
if source == DATA_SOURCE_LOCAL:
_queries_cache = get_queries()
mcp_logger.info(f"本地配置: {len(_queries_cache)}")
elif source == DATA_SOURCE_API:
try:
_queries_cache = await call_third_party_api()
mcp_logger.info(f"API配置: {len(_queries_cache)}")
mcp_logger.info(f"API配置数组: {_queries_cache}")
except Exception as e:
mcp_logger.warning(f"API获取失败降级使用本地配置: {e}")
_queries_cache = get_queries()
else: # DATA_SOURCE_BOTH
local = get_queries()
try:
api = await call_third_party_api()
except Exception as e:
mcp_logger.warning(f"API获取失败: {e}")
api = []
_queries_cache = local + api
mcp_logger.info(f"配置总数: {len(_queries_cache)} 条(本地{len(local)}+API{len(api)}")
return _queries_cache
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
"""
列出所有动态生成的 MCP 工具
Returns:
list[types.Tool]: 所有可用的工具列表
"""
try:
mcp_logger.info("收到列出工具请求")
queries = await get_queries_cache()
tools = []
for query in queries:
tool = generate_tool_schema_from_query(query)
tools.append(tool)
mcp_logger.info(f"成功生成 {len(tools)} 个 MCP 工具")
mcp_logger.debug(f"工具列表: {[tool.name for tool in tools]}")
return tools
except Exception as e:
mcp_logger.error(f"列出工具失败: {e}", exc_info=True)
raise
@server.call_tool()
async def handle_call_tool(
name: str,
arguments: dict[str, Any] | None
) -> list[types.TextContent]:
"""
处理工具调用请求
Args:
name: 工具名称
arguments: 工具参数
Returns:
list[types.TextContent]: 工具执行结果(返回参数和对应的接口配置)
"""
try:
mcp_logger.info(f"收到工具调用请求: {name}")
mcp_logger.debug(f"工具参数: {arguments}")
# 获取查询配置缓存
queries = await get_queries_cache()
# 根据工具名称查找对应的 item接口配置
tool_item = None
for query in queries:
# tool_name = generate_tool_name(query['businessName'], query['id'])
tool_name = query['businessName']
if tool_name == name:
tool_item = query
break
# 构建返回结果
import json
if tool_item:
request_data = {
"datasourceId": tool_item.get("datasourceId"),
"businessName": tool_item.get("businessName"),
"businessDescription": tool_item.get("businessDescription"),
"sqlTemplate": tool_item.get("sqlTemplate"),
"parameters": tool_item.get("parameters"),
"testParams": arguments or {}
}
# 如果 arguments 中有 targetDatabaseName 且有值,添加到 request_data
if arguments and arguments.get("targetDatabaseName"):
request_data["targetDatabaseName"] = arguments["targetDatabaseName"]
mcp_logger.debug(f"添加目标数据库名称: {arguments['targetDatabaseName']}")
# 调用测试SQL API
try:
mcp_logger.info("正在调用测试SQL API...")
api_response = test_sql_with_schema(request_data)
mcp_logger.info("测试SQL API调用成功")
# 只返回 API 响应结果
result_text = json.dumps(api_response, ensure_ascii=False, indent=2)
except Exception as e:
error_msg = f"调用测试SQL API失败: {str(e)}"
mcp_logger.error(error_msg, exc_info=True)
result_text = json.dumps({"error": error_msg}, ensure_ascii=False, indent=2)
else:
error_msg = f"未找到工具 {name} 对应的配置"
result_text = json.dumps({"error": error_msg}, ensure_ascii=False, indent=2)
mcp_logger.debug(f"工具调用结果: {result_text}")
return [
types.TextContent(
type="text",
text=result_text
)
]
except Exception as e:
error_msg = f"工具调用失败: {name}, 错误: {e}"
mcp_logger.error(error_msg, exc_info=True)
return [
types.TextContent(
type="text",
text=f"错误: {error_msg}"
)
]
async def call_third_party_api(skill_id: str = None) -> list:
"""
调用第三方API获取技能信息并返回处理后的数据
Args:
skill_id: 技能ID默认从环境变量 SKILL_ID 读取,如果未设置则使用 1981000305474482178
Returns:
处理后的查询配置列表businessQueries格式
Example:
queries = await call_third_party_api()
# 返回: [{"id": "...", "businessName": "...", ...}, ...]
"""
try:
# 如果没有传入 skill_id则从环境变量读取
if skill_id is None:
skill_id = get_skill_id()
mcp_logger.info(f"调用第三方APIskill_id: {skill_id}")
# 获取原始数据
raw_result = get_skill_by_id(skill_id)
mcp_logger.info(f"成功{raw_result}")
# 处理并返回
processed_queries = process_skill_response(raw_result)
mcp_logger.info(f"成功获取并处理 {len(processed_queries)} 条数据")
return processed_queries
except Exception as e:
mcp_logger.error(f"API调用失败: {e}", exc_info=True)
raise
async def async_main():
"""MCP 服务器异步主函数"""
try:
mcp_logger.info("=" * 60)
mcp_logger.info("正在启动 MCP 服务器: lzwcai-mcp-sqlexecutor")
mcp_logger.info("版本: 0.1.0")
mcp_logger.info("=" * 60)
# 输出环境配置信息
env_config = get_env_config()
mcp_logger.info(f"环境配置 - Database ID: {env_config['database_id']}")
mcp_logger.info(f"环境配置 - Skill ID: {env_config['skill_id']}")
mcp_logger.info(f"环境配置 - Backend Base URL: {env_config['backend_base_url']}")
mcp_logger.info("=" * 60)
from mcp.server.stdio import stdio_server
async with stdio_server() as (read_stream, write_stream):
mcp_logger.info("MCP 服务器已启动,等待客户端连接...")
await server.run(
read_stream,
write_stream,
InitializationOptions(
server_name="lzwcai-mcp-sqlexecutor",
server_version="0.1.0",
capabilities=server.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
)
mcp_logger.info("MCP 服务器已关闭")
except Exception as e:
mcp_logger.error(f"MCP 服务器运行失败: {e}", exc_info=True)
raise
def main():
"""入口点函数(用于 console_scripts"""
try:
# 初始化系统日志
# MCP协议使用stdio通信必须禁用控制台输出以避免干扰JSON-RPC通信
logger_config.setup_logging(
app_name="lzwcai_mcp_sqlexecutor",
log_level=logging.INFO,
console_output=False # 禁用控制台输出
)
mcp_logger.info("开始运行 MCP SQL Executor 服务器")
asyncio.run(async_main())
except KeyboardInterrupt:
mcp_logger.info("收到中断信号,正在关闭服务器...")
except Exception as e:
mcp_logger.error(f"程序运行失败: {e}", exc_info=True)
raise
if __name__ == "__main__":
main()