mirror of
https://github.com/rnvm9wjdtj-bot/myaps_api.git
synced 2026-06-02 05:54:40 +00:00
fix(security,stability): 完成API安全与稳定性修复
- 安全: 修复鉴权失败返回码(HTTP 401/403替代200) - 安全: 新增SafeQueryBuilder封堵SQL注入入口 - 安全: 移除Pydantic json_encoders弃用配置 - 稳定: 统一后台任务托管与生命周期管理 - 稳定: 新增TaskManager统一管理后台任务 - 文档: 更新README.md与.env.example - 重构: routers.py使用安全SQL构建器替代字符串拼接
This commit is contained in:
@@ -10,6 +10,10 @@ TIMEZONE=+8
|
|||||||
# 项目目录配置(必填)
|
# 项目目录配置(必填)
|
||||||
PROJECT_DIR=YOUR_PROJECT_DIR
|
PROJECT_DIR=YOUR_PROJECT_DIR
|
||||||
|
|
||||||
|
# 安全配置(可选)
|
||||||
|
# API_KEY=your-api-key-here # API密钥,设置后非公开接口需要认证
|
||||||
|
# IP_WHITELIST=127.0.0.1,192.168.1.* # IP白名单,支持通配符、范围和CIDR
|
||||||
|
|
||||||
# 数据库配置
|
# 数据库配置
|
||||||
MYAPS_DB_HOST=localhost
|
MYAPS_DB_HOST=localhost
|
||||||
MYAPS_DB_PORT=3333
|
MYAPS_DB_PORT=3333
|
||||||
|
|||||||
@@ -159,12 +159,92 @@ MYAPS_DB_PASSWORD
|
|||||||
- `STAGING_DB_NAME` 默认为 `--s`
|
- `STAGING_DB_NAME` 默认为 `--s`
|
||||||
- 清洗模式与主业务数据库配置分离
|
- 清洗模式与主业务数据库配置分离
|
||||||
|
|
||||||
|
## 安全配置说明
|
||||||
|
|
||||||
|
### API 认证
|
||||||
|
|
||||||
|
项目支持基于 API Key 的认证机制:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 设置 API Key(可选)
|
||||||
|
API_KEY=your-api-key-here
|
||||||
|
```
|
||||||
|
|
||||||
|
- 未设置 `API_KEY` 时,所有接口可自由访问
|
||||||
|
- 设置后,非公开接口需要在请求头中携带 `X-API-Key`
|
||||||
|
- 认证失败返回 HTTP 401(不再返回 200)
|
||||||
|
|
||||||
|
### 公开接口
|
||||||
|
|
||||||
|
以下接口无需认证即可访问:
|
||||||
|
|
||||||
|
- `/health` - 健康检查(用于 K8s/负载均衡)
|
||||||
|
- `/health/database` - 数据库健康检查
|
||||||
|
- `/static/*` - 静态资源
|
||||||
|
- `/mds/*` - MDS 页面
|
||||||
|
- `/docs`, `/redoc` - API 文档(仅限内网访问)
|
||||||
|
|
||||||
|
### IP 白名单
|
||||||
|
|
||||||
|
支持多种格式的 IP 白名单配置:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
IP_WHITELIST=127.0.0.1,192.168.1.*,10.0.0.0/8,192.168.1.100-200
|
||||||
|
```
|
||||||
|
|
||||||
|
支持的格式:
|
||||||
|
- 精确 IP: `192.168.1.100`
|
||||||
|
- 通配符: `192.168.1.*`
|
||||||
|
- IP 范围: `192.168.1.100-200`
|
||||||
|
- CIDR 表示法: `10.0.0.0/8`
|
||||||
|
|
||||||
|
### SQL 注入防护
|
||||||
|
|
||||||
|
项目已实现 SQL 注入防护机制:
|
||||||
|
|
||||||
|
- 使用安全 SQL 构建函数(`escape_sql_value`, `build_safe_condition`)
|
||||||
|
- 自动转义用户输入
|
||||||
|
- 标识符验证防止注入攻击
|
||||||
|
- 所有外部输入都经过安全处理
|
||||||
|
|
||||||
## 监控与日志
|
## 监控与日志
|
||||||
|
|
||||||
- 实时日志与历史日志页面位于 `/monitor`
|
- 实时日志与历史日志页面位于 `/monitor`
|
||||||
- 统一日志系统位于 `globalobjects/logger/`
|
- 统一日志系统位于 `globalobjects/logger/`
|
||||||
- 开发期可通过 `./scripts/dev_server.sh logs -f` 查看实时日志
|
- 开发期可通过 `./scripts/dev_server.sh logs -f` 查看实时日志
|
||||||
|
|
||||||
|
## 生命周期管理
|
||||||
|
|
||||||
|
### 后台任务管理
|
||||||
|
|
||||||
|
项目使用统一的后台任务管理器(`core/task_manager.py`):
|
||||||
|
|
||||||
|
- 所有后台任务统一注册和跟踪
|
||||||
|
- 支持优雅关闭和超时保护
|
||||||
|
- 自动清理已完成的任务
|
||||||
|
|
||||||
|
### 关闭顺序
|
||||||
|
|
||||||
|
应用关闭时按照以下顺序执行,确保资源正确释放:
|
||||||
|
|
||||||
|
```
|
||||||
|
阶段1: 取消所有后台任务(优先执行)
|
||||||
|
阶段2: 停止服务和监控器
|
||||||
|
阶段3: 释放资源和连接(数据库最后关闭)
|
||||||
|
阶段4: 关闭日志系统
|
||||||
|
```
|
||||||
|
|
||||||
|
关闭时会输出明确提示:
|
||||||
|
|
||||||
|
```
|
||||||
|
==================应用关闭完成==================
|
||||||
|
所有资源已释放,服务已完全停止
|
||||||
|
==================================================
|
||||||
|
MyAPS API 应用已完全关闭
|
||||||
|
感谢使用,再见!
|
||||||
|
==================================================
|
||||||
|
```
|
||||||
|
|
||||||
## 当前验证状态
|
## 当前验证状态
|
||||||
|
|
||||||
在当前仓库环境下,以下命令已验证通过:
|
在当前仓库环境下,以下命令已验证通过:
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ def get_host_from_url(url: str) -> Optional[str]:
|
|||||||
|
|
||||||
class APIRequest(Model):
|
class APIRequest(Model):
|
||||||
"""API 请求记录模型"""
|
"""API 请求记录模型"""
|
||||||
id = fields.IntField(pk=True, auto_generate=True)
|
id = fields.IntField(primary_key=True, auto_generate=True)
|
||||||
request_id = fields.CharField(max_length=36, null=True, description="请求唯一标识(UUID)")
|
request_id = fields.CharField(max_length=36, null=True, description="请求唯一标识(UUID)")
|
||||||
timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="请求时间")
|
timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="请求时间")
|
||||||
method = fields.CharField(max_length=10, description="HTTP 方法")
|
method = fields.CharField(max_length=10, description="HTTP 方法")
|
||||||
@@ -144,7 +144,7 @@ class APIRequest(Model):
|
|||||||
|
|
||||||
class OutboundAPIRequest(Model):
|
class OutboundAPIRequest(Model):
|
||||||
"""对外 HTTP 请求记录模型"""
|
"""对外 HTTP 请求记录模型"""
|
||||||
id = fields.IntField(pk=True, auto_generate=True)
|
id = fields.IntField(primary_key=True, auto_generate=True)
|
||||||
timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="请求时间")
|
timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="请求时间")
|
||||||
method = fields.CharField(max_length=10, description="HTTP 方法")
|
method = fields.CharField(max_length=10, description="HTTP 方法")
|
||||||
url = fields.TextField(description="请求 URL")
|
url = fields.TextField(description="请求 URL")
|
||||||
@@ -177,7 +177,7 @@ class OutboundAPIRequest(Model):
|
|||||||
|
|
||||||
class SystemLog(Model):
|
class SystemLog(Model):
|
||||||
"""系统日志模型"""
|
"""系统日志模型"""
|
||||||
id = fields.IntField(pk=True, auto_generate=True)
|
id = fields.IntField(primary_key=True, auto_generate=True)
|
||||||
timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间")
|
timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间")
|
||||||
level = fields.CharField(max_length=10, description="日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL")
|
level = fields.CharField(max_length=10, description="日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL")
|
||||||
module = fields.CharField(max_length=255, description="模块名称")
|
module = fields.CharField(max_length=255, description="模块名称")
|
||||||
@@ -205,7 +205,7 @@ class SystemLog(Model):
|
|||||||
class BinlogPosition(Model):
|
class BinlogPosition(Model):
|
||||||
"""Binlog 位置记录模型"""
|
"""Binlog 位置记录模型"""
|
||||||
|
|
||||||
id = fields.IntField(pk=True, auto_generate=True)
|
id = fields.IntField(primary_key=True, auto_generate=True)
|
||||||
server_id = fields.CharField(max_length=255, description="MySQL 服务器标识")
|
server_id = fields.CharField(max_length=255, description="MySQL 服务器标识")
|
||||||
log_file = fields.CharField(max_length=255, description="Binlog 文件名")
|
log_file = fields.CharField(max_length=255, description="Binlog 文件名")
|
||||||
log_pos = fields.BigIntField(description="Binlog 位置")
|
log_pos = fields.BigIntField(description="Binlog 位置")
|
||||||
@@ -223,7 +223,7 @@ class BinlogPosition(Model):
|
|||||||
class ProcessedEvent(Model):
|
class ProcessedEvent(Model):
|
||||||
"""已处理的事件记录模型(用于去重)"""
|
"""已处理的事件记录模型(用于去重)"""
|
||||||
|
|
||||||
id = fields.IntField(pk=True, auto_generate=True)
|
id = fields.IntField(primary_key=True, auto_generate=True)
|
||||||
event_id = fields.CharField(max_length=512, unique=True, description="事件唯一标识")
|
event_id = fields.CharField(max_length=512, unique=True, description="事件唯一标识")
|
||||||
log_file = fields.CharField(max_length=255, description="Binlog 文件名")
|
log_file = fields.CharField(max_length=255, description="Binlog 文件名")
|
||||||
log_pos = fields.BigIntField(description="Binlog 位置")
|
log_pos = fields.BigIntField(description="Binlog 位置")
|
||||||
@@ -246,7 +246,7 @@ class FailedOperation(Model):
|
|||||||
"""失败的数据库操作持久化模型"""
|
"""失败的数据库操作持久化模型"""
|
||||||
|
|
||||||
# 基础信息
|
# 基础信息
|
||||||
id = fields.IntField(pk=True, auto_generate=True)
|
id = fields.IntField(primary_key=True, auto_generate=True)
|
||||||
operation_id = fields.CharField(max_length=64, unique=True, description="操作唯一ID (UUID)")
|
operation_id = fields.CharField(max_length=64, unique=True, description="操作唯一ID (UUID)")
|
||||||
timestamp = fields.DatetimeField(index=True, description="失败时间")
|
timestamp = fields.DatetimeField(index=True, description="失败时间")
|
||||||
|
|
||||||
@@ -286,7 +286,7 @@ class FailedOperation(Model):
|
|||||||
|
|
||||||
# class APILog(Model):
|
# class APILog(Model):
|
||||||
# """API 相关日志模型"""
|
# """API 相关日志模型"""
|
||||||
# id = fields.IntField(pk=True, auto_generate=True)
|
# id = fields.IntField(primary_key=True, auto_generate=True)
|
||||||
# timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间")
|
# timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间")
|
||||||
# level = fields.CharField(max_length=10, description="日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL")
|
# level = fields.CharField(max_length=10, description="日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL")
|
||||||
# api_request = fields.ForeignKeyField("monitor_models.APIRequest", null=True, description="关联的内部API请求")
|
# api_request = fields.ForeignKeyField("monitor_models.APIRequest", null=True, description="关联的内部API请求")
|
||||||
@@ -308,7 +308,7 @@ class FailedOperation(Model):
|
|||||||
|
|
||||||
# class PerformanceLog(Model):
|
# class PerformanceLog(Model):
|
||||||
# """性能日志模型"""
|
# """性能日志模型"""
|
||||||
# id = fields.IntField(pk=True, auto_generate=True)
|
# id = fields.IntField(primary_key=True, auto_generate=True)
|
||||||
# timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间")
|
# timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间")
|
||||||
# operation = fields.CharField(max_length=255, description="操作名称")
|
# operation = fields.CharField(max_length=255, description="操作名称")
|
||||||
# duration = fields.FloatField(description="执行时间(毫秒)")
|
# duration = fields.FloatField(description="执行时间(毫秒)")
|
||||||
@@ -332,7 +332,7 @@ class FailedOperation(Model):
|
|||||||
|
|
||||||
# class SecurityLog(Model):
|
# class SecurityLog(Model):
|
||||||
# """安全日志模型"""
|
# """安全日志模型"""
|
||||||
# id = fields.IntField(pk=True, auto_generate=True)
|
# id = fields.IntField(primary_key=True, auto_generate=True)
|
||||||
# timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间")
|
# timestamp = fields.DatetimeField(default=lambda: datetime.now(timezone.utc), description="日志时间")
|
||||||
# event_type = fields.CharField(max_length=50, description="事件类型:登录、登出、权限变更等")
|
# event_type = fields.CharField(max_length=50, description="事件类型:登录、登出、权限变更等")
|
||||||
# user = fields.CharField(max_length=255, null=True, description="用户标识")
|
# user = fields.CharField(max_length=255, null=True, description="用户标识")
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from tortoise import fields
|
|||||||
|
|
||||||
|
|
||||||
class ProtoBatchLog(TortoiseBaseModel):
|
class ProtoBatchLog(TortoiseBaseModel):
|
||||||
id = fields.IntField(source_field='ID', pk=True)
|
id = fields.IntField(source_field='ID', primary_key=True)
|
||||||
systime = fields.DatetimeField(source_field='SysTime', null=True) # Field name made lowercase.
|
systime = fields.DatetimeField(source_field='SysTime', null=True) # Field name made lowercase.
|
||||||
pidno = fields.CharField(source_field='PIDNO', max_length=32, null=True, description="ProfileNo") # Field name made lowercase.
|
pidno = fields.CharField(source_field='PIDNO', max_length=32, null=True, description="ProfileNo") # Field name made lowercase.
|
||||||
task = fields.CharField(source_field='Task', max_length=1000, null=True, description="任务") # Field name made lowercase.
|
task = fields.CharField(source_field='Task', max_length=1000, null=True, description="任务") # Field name made lowercase.
|
||||||
|
|||||||
+26
-13
@@ -380,8 +380,10 @@ async def get_material(
|
|||||||
db_name = db_name.replace(" ", "")
|
db_name = db_name.replace(" ", "")
|
||||||
try:
|
try:
|
||||||
if materialnos != "...":
|
if materialnos != "...":
|
||||||
materialnos = ",".join([f"'{_}'" for _ in materialnos.split(",")])
|
# 使用安全的SQL条件构建器,防止注入
|
||||||
filter_string = f"`MaterialNo` IN ({materialnos})"
|
materialno_list = materialnos.split(",")
|
||||||
|
from globalobjects.db_manager import build_safe_condition
|
||||||
|
filter_string = build_safe_condition("MaterialNo", "IN", materialno_list)
|
||||||
else:
|
else:
|
||||||
filter_string = ""
|
filter_string = ""
|
||||||
result = await db_query(db_name=db_name, model_or_tablename="t_material", filter_string=filter_string)
|
result = await db_query(db_name=db_name, model_or_tablename="t_material", filter_string=filter_string)
|
||||||
@@ -1178,12 +1180,14 @@ async def get_mo_page(
|
|||||||
log_api_request(request)
|
log_api_request(request)
|
||||||
db_name = db_name.replace(" ", "")
|
db_name = db_name.replace(" ", "")
|
||||||
|
|
||||||
filter = []
|
from globalobjects.db_manager import build_safe_filter
|
||||||
|
|
||||||
|
conditions = []
|
||||||
if start_time:
|
if start_time:
|
||||||
filter.append(f"`DT_OrdStart` >= '{start_time}'")
|
conditions.append(("DT_OrdStart", ">=", start_time))
|
||||||
if end_time:
|
if end_time:
|
||||||
filter.append(f"`DT_OrdEnd` <= '{end_time}'")
|
conditions.append(("DT_OrdEnd", "<=", end_time))
|
||||||
filter_string = " AND ".join(filter)
|
filter_string = build_safe_filter(conditions)
|
||||||
try:
|
try:
|
||||||
result = await db_query(db_name=db_name, model_or_tablename="v_supply_mo", filter_string=filter_string, page_size=page_size, page_index=page_index)
|
result = await db_query(db_name=db_name, model_or_tablename="v_supply_mo", filter_string=filter_string, page_size=page_size, page_index=page_index)
|
||||||
return standard_response(
|
return standard_response(
|
||||||
@@ -1273,12 +1277,15 @@ async def get_orderwc_page(
|
|||||||
):
|
):
|
||||||
log_api_request(request)
|
log_api_request(request)
|
||||||
db_name = db_name.replace(" ", "")
|
db_name = db_name.replace(" ", "")
|
||||||
filter = []
|
|
||||||
|
from globalobjects.db_manager import build_safe_filter
|
||||||
|
|
||||||
|
conditions = []
|
||||||
if start_time:
|
if start_time:
|
||||||
filter.append(f"`DT_Start` >= '{start_time}'")
|
conditions.append(("DT_Start", ">=", start_time))
|
||||||
if end_time:
|
if end_time:
|
||||||
filter.append(f"`DT_End` <= '{end_time}'")
|
conditions.append(("DT_End", "<=", end_time))
|
||||||
filter_string = " AND ".join(filter)
|
filter_string = build_safe_filter(conditions)
|
||||||
try:
|
try:
|
||||||
result = await db_query(db_name=db_name, model_or_tablename="v_orderwc", filter_string=filter_string, page_size=page_size, page_index=page_index)
|
result = await db_query(db_name=db_name, model_or_tablename="v_orderwc", filter_string=filter_string, page_size=page_size, page_index=page_index)
|
||||||
return standard_response(
|
return standard_response(
|
||||||
@@ -1313,7 +1320,9 @@ async def get_orderwc(
|
|||||||
):
|
):
|
||||||
log_api_request(request)
|
log_api_request(request)
|
||||||
db_name = db_name.replace(" ", "")
|
db_name = db_name.replace(" ", "")
|
||||||
filter_string = f"`SupplyNo` = '{supplyno}'"
|
|
||||||
|
from globalobjects.db_manager import build_safe_condition
|
||||||
|
filter_string = build_safe_condition("SupplyNo", "=", supplyno)
|
||||||
try:
|
try:
|
||||||
result = await db_query(db_name=db_name, model_or_tablename="v_orderwc", filter_string=filter_string)
|
result = await db_query(db_name=db_name, model_or_tablename="v_orderwc", filter_string=filter_string)
|
||||||
return standard_response(
|
return standard_response(
|
||||||
@@ -1478,9 +1487,13 @@ async def delete_workreport(
|
|||||||
):
|
):
|
||||||
log_api_request(request)
|
log_api_request(request)
|
||||||
db_name = db_name.replace(" ", "")
|
db_name = db_name.replace(" ", "")
|
||||||
filter_string = f"`SupplyNo`='{supplyno}'"
|
|
||||||
|
from globalobjects.db_manager import build_safe_filter
|
||||||
|
|
||||||
|
conditions = [("SupplyNo", "=", supplyno)]
|
||||||
if not itemno == "...":
|
if not itemno == "...":
|
||||||
filter_string += f" AND `ItemNo`='{itemno}'"
|
conditions.append(("ItemNo", "=", itemno))
|
||||||
|
filter_string = build_safe_filter(conditions)
|
||||||
try:
|
try:
|
||||||
result = await db_delete(db_names=db_name, model_or_tablename="t_confirm", filter_string=filter_string)
|
result = await db_delete(db_names=db_name, model_or_tablename="t_confirm", filter_string=filter_string)
|
||||||
return standard_response(
|
return standard_response(
|
||||||
|
|||||||
@@ -238,17 +238,18 @@ async def drop_matched_data(data: List[Any], db_names: str, table_name: str, mat
|
|||||||
batch_size = 100
|
batch_size = 100
|
||||||
combinations_list = list(unique_combinations)
|
combinations_list = list(unique_combinations)
|
||||||
|
|
||||||
|
from globalobjects.db_manager import build_safe_filter
|
||||||
|
|
||||||
for i in range(0, len(combinations_list), batch_size):
|
for i in range(0, len(combinations_list), batch_size):
|
||||||
batch = combinations_list[i:i+batch_size]
|
batch = combinations_list[i:i+batch_size]
|
||||||
conditions = []
|
batch_conditions = []
|
||||||
for values in batch:
|
for values in batch:
|
||||||
# 构建条件,支持任意数量的字段
|
# 构建条件,支持任意数量的字段
|
||||||
field_conditions = []
|
field_conditions = []
|
||||||
for db_field, value in zip(db_fields, values):
|
for db_field, value in zip(db_fields, values):
|
||||||
field_conditions.append(f"`{db_field}`='{value}'")
|
field_conditions.append((db_field, "=", value))
|
||||||
condition = " AND ".join(field_conditions)
|
batch_conditions.append(build_safe_filter(field_conditions))
|
||||||
conditions.append(f"({condition})")
|
filter_string = " OR ".join(f"({cond})" for cond in batch_conditions)
|
||||||
filter_string = " OR ".join(conditions)
|
|
||||||
try:
|
try:
|
||||||
await db_delete(db_names=db_names, model_or_tablename=table_name, filter_string=filter_string)
|
await db_delete(db_names=db_names, model_or_tablename=table_name, filter_string=filter_string)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
+133
-134
@@ -18,6 +18,7 @@ from apps.common.monitor.log_stream_service import start_log_stream, stop_log_st
|
|||||||
from globalobjects import EVENT_AGGREGATOR
|
from globalobjects import EVENT_AGGREGATOR
|
||||||
from core.settings import TURNON_BINLOG_LISTENER, TRUNON_SCHEDULER, MAX_EVENTS_BATCH_SIZE
|
from core.settings import TURNON_BINLOG_LISTENER, TRUNON_SCHEDULER, MAX_EVENTS_BATCH_SIZE
|
||||||
from core.database import check_db_connections, warmup_connections, start_pool_monitoring, db_init_manager
|
from core.database import check_db_connections, warmup_connections, start_pool_monitoring, db_init_manager
|
||||||
|
from core.task_manager import get_task_manager
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -118,12 +119,19 @@ async def lifespan(app):
|
|||||||
|
|
||||||
# 启动数据库连接检查任务(从原startup_event迁移)
|
# 启动数据库连接检查任务(从原startup_event迁移)
|
||||||
log_config.info("启动数据库连接检查任务...")
|
log_config.info("启动数据库连接检查任务...")
|
||||||
db_check_task = asyncio.create_task(schedule_db_checks())
|
task_manager = get_task_manager()
|
||||||
|
db_check_task = task_manager.create_and_register(
|
||||||
|
"db_check_task",
|
||||||
|
schedule_db_checks()
|
||||||
|
)
|
||||||
log_config.info("数据库连接检查任务已启动")
|
log_config.info("数据库连接检查任务已启动")
|
||||||
|
|
||||||
# 启动连接池监控任务
|
# 启动连接池监控任务
|
||||||
log_config.info("启动连接池监控任务...")
|
log_config.info("启动连接池监控任务...")
|
||||||
pool_monitor_task = asyncio.create_task(start_pool_monitoring())
|
pool_monitor_task = task_manager.create_and_register(
|
||||||
|
"pool_monitor_task",
|
||||||
|
start_pool_monitoring()
|
||||||
|
)
|
||||||
log_config.info("连接池监控任务已启动")
|
log_config.info("连接池监控任务已启动")
|
||||||
|
|
||||||
# 启动日志数据库批次刷新任务
|
# 启动日志数据库批次刷新任务
|
||||||
@@ -145,7 +153,10 @@ async def lifespan(app):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
log_config.info("启动日志数据库批次刷新任务...")
|
log_config.info("启动日志数据库批次刷新任务...")
|
||||||
log_db_flush_task = asyncio.create_task(schedule_log_db_flush())
|
log_db_flush_task = task_manager.create_and_register(
|
||||||
|
"log_db_flush_task",
|
||||||
|
schedule_log_db_flush()
|
||||||
|
)
|
||||||
log_config.info("日志数据库批次刷新任务已启动")
|
log_config.info("日志数据库批次刷新任务已启动")
|
||||||
|
|
||||||
# 启动数据库健康检查器(独立后台任务,不依赖前端访问)
|
# 启动数据库健康检查器(独立后台任务,不依赖前端访问)
|
||||||
@@ -218,7 +229,10 @@ async def lifespan(app):
|
|||||||
await asyncio.sleep(90)
|
await asyncio.sleep(90)
|
||||||
|
|
||||||
# 创建 Redis 健康检查任务
|
# 创建 Redis 健康检查任务
|
||||||
redis_check_task = asyncio.create_task(schedule_redis_checks())
|
redis_check_task = task_manager.create_and_register(
|
||||||
|
"redis_check_task",
|
||||||
|
schedule_redis_checks()
|
||||||
|
)
|
||||||
log_config.info("Redis 健康检查任务已启动")
|
log_config.info("Redis 健康检查任务已启动")
|
||||||
|
|
||||||
# 启动 Redis 消息消费者(处理来自数据库监听器的事件)
|
# 启动 Redis 消息消费者(处理来自数据库监听器的事件)
|
||||||
@@ -383,11 +397,17 @@ async def lifespan(app):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_config.error(f"Redis 事件清理任务启动失败: {e}")
|
log_config.error(f"Redis 事件清理任务启动失败: {e}")
|
||||||
|
|
||||||
asyncio.create_task(start_redis_consumer())
|
task_manager.create_and_register(
|
||||||
|
"redis_consumer_task",
|
||||||
|
start_redis_consumer()
|
||||||
|
)
|
||||||
log_config.info("Redis 消息消费者已启动")
|
log_config.info("Redis 消息消费者已启动")
|
||||||
|
|
||||||
# 启动 Redis 事件清理任务
|
# 启动 Redis 事件清理任务
|
||||||
asyncio.create_task(cleanup_expired_events())
|
task_manager.create_and_register(
|
||||||
|
"redis_cleanup_task",
|
||||||
|
cleanup_expired_events()
|
||||||
|
)
|
||||||
log_config.info("Redis 事件清理任务已启动")
|
log_config.info("Redis 事件清理任务已启动")
|
||||||
|
|
||||||
# 等待一段时间,确保所有服务正常启动
|
# 等待一段时间,确保所有服务正常启动
|
||||||
@@ -396,140 +416,119 @@ async def lifespan(app):
|
|||||||
|
|
||||||
yield # 应用运行期间
|
yield # 应用运行期间
|
||||||
|
|
||||||
# 应用关闭时执行的操作
|
# ============ 应用关闭阶段 ============
|
||||||
log_config.info("应用关闭中...")
|
# 关闭顺序:任务 -> 服务 -> 资源
|
||||||
|
log_config.info("==================应用开始关闭==================")
|
||||||
|
|
||||||
# 0. 关闭数据库连接
|
# 阶段1: 取消所有后台任务(优先执行)
|
||||||
log_config.info("正在关闭数据库连接...")
|
log_config.info("【阶段1】取消所有后台任务...")
|
||||||
|
task_manager = get_task_manager()
|
||||||
|
await task_manager.cancel_all(timeout=10.0)
|
||||||
|
log_config.info("✅ 所有后台任务已取消")
|
||||||
|
|
||||||
|
# 阶段2: 停止各服务和监控器
|
||||||
|
log_config.info("【阶段2】停止服务和监控器...")
|
||||||
|
|
||||||
|
# 2.1 停止实时日志流服务
|
||||||
|
log_config.info("停止实时日志流服务...")
|
||||||
|
await stop_log_stream()
|
||||||
|
log_config.info("✅ 实时日志流服务已停止")
|
||||||
|
|
||||||
|
# 2.2 停止数据库健康检查器
|
||||||
|
log_config.info("停止数据库健康检查器...")
|
||||||
|
await stop_db_health_checker()
|
||||||
|
log_config.info("✅ 数据库健康检查器已停止")
|
||||||
|
|
||||||
|
# 2.3 停止失败操作恢复管理器
|
||||||
|
log_config.info("停止失败操作恢复管理器...")
|
||||||
|
await stop_failed_operation_recovery()
|
||||||
|
log_config.info("✅ 失败操作恢复管理器已停止")
|
||||||
|
|
||||||
|
# 2.4 停止 MySQL Binlog 监控
|
||||||
|
if TURNON_BINLOG_LISTENER:
|
||||||
|
log_config.info("停止 MySQL Binlog 监控...")
|
||||||
|
binlog_listener.stop_monitoring()
|
||||||
|
log_config.info("✅ MySQL Binlog监控已停止")
|
||||||
|
|
||||||
|
# 2.5 停止资源监控
|
||||||
|
log_config.info("停止资源监控...")
|
||||||
|
resource_monitor.stop_monitoring()
|
||||||
|
log_config.info("✅ 系统资源监控已停止")
|
||||||
|
|
||||||
|
# 2.6 停止事件聚合器
|
||||||
|
log_config.info("停止事件聚合器...")
|
||||||
|
EVENT_AGGREGATOR.stop()
|
||||||
|
log_config.info("✅ 事件聚合器已停止")
|
||||||
|
|
||||||
|
# 2.7 关闭事件线程池管理器
|
||||||
|
log_config.info("关闭事件线程池...")
|
||||||
|
from globalobjects.event_aggregator import get_event_pool_manager
|
||||||
|
get_event_pool_manager().shutdown_all()
|
||||||
|
log_config.info("✅ 事件线程池已关闭")
|
||||||
|
|
||||||
|
# 2.8 关闭调度器
|
||||||
|
if TRUNON_SCHEDULER:
|
||||||
|
log_config.info("关闭调度器...")
|
||||||
|
scheduler_manager.shutdown()
|
||||||
|
log_config.info("✅ 定时任务管理器已关闭")
|
||||||
|
|
||||||
|
log_config.info("✅ 所有服务已停止")
|
||||||
|
|
||||||
|
# 阶段3: 释放资源和连接(最后执行)
|
||||||
|
log_config.info("【阶段3】释放资源和连接...")
|
||||||
|
|
||||||
|
# 3.1 刷新 Redis 缓冲
|
||||||
|
log_config.info("刷新 Redis 缓冲...")
|
||||||
|
import concurrent.futures
|
||||||
|
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
def get_buffer_size():
|
||||||
|
from apps.common.utils.redis_pool_manager import get_redis_pool_manager
|
||||||
|
return get_redis_pool_manager().get_buffer_size()
|
||||||
|
|
||||||
|
try:
|
||||||
|
buffer_size = await loop.run_in_executor(executor, get_buffer_size)
|
||||||
|
if buffer_size > 0:
|
||||||
|
log_config.info(f"发现 {buffer_size} 个事件在本地缓冲中,准备刷新...")
|
||||||
|
|
||||||
|
def flush_buffer():
|
||||||
|
from apps.common.utils.redis_pool_manager import flush_event_buffer
|
||||||
|
return flush_event_buffer('db_events')
|
||||||
|
|
||||||
|
flushed = await loop.run_in_executor(executor, flush_buffer)
|
||||||
|
log_config.info(f"✅ 缓冲刷新完成,成功刷新 {flushed} 个事件")
|
||||||
|
except Exception as e:
|
||||||
|
log_config.warning(f"⚠️ 刷新Redis缓冲失败: {e}")
|
||||||
|
|
||||||
|
# 3.2 关闭事件辅助模块
|
||||||
|
log_config.info("关闭事件辅助模块...")
|
||||||
|
try:
|
||||||
|
from apps.common.utils.event_helpers import shutdown_event_helpers
|
||||||
|
shutdown_event_helpers()
|
||||||
|
log_config.info("✅ 事件辅助模块已关闭")
|
||||||
|
except Exception as e:
|
||||||
|
log_config.warning(f"⚠️ 关闭事件辅助模块失败: {e}")
|
||||||
|
|
||||||
|
# 3.3 关闭数据库连接(最后关闭)
|
||||||
|
log_config.info("关闭数据库连接...")
|
||||||
try:
|
try:
|
||||||
from tortoise import Tortoise
|
from tortoise import Tortoise
|
||||||
await Tortoise.close_connections()
|
await Tortoise.close_connections()
|
||||||
log_config.info("✅ 数据库连接已关闭")
|
log_config.info("✅ 数据库连接已关闭")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_config.warning(f"⚠️ 关闭数据库连接时出错: {e}")
|
log_config.warning(f"⚠️ 关闭数据库连接失败: {e}")
|
||||||
|
|
||||||
# 1. 先停止 MySQL Binlog 监控(最依赖数据库)
|
# 阶段4: 关闭日志系统(最后)
|
||||||
if TURNON_BINLOG_LISTENER:
|
log_config.info("【阶段4】关闭日志系统...")
|
||||||
log_config.info("正在停止 MySQL Binlog 监控...")
|
|
||||||
binlog_listener.stop_monitoring()
|
|
||||||
log_config.info("==================MySQL Binlog监控已停止==================")
|
|
||||||
else:
|
|
||||||
log_config.debug("⚠️ MySQL Binlog监控未启动,无需停止")
|
|
||||||
|
|
||||||
# 2. 停止 Redis 相关任务
|
|
||||||
log_config.info("正在停止 Redis 相关任务...")
|
|
||||||
# 在线程池中执行缓冲刷新,避免阻塞事件循环
|
|
||||||
import concurrent.futures
|
|
||||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
def get_buffer_size():
|
|
||||||
from apps.common.utils.redis_pool_manager import get_redis_pool_manager
|
|
||||||
return get_redis_pool_manager().get_buffer_size()
|
|
||||||
|
|
||||||
buffer_size = await loop.run_in_executor(executor, get_buffer_size)
|
|
||||||
if buffer_size > 0:
|
|
||||||
log_config.info(f"发现 {buffer_size} 个事件在本地缓冲中,准备刷新...")
|
|
||||||
|
|
||||||
def flush_buffer():
|
|
||||||
from apps.common.utils.redis_pool_manager import flush_event_buffer
|
|
||||||
return flush_event_buffer('db_events')
|
|
||||||
|
|
||||||
flushed = await loop.run_in_executor(executor, flush_buffer)
|
|
||||||
log_config.info(f"缓冲刷新完成,成功刷新 {flushed} 个事件")
|
|
||||||
log_config.info("==================Redis 相关任务已停止==================")
|
|
||||||
|
|
||||||
# 2.1 关闭事件辅助模块(DeadLetter队列等)
|
|
||||||
log_config.info("正在关闭事件辅助模块...")
|
|
||||||
from apps.common.utils.event_helpers import shutdown_event_helpers
|
|
||||||
shutdown_event_helpers()
|
|
||||||
log_config.info("==================事件辅助模块已关闭==================")
|
|
||||||
|
|
||||||
# 3. 等待一段时间,确保所有任务完成
|
|
||||||
log_config.info("⏳ 等待所有后台任务完成...")
|
|
||||||
await asyncio.sleep(5) # 等待5秒,让所有任务完成
|
|
||||||
|
|
||||||
# 4. 关闭调度器
|
|
||||||
if TRUNON_SCHEDULER:
|
|
||||||
log_config.info("正在关闭调度器...")
|
|
||||||
scheduler_manager.shutdown()
|
|
||||||
log_config.info("==================定时任务管理器已关闭==================")
|
|
||||||
else:
|
|
||||||
log_config.debug("⚠️ 定时任务管理器未启动,无需关闭")
|
|
||||||
|
|
||||||
# 5. 停止资源监控
|
# 在关闭日志系统前输出最终提示
|
||||||
log_config.info("正在停止资源监控...")
|
|
||||||
resource_monitor.stop_monitoring()
|
|
||||||
log_config.info("==================系统资源监控已停止==================")
|
|
||||||
|
|
||||||
# 6. 停止事件聚合器
|
|
||||||
log_config.info("正在停止事件聚合器...")
|
|
||||||
log_config.info("==================事件聚合器已停止==================")
|
|
||||||
EVENT_AGGREGATOR.stop()
|
|
||||||
log_config.info("==================事件聚合器已停止==================")
|
|
||||||
|
|
||||||
# 6.1 关闭事件线程池管理器
|
|
||||||
log_config.info("正在关闭事件线程池...")
|
|
||||||
from globalobjects.event_aggregator import get_event_pool_manager
|
|
||||||
get_event_pool_manager().shutdown_all()
|
|
||||||
log_config.info("==================事件线程池已关闭==================")
|
|
||||||
|
|
||||||
# 7. 停止数据库健康检查器
|
|
||||||
log_config.info("正在停止数据库健康检查器...")
|
|
||||||
await stop_db_health_checker()
|
|
||||||
log_config.info("==================数据库健康检查器已停止==================")
|
|
||||||
|
|
||||||
# 8. 停止失败操作恢复管理器
|
|
||||||
log_config.info("正在停止OperationRecovery管理器...")
|
|
||||||
await stop_failed_operation_recovery()
|
|
||||||
log_config.info("==================OperationRecovery管理器已停止==================")
|
|
||||||
|
|
||||||
# 10. 取消后台任务
|
|
||||||
if 'db_check_task' in locals():
|
|
||||||
log_config.info("正在取消数据库连接检查任务...")
|
|
||||||
db_check_task.cancel()
|
|
||||||
try:
|
|
||||||
await db_check_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
log_config.info("==================数据库连接检查任务已取消==================")
|
|
||||||
|
|
||||||
if 'log_db_flush_task' in locals():
|
|
||||||
log_config.info("正在取消日志数据库批次刷新任务...")
|
|
||||||
log_db_flush_task.cancel()
|
|
||||||
try:
|
|
||||||
await log_db_flush_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
log_config.info("==================日志数据库批次刷新任务已取消==================")
|
|
||||||
|
|
||||||
if 'pool_monitor_task' in locals():
|
|
||||||
log_config.info("正在取消连接池监控任务...")
|
|
||||||
pool_monitor_task.cancel()
|
|
||||||
try:
|
|
||||||
await pool_monitor_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
log_config.info("==================连接池监控任务已取消==================")
|
|
||||||
|
|
||||||
# 取消 Redis 健康检查任务
|
|
||||||
if 'redis_check_task' in locals():
|
|
||||||
log_config.info("正在取消 Redis 健康检查任务...")
|
|
||||||
redis_check_task.cancel()
|
|
||||||
try:
|
|
||||||
await redis_check_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
log_config.info("==================Redis 健康检查任务已取消==================")
|
|
||||||
|
|
||||||
# 11. 等待一段时间,确保所有任务真正完成
|
|
||||||
log_config.info("⏳ 等待所有任务彻底完成...")
|
|
||||||
await asyncio.sleep(3) # 再等待3秒
|
|
||||||
|
|
||||||
log_config.info("==================应用关闭完成==================")
|
log_config.info("==================应用关闭完成==================")
|
||||||
|
log_config.info("所有资源已释放,服务已完全停止")
|
||||||
# 12. 关闭统一日志系统
|
|
||||||
await shutdown_logging()
|
await shutdown_logging()
|
||||||
|
|
||||||
# 13. 停止实时日志流服务
|
# 使用print确保关闭后的提示能输出(日志系统已关闭)
|
||||||
await stop_log_stream()
|
print("=" * 50)
|
||||||
|
print("MyAPS API 应用已完全关闭")
|
||||||
|
print("=" * 50)
|
||||||
|
|||||||
+29
-3
@@ -15,6 +15,18 @@ DOC_PREFIXES = ["/static/swagger"]
|
|||||||
MDS_PATHS = ["/mds", "/mds/material", "/mds/workcenter", "/mds/mat-ver",
|
MDS_PATHS = ["/mds", "/mds/material", "/mds/workcenter", "/mds/mat-ver",
|
||||||
"/mds/mat-wc", "/mds/mat-wc-bom", "/mds/mold", "/mds/mat-wc-mold"]
|
"/mds/mat-wc", "/mds/mat-wc-bom", "/mds/mold", "/mds/mat-wc-mold"]
|
||||||
|
|
||||||
|
# 公开的GET接口(不需要认证)
|
||||||
|
# 用于健康检查、静态资源等
|
||||||
|
PUBLIC_GET_PATHS = [
|
||||||
|
"/health", # K8s/负载均衡健康检查
|
||||||
|
"/health/database", # 数据库健康检查
|
||||||
|
]
|
||||||
|
|
||||||
|
# 公开的路径前缀
|
||||||
|
PUBLIC_GET_PREFIXES = [
|
||||||
|
"/static/", # 静态资源
|
||||||
|
]
|
||||||
|
|
||||||
# 缓存已注册的路由信息,避免每次请求都重新解析
|
# 缓存已注册的路由信息,避免每次请求都重新解析
|
||||||
REGISTERED_ROUTES = []
|
REGISTERED_ROUTES = []
|
||||||
|
|
||||||
@@ -235,10 +247,20 @@ def create_security_middleware():
|
|||||||
content={"status_code": 404, "success": 0, "meta": {}, "message": "Not Found"}
|
content={"status_code": 404, "success": 0, "meta": {}, "message": "Not Found"}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 对GET和OPTIONS方法直接放行
|
# OPTIONS方法直接放行(CORS预检)
|
||||||
if request_method in ["GET", "OPTIONS"]:
|
if request_method == "OPTIONS":
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
|
# GET方法需要检查是否在公开路径列表
|
||||||
|
if request_method == "GET":
|
||||||
|
is_public_path = (
|
||||||
|
url_path in PUBLIC_GET_PATHS or
|
||||||
|
any(url_path.startswith(prefix) for prefix in PUBLIC_GET_PREFIXES)
|
||||||
|
)
|
||||||
|
if is_public_path:
|
||||||
|
return await call_next(request)
|
||||||
|
# 非公开GET路径需要继续鉴权
|
||||||
|
|
||||||
# 检查IP是否在白名单中
|
# 检查IP是否在白名单中
|
||||||
client_ip = request.client.host
|
client_ip = request.client.host
|
||||||
if is_ip_allowed(client_ip):
|
if is_ip_allowed(client_ip):
|
||||||
@@ -248,6 +270,10 @@ def create_security_middleware():
|
|||||||
if not API_KEY or request.headers.get("X-API-Key") == API_KEY:
|
if not API_KEY or request.headers.get("X-API-Key") == API_KEY:
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
return JSONResponse(status_code=200, content={"status_code": 403, "success": 0, "meta": {}, "message": "Forbidden: Invalid or missing API Key"})
|
# 未授权请求返回真实HTTP 401状态码
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={"status_code": 401, "success": 0, "meta": {}, "message": "Unauthorized: Invalid or missing API Key"}
|
||||||
|
)
|
||||||
|
|
||||||
return security_middleware
|
return security_middleware
|
||||||
|
|||||||
@@ -0,0 +1,181 @@
|
|||||||
|
"""
|
||||||
|
后台任务管理器
|
||||||
|
统一管理所有后台任务的生命周期
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, Set, Optional, Any
|
||||||
|
from globalobjects import logger as log_config
|
||||||
|
|
||||||
|
|
||||||
|
class BackgroundTaskManager:
|
||||||
|
"""
|
||||||
|
后台任务统一管理器
|
||||||
|
|
||||||
|
功能:
|
||||||
|
1. 统一注册所有后台任务
|
||||||
|
2. 提供优雅关闭机制
|
||||||
|
3. 确保关闭顺序正确(先取消任务,再释放资源)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._tasks: Dict[str, asyncio.Task] = {}
|
||||||
|
self._shutdown_timeout: float = 10.0
|
||||||
|
|
||||||
|
def register(self, name: str, task: asyncio.Task) -> asyncio.Task:
|
||||||
|
"""
|
||||||
|
注册后台任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 任务名称(用于日志和调试)
|
||||||
|
task: asyncio.Task实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
返回传入的task,方便链式调用
|
||||||
|
"""
|
||||||
|
self._tasks[name] = task
|
||||||
|
log_config.debug(f"后台任务已注册: {name}")
|
||||||
|
|
||||||
|
# 添加完成回调,自动清理
|
||||||
|
def _on_task_done(t: asyncio.Task):
|
||||||
|
task_name = None
|
||||||
|
for k, v in self._tasks.items():
|
||||||
|
if v == t:
|
||||||
|
task_name = k
|
||||||
|
break
|
||||||
|
if task_name:
|
||||||
|
self._tasks.pop(task_name, None)
|
||||||
|
if not t.cancelled():
|
||||||
|
exc = t.exception()
|
||||||
|
if exc:
|
||||||
|
log_config.error(f"后台任务异常退出: {task_name}", exc_info=exc)
|
||||||
|
else:
|
||||||
|
log_config.debug(f"后台任务正常完成: {task_name}")
|
||||||
|
|
||||||
|
task.add_done_callback(_on_task_done)
|
||||||
|
return task
|
||||||
|
|
||||||
|
def create_and_register(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
coro,
|
||||||
|
*,
|
||||||
|
delay: float = 0.0
|
||||||
|
) -> asyncio.Task:
|
||||||
|
"""
|
||||||
|
创建并注册后台任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 任务名称
|
||||||
|
coro: 协程对象
|
||||||
|
delay: 启动延迟(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
创建的Task实例
|
||||||
|
"""
|
||||||
|
async def _wrapped_coro():
|
||||||
|
if delay > 0:
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
await coro
|
||||||
|
|
||||||
|
task = asyncio.create_task(_wrapped_coro())
|
||||||
|
return self.register(name, task)
|
||||||
|
|
||||||
|
async def cancel_all(self, timeout: Optional[float] = None) -> Dict[str, bool]:
|
||||||
|
"""
|
||||||
|
取消所有后台任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: 超时时间(秒),None使用默认值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务取消结果字典 {任务名: 是否成功}
|
||||||
|
"""
|
||||||
|
if not self._tasks:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
timeout = timeout or self._shutdown_timeout
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
log_config.info(f"开始取消 {len(self._tasks)} 个后台任务...")
|
||||||
|
|
||||||
|
# 第一阶段:发送取消信号
|
||||||
|
for name, task in list(self._tasks.items()):
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
log_config.debug(f"已发送取消信号: {name}")
|
||||||
|
|
||||||
|
# 第二阶段:等待任务完成
|
||||||
|
async def wait_for_task(name: str, task: asyncio.Task) -> bool:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(task, timeout=timeout)
|
||||||
|
return True
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return True # 正常取消
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
log_config.warning(f"任务取消超时: {name}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
log_config.error(f"任务取消异常: {name}, {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 并行等待所有任务
|
||||||
|
wait_tasks = {
|
||||||
|
name: asyncio.create_task(wait_for_task(name, task))
|
||||||
|
for name, task in list(self._tasks.items())
|
||||||
|
}
|
||||||
|
|
||||||
|
# 等待所有等待任务完成
|
||||||
|
for name, wait_task in wait_tasks.items():
|
||||||
|
try:
|
||||||
|
results[name] = await wait_task
|
||||||
|
except Exception as e:
|
||||||
|
log_config.error(f"等待任务完成失败: {name}, {e}")
|
||||||
|
results[name] = False
|
||||||
|
|
||||||
|
# 清理已完成的任务
|
||||||
|
self._tasks.clear()
|
||||||
|
|
||||||
|
# 统计结果
|
||||||
|
success_count = sum(1 for v in results.values() if v)
|
||||||
|
fail_count = len(results) - success_count
|
||||||
|
|
||||||
|
if fail_count > 0:
|
||||||
|
log_config.warning(f"任务取消完成: 成功 {success_count}, 失败 {fail_count}")
|
||||||
|
else:
|
||||||
|
log_config.info(f"所有 {success_count} 个后台任务已成功取消")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_task_names(self) -> Set[str]:
|
||||||
|
"""获取所有已注册任务名称"""
|
||||||
|
return set(self._tasks.keys())
|
||||||
|
|
||||||
|
def get_task_count(self) -> int:
|
||||||
|
"""获取已注册任务数量"""
|
||||||
|
return len(self._tasks)
|
||||||
|
|
||||||
|
def has_task(self, name: str) -> bool:
|
||||||
|
"""检查任务是否已注册"""
|
||||||
|
return name in self._tasks
|
||||||
|
|
||||||
|
def set_shutdown_timeout(self, timeout: float):
|
||||||
|
"""设置关闭超时时间"""
|
||||||
|
self._shutdown_timeout = timeout
|
||||||
|
|
||||||
|
|
||||||
|
# 全局任务管理器实例
|
||||||
|
_task_manager: Optional[BackgroundTaskManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_manager() -> BackgroundTaskManager:
|
||||||
|
"""获取全局任务管理器"""
|
||||||
|
global _task_manager
|
||||||
|
if _task_manager is None:
|
||||||
|
_task_manager = BackgroundTaskManager()
|
||||||
|
return _task_manager
|
||||||
|
|
||||||
|
|
||||||
|
def reset_task_manager():
|
||||||
|
"""重置任务管理器(用于测试)"""
|
||||||
|
global _task_manager
|
||||||
|
_task_manager = None
|
||||||
+261
-1
@@ -4,6 +4,7 @@ from datetime import datetime
|
|||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
|
import re
|
||||||
|
|
||||||
from tortoise import Tortoise
|
from tortoise import Tortoise
|
||||||
from tortoise.connection import connections
|
from tortoise.connection import connections
|
||||||
@@ -16,10 +17,269 @@ from globalobjects import logger as log_config
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
LOG_LEVEL = os.getenv("LOG_LEVEL") or "INFO"
|
LOG_LEVEL = os.getenv("LOG_LEVEL") or "INFO"
|
||||||
# 获取统一日志器
|
|
||||||
logger = log_config.get_logger(__name__, level=LOG_LEVEL)
|
logger = log_config.get_logger(__name__, level=LOG_LEVEL)
|
||||||
|
|
||||||
|
|
||||||
|
def escape_sql_value(value: Any) -> str:
|
||||||
|
"""
|
||||||
|
安全转义SQL值,防止SQL注入
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: 要转义的值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
转义后的安全字符串
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return "NULL"
|
||||||
|
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return "1" if value else "0"
|
||||||
|
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
return f"'{value.strftime('%Y-%m-%d %H:%M:%S')}'"
|
||||||
|
|
||||||
|
# 字符串处理:转义单引号
|
||||||
|
str_value = str(value)
|
||||||
|
# 将单引号转义为两个单引号(SQL标准)
|
||||||
|
escaped = str_value.replace("'", "''")
|
||||||
|
return f"'{escaped}'"
|
||||||
|
|
||||||
|
|
||||||
|
def validate_identifier(identifier: str) -> str:
|
||||||
|
"""
|
||||||
|
验证并安全化SQL标识符(表名、字段名)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier: 标识符
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
安全的标识符(用反引号包裹)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果标识符包含危险字符
|
||||||
|
"""
|
||||||
|
# 移除首尾空格
|
||||||
|
identifier = identifier.strip()
|
||||||
|
|
||||||
|
# 检查危险字符
|
||||||
|
dangerous_chars = ["'", '"', ';', '--', '/*', '*/', '\x00', '\n', '\r']
|
||||||
|
for char in dangerous_chars:
|
||||||
|
if char in identifier:
|
||||||
|
raise ValueError(f"Invalid identifier: contains dangerous character '{char}'")
|
||||||
|
|
||||||
|
# 用反引号包裹
|
||||||
|
return f"`{identifier}`"
|
||||||
|
|
||||||
|
|
||||||
|
def build_safe_condition(field: str, operator: str, value: Any) -> str:
|
||||||
|
"""
|
||||||
|
构建安全的SQL条件表达式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: 字段名
|
||||||
|
operator: 操作符 (=, !=, >, <, >=, <=, LIKE, IN)
|
||||||
|
value: 值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
安全的SQL条件字符串
|
||||||
|
"""
|
||||||
|
safe_field = validate_identifier(field)
|
||||||
|
operator = operator.upper().strip()
|
||||||
|
|
||||||
|
if operator == "IN":
|
||||||
|
if not isinstance(value, (list, tuple)):
|
||||||
|
raise ValueError("IN operator requires a list or tuple of values")
|
||||||
|
escaped_values = [escape_sql_value(v) for v in value]
|
||||||
|
return f"{safe_field} IN ({', '.join(escaped_values)})"
|
||||||
|
|
||||||
|
if operator == "LIKE":
|
||||||
|
return f"{safe_field} LIKE {escape_sql_value(value)}"
|
||||||
|
|
||||||
|
# 标准比较操作符
|
||||||
|
valid_operators = ['=', '!=', '<>', '>', '<', '>=', '<=', 'IS', 'IS NOT']
|
||||||
|
if operator not in valid_operators:
|
||||||
|
raise ValueError(f"Invalid operator: {operator}")
|
||||||
|
|
||||||
|
if operator in ('IS', 'IS NOT'):
|
||||||
|
if value is None:
|
||||||
|
return f"{safe_field} {operator} NULL"
|
||||||
|
raise ValueError(f"{operator} operator only accepts None value")
|
||||||
|
|
||||||
|
return f"{safe_field} {operator} {escape_sql_value(value)}"
|
||||||
|
|
||||||
|
|
||||||
|
def build_safe_filter(conditions: List[Tuple[str, str, Any]], logic: str = "AND") -> str:
|
||||||
|
"""
|
||||||
|
构建安全的WHERE条件字符串
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conditions: 条件列表,每个条件为 (字段名, 操作符, 值) 元组
|
||||||
|
logic: 逻辑连接符 (AND/OR)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
安全的WHERE条件字符串
|
||||||
|
"""
|
||||||
|
if not conditions:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
logic = logic.upper().strip()
|
||||||
|
if logic not in ("AND", "OR"):
|
||||||
|
raise ValueError(f"Invalid logic operator: {logic}")
|
||||||
|
|
||||||
|
safe_conditions = [build_safe_condition(*cond) for cond in conditions]
|
||||||
|
return f" {logic} ".join(safe_conditions)
|
||||||
|
|
||||||
|
|
||||||
|
def build_safe_order_by(fields: List[Tuple[str, str]]) -> str:
|
||||||
|
"""
|
||||||
|
构建安全的ORDER BY子句
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fields: 排序字段列表,每个元素为 (字段名, 方向) 元组
|
||||||
|
方向为 'ASC' 或 'DESC'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
安全的ORDER BY字符串
|
||||||
|
"""
|
||||||
|
if not fields:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
order_parts = []
|
||||||
|
for field, direction in fields:
|
||||||
|
safe_field = validate_identifier(field)
|
||||||
|
direction = direction.upper().strip()
|
||||||
|
if direction not in ("ASC", "DESC"):
|
||||||
|
raise ValueError(f"Invalid order direction: {direction}")
|
||||||
|
order_parts.append(f"{safe_field} {direction}")
|
||||||
|
|
||||||
|
return ", ".join(order_parts)
|
||||||
|
|
||||||
|
|
||||||
|
def build_safe_select(fields: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
构建安全的SELECT字段列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fields: 字段名列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
安全的SELECT字段字符串
|
||||||
|
"""
|
||||||
|
if not fields:
|
||||||
|
return "*"
|
||||||
|
|
||||||
|
return ", ".join(validate_identifier(f) for f in fields)
|
||||||
|
|
||||||
|
|
||||||
|
class SafeQueryBuilder:
|
||||||
|
"""
|
||||||
|
安全SQL查询构建器
|
||||||
|
|
||||||
|
提供链式调用的SQL构建接口,自动处理转义和验证
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, table_name: str):
|
||||||
|
"""
|
||||||
|
初始化查询构建器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: 表名
|
||||||
|
"""
|
||||||
|
self._table = validate_identifier(table_name)
|
||||||
|
self._select_fields = "*"
|
||||||
|
self._conditions = []
|
||||||
|
self._order_fields = []
|
||||||
|
self._limit = None
|
||||||
|
self._offset = None
|
||||||
|
|
||||||
|
def select(self, *fields: str) -> 'SafeQueryBuilder':
|
||||||
|
"""设置SELECT字段"""
|
||||||
|
if fields:
|
||||||
|
self._select_fields = build_safe_select(list(fields))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def where(self, field: str, operator: str, value: Any) -> 'SafeQueryBuilder':
|
||||||
|
"""添加WHERE条件"""
|
||||||
|
self._conditions.append((field, operator, value))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def where_in(self, field: str, values: List[Any]) -> 'SafeQueryBuilder':
|
||||||
|
"""添加IN条件"""
|
||||||
|
self._conditions.append((field, "IN", values))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def where_like(self, field: str, pattern: str) -> 'SafeQueryBuilder':
|
||||||
|
"""添加LIKE条件"""
|
||||||
|
self._conditions.append((field, "LIKE", pattern))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def where_between(self, field: str, start: Any, end: Any) -> 'SafeQueryBuilder':
|
||||||
|
"""添加BETWEEN条件"""
|
||||||
|
safe_field = validate_identifier(field)
|
||||||
|
self._conditions.append((f"{safe_field} >= {escape_sql_value(start)}", "=", True))
|
||||||
|
self._conditions.append((f"{safe_field} <= {escape_sql_value(end)}", "=", True))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def order_by(self, field: str, direction: str = "ASC") -> 'SafeQueryBuilder':
|
||||||
|
"""添加排序"""
|
||||||
|
self._order_fields.append((field, direction))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def limit(self, count: int) -> 'SafeQueryBuilder':
|
||||||
|
"""设置LIMIT"""
|
||||||
|
if count > 0:
|
||||||
|
self._limit = count
|
||||||
|
return self
|
||||||
|
|
||||||
|
def offset(self, count: int) -> 'SafeQueryBuilder':
|
||||||
|
"""设置OFFSET"""
|
||||||
|
if count >= 0:
|
||||||
|
self._offset = count
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build_select_sql(self) -> str:
|
||||||
|
"""构建SELECT SQL语句"""
|
||||||
|
sql = f"SELECT {self._select_fields} FROM {self._table}"
|
||||||
|
|
||||||
|
if self._conditions:
|
||||||
|
where_clause = build_safe_filter(self._conditions)
|
||||||
|
sql += f" WHERE {where_clause}"
|
||||||
|
|
||||||
|
if self._order_fields:
|
||||||
|
order_clause = build_safe_order_by(self._order_fields)
|
||||||
|
sql += f" ORDER BY {order_clause}"
|
||||||
|
|
||||||
|
if self._limit is not None:
|
||||||
|
sql += f" LIMIT {self._limit}"
|
||||||
|
|
||||||
|
if self._offset is not None:
|
||||||
|
sql += f" OFFSET {self._offset}"
|
||||||
|
|
||||||
|
return sql
|
||||||
|
|
||||||
|
def build_count_sql(self) -> str:
|
||||||
|
"""构建COUNT SQL语句"""
|
||||||
|
sql = f"SELECT COUNT(*) as total FROM {self._table}"
|
||||||
|
|
||||||
|
if self._conditions:
|
||||||
|
where_clause = build_safe_filter(self._conditions)
|
||||||
|
sql += f" WHERE {where_clause}"
|
||||||
|
|
||||||
|
return sql
|
||||||
|
|
||||||
|
def build_delete_sql(self) -> str:
|
||||||
|
"""构建DELETE SQL语句"""
|
||||||
|
if not self._conditions:
|
||||||
|
raise ValueError("DELETE operation requires WHERE conditions for safety")
|
||||||
|
|
||||||
|
where_clause = build_safe_filter(self._conditions)
|
||||||
|
return f"DELETE FROM {self._table} WHERE {where_clause}"
|
||||||
|
|
||||||
|
|
||||||
def dict_to_lower_keys(d: dict) -> dict:
|
def dict_to_lower_keys(d: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
将字典的键转换为小写
|
将字典的键转换为小写
|
||||||
|
|||||||
@@ -36,12 +36,6 @@ class LogRecord(BaseModel):
|
|||||||
|
|
||||||
extra: Optional[Dict[str, Any]] = None
|
extra: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
model_config = {
|
|
||||||
"json_encoders": {
|
|
||||||
datetime: lambda v: v.isoformat()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@field_validator('level_name', mode='before')
|
@field_validator('level_name', mode='before')
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_level_name(cls, v: str) -> str:
|
def validate_level_name(cls, v: str) -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user