Files
myaps_api/apps/io_api/utils/common.py
T
chaoge fa5cecd6d1 fix(security,stability): 完成API安全与稳定性修复
- 安全: 修复鉴权失败返回码(HTTP 401/403替代200)
- 安全: 新增SafeQueryBuilder封堵SQL注入入口
- 安全: 移除Pydantic json_encoders弃用配置
- 稳定: 统一后台任务托管与生命周期管理
- 稳定: 新增TaskManager统一管理后台任务
- 文档: 更新README.md与.env.example
- 重构: routers.py使用安全SQL构建器替代字符串拼接
2026-05-25 20:08:35 +08:00

326 lines
11 KiB
Python

from typing import Dict, Any, List, Tuple, Literal
import enum
from datetime import datetime
from fastapi import status, HTTPException, status, Request
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel as PydanticSchema
# 从 globalobjects.db_manager 导入 dict_to_lower_keys 函数
from globalobjects.db_manager import dict_to_lower_keys
# from core.settings import MYAPS_MAIN_DB
# from globalobjects.globalconst import SupplyTypeEnum
def format_query_result(d: dict) -> dict:
"""
格式化查询结果
1. 将字典的键转换为小写
2. 格式化字典中的日期时间字段(支持datetime对象和ISO 8601字符串格式)
"""
result = {}
for k, v in d.items():
# 将键转换为小写
lower_key = k.lower()
# 格式化日期时间字段
if isinstance(v, datetime):
result[lower_key] = v.strftime("%Y-%m-%d %H:%M:%S")
# elif isinstance(v, str) and 'T' in v:
# # 尝试解析ISO 8601格式的字符串
# try:
# # 移除可能的时区信息
# if '+' in v or '-' in v:
# v = v.split('+')[0].split('-')[0]
# # 解析字符串为datetime对象
# dt = datetime.fromisoformat(v)
# # 格式化为目标格式
# result[lower_key] = dt.strftime("%Y-%m-%d %H:%M:%S")
# except ValueError:
# # 如果解析失败,保留原始值
# result[lower_key] = v
else:
result[lower_key] = v
return result
# 路由相关公共格式
def standard_response(
status_code: int = status.HTTP_200_OK,
success: int = 1,
message: str = "success",
data: Any = None,
meta: Dict[str, Any] = None
):
# 延迟导入,避免循环依赖
from .db_operation import DbResult, MultiDbResult
# 处理 DbResult 或 MultiDbResult 类型的返回值
if isinstance(data, (DbResult, MultiDbResult)):
return {
"status_code": status_code,
"success": data.success,
"message": data.message,
"meta": data.meta,
"data": data.data
}
return {
"status_code": status_code,
"success": success,
"message": message,
"meta": meta,
"data": data
}
# # url - 公共参数
# common_params = {
# "db_name": Query(MYAPS_MAIN_DB, examples={"default": {"value": MYAPS_MAIN_DB}}, description="账套"),
# "page_size": Query(1000, description="每页数量", gt=0, le=10000),
# "page_index": Query(0, description="分页页码,从0开始", ge=0),
# "supply_type": Query(..., description="供应类型", openapi_examples={key: {"value": key, "summary": value.value} for key, value in SupplyTypeEnum.__members__.items()}),
# "x_api_key": Header(None, description="API密钥")
# }
def get_raw_input_data(data_item: PydanticSchema | Dict[str, Any]) -> Dict:
"""
获取model_validator之前的原始数据
Args:
data_item: 单个数据项,可以是PydanticSchema对象或字典
Returns:
model_validator之前的原始数据
"""
if isinstance(data_item, PydanticSchema):
# 检查是否有_raw_input_data属性(在after验证阶段设置的)
if hasattr(data_item, '_raw_input_data'):
return data_item._raw_input_data
else:
# 如果没有,说明可能没有执行before验证或没有暂存数据
# 尝试直接访问属性或使用model_dump(include='_raw_input_data')
try:
# 使用model_dump获取所有数据,包括私有属性
all_data = data_item.model_dump(include={'_raw_input_data'}, exclude_none=False)
return all_data['_raw_input_data']
except Exception:
return data_item.model_dump(exclude_none=False)
else:
# 如果不是PydanticSchema对象,直接使用原始值
return data_item
def convert_to_dict(data_item: PydanticSchema | Dict[str, Any], exclude_none: bool = True) -> dict:
"""
将数据项转换为字典
Args:
data_item: 单个数据项,可以是PydanticSchema对象或字典
exclude_none: 是否排除None值
Returns:
转换后的字典
"""
if isinstance(data_item, PydanticSchema):
return data_item.model_dump(exclude_none=exclude_none)
return data_item
def format_data_for_logging(data):
"""
格式化数据用于日志记录,将复杂类型转换为字符串表示
Args:
data: 要格式化的数据
Returns:
格式化后的数据
"""
from decimal import Decimal
import enum
if isinstance(data, dict):
return {k: format_data_for_logging(v) for k, v in data.items()}
elif isinstance(data, list):
return [format_data_for_logging(item) for item in data]
elif isinstance(data, enum.Enum):
return data.value
elif isinstance(data, Decimal):
return str(data)
else:
return data
# pydantic验证错误统一格式
class CustomValidationError(HTTPException):
def __init__(self, errors: List[Dict[str, Any]]):
status_code = status.HTTP_422_UNPROCESSABLE_ENTITY
detail = standard_response(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
success=0,
message="数据验证错误",
meta={
"error_details": errors
}
)
super().__init__(status_code=status_code, detail=detail)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
errors = []
for error in exc.errors():
errors.append({
"field": "->".join(str(loc) for loc in error['loc']),
"message": error['msg'],
"type": error['type']
})
raise CustomValidationError(errors=errors)
async def http_exception_handler(request: Request, exc: HTTPException):
if isinstance(exc.detail, dict):
detail = exc.detail
else:
detail = {
"message": str(exc.detail),
"meta": None
}
return JSONResponse(
status_code=exc.status_code,
content=standard_response(
status_code=exc.status_code,
success=0,
message=detail.get("message", "Error occurred"),
data=None,
meta=detail.get("meta", None)
)
)
def register_exception_handlers(app):
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(HTTPException, http_exception_handler)
async def drop_matched_data(data: List[Any], db_names: str, table_name: str, match_on: Tuple[str, ...], db_fields: Tuple[str, ...]=None):
"""
根据组合字段删除已存在的数据项
Args:
data: 新数据列表
db_names: 数据库名称
table_name: 数据库表名
match_on: 组合字段,用于唯一标识数据项,如 ("materialno", "matver")
db_fields: 数据库字段,用于删除数据,如 ("MaterialNo", "MatVer")
"""
from .db_operation import db_delete
from globalobjects import logger as log_config
logger = log_config.get_logger(__name__)
# 收集唯一组合
db_fields = db_fields or match_on
unique_combinations = set()
for item in data:
field_values = []
for field in match_on:
if isinstance(item, dict):
field_value = item.get(field)
else:
field_value = getattr(item, field, None)
field_values.append(field_value)
# 确保所有字段都有值
if all(field_values):
unique_combinations.add(tuple(field_values))
# 分批删除
if unique_combinations:
batch_size = 100
combinations_list = list(unique_combinations)
from globalobjects.db_manager import build_safe_filter
for i in range(0, len(combinations_list), batch_size):
batch = combinations_list[i:i+batch_size]
batch_conditions = []
for values in batch:
# 构建条件,支持任意数量的字段
field_conditions = []
for db_field, value in zip(db_fields, values):
field_conditions.append((db_field, "=", value))
batch_conditions.append(build_safe_filter(field_conditions))
filter_string = " OR ".join(f"({cond})" for cond in batch_conditions)
try:
await db_delete(db_names=db_names, model_or_tablename=table_name, filter_string=filter_string)
except Exception as e:
logger.error(f"删除数据失败: {str(e)}")
raise e
async def mark_as_removing(
model_class,
table_name: str,
drop: Literal["all", "matched"],
data_list: List[Dict] = None,
drop_fields: List[str] = None
) -> int:
"""
将缓冲表数据标记为removing状态(用于staging模式drop功能)
Args:
model_class: Tortoise ORM模型类
table_name: 表名
drop: 标记方式,"all""matched"
data_list: 新数据列表(drop="matched"时需要)
drop_fields: 匹配字段列表(drop="matched"时需要)
Returns:
标记的记录数
"""
from globalobjects import logger as log_config
from apps.data_opt.mds._base import StagingStatus
logger = log_config.get_logger(__name__)
if drop == "all":
count = await model_class.all().update(_status=StagingStatus.REMOVING)
logger.info(f"drop=all: 已将 {table_name} 全部 {count} 条记录标记为removing")
return count
elif drop == "matched":
if not data_list or not drop_fields:
logger.warning(f"drop=matched: 缺少data_list或drop_fields,跳过")
return 0
unique_combinations = set()
for item in data_list:
field_values = []
for field in drop_fields:
value = item.get(field)
if value is not None and value != '':
field_values.append(value)
else:
break
if len(field_values) == len(drop_fields):
unique_combinations.add(tuple(field_values))
if not unique_combinations:
logger.info(f"drop=matched: 无有效匹配值,跳过")
return 0
total_marked = 0
batch_size = 100
combinations_list = list(unique_combinations)
for i in range(0, len(combinations_list), batch_size):
batch = combinations_list[i:i+batch_size]
for values in batch:
condition = {f: v for f, v in zip(drop_fields, values)}
count = await model_class.filter(**condition).update(_status=StagingStatus.REMOVING)
total_marked += count
logger.info(f"drop=matched: 已将 {table_name}{total_marked} 条记录标记为removing (匹配{len(unique_combinations)}组)")
return total_marked
return 0