mirror of
https://github.com/rnvm9wjdtj-bot/myaps_api.git
synced 2026-06-02 05:54:40 +00:00
fa5cecd6d1
- 安全: 修复鉴权失败返回码(HTTP 401/403替代200) - 安全: 新增SafeQueryBuilder封堵SQL注入入口 - 安全: 移除Pydantic json_encoders弃用配置 - 稳定: 统一后台任务托管与生命周期管理 - 稳定: 新增TaskManager统一管理后台任务 - 文档: 更新README.md与.env.example - 重构: routers.py使用安全SQL构建器替代字符串拼接
326 lines
11 KiB
Python
326 lines
11 KiB
Python
from typing import Dict, Any, List, Tuple, Literal
|
|
import enum
|
|
from datetime import datetime
|
|
|
|
from fastapi import status, HTTPException, status, Request
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.exceptions import RequestValidationError
|
|
from pydantic import BaseModel as PydanticSchema
|
|
|
|
# 从 globalobjects.db_manager 导入 dict_to_lower_keys 函数
|
|
from globalobjects.db_manager import dict_to_lower_keys
|
|
|
|
# from core.settings import MYAPS_MAIN_DB
|
|
# from globalobjects.globalconst import SupplyTypeEnum
|
|
|
|
def format_query_result(d: dict) -> dict:
|
|
"""
|
|
格式化查询结果
|
|
1. 将字典的键转换为小写
|
|
2. 格式化字典中的日期时间字段(支持datetime对象和ISO 8601字符串格式)
|
|
"""
|
|
result = {}
|
|
for k, v in d.items():
|
|
# 将键转换为小写
|
|
lower_key = k.lower()
|
|
# 格式化日期时间字段
|
|
if isinstance(v, datetime):
|
|
result[lower_key] = v.strftime("%Y-%m-%d %H:%M:%S")
|
|
# elif isinstance(v, str) and 'T' in v:
|
|
# # 尝试解析ISO 8601格式的字符串
|
|
# try:
|
|
# # 移除可能的时区信息
|
|
# if '+' in v or '-' in v:
|
|
# v = v.split('+')[0].split('-')[0]
|
|
# # 解析字符串为datetime对象
|
|
# dt = datetime.fromisoformat(v)
|
|
# # 格式化为目标格式
|
|
# result[lower_key] = dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
# except ValueError:
|
|
# # 如果解析失败,保留原始值
|
|
# result[lower_key] = v
|
|
else:
|
|
result[lower_key] = v
|
|
return result
|
|
|
|
# 路由相关公共格式
|
|
def standard_response(
|
|
status_code: int = status.HTTP_200_OK,
|
|
success: int = 1,
|
|
message: str = "success",
|
|
data: Any = None,
|
|
meta: Dict[str, Any] = None
|
|
):
|
|
# 延迟导入,避免循环依赖
|
|
from .db_operation import DbResult, MultiDbResult
|
|
|
|
# 处理 DbResult 或 MultiDbResult 类型的返回值
|
|
if isinstance(data, (DbResult, MultiDbResult)):
|
|
return {
|
|
"status_code": status_code,
|
|
"success": data.success,
|
|
"message": data.message,
|
|
"meta": data.meta,
|
|
"data": data.data
|
|
}
|
|
|
|
return {
|
|
"status_code": status_code,
|
|
"success": success,
|
|
"message": message,
|
|
"meta": meta,
|
|
"data": data
|
|
}
|
|
|
|
# # url - 公共参数
|
|
# common_params = {
|
|
# "db_name": Query(MYAPS_MAIN_DB, examples={"default": {"value": MYAPS_MAIN_DB}}, description="账套"),
|
|
# "page_size": Query(1000, description="每页数量", gt=0, le=10000),
|
|
# "page_index": Query(0, description="分页页码,从0开始", ge=0),
|
|
# "supply_type": Query(..., description="供应类型", openapi_examples={key: {"value": key, "summary": value.value} for key, value in SupplyTypeEnum.__members__.items()}),
|
|
# "x_api_key": Header(None, description="API密钥")
|
|
# }
|
|
|
|
|
|
def get_raw_input_data(data_item: PydanticSchema | Dict[str, Any]) -> Dict:
|
|
"""
|
|
获取model_validator之前的原始数据
|
|
|
|
Args:
|
|
data_item: 单个数据项,可以是PydanticSchema对象或字典
|
|
|
|
Returns:
|
|
model_validator之前的原始数据
|
|
"""
|
|
if isinstance(data_item, PydanticSchema):
|
|
# 检查是否有_raw_input_data属性(在after验证阶段设置的)
|
|
if hasattr(data_item, '_raw_input_data'):
|
|
return data_item._raw_input_data
|
|
else:
|
|
# 如果没有,说明可能没有执行before验证或没有暂存数据
|
|
# 尝试直接访问属性或使用model_dump(include='_raw_input_data')
|
|
try:
|
|
# 使用model_dump获取所有数据,包括私有属性
|
|
all_data = data_item.model_dump(include={'_raw_input_data'}, exclude_none=False)
|
|
return all_data['_raw_input_data']
|
|
except Exception:
|
|
return data_item.model_dump(exclude_none=False)
|
|
else:
|
|
# 如果不是PydanticSchema对象,直接使用原始值
|
|
return data_item
|
|
|
|
|
|
def convert_to_dict(data_item: PydanticSchema | Dict[str, Any], exclude_none: bool = True) -> dict:
|
|
"""
|
|
将数据项转换为字典
|
|
|
|
Args:
|
|
data_item: 单个数据项,可以是PydanticSchema对象或字典
|
|
exclude_none: 是否排除None值
|
|
|
|
Returns:
|
|
转换后的字典
|
|
"""
|
|
if isinstance(data_item, PydanticSchema):
|
|
return data_item.model_dump(exclude_none=exclude_none)
|
|
return data_item
|
|
|
|
|
|
def format_data_for_logging(data):
|
|
"""
|
|
格式化数据用于日志记录,将复杂类型转换为字符串表示
|
|
|
|
Args:
|
|
data: 要格式化的数据
|
|
|
|
Returns:
|
|
格式化后的数据
|
|
"""
|
|
from decimal import Decimal
|
|
import enum
|
|
|
|
if isinstance(data, dict):
|
|
return {k: format_data_for_logging(v) for k, v in data.items()}
|
|
elif isinstance(data, list):
|
|
return [format_data_for_logging(item) for item in data]
|
|
elif isinstance(data, enum.Enum):
|
|
return data.value
|
|
elif isinstance(data, Decimal):
|
|
return str(data)
|
|
else:
|
|
return data
|
|
|
|
|
|
# pydantic验证错误统一格式
|
|
class CustomValidationError(HTTPException):
|
|
def __init__(self, errors: List[Dict[str, Any]]):
|
|
status_code = status.HTTP_422_UNPROCESSABLE_ENTITY
|
|
detail = standard_response(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
success=0,
|
|
message="数据验证错误",
|
|
meta={
|
|
"error_details": errors
|
|
}
|
|
)
|
|
super().__init__(status_code=status_code, detail=detail)
|
|
|
|
|
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
|
errors = []
|
|
for error in exc.errors():
|
|
errors.append({
|
|
"field": "->".join(str(loc) for loc in error['loc']),
|
|
"message": error['msg'],
|
|
"type": error['type']
|
|
})
|
|
raise CustomValidationError(errors=errors)
|
|
|
|
|
|
async def http_exception_handler(request: Request, exc: HTTPException):
|
|
if isinstance(exc.detail, dict):
|
|
detail = exc.detail
|
|
else:
|
|
detail = {
|
|
"message": str(exc.detail),
|
|
"meta": None
|
|
}
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content=standard_response(
|
|
status_code=exc.status_code,
|
|
success=0,
|
|
message=detail.get("message", "Error occurred"),
|
|
data=None,
|
|
meta=detail.get("meta", None)
|
|
)
|
|
)
|
|
|
|
|
|
def register_exception_handlers(app):
|
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
|
app.add_exception_handler(HTTPException, http_exception_handler)
|
|
|
|
|
|
|
|
async def drop_matched_data(data: List[Any], db_names: str, table_name: str, match_on: Tuple[str, ...], db_fields: Tuple[str, ...]=None):
|
|
"""
|
|
根据组合字段删除已存在的数据项
|
|
Args:
|
|
data: 新数据列表
|
|
db_names: 数据库名称
|
|
table_name: 数据库表名
|
|
match_on: 组合字段,用于唯一标识数据项,如 ("materialno", "matver")
|
|
db_fields: 数据库字段,用于删除数据,如 ("MaterialNo", "MatVer")
|
|
"""
|
|
from .db_operation import db_delete
|
|
from globalobjects import logger as log_config
|
|
logger = log_config.get_logger(__name__)
|
|
|
|
# 收集唯一组合
|
|
db_fields = db_fields or match_on
|
|
unique_combinations = set()
|
|
for item in data:
|
|
field_values = []
|
|
for field in match_on:
|
|
if isinstance(item, dict):
|
|
field_value = item.get(field)
|
|
else:
|
|
field_value = getattr(item, field, None)
|
|
field_values.append(field_value)
|
|
|
|
# 确保所有字段都有值
|
|
if all(field_values):
|
|
unique_combinations.add(tuple(field_values))
|
|
|
|
# 分批删除
|
|
if unique_combinations:
|
|
batch_size = 100
|
|
combinations_list = list(unique_combinations)
|
|
|
|
from globalobjects.db_manager import build_safe_filter
|
|
|
|
for i in range(0, len(combinations_list), batch_size):
|
|
batch = combinations_list[i:i+batch_size]
|
|
batch_conditions = []
|
|
for values in batch:
|
|
# 构建条件,支持任意数量的字段
|
|
field_conditions = []
|
|
for db_field, value in zip(db_fields, values):
|
|
field_conditions.append((db_field, "=", value))
|
|
batch_conditions.append(build_safe_filter(field_conditions))
|
|
filter_string = " OR ".join(f"({cond})" for cond in batch_conditions)
|
|
try:
|
|
await db_delete(db_names=db_names, model_or_tablename=table_name, filter_string=filter_string)
|
|
except Exception as e:
|
|
logger.error(f"删除数据失败: {str(e)}")
|
|
raise e
|
|
|
|
|
|
async def mark_as_removing(
|
|
model_class,
|
|
table_name: str,
|
|
drop: Literal["all", "matched"],
|
|
data_list: List[Dict] = None,
|
|
drop_fields: List[str] = None
|
|
) -> int:
|
|
"""
|
|
将缓冲表数据标记为removing状态(用于staging模式drop功能)
|
|
|
|
Args:
|
|
model_class: Tortoise ORM模型类
|
|
table_name: 表名
|
|
drop: 标记方式,"all"或"matched"
|
|
data_list: 新数据列表(drop="matched"时需要)
|
|
drop_fields: 匹配字段列表(drop="matched"时需要)
|
|
|
|
Returns:
|
|
标记的记录数
|
|
"""
|
|
from globalobjects import logger as log_config
|
|
from apps.data_opt.mds._base import StagingStatus
|
|
logger = log_config.get_logger(__name__)
|
|
|
|
if drop == "all":
|
|
count = await model_class.all().update(_status=StagingStatus.REMOVING)
|
|
logger.info(f"drop=all: 已将 {table_name} 全部 {count} 条记录标记为removing")
|
|
return count
|
|
|
|
elif drop == "matched":
|
|
if not data_list or not drop_fields:
|
|
logger.warning(f"drop=matched: 缺少data_list或drop_fields,跳过")
|
|
return 0
|
|
|
|
unique_combinations = set()
|
|
for item in data_list:
|
|
field_values = []
|
|
for field in drop_fields:
|
|
value = item.get(field)
|
|
if value is not None and value != '':
|
|
field_values.append(value)
|
|
else:
|
|
break
|
|
|
|
if len(field_values) == len(drop_fields):
|
|
unique_combinations.add(tuple(field_values))
|
|
|
|
if not unique_combinations:
|
|
logger.info(f"drop=matched: 无有效匹配值,跳过")
|
|
return 0
|
|
|
|
total_marked = 0
|
|
batch_size = 100
|
|
combinations_list = list(unique_combinations)
|
|
|
|
for i in range(0, len(combinations_list), batch_size):
|
|
batch = combinations_list[i:i+batch_size]
|
|
|
|
for values in batch:
|
|
condition = {f: v for f, v in zip(drop_fields, values)}
|
|
count = await model_class.filter(**condition).update(_status=StagingStatus.REMOVING)
|
|
total_marked += count
|
|
|
|
logger.info(f"drop=matched: 已将 {table_name} 的 {total_marked} 条记录标记为removing (匹配{len(unique_combinations)}组)")
|
|
return total_marked
|
|
|
|
return 0 |