mirror of
https://github.com/rnvm9wjdtj-bot/myaps_api.git
synced 2026-06-02 05:54:40 +00:00
fa5cecd6d1
- 安全: 修复鉴权失败返回码(HTTP 401/403替代200) - 安全: 新增SafeQueryBuilder封堵SQL注入入口 - 安全: 移除Pydantic json_encoders弃用配置 - 稳定: 统一后台任务托管与生命周期管理 - 稳定: 新增TaskManager统一管理后台任务 - 文档: 更新README.md与.env.example - 重构: routers.py使用安全SQL构建器替代字符串拼接
2335 lines
96 KiB
Python
2335 lines
96 KiB
Python
from typing import List, Dict, Any, Tuple, Optional, Union, Literal
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime
|
||
import time
|
||
import asyncio
|
||
import functools
|
||
import re
|
||
|
||
from tortoise import Tortoise
|
||
from tortoise.connection import connections
|
||
from tortoise.expressions import Q
|
||
from tortoise.transactions import in_transaction
|
||
from tortoise.exceptions import IntegrityError
|
||
|
||
from core.settings import MYAPS_DBSET_LIST
|
||
from globalobjects import logger as log_config
|
||
import os
|
||
|
||
LOG_LEVEL = os.getenv("LOG_LEVEL") or "INFO"
|
||
logger = log_config.get_logger(__name__, level=LOG_LEVEL)
|
||
|
||
|
||
def escape_sql_value(value: Any) -> str:
|
||
"""
|
||
安全转义SQL值,防止SQL注入
|
||
|
||
Args:
|
||
value: 要转义的值
|
||
|
||
Returns:
|
||
转义后的安全字符串
|
||
"""
|
||
if value is None:
|
||
return "NULL"
|
||
|
||
if isinstance(value, bool):
|
||
return "1" if value else "0"
|
||
|
||
if isinstance(value, (int, float)):
|
||
return str(value)
|
||
|
||
if isinstance(value, datetime):
|
||
return f"'{value.strftime('%Y-%m-%d %H:%M:%S')}'"
|
||
|
||
# 字符串处理:转义单引号
|
||
str_value = str(value)
|
||
# 将单引号转义为两个单引号(SQL标准)
|
||
escaped = str_value.replace("'", "''")
|
||
return f"'{escaped}'"
|
||
|
||
|
||
def validate_identifier(identifier: str) -> str:
|
||
"""
|
||
验证并安全化SQL标识符(表名、字段名)
|
||
|
||
Args:
|
||
identifier: 标识符
|
||
|
||
Returns:
|
||
安全的标识符(用反引号包裹)
|
||
|
||
Raises:
|
||
ValueError: 如果标识符包含危险字符
|
||
"""
|
||
# 移除首尾空格
|
||
identifier = identifier.strip()
|
||
|
||
# 检查危险字符
|
||
dangerous_chars = ["'", '"', ';', '--', '/*', '*/', '\x00', '\n', '\r']
|
||
for char in dangerous_chars:
|
||
if char in identifier:
|
||
raise ValueError(f"Invalid identifier: contains dangerous character '{char}'")
|
||
|
||
# 用反引号包裹
|
||
return f"`{identifier}`"
|
||
|
||
|
||
def build_safe_condition(field: str, operator: str, value: Any) -> str:
|
||
"""
|
||
构建安全的SQL条件表达式
|
||
|
||
Args:
|
||
field: 字段名
|
||
operator: 操作符 (=, !=, >, <, >=, <=, LIKE, IN)
|
||
value: 值
|
||
|
||
Returns:
|
||
安全的SQL条件字符串
|
||
"""
|
||
safe_field = validate_identifier(field)
|
||
operator = operator.upper().strip()
|
||
|
||
if operator == "IN":
|
||
if not isinstance(value, (list, tuple)):
|
||
raise ValueError("IN operator requires a list or tuple of values")
|
||
escaped_values = [escape_sql_value(v) for v in value]
|
||
return f"{safe_field} IN ({', '.join(escaped_values)})"
|
||
|
||
if operator == "LIKE":
|
||
return f"{safe_field} LIKE {escape_sql_value(value)}"
|
||
|
||
# 标准比较操作符
|
||
valid_operators = ['=', '!=', '<>', '>', '<', '>=', '<=', 'IS', 'IS NOT']
|
||
if operator not in valid_operators:
|
||
raise ValueError(f"Invalid operator: {operator}")
|
||
|
||
if operator in ('IS', 'IS NOT'):
|
||
if value is None:
|
||
return f"{safe_field} {operator} NULL"
|
||
raise ValueError(f"{operator} operator only accepts None value")
|
||
|
||
return f"{safe_field} {operator} {escape_sql_value(value)}"
|
||
|
||
|
||
def build_safe_filter(conditions: List[Tuple[str, str, Any]], logic: str = "AND") -> str:
|
||
"""
|
||
构建安全的WHERE条件字符串
|
||
|
||
Args:
|
||
conditions: 条件列表,每个条件为 (字段名, 操作符, 值) 元组
|
||
logic: 逻辑连接符 (AND/OR)
|
||
|
||
Returns:
|
||
安全的WHERE条件字符串
|
||
"""
|
||
if not conditions:
|
||
return ""
|
||
|
||
logic = logic.upper().strip()
|
||
if logic not in ("AND", "OR"):
|
||
raise ValueError(f"Invalid logic operator: {logic}")
|
||
|
||
safe_conditions = [build_safe_condition(*cond) for cond in conditions]
|
||
return f" {logic} ".join(safe_conditions)
|
||
|
||
|
||
def build_safe_order_by(fields: List[Tuple[str, str]]) -> str:
|
||
"""
|
||
构建安全的ORDER BY子句
|
||
|
||
Args:
|
||
fields: 排序字段列表,每个元素为 (字段名, 方向) 元组
|
||
方向为 'ASC' 或 'DESC'
|
||
|
||
Returns:
|
||
安全的ORDER BY字符串
|
||
"""
|
||
if not fields:
|
||
return ""
|
||
|
||
order_parts = []
|
||
for field, direction in fields:
|
||
safe_field = validate_identifier(field)
|
||
direction = direction.upper().strip()
|
||
if direction not in ("ASC", "DESC"):
|
||
raise ValueError(f"Invalid order direction: {direction}")
|
||
order_parts.append(f"{safe_field} {direction}")
|
||
|
||
return ", ".join(order_parts)
|
||
|
||
|
||
def build_safe_select(fields: List[str]) -> str:
|
||
"""
|
||
构建安全的SELECT字段列表
|
||
|
||
Args:
|
||
fields: 字段名列表
|
||
|
||
Returns:
|
||
安全的SELECT字段字符串
|
||
"""
|
||
if not fields:
|
||
return "*"
|
||
|
||
return ", ".join(validate_identifier(f) for f in fields)
|
||
|
||
|
||
class SafeQueryBuilder:
|
||
"""
|
||
安全SQL查询构建器
|
||
|
||
提供链式调用的SQL构建接口,自动处理转义和验证
|
||
"""
|
||
|
||
def __init__(self, table_name: str):
|
||
"""
|
||
初始化查询构建器
|
||
|
||
Args:
|
||
table_name: 表名
|
||
"""
|
||
self._table = validate_identifier(table_name)
|
||
self._select_fields = "*"
|
||
self._conditions = []
|
||
self._order_fields = []
|
||
self._limit = None
|
||
self._offset = None
|
||
|
||
def select(self, *fields: str) -> 'SafeQueryBuilder':
|
||
"""设置SELECT字段"""
|
||
if fields:
|
||
self._select_fields = build_safe_select(list(fields))
|
||
return self
|
||
|
||
def where(self, field: str, operator: str, value: Any) -> 'SafeQueryBuilder':
|
||
"""添加WHERE条件"""
|
||
self._conditions.append((field, operator, value))
|
||
return self
|
||
|
||
def where_in(self, field: str, values: List[Any]) -> 'SafeQueryBuilder':
|
||
"""添加IN条件"""
|
||
self._conditions.append((field, "IN", values))
|
||
return self
|
||
|
||
def where_like(self, field: str, pattern: str) -> 'SafeQueryBuilder':
|
||
"""添加LIKE条件"""
|
||
self._conditions.append((field, "LIKE", pattern))
|
||
return self
|
||
|
||
def where_between(self, field: str, start: Any, end: Any) -> 'SafeQueryBuilder':
|
||
"""添加BETWEEN条件"""
|
||
safe_field = validate_identifier(field)
|
||
self._conditions.append((f"{safe_field} >= {escape_sql_value(start)}", "=", True))
|
||
self._conditions.append((f"{safe_field} <= {escape_sql_value(end)}", "=", True))
|
||
return self
|
||
|
||
def order_by(self, field: str, direction: str = "ASC") -> 'SafeQueryBuilder':
|
||
"""添加排序"""
|
||
self._order_fields.append((field, direction))
|
||
return self
|
||
|
||
def limit(self, count: int) -> 'SafeQueryBuilder':
|
||
"""设置LIMIT"""
|
||
if count > 0:
|
||
self._limit = count
|
||
return self
|
||
|
||
def offset(self, count: int) -> 'SafeQueryBuilder':
|
||
"""设置OFFSET"""
|
||
if count >= 0:
|
||
self._offset = count
|
||
return self
|
||
|
||
def build_select_sql(self) -> str:
|
||
"""构建SELECT SQL语句"""
|
||
sql = f"SELECT {self._select_fields} FROM {self._table}"
|
||
|
||
if self._conditions:
|
||
where_clause = build_safe_filter(self._conditions)
|
||
sql += f" WHERE {where_clause}"
|
||
|
||
if self._order_fields:
|
||
order_clause = build_safe_order_by(self._order_fields)
|
||
sql += f" ORDER BY {order_clause}"
|
||
|
||
if self._limit is not None:
|
||
sql += f" LIMIT {self._limit}"
|
||
|
||
if self._offset is not None:
|
||
sql += f" OFFSET {self._offset}"
|
||
|
||
return sql
|
||
|
||
def build_count_sql(self) -> str:
|
||
"""构建COUNT SQL语句"""
|
||
sql = f"SELECT COUNT(*) as total FROM {self._table}"
|
||
|
||
if self._conditions:
|
||
where_clause = build_safe_filter(self._conditions)
|
||
sql += f" WHERE {where_clause}"
|
||
|
||
return sql
|
||
|
||
def build_delete_sql(self) -> str:
|
||
"""构建DELETE SQL语句"""
|
||
if not self._conditions:
|
||
raise ValueError("DELETE operation requires WHERE conditions for safety")
|
||
|
||
where_clause = build_safe_filter(self._conditions)
|
||
return f"DELETE FROM {self._table} WHERE {where_clause}"
|
||
|
||
|
||
def dict_to_lower_keys(d: dict) -> dict:
|
||
"""
|
||
将字典的键转换为小写
|
||
"""
|
||
return {k.lower(): v for k, v in d.items()}
|
||
|
||
|
||
class DbManagerError(Exception):
|
||
"""数据库管理错误基类"""
|
||
def __init__(self, message, operation=None, table=None, connection=None):
|
||
self.message = message
|
||
self.operation = operation
|
||
self.table = table
|
||
self.connection = connection
|
||
super().__init__(message)
|
||
|
||
def to_dict(self):
|
||
return {
|
||
"error": self.__class__.__name__,
|
||
"message": self.message,
|
||
"operation": self.operation,
|
||
"table": self.table,
|
||
"connection": self.connection
|
||
}
|
||
|
||
|
||
class DbConnectionError(DbManagerError):
|
||
"""数据库连接错误"""
|
||
pass
|
||
|
||
|
||
class DbQueryError(DbManagerError):
|
||
"""数据库查询错误"""
|
||
pass
|
||
|
||
|
||
class DbTransactionError(DbManagerError):
|
||
"""数据库事务错误"""
|
||
pass
|
||
|
||
|
||
class DbDeadlockError(DbManagerError):
|
||
"""数据库死锁错误"""
|
||
pass
|
||
|
||
|
||
class DbCircuitBreakerError(DbManagerError):
|
||
"""数据库熔断错误"""
|
||
pass
|
||
|
||
|
||
class DeadlockCircuitBreaker:
|
||
"""
|
||
死锁熔断保护类
|
||
|
||
实现基于时间窗口的熔断机制,当单位时间内死锁次数超过阈值时触发熔断,
|
||
熔断期间所有请求直接失败,避免系统雪崩。
|
||
|
||
状态机:
|
||
CLOSED(闭合)-> 正常处理请求
|
||
OPEN(打开)-> 熔断状态,拒绝请求
|
||
HALF_OPEN(半开)-> 尝试恢复,允许部分请求通过
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
threshold: int = 5,
|
||
window_seconds: int = 60,
|
||
cooldown_seconds: int = 30,
|
||
half_open_attempts: int = 2
|
||
):
|
||
"""
|
||
初始化熔断器
|
||
|
||
Args:
|
||
threshold: 时间窗口内的最大死锁次数(超过此值触发熔断)
|
||
window_seconds: 时间窗口大小(秒)
|
||
cooldown_seconds: 熔断持续时间(秒)
|
||
half_open_attempts: 半开状态下允许的尝试次数
|
||
"""
|
||
self.threshold = threshold
|
||
self.window_seconds = window_seconds
|
||
self.cooldown_seconds = cooldown_seconds
|
||
self.half_open_attempts = half_open_attempts
|
||
|
||
# 状态常量
|
||
self.STATE_CLOSED = "closed"
|
||
self.STATE_OPEN = "open"
|
||
self.STATE_HALF_OPEN = "half_open"
|
||
|
||
# 当前状态
|
||
self._state = self.STATE_CLOSED
|
||
|
||
# 死锁计数器
|
||
self._deadlock_count = 0
|
||
|
||
# 时间戳记录
|
||
self._window_start_time = time.time()
|
||
self._open_time = None
|
||
|
||
# 半开状态下的尝试计数
|
||
self._half_open_success_count = 0
|
||
self._half_open_failure_count = 0
|
||
|
||
@property
|
||
def state(self):
|
||
"""获取当前熔断状态"""
|
||
self._check_state_transition()
|
||
return self._state
|
||
|
||
def _check_state_transition(self):
|
||
"""检查并执行状态转换"""
|
||
now = time.time()
|
||
|
||
# 如果是OPEN状态,检查是否需要转换到HALF_OPEN
|
||
if self._state == self.STATE_OPEN and self._open_time:
|
||
if now - self._open_time >= self.cooldown_seconds:
|
||
self._transition_to_half_open()
|
||
|
||
# 如果是CLOSED状态,检查时间窗口是否需要重置
|
||
if self._state == self.STATE_CLOSED:
|
||
if now - self._window_start_time >= self.window_seconds:
|
||
self._reset_window()
|
||
|
||
def _transition_to_open(self):
|
||
"""转换到OPEN状态(熔断)"""
|
||
self._state = self.STATE_OPEN
|
||
self._open_time = time.time()
|
||
self._deadlock_count = 0
|
||
logger.warning(
|
||
"CircuitBreaker",
|
||
"TRIGGERED",
|
||
f"死锁熔断已触发,将在 {self.cooldown_seconds} 秒后尝试恢复"
|
||
)
|
||
|
||
def _transition_to_half_open(self):
|
||
"""转换到HALF_OPEN状态(尝试恢复)"""
|
||
self._state = self.STATE_HALF_OPEN
|
||
self._half_open_success_count = 0
|
||
self._half_open_failure_count = 0
|
||
logger.warning(
|
||
"CircuitBreaker",
|
||
"HALF_OPEN",
|
||
"熔断进入半开状态,开始尝试恢复"
|
||
)
|
||
|
||
def _transition_to_closed(self):
|
||
"""转换到CLOSED状态(正常)"""
|
||
self._state = self.STATE_CLOSED
|
||
self._deadlock_count = 0
|
||
self._window_start_time = time.time()
|
||
self._open_time = None
|
||
logger.info(
|
||
"CircuitBreaker",
|
||
"RESET",
|
||
"熔断已恢复,系统正常运行"
|
||
)
|
||
|
||
def _reset_window(self):
|
||
"""重置时间窗口"""
|
||
self._deadlock_count = 0
|
||
self._window_start_time = time.time()
|
||
|
||
def record_deadlock(self):
|
||
"""记录一次死锁"""
|
||
self._check_state_transition()
|
||
|
||
if self._state == self.STATE_CLOSED:
|
||
self._deadlock_count += 1
|
||
|
||
# 检查是否超过阈值
|
||
if self._deadlock_count >= self.threshold:
|
||
self._transition_to_open()
|
||
return True # 熔断已触发
|
||
|
||
elif self._state == self.STATE_HALF_OPEN:
|
||
# 半开状态下再次死锁,立即回到OPEN状态
|
||
self._half_open_failure_count += 1
|
||
self._transition_to_open()
|
||
return True
|
||
|
||
return False # 未触发熔断
|
||
|
||
def record_success(self):
|
||
"""记录一次成功操作"""
|
||
self._check_state_transition()
|
||
|
||
if self._state == self.STATE_HALF_OPEN:
|
||
self._half_open_success_count += 1
|
||
|
||
# 检查是否达到恢复条件
|
||
if self._half_open_success_count >= self.half_open_attempts:
|
||
self._transition_to_closed()
|
||
|
||
def is_available(self) -> bool:
|
||
"""检查是否可以执行操作"""
|
||
self._check_state_transition()
|
||
|
||
if self._state == self.STATE_OPEN:
|
||
return False
|
||
|
||
return True
|
||
|
||
def get_wait_time(self) -> float:
|
||
"""获取距离熔断恢复的剩余时间(秒)"""
|
||
if self._state != self.STATE_OPEN or not self._open_time:
|
||
return 0.0
|
||
|
||
elapsed = time.time() - self._open_time
|
||
remaining = max(0.0, self.cooldown_seconds - elapsed)
|
||
return remaining
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""获取熔断器统计信息"""
|
||
self._check_state_transition()
|
||
return {
|
||
"state": self._state,
|
||
"deadlock_count": self._deadlock_count,
|
||
"threshold": self.threshold,
|
||
"window_seconds": self.window_seconds,
|
||
"cooldown_remaining": self.get_wait_time(),
|
||
"half_open_success_count": self._half_open_success_count,
|
||
"half_open_failure_count": self._half_open_failure_count
|
||
}
|
||
|
||
|
||
def with_transaction(func):
|
||
"""
|
||
事务装饰器,根据实例配置或方法参数决定是否使用事务
|
||
|
||
Args:
|
||
func: 要装饰的异步方法
|
||
|
||
Returns:
|
||
装饰后的方法
|
||
"""
|
||
@functools.wraps(func)
|
||
async def wrapper(self, *args, **kwargs):
|
||
# 检查是否有use_transaction参数,如果有则使用,否则使用实例默认值
|
||
transaction_mode = kwargs.pop('use_transaction', self.use_transaction)
|
||
|
||
try:
|
||
if transaction_mode:
|
||
# 使用事务
|
||
async with in_transaction(self.connection_name):
|
||
return await func(self, *args, **kwargs)
|
||
else:
|
||
# 不使用事务
|
||
return await func(self, *args, **kwargs)
|
||
except Exception as e:
|
||
# 特殊处理 bulk_upsert 方法在非事务模式下的错误
|
||
if func.__name__ == 'bulk_upsert' and not transaction_mode:
|
||
# 对于 bulk_upsert 方法,在非事务模式下返回错误信息
|
||
db_table = None
|
||
if 'model_class' in kwargs:
|
||
db_table = getattr(kwargs['model_class']._meta, 'db_table', None)
|
||
logger.fail("批量upsert", f"{db_table}@{self.connection_name}", str(e))
|
||
data_list = kwargs.get('data_list', [])
|
||
return {
|
||
"success": False,
|
||
"error": str(e),
|
||
"total_records": len(data_list),
|
||
"inserted": 0,
|
||
"updated": 0
|
||
}
|
||
else:
|
||
# 转换为自定义异常
|
||
operation = func.__name__
|
||
table = None
|
||
|
||
# 尝试从参数中获取表名
|
||
if 'table_name' in kwargs:
|
||
table = kwargs['table_name']
|
||
elif 'model_class' in kwargs:
|
||
table = getattr(kwargs['model_class']._meta, 'db_table', None)
|
||
|
||
error_str = str(e).upper()
|
||
is_deadlock = 'DEADLOCK' in error_str or '1213' in str(e) or '死锁' in str(e)
|
||
is_operational_error = "OperationalError" in str(type(e))
|
||
is_connection_closed = "Cannot acquire connection after closing pool" in str(e)
|
||
is_none_type_error = "NoneType" in str(e)
|
||
is_connection_error = "Connection" in str(type(e))
|
||
is_timeout_error = "Timeout" in str(type(e))
|
||
is_network_error = "Network" in str(type(e)) or "网络" in str(e)
|
||
is_pool_error = "Pool" in str(type(e))
|
||
|
||
# 检查 OperationalError 是否为连接相关错误
|
||
is_operational_connection_error = is_operational_error and (
|
||
'CONNECTION' in error_str or
|
||
'CONNECT' in error_str or
|
||
'POOL' in error_str or
|
||
'TIMEOUT' in error_str or
|
||
'NETWORK' in error_str or
|
||
'HOST' in error_str or
|
||
'PORT' in error_str or
|
||
'SERVER' in error_str
|
||
)
|
||
|
||
if is_deadlock:
|
||
logger.fail(operation, f"@{self.connection_name}", str(e))
|
||
raise DbDeadlockError(f"数据库死锁: {str(e)}", operation, table, self.connection_name)
|
||
elif is_connection_error or is_operational_connection_error or is_connection_closed:
|
||
logger.fail(operation, f"@{self.connection_name}", str(e))
|
||
raise DbConnectionError(f"数据库连接错误: {str(e)}", operation, table, self.connection_name)
|
||
else:
|
||
logger.fail(operation, f"@{self.connection_name}", str(e))
|
||
raise DbQueryError(f"数据库操作错误: {str(e)}", operation, table, self.connection_name)
|
||
|
||
return wrapper
|
||
|
||
|
||
def handle_db_errors(max_retries: int = 3):
|
||
"""
|
||
带重试机制和熔断保护的数据库错误处理装饰器
|
||
|
||
Args:
|
||
max_retries: 最大重试次数
|
||
|
||
Returns:
|
||
装饰后的方法
|
||
"""
|
||
def decorator(func):
|
||
@functools.wraps(func)
|
||
async def wrapper(self, *args, **kwargs):
|
||
retry_count = 0
|
||
last_exception = None
|
||
|
||
while retry_count <= max_retries:
|
||
# 检查熔断器状态
|
||
circuit_breaker = getattr(self, 'circuit_breaker', None)
|
||
if circuit_breaker and not circuit_breaker.is_available():
|
||
wait_time = circuit_breaker.get_wait_time()
|
||
logger.warning(
|
||
func.__name__,
|
||
f"@{self.connection_name}",
|
||
f"熔断器已触发,拒绝请求,剩余熔断时间: {wait_time:.1f}秒"
|
||
)
|
||
raise DbCircuitBreakerError(
|
||
f"数据库熔断中,剩余{wait_time:.1f}秒后恢复",
|
||
operation=func.__name__,
|
||
connection=self.connection_name
|
||
)
|
||
|
||
try:
|
||
result = await func(self, *args, **kwargs)
|
||
|
||
# 操作成功,通知熔断器(用于半开状态恢复)
|
||
if circuit_breaker:
|
||
circuit_breaker.record_success()
|
||
|
||
return result
|
||
except Exception as e:
|
||
error_str = str(e).upper()
|
||
is_deadlock = 'DEADLOCK' in error_str or '1213' in str(e) or '死锁' in str(e)
|
||
is_operational_error = "OperationalError" in str(type(e))
|
||
is_connection_closed = "Cannot acquire connection after closing pool" in str(e)
|
||
is_none_type_error = "NoneType" in str(e)
|
||
is_connection_error = "Connection" in str(type(e))
|
||
is_timeout_error = "Timeout" in str(type(e))
|
||
is_network_error = "Network" in str(type(e)) or "网络" in str(e)
|
||
is_pool_error = "Pool" in str(type(e))
|
||
|
||
# 检查 OperationalError 是否为连接相关错误
|
||
is_operational_connection_error = is_operational_error and (
|
||
'CONNECTION' in error_str or
|
||
'CONNECT' in error_str or
|
||
'POOL' in error_str or
|
||
'TIMEOUT' in error_str or
|
||
'NETWORK' in error_str or
|
||
'HOST' in error_str or
|
||
'PORT' in error_str or
|
||
'SERVER' in error_str
|
||
)
|
||
|
||
is_retryable = (is_deadlock or is_operational_connection_error or is_connection_closed or
|
||
is_none_type_error or is_connection_error or is_timeout_error or
|
||
is_network_error or is_pool_error)
|
||
|
||
# 死锁时记录到熔断器
|
||
if is_deadlock and circuit_breaker:
|
||
triggered = circuit_breaker.record_deadlock()
|
||
if triggered:
|
||
logger.warning(
|
||
func.__name__,
|
||
f"@{self.connection_name}",
|
||
"熔断器已触发,立即拒绝后续请求"
|
||
)
|
||
|
||
if is_retryable and retry_count < max_retries:
|
||
retry_count += 1
|
||
|
||
# 使用指数退避策略
|
||
if is_connection_closed:
|
||
base_delay = 3.0
|
||
elif is_deadlock:
|
||
base_delay = 2.0
|
||
else:
|
||
base_delay = 1.0
|
||
|
||
current_delay = base_delay * (2 ** (retry_count - 1))
|
||
current_delay = min(current_delay, 20.0)
|
||
|
||
error_type = "连接错误"
|
||
if is_connection_closed:
|
||
error_type = "连接池关闭"
|
||
elif is_deadlock:
|
||
error_type = "死锁"
|
||
elif is_timeout_error:
|
||
error_type = "超时错误"
|
||
elif is_network_error:
|
||
error_type = "网络错误"
|
||
|
||
logger.warning_msg(
|
||
func.__name__,
|
||
f"@{self.connection_name} 检测到{error_type},第{retry_count}次重试中...",
|
||
f"等待{current_delay:.1f}秒后重试"
|
||
)
|
||
|
||
# 尝试刷新连接
|
||
if not is_deadlock:
|
||
await self.refresh_connection(fast_mode=is_connection_closed)
|
||
|
||
await asyncio.sleep(current_delay)
|
||
continue
|
||
else:
|
||
# 转换为自定义异常
|
||
operation = func.__name__
|
||
table = None
|
||
|
||
# 尝试从参数中获取表名
|
||
if 'table_name' in kwargs:
|
||
table = kwargs['table_name']
|
||
elif 'model_class' in kwargs:
|
||
table = getattr(kwargs['model_class']._meta, 'db_table', None)
|
||
|
||
if is_deadlock:
|
||
logger.fail(operation, f"@{self.connection_name}", str(e))
|
||
raise DbDeadlockError(f"数据库死锁: {str(e)}", operation, table, self.connection_name)
|
||
elif is_connection_error or is_operational_connection_error or is_connection_closed:
|
||
logger.fail(operation, f"@{self.connection_name}", str(e))
|
||
raise DbConnectionError(f"数据库连接错误: {str(e)}", operation, table, self.connection_name)
|
||
elif isinstance(e, IntegrityError):
|
||
logger.fail(operation, f"@{self.connection_name}", str(e))
|
||
raise DbTransactionError(f"数据完整性错误: {str(e)}", operation, table, self.connection_name)
|
||
else:
|
||
logger.fail(operation, f"@{self.connection_name}", str(e))
|
||
raise DbQueryError(f"数据库操作错误: {str(e)}", operation, table, self.connection_name)
|
||
|
||
if last_exception:
|
||
raise last_exception
|
||
|
||
return wrapper
|
||
|
||
return decorator
|
||
|
||
|
||
class DbManager:
|
||
"""数据库操作管理器"""
|
||
|
||
# 类级别的熔断器,每个数据库连接共享一个熔断器
|
||
_circuit_breakers: Dict[str, DeadlockCircuitBreaker] = {}
|
||
|
||
def __init__(self, connection_name: str, batch_size: int = 1000, use_transaction: bool = True):
|
||
"""
|
||
初始化管理器
|
||
|
||
Args:
|
||
connection_name: 数据库连接名称
|
||
batch_size: 批量大小,超过此数量会分批处理
|
||
use_transaction: 是否使用事务
|
||
"""
|
||
self.connection_name = connection_name
|
||
self.batch_size = batch_size
|
||
self.use_transaction = use_transaction
|
||
self.stats = {
|
||
'total_processed': 0,
|
||
'batches_executed': 0,
|
||
'last_execution_time': None
|
||
}
|
||
# 动态批量大小配置
|
||
self.min_batch_size = 500
|
||
self.max_batch_size = 5000
|
||
self.optimal_batch_size = batch_size
|
||
self.batch_size_history = []
|
||
self.batch_size_adjustment_interval = 10 # 每10次操作调整一次
|
||
|
||
# 获取或创建熔断器实例(每个连接共享)
|
||
self._get_or_create_circuit_breaker()
|
||
|
||
def _get_or_create_circuit_breaker(self):
|
||
"""获取或创建熔断器实例"""
|
||
if self.connection_name not in DbManager._circuit_breakers:
|
||
DbManager._circuit_breakers[self.connection_name] = DeadlockCircuitBreaker(
|
||
threshold=5, # 60秒内超过5次死锁触发熔断
|
||
window_seconds=60, # 时间窗口60秒
|
||
cooldown_seconds=30, # 熔断持续30秒
|
||
half_open_attempts=2 # 半开状态下成功2次恢复
|
||
)
|
||
|
||
@property
|
||
def circuit_breaker(self) -> DeadlockCircuitBreaker:
|
||
"""获取熔断器实例"""
|
||
return DbManager._circuit_breakers.get(self.connection_name)
|
||
|
||
async def _get_valid_connection(self):
|
||
"""
|
||
获取并检查数据库连接的有效性
|
||
|
||
Returns:
|
||
有效的数据库连接对象
|
||
|
||
Raises:
|
||
DbConnectionError: 如果连接无效
|
||
"""
|
||
import asyncio
|
||
|
||
# 尝试获取连接,最多尝试3次
|
||
for attempt in range(3):
|
||
try:
|
||
# Tortoise ORM已由lifespan初始化,禁止运行时重新初始化
|
||
if not hasattr(Tortoise, '_inited') or not Tortoise._inited:
|
||
raise RuntimeError("Tortoise ORM未初始化,请等待应用启动完成")
|
||
|
||
# 获取数据库连接
|
||
conn = Tortoise.get_connection(self.connection_name)
|
||
|
||
# 检查连接是否有效
|
||
if conn is None:
|
||
logger.warning(f"连接为None,尝试刷新连接: {self.connection_name}")
|
||
await self.refresh_connection(fast_mode=True)
|
||
continue
|
||
|
||
# 检查连接是否已关闭
|
||
if hasattr(conn, 'closed') and conn.closed:
|
||
logger.warning(f"连接已关闭,尝试刷新连接: {self.connection_name}")
|
||
await self.refresh_connection(fast_mode=True)
|
||
continue
|
||
|
||
# 检查连接是否有execute_query方法
|
||
if not hasattr(conn, 'execute_query'):
|
||
logger.warning(f"连接不支持execute_query,尝试刷新连接: {self.connection_name}")
|
||
await self.refresh_connection(fast_mode=True)
|
||
continue
|
||
|
||
# 检查连接是否有_execute_command方法(避免NoneType错误)
|
||
if hasattr(conn, '_execute_command') and conn._execute_command is None:
|
||
logger.warning(f"连接的_execute_command为None,尝试刷新连接: {self.connection_name}")
|
||
await self.refresh_connection(fast_mode=True)
|
||
continue
|
||
|
||
# 验证连接是否可以正常执行查询
|
||
await conn.execute_query("SELECT 1")
|
||
return conn
|
||
|
||
except Exception as e:
|
||
logger.warning(f"获取连接时出错 (尝试 {attempt+1}/3): {e}")
|
||
if attempt < 2:
|
||
# 等待一段时间后重试
|
||
await asyncio.sleep(1)
|
||
# 尝试刷新连接
|
||
await self.refresh_connection(fast_mode=True)
|
||
else:
|
||
raise DbConnectionError(f"获取数据库连接失败:{str(e)}")
|
||
|
||
|
||
@asynccontextmanager
|
||
async def get_connection(self):
|
||
"""
|
||
异步上下文管理器,用于安全地获取Tortoise ORM的数据库连接
|
||
注意:Tortoise会自动管理连接的获取和释放,不需要手动关闭连接
|
||
|
||
Yields:
|
||
Tortoise数据库连接对象
|
||
"""
|
||
connection = Tortoise.get_connection(self.connection_name)
|
||
yield connection
|
||
|
||
|
||
@classmethod
|
||
def _get_conflict_fields(cls, model_class, conflict_fields: Optional[Tuple[str, ...]]=None) -> Tuple[str, ...]:
|
||
"""
|
||
获取冲突字段,如果未提供则自动确定
|
||
|
||
Args:
|
||
model_class: Tortoise 模型类
|
||
conflict_fields: 冲突检测字段(可选)
|
||
|
||
Returns:
|
||
冲突检测字段元组
|
||
|
||
Raises:
|
||
ValueError: 如果模型没有定义主键或唯一约束
|
||
"""
|
||
if conflict_fields is None:
|
||
unique_together = getattr(model_class._meta, 'unique_together', [])
|
||
if unique_together:
|
||
conflict_fields = unique_together[0]
|
||
else:
|
||
pk_attr = getattr(model_class._meta, 'pk_attr', None)
|
||
if pk_attr:
|
||
conflict_fields = (pk_attr,)
|
||
else:
|
||
raise ValueError(f"模型 {model_class.__name__} 没有定义主键或唯一约束")
|
||
return conflict_fields
|
||
|
||
|
||
@with_transaction
|
||
@handle_db_errors(max_retries=5)
|
||
async def call_stored_procedure(
|
||
self,
|
||
procedure_name: str,
|
||
params_list: List[List[Any]] = None,
|
||
use_transaction: Optional[bool] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
调用数据库存储过程(支持死锁自动重试)
|
||
|
||
Args:
|
||
procedure_name: 存储过程名称
|
||
params_list: 存储过程参数列表,每个元素是一个参数列表(可选,默认[[]])
|
||
use_transaction: 是否使用事务(可选,默认使用实例配置的use_transaction)
|
||
|
||
Returns:
|
||
包含执行结果的字典,包括成功状态、执行时间、影响记录数等
|
||
|
||
Raises:
|
||
DbManagerError: 如果存储过程执行失败
|
||
"""
|
||
if params_list is None:
|
||
params_list = [[]]
|
||
|
||
start_time = datetime.now()
|
||
|
||
# 优化:在获取连接前先检查连接池状态
|
||
conn = await self._get_valid_connection()
|
||
affect_count = 0
|
||
results = []
|
||
|
||
for params in params_list:
|
||
result = await conn.execute_query(
|
||
f'CALL `{procedure_name}`({", ".join(["%s"] * len(params))})',
|
||
params
|
||
)
|
||
count = result[0] if result else 0
|
||
affect_count += count
|
||
results.append(result)
|
||
|
||
execution_time = (datetime.now() - start_time).total_seconds()
|
||
|
||
self.stats['total_processed'] += len(params_list)
|
||
self.stats['batches_executed'] += len(params_list)
|
||
self.stats['last_execution_time'] = execution_time
|
||
|
||
response = {
|
||
"success": True,
|
||
"procedure_name": procedure_name,
|
||
"execution_time": execution_time,
|
||
"total_calls": len(params_list),
|
||
"affected_rows": affect_count,
|
||
"results": results
|
||
}
|
||
|
||
logger.success("存储过程调用", f"{procedure_name}@{self.connection_name}", f"执行时间{execution_time:.3f}秒,影响记录数{affect_count}条")
|
||
return response
|
||
|
||
|
||
@handle_db_errors(max_retries=3)
|
||
async def query_data(self, table_name: str, select_fields: str = '*', filter_string: str = '', order_string: str = '', page_size: int = 1000, page_index: int = 0) -> Dict[str, Any]:
|
||
"""
|
||
查询数据库表数据,获取符合筛选条件的数据,支持重试机制
|
||
|
||
Args:
|
||
table_name: 表名
|
||
select_fields: SELECT字段字符串(默认"*",即查询所有字段)
|
||
filter_string: WHERE条件字符串(可选)
|
||
order_string: ORDER BY排序字符串(可选)
|
||
page_size: 分页查询的页大小(最大/默认1000)
|
||
page_index: 分页查询的页码(默认0,获取全部数据;若大于0则取对应页数据)
|
||
|
||
Returns:
|
||
包含查询结果的字典,包括成功状态、数据列表、总数、执行时间等
|
||
|
||
Raises:
|
||
DbManagerError: 如果查询失败
|
||
"""
|
||
page_size = min(page_size, 1000) # 限制最大页大小为1000
|
||
start_time = datetime.now()
|
||
|
||
# 使用Tortoise的连接池机制,不需要手动关闭连接
|
||
# Tortoise会自动管理连接的获取和释放
|
||
conn = await self._get_valid_connection()
|
||
|
||
# 构建WHERE和ORDER子句
|
||
where = f" WHERE {filter_string}" if filter_string else ''
|
||
order = f" ORDER BY {order_string}" if order_string else ''
|
||
|
||
# 检查索引使用情况
|
||
self._check_index_usage(table_name, filter_string, order_string)
|
||
|
||
# 先获取数据总条数
|
||
count_sql = f'SELECT COUNT(*) as total FROM `{table_name}` {where}'
|
||
count_result = await conn.execute_query(count_sql)
|
||
total = count_result[1][0].get('total', 0)
|
||
|
||
# 查询数据
|
||
all_data = []
|
||
|
||
offset = max((page_index - 1) * page_size, 0)
|
||
while offset < total:
|
||
# 构建带LIMIT和OFFSET的分页查询SQL
|
||
sql = f'SELECT {select_fields} FROM `{table_name}` {where} {order} LIMIT {page_size} OFFSET {offset}'
|
||
_, batch_data = await conn.execute_query(sql)
|
||
all_data.extend(batch_data)
|
||
if page_index > 0: # 若page_index大于0,说明需要查询指定页数据,查询当前页后直接退出
|
||
break
|
||
offset += page_size
|
||
# 如果当前批次数据不足page_size,说明已经获取完所有数据
|
||
if len(batch_data) < page_size:
|
||
break
|
||
|
||
execution_time = (datetime.now() - start_time).total_seconds()
|
||
|
||
# 更新统计信息
|
||
self.stats['total_processed'] += total
|
||
self.stats['batches_executed'] += (total + page_size - 1) // page_size
|
||
self.stats['last_execution_time'] = execution_time
|
||
|
||
# 记录查询性能
|
||
self._record_query_performance(table_name, len(all_data), execution_time)
|
||
|
||
response = {
|
||
"success": True,
|
||
"table_name": table_name,
|
||
"filter": filter_string,
|
||
"order": order_string,
|
||
"execution_time": execution_time,
|
||
"total": total,
|
||
"page_size": page_size,
|
||
"page_index": page_index,
|
||
"data": [dict_to_lower_keys(item) for item in all_data],
|
||
"performance_metrics": {
|
||
"records_per_second": len(all_data) / execution_time if execution_time > 0 else 0,
|
||
"query_type": "paginated" if page_index > 0 else "full"
|
||
}
|
||
}
|
||
|
||
logger.success("数据查询", f"{table_name}@{self.connection_name}", f"执行时间{execution_time:.3f}秒,返回{len(all_data)}条记录")
|
||
if execution_time > 1.0:
|
||
logger.warning("数据查询", f"{table_name}@{self.connection_name}", f"查询执行时间较长: {execution_time:.3f}秒")
|
||
logger.debug(f"数据查询完成:{response}")
|
||
return response
|
||
|
||
|
||
@with_transaction
|
||
@handle_db_errors(max_retries=3)
|
||
async def delete_data(self, table_name: str, filter_string: str = '', use_transaction: Optional[bool] = None) -> Dict[str, Any]:
|
||
"""
|
||
删除数据库表数据
|
||
|
||
Args:
|
||
table_name: 表名
|
||
filter_string: WHERE条件字符串(可选)
|
||
use_transaction: 是否使用事务(可选,默认使用实例配置的use_transaction)
|
||
|
||
Returns:
|
||
包含删除结果的字典,包括成功状态、删除记录数、执行时间等
|
||
|
||
Raises:
|
||
DbManagerError: 如果删除失败
|
||
"""
|
||
start_time = datetime.now()
|
||
|
||
# 使用Tortoise的连接池机制,不需要手动关闭连接
|
||
# Tortoise会自动管理连接的获取和释放
|
||
conn = await self._get_valid_connection()
|
||
|
||
# 构建WHERE子句
|
||
where = f" WHERE {filter_string}" if filter_string else ''
|
||
|
||
# 构建DELETE SQL语句
|
||
delete_sql = f'DELETE FROM `{table_name}` {where}'
|
||
|
||
# 执行删除操作
|
||
affected_rows, data = await conn.execute_query(delete_sql)
|
||
|
||
execution_time = (datetime.now() - start_time).total_seconds()
|
||
|
||
# 更新统计信息
|
||
self.stats['total_processed'] += affected_rows
|
||
self.stats['batches_executed'] += 1
|
||
self.stats['last_execution_time'] = execution_time
|
||
|
||
response = {
|
||
"success": True,
|
||
"table_name": table_name,
|
||
"filter": filter_string,
|
||
"execution_time": execution_time,
|
||
"affected_rows": affected_rows,
|
||
"connection_name": self.connection_name
|
||
}
|
||
|
||
logger.success("数据删除", f"{table_name}@{self.connection_name}", f"影响{affected_rows}行")
|
||
return response
|
||
|
||
|
||
@handle_db_errors(max_retries=3)
|
||
async def _execute_native_sql(self, sql: str, params: List[Any], description: str = "") -> tuple:
|
||
"""
|
||
执行原生 SQL 查询
|
||
|
||
Args:
|
||
sql: SQL语句
|
||
params: SQL参数列表
|
||
description: 操作描述(可选)
|
||
|
||
Returns:
|
||
包含影响行数和数据列表的元组
|
||
|
||
Raises:
|
||
DbManagerError: 如果执行失败
|
||
"""
|
||
start_time = datetime.now()
|
||
# 使用Tortoise的连接池机制,不需要手动关闭连接
|
||
# Tortoise会自动管理连接的获取和释放
|
||
conn = await self._get_valid_connection()
|
||
|
||
# 检查SQL语句是否为空
|
||
if not sql.strip():
|
||
raise DbQueryError("SQL语句为空")
|
||
|
||
count, data_list = await conn.execute_query(sql, params)
|
||
if data_list:
|
||
data_list = [dict_to_lower_keys(row) for row in data_list]
|
||
execution_time = (datetime.now() - start_time).total_seconds()
|
||
|
||
if description:
|
||
logger.debug(f"{description} - 执行时间:{execution_time:.3f}秒")
|
||
|
||
return (count if count else 0, data_list)
|
||
|
||
|
||
async def _bulk_upsert_native_sql(
|
||
self,
|
||
model_class,
|
||
data_list: List[Dict[str, Any]],
|
||
update_fields: Optional[List[str]] = None,
|
||
exclude_fields: Optional[List[str]] = None,
|
||
conflict_fields: Optional[Tuple[str, ...]] = None
|
||
) -> Dict[str, int]:
|
||
"""
|
||
使用原生 SQL 执行批量 upsert
|
||
|
||
Args:
|
||
model_class: Tortoise 模型类
|
||
data_list: 数据列表(不能为空)
|
||
conflict_fields: 冲突检测字段(联合主键,必须为元组形式,可省略,默认自动从model_class._meta.unique_together或model_class._meta.pk_attr获取)
|
||
update_fields: 冲突时更新的字段(可选,默认使用所有非冲突非排除字段)
|
||
exclude_fields: 排除的字段列表(可选,默认使用conflict_fields作为排除字段)
|
||
|
||
Returns:
|
||
包含新增和更新数量的字典: {'inserted': int, 'updated': int, 'total': int}
|
||
"""
|
||
|
||
# 获取冲突字段
|
||
conflict_fields = self._get_conflict_fields(model_class, conflict_fields)
|
||
|
||
# 如果未提供exclude_fields,则初始化为空列表
|
||
# 注意:不再默认排除冲突字段,因为它们可能是必需的主键字段
|
||
if exclude_fields is None:
|
||
exclude_fields = []
|
||
|
||
# 获取表名
|
||
table_name = model_class._meta.db_table
|
||
|
||
# 收集所有记录中的所有字段,排除指定字段
|
||
all_fields_set = set()
|
||
for data in data_list:
|
||
all_fields_set.update(data.keys())
|
||
all_fields = [field for field in all_fields_set if field not in exclude_fields]
|
||
|
||
if not all_fields:
|
||
raise ValueError("没有可插入的字段")
|
||
|
||
# 验证字段
|
||
# 只验证冲突字段,update_fields可能为空
|
||
for field in conflict_fields:
|
||
# 检查字段是否是自增主键
|
||
is_auto_increment_pk = False
|
||
if hasattr(model_class._meta, 'pk_attr') and field == model_class._meta.pk_attr:
|
||
pk_field = model_class._meta.fields_map.get(field)
|
||
if pk_field and pk_field.generated:
|
||
is_auto_increment_pk = True
|
||
# 跳过自增主键的验证
|
||
if not is_auto_increment_pk and field not in all_fields:
|
||
raise ValueError(f"字段 {field} 不在数据字段中")
|
||
|
||
# 注意:不再验证update_fields是否在all_fields中,因为不同批次可能包含不同的字段
|
||
# update_fields是从所有数据中收集的,而all_fields是从当前批次中收集的
|
||
|
||
# 构建字段字符串,使用反引号包裹字段名
|
||
fields_str = ', '.join([f"`{field}`" for field in all_fields])
|
||
total_inserted = 0
|
||
total_updated = 0
|
||
|
||
# 动态调整批量大小,避免锁等待超时
|
||
# 对于批量upsert操作,使用较小的批量大小
|
||
batch_size = min(self.batch_size, 500) # 减小批量大小,降低锁竞争
|
||
|
||
# 分批处理
|
||
for i in range(0, len(data_list), batch_size):
|
||
batch = data_list[i:i + batch_size]
|
||
|
||
# 构建 VALUES 占位符和参数
|
||
placeholders = []
|
||
values = []
|
||
for data in batch:
|
||
# 只包含需要的字段,对于不存在的字段使用None
|
||
row_values = [data.get(field) for field in all_fields]
|
||
placeholders.append('(' + ', '.join(['%s'] * len(all_fields)) + ')')
|
||
values.extend(row_values)
|
||
|
||
# 检查是否有数据要插入
|
||
if not placeholders:
|
||
logger.warning(f"批次 {i//batch_size + 1} 没有数据要插入,跳过执行")
|
||
continue
|
||
|
||
# 构建 SQL
|
||
if update_fields:
|
||
# 如果有update_fields,构建完整的UPSERT语句
|
||
# 获取当前批次中实际存在的字段
|
||
batch_fields_set = set()
|
||
for data in batch:
|
||
batch_fields_set.update(data.keys())
|
||
# 只更新在当前批次中实际存在的字段
|
||
batch_update_fields = [field for field in update_fields if field in batch_fields_set]
|
||
|
||
if batch_update_fields:
|
||
# 构建 ON DUPLICATE KEY UPDATE 部分
|
||
update_parts = [f"`{field}` = VALUES(`{field}`)" for field in batch_update_fields]
|
||
update_str = ', '.join(update_parts)
|
||
|
||
sql = f"""
|
||
INSERT INTO `{table_name}` ({fields_str})
|
||
VALUES {', '.join(placeholders)}
|
||
ON DUPLICATE KEY UPDATE
|
||
{update_str}
|
||
"""
|
||
else:
|
||
# 如果当前批次中没有需要更新的字段,使用INSERT IGNORE
|
||
sql = f"INSERT IGNORE INTO `{table_name}` ({fields_str}) VALUES {', '.join(placeholders)}"
|
||
else:
|
||
# 如果没有update_fields,只执行INSERT IGNORE
|
||
sql = f"INSERT IGNORE INTO `{table_name}` ({fields_str}) VALUES {', '.join(placeholders)}"
|
||
|
||
# 执行 SQL
|
||
try:
|
||
affected, result_data = await self._execute_native_sql(
|
||
sql,
|
||
values,
|
||
description=f"批量 upsert 批次 {i//batch_size + 1}"
|
||
)
|
||
|
||
# 计算新增和更新数量
|
||
if update_fields:
|
||
# 对于 INSERT INTO ... ON DUPLICATE KEY UPDATE:
|
||
# - 新增行:影响行数 = 1
|
||
# - 更新行:影响行数 = 2
|
||
# - 未改变:影响行数 = 0
|
||
# 使用实际处理的数据行数(len(batch))代替batch_size,因为可能有重复数据
|
||
actual_size = len(batch)
|
||
affected = affected or 0
|
||
updated = max(0, affected - actual_size)
|
||
inserted = affected - 2 * updated
|
||
# 确保插入数量为非负数
|
||
inserted = max(0, inserted)
|
||
else:
|
||
# 对于 INSERT IGNORE:
|
||
# - 成功插入:影响行数 = 1
|
||
# - 忽略冲突:影响行数 = 0
|
||
inserted = affected or 0
|
||
updated = 0
|
||
|
||
logger.info(f"批量upsert执行成功: 表={table_name}, 批次={i//batch_size + 1}, 影响行数={affected}, 插入={inserted}, 更新={updated}")
|
||
|
||
total_inserted = (total_inserted or 0) + inserted
|
||
total_updated = (total_updated or 0) + updated
|
||
|
||
self.stats['batches_executed'] += 1
|
||
|
||
# 每处理完一个批次,短暂休眠,避免过度占用数据库资源
|
||
await asyncio.sleep(0.1)
|
||
except Exception as e:
|
||
# 记录错误并继续处理下一批次
|
||
logger.error(f"执行批量upsert批次失败: {e}")
|
||
logger.error(f"SQL语句: {sql}")
|
||
# 跳过当前批次,继续处理下一批次
|
||
continue
|
||
|
||
return {
|
||
'inserted': total_inserted or 0,
|
||
'updated': total_updated or 0,
|
||
'total': (total_inserted or 0) + (total_updated or 0)
|
||
}
|
||
|
||
|
||
@handle_db_errors(max_retries=3)
|
||
async def _bulk_upsert_orm(
|
||
self,
|
||
model_class,
|
||
data_list: List[Dict[str, Any]],
|
||
update_fields: Optional[List[str]] = None,
|
||
exclude_fields: Optional[List[str]] = None,
|
||
conflict_fields: Optional[Tuple[str, ...]] = None
|
||
) -> Dict[str, int]:
|
||
"""
|
||
使用 ORM 的 bulk_create 执行批量 upsert
|
||
适合小批量数据
|
||
|
||
Args:
|
||
model_class: Tortoise 模型类
|
||
data_list: 数据列表(不能为空)
|
||
conflict_fields: 冲突检测字段(联合主键,必须为元组形式,可省略,默认自动从model_class._meta.unique_together或model_class._meta.pk_attr获取)
|
||
update_fields: 冲突时更新的字段
|
||
exclude_fields: 排除的字段列表(必须显式提供)
|
||
|
||
Returns:
|
||
包含新增和更新数量的字典: {'inserted': int, 'updated': int, 'total': int}
|
||
|
||
Raises:
|
||
DbManagerError: 如果执行失败
|
||
"""
|
||
# 获取冲突字段
|
||
# 获取数据库连接对象
|
||
db = await self._get_valid_connection()
|
||
conflict_fields = conflict_fields if conflict_fields is not None else self._get_conflict_fields(model_class)
|
||
|
||
# 获取模型的主键字段
|
||
pk_field = getattr(model_class._meta, 'pk_attr', None)
|
||
|
||
# 如果未提供exclude_fields,则初始化为空列表
|
||
# 注意:不再默认排除冲突字段,因为它们可能是必需的主键字段
|
||
if exclude_fields is None:
|
||
exclude_fields = []
|
||
|
||
# 注意:不要将冲突字段(包括主键)从数据中排除,它们是标识记录所必需的
|
||
# 只需要在后续更新操作中确保不更新主键字段即可
|
||
|
||
# 过滤排除字段
|
||
filtered_data = []
|
||
for data in data_list:
|
||
filtered_data.append({k: v for k, v in data.items()
|
||
if k not in exclude_fields})
|
||
|
||
# 查询已存在的记录
|
||
existing_records = []
|
||
if conflict_fields:
|
||
# 构建查询条件
|
||
conditions = []
|
||
for data in filtered_data:
|
||
condition = {} # 将条件改为字典类型
|
||
for field in conflict_fields:
|
||
if field in data:
|
||
condition[field] = data[field]
|
||
if condition:
|
||
conditions.append(condition)
|
||
|
||
# 查询所有满足冲突条件的记录
|
||
# 使用 Q 对象构建 OR 查询
|
||
if conditions:
|
||
# 第一个条件作为基础
|
||
query = Q(**conditions[0])
|
||
# 为每个条件创建 Q 对象并使用 OR 连接
|
||
for condition in conditions[1:]:
|
||
query |= Q(**condition)
|
||
|
||
existing_records = await model_class.filter(query).only(*conflict_fields).using_db(db).all()
|
||
else:
|
||
existing_records = []
|
||
|
||
# 创建模型实例
|
||
instances = [model_class(**data) for data in filtered_data]
|
||
|
||
# 使用 bulk_create
|
||
# 如果没有update_fields,不执行更新操作(仅插入)
|
||
if existing_records:# update_fields:
|
||
# 使用指定的数据库连接执行bulk_create
|
||
if len(filtered_data) == 1:
|
||
await existing_records[0].update_from_dict(filtered_data[0])
|
||
# 获取需要更新的字段列表,确保不包含主键字段
|
||
update_fields_list = [field for field in filtered_data[0].keys() if field != pk_field]
|
||
await existing_records[0].save(update_fields=update_fields_list)
|
||
# await model_class.filter(query).only(*conflict_fields).using_db(db).all().update_from_dict(filtered_data[0])
|
||
else:
|
||
# 确保update_fields不包含主键字段
|
||
filtered_update_fields = None
|
||
if update_fields:
|
||
filtered_update_fields = [field for field in update_fields if field != pk_field]
|
||
await model_class.bulk_create(instances, on_conflict=conflict_fields, update_fields=filtered_update_fields, using_db=db)
|
||
|
||
else:
|
||
# 只执行插入操作,忽略冲突
|
||
await model_class.bulk_create(instances, ignore_conflicts=True, using_db=db)
|
||
|
||
# 计算新增和更新数量
|
||
# 创建现有记录的冲突字段值的集合,用于快速查找
|
||
existing_keys = set()
|
||
for record in existing_records:
|
||
key = tuple(getattr(record, field) for field in conflict_fields)
|
||
existing_keys.add(key)
|
||
|
||
# 计算新增和更新数量
|
||
updated_count = 0
|
||
inserted_count = 0
|
||
|
||
for data in filtered_data:
|
||
key = tuple(data.get(field) for field in conflict_fields)
|
||
if key in existing_keys:
|
||
updated_count += 1
|
||
else:
|
||
inserted_count += 1
|
||
|
||
return {
|
||
'inserted': inserted_count or 0,
|
||
'updated': updated_count or 0,
|
||
'total': (inserted_count or 0) + (updated_count or 0)
|
||
}
|
||
|
||
|
||
@with_transaction
|
||
@handle_db_errors(max_retries=3)
|
||
async def bulk_upsert(
|
||
self,
|
||
model_class,
|
||
data_list: List[Dict[str, Any]],
|
||
update_fields: Optional[List[str]] = None,
|
||
exclude_fields: Optional[List[str]] = None,
|
||
conflict_fields: Optional[Tuple[str, ...]] = None,
|
||
use_orm_or_sql: Literal["orm", "sql", "auto"] = "sql",
|
||
use_transaction: Optional[bool] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
批量 upsert 主方法
|
||
|
||
Args:
|
||
model_class: Tortoise 模型类
|
||
data_list: 数据字典列表(不能为空)
|
||
conflict_fields: 冲突检测字段(必须为元组形式,可省略,默认自动从model_class._meta.unique_together或model_class._meta.pk_attr获取)
|
||
update_fields: 冲突时更新的字段列表(可选,默认使用所有非冲突非排除字段)
|
||
exclude_fields: 要排除的字段列表(可选,默认使用conflict_fields作为排除字段)
|
||
use_orm_or_sql: 显式指定使用 ORM 或 SQL 执行批量 upsert,默认使用 SQL 执行 (auto时根据数据量自动选择)
|
||
use_transaction: 是否使用事务(可选,默认使用实例配置的use_transaction)
|
||
Returns:
|
||
执行统计信息
|
||
|
||
Raises:
|
||
DbManagerError: 如果执行失败且使用事务
|
||
"""
|
||
start_time = datetime.now()
|
||
db_table = model_class._meta.db_table
|
||
|
||
# 获取冲突字段(需要在计算默认update_fields之前获取)
|
||
if conflict_fields is None:
|
||
conflict_fields = self._get_conflict_fields(model_class, conflict_fields)
|
||
|
||
# 如果未提供exclude_fields,则初始化为空列表
|
||
# 注意:不再默认排除冲突字段,因为它们可能是必需的主键字段
|
||
if exclude_fields is None:
|
||
exclude_fields = []
|
||
|
||
# 如果未提供update_fields,则自动使用所有非冲突非排除字段作为默认更新字段
|
||
if update_fields is None and data_list:
|
||
# 收集所有记录中的所有字段
|
||
all_fields = set()
|
||
for data in data_list:
|
||
all_fields.update(data.keys())
|
||
# 获取冲突字段和排除字段的集合
|
||
# 注意:即使exclude_fields为空,也需要排除冲突字段,因为它们不应该被更新
|
||
excluded_set = set(conflict_fields) | set(exclude_fields)
|
||
# 计算默认更新字段:所有非冲突非排除字段
|
||
update_fields = list(all_fields - excluded_set)
|
||
|
||
# 选择执行策略
|
||
if use_orm_or_sql == "orm" or (use_orm_or_sql == "auto" and len(data_list) < 100):
|
||
method = "orm"
|
||
result = await self._bulk_upsert_orm(
|
||
model_class, data_list, update_fields,
|
||
exclude_fields, conflict_fields
|
||
)
|
||
else:
|
||
method = "native_sql"
|
||
result = await self._bulk_upsert_native_sql(
|
||
model_class, data_list, update_fields,
|
||
exclude_fields, conflict_fields
|
||
)
|
||
|
||
execution_time = (datetime.now() - start_time).total_seconds()
|
||
|
||
# 更新统计
|
||
self.stats['total_processed'] += len(data_list)
|
||
self.stats['last_execution_time'] = execution_time
|
||
|
||
# 记录批量操作性能,用于动态调整批量大小
|
||
self._record_batch_performance(len(data_list), execution_time)
|
||
|
||
response = {
|
||
"success": True,
|
||
"method": method,
|
||
"total_records": len(data_list),
|
||
"affected_rows": result['total'],
|
||
"inserted": result['inserted'],
|
||
"updated": result['updated'],
|
||
"execution_time": execution_time,
|
||
"batch_size": len(data_list),
|
||
"optimal_batch_size": self.optimal_batch_size,
|
||
"conflict_fields": conflict_fields,
|
||
"update_fields": update_fields
|
||
}
|
||
|
||
logger.success("批量upsert", f"{db_table}@{self.connection_name}", f"插入{result['inserted']}条,更新{result['updated']}条,执行时间{execution_time:.3f}秒")
|
||
return response
|
||
|
||
|
||
@with_transaction
|
||
@handle_db_errors(max_retries=3)
|
||
async def single_upsert(
|
||
self,
|
||
model_class,
|
||
data: Dict[str, Any],
|
||
conflict_fields: Optional[Tuple[str, ...]] = None,
|
||
# update_fields: Optional[List[str]] = None,
|
||
# exclude_fields: Optional[List[str]] = None,
|
||
use_transaction: Optional[bool] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
单条记录 upsert 操作
|
||
|
||
Args:
|
||
model_class: Tortoise 模型类
|
||
data: 单条数据字典
|
||
conflict_fields: 冲突检测字段(必须为元组形式,可省略,默认自动从model_class._meta.unique_together或model_class._meta.pk_attr获取)
|
||
update_fields: 冲突时更新的字段列表(可选,默认使用所有非冲突非排除字段)
|
||
exclude_fields: 要排除的字段列表(可选,默认使用conflict_fields作为排除字段)
|
||
use_transaction: 是否使用事务(可选,默认使用实例配置的use_transaction)
|
||
|
||
Returns:
|
||
执行结果字典,包含操作类型、影响行数等信息
|
||
|
||
Raises:
|
||
ValueError: 如果存在多个与冲突字段匹配的记录
|
||
DbManagerError: 如果执行失败
|
||
"""
|
||
# 获取冲突字段
|
||
if conflict_fields is None:
|
||
conflict_fields = conflict_fields or self._get_conflict_fields(model_class)
|
||
# 取字段交集:只保留既在 data 中也在 conflict_fields 里的键,为什么要这样做?
|
||
conflict_fields = tuple(set(conflict_fields) & set(data.keys()))
|
||
# 因为如果 data 中包含了不在 conflict_fields 里的字段,那么在 upsert 时就会报错
|
||
# 具体应用场景:t_supply 联合主键(默认冲突字段)是 supplyno + materialno
|
||
# 而 patch supply 时(单条),前端可能不传入 materialno
|
||
# 此时若按联合主键索引,一则可能 raise materialno 不存在;其次,就算不报错,因为缺少 materialno,也无法索引出目标记录
|
||
else:
|
||
# 检查数据中是否包含所有冲突字段
|
||
missing_fields = [field for field in conflict_fields if field not in data]
|
||
if missing_fields:
|
||
raise ValueError(f"数据中缺少必要的冲突字段: {', '.join(missing_fields)}")
|
||
|
||
# 构建查询条件
|
||
query_conditions = {field: data[field] for field in conflict_fields}
|
||
|
||
# 查询是否存在记录
|
||
conn = await self._get_valid_connection()
|
||
conflict_check_sql = f"""
|
||
SELECT COUNT(*) as count FROM `{model_class._meta.db_table}`
|
||
WHERE {' AND '.join([f'`{field}` = %s' for field in conflict_fields])}
|
||
"""
|
||
|
||
# 执行查询
|
||
result = await conn.execute_query(conflict_check_sql, list(query_conditions.values()))
|
||
count = result[1][0]['count']
|
||
|
||
# 如果存在多条冲突记录,抛出错误
|
||
if count > 1:
|
||
raise ValueError(f"检测到多个 {', '.join(conflict_fields)} 匹配的记录,无法执行单条 upsert 操作")
|
||
|
||
# 计算更新字段
|
||
update_fields = tuple(set(data.keys()) - set(conflict_fields))
|
||
|
||
# 根据记录是否存在,决定执行INSERT还是UPDATE
|
||
conn = await self._get_valid_connection()
|
||
|
||
if count == 0:
|
||
# 执行INSERT操作
|
||
fields = list(data.keys())
|
||
placeholders = ['%s'] * len(fields)
|
||
values = list(data.values())
|
||
|
||
insert_sql = f"""
|
||
INSERT INTO `{model_class._meta.db_table}` ({', '.join([f'`{k}`' for k in fields])})
|
||
VALUES ({', '.join(placeholders)})
|
||
"""
|
||
|
||
result = await conn.execute_query(insert_sql, values)
|
||
affected_rows = result[0]
|
||
operation_type = 'inserted'
|
||
else:
|
||
# 执行UPDATE操作
|
||
if not update_fields:
|
||
# 没有需要更新的字段
|
||
affected_rows = 0
|
||
operation_type = 'no_change'
|
||
else:
|
||
update_parts = [f'`{field}` = %s' for field in update_fields]
|
||
where_parts = [f'`{field}` = %s' for field in conflict_fields]
|
||
|
||
update_sql = f"""
|
||
UPDATE `{model_class._meta.db_table}`
|
||
SET {', '.join(update_parts)}
|
||
WHERE {' AND '.join(where_parts)}
|
||
"""
|
||
|
||
# 构建参数列表:先更新字段的值,再冲突字段的值
|
||
update_values = [data[field] for field in update_fields]
|
||
conflict_values = [data[field] for field in conflict_fields]
|
||
all_values = update_values + conflict_values
|
||
|
||
result = await conn.execute_query(update_sql, all_values)
|
||
affected_rows = result[0]
|
||
operation_type = 'updated'
|
||
|
||
return {
|
||
'success': True,
|
||
'operation_type': operation_type,
|
||
'affected_rows': affected_rows,
|
||
'conflict_fields': conflict_fields,
|
||
'update_fields': update_fields,
|
||
'inserted': 1 if operation_type == 'inserted' else 0,
|
||
'updated': 1 if operation_type == 'updated' else 0
|
||
}
|
||
|
||
|
||
@with_transaction
|
||
@handle_db_errors(max_retries=3)
|
||
async def conditional_bulk_upsert(
|
||
self,
|
||
model_class,
|
||
data_list: List[Dict[str, Any]],
|
||
update_rules: Dict[str, str],
|
||
condition_field: str,
|
||
condition_value: Any,
|
||
conflict_fields: Optional[Tuple[str, ...]] = None,
|
||
use_transaction: Optional[bool] = None
|
||
) -> Dict[str, int]:
|
||
"""
|
||
条件批量 upsert
|
||
支持更复杂的更新逻辑
|
||
|
||
Args:
|
||
conflict_fields: 冲突检测字段(联合主键,必须为元组形式,可省略,默认自动从model_class._meta.unique_together或model_class._meta.pk_attr获取)
|
||
update_rules: 更新规则字典,key为字段名,value为SQL表达式
|
||
注意:所有表达式必须包含VALUES
|
||
例如: {'quantity': 'quantity + VALUES(quantity)', 'price': 'VALUES(price)'}
|
||
condition_field: 条件字段(必需)
|
||
condition_value: 条件值(必需)
|
||
use_transaction: 是否使用事务(可选,默认使用实例配置的use_transaction)
|
||
|
||
Returns:
|
||
包含新增和更新数量的字典: {'inserted': int, 'updated': int, 'total': int}
|
||
|
||
Raises:
|
||
DbManagerError: 如果执行失败
|
||
"""
|
||
# 获取冲突字段
|
||
conflict_fields = self._get_conflict_fields(model_class, conflict_fields)
|
||
|
||
table_name = model_class._meta.db_table
|
||
all_fields = list(data_list[0].keys())
|
||
fields_str = ', '.join([f"`{field}`" for field in all_fields])
|
||
|
||
total_inserted = 0
|
||
total_updated = 0
|
||
|
||
async def execute_batch():
|
||
nonlocal total_inserted, total_updated
|
||
for i in range(0, len(data_list), self.batch_size):
|
||
batch = data_list[i:i + self.batch_size]
|
||
batch_size = len(batch)
|
||
|
||
# 构建 VALUES
|
||
placeholders = []
|
||
values = []
|
||
for data in batch:
|
||
row_values = [data[field] for field in all_fields]
|
||
placeholders.append('(' + ', '.join(['%s'] * len(all_fields)) + ')')
|
||
values.extend(row_values)
|
||
|
||
# 构建条件更新
|
||
update_parts = []
|
||
for field, expression in update_rules.items():
|
||
if 'VALUES' not in expression:
|
||
raise ValueError(f"更新规则表达式必须包含VALUES: {field} = {expression}")
|
||
update_parts.append(f"`{field}` = {expression}")
|
||
|
||
# 添加条件
|
||
where_clause = f"WHERE `{condition_field}` = %s"
|
||
values.append(condition_value)
|
||
|
||
# 为冲突字段字符串添加反引号包裹
|
||
conflict_fields_str = ', '.join([f"`{field}`" for field in conflict_fields])
|
||
update_str = ', '.join(update_parts)
|
||
|
||
# 构建 SQL 语句
|
||
sql = f"""
|
||
INSERT INTO `{table_name}` ({fields_str})
|
||
VALUES {', '.join(placeholders)}
|
||
ON DUPLICATE KEY UPDATE
|
||
{update_str}
|
||
{where_clause}
|
||
"""
|
||
|
||
affected, data_list = await self._execute_native_sql(
|
||
sql, values, f"条件批量 upsert 批次 {i//self.batch_size + 1}"
|
||
)
|
||
|
||
# 计算新增和更新数量
|
||
# 对于条件更新,我们需要预查询来获取更准确的计数
|
||
# 先查询已存在的记录
|
||
existing_records = []
|
||
if conflict_fields:
|
||
# 构建查询条件
|
||
conditions = []
|
||
for data in batch:
|
||
condition = {}
|
||
for field in conflict_fields:
|
||
condition[field] = data[field]
|
||
conditions.append(condition)
|
||
|
||
# 查询所有满足冲突条件的记录
|
||
# 使用 Q 对象构建 OR 查询
|
||
if conditions:
|
||
# 第一个条件作为基础
|
||
query = Q(**conditions[0])
|
||
# 为每个条件创建 Q 对象并使用 OR 连接
|
||
for condition in conditions[1:]:
|
||
query |= Q(**condition)
|
||
|
||
existing_records = await model_class.filter(
|
||
query
|
||
).using_db(self.connection_name).all()
|
||
else:
|
||
existing_records = []
|
||
|
||
# 计算新增数量
|
||
inserted = batch_size - len(existing_records)
|
||
|
||
# 计算更新数量
|
||
# 对于条件更新,影响行数可能不等于更新行数
|
||
# 影响行数 = 新增行数 + 更新成功的行数
|
||
updated = max(0, affected - inserted)
|
||
|
||
total_inserted = (total_inserted or 0) + inserted
|
||
total_updated = (total_updated or 0) + updated
|
||
|
||
# 移除事务分支,直接执行批次处理
|
||
await execute_batch()
|
||
|
||
return {
|
||
'inserted': total_inserted or 0,
|
||
'updated': total_updated or 0,
|
||
'total': (total_inserted or 0) + (total_updated or 0)
|
||
}
|
||
|
||
|
||
def get_stats(self) -> Dict[str, Any]:
|
||
"""获取统计信息"""
|
||
return self.stats.copy()
|
||
|
||
|
||
def reset_stats(self):
|
||
"""重置统计信息"""
|
||
self.stats = {
|
||
'total_processed': 0,
|
||
'batches_executed': 0,
|
||
'last_execution_time': None
|
||
}
|
||
|
||
def _record_batch_performance(self, batch_size: int, execution_time: float):
|
||
"""
|
||
记录批量操作性能,用于动态调整批量大小
|
||
|
||
Args:
|
||
batch_size: 批量大小
|
||
execution_time: 执行时间(秒)
|
||
"""
|
||
if execution_time > 0:
|
||
# 计算每秒处理记录数
|
||
records_per_second = batch_size / execution_time
|
||
self.batch_size_history.append({
|
||
'batch_size': batch_size,
|
||
'execution_time': execution_time,
|
||
'records_per_second': records_per_second
|
||
})
|
||
|
||
# 限制历史记录数量
|
||
if len(self.batch_size_history) > 50:
|
||
self.batch_size_history = self.batch_size_history[-50:]
|
||
|
||
# 每10次操作调整一次批量大小
|
||
if len(self.batch_size_history) % self.batch_size_adjustment_interval == 0:
|
||
self._adjust_batch_size()
|
||
|
||
def _adjust_batch_size(self):
|
||
"""
|
||
根据历史性能数据调整批量大小
|
||
"""
|
||
if len(self.batch_size_history) < 5:
|
||
return
|
||
|
||
# 分析最近的性能数据
|
||
recent_history = self.batch_size_history[-10:]
|
||
|
||
# 计算不同批量大小的平均性能
|
||
batch_performance = {}
|
||
for record in recent_history:
|
||
batch_size = record['batch_size']
|
||
if batch_size not in batch_performance:
|
||
batch_performance[batch_size] = []
|
||
batch_performance[batch_size].append(record['records_per_second'])
|
||
|
||
# 计算每个批量大小的平均性能
|
||
avg_performance = {}
|
||
for batch_size, performances in batch_performance.items():
|
||
avg_performance[batch_size] = sum(performances) / len(performances)
|
||
|
||
# 找出性能最好的批量大小
|
||
best_batch_size = max(avg_performance, key=avg_performance.get)
|
||
best_performance = avg_performance[best_batch_size]
|
||
|
||
# 计算当前批量大小的性能
|
||
current_performance = avg_performance.get(self.batch_size, 0)
|
||
|
||
# 如果最佳批量大小的性能比当前高20%以上,则调整
|
||
if best_performance > current_performance * 1.2:
|
||
# 调整批量大小,在最佳批量大小基础上进行微调
|
||
new_batch_size = int(best_batch_size * 1.1) # 稍微增加一点
|
||
new_batch_size = max(self.min_batch_size, min(self.max_batch_size, new_batch_size))
|
||
|
||
if new_batch_size != self.batch_size:
|
||
# 避免除以零的错误
|
||
if current_performance > 0:
|
||
performance_improvement = ((best_performance / current_performance) - 1) * 100
|
||
logger.info(f"调整批量大小: {self.batch_size} -> {new_batch_size}, 性能提升: {performance_improvement:.1f}%")
|
||
else:
|
||
logger.info(f"调整批量大小: {self.batch_size} -> {new_batch_size}")
|
||
self.batch_size = new_batch_size
|
||
self.optimal_batch_size = new_batch_size
|
||
|
||
def _check_index_usage(self, table_name: str, filter_string: str, order_string: str):
|
||
"""
|
||
检查索引使用情况,提供索引优化建议
|
||
|
||
Args:
|
||
table_name: 表名
|
||
filter_string: WHERE条件字符串
|
||
order_string: ORDER BY排序字符串
|
||
"""
|
||
# 简单的索引使用检查
|
||
# 实际项目中可以根据数据库类型执行EXPLAIN查询来分析索引使用情况
|
||
if filter_string:
|
||
# 检查是否使用了索引字段
|
||
# 这里只是简单的检查,实际项目中可以根据具体表结构和索引进行更详细的分析
|
||
pass
|
||
|
||
if order_string:
|
||
# 检查排序字段是否有索引
|
||
pass
|
||
|
||
def _record_query_performance(self, table_name: str, record_count: int, execution_time: float):
|
||
"""
|
||
记录查询性能,用于监控和优化
|
||
|
||
Args:
|
||
table_name: 表名
|
||
record_count: 返回记录数
|
||
execution_time: 执行时间(秒)
|
||
"""
|
||
# 记录查询性能数据
|
||
if execution_time > 0:
|
||
records_per_second = record_count / execution_time
|
||
# 可以将性能数据存储到监控系统中
|
||
logger.debug(f"查询性能: {table_name} - {records_per_second:.2f} 条/秒, 执行时间: {execution_time:.3f}秒")
|
||
|
||
|
||
def switch_connection(self, connection_name: str):
|
||
"""
|
||
切换数据库连接
|
||
|
||
Args:
|
||
connection_name: 新的数据库连接名称
|
||
"""
|
||
self.connection_name = connection_name
|
||
logger.info(f"已切换数据库连接至:{connection_name}")
|
||
|
||
async def get_connection_pool_status(self) -> Dict[str, Any]:
|
||
"""
|
||
获取连接池状态(增强版)
|
||
|
||
Returns:
|
||
连接池状态信息
|
||
"""
|
||
try:
|
||
conn = Tortoise.get_connection(self.connection_name)
|
||
pool = conn._pool if hasattr(conn, '_pool') else None
|
||
|
||
status = {
|
||
'connection_name': self.connection_name,
|
||
'pool_available': pool is not None,
|
||
'timestamp': time.time(),
|
||
'warnings': [],
|
||
'alerts': []
|
||
}
|
||
|
||
if pool:
|
||
# 不同数据库后端的连接池属性可能不同
|
||
# 尝试多种可能的属性名称
|
||
# 当前连接数
|
||
if hasattr(pool, '_size'):
|
||
status['current_size'] = pool._size
|
||
elif hasattr(pool, 'size'):
|
||
status['current_size'] = pool.size
|
||
elif hasattr(pool, 'current_size'):
|
||
status['current_size'] = pool.current_size
|
||
elif hasattr(pool, '_connections'):
|
||
status['current_size'] = len(pool._connections)
|
||
elif hasattr(pool, 'connections'):
|
||
status['current_size'] = len(pool.connections)
|
||
else:
|
||
status['current_size'] = 10 # 默认值
|
||
|
||
# 最大连接数
|
||
if hasattr(pool, '_maxsize'):
|
||
status['max_size'] = pool._maxsize
|
||
elif hasattr(pool, 'maxsize'):
|
||
status['max_size'] = pool.maxsize
|
||
elif hasattr(pool, 'max_size'):
|
||
status['max_size'] = pool.max_size
|
||
elif hasattr(pool, 'maximum_size'):
|
||
status['max_size'] = pool.maximum_size
|
||
else:
|
||
status['max_size'] = 30 # 默认值
|
||
|
||
# 最小连接数
|
||
if hasattr(pool, '_minsize'):
|
||
status['min_size'] = pool._minsize
|
||
elif hasattr(pool, 'minsize'):
|
||
status['min_size'] = pool.minsize
|
||
elif hasattr(pool, 'min_size'):
|
||
status['min_size'] = pool.min_size
|
||
elif hasattr(pool, 'minimum_size'):
|
||
status['min_size'] = pool.minimum_size
|
||
else:
|
||
status['min_size'] = 10 # 默认值
|
||
|
||
# 空闲连接数
|
||
if hasattr(pool, '_idle'):
|
||
status['idle_connections'] = len(pool._idle)
|
||
elif hasattr(pool, 'idle'):
|
||
status['idle_connections'] = len(pool.idle)
|
||
elif hasattr(pool, 'idle_connections'):
|
||
status['idle_connections'] = pool.idle_connections
|
||
elif hasattr(pool, 'free'):
|
||
status['idle_connections'] = len(pool.free)
|
||
elif hasattr(pool, 'free_connections'):
|
||
status['idle_connections'] = pool.free_connections
|
||
elif hasattr(pool, '_free'):
|
||
status['idle_connections'] = len(pool._free)
|
||
else:
|
||
status['idle_connections'] = status.get('current_size', 10) # 默认值
|
||
|
||
# 使用中连接数
|
||
if hasattr(pool, '_used'):
|
||
status['used_connections'] = len(pool._used)
|
||
elif hasattr(pool, 'used'):
|
||
status['used_connections'] = len(pool.used)
|
||
elif hasattr(pool, 'used_connections'):
|
||
status['used_connections'] = pool.used_connections
|
||
elif hasattr(pool, 'in_use'):
|
||
status['used_connections'] = len(pool.in_use)
|
||
elif hasattr(pool, 'busy'):
|
||
status['used_connections'] = len(pool.busy)
|
||
elif hasattr(pool, 'busy_connections'):
|
||
status['used_connections'] = pool.busy_connections
|
||
elif hasattr(pool, '_busy'):
|
||
status['used_connections'] = len(pool._busy)
|
||
# 如果无法直接获取使用中连接数,尝试计算
|
||
elif 'current_size' in status and 'idle_connections' in status:
|
||
status['used_connections'] = status['current_size'] - status['idle_connections']
|
||
# 直接设置默认值
|
||
else:
|
||
status['used_connections'] = 0 # 默认值
|
||
|
||
# 计算使用率和预警
|
||
if status.get('max_size', 0) > 0:
|
||
status['usage_rate'] = status['used_connections'] / status['max_size'] * 100
|
||
|
||
# 使用率预警
|
||
if status['usage_rate'] >= 90:
|
||
status['alerts'].append(f"连接池使用率过高: {status['usage_rate']:.1f}%")
|
||
elif status['usage_rate'] >= 80:
|
||
status['warnings'].append(f"连接池使用率较高: {status['usage_rate']:.1f}%")
|
||
|
||
# 空闲连接预警
|
||
if status['idle_connections'] == 0:
|
||
status['warnings'].append("连接池没有空闲连接")
|
||
|
||
return status
|
||
except Exception as e:
|
||
logger.error(f"获取连接池状态失败: {e}")
|
||
return {
|
||
'connection_name': self.connection_name,
|
||
'pool_available': False,
|
||
'error': str(e),
|
||
'timestamp': time.time()
|
||
}
|
||
|
||
async def check_connection_health(self, timeout: int = 5, fast_mode: bool = False) -> bool:
|
||
"""
|
||
检查数据库连接健康状态
|
||
|
||
Args:
|
||
timeout: 查询超时时间(秒)
|
||
fast_mode: 是否使用快速模式,快速模式下刷新连接时使用较少的重试次数
|
||
|
||
Returns:
|
||
bool: 连接是否健康
|
||
"""
|
||
try:
|
||
import asyncio
|
||
conn = Tortoise.get_connection(self.connection_name)
|
||
# 执行一个简单的查询来检查连接是否有效,添加超时设置
|
||
await asyncio.wait_for(conn.execute_query("SELECT 1"), timeout=timeout)
|
||
return True
|
||
except asyncio.TimeoutError:
|
||
logger.warning(f"数据库连接健康检查超时: {self.connection_name}")
|
||
# 尝试刷新连接
|
||
await self.refresh_connection(fast_mode=fast_mode)
|
||
# 再次检查连接
|
||
try:
|
||
conn = Tortoise.get_connection(self.connection_name)
|
||
await asyncio.wait_for(conn.execute_query("SELECT 1"), timeout=timeout)
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"刷新连接后健康检查仍失败: {e}")
|
||
return False
|
||
except Exception as e:
|
||
logger.warning(f"数据库连接健康检查失败: {e}")
|
||
# 尝试刷新连接
|
||
await self.refresh_connection(fast_mode=fast_mode)
|
||
# 再次检查连接
|
||
try:
|
||
conn = Tortoise.get_connection(self.connection_name)
|
||
await asyncio.wait_for(conn.execute_query("SELECT 1"), timeout=timeout)
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"刷新连接后健康检查仍失败: {e}")
|
||
return False
|
||
|
||
async def refresh_connection(self, fast_mode: bool = False):
|
||
"""
|
||
刷新数据库连接,确保连接有效
|
||
当出现数据包序列号错误等协议层问题时,会强制重新创建连接
|
||
|
||
Args:
|
||
fast_mode: 是否使用快速模式,快速模式下使用较少的重试次数和较短的等待时间
|
||
"""
|
||
import asyncio
|
||
import time
|
||
from tortoise.connection import connections
|
||
# 根据模式设置不同的重试参数
|
||
if fast_mode:
|
||
retry_count = 3 # 快速模式下减少重试次数
|
||
retry_delay = 1.0 # 快速模式下减少初始延迟
|
||
else:
|
||
retry_count = 5 # 正常模式下增加重试次数
|
||
retry_delay = 2.0 # 正常模式下增加初始延迟
|
||
|
||
for attempt in range(retry_count):
|
||
try:
|
||
start_time = time.time()
|
||
logger.info(f"尝试刷新连接 {self.connection_name} (尝试 {attempt + 1}/{retry_count})")
|
||
|
||
# 1. 尝试获取并关闭当前连接
|
||
try:
|
||
conn = Tortoise.get_connection(self.connection_name)
|
||
|
||
# 关闭当前连接(如果存在且可关闭)
|
||
if conn and hasattr(conn, 'close'):
|
||
try:
|
||
# 检查是否是TransactionWrapper(通过检查是否有_pool属性)
|
||
if hasattr(conn, '_pool'):
|
||
# 尝试关闭连接,但如果出现事件循环冲突,就跳过
|
||
try:
|
||
await conn.close()
|
||
logger.info(f"已关闭旧连接: {self.connection_name}")
|
||
except Exception as close_error:
|
||
# 如果是事件循环冲突,就跳过关闭操作
|
||
if "bound to a different event loop" in str(close_error):
|
||
logger.warning(f"关闭旧连接时出现事件循环冲突,跳过关闭操作: {self.connection_name}")
|
||
else:
|
||
logger.warning(f"关闭旧连接时出错: {close_error}")
|
||
else:
|
||
# 对于TransactionWrapper,尝试获取内部连接并关闭
|
||
if hasattr(conn, 'connection'):
|
||
inner_conn = conn.connection
|
||
if inner_conn and hasattr(inner_conn, 'close'):
|
||
try:
|
||
await inner_conn.close()
|
||
logger.info(f"已关闭TransactionWrapper内部连接: {self.connection_name}")
|
||
except Exception as inner_close_error:
|
||
# 如果是事件循环冲突,就跳过关闭操作
|
||
if "bound to a different event loop" in str(inner_close_error):
|
||
logger.warning(f"关闭TransactionWrapper内部连接时出现事件循环冲突,跳过关闭操作: {self.connection_name}")
|
||
else:
|
||
logger.warning(f"关闭TransactionWrapper内部连接时出错: {inner_close_error}")
|
||
logger.info(f"处理TransactionWrapper连接: {self.connection_name}")
|
||
except Exception as close_error:
|
||
# 如果是事件循环冲突,就跳过关闭操作
|
||
if "bound to a different event loop" in str(close_error):
|
||
logger.warning(f"关闭旧连接时出现事件循环冲突,跳过关闭操作: {self.connection_name}")
|
||
else:
|
||
logger.warning(f"关闭旧连接时出错: {close_error}")
|
||
except Exception as get_conn_error:
|
||
logger.warning(f"获取当前连接时出错: {get_conn_error}")
|
||
|
||
# 2. 从 Tortoise ORM 的连接存储中移除损坏的连接
|
||
# 这样下次调用 connections.get() 时会自动创建新的连接
|
||
try:
|
||
connections.discard(self.connection_name)
|
||
logger.info(f"已从连接存储中移除: {self.connection_name}")
|
||
except Exception as discard_error:
|
||
logger.warning(f"移除连接时出错: {discard_error}")
|
||
|
||
# 3. Tortoise ORM已由lifespan初始化,禁止运行时重新初始化
|
||
# 运行时重新初始化会破坏TortoiseContext导致所有请求失败
|
||
if not hasattr(Tortoise, '_inited') or not Tortoise._inited:
|
||
logger.error(f"Tortoise未初始化,无法刷新连接。请等待应用启动完成")
|
||
raise RuntimeError("Tortoise ORM未初始化,请等待应用启动完成")
|
||
|
||
# 3. 等待一段时间,确保连接完全关闭
|
||
if fast_mode:
|
||
wait_time = 0.5 + attempt * 0.5 # 快速模式下使用较短的等待时间
|
||
else:
|
||
wait_time = 1.0 + attempt * 1.0 # 正常模式下使用较长的等待时间
|
||
logger.info(f"等待 {wait_time:.1f} 秒,确保连接完全关闭...")
|
||
await asyncio.sleep(wait_time)
|
||
|
||
# 4. 尝试重新获取连接(会自动创建新的连接)
|
||
try:
|
||
logger.info(f"尝试创建新连接: {self.connection_name}")
|
||
new_conn = connections.get(self.connection_name)
|
||
|
||
# 5. 检查新连接是否有效
|
||
if new_conn:
|
||
# 执行一个简单的查询来验证连接
|
||
logger.info(f"验证新连接: {self.connection_name}")
|
||
await new_conn.execute_query("SELECT 1")
|
||
elapsed_time = time.time() - start_time
|
||
logger.info(f"数据库连接已刷新并验证: {self.connection_name} (耗时: {elapsed_time:.2f}秒)")
|
||
return True
|
||
else:
|
||
logger.error(f"刷新数据库连接失败: 无法获取新连接")
|
||
if attempt < retry_count - 1:
|
||
logger.info(f"将在 {retry_delay} 秒后重试...")
|
||
await asyncio.sleep(retry_delay)
|
||
retry_delay *= 1.5 # 指数退避
|
||
continue
|
||
except Exception as get_conn_error:
|
||
logger.error(f"获取新连接时出错: {get_conn_error}")
|
||
# 针对 Packet sequence number 错误进行特殊处理
|
||
if "Packet sequence number wrong" in str(get_conn_error):
|
||
logger.error(f"检测到数据包序列号错误,这是 MySQL 协议层问题")
|
||
logger.error(f"将进行更彻底的连接重置...")
|
||
# 额外等待一段时间
|
||
if fast_mode:
|
||
await asyncio.sleep(1.0) # 快速模式下使用较短的等待时间
|
||
else:
|
||
await asyncio.sleep(2.0) # 正常模式下使用较长的等待时间
|
||
if attempt < retry_count - 1:
|
||
logger.info(f"将在 {retry_delay} 秒后重试...")
|
||
await asyncio.sleep(retry_delay)
|
||
retry_delay *= 1.5 # 指数退避
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"刷新数据库连接失败: {e}")
|
||
if attempt < retry_count - 1:
|
||
logger.info(f"将在 {retry_delay} 秒后重试...")
|
||
await asyncio.sleep(retry_delay)
|
||
retry_delay *= 1.5 # 指数退避
|
||
continue
|
||
|
||
# 所有尝试都失败
|
||
logger.error(f"所有尝试都失败: 无法刷新数据库连接 {self.connection_name}")
|
||
return False
|
||
|
||
|
||
@with_transaction
|
||
@handle_db_errors(max_retries=3)
|
||
async def update_by_index(
|
||
self,
|
||
model_class,
|
||
index_dict: Dict[str, Any],
|
||
new_values_dict: Dict[str, Any],
|
||
not_found_behavior: Literal["insert", "error", "skip"] = "error",
|
||
use_transaction: Optional[bool] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
基于索引更新记录,支持更新联合主键字段
|
||
|
||
Args:
|
||
model_class: Tortoise 模型类
|
||
index_dict: 用于索引记录的字典,包含旧的键值
|
||
new_values_dict: 新值构成的字典,可包含联合主键字段
|
||
not_found_behavior: 找不到记录时的行为:"insert" 新增,"error" 报错,"skip" 略过
|
||
use_transaction: 是否使用事务(可选,默认使用实例配置的use_transaction)
|
||
|
||
Returns:
|
||
执行结果字典,包含操作类型、影响行数等信息
|
||
|
||
Raises:
|
||
ValueError: 当 not_found_behavior 为 "error" 且找不到记录时
|
||
DbQueryError: 当数据库操作失败时
|
||
DbConnectionError: 当数据库连接失败时
|
||
"""
|
||
start_time = datetime.now()
|
||
|
||
# 先检查连接是否健康,如果不健康就刷新连接
|
||
is_healthy = await self.check_connection_health(fast_mode=True)
|
||
if not is_healthy:
|
||
logger.warning(f"数据库连接不健康,将刷新连接: {self.connection_name}")
|
||
await self.refresh_connection(fast_mode=True)
|
||
|
||
table_name = model_class._meta.db_table
|
||
conn = await self._get_valid_connection()
|
||
|
||
# 构建 WHERE 子句(使用旧值)
|
||
where_parts = []
|
||
where_values = []
|
||
for field, value in index_dict.items():
|
||
where_parts.append(f"`{field}` = %s")
|
||
where_values.append(value)
|
||
|
||
where_clause = " WHERE " + " AND ".join(where_parts) if where_parts else ""
|
||
|
||
# 检查记录是否存在
|
||
check_sql = f"SELECT COUNT(*) as count FROM `{table_name}`{where_clause}"
|
||
result = await conn.execute_query(check_sql, where_values)
|
||
count = result[1][0]['count']
|
||
|
||
if count == 0:
|
||
if not_found_behavior == "error":
|
||
raise ValueError(f"未找到匹配记录: {index_dict}")
|
||
elif not_found_behavior == "insert":
|
||
# 执行插入操作
|
||
all_fields = list(index_dict.keys()) + list(new_values_dict.keys())
|
||
all_fields = list(set(all_fields)) # 去重
|
||
|
||
fields_str = ', '.join([f"`{field}`" for field in all_fields])
|
||
placeholders = ', '.join(['%s'] * len(all_fields))
|
||
|
||
values = []
|
||
for field in all_fields:
|
||
if field in new_values_dict:
|
||
values.append(new_values_dict[field])
|
||
elif field in index_dict:
|
||
values.append(index_dict[field])
|
||
|
||
insert_sql = f"INSERT INTO `{table_name}` ({fields_str}) VALUES ({placeholders})"
|
||
affected_rows, data_list = await self._execute_native_sql(
|
||
insert_sql,
|
||
values,
|
||
description="基于索引更新 - 新增记录"
|
||
)
|
||
|
||
operation_type = 'inserted'
|
||
else: # skip
|
||
affected_rows = 0
|
||
operation_type = 'skipped'
|
||
else:
|
||
# 执行更新操作
|
||
# 构建 SET 子句(使用新值)
|
||
set_parts = []
|
||
set_values = []
|
||
for field, value in new_values_dict.items():
|
||
set_parts.append(f"`{field}` = %s")
|
||
set_values.append(value)
|
||
|
||
if not set_parts:
|
||
affected_rows = 0
|
||
operation_type = 'no_change'
|
||
else:
|
||
set_clause = " SET " + ", ".join(set_parts)
|
||
update_sql = f"UPDATE `{table_name}`{set_clause}{where_clause}"
|
||
|
||
affected_rows, data_list = await self._execute_native_sql(
|
||
update_sql,
|
||
set_values + where_values,
|
||
description="基于索引更新 - 更新记录"
|
||
)
|
||
|
||
operation_type = 'updated'
|
||
|
||
execution_time = (datetime.now() - start_time).total_seconds()
|
||
|
||
# 更新统计信息
|
||
self.stats['total_processed'] += 1
|
||
self.stats['batches_executed'] += 1
|
||
self.stats['last_execution_time'] = execution_time
|
||
|
||
response = {
|
||
"success": True,
|
||
"operation_type": operation_type,
|
||
"affected_rows": affected_rows,
|
||
"index_dict": index_dict,
|
||
"updated_fields": list(new_values_dict.keys()),
|
||
"execution_time": execution_time
|
||
}
|
||
|
||
logger.success("索引更新", f"{table_name}@{self.connection_name}", f"影响{affected_rows}行")
|
||
return response
|
||
|
||
|
||
# 延迟初始化 db_managers
|
||
_db_managers = None
|
||
|
||
def get_db_managers():
|
||
"""
|
||
获取数据库管理器实例字典
|
||
返回同一个实例字典,确保统计信息的持久化
|
||
"""
|
||
global _db_managers
|
||
if _db_managers is None:
|
||
_db_managers = {}
|
||
for db in MYAPS_DBSET_LIST:
|
||
_db_managers[db] = DbManager(db)
|
||
return _db_managers
|
||
|
||
# 为了保持向后兼容,提供一个模块级别的变量
|
||
# 但在实际使用中,建议使用 get_db_managers() 函数来获取
|
||
def db_managers():
|
||
"""
|
||
获取数据库管理器实例字典
|
||
每次调用都会返回最新的实例字典,确保使用当前事件循环的连接
|
||
"""
|
||
return get_db_managers() |