fix(security,stability): 完成API安全与稳定性修复

- 安全: 修复鉴权失败返回码(HTTP 401/403替代200)
- 安全: 新增SafeQueryBuilder封堵SQL注入入口
- 安全: 移除Pydantic json_encoders弃用配置
- 稳定: 统一后台任务托管与生命周期管理
- 稳定: 新增TaskManager统一管理后台任务
- 文档: 更新README.md与.env.example
- 重构: routers.py使用安全SQL构建器替代字符串拼接
This commit is contained in:
2026-05-25 20:08:35 +08:00
parent f7acae8dee
commit fa5cecd6d1
11 changed files with 730 additions and 172 deletions
+4
View File
@@ -10,6 +10,10 @@ TIMEZONE=+8
# 项目目录配置(必填) # 项目目录配置(必填)
PROJECT_DIR=YOUR_PROJECT_DIR PROJECT_DIR=YOUR_PROJECT_DIR
# 安全配置(可选)
# API_KEY=your-api-key-here # API密钥,设置后非公开接口需要认证
# IP_WHITELIST=127.0.0.1,192.168.1.* # IP白名单,支持通配符、范围和CIDR
# 数据库配置 # 数据库配置
MYAPS_DB_HOST=localhost MYAPS_DB_HOST=localhost
MYAPS_DB_PORT=3333 MYAPS_DB_PORT=3333
+80
View File
@@ -159,12 +159,92 @@ MYAPS_DB_PASSWORD
- `STAGING_DB_NAME` 默认为 `--s` - `STAGING_DB_NAME` 默认为 `--s`
- 清洗模式与主业务数据库配置分离 - 清洗模式与主业务数据库配置分离
## 安全配置说明
### API 认证
项目支持基于 API Key 的认证机制:
```bash
# 设置 API Key(可选)
API_KEY=your-api-key-here
```
- 未设置 `API_KEY` 时,所有接口可自由访问
- 设置后,非公开接口需要在请求头中携带 `X-API-Key`
- 认证失败返回 HTTP 401(不再返回 200
### 公开接口
以下接口无需认证即可访问:
- `/health` - 健康检查(用于 K8s/负载均衡)
- `/health/database` - 数据库健康检查
- `/static/*` - 静态资源
- `/mds/*` - MDS 页面
- `/docs`, `/redoc` - API 文档(仅限内网访问)
### IP 白名单
支持多种格式的 IP 白名单配置:
```bash
IP_WHITELIST=127.0.0.1,192.168.1.*,10.0.0.0/8,192.168.1.100-200
```
支持的格式:
- 精确 IP: `192.168.1.100`
- 通配符: `192.168.1.*`
- IP 范围: `192.168.1.100-200`
- CIDR 表示法: `10.0.0.0/8`
### SQL 注入防护
项目已实现 SQL 注入防护机制:
- 使用安全 SQL 构建函数(`escape_sql_value`, `build_safe_condition`
- 自动转义用户输入
- 标识符验证防止注入攻击
- 所有外部输入都经过安全处理
## 监控与日志 ## 监控与日志
- 实时日志与历史日志页面位于 `/monitor` - 实时日志与历史日志页面位于 `/monitor`
- 统一日志系统位于 `globalobjects/logger/` - 统一日志系统位于 `globalobjects/logger/`
- 开发期可通过 `./scripts/dev_server.sh logs -f` 查看实时日志 - 开发期可通过 `./scripts/dev_server.sh logs -f` 查看实时日志
## 生命周期管理
### 后台任务管理
项目使用统一的后台任务管理器(`core/task_manager.py`):
- 所有后台任务统一注册和跟踪
- 支持优雅关闭和超时保护
- 自动清理已完成的任务
### 关闭顺序
应用关闭时按照以下顺序执行,确保资源正确释放:
```
阶段1: 取消所有后台任务(优先执行)
阶段2: 停止服务和监控器
阶段3: 释放资源和连接(数据库最后关闭)
阶段4: 关闭日志系统
```
关闭时会输出明确提示:
```
==================应用关闭完成==================
所有资源已释放,服务已完全停止
==================================================
MyAPS API 应用已完全关闭
感谢使用,再见!
==================================================
```
## 当前验证状态 ## 当前验证状态
在当前仓库环境下,以下命令已验证通过: 在当前仓库环境下,以下命令已验证通过:
+9 -9
View File
@@ -104,7 +104,7 @@ def get_host_from_url(url: str) -> Optional[str]:
class APIRequest(Model): class APIRequest(Model):
"""API 请求记录模型""" """API 请求记录模型"""
id = fields.IntField(pk=True, auto_generate=True) id = fields.IntField(primary_key=True, auto_generate=True)
request_id = fields.CharField(max_length=36, null=True, description="请求唯一标识(UUID") request_id = fields.CharField(max_length=36, null=True, description="请求唯一标识(UUID")
timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="请求时间") timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="请求时间")
method = fields.CharField(max_length=10, description="HTTP 方法") method = fields.CharField(max_length=10, description="HTTP 方法")
@@ -144,7 +144,7 @@ class APIRequest(Model):
class OutboundAPIRequest(Model): class OutboundAPIRequest(Model):
"""对外 HTTP 请求记录模型""" """对外 HTTP 请求记录模型"""
id = fields.IntField(pk=True, auto_generate=True) id = fields.IntField(primary_key=True, auto_generate=True)
timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="请求时间") timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="请求时间")
method = fields.CharField(max_length=10, description="HTTP 方法") method = fields.CharField(max_length=10, description="HTTP 方法")
url = fields.TextField(description="请求 URL") url = fields.TextField(description="请求 URL")
@@ -177,7 +177,7 @@ class OutboundAPIRequest(Model):
class SystemLog(Model): class SystemLog(Model):
"""系统日志模型""" """系统日志模型"""
id = fields.IntField(pk=True, auto_generate=True) id = fields.IntField(primary_key=True, auto_generate=True)
timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间") timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间")
level = fields.CharField(max_length=10, description="日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL") level = fields.CharField(max_length=10, description="日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL")
module = fields.CharField(max_length=255, description="模块名称") module = fields.CharField(max_length=255, description="模块名称")
@@ -205,7 +205,7 @@ class SystemLog(Model):
class BinlogPosition(Model): class BinlogPosition(Model):
"""Binlog 位置记录模型""" """Binlog 位置记录模型"""
id = fields.IntField(pk=True, auto_generate=True) id = fields.IntField(primary_key=True, auto_generate=True)
server_id = fields.CharField(max_length=255, description="MySQL 服务器标识") server_id = fields.CharField(max_length=255, description="MySQL 服务器标识")
log_file = fields.CharField(max_length=255, description="Binlog 文件名") log_file = fields.CharField(max_length=255, description="Binlog 文件名")
log_pos = fields.BigIntField(description="Binlog 位置") log_pos = fields.BigIntField(description="Binlog 位置")
@@ -223,7 +223,7 @@ class BinlogPosition(Model):
class ProcessedEvent(Model): class ProcessedEvent(Model):
"""已处理的事件记录模型(用于去重)""" """已处理的事件记录模型(用于去重)"""
id = fields.IntField(pk=True, auto_generate=True) id = fields.IntField(primary_key=True, auto_generate=True)
event_id = fields.CharField(max_length=512, unique=True, description="事件唯一标识") event_id = fields.CharField(max_length=512, unique=True, description="事件唯一标识")
log_file = fields.CharField(max_length=255, description="Binlog 文件名") log_file = fields.CharField(max_length=255, description="Binlog 文件名")
log_pos = fields.BigIntField(description="Binlog 位置") log_pos = fields.BigIntField(description="Binlog 位置")
@@ -246,7 +246,7 @@ class FailedOperation(Model):
"""失败的数据库操作持久化模型""" """失败的数据库操作持久化模型"""
# 基础信息 # 基础信息
id = fields.IntField(pk=True, auto_generate=True) id = fields.IntField(primary_key=True, auto_generate=True)
operation_id = fields.CharField(max_length=64, unique=True, description="操作唯一ID (UUID)") operation_id = fields.CharField(max_length=64, unique=True, description="操作唯一ID (UUID)")
timestamp = fields.DatetimeField(index=True, description="失败时间") timestamp = fields.DatetimeField(index=True, description="失败时间")
@@ -286,7 +286,7 @@ class FailedOperation(Model):
# class APILog(Model): # class APILog(Model):
# """API 相关日志模型""" # """API 相关日志模型"""
# id = fields.IntField(pk=True, auto_generate=True) # id = fields.IntField(primary_key=True, auto_generate=True)
# timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间") # timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间")
# level = fields.CharField(max_length=10, description="日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL") # level = fields.CharField(max_length=10, description="日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL")
# api_request = fields.ForeignKeyField("monitor_models.APIRequest", null=True, description="关联的内部API请求") # api_request = fields.ForeignKeyField("monitor_models.APIRequest", null=True, description="关联的内部API请求")
@@ -308,7 +308,7 @@ class FailedOperation(Model):
# class PerformanceLog(Model): # class PerformanceLog(Model):
# """性能日志模型""" # """性能日志模型"""
# id = fields.IntField(pk=True, auto_generate=True) # id = fields.IntField(primary_key=True, auto_generate=True)
# timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间") # timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间")
# operation = fields.CharField(max_length=255, description="操作名称") # operation = fields.CharField(max_length=255, description="操作名称")
# duration = fields.FloatField(description="执行时间(毫秒)") # duration = fields.FloatField(description="执行时间(毫秒)")
@@ -332,7 +332,7 @@ class FailedOperation(Model):
# class SecurityLog(Model): # class SecurityLog(Model):
# """安全日志模型""" # """安全日志模型"""
# id = fields.IntField(pk=True, auto_generate=True) # id = fields.IntField(primary_key=True, auto_generate=True)
# timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间") # timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间")
# event_type = fields.CharField(max_length=50, description="事件类型:登录、登出、权限变更等") # event_type = fields.CharField(max_length=50, description="事件类型:登录、登出、权限变更等")
# user = fields.CharField(max_length=255, null=True, description="用户标识") # user = fields.CharField(max_length=255, null=True, description="用户标识")
+1 -1
View File
@@ -4,7 +4,7 @@ from tortoise import fields
class ProtoBatchLog(TortoiseBaseModel): class ProtoBatchLog(TortoiseBaseModel):
id = fields.IntField(source_field='ID', pk=True) id = fields.IntField(source_field='ID', primary_key=True)
systime = fields.DatetimeField(source_field='SysTime', null=True) # Field name made lowercase. systime = fields.DatetimeField(source_field='SysTime', null=True) # Field name made lowercase.
pidno = fields.CharField(source_field='PIDNO', max_length=32, null=True, description="ProfileNo") # Field name made lowercase. pidno = fields.CharField(source_field='PIDNO', max_length=32, null=True, description="ProfileNo") # Field name made lowercase.
task = fields.CharField(source_field='Task', max_length=1000, null=True, description="任务") # Field name made lowercase. task = fields.CharField(source_field='Task', max_length=1000, null=True, description="任务") # Field name made lowercase.
+26 -13
View File
@@ -380,8 +380,10 @@ async def get_material(
db_name = db_name.replace(" ", "") db_name = db_name.replace(" ", "")
try: try:
if materialnos != "...": if materialnos != "...":
materialnos = ",".join([f"'{_}'" for _ in materialnos.split(",")]) # 使用安全的SQL条件构建器,防止注入
filter_string = f"`MaterialNo` IN ({materialnos})" materialno_list = materialnos.split(",")
from globalobjects.db_manager import build_safe_condition
filter_string = build_safe_condition("MaterialNo", "IN", materialno_list)
else: else:
filter_string = "" filter_string = ""
result = await db_query(db_name=db_name, model_or_tablename="t_material", filter_string=filter_string) result = await db_query(db_name=db_name, model_or_tablename="t_material", filter_string=filter_string)
@@ -1178,12 +1180,14 @@ async def get_mo_page(
log_api_request(request) log_api_request(request)
db_name = db_name.replace(" ", "") db_name = db_name.replace(" ", "")
filter = [] from globalobjects.db_manager import build_safe_filter
conditions = []
if start_time: if start_time:
filter.append(f"`DT_OrdStart` >= '{start_time}'") conditions.append(("DT_OrdStart", ">=", start_time))
if end_time: if end_time:
filter.append(f"`DT_OrdEnd` <= '{end_time}'") conditions.append(("DT_OrdEnd", "<=", end_time))
filter_string = " AND ".join(filter) filter_string = build_safe_filter(conditions)
try: try:
result = await db_query(db_name=db_name, model_or_tablename="v_supply_mo", filter_string=filter_string, page_size=page_size, page_index=page_index) result = await db_query(db_name=db_name, model_or_tablename="v_supply_mo", filter_string=filter_string, page_size=page_size, page_index=page_index)
return standard_response( return standard_response(
@@ -1273,12 +1277,15 @@ async def get_orderwc_page(
): ):
log_api_request(request) log_api_request(request)
db_name = db_name.replace(" ", "") db_name = db_name.replace(" ", "")
filter = []
from globalobjects.db_manager import build_safe_filter
conditions = []
if start_time: if start_time:
filter.append(f"`DT_Start` >= '{start_time}'") conditions.append(("DT_Start", ">=", start_time))
if end_time: if end_time:
filter.append(f"`DT_End` <= '{end_time}'") conditions.append(("DT_End", "<=", end_time))
filter_string = " AND ".join(filter) filter_string = build_safe_filter(conditions)
try: try:
result = await db_query(db_name=db_name, model_or_tablename="v_orderwc", filter_string=filter_string, page_size=page_size, page_index=page_index) result = await db_query(db_name=db_name, model_or_tablename="v_orderwc", filter_string=filter_string, page_size=page_size, page_index=page_index)
return standard_response( return standard_response(
@@ -1313,7 +1320,9 @@ async def get_orderwc(
): ):
log_api_request(request) log_api_request(request)
db_name = db_name.replace(" ", "") db_name = db_name.replace(" ", "")
filter_string = f"`SupplyNo` = '{supplyno}'"
from globalobjects.db_manager import build_safe_condition
filter_string = build_safe_condition("SupplyNo", "=", supplyno)
try: try:
result = await db_query(db_name=db_name, model_or_tablename="v_orderwc", filter_string=filter_string) result = await db_query(db_name=db_name, model_or_tablename="v_orderwc", filter_string=filter_string)
return standard_response( return standard_response(
@@ -1478,9 +1487,13 @@ async def delete_workreport(
): ):
log_api_request(request) log_api_request(request)
db_name = db_name.replace(" ", "") db_name = db_name.replace(" ", "")
filter_string = f"`SupplyNo`='{supplyno}'"
from globalobjects.db_manager import build_safe_filter
conditions = [("SupplyNo", "=", supplyno)]
if not itemno == "...": if not itemno == "...":
filter_string += f" AND `ItemNo`='{itemno}'" conditions.append(("ItemNo", "=", itemno))
filter_string = build_safe_filter(conditions)
try: try:
result = await db_delete(db_names=db_name, model_or_tablename="t_confirm", filter_string=filter_string) result = await db_delete(db_names=db_name, model_or_tablename="t_confirm", filter_string=filter_string)
return standard_response( return standard_response(
+6 -5
View File
@@ -238,17 +238,18 @@ async def drop_matched_data(data: List[Any], db_names: str, table_name: str, mat
batch_size = 100 batch_size = 100
combinations_list = list(unique_combinations) combinations_list = list(unique_combinations)
from globalobjects.db_manager import build_safe_filter
for i in range(0, len(combinations_list), batch_size): for i in range(0, len(combinations_list), batch_size):
batch = combinations_list[i:i+batch_size] batch = combinations_list[i:i+batch_size]
conditions = [] batch_conditions = []
for values in batch: for values in batch:
# 构建条件,支持任意数量的字段 # 构建条件,支持任意数量的字段
field_conditions = [] field_conditions = []
for db_field, value in zip(db_fields, values): for db_field, value in zip(db_fields, values):
field_conditions.append(f"`{db_field}`='{value}'") field_conditions.append((db_field, "=", value))
condition = " AND ".join(field_conditions) batch_conditions.append(build_safe_filter(field_conditions))
conditions.append(f"({condition})") filter_string = " OR ".join(f"({cond})" for cond in batch_conditions)
filter_string = " OR ".join(conditions)
try: try:
await db_delete(db_names=db_names, model_or_tablename=table_name, filter_string=filter_string) await db_delete(db_names=db_names, model_or_tablename=table_name, filter_string=filter_string)
except Exception as e: except Exception as e:
+133 -134
View File
@@ -18,6 +18,7 @@ from apps.common.monitor.log_stream_service import start_log_stream, stop_log_st
from globalobjects import EVENT_AGGREGATOR from globalobjects import EVENT_AGGREGATOR
from core.settings import TURNON_BINLOG_LISTENER, TRUNON_SCHEDULER, MAX_EVENTS_BATCH_SIZE from core.settings import TURNON_BINLOG_LISTENER, TRUNON_SCHEDULER, MAX_EVENTS_BATCH_SIZE
from core.database import check_db_connections, warmup_connections, start_pool_monitoring, db_init_manager from core.database import check_db_connections, warmup_connections, start_pool_monitoring, db_init_manager
from core.task_manager import get_task_manager
@asynccontextmanager @asynccontextmanager
@@ -118,12 +119,19 @@ async def lifespan(app):
# 启动数据库连接检查任务(从原startup_event迁移) # 启动数据库连接检查任务(从原startup_event迁移)
log_config.info("启动数据库连接检查任务...") log_config.info("启动数据库连接检查任务...")
db_check_task = asyncio.create_task(schedule_db_checks()) task_manager = get_task_manager()
db_check_task = task_manager.create_and_register(
"db_check_task",
schedule_db_checks()
)
log_config.info("数据库连接检查任务已启动") log_config.info("数据库连接检查任务已启动")
# 启动连接池监控任务 # 启动连接池监控任务
log_config.info("启动连接池监控任务...") log_config.info("启动连接池监控任务...")
pool_monitor_task = asyncio.create_task(start_pool_monitoring()) pool_monitor_task = task_manager.create_and_register(
"pool_monitor_task",
start_pool_monitoring()
)
log_config.info("连接池监控任务已启动") log_config.info("连接池监控任务已启动")
# 启动日志数据库批次刷新任务 # 启动日志数据库批次刷新任务
@@ -145,7 +153,10 @@ async def lifespan(app):
pass pass
log_config.info("启动日志数据库批次刷新任务...") log_config.info("启动日志数据库批次刷新任务...")
log_db_flush_task = asyncio.create_task(schedule_log_db_flush()) log_db_flush_task = task_manager.create_and_register(
"log_db_flush_task",
schedule_log_db_flush()
)
log_config.info("日志数据库批次刷新任务已启动") log_config.info("日志数据库批次刷新任务已启动")
# 启动数据库健康检查器(独立后台任务,不依赖前端访问) # 启动数据库健康检查器(独立后台任务,不依赖前端访问)
@@ -218,7 +229,10 @@ async def lifespan(app):
await asyncio.sleep(90) await asyncio.sleep(90)
# 创建 Redis 健康检查任务 # 创建 Redis 健康检查任务
redis_check_task = asyncio.create_task(schedule_redis_checks()) redis_check_task = task_manager.create_and_register(
"redis_check_task",
schedule_redis_checks()
)
log_config.info("Redis 健康检查任务已启动") log_config.info("Redis 健康检查任务已启动")
# 启动 Redis 消息消费者(处理来自数据库监听器的事件) # 启动 Redis 消息消费者(处理来自数据库监听器的事件)
@@ -383,11 +397,17 @@ async def lifespan(app):
except Exception as e: except Exception as e:
log_config.error(f"Redis 事件清理任务启动失败: {e}") log_config.error(f"Redis 事件清理任务启动失败: {e}")
asyncio.create_task(start_redis_consumer()) task_manager.create_and_register(
"redis_consumer_task",
start_redis_consumer()
)
log_config.info("Redis 消息消费者已启动") log_config.info("Redis 消息消费者已启动")
# 启动 Redis 事件清理任务 # 启动 Redis 事件清理任务
asyncio.create_task(cleanup_expired_events()) task_manager.create_and_register(
"redis_cleanup_task",
cleanup_expired_events()
)
log_config.info("Redis 事件清理任务已启动") log_config.info("Redis 事件清理任务已启动")
# 等待一段时间,确保所有服务正常启动 # 等待一段时间,确保所有服务正常启动
@@ -396,140 +416,119 @@ async def lifespan(app):
yield # 应用运行期间 yield # 应用运行期间
# 应用关闭时执行的操作 # ============ 应用关闭阶段 ============
log_config.info("应用关闭中...") # 关闭顺序:任务 -> 服务 -> 资源
log_config.info("==================应用开始关闭==================")
# 0. 关闭数据库连接 # 阶段1: 取消所有后台任务(优先执行)
log_config.info("正在关闭数据库连接...") log_config.info("【阶段1】取消所有后台任务...")
task_manager = get_task_manager()
await task_manager.cancel_all(timeout=10.0)
log_config.info("✅ 所有后台任务已取消")
# 阶段2: 停止各服务和监控器
log_config.info("【阶段2】停止服务和监控器...")
# 2.1 停止实时日志流服务
log_config.info("停止实时日志流服务...")
await stop_log_stream()
log_config.info("✅ 实时日志流服务已停止")
# 2.2 停止数据库健康检查器
log_config.info("停止数据库健康检查器...")
await stop_db_health_checker()
log_config.info("✅ 数据库健康检查器已停止")
# 2.3 停止失败操作恢复管理器
log_config.info("停止失败操作恢复管理器...")
await stop_failed_operation_recovery()
log_config.info("✅ 失败操作恢复管理器已停止")
# 2.4 停止 MySQL Binlog 监控
if TURNON_BINLOG_LISTENER:
log_config.info("停止 MySQL Binlog 监控...")
binlog_listener.stop_monitoring()
log_config.info("✅ MySQL Binlog监控已停止")
# 2.5 停止资源监控
log_config.info("停止资源监控...")
resource_monitor.stop_monitoring()
log_config.info("✅ 系统资源监控已停止")
# 2.6 停止事件聚合器
log_config.info("停止事件聚合器...")
EVENT_AGGREGATOR.stop()
log_config.info("✅ 事件聚合器已停止")
# 2.7 关闭事件线程池管理器
log_config.info("关闭事件线程池...")
from globalobjects.event_aggregator import get_event_pool_manager
get_event_pool_manager().shutdown_all()
log_config.info("✅ 事件线程池已关闭")
# 2.8 关闭调度器
if TRUNON_SCHEDULER:
log_config.info("关闭调度器...")
scheduler_manager.shutdown()
log_config.info("✅ 定时任务管理器已关闭")
log_config.info("✅ 所有服务已停止")
# 阶段3: 释放资源和连接(最后执行)
log_config.info("【阶段3】释放资源和连接...")
# 3.1 刷新 Redis 缓冲
log_config.info("刷新 Redis 缓冲...")
import concurrent.futures
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
loop = asyncio.get_event_loop()
def get_buffer_size():
from apps.common.utils.redis_pool_manager import get_redis_pool_manager
return get_redis_pool_manager().get_buffer_size()
try:
buffer_size = await loop.run_in_executor(executor, get_buffer_size)
if buffer_size > 0:
log_config.info(f"发现 {buffer_size} 个事件在本地缓冲中,准备刷新...")
def flush_buffer():
from apps.common.utils.redis_pool_manager import flush_event_buffer
return flush_event_buffer('db_events')
flushed = await loop.run_in_executor(executor, flush_buffer)
log_config.info(f"✅ 缓冲刷新完成,成功刷新 {flushed} 个事件")
except Exception as e:
log_config.warning(f"⚠️ 刷新Redis缓冲失败: {e}")
# 3.2 关闭事件辅助模块
log_config.info("关闭事件辅助模块...")
try:
from apps.common.utils.event_helpers import shutdown_event_helpers
shutdown_event_helpers()
log_config.info("✅ 事件辅助模块已关闭")
except Exception as e:
log_config.warning(f"⚠️ 关闭事件辅助模块失败: {e}")
# 3.3 关闭数据库连接(最后关闭)
log_config.info("关闭数据库连接...")
try: try:
from tortoise import Tortoise from tortoise import Tortoise
await Tortoise.close_connections() await Tortoise.close_connections()
log_config.info("✅ 数据库连接已关闭") log_config.info("✅ 数据库连接已关闭")
except Exception as e: except Exception as e:
log_config.warning(f"⚠️ 关闭数据库连接时出错: {e}") log_config.warning(f"⚠️ 关闭数据库连接失败: {e}")
# 1. 先停止 MySQL Binlog 监控(最依赖数据库 # 阶段4: 关闭日志系统(最后
if TURNON_BINLOG_LISTENER: log_config.info("【阶段4】关闭日志系统...")
log_config.info("正在停止 MySQL Binlog 监控...")
binlog_listener.stop_monitoring()
log_config.info("==================MySQL Binlog监控已停止==================")
else:
log_config.debug("⚠️ MySQL Binlog监控未启动,无需停止")
# 2. 停止 Redis 相关任务
log_config.info("正在停止 Redis 相关任务...")
# 在线程池中执行缓冲刷新,避免阻塞事件循环
import concurrent.futures
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
loop = asyncio.get_event_loop()
def get_buffer_size():
from apps.common.utils.redis_pool_manager import get_redis_pool_manager
return get_redis_pool_manager().get_buffer_size()
buffer_size = await loop.run_in_executor(executor, get_buffer_size)
if buffer_size > 0:
log_config.info(f"发现 {buffer_size} 个事件在本地缓冲中,准备刷新...")
def flush_buffer():
from apps.common.utils.redis_pool_manager import flush_event_buffer
return flush_event_buffer('db_events')
flushed = await loop.run_in_executor(executor, flush_buffer)
log_config.info(f"缓冲刷新完成,成功刷新 {flushed} 个事件")
log_config.info("==================Redis 相关任务已停止==================")
# 2.1 关闭事件辅助模块(DeadLetter队列等)
log_config.info("正在关闭事件辅助模块...")
from apps.common.utils.event_helpers import shutdown_event_helpers
shutdown_event_helpers()
log_config.info("==================事件辅助模块已关闭==================")
# 3. 等待一段时间,确保所有任务完成
log_config.info("⏳ 等待所有后台任务完成...")
await asyncio.sleep(5) # 等待5秒,让所有任务完成
# 4. 关闭调度器
if TRUNON_SCHEDULER:
log_config.info("正在关闭调度器...")
scheduler_manager.shutdown()
log_config.info("==================定时任务管理器已关闭==================")
else:
log_config.debug("⚠️ 定时任务管理器未启动,无需关闭")
# 5. 停止资源监控 # 在关闭日志系统前输出最终提示
log_config.info("正在停止资源监控...")
resource_monitor.stop_monitoring()
log_config.info("==================系统资源监控已停止==================")
# 6. 停止事件聚合器
log_config.info("正在停止事件聚合器...")
log_config.info("==================事件聚合器已停止==================")
EVENT_AGGREGATOR.stop()
log_config.info("==================事件聚合器已停止==================")
# 6.1 关闭事件线程池管理器
log_config.info("正在关闭事件线程池...")
from globalobjects.event_aggregator import get_event_pool_manager
get_event_pool_manager().shutdown_all()
log_config.info("==================事件线程池已关闭==================")
# 7. 停止数据库健康检查器
log_config.info("正在停止数据库健康检查器...")
await stop_db_health_checker()
log_config.info("==================数据库健康检查器已停止==================")
# 8. 停止失败操作恢复管理器
log_config.info("正在停止OperationRecovery管理器...")
await stop_failed_operation_recovery()
log_config.info("==================OperationRecovery管理器已停止==================")
# 10. 取消后台任务
if 'db_check_task' in locals():
log_config.info("正在取消数据库连接检查任务...")
db_check_task.cancel()
try:
await db_check_task
except asyncio.CancelledError:
pass
log_config.info("==================数据库连接检查任务已取消==================")
if 'log_db_flush_task' in locals():
log_config.info("正在取消日志数据库批次刷新任务...")
log_db_flush_task.cancel()
try:
await log_db_flush_task
except asyncio.CancelledError:
pass
log_config.info("==================日志数据库批次刷新任务已取消==================")
if 'pool_monitor_task' in locals():
log_config.info("正在取消连接池监控任务...")
pool_monitor_task.cancel()
try:
await pool_monitor_task
except asyncio.CancelledError:
pass
log_config.info("==================连接池监控任务已取消==================")
# 取消 Redis 健康检查任务
if 'redis_check_task' in locals():
log_config.info("正在取消 Redis 健康检查任务...")
redis_check_task.cancel()
try:
await redis_check_task
except asyncio.CancelledError:
pass
log_config.info("==================Redis 健康检查任务已取消==================")
# 11. 等待一段时间,确保所有任务真正完成
log_config.info("⏳ 等待所有任务彻底完成...")
await asyncio.sleep(3) # 再等待3秒
log_config.info("==================应用关闭完成==================") log_config.info("==================应用关闭完成==================")
log_config.info("所有资源已释放,服务已完全停止")
# 12. 关闭统一日志系统
await shutdown_logging() await shutdown_logging()
# 13. 停止实时日志流服务 # 使用print确保关闭后的提示能输出(日志系统已关闭)
await stop_log_stream() print("=" * 50)
print("MyAPS API 应用已完全关闭")
print("=" * 50)
+29 -3
View File
@@ -15,6 +15,18 @@ DOC_PREFIXES = ["/static/swagger"]
MDS_PATHS = ["/mds", "/mds/material", "/mds/workcenter", "/mds/mat-ver", MDS_PATHS = ["/mds", "/mds/material", "/mds/workcenter", "/mds/mat-ver",
"/mds/mat-wc", "/mds/mat-wc-bom", "/mds/mold", "/mds/mat-wc-mold"] "/mds/mat-wc", "/mds/mat-wc-bom", "/mds/mold", "/mds/mat-wc-mold"]
# 公开的GET接口(不需要认证)
# 用于健康检查、静态资源等
PUBLIC_GET_PATHS = [
"/health", # K8s/负载均衡健康检查
"/health/database", # 数据库健康检查
]
# 公开的路径前缀
PUBLIC_GET_PREFIXES = [
"/static/", # 静态资源
]
# 缓存已注册的路由信息,避免每次请求都重新解析 # 缓存已注册的路由信息,避免每次请求都重新解析
REGISTERED_ROUTES = [] REGISTERED_ROUTES = []
@@ -235,10 +247,20 @@ def create_security_middleware():
content={"status_code": 404, "success": 0, "meta": {}, "message": "Not Found"} content={"status_code": 404, "success": 0, "meta": {}, "message": "Not Found"}
) )
# 对GET和OPTIONS方法直接放行 # OPTIONS方法直接放行CORS预检)
if request_method in ["GET", "OPTIONS"]: if request_method == "OPTIONS":
return await call_next(request) return await call_next(request)
# GET方法需要检查是否在公开路径列表
if request_method == "GET":
is_public_path = (
url_path in PUBLIC_GET_PATHS or
any(url_path.startswith(prefix) for prefix in PUBLIC_GET_PREFIXES)
)
if is_public_path:
return await call_next(request)
# 非公开GET路径需要继续鉴权
# 检查IP是否在白名单中 # 检查IP是否在白名单中
client_ip = request.client.host client_ip = request.client.host
if is_ip_allowed(client_ip): if is_ip_allowed(client_ip):
@@ -248,6 +270,10 @@ def create_security_middleware():
if not API_KEY or request.headers.get("X-API-Key") == API_KEY: if not API_KEY or request.headers.get("X-API-Key") == API_KEY:
return await call_next(request) return await call_next(request)
return JSONResponse(status_code=200, content={"status_code": 403, "success": 0, "meta": {}, "message": "Forbidden: Invalid or missing API Key"}) # 未授权请求返回真实HTTP 401状态码
return JSONResponse(
status_code=401,
content={"status_code": 401, "success": 0, "meta": {}, "message": "Unauthorized: Invalid or missing API Key"}
)
return security_middleware return security_middleware
+181
View File
@@ -0,0 +1,181 @@
"""
后台任务管理器
统一管理所有后台任务的生命周期
"""
import asyncio
from typing import Dict, Set, Optional, Any
from globalobjects import logger as log_config
class BackgroundTaskManager:
"""
后台任务统一管理器
功能
1. 统一注册所有后台任务
2. 提供优雅关闭机制
3. 确保关闭顺序正确先取消任务再释放资源
"""
def __init__(self):
self._tasks: Dict[str, asyncio.Task] = {}
self._shutdown_timeout: float = 10.0
def register(self, name: str, task: asyncio.Task) -> asyncio.Task:
"""
注册后台任务
Args:
name: 任务名称用于日志和调试
task: asyncio.Task实例
Returns:
返回传入的task方便链式调用
"""
self._tasks[name] = task
log_config.debug(f"后台任务已注册: {name}")
# 添加完成回调,自动清理
def _on_task_done(t: asyncio.Task):
task_name = None
for k, v in self._tasks.items():
if v == t:
task_name = k
break
if task_name:
self._tasks.pop(task_name, None)
if not t.cancelled():
exc = t.exception()
if exc:
log_config.error(f"后台任务异常退出: {task_name}", exc_info=exc)
else:
log_config.debug(f"后台任务正常完成: {task_name}")
task.add_done_callback(_on_task_done)
return task
def create_and_register(
self,
name: str,
coro,
*,
delay: float = 0.0
) -> asyncio.Task:
"""
创建并注册后台任务
Args:
name: 任务名称
coro: 协程对象
delay: 启动延迟
Returns:
创建的Task实例
"""
async def _wrapped_coro():
if delay > 0:
await asyncio.sleep(delay)
await coro
task = asyncio.create_task(_wrapped_coro())
return self.register(name, task)
async def cancel_all(self, timeout: Optional[float] = None) -> Dict[str, bool]:
"""
取消所有后台任务
Args:
timeout: 超时时间None使用默认值
Returns:
任务取消结果字典 {任务名: 是否成功}
"""
if not self._tasks:
return {}
timeout = timeout or self._shutdown_timeout
results = {}
log_config.info(f"开始取消 {len(self._tasks)} 个后台任务...")
# 第一阶段:发送取消信号
for name, task in list(self._tasks.items()):
if not task.done():
task.cancel()
log_config.debug(f"已发送取消信号: {name}")
# 第二阶段:等待任务完成
async def wait_for_task(name: str, task: asyncio.Task) -> bool:
try:
await asyncio.wait_for(task, timeout=timeout)
return True
except asyncio.CancelledError:
return True # 正常取消
except asyncio.TimeoutError:
log_config.warning(f"任务取消超时: {name}")
return False
except Exception as e:
log_config.error(f"任务取消异常: {name}, {e}")
return False
# 并行等待所有任务
wait_tasks = {
name: asyncio.create_task(wait_for_task(name, task))
for name, task in list(self._tasks.items())
}
# 等待所有等待任务完成
for name, wait_task in wait_tasks.items():
try:
results[name] = await wait_task
except Exception as e:
log_config.error(f"等待任务完成失败: {name}, {e}")
results[name] = False
# 清理已完成的任务
self._tasks.clear()
# 统计结果
success_count = sum(1 for v in results.values() if v)
fail_count = len(results) - success_count
if fail_count > 0:
log_config.warning(f"任务取消完成: 成功 {success_count}, 失败 {fail_count}")
else:
log_config.info(f"所有 {success_count} 个后台任务已成功取消")
return results
def get_task_names(self) -> Set[str]:
"""获取所有已注册任务名称"""
return set(self._tasks.keys())
def get_task_count(self) -> int:
"""获取已注册任务数量"""
return len(self._tasks)
def has_task(self, name: str) -> bool:
"""检查任务是否已注册"""
return name in self._tasks
def set_shutdown_timeout(self, timeout: float):
"""设置关闭超时时间"""
self._shutdown_timeout = timeout
# 全局任务管理器实例
_task_manager: Optional[BackgroundTaskManager] = None
def get_task_manager() -> BackgroundTaskManager:
"""获取全局任务管理器"""
global _task_manager
if _task_manager is None:
_task_manager = BackgroundTaskManager()
return _task_manager
def reset_task_manager():
"""重置任务管理器(用于测试)"""
global _task_manager
_task_manager = None
+261 -1
View File
@@ -4,6 +4,7 @@ from datetime import datetime
import time import time
import asyncio import asyncio
import functools import functools
import re
from tortoise import Tortoise from tortoise import Tortoise
from tortoise.connection import connections from tortoise.connection import connections
@@ -16,10 +17,269 @@ from globalobjects import logger as log_config
import os import os
LOG_LEVEL = os.getenv("LOG_LEVEL") or "INFO" LOG_LEVEL = os.getenv("LOG_LEVEL") or "INFO"
# 获取统一日志器
logger = log_config.get_logger(__name__, level=LOG_LEVEL) logger = log_config.get_logger(__name__, level=LOG_LEVEL)
def escape_sql_value(value: Any) -> str:
"""
安全转义SQL值防止SQL注入
Args:
value: 要转义的值
Returns:
转义后的安全字符串
"""
if value is None:
return "NULL"
if isinstance(value, bool):
return "1" if value else "0"
if isinstance(value, (int, float)):
return str(value)
if isinstance(value, datetime):
return f"'{value.strftime('%Y-%m-%d %H:%M:%S')}'"
# 字符串处理:转义单引号
str_value = str(value)
# 将单引号转义为两个单引号(SQL标准)
escaped = str_value.replace("'", "''")
return f"'{escaped}'"
def validate_identifier(identifier: str) -> str:
"""
验证并安全化SQL标识符表名字段名
Args:
identifier: 标识符
Returns:
安全的标识符用反引号包裹
Raises:
ValueError: 如果标识符包含危险字符
"""
# 移除首尾空格
identifier = identifier.strip()
# 检查危险字符
dangerous_chars = ["'", '"', ';', '--', '/*', '*/', '\x00', '\n', '\r']
for char in dangerous_chars:
if char in identifier:
raise ValueError(f"Invalid identifier: contains dangerous character '{char}'")
# 用反引号包裹
return f"`{identifier}`"
def build_safe_condition(field: str, operator: str, value: Any) -> str:
"""
构建安全的SQL条件表达式
Args:
field: 字段名
operator: 操作符 (=, !=, >, <, >=, <=, LIKE, IN)
value:
Returns:
安全的SQL条件字符串
"""
safe_field = validate_identifier(field)
operator = operator.upper().strip()
if operator == "IN":
if not isinstance(value, (list, tuple)):
raise ValueError("IN operator requires a list or tuple of values")
escaped_values = [escape_sql_value(v) for v in value]
return f"{safe_field} IN ({', '.join(escaped_values)})"
if operator == "LIKE":
return f"{safe_field} LIKE {escape_sql_value(value)}"
# 标准比较操作符
valid_operators = ['=', '!=', '<>', '>', '<', '>=', '<=', 'IS', 'IS NOT']
if operator not in valid_operators:
raise ValueError(f"Invalid operator: {operator}")
if operator in ('IS', 'IS NOT'):
if value is None:
return f"{safe_field} {operator} NULL"
raise ValueError(f"{operator} operator only accepts None value")
return f"{safe_field} {operator} {escape_sql_value(value)}"
def build_safe_filter(conditions: List[Tuple[str, str, Any]], logic: str = "AND") -> str:
"""
构建安全的WHERE条件字符串
Args:
conditions: 条件列表每个条件为 (字段名, 操作符, ) 元组
logic: 逻辑连接符 (AND/OR)
Returns:
安全的WHERE条件字符串
"""
if not conditions:
return ""
logic = logic.upper().strip()
if logic not in ("AND", "OR"):
raise ValueError(f"Invalid logic operator: {logic}")
safe_conditions = [build_safe_condition(*cond) for cond in conditions]
return f" {logic} ".join(safe_conditions)
def build_safe_order_by(fields: List[Tuple[str, str]]) -> str:
"""
构建安全的ORDER BY子句
Args:
fields: 排序字段列表每个元素为 (字段名, 方向) 元组
方向为 'ASC' 'DESC'
Returns:
安全的ORDER BY字符串
"""
if not fields:
return ""
order_parts = []
for field, direction in fields:
safe_field = validate_identifier(field)
direction = direction.upper().strip()
if direction not in ("ASC", "DESC"):
raise ValueError(f"Invalid order direction: {direction}")
order_parts.append(f"{safe_field} {direction}")
return ", ".join(order_parts)
def build_safe_select(fields: List[str]) -> str:
"""
构建安全的SELECT字段列表
Args:
fields: 字段名列表
Returns:
安全的SELECT字段字符串
"""
if not fields:
return "*"
return ", ".join(validate_identifier(f) for f in fields)
class SafeQueryBuilder:
"""
安全SQL查询构建器
提供链式调用的SQL构建接口自动处理转义和验证
"""
def __init__(self, table_name: str):
"""
初始化查询构建器
Args:
table_name: 表名
"""
self._table = validate_identifier(table_name)
self._select_fields = "*"
self._conditions = []
self._order_fields = []
self._limit = None
self._offset = None
def select(self, *fields: str) -> 'SafeQueryBuilder':
"""设置SELECT字段"""
if fields:
self._select_fields = build_safe_select(list(fields))
return self
def where(self, field: str, operator: str, value: Any) -> 'SafeQueryBuilder':
"""添加WHERE条件"""
self._conditions.append((field, operator, value))
return self
def where_in(self, field: str, values: List[Any]) -> 'SafeQueryBuilder':
"""添加IN条件"""
self._conditions.append((field, "IN", values))
return self
def where_like(self, field: str, pattern: str) -> 'SafeQueryBuilder':
"""添加LIKE条件"""
self._conditions.append((field, "LIKE", pattern))
return self
def where_between(self, field: str, start: Any, end: Any) -> 'SafeQueryBuilder':
"""添加BETWEEN条件"""
safe_field = validate_identifier(field)
self._conditions.append((f"{safe_field} >= {escape_sql_value(start)}", "=", True))
self._conditions.append((f"{safe_field} <= {escape_sql_value(end)}", "=", True))
return self
def order_by(self, field: str, direction: str = "ASC") -> 'SafeQueryBuilder':
"""添加排序"""
self._order_fields.append((field, direction))
return self
def limit(self, count: int) -> 'SafeQueryBuilder':
"""设置LIMIT"""
if count > 0:
self._limit = count
return self
def offset(self, count: int) -> 'SafeQueryBuilder':
"""设置OFFSET"""
if count >= 0:
self._offset = count
return self
def build_select_sql(self) -> str:
"""构建SELECT SQL语句"""
sql = f"SELECT {self._select_fields} FROM {self._table}"
if self._conditions:
where_clause = build_safe_filter(self._conditions)
sql += f" WHERE {where_clause}"
if self._order_fields:
order_clause = build_safe_order_by(self._order_fields)
sql += f" ORDER BY {order_clause}"
if self._limit is not None:
sql += f" LIMIT {self._limit}"
if self._offset is not None:
sql += f" OFFSET {self._offset}"
return sql
def build_count_sql(self) -> str:
"""构建COUNT SQL语句"""
sql = f"SELECT COUNT(*) as total FROM {self._table}"
if self._conditions:
where_clause = build_safe_filter(self._conditions)
sql += f" WHERE {where_clause}"
return sql
def build_delete_sql(self) -> str:
"""构建DELETE SQL语句"""
if not self._conditions:
raise ValueError("DELETE operation requires WHERE conditions for safety")
where_clause = build_safe_filter(self._conditions)
return f"DELETE FROM {self._table} WHERE {where_clause}"
def dict_to_lower_keys(d: dict) -> dict: def dict_to_lower_keys(d: dict) -> dict:
""" """
将字典的键转换为小写 将字典的键转换为小写
-6
View File
@@ -36,12 +36,6 @@ class LogRecord(BaseModel):
extra: Optional[Dict[str, Any]] = None extra: Optional[Dict[str, Any]] = None
model_config = {
"json_encoders": {
datetime: lambda v: v.isoformat()
}
}
@field_validator('level_name', mode='before') @field_validator('level_name', mode='before')
@classmethod @classmethod
def validate_level_name(cls, v: str) -> str: def validate_level_name(cls, v: str) -> str: