From fa5cecd6d1c202072d65f246e9fe7c30f3c96879 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B6=85=E5=93=A5?= <2982212683@qq.com> Date: Mon, 25 May 2026 20:08:35 +0800 Subject: [PATCH] =?UTF-8?q?fix(security,stability):=20=E5=AE=8C=E6=88=90AP?= =?UTF-8?q?I=E5=AE=89=E5=85=A8=E4=B8=8E=E7=A8=B3=E5=AE=9A=E6=80=A7?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 安全: 修复鉴权失败返回码(HTTP 401/403替代200) - 安全: 新增SafeQueryBuilder封堵SQL注入入口 - 安全: 移除Pydantic json_encoders弃用配置 - 稳定: 统一后台任务托管与生命周期管理 - 稳定: 新增TaskManager统一管理后台任务 - 文档: 更新README.md与.env.example - 重构: routers.py使用安全SQL构建器替代字符串拼接 --- .env.example | 4 + README.md | 80 ++++++++++ apps/common/monitor/models.py | 18 +-- apps/io_api/protomodels.py | 2 +- apps/io_api/routers.py | 39 +++-- apps/io_api/utils/common.py | 11 +- core/lifespan.py | 267 ++++++++++++++++----------------- core/middleware.py | 32 +++- core/task_manager.py | 181 ++++++++++++++++++++++ globalobjects/db_manager.py | 262 +++++++++++++++++++++++++++++++- globalobjects/logger/models.py | 6 - 11 files changed, 730 insertions(+), 172 deletions(-) create mode 100644 core/task_manager.py diff --git a/.env.example b/.env.example index bba802b..0b751ad 100644 --- a/.env.example +++ b/.env.example @@ -10,6 +10,10 @@ TIMEZONE=+8 # 项目目录配置(必填) PROJECT_DIR=YOUR_PROJECT_DIR +# 安全配置(可选) +# API_KEY=your-api-key-here # API密钥,设置后非公开接口需要认证 +# IP_WHITELIST=127.0.0.1,192.168.1.* # IP白名单,支持通配符、范围和CIDR + # 数据库配置 MYAPS_DB_HOST=localhost MYAPS_DB_PORT=3333 diff --git a/README.md b/README.md index ad9e89b..f082e01 100644 --- a/README.md +++ b/README.md @@ -159,12 +159,92 @@ MYAPS_DB_PASSWORD - `STAGING_DB_NAME` 默认为 `--s` - 清洗模式与主业务数据库配置分离 +## 安全配置说明 + +### API 认证 + +项目支持基于 API Key 的认证机制: + +```bash +# 设置 API Key(可选) +API_KEY=your-api-key-here +``` + +- 未设置 `API_KEY` 时,所有接口可自由访问 +- 设置后,非公开接口需要在请求头中携带 `X-API-Key` +- 认证失败返回 HTTP 401(不再返回 200) + +### 公开接口 + +以下接口无需认证即可访问: + +- `/health` - 健康检查(用于 K8s/负载均衡) +- `/health/database` - 数据库健康检查 +- `/static/*` - 静态资源 +- `/mds/*` - MDS 页面 +- `/docs`, `/redoc` - API 文档(仅限内网访问) + +### IP 白名单 + +支持多种格式的 IP 白名单配置: + +```bash +IP_WHITELIST=127.0.0.1,192.168.1.*,10.0.0.0/8,192.168.1.100-200 +``` + +支持的格式: +- 精确 IP: `192.168.1.100` +- 通配符: `192.168.1.*` +- IP 范围: `192.168.1.100-200` +- CIDR 表示法: `10.0.0.0/8` + +### SQL 注入防护 + +项目已实现 SQL 注入防护机制: + +- 使用安全 SQL 构建函数(`escape_sql_value`, `build_safe_condition`) +- 自动转义用户输入 +- 标识符验证防止注入攻击 +- 所有外部输入都经过安全处理 + ## 监控与日志 - 实时日志与历史日志页面位于 `/monitor` - 统一日志系统位于 `globalobjects/logger/` - 开发期可通过 `./scripts/dev_server.sh logs -f` 查看实时日志 +## 生命周期管理 + +### 后台任务管理 + +项目使用统一的后台任务管理器(`core/task_manager.py`): + +- 所有后台任务统一注册和跟踪 +- 支持优雅关闭和超时保护 +- 自动清理已完成的任务 + +### 关闭顺序 + +应用关闭时按照以下顺序执行,确保资源正确释放: + +``` +阶段1: 取消所有后台任务(优先执行) +阶段2: 停止服务和监控器 +阶段3: 释放资源和连接(数据库最后关闭) +阶段4: 关闭日志系统 +``` + +关闭时会输出明确提示: + +``` +==================应用关闭完成================== +所有资源已释放,服务已完全停止 +================================================== +MyAPS API 应用已完全关闭 +感谢使用,再见! +================================================== +``` + ## 当前验证状态 在当前仓库环境下,以下命令已验证通过: diff --git a/apps/common/monitor/models.py b/apps/common/monitor/models.py index d6cf500..11b21b5 100644 --- a/apps/common/monitor/models.py +++ b/apps/common/monitor/models.py @@ -104,7 +104,7 @@ def get_host_from_url(url: str) -> Optional[str]: class APIRequest(Model): """API 请求记录模型""" - id = fields.IntField(pk=True, auto_generate=True) + id = fields.IntField(primary_key=True, auto_generate=True) request_id = fields.CharField(max_length=36, null=True, description="请求唯一标识(UUID)") timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="请求时间") method = fields.CharField(max_length=10, description="HTTP 方法") @@ -144,7 +144,7 @@ class APIRequest(Model): class OutboundAPIRequest(Model): """对外 HTTP 请求记录模型""" - id = fields.IntField(pk=True, auto_generate=True) + id = fields.IntField(primary_key=True, auto_generate=True) timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="请求时间") method = fields.CharField(max_length=10, description="HTTP 方法") url = fields.TextField(description="请求 URL") @@ -177,7 +177,7 @@ class OutboundAPIRequest(Model): class SystemLog(Model): """系统日志模型""" - id = fields.IntField(pk=True, auto_generate=True) + id = fields.IntField(primary_key=True, auto_generate=True) timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间") level = fields.CharField(max_length=10, description="日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL") module = fields.CharField(max_length=255, description="模块名称") @@ -205,7 +205,7 @@ class SystemLog(Model): class BinlogPosition(Model): """Binlog 位置记录模型""" - id = fields.IntField(pk=True, auto_generate=True) + id = fields.IntField(primary_key=True, auto_generate=True) server_id = fields.CharField(max_length=255, description="MySQL 服务器标识") log_file = fields.CharField(max_length=255, description="Binlog 文件名") log_pos = fields.BigIntField(description="Binlog 位置") @@ -223,7 +223,7 @@ class BinlogPosition(Model): class ProcessedEvent(Model): """已处理的事件记录模型(用于去重)""" - id = fields.IntField(pk=True, auto_generate=True) + id = fields.IntField(primary_key=True, auto_generate=True) event_id = fields.CharField(max_length=512, unique=True, description="事件唯一标识") log_file = fields.CharField(max_length=255, description="Binlog 文件名") log_pos = fields.BigIntField(description="Binlog 位置") @@ -246,7 +246,7 @@ class FailedOperation(Model): """失败的数据库操作持久化模型""" # 基础信息 - id = fields.IntField(pk=True, auto_generate=True) + id = fields.IntField(primary_key=True, auto_generate=True) operation_id = fields.CharField(max_length=64, unique=True, description="操作唯一ID (UUID)") timestamp = fields.DatetimeField(index=True, description="失败时间") @@ -286,7 +286,7 @@ class FailedOperation(Model): # class APILog(Model): # """API 相关日志模型""" -# id = fields.IntField(pk=True, auto_generate=True) +# id = fields.IntField(primary_key=True, auto_generate=True) # timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间") # level = fields.CharField(max_length=10, description="日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL") # api_request = fields.ForeignKeyField("monitor_models.APIRequest", null=True, description="关联的内部API请求") @@ -308,7 +308,7 @@ class FailedOperation(Model): # class PerformanceLog(Model): # """性能日志模型""" -# id = fields.IntField(pk=True, auto_generate=True) +# id = fields.IntField(primary_key=True, auto_generate=True) # timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间") # operation = fields.CharField(max_length=255, description="操作名称") # duration = fields.FloatField(description="执行时间(毫秒)") @@ -332,7 +332,7 @@ class FailedOperation(Model): # class SecurityLog(Model): # """安全日志模型""" -# id = fields.IntField(pk=True, auto_generate=True) +# id = fields.IntField(primary_key=True, auto_generate=True) # timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间") # event_type = fields.CharField(max_length=50, description="事件类型:登录、登出、权限变更等") # user = fields.CharField(max_length=255, null=True, description="用户标识") diff --git a/apps/io_api/protomodels.py b/apps/io_api/protomodels.py index cc442fb..cc8bc13 100644 --- a/apps/io_api/protomodels.py +++ b/apps/io_api/protomodels.py @@ -4,7 +4,7 @@ from tortoise import fields class ProtoBatchLog(TortoiseBaseModel): - id = fields.IntField(source_field='ID', pk=True) + id = fields.IntField(source_field='ID', primary_key=True) systime = fields.DatetimeField(source_field='SysTime', null=True) # Field name made lowercase. pidno = fields.CharField(source_field='PIDNO', max_length=32, null=True, description="ProfileNo") # Field name made lowercase. task = fields.CharField(source_field='Task', max_length=1000, null=True, description="任务") # Field name made lowercase. diff --git a/apps/io_api/routers.py b/apps/io_api/routers.py index a9ebe2a..651b27d 100644 --- a/apps/io_api/routers.py +++ b/apps/io_api/routers.py @@ -380,8 +380,10 @@ async def get_material( db_name = db_name.replace(" ", "") try: if materialnos != "...": - materialnos = ",".join([f"'{_}'" for _ in materialnos.split(",")]) - filter_string = f"`MaterialNo` IN ({materialnos})" + # 使用安全的SQL条件构建器,防止注入 + materialno_list = materialnos.split(",") + from globalobjects.db_manager import build_safe_condition + filter_string = build_safe_condition("MaterialNo", "IN", materialno_list) else: filter_string = "" result = await db_query(db_name=db_name, model_or_tablename="t_material", filter_string=filter_string) @@ -1178,12 +1180,14 @@ async def get_mo_page( log_api_request(request) db_name = db_name.replace(" ", "") - filter = [] + from globalobjects.db_manager import build_safe_filter + + conditions = [] if start_time: - filter.append(f"`DT_OrdStart` >= '{start_time}'") + conditions.append(("DT_OrdStart", ">=", start_time)) if end_time: - filter.append(f"`DT_OrdEnd` <= '{end_time}'") - filter_string = " AND ".join(filter) + conditions.append(("DT_OrdEnd", "<=", end_time)) + filter_string = build_safe_filter(conditions) try: result = await db_query(db_name=db_name, model_or_tablename="v_supply_mo", filter_string=filter_string, page_size=page_size, page_index=page_index) return standard_response( @@ -1273,12 +1277,15 @@ async def get_orderwc_page( ): log_api_request(request) db_name = db_name.replace(" ", "") - filter = [] + + from globalobjects.db_manager import build_safe_filter + + conditions = [] if start_time: - filter.append(f"`DT_Start` >= '{start_time}'") + conditions.append(("DT_Start", ">=", start_time)) if end_time: - filter.append(f"`DT_End` <= '{end_time}'") - filter_string = " AND ".join(filter) + conditions.append(("DT_End", "<=", end_time)) + filter_string = build_safe_filter(conditions) try: result = await db_query(db_name=db_name, model_or_tablename="v_orderwc", filter_string=filter_string, page_size=page_size, page_index=page_index) return standard_response( @@ -1313,7 +1320,9 @@ async def get_orderwc( ): log_api_request(request) db_name = db_name.replace(" ", "") - filter_string = f"`SupplyNo` = '{supplyno}'" + + from globalobjects.db_manager import build_safe_condition + filter_string = build_safe_condition("SupplyNo", "=", supplyno) try: result = await db_query(db_name=db_name, model_or_tablename="v_orderwc", filter_string=filter_string) return standard_response( @@ -1478,9 +1487,13 @@ async def delete_workreport( ): log_api_request(request) db_name = db_name.replace(" ", "") - filter_string = f"`SupplyNo`='{supplyno}'" + + from globalobjects.db_manager import build_safe_filter + + conditions = [("SupplyNo", "=", supplyno)] if not itemno == "...": - filter_string += f" AND `ItemNo`='{itemno}'" + conditions.append(("ItemNo", "=", itemno)) + filter_string = build_safe_filter(conditions) try: result = await db_delete(db_names=db_name, model_or_tablename="t_confirm", filter_string=filter_string) return standard_response( diff --git a/apps/io_api/utils/common.py b/apps/io_api/utils/common.py index 958c047..47e18de 100644 --- a/apps/io_api/utils/common.py +++ b/apps/io_api/utils/common.py @@ -238,17 +238,18 @@ async def drop_matched_data(data: List[Any], db_names: str, table_name: str, mat batch_size = 100 combinations_list = list(unique_combinations) + from globalobjects.db_manager import build_safe_filter + for i in range(0, len(combinations_list), batch_size): batch = combinations_list[i:i+batch_size] - conditions = [] + batch_conditions = [] for values in batch: # 构建条件,支持任意数量的字段 field_conditions = [] for db_field, value in zip(db_fields, values): - field_conditions.append(f"`{db_field}`='{value}'") - condition = " AND ".join(field_conditions) - conditions.append(f"({condition})") - filter_string = " OR ".join(conditions) + field_conditions.append((db_field, "=", value)) + batch_conditions.append(build_safe_filter(field_conditions)) + filter_string = " OR ".join(f"({cond})" for cond in batch_conditions) try: await db_delete(db_names=db_names, model_or_tablename=table_name, filter_string=filter_string) except Exception as e: diff --git a/core/lifespan.py b/core/lifespan.py index d5a3006..63f2f64 100644 --- a/core/lifespan.py +++ b/core/lifespan.py @@ -18,6 +18,7 @@ from apps.common.monitor.log_stream_service import start_log_stream, stop_log_st from globalobjects import EVENT_AGGREGATOR from core.settings import TURNON_BINLOG_LISTENER, TRUNON_SCHEDULER, MAX_EVENTS_BATCH_SIZE from core.database import check_db_connections, warmup_connections, start_pool_monitoring, db_init_manager +from core.task_manager import get_task_manager @asynccontextmanager @@ -118,12 +119,19 @@ async def lifespan(app): # 启动数据库连接检查任务(从原startup_event迁移) log_config.info("启动数据库连接检查任务...") - db_check_task = asyncio.create_task(schedule_db_checks()) + task_manager = get_task_manager() + db_check_task = task_manager.create_and_register( + "db_check_task", + schedule_db_checks() + ) log_config.info("数据库连接检查任务已启动") # 启动连接池监控任务 log_config.info("启动连接池监控任务...") - pool_monitor_task = asyncio.create_task(start_pool_monitoring()) + pool_monitor_task = task_manager.create_and_register( + "pool_monitor_task", + start_pool_monitoring() + ) log_config.info("连接池监控任务已启动") # 启动日志数据库批次刷新任务 @@ -145,7 +153,10 @@ async def lifespan(app): pass log_config.info("启动日志数据库批次刷新任务...") - log_db_flush_task = asyncio.create_task(schedule_log_db_flush()) + log_db_flush_task = task_manager.create_and_register( + "log_db_flush_task", + schedule_log_db_flush() + ) log_config.info("日志数据库批次刷新任务已启动") # 启动数据库健康检查器(独立后台任务,不依赖前端访问) @@ -218,7 +229,10 @@ async def lifespan(app): await asyncio.sleep(90) # 创建 Redis 健康检查任务 - redis_check_task = asyncio.create_task(schedule_redis_checks()) + redis_check_task = task_manager.create_and_register( + "redis_check_task", + schedule_redis_checks() + ) log_config.info("Redis 健康检查任务已启动") # 启动 Redis 消息消费者(处理来自数据库监听器的事件) @@ -383,11 +397,17 @@ async def lifespan(app): except Exception as e: log_config.error(f"Redis 事件清理任务启动失败: {e}") - asyncio.create_task(start_redis_consumer()) + task_manager.create_and_register( + "redis_consumer_task", + start_redis_consumer() + ) log_config.info("Redis 消息消费者已启动") - + # 启动 Redis 事件清理任务 - asyncio.create_task(cleanup_expired_events()) + task_manager.create_and_register( + "redis_cleanup_task", + cleanup_expired_events() + ) log_config.info("Redis 事件清理任务已启动") # 等待一段时间,确保所有服务正常启动 @@ -396,140 +416,119 @@ async def lifespan(app): yield # 应用运行期间 - # 应用关闭时执行的操作 - log_config.info("应用关闭中...") + # ============ 应用关闭阶段 ============ + # 关闭顺序:任务 -> 服务 -> 资源 + log_config.info("==================应用开始关闭==================") - # 0. 关闭数据库连接 - log_config.info("正在关闭数据库连接...") + # 阶段1: 取消所有后台任务(优先执行) + log_config.info("【阶段1】取消所有后台任务...") + task_manager = get_task_manager() + await task_manager.cancel_all(timeout=10.0) + log_config.info("✅ 所有后台任务已取消") + + # 阶段2: 停止各服务和监控器 + log_config.info("【阶段2】停止服务和监控器...") + + # 2.1 停止实时日志流服务 + log_config.info("停止实时日志流服务...") + await stop_log_stream() + log_config.info("✅ 实时日志流服务已停止") + + # 2.2 停止数据库健康检查器 + log_config.info("停止数据库健康检查器...") + await stop_db_health_checker() + log_config.info("✅ 数据库健康检查器已停止") + + # 2.3 停止失败操作恢复管理器 + log_config.info("停止失败操作恢复管理器...") + await stop_failed_operation_recovery() + log_config.info("✅ 失败操作恢复管理器已停止") + + # 2.4 停止 MySQL Binlog 监控 + if TURNON_BINLOG_LISTENER: + log_config.info("停止 MySQL Binlog 监控...") + binlog_listener.stop_monitoring() + log_config.info("✅ MySQL Binlog监控已停止") + + # 2.5 停止资源监控 + log_config.info("停止资源监控...") + resource_monitor.stop_monitoring() + log_config.info("✅ 系统资源监控已停止") + + # 2.6 停止事件聚合器 + log_config.info("停止事件聚合器...") + EVENT_AGGREGATOR.stop() + log_config.info("✅ 事件聚合器已停止") + + # 2.7 关闭事件线程池管理器 + log_config.info("关闭事件线程池...") + from globalobjects.event_aggregator import get_event_pool_manager + get_event_pool_manager().shutdown_all() + log_config.info("✅ 事件线程池已关闭") + + # 2.8 关闭调度器 + if TRUNON_SCHEDULER: + log_config.info("关闭调度器...") + scheduler_manager.shutdown() + log_config.info("✅ 定时任务管理器已关闭") + + log_config.info("✅ 所有服务已停止") + + # 阶段3: 释放资源和连接(最后执行) + log_config.info("【阶段3】释放资源和连接...") + + # 3.1 刷新 Redis 缓冲 + log_config.info("刷新 Redis 缓冲...") + import concurrent.futures + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + loop = asyncio.get_event_loop() + + def get_buffer_size(): + from apps.common.utils.redis_pool_manager import get_redis_pool_manager + return get_redis_pool_manager().get_buffer_size() + + try: + buffer_size = await loop.run_in_executor(executor, get_buffer_size) + if buffer_size > 0: + log_config.info(f"发现 {buffer_size} 个事件在本地缓冲中,准备刷新...") + + def flush_buffer(): + from apps.common.utils.redis_pool_manager import flush_event_buffer + return flush_event_buffer('db_events') + + flushed = await loop.run_in_executor(executor, flush_buffer) + log_config.info(f"✅ 缓冲刷新完成,成功刷新 {flushed} 个事件") + except Exception as e: + log_config.warning(f"⚠️ 刷新Redis缓冲失败: {e}") + + # 3.2 关闭事件辅助模块 + log_config.info("关闭事件辅助模块...") + try: + from apps.common.utils.event_helpers import shutdown_event_helpers + shutdown_event_helpers() + log_config.info("✅ 事件辅助模块已关闭") + except Exception as e: + log_config.warning(f"⚠️ 关闭事件辅助模块失败: {e}") + + # 3.3 关闭数据库连接(最后关闭) + log_config.info("关闭数据库连接...") try: from tortoise import Tortoise await Tortoise.close_connections() log_config.info("✅ 数据库连接已关闭") except Exception as e: - log_config.warning(f"⚠️ 关闭数据库连接时出错: {e}") + log_config.warning(f"⚠️ 关闭数据库连接失败: {e}") - # 1. 先停止 MySQL Binlog 监控(最依赖数据库) - if TURNON_BINLOG_LISTENER: - log_config.info("正在停止 MySQL Binlog 监控...") - binlog_listener.stop_monitoring() - log_config.info("==================MySQL Binlog监控已停止==================") - else: - log_config.debug("⚠️ MySQL Binlog监控未启动,无需停止") - - # 2. 停止 Redis 相关任务 - log_config.info("正在停止 Redis 相关任务...") - # 在线程池中执行缓冲刷新,避免阻塞事件循环 - import concurrent.futures - executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - loop = asyncio.get_event_loop() - - def get_buffer_size(): - from apps.common.utils.redis_pool_manager import get_redis_pool_manager - return get_redis_pool_manager().get_buffer_size() - - buffer_size = await loop.run_in_executor(executor, get_buffer_size) - if buffer_size > 0: - log_config.info(f"发现 {buffer_size} 个事件在本地缓冲中,准备刷新...") - - def flush_buffer(): - from apps.common.utils.redis_pool_manager import flush_event_buffer - return flush_event_buffer('db_events') - - flushed = await loop.run_in_executor(executor, flush_buffer) - log_config.info(f"缓冲刷新完成,成功刷新 {flushed} 个事件") - log_config.info("==================Redis 相关任务已停止==================") - - # 2.1 关闭事件辅助模块(DeadLetter队列等) - log_config.info("正在关闭事件辅助模块...") - from apps.common.utils.event_helpers import shutdown_event_helpers - shutdown_event_helpers() - log_config.info("==================事件辅助模块已关闭==================") - - # 3. 等待一段时间,确保所有任务完成 - log_config.info("⏳ 等待所有后台任务完成...") - await asyncio.sleep(5) # 等待5秒,让所有任务完成 - - # 4. 关闭调度器 - if TRUNON_SCHEDULER: - log_config.info("正在关闭调度器...") - scheduler_manager.shutdown() - log_config.info("==================定时任务管理器已关闭==================") - else: - log_config.debug("⚠️ 定时任务管理器未启动,无需关闭") + # 阶段4: 关闭日志系统(最后) + log_config.info("【阶段4】关闭日志系统...") - # 5. 停止资源监控 - log_config.info("正在停止资源监控...") - resource_monitor.stop_monitoring() - log_config.info("==================系统资源监控已停止==================") - - # 6. 停止事件聚合器 - log_config.info("正在停止事件聚合器...") - log_config.info("==================事件聚合器已停止==================") - EVENT_AGGREGATOR.stop() - log_config.info("==================事件聚合器已停止==================") - - # 6.1 关闭事件线程池管理器 - log_config.info("正在关闭事件线程池...") - from globalobjects.event_aggregator import get_event_pool_manager - get_event_pool_manager().shutdown_all() - log_config.info("==================事件线程池已关闭==================") - - # 7. 停止数据库健康检查器 - log_config.info("正在停止数据库健康检查器...") - await stop_db_health_checker() - log_config.info("==================数据库健康检查器已停止==================") - - # 8. 停止失败操作恢复管理器 - log_config.info("正在停止OperationRecovery管理器...") - await stop_failed_operation_recovery() - log_config.info("==================OperationRecovery管理器已停止==================") - - # 10. 取消后台任务 - if 'db_check_task' in locals(): - log_config.info("正在取消数据库连接检查任务...") - db_check_task.cancel() - try: - await db_check_task - except asyncio.CancelledError: - pass - log_config.info("==================数据库连接检查任务已取消==================") - - if 'log_db_flush_task' in locals(): - log_config.info("正在取消日志数据库批次刷新任务...") - log_db_flush_task.cancel() - try: - await log_db_flush_task - except asyncio.CancelledError: - pass - log_config.info("==================日志数据库批次刷新任务已取消==================") - - if 'pool_monitor_task' in locals(): - log_config.info("正在取消连接池监控任务...") - pool_monitor_task.cancel() - try: - await pool_monitor_task - except asyncio.CancelledError: - pass - log_config.info("==================连接池监控任务已取消==================") - - # 取消 Redis 健康检查任务 - if 'redis_check_task' in locals(): - log_config.info("正在取消 Redis 健康检查任务...") - redis_check_task.cancel() - try: - await redis_check_task - except asyncio.CancelledError: - pass - log_config.info("==================Redis 健康检查任务已取消==================") - - # 11. 等待一段时间,确保所有任务真正完成 - log_config.info("⏳ 等待所有任务彻底完成...") - await asyncio.sleep(3) # 再等待3秒 - + # 在关闭日志系统前输出最终提示 log_config.info("==================应用关闭完成==================") - - # 12. 关闭统一日志系统 + log_config.info("所有资源已释放,服务已完全停止") + await shutdown_logging() - - # 13. 停止实时日志流服务 - await stop_log_stream() + + # 使用print确保关闭后的提示能输出(日志系统已关闭) + print("=" * 50) + print("MyAPS API 应用已完全关闭") + print("=" * 50) diff --git a/core/middleware.py b/core/middleware.py index 5b4fd7c..0983841 100644 --- a/core/middleware.py +++ b/core/middleware.py @@ -15,6 +15,18 @@ DOC_PREFIXES = ["/static/swagger"] MDS_PATHS = ["/mds", "/mds/material", "/mds/workcenter", "/mds/mat-ver", "/mds/mat-wc", "/mds/mat-wc-bom", "/mds/mold", "/mds/mat-wc-mold"] +# 公开的GET接口(不需要认证) +# 用于健康检查、静态资源等 +PUBLIC_GET_PATHS = [ + "/health", # K8s/负载均衡健康检查 + "/health/database", # 数据库健康检查 +] + +# 公开的路径前缀 +PUBLIC_GET_PREFIXES = [ + "/static/", # 静态资源 +] + # 缓存已注册的路由信息,避免每次请求都重新解析 REGISTERED_ROUTES = [] @@ -235,10 +247,20 @@ def create_security_middleware(): content={"status_code": 404, "success": 0, "meta": {}, "message": "Not Found"} ) - # 对GET和OPTIONS方法直接放行 - if request_method in ["GET", "OPTIONS"]: + # OPTIONS方法直接放行(CORS预检) + if request_method == "OPTIONS": return await call_next(request) + # GET方法需要检查是否在公开路径列表 + if request_method == "GET": + is_public_path = ( + url_path in PUBLIC_GET_PATHS or + any(url_path.startswith(prefix) for prefix in PUBLIC_GET_PREFIXES) + ) + if is_public_path: + return await call_next(request) + # 非公开GET路径需要继续鉴权 + # 检查IP是否在白名单中 client_ip = request.client.host if is_ip_allowed(client_ip): @@ -248,6 +270,10 @@ def create_security_middleware(): if not API_KEY or request.headers.get("X-API-Key") == API_KEY: return await call_next(request) - return JSONResponse(status_code=200, content={"status_code": 403, "success": 0, "meta": {}, "message": "Forbidden: Invalid or missing API Key"}) + # 未授权请求返回真实HTTP 401状态码 + return JSONResponse( + status_code=401, + content={"status_code": 401, "success": 0, "meta": {}, "message": "Unauthorized: Invalid or missing API Key"} + ) return security_middleware diff --git a/core/task_manager.py b/core/task_manager.py new file mode 100644 index 0000000..ec58928 --- /dev/null +++ b/core/task_manager.py @@ -0,0 +1,181 @@ +""" +后台任务管理器 +统一管理所有后台任务的生命周期 +""" +import asyncio +from typing import Dict, Set, Optional, Any +from globalobjects import logger as log_config + + +class BackgroundTaskManager: + """ + 后台任务统一管理器 + + 功能: + 1. 统一注册所有后台任务 + 2. 提供优雅关闭机制 + 3. 确保关闭顺序正确(先取消任务,再释放资源) + """ + + def __init__(self): + self._tasks: Dict[str, asyncio.Task] = {} + self._shutdown_timeout: float = 10.0 + + def register(self, name: str, task: asyncio.Task) -> asyncio.Task: + """ + 注册后台任务 + + Args: + name: 任务名称(用于日志和调试) + task: asyncio.Task实例 + + Returns: + 返回传入的task,方便链式调用 + """ + self._tasks[name] = task + log_config.debug(f"后台任务已注册: {name}") + + # 添加完成回调,自动清理 + def _on_task_done(t: asyncio.Task): + task_name = None + for k, v in self._tasks.items(): + if v == t: + task_name = k + break + if task_name: + self._tasks.pop(task_name, None) + if not t.cancelled(): + exc = t.exception() + if exc: + log_config.error(f"后台任务异常退出: {task_name}", exc_info=exc) + else: + log_config.debug(f"后台任务正常完成: {task_name}") + + task.add_done_callback(_on_task_done) + return task + + def create_and_register( + self, + name: str, + coro, + *, + delay: float = 0.0 + ) -> asyncio.Task: + """ + 创建并注册后台任务 + + Args: + name: 任务名称 + coro: 协程对象 + delay: 启动延迟(秒) + + Returns: + 创建的Task实例 + """ + async def _wrapped_coro(): + if delay > 0: + await asyncio.sleep(delay) + await coro + + task = asyncio.create_task(_wrapped_coro()) + return self.register(name, task) + + async def cancel_all(self, timeout: Optional[float] = None) -> Dict[str, bool]: + """ + 取消所有后台任务 + + Args: + timeout: 超时时间(秒),None使用默认值 + + Returns: + 任务取消结果字典 {任务名: 是否成功} + """ + if not self._tasks: + return {} + + timeout = timeout or self._shutdown_timeout + results = {} + + log_config.info(f"开始取消 {len(self._tasks)} 个后台任务...") + + # 第一阶段:发送取消信号 + for name, task in list(self._tasks.items()): + if not task.done(): + task.cancel() + log_config.debug(f"已发送取消信号: {name}") + + # 第二阶段:等待任务完成 + async def wait_for_task(name: str, task: asyncio.Task) -> bool: + try: + await asyncio.wait_for(task, timeout=timeout) + return True + except asyncio.CancelledError: + return True # 正常取消 + except asyncio.TimeoutError: + log_config.warning(f"任务取消超时: {name}") + return False + except Exception as e: + log_config.error(f"任务取消异常: {name}, {e}") + return False + + # 并行等待所有任务 + wait_tasks = { + name: asyncio.create_task(wait_for_task(name, task)) + for name, task in list(self._tasks.items()) + } + + # 等待所有等待任务完成 + for name, wait_task in wait_tasks.items(): + try: + results[name] = await wait_task + except Exception as e: + log_config.error(f"等待任务完成失败: {name}, {e}") + results[name] = False + + # 清理已完成的任务 + self._tasks.clear() + + # 统计结果 + success_count = sum(1 for v in results.values() if v) + fail_count = len(results) - success_count + + if fail_count > 0: + log_config.warning(f"任务取消完成: 成功 {success_count}, 失败 {fail_count}") + else: + log_config.info(f"所有 {success_count} 个后台任务已成功取消") + + return results + + def get_task_names(self) -> Set[str]: + """获取所有已注册任务名称""" + return set(self._tasks.keys()) + + def get_task_count(self) -> int: + """获取已注册任务数量""" + return len(self._tasks) + + def has_task(self, name: str) -> bool: + """检查任务是否已注册""" + return name in self._tasks + + def set_shutdown_timeout(self, timeout: float): + """设置关闭超时时间""" + self._shutdown_timeout = timeout + + +# 全局任务管理器实例 +_task_manager: Optional[BackgroundTaskManager] = None + + +def get_task_manager() -> BackgroundTaskManager: + """获取全局任务管理器""" + global _task_manager + if _task_manager is None: + _task_manager = BackgroundTaskManager() + return _task_manager + + +def reset_task_manager(): + """重置任务管理器(用于测试)""" + global _task_manager + _task_manager = None diff --git a/globalobjects/db_manager.py b/globalobjects/db_manager.py index 87519e7..958aa9f 100644 --- a/globalobjects/db_manager.py +++ b/globalobjects/db_manager.py @@ -4,6 +4,7 @@ from datetime import datetime import time import asyncio import functools +import re from tortoise import Tortoise from tortoise.connection import connections @@ -16,10 +17,269 @@ from globalobjects import logger as log_config import os LOG_LEVEL = os.getenv("LOG_LEVEL") or "INFO" -# 获取统一日志器 logger = log_config.get_logger(__name__, level=LOG_LEVEL) +def escape_sql_value(value: Any) -> str: + """ + 安全转义SQL值,防止SQL注入 + + Args: + value: 要转义的值 + + Returns: + 转义后的安全字符串 + """ + if value is None: + return "NULL" + + if isinstance(value, bool): + return "1" if value else "0" + + if isinstance(value, (int, float)): + return str(value) + + if isinstance(value, datetime): + return f"'{value.strftime('%Y-%m-%d %H:%M:%S')}'" + + # 字符串处理:转义单引号 + str_value = str(value) + # 将单引号转义为两个单引号(SQL标准) + escaped = str_value.replace("'", "''") + return f"'{escaped}'" + + +def validate_identifier(identifier: str) -> str: + """ + 验证并安全化SQL标识符(表名、字段名) + + Args: + identifier: 标识符 + + Returns: + 安全的标识符(用反引号包裹) + + Raises: + ValueError: 如果标识符包含危险字符 + """ + # 移除首尾空格 + identifier = identifier.strip() + + # 检查危险字符 + dangerous_chars = ["'", '"', ';', '--', '/*', '*/', '\x00', '\n', '\r'] + for char in dangerous_chars: + if char in identifier: + raise ValueError(f"Invalid identifier: contains dangerous character '{char}'") + + # 用反引号包裹 + return f"`{identifier}`" + + +def build_safe_condition(field: str, operator: str, value: Any) -> str: + """ + 构建安全的SQL条件表达式 + + Args: + field: 字段名 + operator: 操作符 (=, !=, >, <, >=, <=, LIKE, IN) + value: 值 + + Returns: + 安全的SQL条件字符串 + """ + safe_field = validate_identifier(field) + operator = operator.upper().strip() + + if operator == "IN": + if not isinstance(value, (list, tuple)): + raise ValueError("IN operator requires a list or tuple of values") + escaped_values = [escape_sql_value(v) for v in value] + return f"{safe_field} IN ({', '.join(escaped_values)})" + + if operator == "LIKE": + return f"{safe_field} LIKE {escape_sql_value(value)}" + + # 标准比较操作符 + valid_operators = ['=', '!=', '<>', '>', '<', '>=', '<=', 'IS', 'IS NOT'] + if operator not in valid_operators: + raise ValueError(f"Invalid operator: {operator}") + + if operator in ('IS', 'IS NOT'): + if value is None: + return f"{safe_field} {operator} NULL" + raise ValueError(f"{operator} operator only accepts None value") + + return f"{safe_field} {operator} {escape_sql_value(value)}" + + +def build_safe_filter(conditions: List[Tuple[str, str, Any]], logic: str = "AND") -> str: + """ + 构建安全的WHERE条件字符串 + + Args: + conditions: 条件列表,每个条件为 (字段名, 操作符, 值) 元组 + logic: 逻辑连接符 (AND/OR) + + Returns: + 安全的WHERE条件字符串 + """ + if not conditions: + return "" + + logic = logic.upper().strip() + if logic not in ("AND", "OR"): + raise ValueError(f"Invalid logic operator: {logic}") + + safe_conditions = [build_safe_condition(*cond) for cond in conditions] + return f" {logic} ".join(safe_conditions) + + +def build_safe_order_by(fields: List[Tuple[str, str]]) -> str: + """ + 构建安全的ORDER BY子句 + + Args: + fields: 排序字段列表,每个元素为 (字段名, 方向) 元组 + 方向为 'ASC' 或 'DESC' + + Returns: + 安全的ORDER BY字符串 + """ + if not fields: + return "" + + order_parts = [] + for field, direction in fields: + safe_field = validate_identifier(field) + direction = direction.upper().strip() + if direction not in ("ASC", "DESC"): + raise ValueError(f"Invalid order direction: {direction}") + order_parts.append(f"{safe_field} {direction}") + + return ", ".join(order_parts) + + +def build_safe_select(fields: List[str]) -> str: + """ + 构建安全的SELECT字段列表 + + Args: + fields: 字段名列表 + + Returns: + 安全的SELECT字段字符串 + """ + if not fields: + return "*" + + return ", ".join(validate_identifier(f) for f in fields) + + +class SafeQueryBuilder: + """ + 安全SQL查询构建器 + + 提供链式调用的SQL构建接口,自动处理转义和验证 + """ + + def __init__(self, table_name: str): + """ + 初始化查询构建器 + + Args: + table_name: 表名 + """ + self._table = validate_identifier(table_name) + self._select_fields = "*" + self._conditions = [] + self._order_fields = [] + self._limit = None + self._offset = None + + def select(self, *fields: str) -> 'SafeQueryBuilder': + """设置SELECT字段""" + if fields: + self._select_fields = build_safe_select(list(fields)) + return self + + def where(self, field: str, operator: str, value: Any) -> 'SafeQueryBuilder': + """添加WHERE条件""" + self._conditions.append((field, operator, value)) + return self + + def where_in(self, field: str, values: List[Any]) -> 'SafeQueryBuilder': + """添加IN条件""" + self._conditions.append((field, "IN", values)) + return self + + def where_like(self, field: str, pattern: str) -> 'SafeQueryBuilder': + """添加LIKE条件""" + self._conditions.append((field, "LIKE", pattern)) + return self + + def where_between(self, field: str, start: Any, end: Any) -> 'SafeQueryBuilder': + """添加BETWEEN条件""" + safe_field = validate_identifier(field) + self._conditions.append((f"{safe_field} >= {escape_sql_value(start)}", "=", True)) + self._conditions.append((f"{safe_field} <= {escape_sql_value(end)}", "=", True)) + return self + + def order_by(self, field: str, direction: str = "ASC") -> 'SafeQueryBuilder': + """添加排序""" + self._order_fields.append((field, direction)) + return self + + def limit(self, count: int) -> 'SafeQueryBuilder': + """设置LIMIT""" + if count > 0: + self._limit = count + return self + + def offset(self, count: int) -> 'SafeQueryBuilder': + """设置OFFSET""" + if count >= 0: + self._offset = count + return self + + def build_select_sql(self) -> str: + """构建SELECT SQL语句""" + sql = f"SELECT {self._select_fields} FROM {self._table}" + + if self._conditions: + where_clause = build_safe_filter(self._conditions) + sql += f" WHERE {where_clause}" + + if self._order_fields: + order_clause = build_safe_order_by(self._order_fields) + sql += f" ORDER BY {order_clause}" + + if self._limit is not None: + sql += f" LIMIT {self._limit}" + + if self._offset is not None: + sql += f" OFFSET {self._offset}" + + return sql + + def build_count_sql(self) -> str: + """构建COUNT SQL语句""" + sql = f"SELECT COUNT(*) as total FROM {self._table}" + + if self._conditions: + where_clause = build_safe_filter(self._conditions) + sql += f" WHERE {where_clause}" + + return sql + + def build_delete_sql(self) -> str: + """构建DELETE SQL语句""" + if not self._conditions: + raise ValueError("DELETE operation requires WHERE conditions for safety") + + where_clause = build_safe_filter(self._conditions) + return f"DELETE FROM {self._table} WHERE {where_clause}" + + def dict_to_lower_keys(d: dict) -> dict: """ 将字典的键转换为小写 diff --git a/globalobjects/logger/models.py b/globalobjects/logger/models.py index c4df99d..f66383a 100644 --- a/globalobjects/logger/models.py +++ b/globalobjects/logger/models.py @@ -36,12 +36,6 @@ class LogRecord(BaseModel): extra: Optional[Dict[str, Any]] = None - model_config = { - "json_encoders": { - datetime: lambda v: v.isoformat() - } - } - @field_validator('level_name', mode='before') @classmethod def validate_level_name(cls, v: str) -> str: