feat(lzwcai-demp-tool-server-dify-to-mcp): 初始化 Dify 集成工具模块
新增 Dify 到 MCP 的集成工具,支持通过 Dify API 将模型部署到 MCP 平台并进行推理。 该模块包含完整的服务器实现、依赖配置和命令行启动脚本。 主要功能: - 支持 Workflow 和 Completion 模式的调用 - 自动翻译工具名称为驼峰命名格式 - 提供文件上传与任务停止接口 - 兼容流式与非流式响应处理
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
|
||||
{}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,712 @@
|
||||
"""
|
||||
业务工具模块 - API参数处理和JSON Schema操作
|
||||
|
||||
这个模块提供了API参数到JSON Schema转换、参数验证、API配置操作等核心业务功能。
|
||||
它是连接业务API配置和MCP工具定义的桥梁。
|
||||
|
||||
主要功能:
|
||||
1. API参数列表转换为JSON Schema
|
||||
2. 参数验证和默认值处理
|
||||
3. Schema提示文本生成
|
||||
4. API配置数组操作
|
||||
5. 必填参数检查
|
||||
|
||||
核心组件:
|
||||
- 参数类型常量定义
|
||||
- JSON Schema生成器
|
||||
- 参数验证器
|
||||
- 配置操作工具
|
||||
|
||||
设计原则:
|
||||
- 类型安全:严格的类型检查和转换
|
||||
- 容错性:完善的异常处理机制
|
||||
- 可扩展性:支持新的参数类型和验证规则
|
||||
- 标准化:符合JSON Schema规范
|
||||
|
||||
作者: lzwcai
|
||||
版本: 1.0.0
|
||||
"""
|
||||
|
||||
import json
|
||||
import copy
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from ..util.logger_config import get_logger
|
||||
|
||||
# 获取日志器实例
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# ==================== 常量定义 ====================
|
||||
|
||||
class ParamType:
|
||||
"""
|
||||
API参数类型常量
|
||||
|
||||
定义业务平台API参数的所有支持类型。
|
||||
这些类型会被映射到对应的JSON Schema类型。
|
||||
"""
|
||||
STRING = "STRING" # 字符串类型
|
||||
NUMBER = "NUMBER" # 数字类型(浮点数)
|
||||
INTEGER = "INTEGER" # 整数类型
|
||||
BOOLEAN = "BOOLEAN" # 布尔类型
|
||||
ARRAY = "ARRAY" # 数组类型
|
||||
OBJECT = "OBJECT" # 对象类型
|
||||
|
||||
|
||||
class JsonSchemaType:
|
||||
"""
|
||||
JSON Schema类型常量
|
||||
|
||||
定义JSON Schema规范中的标准类型。
|
||||
用于将业务参数类型转换为标准Schema类型。
|
||||
"""
|
||||
STRING = "string" # 字符串
|
||||
NUMBER = "number" # 数字
|
||||
BOOLEAN = "boolean" # 布尔值
|
||||
ARRAY = "array" # 数组
|
||||
OBJECT = "object" # 对象
|
||||
|
||||
|
||||
class RequestType:
|
||||
"""
|
||||
请求类型常量
|
||||
|
||||
定义API参数在HTTP请求中的位置类型。
|
||||
用于将参数正确分组到请求的不同部分。
|
||||
"""
|
||||
HEADER = "header" # 请求头参数
|
||||
QUERY = "query" # 查询参数(URL参数)
|
||||
BODY = "body" # 请求体参数
|
||||
LZWCAI_CONFIG = "lzwcaiConfig" # lzwcaiConfig参数(新的用户ID存储位置)
|
||||
|
||||
|
||||
# ==================== 类型映射配置 ====================
|
||||
|
||||
# 参数类型映射表:业务参数类型 -> JSON Schema类型
|
||||
PARAM_TYPE_MAPPING = {
|
||||
ParamType.STRING: JsonSchemaType.STRING, # 字符串 -> string
|
||||
ParamType.NUMBER: JsonSchemaType.NUMBER, # 数字 -> number
|
||||
ParamType.INTEGER: JsonSchemaType.NUMBER, # 整数 -> number(JSON Schema中整数也是number)
|
||||
ParamType.BOOLEAN: JsonSchemaType.BOOLEAN, # 布尔 -> boolean
|
||||
ParamType.ARRAY: JsonSchemaType.ARRAY, # 数组 -> array
|
||||
ParamType.OBJECT: JsonSchemaType.OBJECT, # 对象 -> object
|
||||
}
|
||||
|
||||
# ==================== 默认参数配置 ====================
|
||||
|
||||
# 默认用户ID参数配置
|
||||
# 这个参数会自动添加到所有API的Schema中,用于标识当前用户
|
||||
DEFAULT_USER_ID_PARAM = {
|
||||
"paramName": "userId", # 参数名称
|
||||
"paramType": ParamType.STRING, # 参数类型:字符串
|
||||
"paramPrompts": "当前与您对话的用户信息的用户ID", # 参数描述
|
||||
"requestType": RequestType.LZWCAI_CONFIG, # 请求类型:lzwcaiConfig类型
|
||||
"required": 1, # 必填参数
|
||||
}
|
||||
|
||||
|
||||
# ==================== 核心函数 ====================
|
||||
|
||||
def generate_json_schema(api_params: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
API参数列表转JSON Schema格式
|
||||
|
||||
这是模块的核心函数,负责将业务平台的API参数列表转换为符合JSON Schema规范的对象。
|
||||
生成的Schema用于MCP工具的参数定义和验证。
|
||||
|
||||
转换流程:
|
||||
1. 验证输入参数类型
|
||||
2. 过滤掉header类型的参数(header参数单独处理)
|
||||
3. 添加默认的userId参数(存储在lzwcaiConfig中)
|
||||
4. 按请求类型分组参数(query, body, lzwcaiConfig等)
|
||||
5. 为每个分组创建Schema属性
|
||||
6. 清理空的required列表
|
||||
|
||||
Schema结构:
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "object",
|
||||
"properties": {...},
|
||||
"required": [...]
|
||||
},
|
||||
"body": {
|
||||
"type": "object",
|
||||
"properties": {...},
|
||||
"required": [...]
|
||||
},
|
||||
"lzwcaiConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"userId": {
|
||||
"type": "string",
|
||||
"description": "当前与您对话的用户信息的用户ID"
|
||||
}
|
||||
},
|
||||
"required": ["userId"]
|
||||
}
|
||||
},
|
||||
"required": ["query", "body", "lzwcaiConfig"] // 只包含有参数的分组
|
||||
}
|
||||
|
||||
参数:
|
||||
api_params: API参数对象列表,每个对象包含:
|
||||
- paramName: 参数名称
|
||||
- paramType: 参数类型(STRING, NUMBER等)
|
||||
- paramPrompts: 参数描述
|
||||
- requestType: 请求类型(header, query, body等)
|
||||
- required: 是否必填(1为必填,0为可选)
|
||||
- defaultValue: 默认值(可选)
|
||||
|
||||
返回:
|
||||
dict: 符合JSON Schema规范的对象
|
||||
|
||||
异常处理:
|
||||
TypeError: 如果api_params不是列表类型
|
||||
|
||||
设计考虑:
|
||||
- header参数被过滤掉,因为它们在HTTP请求中单独处理
|
||||
- 自动添加userId参数到lzwcaiConfig分组,确保所有API都能获取用户信息
|
||||
- 按请求类型分组,便于后续的参数处理和验证
|
||||
- 清理空的required列表,保持Schema的简洁性
|
||||
"""
|
||||
# 参数类型验证
|
||||
if not isinstance(api_params, list):
|
||||
raise TypeError("api_params must be a list")
|
||||
|
||||
logger.debug(f"生成JSON Schema,参数数量: {len(api_params)}")
|
||||
|
||||
# 创建基础Schema结构
|
||||
schema = {
|
||||
"type": JsonSchemaType.OBJECT, # 根类型为对象
|
||||
"properties": {}, # 属性定义
|
||||
"required": [] # 必填字段列表
|
||||
}
|
||||
|
||||
# 过滤参数并添加默认userId参数
|
||||
# header参数在HTTP请求中单独处理,不包含在Schema中
|
||||
filtered_params = [
|
||||
param for param in api_params
|
||||
if param.get("requestType") != RequestType.HEADER
|
||||
]
|
||||
|
||||
# 添加默认的userId参数到lzwcaiConfig分组,确保所有API都能获取用户信息
|
||||
filtered_params.append(DEFAULT_USER_ID_PARAM)
|
||||
|
||||
logger.debug(f"过滤后参数数量: {len(filtered_params)}")
|
||||
|
||||
# 按请求类型分组参数
|
||||
param_groups = _group_parameters_by_type(filtered_params)
|
||||
logger.debug(f"参数分组: {list(param_groups.keys())}")
|
||||
|
||||
# 为每个请求类型创建Schema属性
|
||||
for req_type, params in param_groups.items():
|
||||
logger.debug(f"处理参数组 {req_type},包含 {len(params)} 个参数")
|
||||
_add_request_type_to_schema(schema, req_type, params)
|
||||
|
||||
# 清理空的required列表,保持Schema简洁
|
||||
_cleanup_empty_required_lists(schema)
|
||||
|
||||
logger.debug("JSON Schema生成完成")
|
||||
return schema
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
def _group_parameters_by_type(
|
||||
params: List[Dict[str, Any]],
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
按请求类型分组参数
|
||||
|
||||
将参数列表按照requestType字段进行分组,便于后续按类型处理。
|
||||
如果参数没有指定requestType,默认归类为query类型。
|
||||
|
||||
参数:
|
||||
params: 参数字典列表
|
||||
|
||||
返回:
|
||||
Dict[str, List[Dict[str, Any]]]: 按类型分组的参数字典
|
||||
格式: {"query": [...], "body": [...], "lzwcaiConfig": [...]}
|
||||
"""
|
||||
param_groups = {}
|
||||
for param in params:
|
||||
# 获取请求类型,默认为query
|
||||
req_type = param.get("requestType", RequestType.QUERY)
|
||||
|
||||
# 记录参数名和类型,便于调试
|
||||
param_name = param.get("paramName", "未命名参数")
|
||||
logger.debug(f"参数 '{param_name}' 归类为 {req_type} 类型")
|
||||
|
||||
# 初始化分组
|
||||
if req_type not in param_groups:
|
||||
param_groups[req_type] = []
|
||||
logger.debug(f"创建新的参数分组: {req_type}")
|
||||
|
||||
# 添加参数到对应分组
|
||||
param_groups[req_type].append(param)
|
||||
|
||||
logger.debug(f"参数分组完成,共 {len(param_groups)} 个分组")
|
||||
return param_groups
|
||||
|
||||
|
||||
def _add_request_type_to_schema(
|
||||
schema: Dict[str, Any], req_type: str, params: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""
|
||||
向Schema添加请求类型分组
|
||||
|
||||
为指定的请求类型创建Schema属性,并处理该类型下的所有参数。
|
||||
每个请求类型都会成为根Schema的一个属性。
|
||||
|
||||
参数:
|
||||
schema: 根Schema对象
|
||||
req_type: 请求类型(如"query", "body", "lzwcaiConfig")
|
||||
params: 该类型下的参数列表
|
||||
"""
|
||||
# 如果该请求类型还没有在Schema中定义,创建它
|
||||
if req_type not in schema["properties"]:
|
||||
schema["properties"][req_type] = {
|
||||
"type": JsonSchemaType.OBJECT, # 每个请求类型都是对象类型
|
||||
"properties": {}, # 该类型下的参数定义
|
||||
"required": [], # 该类型下的必填参数
|
||||
}
|
||||
|
||||
# 如果该类型有参数,将其标记为根级别的必填字段
|
||||
if params:
|
||||
schema["required"].append(req_type)
|
||||
|
||||
# 处理该类型下的每个参数
|
||||
for param in params:
|
||||
_add_parameter_to_schema(schema["properties"][req_type], param)
|
||||
|
||||
|
||||
def _add_parameter_to_schema(
|
||||
type_schema: Dict[str, Any], param: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
向类型Schema添加单个参数
|
||||
|
||||
将单个参数的定义添加到指定类型的Schema中,包括参数类型、描述、
|
||||
默认值、示例等信息。
|
||||
|
||||
参数处理:
|
||||
1. 提取参数基本信息(名称、类型、描述)
|
||||
2. 处理默认值(添加到描述中并设置default字段)
|
||||
3. 映射参数类型到JSON Schema类型
|
||||
4. 创建参数定义对象
|
||||
5. 添加到Schema的properties中
|
||||
6. 处理必填参数标记
|
||||
|
||||
参数:
|
||||
type_schema: 类型Schema对象(如query、body的Schema)
|
||||
param: 参数定义字典
|
||||
"""
|
||||
# 获取参数名称,如果没有名称则跳过
|
||||
param_name = param.get("paramName")
|
||||
if not param_name:
|
||||
logger.warning(f"参数缺少名称,跳过: {param}")
|
||||
return
|
||||
|
||||
# 获取参数类型和描述
|
||||
param_type = param.get("paramType", ParamType.STRING)
|
||||
param_desc = param.get("paramPrompts", "")
|
||||
is_required = param.get("required") == 1
|
||||
|
||||
logger.debug(f"添加参数: {param_name}, 类型: {param_type}, 必填: {is_required}")
|
||||
|
||||
# 如果有默认值,添加到描述中
|
||||
if param.get("defaultValue") is not None:
|
||||
param_desc += f"(默认值为{param['defaultValue']})"
|
||||
logger.debug(f"参数 {param_name} 有默认值: {param['defaultValue']}")
|
||||
|
||||
# 将业务参数类型映射到JSON Schema类型
|
||||
json_type = PARAM_TYPE_MAPPING.get(param_type, JsonSchemaType.STRING)
|
||||
if param_type not in PARAM_TYPE_MAPPING:
|
||||
logger.warning(f"未知的参数类型 {param_type},使用默认类型 string")
|
||||
|
||||
# 创建参数定义对象
|
||||
param_def = {
|
||||
"type": json_type, # JSON Schema类型
|
||||
"description": param_desc, # 参数描述
|
||||
}
|
||||
|
||||
# 添加可选字段
|
||||
if param.get("defaultValue") is not None:
|
||||
param_def["default"] = param["defaultValue"] # 默认值
|
||||
if param.get("example") is not None:
|
||||
param_def["example"] = param["example"] # 示例值
|
||||
|
||||
# 添加到类型Schema的properties中
|
||||
type_schema["properties"][param_name] = param_def
|
||||
|
||||
# 如果是必填参数,添加到required列表中
|
||||
if is_required:
|
||||
type_schema["required"].append(param_name)
|
||||
logger.debug(f"参数 {param_name} 标记为必填")
|
||||
|
||||
|
||||
def _cleanup_empty_required_lists(schema: Dict[str, Any]) -> None:
|
||||
"""
|
||||
清理Schema中的空required列表
|
||||
|
||||
移除Schema中所有空的required数组,保持Schema的简洁性。
|
||||
这包括嵌套属性中的required列表和顶级的required列表。
|
||||
|
||||
清理规则:
|
||||
1. 遍历所有请求类型的Schema
|
||||
2. 如果某个类型的required列表为空,删除该字段
|
||||
3. 如果顶级required列表为空,删除该字段
|
||||
|
||||
参数:
|
||||
schema: 要清理的Schema对象
|
||||
"""
|
||||
# 清理嵌套属性中的空required列表
|
||||
for req_type in list(schema["properties"].keys()):
|
||||
type_schema = schema["properties"][req_type]
|
||||
if not type_schema.get("required"):
|
||||
# 如果required列表为空,删除该字段
|
||||
type_schema.pop("required", None)
|
||||
|
||||
# 清理顶级的空required列表
|
||||
if not schema.get("required"):
|
||||
schema.pop("required", None)
|
||||
|
||||
|
||||
def create_structured_data(
|
||||
schema: Dict[str, Any], params: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate data structure based on schema and input parameters.
|
||||
|
||||
Automatically matches input parameters to schema fields and creates
|
||||
a properly structured data object.
|
||||
|
||||
Args:
|
||||
schema: JSON schema object defining the data structure
|
||||
params: Parameter values to fill
|
||||
|
||||
Returns:
|
||||
dict: Structured data object conforming to schema
|
||||
|
||||
Raises:
|
||||
TypeError: If schema or params are not dictionaries
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
raise TypeError("schema must be a dictionary")
|
||||
if not isinstance(params, dict):
|
||||
raise TypeError("params must be a dictionary")
|
||||
|
||||
result = {}
|
||||
|
||||
# Process each top-level schema property
|
||||
for field_name, field_schema in schema.get("properties", {}).items():
|
||||
result[field_name] = {}
|
||||
field_properties = field_schema.get("properties", {})
|
||||
|
||||
# Find matching parameters for this field
|
||||
matched_params = {
|
||||
param_name: param_value
|
||||
for param_name, param_value in params.items()
|
||||
if param_name in field_properties
|
||||
}
|
||||
|
||||
# Only include field if there are matching parameters
|
||||
if matched_params:
|
||||
result[field_name] = matched_params
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def generate_schema_prompt(schema: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generate descriptive prompt from JSON Schema for LLM guidance.
|
||||
|
||||
Args:
|
||||
schema: JSON Schema object returned by generate_json_schema
|
||||
|
||||
Returns:
|
||||
str: Structured prompt text for guiding LLM parameter generation
|
||||
|
||||
Raises:
|
||||
TypeError: If schema is not a dictionary
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
raise TypeError("schema must be a dictionary")
|
||||
|
||||
prompt_parts = ["当前工具所需参数:\n"]
|
||||
|
||||
# Process each parameter group
|
||||
for group_name, group_schema in schema.get("properties", {}).items():
|
||||
prompt_parts.append(f"## {group_name} 参数组:")
|
||||
|
||||
required_params = set(group_schema.get("required", []))
|
||||
|
||||
# Process each parameter in the group
|
||||
for param_name, param_info in group_schema.get("properties", {}).items():
|
||||
param_type = param_info.get("type", JsonSchemaType.STRING)
|
||||
param_desc = param_info.get("description", "")
|
||||
required_mark = "(必填)" if param_name in required_params else "(可选)"
|
||||
|
||||
prompt_parts.append(
|
||||
f"- {param_name}{required_mark}: {param_desc}, 类型: {param_type}"
|
||||
)
|
||||
|
||||
prompt_parts.append("")
|
||||
|
||||
# Add output format guidance
|
||||
prompt_parts.extend(
|
||||
[
|
||||
"参数格式要求:JSON对象,包含所有必填字段。",
|
||||
"示例格式:",
|
||||
"```json",
|
||||
"{",
|
||||
]
|
||||
)
|
||||
|
||||
for group_name in schema.get("properties", {}):
|
||||
prompt_parts.append(f' "{group_name}": {{')
|
||||
prompt_parts.append(" // 相关参数")
|
||||
prompt_parts.append(" },")
|
||||
|
||||
prompt_parts.extend(["}", "```"])
|
||||
|
||||
return "\n".join(prompt_parts)
|
||||
|
||||
|
||||
def _remove_property_from_schema(schema: Dict[str, Any], property_name: str) -> None:
|
||||
"""
|
||||
Remove a property from schema (helper function to reduce code duplication).
|
||||
|
||||
Args:
|
||||
schema: Schema object to modify
|
||||
property_name: Name of property to remove
|
||||
"""
|
||||
# Remove from properties
|
||||
if isinstance(schema.get("properties"), dict):
|
||||
schema["properties"].pop(property_name, None)
|
||||
|
||||
# Remove from required list
|
||||
if isinstance(schema.get("required"), list):
|
||||
try:
|
||||
schema["required"].remove(property_name)
|
||||
except ValueError:
|
||||
pass # Property not in required list, ignore
|
||||
|
||||
# Remove empty required list
|
||||
if not schema["required"]:
|
||||
schema.pop("required", None)
|
||||
|
||||
|
||||
def remove_property_from_api_item(
|
||||
api_item: Dict[str, Any], property_name: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Remove specified property from a single API object.
|
||||
|
||||
Args:
|
||||
api_item: API object (single element from generate_api_array result)
|
||||
property_name: Name of property to remove
|
||||
|
||||
Returns:
|
||||
dict: Processed API object with property removed
|
||||
|
||||
Raises:
|
||||
ValueError: When input parameters are invalid
|
||||
TypeError: When input types are incorrect
|
||||
"""
|
||||
if not isinstance(api_item, dict) or not api_item:
|
||||
raise ValueError("api_item must be a non-empty dictionary")
|
||||
|
||||
if not isinstance(property_name, str) or not property_name.strip():
|
||||
raise ValueError("property_name must be a non-empty string")
|
||||
|
||||
# Create deep copy to avoid modifying original
|
||||
new_api = copy.deepcopy(api_item)
|
||||
|
||||
if "schema" in new_api and isinstance(new_api["schema"], dict):
|
||||
_remove_property_from_schema(new_api["schema"], property_name)
|
||||
|
||||
return new_api
|
||||
|
||||
|
||||
def remove_property_from_api_array(
|
||||
api_array: List[Dict[str, Any]], property_name: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Remove specified property from all API objects in array.
|
||||
|
||||
Args:
|
||||
api_array: API array returned by generate_api_array
|
||||
property_name: Name of property to remove
|
||||
|
||||
Returns:
|
||||
list: Processed API array with property removed from all items
|
||||
|
||||
Raises:
|
||||
ValueError: When input parameters are invalid
|
||||
TypeError: When input types are incorrect
|
||||
"""
|
||||
if not isinstance(api_array, list) or not api_array:
|
||||
raise ValueError("api_array must be a non-empty list")
|
||||
|
||||
if not isinstance(property_name, str) or not property_name.strip():
|
||||
raise ValueError("property_name must be a non-empty string")
|
||||
|
||||
result = []
|
||||
|
||||
for api_item in api_array:
|
||||
try:
|
||||
processed_item = remove_property_from_api_item(api_item, property_name)
|
||||
result.append(processed_item)
|
||||
except Exception as e:
|
||||
result.append(copy.deepcopy(api_item)) # Keep original if processing fails
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fill_default_values_by_schema(
|
||||
schema: Dict[str, Any], arguments: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Auto-fill default values in arguments based on JSON Schema.
|
||||
|
||||
Only fills missing parameters or those with None/empty string values.
|
||||
|
||||
Args:
|
||||
schema: JSON Schema object with grouped structure (body/query/lzwcaiConfig)
|
||||
arguments: User input parameter dictionary
|
||||
|
||||
Returns:
|
||||
dict: Parameter dictionary with default values filled
|
||||
|
||||
Raises:
|
||||
TypeError: When input types are incorrect
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
raise TypeError("schema must be a dictionary")
|
||||
|
||||
if not isinstance(arguments, dict):
|
||||
arguments = {}
|
||||
else:
|
||||
arguments = copy.deepcopy(arguments)
|
||||
|
||||
for group_name, group_schema in schema.get("properties", {}).items():
|
||||
if not isinstance(arguments.get(group_name), dict):
|
||||
arguments[group_name] = {}
|
||||
|
||||
for param_name, param_schema in group_schema.get("properties", {}).items():
|
||||
default_value = param_schema.get("default")
|
||||
current_value = arguments[group_name].get(param_name)
|
||||
|
||||
# Fill default only if parameter is missing, None, or empty string
|
||||
if (
|
||||
current_value is None or current_value == ""
|
||||
) and default_value is not None:
|
||||
arguments[group_name][param_name] = default_value
|
||||
|
||||
return arguments
|
||||
|
||||
|
||||
def generate_api_array(api_params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate API array with JSON Schema for each API.
|
||||
|
||||
Args:
|
||||
api_params: List of API parameter definitions
|
||||
|
||||
Returns:
|
||||
list: API array with schema attached to each item
|
||||
|
||||
Raises:
|
||||
TypeError: If api_params is not a list
|
||||
"""
|
||||
if not isinstance(api_params, list):
|
||||
raise TypeError("api_params must be a list")
|
||||
|
||||
api_array = []
|
||||
for param in api_params:
|
||||
try:
|
||||
schema = generate_json_schema(param.get("parameters", []))
|
||||
api_array.append({**param, "schema": schema})
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing config: {str(e)}")
|
||||
api_array.append(param) # Add without schema if generation fails
|
||||
|
||||
return api_array
|
||||
|
||||
|
||||
def check_required_arguments(
|
||||
schema: Dict[str, Any], arguments: Dict[str, Any]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Validate that arguments contain all required parameters from schema.
|
||||
|
||||
Args:
|
||||
schema: JSON Schema object with grouped structure (body/query/lzwcaiConfig)
|
||||
arguments: User input parameter dictionary
|
||||
|
||||
Returns:
|
||||
list: Missing required parameters with group names, e.g., ['body.username', 'lzwcaiConfig.userId']
|
||||
|
||||
Raises:
|
||||
TypeError: When input types are incorrect
|
||||
"""
|
||||
if not isinstance(schema, dict):
|
||||
raise TypeError("schema must be a dictionary")
|
||||
|
||||
if not isinstance(arguments, dict):
|
||||
arguments = {}
|
||||
|
||||
missing_params = []
|
||||
|
||||
for group_name, group_schema in schema.get("properties", {}).items():
|
||||
group_args = arguments.get(group_name, {})
|
||||
required_params = group_schema.get("required", [])
|
||||
|
||||
for param_name in required_params:
|
||||
param_value = group_args.get(param_name) if group_args else None
|
||||
|
||||
# Consider missing if not provided, None, or empty string
|
||||
if param_value is None or param_value == "":
|
||||
# Try to get parameter description
|
||||
description = ""
|
||||
try:
|
||||
param_properties = group_schema.get("properties", {})
|
||||
param_info = param_properties.get(param_name, {})
|
||||
description = param_info.get("description", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Format missing parameter name
|
||||
if description:
|
||||
missing_params.append(f"{group_name}.{param_name}({description})")
|
||||
else:
|
||||
missing_params.append(f"{group_name}.{param_name}")
|
||||
|
||||
return missing_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage and testing
|
||||
try:
|
||||
with open(
|
||||
"E:/yh-ai/project/lzwcai-szyg/lzwcai-demp-tool-server/src/"
|
||||
"lzwcai_demp_tool_server_business_to_mcp/mcp_generator/src/parameters.json",
|
||||
"r",
|
||||
encoding="utf-8",
|
||||
) as f:
|
||||
api_params = json.load(f)
|
||||
|
||||
api_array = generate_api_array(api_params)
|
||||
result = remove_property_from_api_array(api_array, "userId")
|
||||
|
||||
# Write results to JSON file
|
||||
with open("schema1.json", "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=4)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Warning: Test data file not found. Skipping example execution.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in example execution: {str(e)}")
|
||||
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Business API configuration utility functions.
|
||||
|
||||
This module provides utilities for fetching and processing business API configurations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import requests
|
||||
from typing import Dict, List, Any
|
||||
from ..util.logger_config import get_logger
|
||||
|
||||
# 配置日志
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
def get_business_api_details(api_ids: List[int], auth_token: str = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取业务平台API详情
|
||||
|
||||
调用业务平台接口获取指定API ID列表的详细信息
|
||||
|
||||
Args:
|
||||
api_ids: API ID列表,例如 [1925128743899111425, 1925128744524062721]
|
||||
auth_token: 认证token,如果不提供则使用默认token
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: API详情列表,返回接口响应中的data字段内容
|
||||
|
||||
Raises:
|
||||
requests.RequestException: 当网络请求失败时抛出
|
||||
ValueError: 当响应格式不正确或返回错误时抛出
|
||||
|
||||
Example:
|
||||
>>> api_ids = [1925128743899111425, 1925128744524062721]
|
||||
>>> details = get_business_api_details(api_ids)
|
||||
>>> print(len(details))
|
||||
2
|
||||
"""
|
||||
if not isinstance(api_ids, list) or not api_ids:
|
||||
raise ValueError("api_ids must be a non-empty list")
|
||||
|
||||
# 默认认证token
|
||||
default_token = "eyJhbGciOiJIUzUxMiJ9.eyJsb2dpbl91c2VyX2tleSI6ImM3OGU0M2NlLTJhZjQtNGRjYy1iMWE1LTU3YjM5YTdkNTA1OSJ9.5f1lSJJdLUunZIwCfneT1DiagGN4jD-QCnFCffWmrnvEcLfpuSMWRpY7fF-6H3yZ2N5ICZ4ZQN6cx7iqwF6jKw"
|
||||
token = auth_token or default_token
|
||||
|
||||
# 接口URL - 支持环境变量配置
|
||||
# 默认URL
|
||||
default_url = "http://lzwcai-demp-corp-manager:8086/system/mcpServer/bizSys/api/getByIds"
|
||||
# 从环境变量获取URL,如果没有设置则使用默认URL
|
||||
url = os.getenv("lzwcai_mcp_dyntoolapi_auth_url", default_url)
|
||||
|
||||
# 请求头
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
# "User-Agent": "Apifox/1.0.0 (https://apifox.com)",
|
||||
"Content-Type": "application/json",
|
||||
# "Accept": "*/*",
|
||||
# "Host": "192.168.2.236:8088",
|
||||
"Connection": "keep-alive"
|
||||
}
|
||||
|
||||
try:
|
||||
# 发送POST请求
|
||||
response = requests.post(url, headers=headers, json=api_ids, timeout=30)
|
||||
response.raise_for_status() # 检查HTTP状态码
|
||||
|
||||
# 解析响应JSON
|
||||
response_data = response.json()
|
||||
|
||||
# 检查响应格式
|
||||
if not isinstance(response_data, dict):
|
||||
raise ValueError("响应格式不正确:期望JSON对象")
|
||||
|
||||
# 检查业务状态码
|
||||
code = response_data.get("code")
|
||||
if code != 200:
|
||||
msg = response_data.get("msg", "未知错误")
|
||||
raise ValueError(f"业务接口返回错误:code={code}, msg={msg}")
|
||||
|
||||
# 获取data字段
|
||||
data = response_data.get("data", [])
|
||||
if not isinstance(data, list):
|
||||
logger.warning("响应中的data字段不是列表类型,将转换为列表")
|
||||
data = [data] if data is not None else []
|
||||
|
||||
logger.info(f"成功获取 {len(data)} 个API详情")
|
||||
return data
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
raise requests.RequestException("请求超时:接口响应时间过长")
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise requests.RequestException("连接错误:无法连接到业务平台服务器")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise requests.RequestException(f"HTTP错误:{e}")
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("响应格式错误:无法解析JSON数据")
|
||||
except Exception as e:
|
||||
logger.error(f"获取API详情时发生未知错误:{str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def map_api_details_to_config(api_details: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
将API详情数据映射为api_config.json格式
|
||||
|
||||
Args:
|
||||
api_details: get_business_api_details方法返回的API详情列表
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 符合api_config.json格式的配置对象
|
||||
|
||||
Raises:
|
||||
ValueError: 当输入数据格式不正确时抛出
|
||||
|
||||
Example:
|
||||
>>> api_details = get_business_api_details([1925128743899111425])
|
||||
>>> config = map_api_details_to_config(api_details)
|
||||
>>> print(config["serverName"])
|
||||
lzwcai_mcp_api_converter
|
||||
"""
|
||||
if not isinstance(api_details, list):
|
||||
raise ValueError("api_details must be a list")
|
||||
|
||||
if not api_details:
|
||||
raise ValueError("api_details cannot be empty")
|
||||
|
||||
# 获取第一个API的domainUrl作为全局domainUrl
|
||||
domain_url = ""
|
||||
if api_details and isinstance(api_details[0], dict):
|
||||
domain_url = api_details[0].get("domainUrl", "")
|
||||
|
||||
# 收集所有businessPrompts用于生成description
|
||||
business_prompts = []
|
||||
for api in api_details:
|
||||
if isinstance(api, dict) and api.get("businessPrompts"):
|
||||
business_prompts.append(api["businessPrompts"])
|
||||
|
||||
# 生成description
|
||||
description = "、".join(business_prompts) if business_prompts else "业务API集合"
|
||||
|
||||
# 构建配置对象
|
||||
config = {
|
||||
"serverName": "lzwcai_mcp_api_converter",
|
||||
"description": description,
|
||||
"domainUrl": domain_url,
|
||||
"packageName": "lzwcai-mcp-dyntoolapi",
|
||||
"version": "1.0.0",
|
||||
"apiConfig": api_details # 直接使用原始API详情数据
|
||||
}
|
||||
|
||||
logger.info(f"成功映射 {len(api_details)} 个API到配置格式")
|
||||
logger.info(f"服务名称: {config['serverName']}")
|
||||
logger.info(f"域名URL: {config['domainUrl']}")
|
||||
logger.info(f"描述: {config['description']}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_business_api_config(api_ids: List[int], auth_token: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
一步到位获取业务平台API配置
|
||||
|
||||
传入API ID列表,直接返回处理好的api_config.json格式配置
|
||||
|
||||
Args:
|
||||
api_ids: API ID列表,例如 [1925128743899111425, 1925128744524062721]
|
||||
auth_token: 认证token,如果不提供则使用默认token
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 符合api_config.json格式的完整配置对象
|
||||
|
||||
Raises:
|
||||
requests.RequestException: 当网络请求失败时抛出
|
||||
ValueError: 当响应格式不正确或返回错误时抛出
|
||||
|
||||
Example:
|
||||
>>> api_ids = [1925128743899111425, 1925128744524062721]
|
||||
>>> config = get_business_api_config(api_ids)
|
||||
>>> print(config["serverName"])
|
||||
lzwcai_mcp_api_converter
|
||||
>>> print(len(config["apiConfig"]))
|
||||
2
|
||||
"""
|
||||
try:
|
||||
# 步骤1: 获取API详情
|
||||
logger.info(f"开始获取 {len(api_ids)} 个API的详情...")
|
||||
api_details = get_business_api_details(api_ids, auth_token)
|
||||
|
||||
# 步骤2: 映射为配置格式
|
||||
logger.info("开始映射为配置格式...")
|
||||
config = map_api_details_to_config(api_details)
|
||||
|
||||
logger.info(f"[SUCCESS] 成功生成API配置!包含 {len(config['apiConfig'])} 个API")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取业务API配置时发生错误: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,966 @@
|
||||
"""
|
||||
API认证服务模块
|
||||
|
||||
提供用户认证、Token管理、企业认证等功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
|
||||
sys.path.append(project_root)
|
||||
|
||||
from ..business.business_util import create_structured_data, generate_api_array
|
||||
from ..util.nested_value import get_nested_value
|
||||
from ..util.logger_config import get_logger
|
||||
from .get_auth import get_auth_data
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
"""认证相关异常"""
|
||||
|
||||
def __init__(self, message: str, error_code: Optional[str] = None):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class AuthConfig:
|
||||
"""认证配置类"""
|
||||
|
||||
# Token相关配置
|
||||
TOKEN_PREFIX = "lzwc"
|
||||
TOKEN_SUFFIX = "token"
|
||||
|
||||
# 重试配置
|
||||
DEFAULT_MAX_ATTEMPTS = 3
|
||||
DEFAULT_BASE_DELAY = 1
|
||||
DEFAULT_MAX_DELAY = 60
|
||||
|
||||
|
||||
class ParameterExtractor:
|
||||
"""参数提取器"""
|
||||
|
||||
@staticmethod
|
||||
def extract_param_defaults(param_list: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
将参数列表转换为 {paramName: defaultValue} 的字典
|
||||
|
||||
Args:
|
||||
param_list: 参数字典列表,每个元素包含 paramName 和 defaultValue
|
||||
|
||||
Returns:
|
||||
以 paramName 为 key,defaultValue 为 value 的字典
|
||||
"""
|
||||
if not isinstance(param_list, list):
|
||||
logger.warning(f"参数列表类型错误,期望list,实际{type(param_list)}")
|
||||
return {}
|
||||
|
||||
logger.debug(f"提取参数默认值,参数数量: {len(param_list)}")
|
||||
result = {}
|
||||
|
||||
for item in param_list:
|
||||
if not isinstance(item, dict):
|
||||
logger.warning(f"参数项类型错误,跳过: {item}")
|
||||
continue
|
||||
|
||||
param_name = item.get("paramName")
|
||||
if param_name:
|
||||
default_value = item.get("defaultValue")
|
||||
result[param_name] = default_value
|
||||
logger.debug(f"参数 {param_name}: {default_value}")
|
||||
|
||||
logger.debug(f"提取完成,共 {len(result)} 个参数")
|
||||
return result
|
||||
|
||||
|
||||
class ApiRetryExecutor:
|
||||
"""API重试执行器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_attempts: int = AuthConfig.DEFAULT_MAX_ATTEMPTS,
|
||||
base_delay: int = AuthConfig.DEFAULT_BASE_DELAY,
|
||||
max_delay: int = AuthConfig.DEFAULT_MAX_DELAY,
|
||||
):
|
||||
self.max_attempts = max_attempts
|
||||
self.base_delay = base_delay
|
||||
self.max_delay = max_delay
|
||||
|
||||
async def execute_with_retry(
|
||||
self,
|
||||
api_config: Dict[str, Any],
|
||||
user_params: Dict[str, Any],
|
||||
need_auth: bool = False,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
使用指数退避策略的API调用重试
|
||||
|
||||
Args:
|
||||
api_config: API配置信息
|
||||
user_params: 用户提供的参数
|
||||
need_auth: 是否需要鉴权
|
||||
|
||||
Returns:
|
||||
API调用响应
|
||||
|
||||
Raises:
|
||||
AuthError: 认证相关错误
|
||||
"""
|
||||
from .core_server import call_api
|
||||
|
||||
attempt = 0
|
||||
last_error = None
|
||||
|
||||
while attempt < self.max_attempts:
|
||||
attempt += 1
|
||||
try:
|
||||
logger.info(f"尝试调用API,第 {attempt} 次")
|
||||
|
||||
# 准备API调用数据
|
||||
api_array = generate_api_array([api_config])
|
||||
call_api_data = create_structured_data(
|
||||
api_array[0]["schema"], user_params
|
||||
)
|
||||
|
||||
# 执行API调用
|
||||
api_res = await call_api(api_config, call_api_data, need_auth=need_auth)
|
||||
|
||||
# 检查响应是否包含错误
|
||||
if api_res and "error" not in api_res:
|
||||
logger.info(f"API调用成功,第 {attempt} 次尝试")
|
||||
return api_res
|
||||
|
||||
# 记录错误信息
|
||||
error_msg = api_res.get("error", "未知错误") if api_res else "响应为空"
|
||||
logger.warning(f"API调用返回错误: {error_msg}")
|
||||
last_error = error_msg
|
||||
|
||||
# 如果还有重试机会,等待后重试
|
||||
if attempt < self.max_attempts:
|
||||
await self._wait_for_retry(attempt)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"API调用发生异常: {error_msg}")
|
||||
last_error = error_msg
|
||||
|
||||
if attempt < self.max_attempts:
|
||||
await self._wait_for_retry(attempt)
|
||||
else:
|
||||
logger.error(
|
||||
f"API调用失败,已达到最大重试次数 ({self.max_attempts})"
|
||||
)
|
||||
raise AuthError(f"API调用失败: {error_msg}")
|
||||
|
||||
# 所有重试都失败了
|
||||
raise AuthError(
|
||||
f"API调用失败,重试{self.max_attempts}次后仍然失败: {last_error}"
|
||||
)
|
||||
|
||||
async def _wait_for_retry(self, attempt: int) -> None:
|
||||
"""等待重试的延迟逻辑"""
|
||||
# 计算延迟时间(指数退避策略)
|
||||
delay = min(self.base_delay * (2 ** (attempt - 1)), self.max_delay)
|
||||
# 添加随机抖动(0.5-1.5倍)避免同时重试
|
||||
jitter = 0.5 + random.random()
|
||||
sleep_time = delay * jitter
|
||||
|
||||
logger.info(f"等待 {sleep_time:.2f} 秒后进行第 {attempt + 1} 次重试")
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
|
||||
class EnvironmentManager:
|
||||
"""环境变量管理器,提供环境变量的增删改查功能,支持持久化存储"""
|
||||
|
||||
_env_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../", ".env_lzwcai_mcp_api_converter"))
|
||||
|
||||
@staticmethod
|
||||
def _load_persistent_tokens() -> Dict[str, str]:
|
||||
"""从.env文件加载持久化的token"""
|
||||
tokens = {}
|
||||
try:
|
||||
if os.path.exists(EnvironmentManager._env_file):
|
||||
with open(EnvironmentManager._env_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and '=' in line and not line.startswith('#'):
|
||||
key, value = line.split('=', 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if key: # 确保key不为空
|
||||
tokens[key] = value
|
||||
except Exception as e:
|
||||
logger.error(f"加载.env文件失败: {e}")
|
||||
return tokens
|
||||
|
||||
@staticmethod
|
||||
def _save_persistent_tokens(tokens: Dict[str, str]) -> bool:
|
||||
"""保存token到.env文件"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(EnvironmentManager._env_file), exist_ok=True)
|
||||
with open(EnvironmentManager._env_file, 'w', encoding='utf-8') as f:
|
||||
for key, value in tokens.items():
|
||||
f.write(f"{key}={value}\n")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"保存.env文件失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get(key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
获取环境变量值,优先从进程环境变量获取,然后从持久化文件获取
|
||||
|
||||
Args:
|
||||
key: 环境变量名
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
环境变量值或默认值
|
||||
"""
|
||||
# 先从进程环境变量获取
|
||||
value = os.environ.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
|
||||
# 从持久化文件获取
|
||||
if key.endswith('token'): # 只对token类型的环境变量使用持久化
|
||||
persistent_tokens = EnvironmentManager._load_persistent_tokens()
|
||||
value = persistent_tokens.get(key)
|
||||
if value is not None:
|
||||
# 同时设置到进程环境变量中
|
||||
os.environ[key] = value
|
||||
return value
|
||||
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def set(key: str, value: str) -> bool:
|
||||
"""
|
||||
设置环境变量,同时持久化token类型的变量
|
||||
|
||||
Args:
|
||||
key: 环境变量名
|
||||
value: 环境变量值
|
||||
|
||||
Returns:
|
||||
设置是否成功
|
||||
"""
|
||||
try:
|
||||
# 设置到进程环境变量
|
||||
os.environ[key] = str(value)
|
||||
|
||||
# 如果是token类型,同时持久化到文件
|
||||
if key.endswith('token'):
|
||||
persistent_tokens = EnvironmentManager._load_persistent_tokens()
|
||||
persistent_tokens[key] = str(value)
|
||||
if not EnvironmentManager._save_persistent_tokens(persistent_tokens):
|
||||
logger.warning(f"持久化token失败,但进程环境变量已设置: {key}")
|
||||
else:
|
||||
logger.debug(f"成功持久化token: {key}")
|
||||
|
||||
logger.debug(f"成功设置环境变量: {key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置环境变量失败: {key}, 错误: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def exists(key: str) -> bool:
|
||||
"""
|
||||
检查环境变量是否存在
|
||||
|
||||
Args:
|
||||
key: 环境变量名
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
# 先检查进程环境变量
|
||||
if key in os.environ:
|
||||
return True
|
||||
|
||||
# 检查持久化文件
|
||||
if key.endswith('token'):
|
||||
persistent_tokens = EnvironmentManager._load_persistent_tokens()
|
||||
return key in persistent_tokens
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def delete(key: str) -> bool:
|
||||
"""
|
||||
删除环境变量,同时删除持久化的token
|
||||
|
||||
Args:
|
||||
key: 环境变量名
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
try:
|
||||
# 从进程环境变量删除
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
# 如果是token类型,同时从持久化文件删除
|
||||
if key.endswith('token'):
|
||||
persistent_tokens = EnvironmentManager._load_persistent_tokens()
|
||||
if key in persistent_tokens:
|
||||
del persistent_tokens[key]
|
||||
EnvironmentManager._save_persistent_tokens(persistent_tokens)
|
||||
logger.debug(f"成功删除持久化token: {key}")
|
||||
|
||||
logger.debug(f"成功删除环境变量: {key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"删除环境变量失败: {key}, 错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""Token管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = get_logger("TokenManager")
|
||||
|
||||
def generate_token_name(self, user_id: str, biz_sys_id: str) -> str:
|
||||
"""
|
||||
生成Token名称
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
biz_sys_id: 业务系统ID
|
||||
|
||||
Returns:
|
||||
Token名称
|
||||
"""
|
||||
return (
|
||||
f"{AuthConfig.TOKEN_PREFIX}{user_id}{biz_sys_id}{AuthConfig.TOKEN_SUFFIX}"
|
||||
)
|
||||
|
||||
def check_token_exists(
|
||||
self, token_name: str
|
||||
) -> Tuple[bool, Optional[Union[str, Dict[str, Any]]]]:
|
||||
"""
|
||||
检查环境变量中是否存在Token,并返回反序列化后的值
|
||||
|
||||
Args:
|
||||
token_name: Token名称
|
||||
|
||||
Returns:
|
||||
(是否存在, Token值)
|
||||
"""
|
||||
token_value = EnvironmentManager.get(token_name)
|
||||
if token_value is None:
|
||||
return False, None
|
||||
|
||||
# 尝试反序列化Token值
|
||||
deserialized_value = self._deserialize_token_value(token_value, token_name)
|
||||
return True, deserialized_value
|
||||
|
||||
def _deserialize_token_value(
|
||||
self, token_value: str, token_name: str
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""反序列化Token值"""
|
||||
try:
|
||||
# 尝试将JSON字符串反序列化为原始结构(字典)
|
||||
return json.loads(token_value)
|
||||
except json.JSONDecodeError:
|
||||
self.logger.debug(f"Token不是JSON格式,保留原始字符串: {token_name}")
|
||||
|
||||
# 尝试处理字符串表示的字典,如 "{'key': 'value'}"
|
||||
if token_value.startswith("{") and token_value.endswith("}"):
|
||||
try:
|
||||
import ast
|
||||
|
||||
return ast.literal_eval(token_value)
|
||||
except (ValueError, SyntaxError):
|
||||
self.logger.debug(f"无法解析字典字符串: {token_value}")
|
||||
|
||||
# 返回原始字符串
|
||||
return token_value
|
||||
|
||||
def store_token(self, token_name: str, token_value: Any) -> bool:
|
||||
"""
|
||||
存储Token到环境变量,自动序列化复杂结构
|
||||
|
||||
Args:
|
||||
token_name: Token名称
|
||||
token_value: Token值
|
||||
|
||||
Returns:
|
||||
存储是否成功
|
||||
"""
|
||||
if token_value is None:
|
||||
self.logger.warning(f"Token值为None,不进行存储: {token_name}")
|
||||
return False
|
||||
|
||||
# 序列化Token值
|
||||
serialized_value = self._serialize_token_value(token_value, token_name)
|
||||
if serialized_value is None:
|
||||
return False
|
||||
|
||||
# 存储到环境变量
|
||||
success = EnvironmentManager.set(token_name, serialized_value)
|
||||
if success:
|
||||
self.logger.info(f"成功将Token存储到环境变量: {token_name}")
|
||||
else:
|
||||
self.logger.error(f"存储Token到环境变量失败: {token_name}")
|
||||
|
||||
return success
|
||||
|
||||
def _serialize_token_value(
|
||||
self, token_value: Any, token_name: str
|
||||
) -> Optional[str]:
|
||||
"""序列化Token值"""
|
||||
try:
|
||||
if isinstance(token_value, (dict, list)):
|
||||
return json.dumps(token_value)
|
||||
elif isinstance(token_value, bytes):
|
||||
return token_value.decode("utf-8")
|
||||
elif isinstance(token_value, (int, float, str)):
|
||||
return str(token_value)
|
||||
else:
|
||||
self.logger.error(
|
||||
f"不支持的Token类型: {type(token_value)},不进行存储: {token_name}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(f"序列化Token时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
class CompanyAuthClient:
|
||||
"""企业认证客户端 - 现在使用get_auth_data方法替代HTTP API调用"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = get_logger("CompanyAuthClient")
|
||||
|
||||
async def get_auth_info(
|
||||
self, user_id: str, biz_sys_id: str
|
||||
) -> Tuple[Optional[int], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
获取鉴权类型和认证数据
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
biz_sys_id: 业务系统ID
|
||||
|
||||
Returns:
|
||||
(鉴权类型, 认证数据)
|
||||
"""
|
||||
try:
|
||||
self.logger.debug(f"使用get_auth_data获取认证信息: user_id={user_id}, biz_sys_id={biz_sys_id}")
|
||||
|
||||
# 使用get_auth_data方法获取认证数据
|
||||
result = get_auth_data(user_id, biz_sys_id)
|
||||
|
||||
if not result:
|
||||
self.logger.error("get_auth_data返回空结果")
|
||||
return None, None
|
||||
|
||||
return self._parse_auth_response(result, user_id, biz_sys_id)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"获取鉴权类型失败: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def _parse_auth_response(
|
||||
self, result: Dict[str, Any], user_id: str, biz_sys_id: str
|
||||
) -> Tuple[Optional[int], Optional[Dict[str, Any]]]:
|
||||
"""解析认证响应"""
|
||||
if result.get("code") != 200:
|
||||
error_msg = result.get("msg", "未知错误")
|
||||
self.logger.error(f"获取鉴权类型失败: {error_msg}")
|
||||
return None, None
|
||||
|
||||
auth_data = result.get("data", {})
|
||||
auth_type = auth_data.get("authType")
|
||||
|
||||
# 将字符串类型的authType转换为整数
|
||||
auth_type_int = None
|
||||
if auth_type is not None:
|
||||
try:
|
||||
auth_type_int = int(auth_type)
|
||||
self.logger.info(
|
||||
f"用户{user_id}业务系统{biz_sys_id}的鉴权类型: {auth_type_int}"
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
self.logger.warning(f"无法将authType转换为整数: {auth_type}")
|
||||
|
||||
return auth_type_int, auth_data
|
||||
|
||||
|
||||
class BusinessTokenService:
|
||||
"""业务系统Token服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = get_logger("BusinessTokenService")
|
||||
self.retry_executor = ApiRetryExecutor()
|
||||
|
||||
async def get_business_system_token(
|
||||
self, auth_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取业务系统Token
|
||||
|
||||
Args:
|
||||
auth_data: 认证数据字典
|
||||
|
||||
Returns:
|
||||
包含token信息的字典
|
||||
"""
|
||||
try:
|
||||
# 验证输入参数
|
||||
validation_result = self._validate_auth_data(auth_data)
|
||||
if not validation_result["valid"]:
|
||||
return {"success": False, "msg": validation_result["error"]}
|
||||
|
||||
# 解析配置
|
||||
config_result = self._parse_auth_config(auth_data)
|
||||
if not config_result["valid"]:
|
||||
return {"success": False, "msg": config_result["error"]}
|
||||
|
||||
# 执行API调用获取Token
|
||||
token_result = await self._execute_token_api(
|
||||
config_result["api_def"], config_result["user_params"], auth_data
|
||||
)
|
||||
|
||||
return token_result
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"获取业务系统Token失败: {str(e)}")
|
||||
return {"success": False, "msg": f"系统错误: {str(e)}"}
|
||||
|
||||
def _validate_auth_data(self, auth_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""验证认证数据"""
|
||||
if not isinstance(auth_data, dict):
|
||||
error_msg = f"auth_data类型错误: {type(auth_data)}"
|
||||
self.logger.error(error_msg)
|
||||
return {"valid": False, "error": "认证数据格式错误"}
|
||||
|
||||
# 验证必要字段
|
||||
if "name" not in auth_data:
|
||||
self.logger.error("auth_data中缺少name字段")
|
||||
return {"valid": False, "error": "缺少name配置"}
|
||||
|
||||
if not auth_data.get("apiVO"):
|
||||
self.logger.error("缺少apiVO信息")
|
||||
return {"valid": False, "error": "缺少API配置信息"}
|
||||
|
||||
apiVO = auth_data["apiVO"]
|
||||
if not isinstance(apiVO, dict):
|
||||
error_msg = f"apiVO类型错误: {type(apiVO)}"
|
||||
self.logger.error(error_msg)
|
||||
return {"valid": False, "error": "API配置格式错误"}
|
||||
|
||||
# 检查必要字段
|
||||
required_fields = ["accountConfig", "tokenPath"]
|
||||
missing_fields = [field for field in required_fields if field not in apiVO]
|
||||
if missing_fields:
|
||||
self.logger.error(f"缺少必要字段: {missing_fields}")
|
||||
return {"valid": False, "error": f"缺少必要配置字段: {missing_fields}"}
|
||||
|
||||
return {"valid": True}
|
||||
|
||||
def _parse_auth_config(self, auth_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""解析认证配置"""
|
||||
try:
|
||||
apiVO = auth_data["apiVO"]
|
||||
|
||||
# 解析accountConfig
|
||||
account_config = apiVO["accountConfig"]
|
||||
if isinstance(account_config, str):
|
||||
try:
|
||||
account_config = json.loads(account_config)
|
||||
self.logger.info("成功解析accountConfig JSON")
|
||||
except json.JSONDecodeError as e:
|
||||
self.logger.error(f"accountConfig JSON解析失败: {str(e)}")
|
||||
return {"valid": False, "error": "配置数据格式错误"}
|
||||
|
||||
if not isinstance(account_config, dict):
|
||||
error_msg = f"accountConfig类型错误: {type(account_config)}"
|
||||
self.logger.error(error_msg)
|
||||
return {"valid": False, "error": "配置数据格式错误"}
|
||||
|
||||
# 验证API定义
|
||||
api_def = apiVO.get("tcapabilityApiVO")
|
||||
if api_def is None:
|
||||
self.logger.error("缺少API配置信息")
|
||||
return {"valid": False, "error": "缺少API配置信息"}
|
||||
|
||||
# 准备用户参数
|
||||
try:
|
||||
parameters_body = account_config.get("parametersBody", [])
|
||||
user_params = ParameterExtractor.extract_param_defaults(
|
||||
parameters_body
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"准备API参数失败: {str(e)}")
|
||||
return {"valid": False, "error": "参数准备失败"}
|
||||
|
||||
# 准备API定义映射
|
||||
api_def_map = {**api_def, "parameters": api_def.get("apiParameterList", [])}
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"api_def": api_def_map,
|
||||
"user_params": user_params,
|
||||
"account_config": account_config,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"解析认证配置失败: {str(e)}")
|
||||
return {"valid": False, "error": "配置解析失败"}
|
||||
|
||||
async def _execute_token_api(
|
||||
self,
|
||||
api_def: Dict[str, Any],
|
||||
user_params: Dict[str, Any],
|
||||
auth_data: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""执行Token获取API"""
|
||||
try:
|
||||
# 执行API调用
|
||||
api_res = await self.retry_executor.execute_with_retry(
|
||||
api_def, user_params, need_auth=False
|
||||
)
|
||||
|
||||
if not api_res:
|
||||
return {"success": False, "msg": "API调用失败"}
|
||||
|
||||
# 提取Token
|
||||
return self._extract_token_from_response(api_res, auth_data)
|
||||
|
||||
except AuthError as e:
|
||||
self.logger.error(f"API调用失败: {str(e)}")
|
||||
return {"success": False, "msg": f"API调用失败: {str(e)}"}
|
||||
except Exception as e:
|
||||
self.logger.error(f"执行Token API异常: {str(e)}")
|
||||
return {"success": False, "msg": f"API调用异常: {str(e)}"}
|
||||
|
||||
def _extract_token_from_response(
|
||||
self, api_res: Dict[str, Any], auth_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""从API响应中提取Token"""
|
||||
try:
|
||||
name = auth_data["name"]
|
||||
token_path = auth_data["apiVO"]["tokenPath"]
|
||||
|
||||
value = get_nested_value({"res": api_res}, token_path)
|
||||
|
||||
if not value:
|
||||
self.logger.error("未获取到有效token")
|
||||
return {"success": False, "msg": "获取token失败"}
|
||||
|
||||
return {
|
||||
"tokenHeader": {name: value},
|
||||
"token": value,
|
||||
"msg": api_res.get("msg", "获取token成功"),
|
||||
"success": True,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"处理token结果失败: {str(e)}")
|
||||
return {"success": False, "msg": "处理token结果失败"}
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""认证服务,负责处理鉴权相关逻辑"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = get_logger("AuthService")
|
||||
self.token_manager = TokenManager()
|
||||
self.company_auth_client = CompanyAuthClient()
|
||||
self.business_token_service = BusinessTokenService()
|
||||
|
||||
async def authorize_request(
|
||||
self,
|
||||
user_id: Optional[str],
|
||||
biz_sys_id: Optional[str],
|
||||
persist_token: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
完整的请求鉴权处理
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
biz_sys_id: 业务系统ID
|
||||
persist_token: 是否持久化token,默认True
|
||||
|
||||
Returns:
|
||||
认证结果
|
||||
"""
|
||||
self.logger.info(f"开始认证处理 - 用户ID: {user_id}, 业务系统ID: {biz_sys_id}, 持久化: {persist_token}")
|
||||
|
||||
# 获取并验证token
|
||||
token_header = await self.check_user_token(user_id, biz_sys_id, persist_token=persist_token)
|
||||
self.logger.info(f"获取到Token头: {token_header}")
|
||||
|
||||
if not token_header:
|
||||
return {
|
||||
"success": False,
|
||||
"error_response": {
|
||||
"error": "获取鉴权令牌失败,请前往管理平台进行鉴权或提供临时令牌"
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tokenHeader": token_header,
|
||||
"source": "auth_service", # 标记token来源
|
||||
}
|
||||
|
||||
async def check_user_token(
|
||||
self,
|
||||
user_id: Optional[str],
|
||||
biz_sys_id: Optional[str],
|
||||
token: Optional[str] = None,
|
||||
persist_token: bool = True,
|
||||
) -> Optional[Union[str, Dict[str, Any]]]:
|
||||
"""
|
||||
检查用户Token是否有效,如无效则重新获取
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
biz_sys_id: 业务系统ID
|
||||
token: 可选的临时Token
|
||||
persist_token: 是否持久化token
|
||||
|
||||
Returns:
|
||||
Token值或None
|
||||
"""
|
||||
if not user_id or not biz_sys_id:
|
||||
self.logger.warning("用户ID或业务系统ID为空,无法检查Token")
|
||||
return None
|
||||
|
||||
# 生成Token名
|
||||
token_name = self.token_manager.generate_token_name(user_id, biz_sys_id)
|
||||
|
||||
self.logger.info(f"Token名: {token_name}")
|
||||
|
||||
# 如果不需要持久化,直接重新获取Token
|
||||
if not persist_token:
|
||||
self.logger.info("不使用持久化,直接获取新Token")
|
||||
return await self._refresh_user_token(user_id, biz_sys_id, token_name, persist_token)
|
||||
|
||||
# 检查环境变量是否存在现有Token
|
||||
exists, token_value = self.token_manager.check_token_exists(token_name)
|
||||
self.logger.info(f"Token存在性检查: {exists}, 值: {token_value}")
|
||||
|
||||
# 如果环境变量存在,直接返回值
|
||||
if exists:
|
||||
self.logger.info(
|
||||
f"从环境变量获取到用户{user_id}业务系统{biz_sys_id}的Token"
|
||||
)
|
||||
return token_value
|
||||
|
||||
# 如果提供了token参数,直接使用并存储
|
||||
if token:
|
||||
if persist_token:
|
||||
self.token_manager.store_token(token_name, token)
|
||||
return token
|
||||
|
||||
# 重新获取Token
|
||||
return await self._refresh_user_token(user_id, biz_sys_id, token_name, persist_token)
|
||||
|
||||
async def _refresh_user_token(
|
||||
self, user_id: str, biz_sys_id: str, token_name: str, persist_token: bool = True
|
||||
) -> Optional[Union[str, Dict[str, Any]]]:
|
||||
"""刷新用户Token"""
|
||||
# 获取鉴权类型和认证数据
|
||||
auth_type, auth_data = await self.company_auth_client.get_auth_info(
|
||||
user_id, biz_sys_id
|
||||
)
|
||||
|
||||
if auth_type is None:
|
||||
self.logger.error(f"无法获取用户{user_id}业务系统{biz_sys_id}的鉴权类型")
|
||||
return None
|
||||
|
||||
# 根据鉴权类型获取Token
|
||||
token_value = await self._get_token_by_auth_type(
|
||||
user_id, biz_sys_id, auth_type, auth_data
|
||||
)
|
||||
|
||||
self.logger.info(f"Token值: {token_value}")
|
||||
|
||||
# 存储Token(根据persist_token参数决定是否持久化)
|
||||
if token_value:
|
||||
if persist_token:
|
||||
success = self.token_manager.store_token(token_name, token_value)
|
||||
if not success:
|
||||
self.logger.error(f"存储用户{user_id}业务系统{biz_sys_id}的Token失败")
|
||||
else:
|
||||
self.logger.info(f"成功存储用户{user_id}业务系统{biz_sys_id}的Token到环境变量: {token_name}")
|
||||
else:
|
||||
self.logger.info(f"跳过持久化,用户{user_id}业务系统{biz_sys_id}的Token仅在内存中使用")
|
||||
else:
|
||||
self.logger.warning(f"未能获取到用户{user_id}业务系统{biz_sys_id}的Token,token_value: {token_value}")
|
||||
|
||||
return token_value
|
||||
|
||||
def clear_token(self, user_id: str, biz_sys_id: str) -> bool:
|
||||
"""
|
||||
清空指定用户的token
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
biz_sys_id: 业务系统ID
|
||||
|
||||
Returns:
|
||||
清空是否成功
|
||||
"""
|
||||
if not user_id or not biz_sys_id:
|
||||
self.logger.warning("用户ID或业务系统ID为空,无法清空Token")
|
||||
return False
|
||||
|
||||
# 生成Token名
|
||||
token_name = self.token_manager.generate_token_name(user_id, biz_sys_id)
|
||||
|
||||
# 删除环境变量中的token
|
||||
success = EnvironmentManager.delete(token_name)
|
||||
|
||||
if success:
|
||||
self.logger.info(f"成功清空用户{user_id}业务系统{biz_sys_id}的Token: {token_name}")
|
||||
else:
|
||||
self.logger.error(f"清空用户{user_id}业务系统{biz_sys_id}的Token失败: {token_name}")
|
||||
|
||||
return success
|
||||
|
||||
def clear_all_tokens(self) -> bool:
|
||||
"""
|
||||
清空所有持久化的token
|
||||
|
||||
Returns:
|
||||
清空是否成功
|
||||
"""
|
||||
try:
|
||||
# 加载所有token
|
||||
persistent_tokens = EnvironmentManager._load_persistent_tokens()
|
||||
|
||||
# 删除所有token类型的环境变量
|
||||
for token_name in list(persistent_tokens.keys()):
|
||||
if token_name.endswith('token'):
|
||||
EnvironmentManager.delete(token_name)
|
||||
|
||||
self.logger.info("成功清空所有持久化token")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"清空所有token失败: {e}")
|
||||
return False
|
||||
|
||||
async def _get_token_by_auth_type(
|
||||
self, user_id: str, biz_sys_id: str, auth_type: int, auth_data: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""根据鉴权类型获取Token"""
|
||||
try:
|
||||
if auth_type == 0:
|
||||
# 直接使用apiKey作为Token
|
||||
return self._get_api_key_token(user_id, biz_sys_id, auth_data)
|
||||
elif auth_type == 1:
|
||||
# 调用登录接口获取Token
|
||||
return await self._get_login_token(user_id, biz_sys_id, auth_data)
|
||||
else:
|
||||
self.logger.warning(f"不支持的鉴权类型: {auth_type}")
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.error(f"获取Token失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def _get_api_key_token(
|
||||
self, user_id: str, biz_sys_id: str, auth_data: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""获取API Key类型的Token"""
|
||||
api_key = auth_data.get("apiKey")
|
||||
name = auth_data.get("name")
|
||||
|
||||
if api_key and name:
|
||||
self.logger.info(f"使用apiKey作为用户{user_id}业务系统{biz_sys_id}的Token")
|
||||
return {name: api_key}
|
||||
else:
|
||||
self.logger.warning(f"用户{user_id}业务系统{biz_sys_id}的apiKey或name为空")
|
||||
return None
|
||||
|
||||
async def _get_login_token(
|
||||
self, user_id: str, biz_sys_id: str, auth_data: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""获取登录类型的Token"""
|
||||
self.logger.info(f"通过登录接口获取用户{user_id}业务系统{biz_sys_id}的Token")
|
||||
|
||||
login_res = await self.business_token_service.get_business_system_token(
|
||||
auth_data
|
||||
)
|
||||
|
||||
if login_res and login_res.get("tokenHeader") is not None:
|
||||
return login_res.get("tokenHeader")
|
||||
else:
|
||||
error_msg = login_res.get("msg", "未知错误") if login_res else "未知错误"
|
||||
self.logger.error(f"获取Token失败: {error_msg}")
|
||||
return None
|
||||
|
||||
|
||||
# 兼容性函数 - 保持向后兼容
|
||||
def extract_param_defaults(param_list: list) -> Dict[str, Any]:
|
||||
"""提取参数默认值(兼容性函数)"""
|
||||
return ParameterExtractor.extract_param_defaults(param_list)
|
||||
|
||||
|
||||
async def execute_api_call_with_retry(
|
||||
api_config: Dict[str, Any],
|
||||
user_params: Dict[str, Any],
|
||||
need_auth: bool = False,
|
||||
max_attempts: int = 3,
|
||||
base_delay: int = 1,
|
||||
max_delay: int = 60,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""API重试调用(兼容性函数)"""
|
||||
executor = ApiRetryExecutor(max_attempts, base_delay, max_delay)
|
||||
return await executor.execute_with_retry(api_config, user_params, need_auth)
|
||||
|
||||
|
||||
class EnvManager:
|
||||
"""环境变量管理器(兼容性类)"""
|
||||
|
||||
@staticmethod
|
||||
def get_env(key: str, default: Optional[str] = None) -> Optional[str]:
|
||||
return EnvironmentManager.get(key, default)
|
||||
|
||||
@staticmethod
|
||||
def set_env(key: str, value: str) -> bool:
|
||||
return EnvironmentManager.set(key, value)
|
||||
|
||||
@staticmethod
|
||||
def exists_env(key: str) -> bool:
|
||||
return EnvironmentManager.exists(key)
|
||||
|
||||
|
||||
class Config:
|
||||
"""配置类(兼容性)"""
|
||||
pass
|
||||
|
||||
|
||||
async def test_auth_service():
|
||||
"""测试认证服务"""
|
||||
auth_service = AuthService()
|
||||
token_header = await auth_service.check_user_token(
|
||||
"1932715213891215361", "1932385006853664770"
|
||||
)
|
||||
logger.info(f"测试结果 - Token头: {token_header}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置日志
|
||||
from ..util.logger_config import setup_logging
|
||||
import logging
|
||||
setup_logging(log_level=logging.INFO)
|
||||
|
||||
asyncio.run(test_auth_service())
|
||||
@@ -0,0 +1,610 @@
|
||||
"""
|
||||
API基础模块 - 核心API管理和调用功能
|
||||
|
||||
这个模块是整个系统的核心,提供了API配置管理、Schema生成、接口调用等基础功能。
|
||||
主要包含以下组件:
|
||||
|
||||
1. ApiBase类: API管理的抽象基类
|
||||
2. 工具函数: 拼音转换、Schema处理等
|
||||
3. 常量定义: 认证级别、参数类型等
|
||||
|
||||
主要功能:
|
||||
- API配置的加载和处理
|
||||
- JSON Schema的生成和验证
|
||||
- 中文接口名称转拼音命名
|
||||
- API接口的统一调用管理
|
||||
- 认证和参数处理
|
||||
|
||||
作者: lzwcai
|
||||
版本: 1.0.0
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import pypinyin
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any, Tuple, Optional, Union
|
||||
|
||||
# 导入业务工具模块
|
||||
from ..business.business_util import (
|
||||
generate_json_schema, # JSON Schema生成
|
||||
generate_schema_prompt, # Schema提示文本生成
|
||||
remove_property_from_api_array, # API数组属性移除
|
||||
)
|
||||
|
||||
|
||||
# ==================== 常量定义 ====================
|
||||
|
||||
class AuthenticationLevel:
|
||||
"""
|
||||
API认证级别常量
|
||||
|
||||
定义API接口的认证要求级别:
|
||||
- REQUIRED: 需要认证(值为1)
|
||||
- NOT_REQUIRED: 不需要认证(值为0)
|
||||
"""
|
||||
REQUIRED = 1 # 需要认证
|
||||
NOT_REQUIRED = 0 # 不需要认证
|
||||
|
||||
|
||||
# 导入统一日志配置
|
||||
from ..util.logger_config import get_logger
|
||||
|
||||
# 获取日志器实例
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
def pinyin_to_camel(text: str) -> str:
|
||||
"""
|
||||
中文文本转拼音驼峰命名函数
|
||||
|
||||
将中文文本转换为带'tool_'前缀的驼峰命名格式,用于生成API工具的标识符。
|
||||
这个函数是系统中重要的命名转换工具,确保中文API名称能够转换为合法的标识符。
|
||||
|
||||
转换规则:
|
||||
1. 将所有非字母数字字符(包括标点符号)替换为下划线
|
||||
2. 将空格替换为下划线
|
||||
3. 移除连续的下划线并去除首尾下划线
|
||||
4. 使用pypinyin库将中文转换为拼音
|
||||
5. 每个拼音单词首字母大写(驼峰格式)
|
||||
6. 添加'tool_'前缀以符合工具命名规范
|
||||
|
||||
参数:
|
||||
text: 要转换的中文文本
|
||||
|
||||
返回:
|
||||
str: 转换后的驼峰命名字符串,格式为'tool_XxxYyy'
|
||||
|
||||
示例:
|
||||
>>> pinyin_to_camel("用户登录")
|
||||
'tool_YongHuDengLu'
|
||||
>>> pinyin_to_camel("获取订单列表")
|
||||
'tool_HuoQuDingDanLieBiao'
|
||||
|
||||
异常处理:
|
||||
TypeError: 如果输入不是字符串类型
|
||||
ValueError: 如果输入为空字符串
|
||||
|
||||
容错机制:
|
||||
- 转换失败时使用hash值生成备用名称
|
||||
- 确保始终返回有效的标识符
|
||||
"""
|
||||
# 参数类型检查
|
||||
if not isinstance(text, str):
|
||||
raise TypeError("text must be a string")
|
||||
|
||||
# 参数内容检查
|
||||
if not text.strip():
|
||||
raise ValueError("text cannot be empty")
|
||||
|
||||
try:
|
||||
logger.debug(f"转换中文文本为拼音: {text}")
|
||||
|
||||
# 第一步:将所有非中文、非字母、非数字的字符(包括中文标点符号)替换为空格
|
||||
# 这样可以正确处理中文标点符号(包括中文括号、顿号等)
|
||||
# \u4e00-\u9fff 匹配所有中文字符
|
||||
# a-zA-Z0-9 匹配英文字母和数字
|
||||
cleaned = re.sub(r'[^\u4e00-\u9fffa-zA-Z0-9\s]', ' ', text)
|
||||
|
||||
# 第二步:将多个空格合并为一个空格,并去除首尾空格
|
||||
cleaned = re.sub(r'\s+', ' ', cleaned).strip()
|
||||
|
||||
# 第三步:使用pypinyin库转换为拼音列表
|
||||
# pypinyin会将中文转为拼音,英文和数字保持原样
|
||||
pinyin_list = pypinyin.lazy_pinyin(cleaned)
|
||||
|
||||
# 第四步:将拼音列表转换为驼峰格式
|
||||
# 过滤掉空白字符、下划线等特殊字符,只保留有效的拼音单词
|
||||
# 注意:pypinyin对于空格会产生空字符串,需要过滤掉
|
||||
camel_case = "".join(
|
||||
word.strip().capitalize()
|
||||
for word in pinyin_list
|
||||
if word.strip() and word.strip() not in ['_', '-', '.', ' ']
|
||||
)
|
||||
|
||||
# 第五步:添加工具前缀
|
||||
result = f"tool_{camel_case}"
|
||||
|
||||
logger.debug(f"拼音转换结果: {text} -> {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"拼音转换失败,文本: '{text}', 错误: {str(e)}")
|
||||
|
||||
# 容错处理:生成基于hash的备用名称
|
||||
# 使用hash确保相同输入产生相同输出
|
||||
fallback = f"tool_Unknown_{hash(text) % 10000}"
|
||||
logger.warning(f"使用备用名称: {fallback}")
|
||||
return fallback
|
||||
|
||||
|
||||
def _process_api_schema(param: Dict[str, Any]) -> Tuple[Dict[str, Any], str]:
|
||||
"""
|
||||
API参数Schema处理函数
|
||||
|
||||
这个函数负责处理单个API配置,生成对应的JSON Schema和接口名称。
|
||||
它是API配置转换的核心函数,将业务平台的API配置转换为MCP工具所需的格式。
|
||||
|
||||
处理流程:
|
||||
1. 提取API参数列表
|
||||
2. 生成JSON Schema(包含参数类型、描述、默认值等)
|
||||
3. 将中文接口名转换为拼音格式的工具名称
|
||||
4. 返回处理后的Schema和接口名称
|
||||
|
||||
参数:
|
||||
param: API参数配置字典,包含以下字段:
|
||||
- interfaceName: 接口名称(中文)
|
||||
- parameters: 参数列表
|
||||
- 其他API配置信息
|
||||
|
||||
返回:
|
||||
tuple: (处理后的JSON Schema, 转换后的接口名称)
|
||||
- processed_schema: 符合JSON Schema规范的参数定义
|
||||
- interface_name: 转换后的拼音格式接口名称
|
||||
|
||||
异常处理:
|
||||
ImportError: 当无法导入必需模块时抛出
|
||||
Exception: 当Schema生成失败时抛出
|
||||
|
||||
注意事项:
|
||||
- userId参数的处理在请求时进行,而不是在Schema生成时
|
||||
- userId现在存储在lzwcaiConfig分组中,支持动态userId值,提高系统灵活性
|
||||
- 使用延迟导入避免循环依赖问题
|
||||
"""
|
||||
try:
|
||||
# 延迟导入避免循环依赖
|
||||
# 这些模块可能会反过来导入当前模块
|
||||
from .core_server import get_env_user_id
|
||||
# from ..business.business_util import remove_property_from_api_item
|
||||
|
||||
logger.debug(f"处理API参数: {param.get('interfaceName', 'N/A')}")
|
||||
|
||||
# 提取API参数列表并生成JSON Schema
|
||||
parameters = param.get("parameters", [])
|
||||
logger.debug(f"参数数量: {len(parameters)}")
|
||||
|
||||
# 调用业务工具模块生成标准JSON Schema
|
||||
schema = generate_json_schema(parameters)
|
||||
|
||||
# 重要说明:userId参数的处理策略
|
||||
# 为了支持动态userId值,userId参数的处理在请求时进行,
|
||||
# 而不是在Schema生成时进行。这样可以支持不同用户的动态切换。
|
||||
# userId现在存储在lzwcaiConfig分组中。
|
||||
logger.debug("Schema生成完成,userId处理将在请求时进行")
|
||||
|
||||
# 生成接口名称(中文转拼音)
|
||||
interface_name_raw = param.get("interfaceName", "")
|
||||
if interface_name_raw:
|
||||
# 使用拼音转换函数生成工具名称
|
||||
interface_name = pinyin_to_camel(interface_name_raw)
|
||||
else:
|
||||
# 备用名称,防止接口名称为空
|
||||
interface_name = "tool_Unknown"
|
||||
|
||||
logger.debug(f"生成接口名称: {interface_name_raw} -> {interface_name}")
|
||||
return schema, interface_name
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"导入必需模块失败: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"处理API Schema时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_api_configs_map(api_configs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
API配置映射处理函数
|
||||
|
||||
这个函数是API配置处理的核心,负责将原始的API配置列表转换为
|
||||
包含Schema和描述信息的处理后配置列表。每个API配置都会被转换为
|
||||
一个完整的MCP工具定义。
|
||||
|
||||
处理内容:
|
||||
1. 为每个API生成JSON Schema
|
||||
2. 转换中文接口名为拼音格式
|
||||
3. 生成工具描述(业务描述 + 参数说明)
|
||||
4. 创建完整的工具配置对象
|
||||
5. 错误处理和容错机制
|
||||
|
||||
参数:
|
||||
api_configs: API配置字典列表,每个字典包含:
|
||||
- interfaceName: 接口名称
|
||||
- businessPrompts: 业务描述
|
||||
- parameters: 参数列表
|
||||
- 其他API配置信息
|
||||
|
||||
返回:
|
||||
list: 处理后的API配置列表,每个配置包含:
|
||||
- interfaceName: 转换后的接口名称(拼音格式)
|
||||
- schema: JSON Schema对象
|
||||
- schema_description: 完整的工具描述
|
||||
- 原始配置的所有其他字段
|
||||
|
||||
异常处理:
|
||||
TypeError: 如果api_configs不是列表类型
|
||||
ValueError: 如果api_configs为空列表
|
||||
|
||||
容错机制:
|
||||
- 跳过无效的配置项(非字典类型)
|
||||
- 处理失败时保留原始配置并添加错误标记
|
||||
- 详细记录处理过程和错误信息
|
||||
"""
|
||||
# 参数验证
|
||||
if not isinstance(api_configs, list):
|
||||
raise TypeError("api_configs must be a list")
|
||||
|
||||
if not api_configs:
|
||||
raise ValueError("api_configs cannot be empty")
|
||||
|
||||
logger.info(f"开始处理 {len(api_configs)} 个API配置")
|
||||
api_array = []
|
||||
|
||||
# 遍历处理每个API配置
|
||||
for i, param in enumerate(api_configs):
|
||||
# 检查配置项类型
|
||||
if not isinstance(param, dict):
|
||||
logger.warning(
|
||||
f"跳过无效的API配置 (索引 {i}): 不是字典类型"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
logger.debug(f"处理API配置 {i}: {param.get('interfaceName', 'N/A')}")
|
||||
|
||||
# 处理Schema和接口名称
|
||||
schema, interface_name = _process_api_schema(param)
|
||||
|
||||
# 获取业务描述
|
||||
description = param.get("businessPrompts", "")
|
||||
|
||||
# 生成参数Schema的提示文本
|
||||
schema_prompt = generate_schema_prompt(schema)
|
||||
|
||||
# 组合工具描述:业务描述 + 参数说明
|
||||
if description:
|
||||
schema_description = f"{description}\n\n{schema_prompt}"
|
||||
else:
|
||||
schema_description = f"工具描述: 暂无描述\n\n{schema_prompt}"
|
||||
|
||||
# 创建处理后的API配置对象
|
||||
processed_config = {
|
||||
**param, # 保留原始配置的所有字段
|
||||
"interfaceName": interface_name, # 更新接口名称为拼音格式
|
||||
"schema": schema, # 添加JSON Schema
|
||||
"schema_description": schema_description, # 添加完整描述
|
||||
}
|
||||
|
||||
api_array.append(processed_config)
|
||||
logger.debug(f"API配置 {i} 处理完成: {interface_name}")
|
||||
|
||||
except Exception as e:
|
||||
# 处理异常:记录错误但不中断整个处理流程
|
||||
logger.error(f"处理API配置 {i} 时出错: {str(e)}")
|
||||
logger.debug(f"错误的API配置内容: {param}")
|
||||
|
||||
# 创建错误配置对象,保留原始数据但添加错误标记
|
||||
error_config = {
|
||||
**param, # 保留原始配置
|
||||
"interfaceName": f"tool_Error_{i}", # 错误标记的接口名
|
||||
"schema": {}, # 空Schema
|
||||
"schema_description": f"配置处理错误: {str(e)}", # 错误描述
|
||||
}
|
||||
api_array.append(error_config)
|
||||
|
||||
logger.info(f"API配置处理完成,成功处理 {len(api_array)} 个配置")
|
||||
return api_array
|
||||
|
||||
|
||||
# ==================== API基础管理类 ====================
|
||||
|
||||
class ApiBase(ABC):
|
||||
"""
|
||||
API管理抽象基类
|
||||
|
||||
这个类是整个API管理系统的核心基类,提供了API配置处理、
|
||||
Schema生成、接口调用等基础功能。所有具体的API管理实现
|
||||
都应该继承这个类。
|
||||
|
||||
主要职责:
|
||||
1. API配置的加载和验证
|
||||
2. API配置到工具定义的转换
|
||||
3. 提供统一的API调用接口
|
||||
4. 管理API配置的生命周期
|
||||
5. 提供配置查询和检索功能
|
||||
|
||||
设计模式:
|
||||
- 抽象基类模式:定义API管理的标准接口
|
||||
- 模板方法模式:提供通用的处理流程
|
||||
- 策略模式:支持不同的认证和调用策略
|
||||
|
||||
属性:
|
||||
api_configs: 原始API配置列表
|
||||
api_configs_map: 处理后的API配置映射表
|
||||
"""
|
||||
|
||||
def __init__(self, api_configs: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
初始化API基础管理器
|
||||
|
||||
这个构造函数负责初始化API管理器,处理传入的API配置列表,
|
||||
并生成相应的工具定义映射表。
|
||||
|
||||
初始化流程:
|
||||
1. 验证输入参数的类型和内容
|
||||
2. 保存原始API配置
|
||||
3. 调用配置映射函数生成工具定义
|
||||
4. 记录初始化结果
|
||||
5. 可选的调试输出
|
||||
|
||||
参数:
|
||||
api_configs: API配置字典列表,每个字典包含完整的API定义
|
||||
|
||||
异常处理:
|
||||
TypeError: 如果api_configs不是列表类型
|
||||
ValueError: 如果api_configs为空列表
|
||||
|
||||
注意事项:
|
||||
- 配置验证在映射函数中进行
|
||||
- 支持部分配置失败的容错处理
|
||||
- 调试模式下可以输出Schema到文件
|
||||
"""
|
||||
# 参数类型验证
|
||||
if not isinstance(api_configs, list):
|
||||
raise TypeError("api_configs must be a list")
|
||||
|
||||
# 参数内容验证
|
||||
if not api_configs:
|
||||
raise ValueError("api_configs cannot be empty")
|
||||
|
||||
# 保存原始配置
|
||||
self.api_configs = api_configs
|
||||
|
||||
# 生成处理后的配置映射表
|
||||
# 这是核心处理步骤,将原始配置转换为MCP工具定义
|
||||
self.api_configs_map = get_api_configs_map(api_configs)
|
||||
|
||||
logger.info(f"ApiBase初始化完成,共处理 {len(self.api_configs)} 个API配置")
|
||||
|
||||
# 可选的调试输出(默认关闭)
|
||||
# 在开发和调试阶段可以启用这个功能
|
||||
# self._save_debug_schema()
|
||||
|
||||
def _save_debug_schema(self) -> None:
|
||||
"""
|
||||
保存调试Schema到文件
|
||||
|
||||
这个方法用于开发和调试阶段,将处理后的API配置映射表
|
||||
保存到JSON文件中,方便查看和分析Schema生成结果。
|
||||
|
||||
输出文件:
|
||||
output_schema.json: 包含所有处理后的API配置
|
||||
|
||||
特性:
|
||||
- UTF-8编码确保中文正确显示
|
||||
- 格式化输出便于阅读
|
||||
- 异常安全,不会影响主要功能
|
||||
"""
|
||||
try:
|
||||
with open("output_schema.json", "w", encoding="utf-8") as f:
|
||||
json.dump(self.api_configs_map, f, ensure_ascii=False, indent=4)
|
||||
logger.debug("调试Schema已保存到 output_schema.json")
|
||||
except Exception as e:
|
||||
logger.error(f"保存调试Schema失败: {str(e)}")
|
||||
|
||||
async def call_interface(
|
||||
self, api_config: Dict[str, Any], request_data: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
API接口调用方法
|
||||
|
||||
这是ApiBase类的核心方法,负责调用具体的API接口。
|
||||
它处理认证逻辑、参数传递和错误处理,为上层提供统一的API调用接口。
|
||||
|
||||
调用流程:
|
||||
1. 验证输入参数的类型
|
||||
2. 导入核心服务器模块(避免循环依赖)
|
||||
3. 判断是否需要认证
|
||||
4. 调用底层API接口
|
||||
5. 返回API响应结果
|
||||
|
||||
认证处理:
|
||||
- 根据API配置中的authenticationRequired字段判断是否需要认证
|
||||
- 支持两种认证级别:REQUIRED(1) 和 NOT_REQUIRED(0)
|
||||
- 认证逻辑由core_server模块的call_api函数处理
|
||||
|
||||
参数:
|
||||
api_config: API配置字典,包含:
|
||||
- authenticationRequired: 认证要求级别
|
||||
- apiUrl: API接口地址
|
||||
- method: HTTP方法
|
||||
- 其他API配置信息
|
||||
request_data: 请求数据字典,包含:
|
||||
- header: 请求头参数
|
||||
- query: 查询参数
|
||||
- body: 请求体参数
|
||||
- lzwcaiConfig: 配置参数(包含userId)
|
||||
|
||||
返回:
|
||||
Any: API接口的响应数据,通常是字典格式
|
||||
|
||||
异常处理:
|
||||
TypeError: 如果参数不是字典类型
|
||||
ImportError: 如果无法导入核心服务器模块
|
||||
Exception: 如果API调用失败
|
||||
|
||||
设计考虑:
|
||||
- 使用延迟导入避免循环依赖
|
||||
- 统一的错误处理和日志记录
|
||||
- 支持异步调用以提高性能
|
||||
"""
|
||||
# 参数类型验证
|
||||
if not isinstance(api_config, dict):
|
||||
raise TypeError("api_config must be a dictionary")
|
||||
|
||||
if not isinstance(request_data, dict):
|
||||
raise TypeError("request_data must be a dictionary")
|
||||
|
||||
try:
|
||||
# 延迟导入避免循环依赖
|
||||
# core_server模块可能会导入当前模块
|
||||
from .core_server import call_api
|
||||
|
||||
# 判断认证要求
|
||||
# 从API配置中获取认证要求,默认为不需要认证
|
||||
auth_required = api_config.get(
|
||||
"authenticationRequired", AuthenticationLevel.NOT_REQUIRED
|
||||
)
|
||||
need_auth = auth_required == AuthenticationLevel.REQUIRED
|
||||
|
||||
logger.info(f"调用API接口,需要认证: {need_auth}")
|
||||
logger.debug(f"API配置: {api_config.get('apiUrl', 'N/A')}")
|
||||
|
||||
# 调用底层API接口
|
||||
# call_api函数处理具体的HTTP请求、认证、参数处理等
|
||||
return await call_api(api_config, request_data, need_auth=need_auth)
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"导入call_api函数失败: {str(e)}")
|
||||
raise ImportError(
|
||||
f"无法导入必需的API调用功能: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"API调用失败: {str(e)}")
|
||||
logger.debug("API调用异常详情:", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_api_config_by_name(self, interface_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据接口名称获取API配置
|
||||
|
||||
这个方法用于根据接口名称查找对应的API配置。
|
||||
接口名称是经过拼音转换后的工具名称(如:tool_YongHuDengLu)。
|
||||
|
||||
查找逻辑:
|
||||
- 遍历所有处理后的API配置
|
||||
- 匹配interfaceName字段
|
||||
- 返回第一个匹配的配置
|
||||
|
||||
参数:
|
||||
interface_name: 要查找的接口名称(拼音格式)
|
||||
|
||||
返回:
|
||||
Optional[Dict[str, Any]]: 找到的API配置字典,未找到则返回None
|
||||
|
||||
使用场景:
|
||||
- MCP工具调用时查找对应的API配置
|
||||
- 验证工具名称是否存在
|
||||
- 获取特定工具的配置信息
|
||||
"""
|
||||
# 参数类型检查
|
||||
if not isinstance(interface_name, str):
|
||||
logger.warning(f"接口名称类型错误: {type(interface_name)}")
|
||||
return None
|
||||
|
||||
# 遍历查找匹配的配置
|
||||
for config in self.api_configs_map:
|
||||
if config.get("interfaceName") == interface_name:
|
||||
logger.debug(f"找到接口配置: {interface_name}")
|
||||
return config
|
||||
|
||||
logger.debug(f"未找到接口配置: {interface_name}")
|
||||
return None
|
||||
|
||||
def get_all_interface_names(self) -> List[str]:
|
||||
"""
|
||||
获取所有可用的接口名称列表
|
||||
|
||||
这个方法返回所有已处理的API配置的接口名称列表。
|
||||
主要用于调试、监控和工具列表展示。
|
||||
|
||||
返回:
|
||||
List[str]: 所有接口名称的列表(拼音格式)
|
||||
|
||||
特性:
|
||||
- 过滤掉空的接口名称
|
||||
- 返回的是处理后的拼音格式名称
|
||||
- 按配置顺序返回
|
||||
|
||||
使用场景:
|
||||
- 系统监控和状态检查
|
||||
- 调试和日志记录
|
||||
- 管理界面展示可用工具
|
||||
"""
|
||||
interface_names = [
|
||||
config.get("interfaceName", "")
|
||||
for config in self.api_configs_map
|
||||
if config.get("interfaceName") # 过滤空名称
|
||||
]
|
||||
|
||||
logger.debug(f"获取到 {len(interface_names)} 个接口名称")
|
||||
return interface_names
|
||||
|
||||
def get_schema_by_name(self, interface_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据接口名称获取JSON Schema
|
||||
|
||||
这个方法是get_api_config_by_name的便捷包装,
|
||||
直接返回指定接口的JSON Schema定义。
|
||||
|
||||
参数:
|
||||
interface_name: 接口名称(拼音格式)
|
||||
|
||||
返回:
|
||||
Optional[Dict[str, Any]]: JSON Schema字典,未找到则返回None
|
||||
|
||||
使用场景:
|
||||
- 参数验证
|
||||
- 文档生成
|
||||
- 客户端工具定义
|
||||
"""
|
||||
config = self.get_api_config_by_name(interface_name)
|
||||
if config:
|
||||
return config.get("schema")
|
||||
return None
|
||||
|
||||
@property
|
||||
def config_count(self) -> int:
|
||||
"""
|
||||
获取API配置数量
|
||||
|
||||
这是一个属性方法,返回当前管理的API配置总数。
|
||||
主要用于监控、日志记录和状态检查。
|
||||
|
||||
返回:
|
||||
int: API配置的数量
|
||||
"""
|
||||
return len(self.api_configs_map)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
ApiBase对象的字符串表示
|
||||
|
||||
提供对象的简洁字符串表示,主要用于调试和日志记录。
|
||||
|
||||
返回:
|
||||
str: 对象的字符串表示,格式为"ApiBase(configs=数量)"
|
||||
"""
|
||||
return f"ApiBase(configs={self.config_count})"
|
||||
@@ -0,0 +1,867 @@
|
||||
"""
|
||||
核心服务器模块
|
||||
|
||||
提供API调用、认证处理、配置管理等核心功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from .api_auth_service import AuthService
|
||||
from ..util.logger_config import get_logger
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ApiError(Exception):
|
||||
"""API调用相关异常"""
|
||||
|
||||
def __init__(self, message: str, status_code: Optional[int] = None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class ResponseSaver:
|
||||
"""响应数据保存管理器"""
|
||||
|
||||
def __init__(self, save_dir: str = "lzwcai_mcp_dyntoolapi_log_call_api"):
|
||||
"""
|
||||
初始化响应保存器
|
||||
|
||||
Args:
|
||||
save_dir: 保存目录路径
|
||||
"""
|
||||
self.save_dir = Path(save_dir)
|
||||
self.save_dir.mkdir(exist_ok=True)
|
||||
logger.debug(f"响应保存目录: {self.save_dir.absolute()}")
|
||||
|
||||
def save_response(
|
||||
self,
|
||||
response_data: Dict[str, Any],
|
||||
api_url: str,
|
||||
method: str = "GET",
|
||||
request_data: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
保存API响应到本地JSON文件
|
||||
|
||||
Args:
|
||||
response_data: 响应数据
|
||||
api_url: API URL
|
||||
method: HTTP方法
|
||||
request_data: 请求数据
|
||||
|
||||
Returns:
|
||||
保存的文件路径
|
||||
"""
|
||||
try:
|
||||
# 生成文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # 精确到毫秒
|
||||
safe_url = self._sanitize_filename(api_url)
|
||||
filename = f"{timestamp}_{method}_{safe_url}.json"
|
||||
file_path = self.save_dir / filename
|
||||
|
||||
# 构建保存的数据结构
|
||||
save_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"api_info": {
|
||||
"url": api_url,
|
||||
"method": method.upper(),
|
||||
},
|
||||
"request_data": request_data or {},
|
||||
"response_data": response_data,
|
||||
"metadata": {
|
||||
"saved_at": datetime.now().isoformat(),
|
||||
"file_name": filename,
|
||||
}
|
||||
}
|
||||
|
||||
# 保存到文件
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(save_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"API响应已保存到: {file_path}")
|
||||
return str(file_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存API响应失败: {str(e)}")
|
||||
return ""
|
||||
|
||||
def _sanitize_filename(self, url: str) -> str:
|
||||
"""
|
||||
清理URL以生成安全的文件名
|
||||
|
||||
Args:
|
||||
url: 原始URL
|
||||
|
||||
Returns:
|
||||
清理后的文件名部分
|
||||
"""
|
||||
# 移除协议和域名,只保留路径
|
||||
if "://" in url:
|
||||
url = url.split("://", 1)[1]
|
||||
if "/" in url:
|
||||
url = url.split("/", 1)[1] if "/" in url else url
|
||||
|
||||
# 替换特殊字符
|
||||
safe_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_"
|
||||
sanitized = "".join(c if c in safe_chars else "_" for c in url)
|
||||
|
||||
# 限制长度
|
||||
return sanitized[:50] if len(sanitized) > 50 else sanitized
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""配置管理器"""
|
||||
|
||||
@staticmethod
|
||||
def load_api_config(file_path: str = "generator_api.json") -> Dict[str, Any]:
|
||||
"""
|
||||
加载API配置文件
|
||||
|
||||
Args:
|
||||
file_path: JSON配置文件路径
|
||||
|
||||
Returns:
|
||||
解析后的API配置字典
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
json.JSONDecodeError: JSON格式错误
|
||||
"""
|
||||
logger.debug(f"尝试加载配置文件: {file_path}")
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
config = json.load(file)
|
||||
logger.info(f"成功加载配置文件: {file_path}")
|
||||
logger.debug(f"配置内容: {config}")
|
||||
return config
|
||||
except FileNotFoundError:
|
||||
logger.error(f"配置文件未找到: {file_path}")
|
||||
raise FileNotFoundError(f"配置文件未找到: {file_path}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"配置文件JSON格式错误: {file_path} - {str(e)}")
|
||||
raise json.JSONDecodeError(
|
||||
f"配置文件JSON格式错误: {file_path} - {str(e)}", e.doc, e.pos
|
||||
)
|
||||
|
||||
|
||||
class UserManager:
|
||||
"""用户管理器"""
|
||||
|
||||
@staticmethod
|
||||
def get_user_id_from_env() -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
从环境变量中获取用户ID
|
||||
|
||||
Returns:
|
||||
tuple: (是否成功获取, 用户ID值)
|
||||
"""
|
||||
try:
|
||||
user_id = os.environ.get("userId")
|
||||
if user_id:
|
||||
logger.debug(f"从环境变量获取用户ID: {user_id}")
|
||||
return True, user_id
|
||||
logger.debug("环境变量中未找到用户ID")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logger.warning(f"获取环境变量用户ID时发生异常: {e}")
|
||||
return False, None
|
||||
|
||||
@staticmethod
|
||||
def extract_user_id_from_request(
|
||||
request_data: Dict[str, Any], is_grouped_format: bool
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从请求数据中提取用户ID
|
||||
|
||||
Args:
|
||||
request_data: 请求数据
|
||||
is_grouped_format: 是否为分组格式
|
||||
|
||||
Returns:
|
||||
用户ID或None
|
||||
"""
|
||||
# 优先从请求数据中获取userId
|
||||
if isinstance(request_data, dict):
|
||||
if is_grouped_format:
|
||||
# 按优先级顺序查找:lzwcaiConfig > header > body > userId(兼容旧版本)
|
||||
user_id = (
|
||||
request_data.get("lzwcaiConfig", {}).get("userId")
|
||||
or request_data.get("header", {}).get("userId")
|
||||
or request_data.get("body", {}).get("userId")
|
||||
or (request_data.get("userId") or {}).get("userId")
|
||||
)
|
||||
if user_id:
|
||||
logger.info(f"从请求数据中获取到用户ID: {user_id}")
|
||||
return user_id
|
||||
else:
|
||||
# 非分组格式:优先从lzwcaiConfig获取,然后是userId(兼容旧版本)
|
||||
user_id = (
|
||||
(request_data.get("lzwcaiConfig") or {}).get("userId")
|
||||
or (request_data.get("userId") or {}).get("userId")
|
||||
)
|
||||
if user_id:
|
||||
logger.info(f"从请求数据中获取到用户ID: {user_id}")
|
||||
return user_id
|
||||
|
||||
# 如果请求数据中没有userId,则从环境变量获取作为备用
|
||||
success, env_user_id = UserManager.get_user_id_from_env()
|
||||
if success:
|
||||
logger.info(f"从环境变量获取到用户ID: {env_user_id}")
|
||||
return env_user_id
|
||||
|
||||
logger.warning("未能从请求数据或环境变量中获取到用户ID")
|
||||
return None
|
||||
|
||||
|
||||
class HeaderProcessor:
|
||||
"""请求头处理器"""
|
||||
|
||||
@staticmethod
|
||||
def validate_header_value(value: Any) -> str:
|
||||
"""
|
||||
验证并标准化请求头值
|
||||
|
||||
Args:
|
||||
value: 原始值
|
||||
|
||||
Returns:
|
||||
标准化后的字符串值
|
||||
"""
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value).strip()
|
||||
|
||||
@staticmethod
|
||||
def process_auth_headers(
|
||||
base_headers: Dict[str, Any],
|
||||
request_data: Dict[str, Any],
|
||||
auth_token: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
处理认证头信息
|
||||
|
||||
Args:
|
||||
base_headers: 基础请求头
|
||||
request_data: 请求数据
|
||||
auth_token: 认证token信息
|
||||
|
||||
Returns:
|
||||
处理后的请求头字典
|
||||
|
||||
Raises:
|
||||
ValueError: 当base_headers不是字典类型时
|
||||
"""
|
||||
if not isinstance(base_headers, dict):
|
||||
raise ValueError("base_headers必须是字典类型")
|
||||
|
||||
if not isinstance(request_data, dict):
|
||||
request_data = {}
|
||||
|
||||
if auth_token is None:
|
||||
auth_token = {}
|
||||
|
||||
# 获取请求数据中的header部分
|
||||
request_headers = request_data.get("header", {})
|
||||
|
||||
# 按优先级合并headers: base < request < auth_token
|
||||
processed_headers = {}
|
||||
|
||||
# 1. 基础headers (key转为小写以避免重复)
|
||||
for key, value in base_headers.items():
|
||||
processed_headers[key.lower()] = HeaderProcessor.validate_header_value(value)
|
||||
|
||||
# 2. 请求中的headers (key转为小写以避免重复)
|
||||
for key, value in request_headers.items():
|
||||
processed_headers[key.lower()] = HeaderProcessor.validate_header_value(value)
|
||||
|
||||
# 3. 认证token headers (最高优先级, key转为小写以避免重复)
|
||||
for key, value in auth_token.items():
|
||||
processed_headers[key.lower()] = HeaderProcessor.validate_header_value(value)
|
||||
|
||||
return processed_headers
|
||||
|
||||
|
||||
class RequestBuilder:
|
||||
"""请求构建器"""
|
||||
|
||||
@staticmethod
|
||||
def is_grouped_format(request_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
检查请求数据是否为分组格式
|
||||
|
||||
Args:
|
||||
request_data: 请求数据
|
||||
|
||||
Returns:
|
||||
是否为分组格式
|
||||
"""
|
||||
return any(
|
||||
key in ["header", "path", "query", "body", "params", "form", "formdata"]
|
||||
for key in request_data.keys()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_url_with_path_params(base_url: str, path_params: Dict[str, Any]) -> str:
|
||||
"""
|
||||
使用路径参数构建URL
|
||||
|
||||
Args:
|
||||
base_url: 基础URL
|
||||
path_params: 路径参数
|
||||
|
||||
Returns:
|
||||
替换路径参数后的URL
|
||||
"""
|
||||
url = base_url
|
||||
for key, value in path_params.items():
|
||||
placeholder = f"{{{key}}}"
|
||||
if placeholder in url:
|
||||
url = url.replace(placeholder, str(value))
|
||||
return url
|
||||
|
||||
@staticmethod
|
||||
def extract_parameters_from_grouped_format(
|
||||
request_data: Dict[str, Any], method: str
|
||||
) -> Tuple[
|
||||
Dict[str, Any],
|
||||
Dict[str, Any],
|
||||
Optional[Dict[str, Any]],
|
||||
Optional[Dict[str, Any]],
|
||||
Optional[Dict[str, Any]],
|
||||
]:
|
||||
"""
|
||||
从分组格式请求数据中提取参数
|
||||
|
||||
Args:
|
||||
request_data: 请求数据
|
||||
method: HTTP方法
|
||||
|
||||
Returns:
|
||||
(path_params, query_params, json_data, form_data, formdata_data)
|
||||
"""
|
||||
path_params = {}
|
||||
query_params = {}
|
||||
json_data = None
|
||||
form_data = None
|
||||
formdata_data = None
|
||||
|
||||
# 处理路径参数
|
||||
if "path" in request_data:
|
||||
path_params.update(request_data["path"])
|
||||
|
||||
# 兼容params命名(与path含义相同)
|
||||
if "params" in request_data:
|
||||
path_params.update(request_data["params"])
|
||||
|
||||
# 处理查询参数
|
||||
if "query" in request_data:
|
||||
query_params.update(request_data["query"])
|
||||
|
||||
# 处理请求体数据
|
||||
if "body" in request_data and method.upper() in ["POST", "PUT", "PATCH"]:
|
||||
json_data = request_data["body"]
|
||||
|
||||
if "form" in request_data and method.upper() in ["POST", "PUT", "PATCH"]:
|
||||
form_data = request_data["form"]
|
||||
|
||||
if "formdata" in request_data and method.upper() in ["POST", "PUT", "PATCH"]:
|
||||
formdata_data = request_data["formdata"]
|
||||
|
||||
return path_params, query_params, json_data, form_data, formdata_data
|
||||
|
||||
@staticmethod
|
||||
def extract_parameters_from_flat_format(
|
||||
request_data: Dict[str, Any], parameters: List[Dict[str, Any]], method: str
|
||||
) -> Tuple[
|
||||
Dict[str, str],
|
||||
Dict[str, Any],
|
||||
Dict[str, Any],
|
||||
Optional[Dict[str, Any]],
|
||||
Optional[Dict[str, Any]],
|
||||
Optional[Dict[str, Any]],
|
||||
]:
|
||||
"""
|
||||
从扁平格式请求数据中提取参数
|
||||
|
||||
Args:
|
||||
request_data: 请求数据
|
||||
parameters: 参数配置列表
|
||||
method: HTTP方法
|
||||
|
||||
Returns:
|
||||
(headers, path_params, query_params, json_data, form_data, formdata_data)
|
||||
"""
|
||||
headers = {}
|
||||
path_params = {}
|
||||
query_params = {}
|
||||
json_data = None
|
||||
form_data = None # for application/x-www-form-urlencoded
|
||||
formdata_data = None # for multipart/form-data
|
||||
|
||||
for param in parameters:
|
||||
param_name = param.get("paramName")
|
||||
request_type = param.get("requestType")
|
||||
default_value = param.get("defaultValue")
|
||||
|
||||
# 获取参数值:优先使用请求数据中的值,否则使用默认值
|
||||
if param_name in request_data:
|
||||
param_value = request_data[param_name]
|
||||
# 如果传入的值为None或空字符串,且有默认值,则使用默认值
|
||||
if param_value in (None, "") and default_value is not None:
|
||||
param_value = default_value
|
||||
else:
|
||||
param_value = default_value
|
||||
|
||||
# 如果参数没有值且不是必需的,则跳过
|
||||
if param_value is None and param.get("required", 0) == 0:
|
||||
continue
|
||||
|
||||
# 根据请求类型分配参数
|
||||
if request_type == "header":
|
||||
headers[param_name] = (
|
||||
str(param_value) if param_value is not None else ""
|
||||
)
|
||||
elif request_type == "query":
|
||||
if param_value is not None:
|
||||
query_params[param_name] = param_value
|
||||
elif request_type in ["params", "path"]:
|
||||
if param_value is not None:
|
||||
path_params[param_name] = param_value
|
||||
elif request_type == "body" and method.upper() in ["POST", "PUT", "PATCH"]:
|
||||
if json_data is None:
|
||||
json_data = {}
|
||||
json_data[param_name] = param_value
|
||||
elif request_type == "form" and method.upper() in ["POST", "PUT", "PATCH"]:
|
||||
if form_data is None:
|
||||
form_data = {}
|
||||
form_data[param_name] = param_value
|
||||
elif request_type == "formdata" and method.upper() in ["POST", "PUT", "PATCH"]:
|
||||
if formdata_data is None:
|
||||
formdata_data = {}
|
||||
formdata_data[param_name] = param_value
|
||||
|
||||
return headers, path_params, query_params, json_data, form_data, formdata_data
|
||||
|
||||
|
||||
class ApiClient:
|
||||
"""API客户端"""
|
||||
|
||||
def __init__(self, timeout: int = 30, verify: bool = True, save_responses: bool = True, save_dir: str = "lzwcai_mcp_dyntoolapi_log_call_api"):
|
||||
"""
|
||||
初始化API客户端
|
||||
|
||||
Args:
|
||||
timeout: 请求超时时间(秒)
|
||||
verify: 是否验证SSL证书
|
||||
save_responses: 是否保存响应到本地JSON文件
|
||||
save_dir: 响应保存目录
|
||||
"""
|
||||
self.timeout = timeout
|
||||
self.verify = verify
|
||||
self.save_responses = save_responses
|
||||
self.auth_service = AuthService()
|
||||
|
||||
# 初始化响应保存器
|
||||
if self.save_responses:
|
||||
self.response_saver = ResponseSaver(save_dir)
|
||||
logger.info(f"已启用响应保存功能,保存目录: {save_dir}")
|
||||
else:
|
||||
self.response_saver = None
|
||||
logger.info("响应保存功能已禁用")
|
||||
|
||||
async def call_api(
|
||||
self,
|
||||
api_config: Dict[str, Any],
|
||||
request_data: Optional[Dict[str, Any]] = None,
|
||||
need_auth: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
通用API调用方法
|
||||
|
||||
Args:
|
||||
api_config: API配置信息
|
||||
request_data: 请求数据,支持扁平格式和分组格式
|
||||
need_auth: 是否需要认证
|
||||
|
||||
Returns:
|
||||
API响应结果
|
||||
|
||||
Raises:
|
||||
ApiError: API调用相关错误
|
||||
ValueError: 配置或参数错误
|
||||
"""
|
||||
api_url = api_config.get("apiUrl", "N/A")
|
||||
logger.info(f"开始API调用: {api_url}")
|
||||
logger.debug(f"需要认证: {need_auth}")
|
||||
logger.debug(f"请求数据: {request_data}")
|
||||
|
||||
# 验证API配置
|
||||
self._validate_api_config(api_config)
|
||||
|
||||
# 准备基础信息
|
||||
domain_url = api_config["domainUrl"]
|
||||
api_url = api_config["apiUrl"]
|
||||
method = api_config.get("method", "GET")
|
||||
|
||||
# 构建完整URL
|
||||
full_url = urljoin(domain_url.rstrip("/") + "/", api_url.lstrip("/"))
|
||||
logger.debug(f"完整URL: {full_url}")
|
||||
logger.debug(f"HTTP方法: {method}")
|
||||
|
||||
# 初始化请求参数
|
||||
headers = {}
|
||||
query_params = {}
|
||||
path_params = {}
|
||||
json_data = None
|
||||
form_data = None
|
||||
formdata_data = None
|
||||
|
||||
request_data = request_data or {}
|
||||
parameters = api_config.get("parameters", [])
|
||||
logger.debug(f"参数配置数量: {len(parameters)}")
|
||||
|
||||
# 检查请求数据格式
|
||||
is_grouped = RequestBuilder.is_grouped_format(request_data)
|
||||
logger.debug(f"请求数据格式 - 分组格式: {is_grouped}")
|
||||
|
||||
# 处理认证
|
||||
if need_auth:
|
||||
logger.debug("开始处理认证")
|
||||
auth_headers = await self._handle_authentication(
|
||||
request_data, api_config, is_grouped
|
||||
)
|
||||
headers.update(auth_headers)
|
||||
logger.debug(f"认证头信息: {auth_headers}")
|
||||
else:
|
||||
logger.debug("跳过认证")
|
||||
|
||||
# 提取参数
|
||||
if is_grouped:
|
||||
(path_params, query_params, json_data, form_data, formdata_data) = (
|
||||
RequestBuilder.extract_parameters_from_grouped_format(
|
||||
request_data, method
|
||||
)
|
||||
)
|
||||
else:
|
||||
(
|
||||
param_headers,
|
||||
path_params,
|
||||
query_params,
|
||||
json_data,
|
||||
form_data,
|
||||
formdata_data,
|
||||
) = RequestBuilder.extract_parameters_from_flat_format(
|
||||
request_data, parameters, method
|
||||
)
|
||||
headers.update(param_headers)
|
||||
|
||||
logger.info(f"请求头: {headers},request_data: {request_data}")
|
||||
# 处理请求头
|
||||
headers = HeaderProcessor.process_auth_headers(headers, request_data)
|
||||
# 替换URL中的路径参数
|
||||
if path_params:
|
||||
full_url = RequestBuilder.build_url_with_path_params(full_url, path_params)
|
||||
|
||||
# 根据请求体内容设置Content-Type (如果未被显式设置)
|
||||
if "content-type" not in headers:
|
||||
if json_data is not None:
|
||||
headers["Content-Type"] = "application/json"
|
||||
# httpx会自动为form_data设置'application/x-www-form-urlencoded'
|
||||
# httpx会自动为formdata_data设置'multipart/form-data'并添加boundary
|
||||
|
||||
# 发送请求
|
||||
logger.info(f"发送HTTP请求: {method} {full_url}")
|
||||
|
||||
# 记录重要的请求信息(INFO级别,便于调试)
|
||||
if headers:
|
||||
# 过滤敏感信息,只显示关键头部
|
||||
safe_headers = {}
|
||||
for key, value in headers.items():
|
||||
if key.lower() in ['authorization', 'x-api-key', 'token']:
|
||||
safe_headers[key] = value
|
||||
else:
|
||||
safe_headers[key] = value
|
||||
logger.info(f"请求头: {safe_headers}")
|
||||
|
||||
if query_params:
|
||||
logger.info(f"查询参数: {query_params}")
|
||||
|
||||
if json_data:
|
||||
logger.info(f"请求体 (JSON): {json_data}")
|
||||
if form_data:
|
||||
logger.info(f"请求体 (Form): {form_data}")
|
||||
if formdata_data:
|
||||
logger.info(f"请求体 (FormData): {formdata_data}")
|
||||
|
||||
# 详细的调试信息仍保留在DEBUG级别
|
||||
logger.debug(f"完整请求头: {headers}")
|
||||
logger.debug(f"完整查询参数: {query_params}")
|
||||
logger.debug(f"完整请求体 (JSON): {json_data}")
|
||||
logger.debug(f"完整请求体 (Form): {form_data}")
|
||||
logger.debug(f"完整请求体 (FormData): {formdata_data}")
|
||||
|
||||
# 发送请求并获取响应
|
||||
response = await self._send_request(
|
||||
method, full_url, headers, query_params, json_data, form_data, formdata_data
|
||||
)
|
||||
|
||||
# 保存响应到本地JSON文件
|
||||
if self.save_responses and self.response_saver:
|
||||
try:
|
||||
saved_path = self.response_saver.save_response(
|
||||
response_data=response,
|
||||
api_url=full_url,
|
||||
method=method,
|
||||
request_data=request_data
|
||||
)
|
||||
if saved_path:
|
||||
logger.info(f"响应已保存到: {saved_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存响应失败: {str(e)}")
|
||||
|
||||
return response
|
||||
|
||||
def _validate_api_config(self, api_config: Dict[str, Any]) -> None:
|
||||
"""验证API配置"""
|
||||
logger.debug("验证API配置")
|
||||
|
||||
if not api_config:
|
||||
logger.error("API配置为空")
|
||||
raise ValueError("API配置不能为空")
|
||||
|
||||
if not api_config.get("domainUrl") or not api_config.get("apiUrl"):
|
||||
logger.error(f"缺少必要配置项 - domainUrl: {api_config.get('domainUrl')}, apiUrl: {api_config.get('apiUrl')}")
|
||||
raise ValueError("缺少必要的API配置项:domainUrl或apiUrl")
|
||||
|
||||
logger.debug("API配置验证通过")
|
||||
|
||||
async def _handle_authentication(
|
||||
self, request_data: Dict[str, Any], api_config: Dict[str, Any], is_grouped: bool
|
||||
) -> Dict[str, str]:
|
||||
"""处理认证逻辑"""
|
||||
user_id = UserManager.extract_user_id_from_request(request_data, is_grouped)
|
||||
biz_sys_id = api_config.get("bizSysId")
|
||||
|
||||
auth_result = await self.auth_service.authorize_request(user_id, biz_sys_id)
|
||||
|
||||
if not auth_result["success"]:
|
||||
raise ApiError(
|
||||
f"认证失败: {auth_result.get('error_response', {})}",
|
||||
auth_result.get("error_response", {}).get("status_code"),
|
||||
)
|
||||
|
||||
return auth_result.get("tokenHeader", {})
|
||||
|
||||
|
||||
def _contains_file(self, data: Dict[str, Any]) -> bool:
|
||||
"""检查数据字典中是否包含文件类对象"""
|
||||
if not data:
|
||||
return False
|
||||
for value in data.values():
|
||||
# 检查是否为字节流或具有read属性的对象(文件句柄)
|
||||
if isinstance(value, bytes) or hasattr(value, 'read'):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _send_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
query_params: Dict[str, Any],
|
||||
json_data: Optional[Dict[str, Any]],
|
||||
form_data: Optional[Dict[str, Any]],
|
||||
formdata_data: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""发送HTTP请求"""
|
||||
async with httpx.AsyncClient(
|
||||
verify=self.verify, timeout=self.timeout
|
||||
) as client:
|
||||
try:
|
||||
# 准备请求参数
|
||||
request_kwargs = {
|
||||
"params": query_params,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
# 为有请求体的方法添加数据
|
||||
if method.upper() in ["POST", "PUT", "PATCH", "DELETE"]:
|
||||
if json_data is not None:
|
||||
request_kwargs["json"] = json_data
|
||||
elif form_data is not None:
|
||||
request_kwargs["data"] = form_data
|
||||
elif formdata_data is not None:
|
||||
# 区分文件上传和普通formdata
|
||||
if self._contains_file(formdata_data):
|
||||
request_kwargs["files"] = formdata_data
|
||||
else:
|
||||
request_kwargs["data"] = formdata_data
|
||||
|
||||
# 根据HTTP方法发送请求
|
||||
request_func = getattr(client, method.lower(), None)
|
||||
if request_func is None:
|
||||
raise ValueError(f"不支持的HTTP方法: {method}")
|
||||
|
||||
response = await request_func(url, **request_kwargs)
|
||||
|
||||
# 检查响应状态
|
||||
response.raise_for_status()
|
||||
|
||||
# 解析响应
|
||||
return self._parse_response(response)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"请求错误: {str(e)}")
|
||||
raise ApiError(f"请求发送失败: {str(e)}")
|
||||
except httpx.HTTPStatusError as e:
|
||||
# 记录详细的HTTP错误信息
|
||||
logger.error(f"HTTP状态错误: {e.response.status_code} - {str(e)}")
|
||||
logger.error(f"请求URL: {e.request.url}")
|
||||
logger.error(f"请求方法: {e.request.method}")
|
||||
|
||||
# 记录响应内容(用于调试)
|
||||
response_text = e.response.text
|
||||
if response_text:
|
||||
logger.error(f"响应内容: {response_text[:500]}{'...' if len(response_text) > 500 else ''}")
|
||||
|
||||
# 记录响应头(可能包含有用的错误信息)
|
||||
if e.response.headers:
|
||||
logger.info(f"响应头: {dict(e.response.headers)}")
|
||||
|
||||
return {
|
||||
"status": "error",
|
||||
"status_code": e.response.status_code,
|
||||
"error": str(e),
|
||||
"response": e.response.text,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}")
|
||||
raise ApiError(f"请求处理失败: {str(e)}")
|
||||
|
||||
def _parse_response(self, response: httpx.Response) -> Dict[str, Any]:
|
||||
"""解析HTTP响应"""
|
||||
# 记录响应信息
|
||||
logger.info(f"HTTP响应: {response.status_code} {response.reason_phrase}")
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
logger.info(f"响应类型: {content_type}")
|
||||
|
||||
# 记录响应大小
|
||||
content_length = len(response.content) if response.content else 0
|
||||
logger.info(f"响应大小: {content_length} bytes")
|
||||
|
||||
if content_type.startswith("application/json"):
|
||||
try:
|
||||
json_response = response.json()
|
||||
# 记录JSON响应的基本信息(避免记录过大的数据)
|
||||
if isinstance(json_response, dict):
|
||||
logger.info(f"JSON响应键: {list(json_response.keys())}")
|
||||
return json_response
|
||||
except Exception as e:
|
||||
logger.error(f"JSON解析失败: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"JSON解析失败: {str(e)}",
|
||||
"raw_response": response.text,
|
||||
"status_code": response.status_code,
|
||||
}
|
||||
else:
|
||||
# 对于非JSON响应,记录前100个字符
|
||||
response_preview = response.text[:100] + "..." if len(response.text) > 100 else response.text
|
||||
logger.info(f"文本响应预览: {response_preview}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"data": response.text,
|
||||
"status_code": response.status_code,
|
||||
}
|
||||
|
||||
|
||||
# 兼容性函数 - 保持向后兼容
|
||||
async def call_api(
|
||||
api_config: Dict[str, Any],
|
||||
request_data: Optional[Dict[str, Any]] = None,
|
||||
need_auth: bool = True,
|
||||
timeout: int = 30,
|
||||
verify: bool = True,
|
||||
save_responses: bool = True,
|
||||
save_dir: str = "lzwcai_mcp_dyntoolapi_log_call_api",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
通用API调用方法(兼容性函数)
|
||||
|
||||
Args:
|
||||
api_config: API配置信息
|
||||
request_data: 请求数据
|
||||
need_auth: 是否需要认证
|
||||
timeout: 请求超时时间(秒)
|
||||
verify: 是否验证SSL证书
|
||||
save_responses: 是否保存响应到本地JSON文件
|
||||
save_dir: 响应保存目录
|
||||
|
||||
Returns:
|
||||
API响应结果
|
||||
"""
|
||||
client = ApiClient(timeout=timeout, verify=verify, save_responses=save_responses, save_dir=save_dir)
|
||||
return await client.call_api(api_config, request_data, need_auth)
|
||||
|
||||
|
||||
def get_env_user_id() -> Tuple[bool, Optional[str]]:
|
||||
"""获取环境变量用户ID(兼容性函数)"""
|
||||
return UserManager.get_user_id_from_env()
|
||||
|
||||
|
||||
def extract_user_id(
|
||||
request_data: Dict[str, Any], is_grouped_format: bool
|
||||
) -> Optional[str]:
|
||||
"""提取用户ID(兼容性函数)"""
|
||||
return UserManager.extract_user_id_from_request(request_data, is_grouped_format)
|
||||
|
||||
|
||||
def process_auth_headers(
|
||||
headers: dict,
|
||||
request_data: dict,
|
||||
auth_token: dict = None,
|
||||
) -> dict:
|
||||
"""处理认证头(兼容性函数)"""
|
||||
return HeaderProcessor.process_auth_headers(headers, request_data, auth_token)
|
||||
|
||||
|
||||
def load_generator_api_config(file_path: str = "generator_api.json") -> Dict[str, Any]:
|
||||
"""加载API配置(兼容性函数)"""
|
||||
return ConfigManager.load_api_config(file_path)
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 测试示例"""
|
||||
request_data = {
|
||||
"header": {},
|
||||
"body": {"username": "wangpeng1", "password": "Wp147258"},
|
||||
"lzwcaiConfig": {"userId": "test_user_123"},
|
||||
}
|
||||
|
||||
try:
|
||||
config_path = "src/lzwcai_demp_tool_server_business_to_mcp/mcp_generator/src/user_params.json"
|
||||
api_config = ConfigManager.load_api_config(config_path)
|
||||
logger.info(f"正在尝试连接: {api_config.get('domainUrl')}{api_config.get('apiUrl')}")
|
||||
client = ApiClient()
|
||||
result = await client.call_api(api_config, request_data)
|
||||
logger.info(f"请求结果: {result}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发生错误: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,297 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
认证数据转换器
|
||||
提供简洁的API来获取和转换认证信息
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
|
||||
sys.path.append(project_root)
|
||||
|
||||
from ..business.get_business_api import get_business_api_details
|
||||
from ..util.logger_config import get_logger
|
||||
|
||||
# 配置日志
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
class AuthDataTransformer:
|
||||
"""认证数据转换器"""
|
||||
|
||||
def __init__(self, base_url: str = "http://lzwcai-demp-corp-manager:8086"):
|
||||
"""
|
||||
初始化转换器
|
||||
|
||||
Args:
|
||||
base_url: API基础URL,默认为 http://lzwcai-demp-corp-manager:8086
|
||||
"""
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.session = requests.Session()
|
||||
|
||||
# 设置默认请求头
|
||||
self.session.headers.update({
|
||||
'User-Agent': 'Python-API-Client/1.0.0',
|
||||
'Accept': '*/*',
|
||||
'Connection': 'keep-alive'
|
||||
})
|
||||
|
||||
def get_transformed_auth_data(self, user_id: str, business_system_id: str) -> Optional[Dict[Any, Any]]:
|
||||
"""
|
||||
获取并转换认证数据
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
business_system_id: 业务系统ID
|
||||
|
||||
Returns:
|
||||
转换后的认证数据JSON,失败时返回None
|
||||
"""
|
||||
try:
|
||||
# 1. 获取原始认证信息
|
||||
raw_data = self._get_raw_auth_info(user_id, business_system_id)
|
||||
if not raw_data:
|
||||
return None
|
||||
|
||||
# 2. 处理bizSysConfig并获取API详情
|
||||
self._process_biz_sys_config(raw_data)
|
||||
|
||||
# 3. 转换数据结构
|
||||
transformed_data = self._transform_data_structure(raw_data)
|
||||
|
||||
return transformed_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取转换认证数据时发生错误: {e}")
|
||||
return None
|
||||
|
||||
def _get_raw_auth_info(self, user_id: str, business_system_id: str) -> Optional[Dict[Any, Any]]:
|
||||
"""
|
||||
获取原始认证信息
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
business_system_id: 业务系统ID
|
||||
|
||||
Returns:
|
||||
原始API响应数据,失败时返回None
|
||||
"""
|
||||
# url = f"{self.base_url}/system/mcpServer/auth/info/{user_id}/{business_system_id}"
|
||||
url = f"http://lzwcai-demp-corp-manager:8086/system/mcpServer/auth/info/{user_id}/{business_system_id}"
|
||||
|
||||
try:
|
||||
response = self.session.get(url, timeout=30)
|
||||
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
return response.json()
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"响应不是有效的JSON格式: {response.text}")
|
||||
return None
|
||||
else:
|
||||
logger.error(f"请求失败,状态码: {response.status_code}, 响应: {response.text}")
|
||||
return None
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"请求异常: {e}")
|
||||
return None
|
||||
|
||||
def _process_biz_sys_config(self, response_data: Dict[Any, Any]) -> None:
|
||||
"""
|
||||
处理bizSysConfig,解析loginApiId并获取API详情
|
||||
|
||||
Args:
|
||||
response_data: API响应数据
|
||||
"""
|
||||
try:
|
||||
data = response_data.get("data", {})
|
||||
if not data:
|
||||
return
|
||||
|
||||
biz_sys_config_str = data.get("bizSysConfig")
|
||||
if not biz_sys_config_str:
|
||||
return
|
||||
|
||||
try:
|
||||
biz_sys_config = json.loads(biz_sys_config_str)
|
||||
login_api_id = biz_sys_config.get("loginApiId")
|
||||
logger.info(f"登录APIID: {login_api_id}")
|
||||
logger.debug(f"获取API详情...", get_business_api_details)
|
||||
if login_api_id and get_business_api_details:
|
||||
try:
|
||||
api_details = get_business_api_details([int(login_api_id)])
|
||||
logger.debug(f"API详情: {api_details}")
|
||||
if api_details and len(api_details) > 0:
|
||||
data["apiItemDetail"] = api_details[0]
|
||||
except Exception as e:
|
||||
logger.error(f"获取API详情时发生错误: {e}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析bizSysConfig JSON失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理bizSysConfig时发生错误: {e}")
|
||||
|
||||
def _transform_data_structure(self, result: Dict[Any, Any]) -> Optional[Dict[Any, Any]]:
|
||||
"""
|
||||
将原始数据结构转换为新的数据结构
|
||||
|
||||
Args:
|
||||
result: 原始数据
|
||||
|
||||
Returns:
|
||||
转换后的新数据结构,失败时返回None
|
||||
"""
|
||||
try:
|
||||
data = result.get("data", {})
|
||||
if not data:
|
||||
return None
|
||||
|
||||
# 解析userAuthConfig
|
||||
user_auth_config = self._safe_json_parse(data.get("userAuthConfig", "{}"))
|
||||
|
||||
# 解析bizSysConfig
|
||||
biz_sys_config = self._safe_json_parse(data.get("bizSysConfig", "{}"))
|
||||
|
||||
# 获取apiItemDetail
|
||||
api_item_detail = data.get("apiItemDetail", {})
|
||||
|
||||
# 构建新的数据结构
|
||||
new_data = {
|
||||
"authType": user_auth_config.get("authType"),
|
||||
"name": biz_sys_config.get("name"),
|
||||
"apiKey": user_auth_config.get("apiKey"),
|
||||
"apiVO": self._build_api_vo(user_auth_config, biz_sys_config, api_item_detail)
|
||||
}
|
||||
|
||||
return {
|
||||
"code": result.get("code", 200),
|
||||
"msg": result.get("msg", "成功"),
|
||||
"data": new_data
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换数据结构时发生错误: {e}")
|
||||
return None
|
||||
|
||||
def _safe_json_parse(self, json_str: str) -> Dict[Any, Any]:
|
||||
"""
|
||||
安全解析JSON字符串
|
||||
|
||||
Args:
|
||||
json_str: JSON字符串
|
||||
|
||||
Returns:
|
||||
解析后的字典,失败时返回空字典
|
||||
"""
|
||||
try:
|
||||
return json.loads(json_str) if json_str else {}
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
def _build_api_vo(self, user_auth_config: Dict[Any, Any], biz_sys_config: Dict[Any, Any], api_item_detail: Dict[Any, Any]) -> Optional[Dict[Any, Any]]:
|
||||
"""
|
||||
构建apiVO对象
|
||||
|
||||
Args:
|
||||
user_auth_config: 用户认证配置
|
||||
biz_sys_config: 业务系统配置
|
||||
api_item_detail: API详情
|
||||
|
||||
Returns:
|
||||
apiVO对象
|
||||
"""
|
||||
if not biz_sys_config:
|
||||
return None
|
||||
|
||||
# 获取dynamicValues和paramMappings
|
||||
dynamic_values = user_auth_config.get("dynamicValues", {})
|
||||
account_config = biz_sys_config.get("accountConfig", {})
|
||||
param_mappings = account_config.get("paramMappings", [])
|
||||
|
||||
# 预处理accountConfig - 构建为对象然后转换为JSON字符串
|
||||
processed_account_config = {}
|
||||
if param_mappings:
|
||||
processed_account_config = {"parametersBody": []}
|
||||
|
||||
for param_mapping in param_mappings:
|
||||
param_name = param_mapping.get("paramName")
|
||||
if param_name:
|
||||
param_value = dynamic_values.get(param_name, param_mapping.get("defaultValue", ""))
|
||||
|
||||
processed_account_config["parametersBody"].append({
|
||||
"paramName": param_name,
|
||||
"defaultValue": param_value,
|
||||
"requestType": param_mapping.get("requestType", "form")
|
||||
})
|
||||
|
||||
# 构建tcapabilityApiVO
|
||||
tcapability_api_vo = self._build_tcapability_api_vo(api_item_detail)
|
||||
|
||||
return {
|
||||
"accountConfig": json.dumps(processed_account_config, ensure_ascii=False), # 转换为JSON字符串
|
||||
"tokenPath": biz_sys_config.get("tokenPath"),
|
||||
"tcapabilityApiVO": tcapability_api_vo # 放在apiVO里面
|
||||
}
|
||||
|
||||
def _build_tcapability_api_vo(self, api_item_detail: Dict[Any, Any]) -> Optional[Dict[Any, Any]]:
|
||||
"""
|
||||
构建tcapabilityApiVO对象
|
||||
|
||||
Args:
|
||||
api_item_detail: API详情
|
||||
|
||||
Returns:
|
||||
tcapabilityApiVO对象
|
||||
"""
|
||||
if not api_item_detail:
|
||||
return None
|
||||
|
||||
# 复制API详情并重命名parameters为apiParameterList
|
||||
tcapability_api_vo = api_item_detail.copy()
|
||||
if "parameters" in tcapability_api_vo:
|
||||
tcapability_api_vo["apiParameterList"] = tcapability_api_vo.pop("parameters")
|
||||
|
||||
return tcapability_api_vo
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def get_auth_data(user_id: str, business_system_id: str, base_url: str = "http://lzwcai-demp-corp-manager:8086") -> Optional[Dict[Any, Any]]:
|
||||
"""
|
||||
便捷函数:获取转换后的认证数据
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
business_system_id: 业务系统ID
|
||||
base_url: API基础URL,默认为 http://lzwcai-demp-corp-manager:8086
|
||||
|
||||
Returns:
|
||||
转换后的认证数据JSON,失败时返回None
|
||||
|
||||
Example:
|
||||
>>> auth_data = get_auth_data("447", "1952255539442741249")
|
||||
>>> if auth_data:
|
||||
... print(f"认证类型: {auth_data['data']['authType']}")
|
||||
"""
|
||||
transformer = AuthDataTransformer(base_url)
|
||||
return transformer.get_transformed_auth_data(user_id, business_system_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
from ..util.logger_config import setup_logging
|
||||
import logging
|
||||
setup_logging(log_level=logging.INFO)
|
||||
result = get_auth_data("447", "1957354824118095874")
|
||||
if result:
|
||||
logger.info("=== 转换后的认证数据 ===")
|
||||
logger.info(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
else:
|
||||
logger.error("获取认证数据失败")
|
||||
@@ -0,0 +1,22 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ToolPlugin(ABC):
|
||||
"""
|
||||
工具插件基类,所有工具插件需继承并实现相关方法
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def register(self, server):
|
||||
"""注册插件到 Server"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unregister(self, server):
|
||||
"""从 Server 注销插件"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def refresh(self, config):
|
||||
"""根据新配置刷新插件"""
|
||||
pass
|
||||
@@ -0,0 +1,788 @@
|
||||
"""
|
||||
MCP服务器创建和管理模块
|
||||
|
||||
这是lzwcai-mcp-dyntoolapi项目的核心模块,负责:
|
||||
1. 创建和配置MCP服务器
|
||||
2. 动态加载业务API配置
|
||||
3. 注册API工具插件
|
||||
4. 处理工具调用请求
|
||||
5. 支持配置热加载
|
||||
6. 支持多租户场景(内存模式)
|
||||
|
||||
主要功能:
|
||||
- 从业务平台获取API配置
|
||||
- 将API配置转换为MCP工具
|
||||
- 处理认证和参数验证
|
||||
- 支持多种传输方式(stdio、SSE)
|
||||
- 支持两种配置模式:文件模式和内存模式
|
||||
|
||||
配置模式说明:
|
||||
1. 内存模式(configMode=memory,默认):
|
||||
- 根据businessUuid创建变量business{businessUuid}存储配置
|
||||
- 配置存储在内存中,不写入本地文件
|
||||
- 支持多个租户共享同一个包实例
|
||||
- 适用于多租户SaaS场景
|
||||
|
||||
2. 文件模式(configMode=file):
|
||||
- 从业务平台获取配置后保存到本地api_config.json文件
|
||||
- 支持配置文件变更热加载
|
||||
- 适用于单租户场景
|
||||
|
||||
环境变量:
|
||||
- configMode: 配置模式(file/memory),默认为memory
|
||||
- businessUuid: 业务UUID(内存模式必需)
|
||||
- bizSysApiIds: API ID列表
|
||||
- ENABLE_CONFIG_WATCH: 是否启用配置热加载(仅文件模式)
|
||||
|
||||
作者: lzwcai
|
||||
版本: 1.1.0
|
||||
"""
|
||||
|
||||
import anyio
|
||||
import mcp.types as types
|
||||
from mcp.server.lowlevel import Server
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import re
|
||||
import uuid
|
||||
|
||||
# 导入核心模块
|
||||
from .core.api_base import ApiBase
|
||||
import mcp.types as types
|
||||
from .core.api_auth_service import AuthService
|
||||
from .core.plugin_base import ToolPlugin
|
||||
|
||||
# 导入业务工具模块
|
||||
from .business.business_util import (
|
||||
fill_default_values_by_schema, # 参数默认值填充
|
||||
check_required_arguments, # 必填参数检查
|
||||
)
|
||||
from .business.get_business_api import get_business_api_config # 业务API配置获取
|
||||
|
||||
# 导入工具模块
|
||||
from .util.logger_config import get_logger, setup_logging
|
||||
|
||||
# 获取日志器实例
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# ==================== 多租户配置存储 ====================
|
||||
# 用于存储多个租户的配置(内存模式)
|
||||
# key格式: business{businessUuid}, value: 配置字典
|
||||
business_configs = {}
|
||||
|
||||
|
||||
def load_api_configs():
|
||||
"""
|
||||
加载API配置的核心函数
|
||||
|
||||
支持两种配置模式:
|
||||
|
||||
模式一 - 内存模式(configMode=memory,默认):
|
||||
1. 根据环境变量businessUuid创建变量名:business{businessUuid}
|
||||
2. 从业务平台获取配置后存储在内存中(business_configs字典)
|
||||
3. 如果内存中已有配置则直接使用,否则从业务平台获取
|
||||
4. 不写入本地文件,支持多租户场景
|
||||
|
||||
模式二 - 文件模式(configMode=file):
|
||||
1. 从业务平台动态获取最新配置(通过get_business_api_config)
|
||||
2. 如果网络获取失败,则从本地api_config.json文件加载备份配置
|
||||
3. 成功获取后保存到本地文件作为备份
|
||||
|
||||
环境变量:
|
||||
- configMode: 配置模式,"file"(文件模式)或 "memory"(内存模式),默认为memory
|
||||
- businessUuid: 业务UUID,仅在内存模式下使用,用于区分不同租户
|
||||
- bizSysApiIds: 指定要加载的API ID列表
|
||||
|
||||
返回:
|
||||
dict: 包含完整API配置的字典,格式如下:
|
||||
{
|
||||
"packageName": "服务包名",
|
||||
"version": "版本号",
|
||||
"description": "服务描述",
|
||||
"apiConfig": [API配置列表]
|
||||
}
|
||||
|
||||
异常处理:
|
||||
- 网络获取失败时自动降级(文件模式降级到本地文件,内存模式报错)
|
||||
- 详细记录所有错误信息用于调试
|
||||
"""
|
||||
global business_configs
|
||||
|
||||
# 获取配置模式(默认为内存模式)
|
||||
config_mode = os.getenv('configMode', 'memory').lower()
|
||||
logger.info(f"配置模式: {config_mode}")
|
||||
|
||||
# 获取当前模块所在目录,用于定位配置文件
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_path = os.path.join(current_dir, "api_config.json")
|
||||
|
||||
# ==================== 模式一:内存模式(多租户支持,默认) ====================
|
||||
if config_mode == 'memory':
|
||||
logger.info("使用内存模式加载配置(多租户支持)")
|
||||
|
||||
# 获取业务UUID,如果不存在则自动生成一个
|
||||
business_uuid = os.getenv('businessUuid')
|
||||
if not business_uuid:
|
||||
# 生成随机UUID(使用uuid4,完全随机)
|
||||
business_uuid = str(uuid.uuid4())
|
||||
logger.warning(f"环境变量businessUuid未设置,已自动生成随机UUID: {business_uuid}")
|
||||
# 可选:将生成的UUID设置回环境变量,供后续使用
|
||||
os.environ['businessUuid'] = business_uuid
|
||||
else:
|
||||
logger.info(f"使用环境变量提供的businessUuid: {business_uuid}")
|
||||
|
||||
# 构建配置变量名
|
||||
config_key = f"business{business_uuid}"
|
||||
logger.info(f"租户配置变量名: {config_key}")
|
||||
|
||||
# 检查内存中是否已有该租户的配置
|
||||
if config_key in business_configs:
|
||||
logger.info(f"从内存中获取租户 {business_uuid} 的配置")
|
||||
return business_configs[config_key]
|
||||
|
||||
# 内存中没有,从业务平台获取
|
||||
logger.info(f"内存中没有租户 {business_uuid} 的配置,开始从业务平台获取...")
|
||||
|
||||
try:
|
||||
# 从环境变量获取API ID列表
|
||||
api_ids = []
|
||||
biz_sys_api_ids = os.getenv('bizSysApiIds')
|
||||
logger.debug(f"已获取环境变量bizSysApiIds: {biz_sys_api_ids}")
|
||||
|
||||
if biz_sys_api_ids:
|
||||
try:
|
||||
ids_str = biz_sys_api_ids.strip('[]')
|
||||
api_ids = []
|
||||
for id_part in ids_str.split(','):
|
||||
id_clean = id_part.strip().strip('"\'')
|
||||
if id_clean:
|
||||
try:
|
||||
api_id = int(id_clean)
|
||||
api_ids.append(api_id)
|
||||
except ValueError:
|
||||
logger.warning(f"无法将 '{id_clean}' 转换为整数,跳过此项")
|
||||
continue
|
||||
logger.info(f"从环境变量bizSysApiIds获取到API IDs: {api_ids}")
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"解析环境变量bizSysApiIds失败,使用默认值: {str(e)}")
|
||||
else:
|
||||
logger.info("未找到环境变量bizSysApiIds,使用默认API IDs")
|
||||
|
||||
# 从业务平台获取配置
|
||||
logger.info(f"调用get_business_api_config获取配置,API IDs: {api_ids}")
|
||||
config = get_business_api_config(api_ids)
|
||||
logger.info(f"成功获取业务API配置,包含 {len(config.get('apiConfig', []))} 个API配置")
|
||||
|
||||
# 存储到内存中
|
||||
business_configs[config_key] = config
|
||||
logger.info(f"配置已存储到内存变量: {config_key}")
|
||||
logger.info(f"当前内存中共有 {len(business_configs)} 个租户配置")
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取业务API配置失败: {str(e)}")
|
||||
error_msg = f"内存模式下无法获取租户 {business_uuid} 的配置: {str(e)}"
|
||||
raise Exception(error_msg)
|
||||
|
||||
# ==================== 模式二:文件模式(单租户) ====================
|
||||
elif config_mode == 'file':
|
||||
logger.info("使用文件模式加载配置(单租户)")
|
||||
|
||||
try:
|
||||
# 从环境变量获取API ID列表
|
||||
# 支持格式: "[1932682081958830081,1932682082285985793]" 或 "1932682081958830081,1932682082285985793"
|
||||
api_ids = [] # 默认空列表
|
||||
|
||||
# 尝试从环境变量获取bizSysApiIds
|
||||
biz_sys_api_ids = os.getenv('bizSysApiIds')
|
||||
logger.debug(f"已获取环境变量bizSysApiIds: {biz_sys_api_ids}")
|
||||
|
||||
if biz_sys_api_ids:
|
||||
try:
|
||||
# 解析环境变量中的字符串,支持多种格式
|
||||
# 格式1: [1932682081958830081,1932682082285985793]
|
||||
# 格式2: 1932682081958830081,1932682082285985793
|
||||
# 格式3: ["1932682081958830081","1932682082285985793"]
|
||||
# 格式4: "1932682081958830081","1932682082285985793"
|
||||
ids_str = biz_sys_api_ids.strip('[]') # 移除方括号
|
||||
|
||||
# 分割并处理每个ID,自动转换字符串数字为整数
|
||||
api_ids = []
|
||||
for id_part in ids_str.split(','):
|
||||
id_clean = id_part.strip().strip('"\'') # 移除空格和引号
|
||||
if id_clean: # 确保不是空字符串
|
||||
try:
|
||||
# 尝试转换为整数,支持字符串数字自动转换
|
||||
api_id = int(id_clean)
|
||||
api_ids.append(api_id)
|
||||
except ValueError:
|
||||
logger.warning(f"无法将 '{id_clean}' 转换为整数,跳过此项")
|
||||
continue
|
||||
|
||||
logger.info(f"从环境变量bizSysApiIds获取到API IDs: {api_ids}")
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"解析环境变量bizSysApiIds失败,使用默认值: {str(e)}")
|
||||
logger.warning(f"环境变量值: {biz_sys_api_ids}")
|
||||
else:
|
||||
logger.info("未找到环境变量bizSysApiIds,使用默认API IDs")
|
||||
|
||||
logger.info(f"调用get_business_api_config获取配置,API IDs: {api_ids}")
|
||||
|
||||
# 从业务平台获取最新配置
|
||||
config = get_business_api_config(api_ids)
|
||||
logger.info(f"成功获取业务API配置,包含 {len(config.get('apiConfig', []))} 个API配置")
|
||||
|
||||
# 将获取的配置保存到本地文件作为备份
|
||||
logger.info(f"保存配置到文件: {config_path}")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||
logger.info("配置文件保存成功")
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取业务API配置失败: {str(e)}")
|
||||
logger.info("尝试从本地配置文件加载...")
|
||||
|
||||
# 降级处理:从本地文件加载配置
|
||||
try:
|
||||
logger.debug(f"加载本地API配置文件: {config_path}")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
logger.info(f"成功加载本地API配置文件,包含 {len(config.get('apiConfig', []))} 个API配置")
|
||||
return config
|
||||
except FileNotFoundError:
|
||||
logger.error(f"本地API配置文件也未找到: {config_path}")
|
||||
raise Exception(f"无法获取业务API配置且本地配置文件不存在: {str(e)}")
|
||||
except json.JSONDecodeError as json_e:
|
||||
logger.error(f"本地API配置文件JSON格式错误: {str(json_e)}")
|
||||
raise Exception(f"无法获取业务API配置且本地配置文件格式错误: {str(e)}")
|
||||
except Exception as file_e:
|
||||
logger.error(f"加载本地API配置文件失败: {str(file_e)}")
|
||||
raise Exception(f"无法获取业务API配置且加载本地配置文件失败: {str(e)}")
|
||||
|
||||
# ==================== 无效配置模式 ====================
|
||||
else:
|
||||
error_msg = f"无效的配置模式: {config_mode},请使用 'file' 或 'memory'"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
|
||||
# ==================== MCP服务器初始化 ====================
|
||||
|
||||
logger.info("开始初始化 MCP 服务器")
|
||||
|
||||
# 加载API配置(从业务平台或本地文件)
|
||||
api_configs = load_api_configs()
|
||||
|
||||
# 设置MCP服务器的默认配置值
|
||||
default_name = "lzwcai-mcp-dyntoolapi" # 默认服务器名称
|
||||
default_version = "1.0.0" # 默认版本号
|
||||
default_instructions = "动态业务API工具服务器" # 默认服务器描述
|
||||
|
||||
# 从加载的配置中获取服务器信息,如果不存在则使用默认值
|
||||
name = api_configs.get("packageName", default_name)
|
||||
version = api_configs.get("version", default_version)
|
||||
instructions = api_configs.get("description", default_instructions)
|
||||
|
||||
logger.info(f"服务器配置 - 名称: {name}, 版本: {version}")
|
||||
logger.debug(f"服务器说明: {instructions}")
|
||||
|
||||
# 创建MCP服务器实例
|
||||
# 这个Server实例将处理所有的MCP协议通信
|
||||
app = Server(
|
||||
name=name, # 服务器名称,用于客户端识别
|
||||
version=version, # 服务器版本,用于兼容性检查
|
||||
instructions=instructions, # 服务器描述,告诉客户端这个服务器的功能
|
||||
)
|
||||
|
||||
# 初始化API基础服务
|
||||
# ApiBase负责管理所有的API配置和调用逻辑
|
||||
logger.debug("初始化 API 基础服务")
|
||||
api_base = ApiBase(api_configs.get("apiConfig"))
|
||||
logger.info(f"API 基础服务初始化完成,共 {api_base.config_count} 个API配置")
|
||||
|
||||
|
||||
class ApiToolPlugin(ToolPlugin):
|
||||
"""
|
||||
API工具插件实现类
|
||||
|
||||
这个类负责将业务API配置转换为MCP工具,并处理工具调用请求。
|
||||
它继承自ToolPlugin基类,实现了插件的标准接口。
|
||||
|
||||
主要功能:
|
||||
1. 将API配置转换为MCP工具定义
|
||||
2. 处理工具列表请求(list_tools)
|
||||
3. 处理工具调用请求(call_tool)
|
||||
4. 支持插件的注册、注销和刷新
|
||||
|
||||
属性:
|
||||
api_base: ApiBase实例,管理所有API配置
|
||||
tools: 工具列表缓存(当前未使用)
|
||||
"""
|
||||
|
||||
def __init__(self, api_base):
|
||||
"""
|
||||
初始化API工具插件
|
||||
|
||||
参数:
|
||||
api_base: ApiBase实例,包含所有API配置信息
|
||||
"""
|
||||
self.api_base = api_base
|
||||
self.tools = [] # 工具列表缓存(预留)
|
||||
logger.debug(f"初始化 API 工具插件,API配置数量: {api_base.config_count}")
|
||||
|
||||
def register(self, server):
|
||||
"""
|
||||
向MCP服务器注册插件
|
||||
|
||||
这个方法会向服务器注册两个处理器:
|
||||
1. list_tools: 返回可用工具列表
|
||||
2. call_tool: 处理工具调用请求
|
||||
|
||||
参数:
|
||||
server: MCP服务器实例
|
||||
"""
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[types.Tool]:
|
||||
"""
|
||||
处理工具列表请求
|
||||
|
||||
当MCP客户端请求可用工具列表时,这个函数会被调用。
|
||||
它会遍历所有API配置,为每个API创建一个MCP工具定义。
|
||||
|
||||
返回:
|
||||
list[types.Tool]: MCP工具定义列表
|
||||
"""
|
||||
logger.debug("处理工具列表请求")
|
||||
tools = []
|
||||
tools_configs = self.api_base.api_configs_map
|
||||
|
||||
# 遍历所有API配置,创建工具定义
|
||||
for tool_config in tools_configs:
|
||||
tool_name = tool_config["interfaceName"] # 工具名称(拼音格式)
|
||||
logger.debug(f"注册工具: {tool_name}")
|
||||
|
||||
# 创建MCP工具定义
|
||||
tools.append(
|
||||
types.Tool(
|
||||
name=tool_name, # 工具名称
|
||||
description=tool_config["schema_description"], # 工具描述(包含参数说明)
|
||||
inputSchema=tool_config["schema"], # 输入参数的JSON Schema
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"返回工具列表,共 {len(tools)} 个工具")
|
||||
return tools
|
||||
|
||||
@server.call_tool()
|
||||
async def fetch_tool(
|
||||
name: str, arguments: dict
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
"""
|
||||
处理工具调用请求
|
||||
|
||||
当MCP客户端调用某个工具时,这个函数会被调用。
|
||||
它负责:
|
||||
1. 查找对应的API配置
|
||||
2. 验证和处理输入参数
|
||||
3. 调用实际的API接口
|
||||
4. 返回格式化的结果
|
||||
|
||||
参数:
|
||||
name: 工具名称(对应API的interfaceName)
|
||||
arguments: 工具调用参数(JSON对象)
|
||||
|
||||
返回:
|
||||
list[types.TextContent]: 包含API调用结果的文本内容列表
|
||||
"""
|
||||
logger.info(f"调用工具: {name}")
|
||||
logger.debug(f"工具参数: {arguments}")
|
||||
|
||||
# 查找对应的工具配置
|
||||
tool_config = None
|
||||
for config in self.api_base.api_configs_map:
|
||||
if config["interfaceName"] == name:
|
||||
tool_config = config
|
||||
break
|
||||
|
||||
# 检查工具是否存在
|
||||
if tool_config is None:
|
||||
logger.error(f"未找到工具: {name}")
|
||||
return [types.TextContent(type="text", text=f"未找到工具: {name}")]
|
||||
|
||||
logger.debug(f"找到工具配置: {tool_config.get('apiUrl', 'N/A')}")
|
||||
|
||||
# ==================== 参数处理和验证 ====================
|
||||
logger.debug("开始参数处理和验证")
|
||||
|
||||
# 保存原始参数用于调试
|
||||
original_args = arguments.copy()
|
||||
|
||||
# 使用schema中的默认值补全缺失的参数
|
||||
arguments = fill_default_values_by_schema(
|
||||
tool_config.get("schema", {}), arguments
|
||||
)
|
||||
logger.debug(f"参数默认值补全完成,原始参数: {original_args}, 补全后: {arguments}")
|
||||
|
||||
# 检查必填参数是否都已提供
|
||||
missing = check_required_arguments(tool_config.get("schema", {}), arguments)
|
||||
if missing:
|
||||
missing_str = ", ".join(missing)
|
||||
logger.warning(f"缺少必填参数: {missing_str}")
|
||||
return [
|
||||
types.TextContent(
|
||||
type="text",
|
||||
text=f"请补充以下必填参数:{missing_str}。填写后再试一次哦~ 如果有疑问请联系管理员。",
|
||||
)
|
||||
]
|
||||
|
||||
# ==================== API接口调用 ====================
|
||||
logger.info(f"开始调用API接口: {tool_config.get('apiUrl', 'N/A')}")
|
||||
try:
|
||||
# 通过ApiBase调用实际的API接口
|
||||
result = await self.api_base.call_interface(tool_config, arguments)
|
||||
logger.info("API调用成功")
|
||||
logger.debug(f"API返回结果: {result}")
|
||||
|
||||
# 将结果转换为JSON格式返回给客户端
|
||||
result_json = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
return [types.TextContent(type="text", text=result_json)]
|
||||
|
||||
except Exception as e:
|
||||
# 处理API调用异常
|
||||
logger.error(f"API调用失败: {str(e)}")
|
||||
logger.debug("API调用异常详情:", exc_info=True)
|
||||
error_msg = f"API调用失败: {str(e)}"
|
||||
return [types.TextContent(type="text", text=error_msg)]
|
||||
|
||||
def unregister(self, server):
|
||||
"""
|
||||
从MCP服务器注销插件
|
||||
|
||||
目前MCP Server不支持动态注销功能,这个方法预留给未来使用。
|
||||
|
||||
参数:
|
||||
server: MCP服务器实例(当前未使用)
|
||||
"""
|
||||
# 目前 Server 不支持动态注销,预留接口
|
||||
logger.debug("插件注销请求(当前不支持动态注销)")
|
||||
pass
|
||||
|
||||
def refresh(self, config):
|
||||
"""
|
||||
刷新插件配置
|
||||
|
||||
当API配置发生变化时(如热加载),这个方法会被调用来更新插件的配置。
|
||||
|
||||
参数:
|
||||
config: 新的配置字典,包含apiConfig字段
|
||||
"""
|
||||
logger.info("刷新API工具插件配置")
|
||||
# 重新创建ApiBase实例以使用新配置
|
||||
self.api_base = ApiBase(config.get("apiConfig"))
|
||||
logger.info(f"插件配置刷新完成,共 {self.api_base.config_count} 个API配置")
|
||||
|
||||
|
||||
# ==================== 插件注册 ====================
|
||||
|
||||
# 创建并注册API工具插件
|
||||
logger.info("注册API工具插件")
|
||||
api_tool_plugin = ApiToolPlugin(api_base) # 创建插件实例
|
||||
api_tool_plugin.register(app) # 向MCP服务器注册插件
|
||||
logger.info("API工具插件注册完成")
|
||||
|
||||
|
||||
|
||||
# ==================== MCP服务器启动函数 ====================
|
||||
|
||||
def main(port: int, transport: str) -> int:
|
||||
"""
|
||||
MCP服务器主启动函数
|
||||
|
||||
根据指定的传输方式启动MCP服务器。支持两种传输方式:
|
||||
1. stdio: 标准输入输出传输(默认,用于命令行工具集成)
|
||||
2. sse: Server-Sent Events传输(用于Web集成)
|
||||
|
||||
参数:
|
||||
port: 服务器端口号(仅在SSE模式下使用)
|
||||
transport: 传输方式,"stdio" 或 "sse"
|
||||
|
||||
返回:
|
||||
int: 退出状态码,0表示成功
|
||||
"""
|
||||
|
||||
if transport == "sse":
|
||||
# ==================== SSE传输模式 ====================
|
||||
logger.info(f"启动SSE传输模式,端口: {port}")
|
||||
|
||||
# 导入SSE相关模块
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
# 创建SSE传输实例
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request):
|
||||
"""
|
||||
处理SSE连接请求
|
||||
|
||||
这个函数处理来自Web客户端的SSE连接请求,
|
||||
建立双向通信流并运行MCP服务器。
|
||||
"""
|
||||
async with sse.connect_sse(
|
||||
request.scope, request.receive, request._send
|
||||
) as streams:
|
||||
# 运行MCP服务器,使用SSE流进行通信
|
||||
await app.run(
|
||||
streams[0], streams[1], app.create_initialization_options()
|
||||
)
|
||||
return Response()
|
||||
|
||||
# 创建Starlette Web应用
|
||||
starlette_app = Starlette(
|
||||
debug=True, # 开启调试模式
|
||||
routes=[
|
||||
# SSE连接端点
|
||||
Route("/sse", endpoint=handle_sse, methods=["GET"]),
|
||||
# 消息处理端点
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
)
|
||||
|
||||
# 使用uvicorn启动Web服务器
|
||||
import uvicorn
|
||||
logger.info(f"启动Web服务器,监听 0.0.0.0:{port}")
|
||||
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
|
||||
|
||||
else:
|
||||
# ==================== STDIO传输模式 ====================
|
||||
logger.info("启动STDIO传输模式")
|
||||
|
||||
# 导入stdio传输模块
|
||||
from mcp.server.stdio import stdio_server
|
||||
|
||||
async def arun():
|
||||
"""
|
||||
异步运行MCP服务器
|
||||
|
||||
使用标准输入输出流与客户端通信,
|
||||
这是MCP协议的标准传输方式。
|
||||
"""
|
||||
async with stdio_server() as streams:
|
||||
# 运行MCP服务器,使用stdio流进行通信
|
||||
await app.run(
|
||||
streams[0], streams[1], app.create_initialization_options()
|
||||
)
|
||||
|
||||
# 启动异步事件循环
|
||||
anyio.run(arun)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
def save_to_json_file(data, file_path="output/test_data.json"):
|
||||
"""
|
||||
将输入数据保存为JSON文件
|
||||
|
||||
这是一个通用的数据保存工具函数,主要用于调试和数据持久化。
|
||||
|
||||
参数:
|
||||
data: 要保存的数据(字典、列表或其他可JSON序列化的对象)
|
||||
file_path: 要保存的文件路径,默认保存到output目录
|
||||
|
||||
返回:
|
||||
bool: 操作成功返回True,失败返回False
|
||||
|
||||
特性:
|
||||
- 自动创建目录(如果不存在)
|
||||
- 使用UTF-8编码确保中文正确显示
|
||||
- 格式化输出(缩进2个空格)
|
||||
"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
# 保存数据到JSON文件
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"数据已成功保存到 {file_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"保存JSON文件时出错: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def refresh_api_configs():
|
||||
"""
|
||||
刷新API配置和工具注册(热加载功能)
|
||||
|
||||
这个函数实现了配置的热加载功能,支持两种模式:
|
||||
- 文件模式:当检测到配置文件变化时会被调用
|
||||
- 内存模式:强制重新从业务平台获取配置并更新内存
|
||||
|
||||
全局变量更新:
|
||||
- api_configs: 重新加载的API配置
|
||||
- name, version, instructions: 服务器基本信息
|
||||
- api_base: 重新创建的API基础服务实例
|
||||
- api_tool_plugin: 刷新插件配置
|
||||
- business_configs: 内存模式下更新租户配置
|
||||
|
||||
注意:
|
||||
这个函数修改全局变量,在多线程环境中需要注意线程安全。
|
||||
"""
|
||||
global api_configs, name, version, instructions, api_base, api_tool_plugin, business_configs
|
||||
|
||||
logger.info("开始刷新API配置...")
|
||||
|
||||
# 获取配置模式
|
||||
config_mode = os.getenv('configMode', 'memory').lower()
|
||||
|
||||
# 内存模式下需要清除当前租户的缓存配置,强制重新获取
|
||||
if config_mode == 'memory':
|
||||
business_uuid = os.getenv('businessUuid')
|
||||
if business_uuid:
|
||||
config_key = f"business{business_uuid}"
|
||||
if config_key in business_configs:
|
||||
logger.info(f"清除租户 {business_uuid} 的缓存配置")
|
||||
del business_configs[config_key]
|
||||
|
||||
# 重新加载API配置
|
||||
api_configs = load_api_configs()
|
||||
|
||||
# 更新服务器基本信息
|
||||
name = api_configs.get("packageName", default_name)
|
||||
version = api_configs.get("version", default_version)
|
||||
instructions = api_configs.get("description", default_instructions)
|
||||
|
||||
# 重新创建API基础服务
|
||||
api_base = ApiBase(api_configs.get("apiConfig"))
|
||||
|
||||
# 刷新插件配置
|
||||
api_tool_plugin.refresh(api_configs)
|
||||
|
||||
logger.info(f"API配置已热加载并刷新(模式:{config_mode})")
|
||||
|
||||
|
||||
# ==================== 配置热加载功能 ====================
|
||||
|
||||
def watch_config_file(interval=2):
|
||||
"""
|
||||
配置文件变更监控函数(仅文件模式)
|
||||
|
||||
这个函数在后台线程中运行,定期检查配置文件的修改时间。
|
||||
当检测到文件变更时,自动触发配置热加载。
|
||||
|
||||
注意:
|
||||
- 仅在文件模式(configMode=file)下有效
|
||||
- 内存模式下此函数会直接返回,不进行监控
|
||||
|
||||
参数:
|
||||
interval: 检查间隔时间(秒),默认2秒
|
||||
|
||||
特性:
|
||||
- 轮询方式检测文件变更
|
||||
- 异常安全,不会因为单次错误而停止监控
|
||||
- 在守护线程中运行,不会阻止程序退出
|
||||
|
||||
注意:
|
||||
这个函数会无限循环运行,直到程序退出。
|
||||
"""
|
||||
# 检查配置模式,内存模式下不需要监控文件
|
||||
config_mode = os.getenv('configMode', 'memory').lower()
|
||||
if config_mode != 'file':
|
||||
logger.info(f"当前为 {config_mode} 模式,无需监控配置文件")
|
||||
return
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_path = os.path.join(current_dir, "api_config.json")
|
||||
|
||||
try:
|
||||
# 获取初始修改时间
|
||||
last_mtime = os.path.getmtime(config_path)
|
||||
logger.debug(f"开始监控配置文件: {config_path}")
|
||||
except OSError:
|
||||
logger.warning(f"配置文件不存在,跳过热加载监控: {config_path}")
|
||||
return
|
||||
|
||||
# 无限循环监控文件变更
|
||||
while True:
|
||||
try:
|
||||
# 获取当前修改时间
|
||||
mtime = os.path.getmtime(config_path)
|
||||
|
||||
# 检查是否有变更
|
||||
if mtime != last_mtime:
|
||||
logger.info("检测到api_config.json变更,自动热加载...")
|
||||
refresh_api_configs()
|
||||
last_mtime = mtime
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"配置热加载异常: {e}")
|
||||
|
||||
# 等待下次检查
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
# ==================== 主启动函数 ====================
|
||||
|
||||
def run_main():
|
||||
"""
|
||||
程序主启动函数
|
||||
|
||||
这是整个MCP服务器的启动入口,负责:
|
||||
1. 配置日志系统(MCP模式下禁用控制台输出)
|
||||
2. 可选启动配置文件监控线程(通过环境变量ENABLE_CONFIG_WATCH控制)
|
||||
3. 启动MCP服务器
|
||||
|
||||
配置说明:
|
||||
- 使用stdio传输方式(标准MCP协议)
|
||||
- 端口8000(仅在SSE模式下使用)
|
||||
- 配置热加载功能默认关闭,可通过环境变量ENABLE_CONFIG_WATCH=true启用
|
||||
"""
|
||||
|
||||
# 在MCP模式下,禁用控制台日志输出,避免干扰stdio通信
|
||||
# 只输出到文件,确保MCP协议通信不被日志干扰
|
||||
setup_logging(console_output=False, file_output=True)
|
||||
|
||||
# 检查是否启用配置文件监控(默认关闭)
|
||||
# 通过环境变量ENABLE_CONFIG_WATCH=true来启用
|
||||
enable_watch = os.getenv('ENABLE_CONFIG_WATCH', 'false').lower() in ['true', '1', 'yes']
|
||||
|
||||
if enable_watch:
|
||||
# 启动配置文件监听线程(守护线程)
|
||||
# 守护线程会在主程序退出时自动结束
|
||||
config_watcher = threading.Thread(target=watch_config_file, daemon=True)
|
||||
config_watcher.start()
|
||||
logger.info("配置文件监控线程已启动")
|
||||
else:
|
||||
logger.info("配置文件监控功能已禁用(如需启用,请设置环境变量ENABLE_CONFIG_WATCH=true)")
|
||||
|
||||
# 启动MCP服务器
|
||||
# 使用stdio传输方式,这是MCP协议的标准传输方式
|
||||
# 如果需要Web集成,可以改为 transport="sse"
|
||||
main(port=8000, transport="stdio")
|
||||
|
||||
|
||||
# ==================== 程序入口 ====================
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
直接运行此模块时的入口点
|
||||
|
||||
通常情况下,这个模块会被main.py调用,
|
||||
但也支持直接运行进行测试。
|
||||
"""
|
||||
run_main()
|
||||
|
||||
# 以下是测试代码,正常运行时被注释掉
|
||||
# from core.api_auth_service import test_auth_service
|
||||
# asyncio.run(test_auth_service())
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,412 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
API助手工具包
|
||||
|
||||
这个工具包集成了三个核心功能:
|
||||
1. 获取接口参数信息
|
||||
2. 调用指定接口API
|
||||
3. 获取用户授权Token
|
||||
|
||||
使用lzwcai_mcp_dyntoolapi包中的方法,提供简洁易用的API操作接口。
|
||||
|
||||
作者: lzwcai
|
||||
版本: 1.0.0
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional
|
||||
import os
|
||||
# 引入lzwcai_mcp_dyntoolapi包中的核心方法
|
||||
from lzwcai_mcp_api_converter.src.business.get_business_api import get_business_api_details
|
||||
from lzwcai_mcp_api_converter.src.core.api_base import ApiBase
|
||||
from lzwcai_mcp_api_converter.src.core.core_server import call_api
|
||||
from lzwcai_mcp_api_converter.src.core.get_auth import get_auth_data
|
||||
from lzwcai_mcp_api_converter.src.core.api_auth_service import AuthService
|
||||
from lzwcai_mcp_api_converter.src.util.logger_config import get_logger
|
||||
|
||||
# 配置日志
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 全局AuthService实例(单例模式)
|
||||
_auth_service_instance = None
|
||||
|
||||
|
||||
def get_auth_service() -> AuthService:
|
||||
"""
|
||||
获取全局共享的AuthService实例
|
||||
|
||||
使用单例模式确保整个应用中只有一个AuthService实例,
|
||||
这样可以复用认证状态和连接,提高性能。
|
||||
|
||||
Returns:
|
||||
AuthService: 全局共享的AuthService实例
|
||||
|
||||
使用示例:
|
||||
>>> auth_service = get_auth_service()
|
||||
>>> token_result = await auth_service.authorize_request(user_id, business_system_id)
|
||||
"""
|
||||
global _auth_service_instance
|
||||
if _auth_service_instance is None:
|
||||
logger.info("创建新的AuthService实例")
|
||||
_auth_service_instance = AuthService()
|
||||
return _auth_service_instance
|
||||
|
||||
|
||||
# ==================== 功能1: 获取接口参数信息 ====================
|
||||
|
||||
def get_api_parameters_info(api_id: int, auth_token: str = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定接口ID的参数信息
|
||||
|
||||
功能说明:
|
||||
- 根据接口ID获取接口的详细配置信息
|
||||
- 生成JSON Schema格式的参数定义
|
||||
- 返回包含参数列表、Schema、描述等完整信息的配置对象
|
||||
|
||||
Args:
|
||||
api_id: 接口ID,例如 1957355058730684417
|
||||
auth_token: 认证token,如果不提供则使用默认token
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 处理后的接口配置对象,包含:
|
||||
- interfaceName: 拼音格式的接口名称
|
||||
- schema: JSON Schema格式的参数定义
|
||||
- schema_description: 完整的参数描述
|
||||
- 原始API配置的所有其他字段
|
||||
|
||||
Raises:
|
||||
Exception: 当获取接口信息失败时抛出
|
||||
|
||||
使用示例:
|
||||
>>> api_info = get_api_parameters_info(1957355058730684417)
|
||||
>>> print(api_info['interfaceName']) # tool_AppZhangHaoMiMaDengLu
|
||||
>>> print(api_info['schema']) # JSON Schema对象
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始获取接口ID {api_id} 的参数信息...")
|
||||
|
||||
# 步骤1: 调用业务接口获取API详情
|
||||
api_details = get_business_api_details([api_id], auth_token)
|
||||
|
||||
if not api_details:
|
||||
logger.error(f"未找到接口ID {api_id} 的详情")
|
||||
return None
|
||||
|
||||
# 获取第一个(也是唯一一个)API详情
|
||||
api_detail = api_details[0]
|
||||
|
||||
# 步骤2: 使用ApiBase处理API配置,生成Schema
|
||||
api_base = ApiBase([api_detail])
|
||||
|
||||
# 获取处理后的配置
|
||||
processed_configs = api_base.api_configs_map
|
||||
if not processed_configs:
|
||||
logger.error("API配置处理失败")
|
||||
return None
|
||||
|
||||
processed_config = processed_configs[0]
|
||||
|
||||
logger.info(f"成功获取接口 '{processed_config.get('interfaceName', 'N/A')}' 的参数信息")
|
||||
return processed_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取接口参数信息失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# ==================== 功能2: 调用指定接口API ====================
|
||||
|
||||
async def call_api_by_id(api_id: int, request_params: Dict[str, Any], auth_token: str = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据接口ID调用API
|
||||
|
||||
功能说明:
|
||||
- 根据接口ID获取API配置
|
||||
- 智能处理认证(先尝试认证调用,失败后尝试无认证调用)
|
||||
- 发送HTTP请求并返回响应结果
|
||||
|
||||
Args:
|
||||
api_id: 接口ID,例如 1957355058730684417
|
||||
request_params: 请求参数,格式如:
|
||||
{
|
||||
"body": {
|
||||
"username": "test_user",
|
||||
"password": "test_password"
|
||||
},
|
||||
"lzwcaiConfig": {
|
||||
"userId": "test_user_id"
|
||||
}
|
||||
}
|
||||
auth_token: 认证token,如果不提供则使用默认token
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: API调用结果
|
||||
|
||||
Raises:
|
||||
Exception: 当API调用失败时抛出
|
||||
|
||||
使用示例:
|
||||
>>> params = {
|
||||
... "body": {"username": "wangpeng1", "password": "Wp147258"},
|
||||
... "lzwcaiConfig": {"userId": "447"}
|
||||
... }
|
||||
>>> result = await call_api_by_id(1957355058730684417, params)
|
||||
>>> print(result['code']) # 200
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始调用接口ID {api_id}...")
|
||||
|
||||
# 步骤1: 获取API配置
|
||||
api_details = get_business_api_details([api_id], auth_token)
|
||||
|
||||
if not api_details:
|
||||
logger.error(f"未找到接口ID {api_id} 的详情")
|
||||
return None
|
||||
|
||||
# 获取第一个(也是唯一一个)API详情
|
||||
api_detail = api_details[0]
|
||||
|
||||
# 步骤2: 使用ApiBase处理API配置
|
||||
api_base = ApiBase([api_detail])
|
||||
processed_configs = api_base.api_configs_map
|
||||
|
||||
if not processed_configs:
|
||||
logger.error("API配置处理失败")
|
||||
return None
|
||||
|
||||
processed_config = processed_configs[0]
|
||||
|
||||
# 步骤3: 调用API
|
||||
logger.info(f"调用接口: {processed_config.get('interfaceName', 'N/A')}")
|
||||
logger.info(f"接口地址: {processed_config.get('apiUrl', 'N/A')}")
|
||||
logger.info(f"请求方法: {processed_config.get('method', 'N/A')}")
|
||||
|
||||
# 判断是否需要认证
|
||||
need_auth = processed_config.get('authenticationRequired', 0) == 1
|
||||
logger.info(f"需要认证: {need_auth}")
|
||||
|
||||
# 尝试调用API(智能认证处理)
|
||||
try:
|
||||
result = await call_api(processed_config, request_params, need_auth=need_auth)
|
||||
except Exception as api_error:
|
||||
# 如果认证失败,尝试不需要认证的方式调用
|
||||
if "认证失败" in str(api_error) or "鉴权令牌失败" in str(api_error):
|
||||
logger.warning(f"认证调用失败,尝试无认证调用: {str(api_error)}")
|
||||
try:
|
||||
result = await call_api(processed_config, request_params, need_auth=False)
|
||||
logger.info("无认证调用成功")
|
||||
except Exception as no_auth_error:
|
||||
logger.error(f"无认证调用也失败: {str(no_auth_error)}")
|
||||
# 返回模拟的错误响应,展示调用过程
|
||||
result = {
|
||||
"error": "API调用失败",
|
||||
"auth_error": str(api_error),
|
||||
"no_auth_error": str(no_auth_error),
|
||||
"note": "这是一个模拟的错误响应,展示了API调用的完整过程"
|
||||
}
|
||||
else:
|
||||
raise api_error
|
||||
|
||||
logger.info("API调用成功")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调用API失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# ==================== 功能3: 获取用户授权Token ====================
|
||||
|
||||
def get_auth_info(user_id: str, business_system_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取用户认证信息
|
||||
|
||||
功能说明:
|
||||
- 获取用户在指定业务系统中的认证配置
|
||||
- 包含认证类型、API密钥、登录接口等信息
|
||||
|
||||
Args:
|
||||
user_id: 用户ID,例如 "447"
|
||||
business_system_id: 业务系统ID,例如 "1957354824118095874"
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 认证信息,包括:
|
||||
- authType: 认证类型
|
||||
- name: 业务系统名称
|
||||
- apiKey: API密钥
|
||||
- apiVO: API配置信息
|
||||
|
||||
Raises:
|
||||
Exception: 当获取认证信息失败时抛出
|
||||
|
||||
使用示例:
|
||||
>>> auth_info = get_auth_info("447", "1957354824118095874")
|
||||
>>> print(auth_info['data']['authType']) # 1
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始获取用户 {user_id} 在业务系统 {business_system_id} 的认证信息...")
|
||||
|
||||
# 调用get_auth_data获取认证数据
|
||||
auth_data = get_auth_data(user_id, business_system_id)
|
||||
|
||||
if not auth_data:
|
||||
logger.error("未能获取到认证数据")
|
||||
return None
|
||||
|
||||
logger.info("成功获取认证信息")
|
||||
return auth_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取认证信息失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_business_token(user_id: str, business_system_id: str,persist_token: bool = False) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取用户在指定业务系统的Token
|
||||
|
||||
功能说明:
|
||||
- 使用全局共享的AuthService实例获取Token
|
||||
- 复用认证状态,提高性能
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
business_system_id: 业务系统ID
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Token获取结果
|
||||
|
||||
Raises:
|
||||
Exception: 当获取Token失败时抛出
|
||||
|
||||
使用示例:
|
||||
>>> token_result = await get_business_token("447", "1957354824118095874")
|
||||
>>> print(token_result.get('success')) # True
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始获取用户 {user_id} 在业务系统 {business_system_id} 的Token...")
|
||||
|
||||
# 使用全局共享的AuthService实例获取Token
|
||||
auth_service = get_auth_service()
|
||||
token_result = await auth_service.authorize_request(user_id, business_system_id,persist_token=persist_token)
|
||||
|
||||
logger.info("Token获取完成")
|
||||
return token_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取业务Token失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# ==================== 便捷工具函数 ====================
|
||||
|
||||
def save_result_to_file(data: Dict[str, Any], filename: str) -> None:
|
||||
"""
|
||||
保存结果到JSON文件
|
||||
|
||||
Args:
|
||||
data: 要保存的数据
|
||||
filename: 文件名
|
||||
"""
|
||||
try:
|
||||
with open(filename, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"💾 数据已保存到: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存文件失败: {str(e)}")
|
||||
|
||||
|
||||
def print_api_info(api_info: Dict[str, Any]) -> None:
|
||||
"""
|
||||
格式化打印API信息
|
||||
|
||||
Args:
|
||||
api_info: API信息字典
|
||||
"""
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("📋 接口信息:")
|
||||
logger.info(f"接口名称: {api_info.get('interfaceName', 'N/A')}")
|
||||
logger.info(f"业务描述: {api_info.get('businessPrompts', 'N/A')}")
|
||||
logger.info(f"接口地址: {api_info.get('apiUrl', 'N/A')}")
|
||||
logger.info(f"请求方法: {api_info.get('method', 'N/A')}")
|
||||
logger.info(f"需要认证: {'是' if api_info.get('authenticationRequired', 0) == 1 else '否'}")
|
||||
logger.info("="*60)
|
||||
|
||||
|
||||
def print_api_result(result: Dict[str, Any]) -> None:
|
||||
"""
|
||||
格式化打印API调用结果
|
||||
|
||||
Args:
|
||||
result: API调用结果
|
||||
"""
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("📥 API调用结果:")
|
||||
if result.get('success') is not None:
|
||||
status = "✅ 成功" if result.get('success') else "❌ 失败"
|
||||
logger.info(f"状态: {status}")
|
||||
if result.get('code') is not None:
|
||||
logger.info(f"状态码: {result.get('code')}")
|
||||
if result.get('msg'):
|
||||
logger.info(f"消息: {result.get('msg')}")
|
||||
logger.info("="*60)
|
||||
|
||||
|
||||
# ==================== 便捷使用示例 ====================
|
||||
|
||||
async def demo_usage():
|
||||
"""
|
||||
使用示例演示
|
||||
|
||||
展示如何使用工具包的三个核心功能
|
||||
"""
|
||||
logger.info("🚀 API助手工具包使用示例")
|
||||
|
||||
# 测试数据
|
||||
# api_id = 1957355058730684417
|
||||
user_id = "2"
|
||||
business_system_id = "1922839602141347842"
|
||||
# os.environ["lzwcai_mcp_dyntoolapi_auth_url"] = (
|
||||
# "http://lzwcai-demp-corp-manager:8086/system/mcpServer/bizSys/api/getByIds"
|
||||
# )
|
||||
try:
|
||||
# get_auth_info_data=await get_business_token(user_id, business_system_id)
|
||||
# print(f"获取接口参数信息: {get_auth_info_data}")
|
||||
# # 功能1: 获取接口参数信息
|
||||
# print(f"\n📋 功能1: 获取接口ID {api_id} 的参数信息")
|
||||
# api_info = get_api_parameters_info(api_id)
|
||||
# if api_info:
|
||||
# print(f"✅ 接口名称: {api_info.get('interfaceName', 'N/A')}")
|
||||
# print(f"✅ 业务描述: {api_info.get('businessPrompts', 'N/A')}")
|
||||
|
||||
# # 功能2: 调用API
|
||||
# print(f"\n🔧 功能2: 调用接口ID {api_id}")
|
||||
# request_params = {
|
||||
# "body": {
|
||||
# "username": "wangpeng1",
|
||||
# "password": "Wp147258"
|
||||
# },
|
||||
# "lzwcaiConfig": {
|
||||
# "userId": user_id
|
||||
# }
|
||||
# }
|
||||
# api_result = await call_api_by_id(api_id, request_params)
|
||||
# if api_result:
|
||||
# print(f"✅ API调用成功: {api_result.get('msg', 'N/A')}")
|
||||
|
||||
# # 功能3: 获取Token
|
||||
# logger.info(f"\n🔑 功能3: 获取用户 {user_id} 的Token")
|
||||
token_result = await get_business_token(user_id, business_system_id)
|
||||
if token_result and token_result.get('success'):
|
||||
logger.info(f"✅ Token获取成功: {token_result.get('msg', 'N/A')}")
|
||||
|
||||
logger.info(f"\n🎉 所有功能演示完成!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 演示过程中发生错误: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行使用示例
|
||||
asyncio.run(demo_usage())
|
||||
@@ -0,0 +1,554 @@
|
||||
"""
|
||||
统一日志配置模块
|
||||
|
||||
这个模块提供了整个项目的统一日志配置和管理功能,确保所有组件使用一致的日志格式和输出方式。
|
||||
|
||||
主要功能:
|
||||
1. 统一的日志格式配置
|
||||
2. 支持控制台和文件双重输出
|
||||
3. 日志文件轮转管理
|
||||
4. MCP模式下的特殊处理(禁用控制台输出)
|
||||
5. 便捷的日志器获取接口
|
||||
6. 丰富的日志工具函数
|
||||
|
||||
设计特点:
|
||||
- 单例模式确保配置一致性
|
||||
- 支持动态配置调整
|
||||
- 异常安全的编码处理
|
||||
- 详细的调试信息记录
|
||||
|
||||
作者: lzwcai
|
||||
版本: 1.0.0
|
||||
"""
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LoggerConfig:
|
||||
"""
|
||||
日志配置管理器
|
||||
|
||||
这个类采用单例模式管理整个项目的日志配置。
|
||||
它提供了统一的日志格式、文件轮转、编码处理等功能。
|
||||
|
||||
主要特性:
|
||||
- 单例模式:确保全局日志配置一致
|
||||
- 双重输出:同时支持控制台和文件输出
|
||||
- 文件轮转:自动管理日志文件大小和数量
|
||||
- 编码安全:正确处理中文字符
|
||||
- MCP兼容:支持MCP模式下的特殊需求
|
||||
|
||||
配置参数:
|
||||
DEFAULT_LOG_LEVEL: 默认日志级别(INFO)
|
||||
DEFAULT_LOG_FORMAT: 日志格式模板
|
||||
DEFAULT_DATE_FORMAT: 时间格式
|
||||
LOG_FILE_NAME: 日志文件名
|
||||
MAX_LOG_SIZE: 单个日志文件最大大小(10MB)
|
||||
BACKUP_COUNT: 保留的备份文件数量(5个)
|
||||
"""
|
||||
|
||||
# ==================== 默认配置常量 ====================
|
||||
|
||||
# 默认日志级别:INFO级别平衡了信息量和性能
|
||||
DEFAULT_LOG_LEVEL = logging.INFO
|
||||
|
||||
# 默认日志格式:包含时间、模块名、级别、文件位置、消息内容
|
||||
DEFAULT_LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s"
|
||||
|
||||
# 默认时间格式:标准的年-月-日 时:分:秒格式
|
||||
DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
# ==================== 日志文件配置 ====================
|
||||
|
||||
# 日志文件名:使用项目名称作为前缀
|
||||
LOG_FILE_NAME = "lzwcai_mcp_api_converter.log"
|
||||
|
||||
# 单个日志文件最大大小:10MB
|
||||
MAX_LOG_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
|
||||
# 保留的备份文件数量:5个(总共约50MB的日志存储)
|
||||
BACKUP_COUNT = 5
|
||||
|
||||
# ==================== 单例模式状态 ====================
|
||||
|
||||
# 初始化标志:确保只初始化一次
|
||||
_initialized = False
|
||||
|
||||
# 日志文件路径:记录当前使用的日志文件路径
|
||||
_log_file_path = None
|
||||
|
||||
@classmethod
|
||||
def setup_logging(
|
||||
cls,
|
||||
log_level: int = DEFAULT_LOG_LEVEL,
|
||||
log_file: Optional[str] = None,
|
||||
console_output: bool = True,
|
||||
file_output: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
设置项目统一日志配置
|
||||
|
||||
这是日志系统的核心初始化方法,负责配置整个项目的日志输出。
|
||||
采用单例模式,确保在整个应用生命周期中只初始化一次。
|
||||
|
||||
配置流程:
|
||||
1. 检查是否已经初始化(单例模式)
|
||||
2. 确定日志文件路径(自动或手动指定)
|
||||
3. 创建必要的目录结构
|
||||
4. 配置根日志器和处理器
|
||||
5. 设置日志格式化器
|
||||
6. 添加控制台和文件处理器
|
||||
7. 记录初始化信息
|
||||
|
||||
特殊处理:
|
||||
- MCP模式下通常禁用控制台输出,避免干扰stdio通信
|
||||
- Windows系统下的UTF-8编码处理
|
||||
- 日志文件的自动轮转管理
|
||||
|
||||
参数:
|
||||
log_level: 日志级别(DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
log_file: 日志文件路径,None时使用默认路径
|
||||
console_output: 是否输出到控制台(MCP模式下通常为False)
|
||||
file_output: 是否输出到文件(通常为True)
|
||||
|
||||
返回:
|
||||
str: 实际使用的日志文件路径
|
||||
|
||||
注意事项:
|
||||
- 这个方法是线程安全的
|
||||
- 重复调用会直接返回已配置的路径
|
||||
- 日志文件会自动创建必要的目录
|
||||
"""
|
||||
# 单例模式检查:如果已经初始化,直接返回
|
||||
if cls._initialized:
|
||||
return cls._log_file_path
|
||||
|
||||
# ==================== 日志文件路径配置 ====================
|
||||
|
||||
if log_file is None:
|
||||
# 自动确定日志文件路径:项目根目录 + 默认文件名
|
||||
project_root = cls._get_project_root()
|
||||
log_file = project_root / cls.LOG_FILE_NAME
|
||||
else:
|
||||
# 使用指定的日志文件路径
|
||||
log_file = Path(log_file)
|
||||
|
||||
# 确保日志目录存在(递归创建)
|
||||
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
cls._log_file_path = str(log_file)
|
||||
|
||||
# ==================== 包日志器配置 ====================
|
||||
|
||||
# 获取包的顶层日志器,而不是根日志器
|
||||
package_logger = logging.getLogger('lzwcai_mcp_api_converter')
|
||||
package_logger.setLevel(log_level)
|
||||
|
||||
# 作为库,不应该清除宿主应用的任何处理器
|
||||
# 也不应该让日志消息向上传播到根日志器,以免重复打印
|
||||
package_logger.propagate = False
|
||||
|
||||
# 清除此日志器上现有的处理器,避免重复配置
|
||||
for handler in package_logger.handlers[:]:
|
||||
package_logger.removeHandler(handler)
|
||||
|
||||
# ==================== 日志格式化器 ====================
|
||||
|
||||
# 创建统一的日志格式化器
|
||||
formatter = logging.Formatter(
|
||||
fmt=cls.DEFAULT_LOG_FORMAT, # 日志格式模板
|
||||
datefmt=cls.DEFAULT_DATE_FORMAT # 时间格式
|
||||
)
|
||||
|
||||
# ==================== 控制台处理器配置 ====================
|
||||
|
||||
if console_output:
|
||||
# 控制台输出处理器,支持彩色输出和UTF-8编码
|
||||
import io
|
||||
|
||||
# 处理Windows系统的编码问题
|
||||
if hasattr(sys.stdout, 'buffer'):
|
||||
# 在Windows上强制使用UTF-8编码,避免中文乱码
|
||||
# errors='replace'确保即使有编码问题也不会崩溃
|
||||
console_stream = io.TextIOWrapper(
|
||||
sys.stdout.buffer,
|
||||
encoding='utf-8',
|
||||
errors='replace'
|
||||
)
|
||||
else:
|
||||
# Unix/Linux系统通常默认支持UTF-8
|
||||
console_stream = sys.stdout
|
||||
|
||||
# 创建控制台处理器
|
||||
console_handler = logging.StreamHandler(console_stream)
|
||||
console_handler.setLevel(log_level)
|
||||
console_handler.setFormatter(formatter)
|
||||
package_logger.addHandler(console_handler)
|
||||
|
||||
# ==================== 文件处理器配置 ====================
|
||||
|
||||
if file_output:
|
||||
# 文件输出处理器,支持自动轮转
|
||||
file_handler = logging.handlers.RotatingFileHandler(
|
||||
filename=cls._log_file_path, # 日志文件路径
|
||||
maxBytes=cls.MAX_LOG_SIZE, # 单文件最大大小
|
||||
backupCount=cls.BACKUP_COUNT, # 备份文件数量
|
||||
encoding='utf-8' # 文件编码
|
||||
)
|
||||
file_handler.setLevel(log_level)
|
||||
file_handler.setFormatter(formatter)
|
||||
package_logger.addHandler(file_handler)
|
||||
|
||||
# ==================== 初始化完成标记 ====================
|
||||
|
||||
# 标记为已初始化,防止重复配置
|
||||
cls._initialized = True
|
||||
|
||||
# ==================== 记录初始化信息 ====================
|
||||
|
||||
# 获取当前模块的日志器并记录初始化信息
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"日志系统初始化完成 - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
logger.info(f"日志级别: {logging.getLevelName(log_level)}")
|
||||
logger.info(f"日志文件: {cls._log_file_path}")
|
||||
logger.info(f"控制台输出: {console_output}")
|
||||
logger.info(f"文件输出: {file_output}")
|
||||
logger.info(f"文件轮转: 最大{cls.MAX_LOG_SIZE // (1024*1024)}MB, 保留{cls.BACKUP_COUNT}个备份")
|
||||
logger.info("=" * 80)
|
||||
|
||||
return cls._log_file_path
|
||||
|
||||
@classmethod
|
||||
def _get_project_root(cls) -> Path:
|
||||
"""
|
||||
获取项目根目录
|
||||
|
||||
这个方法通过向上遍历目录树来查找项目根目录。
|
||||
它会寻找常见的项目标识文件来确定根目录位置。
|
||||
|
||||
查找策略:
|
||||
1. 从当前文件所在目录开始向上查找
|
||||
2. 寻找项目标识文件:pyproject.toml, setup.py, main.py
|
||||
3. 找到任一标识文件的目录即为项目根目录
|
||||
4. 如果都找不到,使用当前文件的上级目录作为备选
|
||||
|
||||
返回:
|
||||
Path: 项目根目录的路径对象
|
||||
|
||||
注意事项:
|
||||
- 这个方法假设项目结构相对标准
|
||||
- 在特殊的部署环境中可能需要调整
|
||||
- 备选方案确保总是返回有效路径
|
||||
"""
|
||||
# 从当前文件向上查找项目根目录
|
||||
current_path = Path(__file__).parent
|
||||
|
||||
# 向上遍历目录树
|
||||
while current_path.parent != current_path: # 避免到达文件系统根目录
|
||||
# 检查常见的项目标识文件
|
||||
if (current_path / "pyproject.toml").exists() or \
|
||||
(current_path / "setup.py").exists() or \
|
||||
(current_path / "main.py").exists():
|
||||
return current_path
|
||||
current_path = current_path.parent
|
||||
|
||||
# 备选方案:如果找不到标识文件,使用预设的相对路径
|
||||
# 这个路径基于当前的项目结构:util -> src -> 项目根
|
||||
return Path(__file__).parent.parent.parent
|
||||
|
||||
@classmethod
|
||||
def get_logger(cls, name: str) -> logging.Logger:
|
||||
"""
|
||||
获取配置好的日志器
|
||||
|
||||
这是获取日志器的标准方法,确保返回的日志器使用统一的配置。
|
||||
如果日志系统尚未初始化,会自动进行初始化。
|
||||
|
||||
参数:
|
||||
name: 日志器名称,通常使用模块的 __name__ 变量
|
||||
|
||||
返回:
|
||||
logging.Logger: 配置好的日志器实例
|
||||
|
||||
使用示例:
|
||||
logger = LoggerConfig.get_logger(__name__)
|
||||
logger.info("这是一条信息日志")
|
||||
|
||||
特性:
|
||||
- 自动初始化:首次调用时自动配置日志系统(MCP模式下禁用控制台输出)
|
||||
- 层次化命名:支持Python日志器的层次化命名
|
||||
- 统一配置:所有日志器使用相同的格式和输出配置
|
||||
"""
|
||||
# 检查是否已初始化,未初始化则使用默认配置初始化
|
||||
# 重要:在MCP模式下禁用控制台输出,避免干扰stdio通信
|
||||
if not cls._initialized:
|
||||
cls.setup_logging(console_output=False, file_output=True)
|
||||
|
||||
# 返回指定名称的日志器
|
||||
return logging.getLogger(name)
|
||||
|
||||
# ==================== 日志工具方法 ====================
|
||||
|
||||
@classmethod
|
||||
def log_function_entry(cls, logger: logging.Logger, func_name: str, **kwargs):
|
||||
"""
|
||||
记录函数入口日志
|
||||
|
||||
用于调试和性能分析,记录函数被调用时的参数信息。
|
||||
通常在DEBUG级别输出,不会影响生产环境的性能。
|
||||
|
||||
参数:
|
||||
logger: 日志器实例
|
||||
func_name: 函数名称
|
||||
**kwargs: 函数参数(键值对形式)
|
||||
|
||||
使用示例:
|
||||
LoggerConfig.log_function_entry(logger, "process_data", user_id=123, action="login")
|
||||
"""
|
||||
args_str = ", ".join([f"{k}={v}" for k, v in kwargs.items()])
|
||||
logger.debug(f"进入函数 {func_name}({args_str})")
|
||||
|
||||
@classmethod
|
||||
def log_function_exit(cls, logger: logging.Logger, func_name: str, result=None):
|
||||
"""
|
||||
记录函数出口日志
|
||||
|
||||
与log_function_entry配对使用,记录函数执行完成和返回值。
|
||||
有助于跟踪函数执行流程和调试返回值问题。
|
||||
|
||||
参数:
|
||||
logger: 日志器实例
|
||||
func_name: 函数名称
|
||||
result: 函数返回值(可选)
|
||||
|
||||
使用示例:
|
||||
LoggerConfig.log_function_exit(logger, "process_data", result={"status": "success"})
|
||||
"""
|
||||
if result is not None:
|
||||
logger.debug(f"退出函数 {func_name},返回值: {result}")
|
||||
else:
|
||||
logger.debug(f"退出函数 {func_name}")
|
||||
|
||||
@classmethod
|
||||
def log_api_request(cls, logger: logging.Logger, method: str, url: str, **kwargs):
|
||||
"""
|
||||
记录API请求日志
|
||||
|
||||
标准化API请求的日志记录,包含HTTP方法、URL和请求参数。
|
||||
有助于API调用的监控和调试。
|
||||
|
||||
参数:
|
||||
logger: 日志器实例
|
||||
method: HTTP方法(GET, POST, PUT, DELETE等)
|
||||
url: 请求URL
|
||||
**kwargs: 请求参数(可选)
|
||||
|
||||
使用示例:
|
||||
LoggerConfig.log_api_request(logger, "POST", "https://api.example.com/users",
|
||||
headers={"Authorization": "Bearer xxx"})
|
||||
"""
|
||||
logger.info(f"API请求 - {method} {url}")
|
||||
if kwargs:
|
||||
logger.debug(f"请求参数: {kwargs}")
|
||||
|
||||
@classmethod
|
||||
def log_api_response(cls, logger: logging.Logger, status_code: int, response_time: float = None):
|
||||
"""
|
||||
记录API响应日志
|
||||
|
||||
记录API响应的状态码和响应时间,用于性能监控和问题诊断。
|
||||
|
||||
参数:
|
||||
logger: 日志器实例
|
||||
status_code: HTTP状态码
|
||||
response_time: 响应时间(秒,可选)
|
||||
|
||||
使用示例:
|
||||
LoggerConfig.log_api_response(logger, 200, 0.156)
|
||||
"""
|
||||
if response_time:
|
||||
logger.info(f"API响应 - 状态码: {status_code}, 响应时间: {response_time:.3f}s")
|
||||
else:
|
||||
logger.info(f"API响应 - 状态码: {status_code}")
|
||||
|
||||
@classmethod
|
||||
def log_error_with_context(cls, logger: logging.Logger, error: Exception, context: str = ""):
|
||||
"""
|
||||
记录带上下文的错误日志
|
||||
|
||||
提供丰富的错误信息记录,包含异常类型、错误消息、上下文信息和详细堆栈。
|
||||
这是错误处理的标准方法。
|
||||
|
||||
参数:
|
||||
logger: 日志器实例
|
||||
error: 异常对象
|
||||
context: 错误发生的上下文描述(可选)
|
||||
|
||||
使用示例:
|
||||
try:
|
||||
risky_operation()
|
||||
except Exception as e:
|
||||
LoggerConfig.log_error_with_context(logger, e, "处理用户请求时")
|
||||
"""
|
||||
if context:
|
||||
logger.error(f"错误发生在 {context}: {type(error).__name__}: {str(error)}")
|
||||
else:
|
||||
logger.error(f"错误: {type(error).__name__}: {str(error)}")
|
||||
# 记录详细的异常堆栈信息(仅在DEBUG级别显示)
|
||||
logger.debug("错误详情:", exc_info=True)
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
获取日志器的便捷函数
|
||||
|
||||
这是LoggerConfig.get_logger的简化版本,提供更简洁的调用方式。
|
||||
推荐在模块级别使用这个函数获取日志器。
|
||||
|
||||
参数:
|
||||
name: 日志器名称,通常使用 __name__
|
||||
|
||||
返回:
|
||||
logging.Logger: 配置好的日志器实例
|
||||
|
||||
使用示例:
|
||||
logger = get_logger(__name__)
|
||||
"""
|
||||
return LoggerConfig.get_logger(name)
|
||||
|
||||
|
||||
def setup_logging(**kwargs) -> str:
|
||||
"""
|
||||
设置日志的便捷函数
|
||||
|
||||
这是LoggerConfig.setup_logging的简化版本,支持所有相同的参数。
|
||||
|
||||
参数:
|
||||
**kwargs: 传递给LoggerConfig.setup_logging的所有参数
|
||||
|
||||
返回:
|
||||
str: 日志文件路径
|
||||
|
||||
使用示例:
|
||||
log_file = setup_logging(log_level=logging.DEBUG, console_output=False)
|
||||
"""
|
||||
return LoggerConfig.setup_logging(**kwargs)
|
||||
|
||||
|
||||
# ==================== 装饰器 ====================
|
||||
|
||||
def log_function_calls(logger: Optional[logging.Logger] = None):
|
||||
"""
|
||||
函数调用日志装饰器
|
||||
|
||||
这个装饰器自动记录函数的调用和返回,包括参数和返回值。
|
||||
主要用于调试和性能分析,在生产环境中通常设置为DEBUG级别。
|
||||
|
||||
特性:
|
||||
- 自动记录函数入口和出口
|
||||
- 记录函数参数(kwargs)
|
||||
- 记录返回值
|
||||
- 自动处理异常并记录错误上下文
|
||||
- 支持自定义日志器或自动获取
|
||||
|
||||
参数:
|
||||
logger: 可选的日志器实例,None时自动获取函数所在模块的日志器
|
||||
|
||||
返回:
|
||||
装饰器函数
|
||||
|
||||
使用示例:
|
||||
@log_function_calls()
|
||||
def process_user_data(user_id, action="login"):
|
||||
# 函数实现
|
||||
return {"status": "success"}
|
||||
|
||||
# 或者指定日志器
|
||||
@log_function_calls(logger=my_logger)
|
||||
def another_function():
|
||||
pass
|
||||
|
||||
注意事项:
|
||||
- 会记录所有kwargs参数,注意不要记录敏感信息
|
||||
- 返回值也会被记录,大对象可能影响性能
|
||||
- 异常会被重新抛出,不会被吞掉
|
||||
"""
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal logger
|
||||
# 如果没有提供日志器,自动获取函数所在模块的日志器
|
||||
if logger is None:
|
||||
logger = get_logger(func.__module__)
|
||||
|
||||
func_name = func.__name__
|
||||
|
||||
# 记录函数入口(只记录kwargs,避免记录过多信息)
|
||||
LoggerConfig.log_function_entry(logger, func_name, **kwargs)
|
||||
|
||||
try:
|
||||
# 执行原函数
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# 记录函数出口和返回值
|
||||
LoggerConfig.log_function_exit(logger, func_name, result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# 记录异常信息并重新抛出
|
||||
LoggerConfig.log_error_with_context(logger, e, f"函数 {func_name}")
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# ==================== 测试代码 ====================
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
日志配置测试代码
|
||||
|
||||
这个测试代码演示了日志系统的基本功能,包括:
|
||||
1. 日志系统初始化
|
||||
2. 不同级别的日志输出
|
||||
3. 日志文件路径获取
|
||||
4. 装饰器功能测试
|
||||
|
||||
运行方式:
|
||||
python -m lzwcai_mcp_api_converter.src.util.logger_config
|
||||
"""
|
||||
# 初始化日志系统(DEBUG级别,同时输出到控制台和文件)
|
||||
log_file = setup_logging(log_level=logging.DEBUG)
|
||||
test_logger = get_logger(__name__)
|
||||
|
||||
test_logger.info("开始测试日志配置...")
|
||||
|
||||
# 测试不同级别的日志输出
|
||||
test_logger.debug("这是一个调试消息 - 用于开发调试")
|
||||
test_logger.info("这是一个信息消息 - 记录重要信息")
|
||||
test_logger.warning("这是一个警告消息 - 提醒注意事项")
|
||||
test_logger.error("这是一个错误消息 - 记录错误情况")
|
||||
|
||||
# 测试工具方法
|
||||
LoggerConfig.log_api_request(test_logger, "GET", "https://api.example.com/test")
|
||||
LoggerConfig.log_api_response(test_logger, 200, 0.123)
|
||||
|
||||
# 测试装饰器
|
||||
@log_function_calls()
|
||||
def test_function(param1, param2="default"):
|
||||
"""测试函数"""
|
||||
return {"result": "success", "param1": param1}
|
||||
|
||||
# 调用测试函数
|
||||
result = test_function("test_value", param2="custom")
|
||||
|
||||
# 输出日志文件位置
|
||||
test_logger.info(f"日志文件位置: {log_file}")
|
||||
test_logger.info("日志配置测试完成!")
|
||||
@@ -0,0 +1,168 @@
|
||||
import re
|
||||
from typing import Any, Optional, Union, Dict, List
|
||||
from .logger_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 预编译正则表达式
|
||||
PATH_PATTERN = re.compile(r"\{([^}]+)\}")
|
||||
|
||||
|
||||
def get_nested_value(
|
||||
data: Union[Dict, List, Any],
|
||||
path_str: str,
|
||||
default: Any = None,
|
||||
raise_error: bool = False,
|
||||
) -> Any:
|
||||
"""从嵌套字典或列表中根据路径字符串获取值
|
||||
|
||||
支持以下格式的路径字符串:
|
||||
1. 简单路径: {res.data.token}
|
||||
2. 带前缀: Bearer {res.data.token}
|
||||
3. 带索引: {res.data.items[0].name}
|
||||
4. 带引号的键: {res.data."user.name"}
|
||||
|
||||
Args:
|
||||
data: 要查询的数据,可以是字典、列表或其他类型
|
||||
path_str: 路径字符串,支持{prefix} {path}或{path}格式
|
||||
default: 当路径不存在时返回的默认值
|
||||
raise_error: 是否在出错时抛出异常,默认为False
|
||||
|
||||
Returns:
|
||||
查询到的值或默认值
|
||||
|
||||
Raises:
|
||||
ValueError: 当raise_error为True且路径格式无效时
|
||||
KeyError: 当raise_error为True且路径不存在时
|
||||
TypeError: 当raise_error为True且类型错误时
|
||||
"""
|
||||
try:
|
||||
if not path_str:
|
||||
logger.warning("路径字符串为空")
|
||||
return default
|
||||
|
||||
if data is None:
|
||||
logger.warning("输入数据为None")
|
||||
return default
|
||||
|
||||
if not isinstance(data, (dict, list)):
|
||||
logger.warning(f"输入数据类型不支持: {type(data)}")
|
||||
return default
|
||||
|
||||
# 使用预编译的正则表达式匹配路径
|
||||
matches = PATH_PATTERN.findall(path_str)
|
||||
|
||||
if not matches:
|
||||
logger.warning(f"路径字符串格式无效: {path_str}")
|
||||
if raise_error:
|
||||
raise ValueError(f"路径字符串格式无效: {path_str}")
|
||||
return default
|
||||
|
||||
# 处理路径
|
||||
if len(matches) == 1:
|
||||
actual_path = matches[0]
|
||||
# 获取{...}之前的所有文本作为前缀
|
||||
prefix = path_str[: path_str.find("{")].strip()
|
||||
elif len(matches) == 2:
|
||||
prefix = matches[0]
|
||||
actual_path = matches[1]
|
||||
else:
|
||||
logger.warning(f"路径字符串包含过多匹配项: {path_str}")
|
||||
if raise_error:
|
||||
raise ValueError(f"路径字符串包含过多匹配项: {path_str}")
|
||||
return default
|
||||
|
||||
# 解析路径
|
||||
current = data
|
||||
# 使用更智能的路径分割
|
||||
path_parts = []
|
||||
current_part = ""
|
||||
in_quotes = False
|
||||
|
||||
for char in actual_path:
|
||||
if char == '"':
|
||||
in_quotes = not in_quotes
|
||||
current_part += char
|
||||
elif char == "." and not in_quotes:
|
||||
path_parts.append(current_part)
|
||||
current_part = ""
|
||||
else:
|
||||
current_part += char
|
||||
|
||||
if current_part:
|
||||
path_parts.append(current_part)
|
||||
|
||||
for part in path_parts:
|
||||
# 处理数组索引
|
||||
if "[" in part and part.endswith("]"):
|
||||
key, index_str = part.split("[", 1)
|
||||
index_str = index_str.rstrip("]")
|
||||
|
||||
# 获取键值
|
||||
if key and isinstance(current, dict):
|
||||
current = current.get(key)
|
||||
elif not key and isinstance(current, list):
|
||||
pass
|
||||
else:
|
||||
logger.warning(f"无效的键: {key}")
|
||||
if raise_error:
|
||||
raise KeyError(f"无效的键: {key}")
|
||||
return default
|
||||
|
||||
# 获取索引
|
||||
try:
|
||||
index = int(index_str)
|
||||
if not isinstance(current, list) or not (0 <= index < len(current)):
|
||||
logger.warning(f"无效的索引: {index}")
|
||||
if raise_error:
|
||||
raise IndexError(f"无效的索引: {index}")
|
||||
return default
|
||||
current = current[index]
|
||||
except ValueError:
|
||||
logger.warning(f"无效的索引格式: {index_str}")
|
||||
if raise_error:
|
||||
raise ValueError(f"无效的索引格式: {index_str}")
|
||||
return default
|
||||
else:
|
||||
# 处理普通键
|
||||
if isinstance(current, dict):
|
||||
# 处理带引号的键
|
||||
if part.startswith('"') and part.endswith('"'):
|
||||
part = part[1:-1]
|
||||
if part not in current:
|
||||
logger.warning(f"键不存在: {part}")
|
||||
if raise_error:
|
||||
raise KeyError(f"键不存在: {part}")
|
||||
return default
|
||||
current = current[part]
|
||||
elif isinstance(current, list):
|
||||
try:
|
||||
index = int(part)
|
||||
if not (0 <= index < len(current)):
|
||||
logger.warning(f"索引越界: {index}")
|
||||
if raise_error:
|
||||
raise IndexError(f"索引越界: {index}")
|
||||
return default
|
||||
current = current[index]
|
||||
except ValueError:
|
||||
logger.warning(f"无效的列表索引: {part}")
|
||||
if raise_error:
|
||||
raise ValueError(f"无效的列表索引: {part}")
|
||||
return default
|
||||
else:
|
||||
logger.warning(f"无法访问键: {part}")
|
||||
if raise_error:
|
||||
raise TypeError(f"无法访问键: {part}")
|
||||
return default
|
||||
|
||||
# 处理前缀
|
||||
if prefix and current is not None:
|
||||
return f"{prefix} {current}"
|
||||
|
||||
return current
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌套值失败: {str(e)}")
|
||||
if raise_error:
|
||||
raise
|
||||
return default
|
||||
Reference in New Issue
Block a user