mirror of
https://github.com/rnvm9wjdtj-bot/myaps_api.git
synced 2026-06-02 05:54:40 +00:00
拆分main
This commit is contained in:
+31
@@ -0,0 +1,31 @@
|
||||
from fastapi import FastAPI
|
||||
from config.settings import PORT
|
||||
|
||||
def create_app(lifespan=None):
|
||||
app = FastAPI(
|
||||
title="MyAPS API",
|
||||
description="MyAPS API系统接口文档,提供物料、工作中心、工序、BOM等主数据管理,以及供应需求等生产数据管理功能。",
|
||||
version="1.0.0",
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
swagger_js_url="/static/swagger/swagger-ui-bundle.js",
|
||||
swagger_css_url="/static/swagger/swagger-ui.css",
|
||||
swagger_favicon_url="/static/swagger/favicon-32x32.png",
|
||||
swagger_ui_parameters={
|
||||
"configUrl": None,
|
||||
"defaultModelsExpandDepth": 2,
|
||||
"defaultModelExpandDepth": 3,
|
||||
"displayRequestDuration": True,
|
||||
"docExpansion": "list",
|
||||
"tryItOutEnabled": True,
|
||||
"jsonEditor": True,
|
||||
"showCommonExtensions": True,
|
||||
"showExtensions": True,
|
||||
"showMutatedRequest": True
|
||||
},
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
app.openapi_version = "3.0.2"
|
||||
|
||||
return app
|
||||
@@ -0,0 +1,30 @@
|
||||
from tortoise.contrib.fastapi import register_tortoise
|
||||
from config.settings import TORTOISE_ORM_CONFIG
|
||||
from globalobjects import logger as log_config
|
||||
|
||||
def register_database(app):
|
||||
register_tortoise(
|
||||
app=app,
|
||||
config=TORTOISE_ORM_CONFIG,
|
||||
# modules={"models": ["project_code.models"]},
|
||||
# generate_schemas=True, # 生产环境不要开,若数据库为空则自动生成对应表单
|
||||
# add_exception_handlers=True, # 生产环境不要开,会泄露调试信息
|
||||
)
|
||||
|
||||
async def check_db_connections():
|
||||
"""定期检查数据库连接状态"""
|
||||
try:
|
||||
from globalobjects.db_manager import get_db_managers
|
||||
db_managers = get_db_managers()
|
||||
for db_name, manager in db_managers.items():
|
||||
# 检查连接健康状态
|
||||
is_healthy = await manager.check_connection_health()
|
||||
if not is_healthy:
|
||||
log_config.warning(f"数据库连接 {db_name} 不健康,尝试刷新连接")
|
||||
await manager.refresh_connection()
|
||||
# 获取连接池状态
|
||||
pool_status = await manager.get_connection_pool_status()
|
||||
log_config.debug(f"连接池状态 - {db_name}: {pool_status}")
|
||||
log_config.debug("数据库连接检查完成")
|
||||
except Exception as e:
|
||||
log_config.error(f"数据库连接检查异常: {e}")
|
||||
@@ -0,0 +1,72 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
import time
|
||||
from globalobjects import logger as log_config
|
||||
from apps.data_opt.utils.scheduler import scheduler_manager, get_scheduler_status
|
||||
from apps.data_opt.utils.mysqlmonitor import mysql_monitor
|
||||
from apps.common.utils.resource_monitor import resource_monitor
|
||||
from globalobjects import EVENT_AGGREGATOR
|
||||
from config.settings import TURNON_DBMONITOR, TRUNON_SCHEDULER
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app):
|
||||
"""应用生命周期管理器"""
|
||||
# 应用启动时执行的操作
|
||||
log_config.initialize_logging()
|
||||
|
||||
# 将主应用事件循环传递给调度器
|
||||
main_loop = asyncio.get_running_loop()
|
||||
scheduler_manager.set_main_loop(main_loop)
|
||||
log_config.info(f"已将主应用事件循环传递给调度器: {main_loop}")
|
||||
|
||||
# 初始化并启动定时任务管理器
|
||||
if TRUNON_SCHEDULER:
|
||||
scheduler_manager.init_scheduler()
|
||||
scheduler_manager.start()
|
||||
log_config.info(f"定时任务管理器状态: {get_scheduler_status()}")
|
||||
else:
|
||||
log_config.warning("⚠️ 定时任务管理器未启动")
|
||||
|
||||
if TURNON_DBMONITOR:
|
||||
mysql_monitor.start_monitoring()
|
||||
log_config.info("MySQL Binlog监控已启动")
|
||||
else:
|
||||
log_config.warning("⚠️ MySQL Binlog监控未启动")
|
||||
|
||||
# 启动资源监控
|
||||
log_config.info("开始启动资源监控...")
|
||||
resource_monitor.start_monitoring(interval=30)
|
||||
log_config.info("系统资源监控已启动")
|
||||
|
||||
# 等待一段时间,确保资源监控线程正常启动
|
||||
time.sleep(1)
|
||||
log_config.info("应用启动完成,开始运行")
|
||||
|
||||
yield # 应用运行期间
|
||||
|
||||
# 应用关闭时执行的操作
|
||||
log_config.info("应用关闭中...")
|
||||
|
||||
if TURNON_DBMONITOR:
|
||||
mysql_monitor.stop_monitoring()
|
||||
log_config.info("MySQL Binlog监控已停止")
|
||||
else:
|
||||
log_config.debug("⚠️ MySQL Binlog监控未启动,无需停止")
|
||||
|
||||
# 关闭调度器
|
||||
if TRUNON_SCHEDULER:
|
||||
scheduler_manager.shutdown()
|
||||
log_config.info("定时任务管理器已关闭")
|
||||
else:
|
||||
log_config.debug("⚠️ 定时任务管理器未启动,无需关闭")
|
||||
|
||||
# 停止资源监控
|
||||
resource_monitor.stop_monitoring()
|
||||
log_config.info("系统资源监控已停止")
|
||||
|
||||
# 停止事件聚合器
|
||||
EVENT_AGGREGATOR.stop()
|
||||
log_config.info("事件聚合器已停止")
|
||||
|
||||
# 关闭统一日志系统
|
||||
log_config.shutdown_logging()
|
||||
@@ -0,0 +1,30 @@
|
||||
import os
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
IP_WHITELIST = os.getenv("IP_WHITELIST", "").split(",")
|
||||
API_KEY = os.getenv("API_KEY", "")
|
||||
|
||||
def create_security_middleware():
|
||||
async def security_middleware(request: Request, call_next):
|
||||
# 对GET和OPTIONS方法直接放行
|
||||
if request.method in ["GET", "OPTIONS"]:
|
||||
return await call_next(request)
|
||||
|
||||
# 允许查阅文档等无需认证的请求
|
||||
url_path = request.url.path
|
||||
if url_path in ["/docs", "/redoc", "/openapi.json"] or url_path.startswith("/static/swagger"):
|
||||
return await call_next(request)
|
||||
|
||||
# 检查IP是否在白名单中
|
||||
client_ip = request.client.host
|
||||
if client_ip in ["127.0.0.1", "localhost"] or client_ip in IP_WHITELIST:
|
||||
return await call_next(request)
|
||||
|
||||
# 若不在IP白名单则需要认证请求头X-API-Key
|
||||
if request.headers.get("X-API-Key") == API_KEY:
|
||||
return await call_next(request)
|
||||
|
||||
return JSONResponse(status_code=200, content={"status_code": 403, "success": 0, "meta": {}, "message": "Forbidden: Invalid or missing API Key"})
|
||||
|
||||
return security_middleware
|
||||
@@ -0,0 +1,34 @@
|
||||
def setup_custom_openapi(app):
|
||||
original_openapi = app.openapi
|
||||
|
||||
def openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
# 调用原始的openapi方法获取schema
|
||||
openapi_schema = original_openapi()
|
||||
|
||||
# 确保所有schemas都有详细的描述和示例
|
||||
for schema_name, schema in openapi_schema.get("components", {}).get("schemas", {}).items():
|
||||
# 添加更多描述信息
|
||||
if "properties" in schema:
|
||||
for prop_name, prop in schema["properties"].items():
|
||||
# 确保每个属性都有描述
|
||||
if "description" not in prop and "title" not in prop:
|
||||
prop["description"] = f"字段: {prop_name}"
|
||||
# 确保每个属性都有示例值
|
||||
if "example" not in prop and "examples" not in prop:
|
||||
# 根据类型设置默认示例值
|
||||
if prop.get("type") == "string":
|
||||
prop["example"] = f"示例{prop_name}"
|
||||
elif prop.get("type") == "integer":
|
||||
prop["example"] = 1
|
||||
elif prop.get("type") == "number":
|
||||
prop["example"] = 1.0
|
||||
elif prop.get("type") == "boolean":
|
||||
prop["example"] = True
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
app.openapi = openapi
|
||||
@@ -0,0 +1,15 @@
|
||||
from apps.io_api.routers import rt as io_rt
|
||||
from apps.data_opt.routers import rt as do_rt
|
||||
|
||||
def register_routes(app):
|
||||
app.include_router(io_rt, prefix="/api", tags=[])
|
||||
app.include_router(do_rt, prefix="/do", tags=[])
|
||||
|
||||
# 根路由
|
||||
@app.get("/")
|
||||
async def read_root():
|
||||
return {
|
||||
"message": "Welcome to MyAPI",
|
||||
"version": "1.0.0",
|
||||
"status": "running"
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
import asyncio
|
||||
from globalobjects import logger as log_config
|
||||
from apps.data_opt.utils.scheduler import initialize_scheduler
|
||||
from core.database import check_db_connections
|
||||
|
||||
async def startup_event():
|
||||
"""应用启动事件"""
|
||||
# 初始化定时任务管理器
|
||||
await initialize_scheduler()
|
||||
|
||||
# 设置定期检查数据库连接的任务
|
||||
async def schedule_db_checks():
|
||||
"""定期执行数据库连接检查"""
|
||||
while True:
|
||||
await check_db_connections()
|
||||
# 每300秒(5分钟)检查一次
|
||||
await asyncio.sleep(300)
|
||||
|
||||
# 启动数据库连接检查任务
|
||||
asyncio.create_task(schedule_db_checks())
|
||||
log_config.info("数据库连接检查任务已启动")
|
||||
@@ -0,0 +1,144 @@
|
||||
import asyncio
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from globalobjects import logger as log_config
|
||||
|
||||
class WebSocketConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections = set()
|
||||
self.ip_connection_count = {}
|
||||
self.max_connections_per_ip = 10 # 每IP最大连接数
|
||||
self.max_total_connections = 100 # 总最大连接数
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
client_ip = websocket.client.host
|
||||
|
||||
# 检查连接数限制
|
||||
if len(self.active_connections) >= self.max_total_connections:
|
||||
log_config.warning(f"WebSocket 连接数达到上限: {self.max_total_connections}")
|
||||
return False
|
||||
|
||||
# 检查每IP连接数限制
|
||||
if self.ip_connection_count.get(client_ip, 0) >= self.max_connections_per_ip:
|
||||
log_config.warning(f"IP {client_ip} 连接数达到上限: {self.max_connections_per_ip}")
|
||||
return False
|
||||
|
||||
# 接受连接
|
||||
await websocket.accept()
|
||||
|
||||
# 记录连接
|
||||
self.active_connections.add(websocket)
|
||||
self.ip_connection_count[client_ip] = self.ip_connection_count.get(client_ip, 0) + 1
|
||||
|
||||
return True
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
if websocket in self.active_connections:
|
||||
self.active_connections.remove(websocket)
|
||||
client_ip = websocket.client.host
|
||||
if client_ip in self.ip_connection_count:
|
||||
self.ip_connection_count[client_ip] -= 1
|
||||
if self.ip_connection_count[client_ip] <= 0:
|
||||
del self.ip_connection_count[client_ip]
|
||||
|
||||
def get_active_connections_count(self):
|
||||
return len(self.active_connections)
|
||||
|
||||
def get_ip_connection_count(self, client_ip):
|
||||
return self.ip_connection_count.get(client_ip, 0)
|
||||
|
||||
# 初始化 WebSocket 连接管理器
|
||||
websocket_manager = WebSocketConnectionManager()
|
||||
|
||||
async def websocket_endpoint(websocket: WebSocket, path: str = ""):
|
||||
"""
|
||||
通用 WebSocket 端点,捕获所有 WebSocket 连接请求。
|
||||
使用 {path:path} 通配符匹配任意路径,避免 "Unsupported upgrade request" 警告。
|
||||
"""
|
||||
# 尝试连接,检查连接数限制
|
||||
connected = await websocket_manager.connect(websocket)
|
||||
if not connected:
|
||||
# 连接被拒绝,直接关闭
|
||||
try:
|
||||
await websocket.close(code=1008, reason="Connection limit reached")
|
||||
except:
|
||||
pass
|
||||
return
|
||||
|
||||
full_path = f"/{path}" if path else "/"
|
||||
client_info = {
|
||||
"client": f"{websocket.client.host}:{websocket.client.port}",
|
||||
"path": full_path,
|
||||
"query_params": dict(websocket.query_params),
|
||||
"headers": {k: v for k, v in websocket.headers.items() if k.lower() not in ['cookie', 'authorization']},
|
||||
}
|
||||
log_config.info(f"WebSocket 连接请求: {client_info}")
|
||||
|
||||
try:
|
||||
# 保持连接并接收消息,3秒超时后自动关闭(缩短超时时间)
|
||||
while True:
|
||||
try:
|
||||
message = await asyncio.wait_for(websocket.receive_text(), timeout=3.0)
|
||||
log_config.info(f"WebSocket 收到消息 [{full_path}]: {message}")
|
||||
await websocket.send_json({
|
||||
"status": "received",
|
||||
"path": full_path,
|
||||
"message": message
|
||||
})
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
log_config.info(f"WebSocket 客户端断开连接: {client_info['client']} - {full_path}")
|
||||
except Exception as e:
|
||||
log_config.warning(f"WebSocket 异常 [{full_path}]: {e}")
|
||||
finally:
|
||||
# 确保断开连接并清理资源
|
||||
try:
|
||||
await websocket.close()
|
||||
except:
|
||||
pass
|
||||
# 从连接管理器中移除
|
||||
websocket_manager.disconnect(websocket)
|
||||
log_config.info(f"WebSocket 连接已关闭: {client_info['client']} - {full_path}")
|
||||
# 记录当前活跃连接数
|
||||
log_config.debug(f"当前活跃 WebSocket 连接数: {websocket_manager.get_active_connections_count()}")
|
||||
|
||||
async def websocket_root(websocket: WebSocket):
|
||||
"""
|
||||
根路径 WebSocket 端点,捕获对根路径的 WebSocket 连接请求。
|
||||
"""
|
||||
# 尝试连接,检查连接数限制
|
||||
connected = await websocket_manager.connect(websocket)
|
||||
if not connected:
|
||||
# 连接被拒绝,直接关闭
|
||||
try:
|
||||
await websocket.close(code=1008, reason="Connection limit reached")
|
||||
except:
|
||||
pass
|
||||
return
|
||||
|
||||
client_info = {
|
||||
"client": f"{websocket.client.host}:{websocket.client.port}",
|
||||
"path": "/",
|
||||
"query_params": dict(websocket.query_params),
|
||||
"headers": {k: v for k, v in websocket.headers.items() if k.lower() not in ['cookie', 'authorization']},
|
||||
}
|
||||
log_config.info(f"WebSocket 根路径连接请求: {client_info}")
|
||||
|
||||
try:
|
||||
# 3秒超时后自动关闭(缩短超时时间)
|
||||
await asyncio.wait_for(websocket.receive_text(), timeout=3.0)
|
||||
except (asyncio.TimeoutError, WebSocketDisconnect):
|
||||
pass
|
||||
except Exception as e:
|
||||
log_config.warning(f"WebSocket 根路径异常: {e}")
|
||||
finally:
|
||||
# 确保断开连接并清理资源
|
||||
try:
|
||||
await websocket.close()
|
||||
except:
|
||||
pass
|
||||
# 从连接管理器中移除
|
||||
websocket_manager.disconnect(websocket)
|
||||
log_config.info(f"WebSocket 根路径连接已关闭: {client_info['client']}")
|
||||
# 记录当前活跃连接数
|
||||
log_config.debug(f"当前活跃 WebSocket 连接数: {websocket_manager.get_active_connections_count()}")
|
||||
@@ -1,6 +1,8 @@
|
||||
import os, uvicorn, asyncio#, hashlib
|
||||
import os, uvicorn
|
||||
from dotenv import load_dotenv
|
||||
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
|
||||
# 加载环境变量
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -8,196 +10,28 @@ env_file = os.path.join(BASE_DIR, '.env')
|
||||
os.environ.setdefault('ENV_FILE', env_file)
|
||||
load_dotenv(env_file)
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
# from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from tortoise.contrib.fastapi import register_tortoise
|
||||
|
||||
from config.settings import TORTOISE_ORM_CONFIG, PORT, BASE_DIR, TURNON_DBMONITOR, TRUNON_SCHEDULER
|
||||
from globalobjects import logger as log_config
|
||||
from apps.io_api.routers import rt as io_rt
|
||||
# 导入模块
|
||||
from core.app import create_app
|
||||
from core.lifespan import lifespan
|
||||
from core.openapi import setup_custom_openapi
|
||||
from core.middleware import create_security_middleware, IP_WHITELIST, API_KEY
|
||||
from core.websocket import websocket_endpoint, websocket_root
|
||||
from core.routes_register import register_routes
|
||||
from core.database import register_database
|
||||
from core.tasks import startup_event
|
||||
from apps.io_api.utils.common import register_exception_handlers
|
||||
from apps.data_opt.routers import rt as do_rt
|
||||
|
||||
|
||||
# 导入全局MySQL监控实例
|
||||
from apps.data_opt.utils.mysqlmonitor import mysql_monitor
|
||||
from apps.data_opt.utils.scheduler import scheduler_manager, get_scheduler_status
|
||||
# 导入资源监控
|
||||
from apps.common.utils.resource_monitor import resource_monitor
|
||||
# 导入事件聚合器
|
||||
from globalobjects import EVENT_AGGREGATOR
|
||||
|
||||
# 定义生命周期事件处理器
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理器"""
|
||||
# 应用启动时执行的操作
|
||||
# 初始化统一日志系统
|
||||
log_config.initialize_logging()
|
||||
|
||||
# 将主应用事件循环传递给调度器(在 uvicorn 启动后获取正确的事件循环)
|
||||
main_loop = asyncio.get_running_loop()
|
||||
scheduler_manager.set_main_loop(main_loop)
|
||||
log_config.info(f"已将主应用事件循环传递给调度器: {main_loop}")
|
||||
|
||||
# 初始化并启动定时任务管理器
|
||||
if TRUNON_SCHEDULER:
|
||||
scheduler_manager.init_scheduler()
|
||||
scheduler_manager.start()
|
||||
log_config.info(f"定时任务管理器状态: {get_scheduler_status()}")
|
||||
else:
|
||||
log_config.warning("⚠️ 定时任务管理器未启动")
|
||||
|
||||
if TURNON_DBMONITOR:
|
||||
mysql_monitor.start_monitoring()
|
||||
log_config.info("MySQL Binlog监控已启动")
|
||||
else:
|
||||
log_config.warning("⚠️ MySQL Binlog监控未启动")
|
||||
|
||||
# 启动资源监控
|
||||
log_config.info("开始启动资源监控...")
|
||||
resource_monitor.start_monitoring(interval=30) # 每30秒监控一次
|
||||
log_config.info("系统资源监控已启动")
|
||||
|
||||
# 等待一段时间,确保资源监控线程正常启动
|
||||
import time
|
||||
time.sleep(1)
|
||||
log_config.info("应用启动完成,开始运行")
|
||||
|
||||
yield # 应用运行期间
|
||||
|
||||
# 应用关闭时执行的操作
|
||||
log_config.info("应用关闭中...")
|
||||
|
||||
if TURNON_DBMONITOR:
|
||||
mysql_monitor.stop_monitoring()
|
||||
log_config.info("MySQL Binlog监控已停止")
|
||||
else:
|
||||
log_config.debug("⚠️ MySQL Binlog监控未启动,无需停止")
|
||||
|
||||
# 关闭调度器
|
||||
if TRUNON_SCHEDULER:
|
||||
scheduler_manager.shutdown()
|
||||
log_config.info("定时任务管理器已关闭")
|
||||
else:
|
||||
log_config.debug("⚠️ 定时任务管理器未启动,无需关闭")
|
||||
|
||||
# 停止资源监控
|
||||
resource_monitor.stop_monitoring()
|
||||
log_config.info("系统资源监控已停止")
|
||||
|
||||
# 停止事件聚合器
|
||||
EVENT_AGGREGATOR.stop()
|
||||
log_config.info("事件聚合器已停止")
|
||||
|
||||
# 关闭统一日志系统
|
||||
log_config.shutdown_logging()
|
||||
|
||||
|
||||
|
||||
# 创建FastAPI应用实例
|
||||
app = FastAPI(
|
||||
title="MyAPS API",
|
||||
description="MyAPS API系统接口文档,提供物料、工作中心、工序、BOM等主数据管理,以及供应需求等生产数据管理功能。",
|
||||
version="1.0.0",
|
||||
# 配置文档页面URL,禁用默认的 Swagger UI,防止CDN资源不稳定导致无法访问文档页
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
swagger_js_url="/static/swagger/swagger-ui-bundle.js",
|
||||
swagger_css_url="/static/swagger/swagger-ui.css",
|
||||
swagger_favicon_url="/static/swagger/favicon-32x32.png",
|
||||
swagger_ui_parameters={
|
||||
"configUrl": None,
|
||||
"defaultModelsExpandDepth": 2, # 默认展开模型深度
|
||||
"defaultModelExpandDepth": 3, # 默认展开模型属性深度
|
||||
"displayRequestDuration": True, # 显示请求持续时间
|
||||
"docExpansion": "list", # 文档展开方式: 'list', 'full', 'none'
|
||||
"tryItOutEnabled": True, # 启用"Try it out"功能
|
||||
"jsonEditor": True, # 使用JSON编辑器编辑请求体
|
||||
"showCommonExtensions": True, # 显示扩展字段
|
||||
"showExtensions": True, # 显示OpenAPI扩展
|
||||
"showMutatedRequest": True # 显示修改后的请求
|
||||
},
|
||||
lifespan=lifespan # 使用新的生命周期事件处理器
|
||||
)
|
||||
|
||||
# 增强OpenAPI schema配置
|
||||
app.openapi_version = "3.0.2"
|
||||
|
||||
# 保存原始的openapi方法
|
||||
original_openapi = app.openapi
|
||||
|
||||
# 自定义OpenAPI schema生成
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
# 调用原始的openapi方法获取schema
|
||||
openapi_schema = original_openapi()
|
||||
|
||||
# 确保所有schemas都有详细的描述和示例
|
||||
for schema_name, schema in openapi_schema.get("components", {}).get("schemas", {}).items():
|
||||
# 添加更多描述信息
|
||||
if "properties" in schema:
|
||||
for prop_name, prop in schema["properties"].items():
|
||||
# 确保每个属性都有描述
|
||||
if "description" not in prop and "title" not in prop:
|
||||
prop["description"] = f"字段: {prop_name}"
|
||||
# 确保每个属性都有示例值
|
||||
if "example" not in prop and "examples" not in prop:
|
||||
# 根据类型设置默认示例值
|
||||
if prop.get("type") == "string":
|
||||
prop["example"] = f"示例{prop_name}"
|
||||
elif prop.get("type") == "integer":
|
||||
prop["example"] = 1
|
||||
elif prop.get("type") == "number":
|
||||
prop["example"] = 1.0
|
||||
elif prop.get("type") == "boolean":
|
||||
prop["example"] = True
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
from config.settings import PORT
|
||||
|
||||
# 创建应用实例
|
||||
app = create_app(lifespan=lifespan)
|
||||
# 设置自定义的OpenAPI schema生成函数
|
||||
app.openapi = custom_openapi
|
||||
|
||||
|
||||
# 定义安全验证验证中间件
|
||||
|
||||
IP_WHITELIST = os.getenv("IP_WHITELIST", "")
|
||||
API_KEY = os.getenv("API_KEY", "")
|
||||
setup_custom_openapi(app)
|
||||
|
||||
# 配置安全中间件
|
||||
if IP_WHITELIST or API_KEY:
|
||||
IP_WHITELIST = os.getenv("IP_WHITELIST", "").split(",")
|
||||
@app.middleware("http")
|
||||
async def security_middleware(request: Request, call_next):
|
||||
# 对GET和OPTIONS方法直接放行
|
||||
if request.method in ["GET", "OPTIONS"]:
|
||||
return await call_next(request)
|
||||
app.middleware("http")(create_security_middleware())
|
||||
|
||||
# 允许查阅文档等无需认证的请求
|
||||
url_path = request.url.path
|
||||
if url_path in ["/docs", "/redoc", "/openapi.json"] or url_path.startswith("/static/swagger"):
|
||||
return await call_next(request)
|
||||
|
||||
# 检查IP是否在白名单中
|
||||
client_ip = request.client.host
|
||||
if client_ip in ["127.0.0.1", "localhost"] or client_ip in IP_WHITELIST:
|
||||
return await call_next(request)
|
||||
|
||||
# 若不在IP白名单则需要认证请求头X-API-Key
|
||||
if request.headers.get("X-API-Key") == API_KEY:
|
||||
return await call_next(request)
|
||||
|
||||
return JSONResponse(status_code=200, content={"status_code": 403, "success": 0, "meta": {}, "message": "Forbidden: Invalid or missing API Key"})
|
||||
|
||||
|
||||
# 配置CORS中间件解决跨域访问问题
|
||||
# 配置CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 在生产环境中应该设置具体的域名列表
|
||||
@@ -206,12 +40,10 @@ app.add_middleware(
|
||||
allow_headers=["*"], # 允许所有请求头
|
||||
)
|
||||
|
||||
|
||||
|
||||
# 挂载静态文件
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
# 覆写原有文档页面路由函数,所有静态资源采用本地文件
|
||||
# 覆写原有文档页面路由函数
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
async def custom_swagger_ui_html():
|
||||
return get_swagger_ui_html(
|
||||
@@ -223,224 +55,21 @@ async def custom_swagger_ui_html():
|
||||
swagger_ui_parameters=app.swagger_ui_parameters
|
||||
)
|
||||
|
||||
|
||||
# 注册自定义的异常处理器
|
||||
# 注册异常处理器
|
||||
register_exception_handlers(app)
|
||||
# register_data_manager_exception_handlers(app)
|
||||
|
||||
# WebSocket 连接管理器,用于跟踪和管理活跃的 WebSocket 连接
|
||||
class WebSocketConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections = set()
|
||||
self.ip_connection_count = {}
|
||||
self.max_connections_per_ip = 10 # 每IP最大连接数
|
||||
self.max_total_connections = 100 # 总最大连接数
|
||||
# 注册WebSocket路由
|
||||
app.websocket("/{path:path}")(websocket_endpoint)
|
||||
app.websocket("/")(websocket_root)
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
client_ip = websocket.client.host
|
||||
|
||||
# 检查连接数限制
|
||||
if len(self.active_connections) >= self.max_total_connections:
|
||||
log_config.warning(f"WebSocket 连接数达到上限: {self.max_total_connections}")
|
||||
return False
|
||||
|
||||
# 检查每IP连接数限制
|
||||
if self.ip_connection_count.get(client_ip, 0) >= self.max_connections_per_ip:
|
||||
log_config.warning(f"IP {client_ip} 连接数达到上限: {self.max_connections_per_ip}")
|
||||
return False
|
||||
|
||||
# 接受连接
|
||||
await websocket.accept()
|
||||
|
||||
# 记录连接
|
||||
self.active_connections.add(websocket)
|
||||
self.ip_connection_count[client_ip] = self.ip_connection_count.get(client_ip, 0) + 1
|
||||
|
||||
return True
|
||||
# 注册路由
|
||||
register_routes(app)
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
if websocket in self.active_connections:
|
||||
self.active_connections.remove(websocket)
|
||||
client_ip = websocket.client.host
|
||||
if client_ip in self.ip_connection_count:
|
||||
self.ip_connection_count[client_ip] -= 1
|
||||
if self.ip_connection_count[client_ip] <= 0:
|
||||
del self.ip_connection_count[client_ip]
|
||||
|
||||
def get_active_connections_count(self):
|
||||
return len(self.active_connections)
|
||||
|
||||
def get_ip_connection_count(self, client_ip):
|
||||
return self.ip_connection_count.get(client_ip, 0)
|
||||
|
||||
# 初始化 WebSocket 连接管理器
|
||||
websocket_manager = WebSocketConnectionManager()
|
||||
|
||||
# WebSocket 路由 - 捕获并记录所有客户端升级请求,避免 "Unsupported upgrade request" 警告
|
||||
# 使用路径参数 {path:path} 匹配任意路径(包括嵌套路径)
|
||||
@app.websocket("/{path:path}")
|
||||
async def websocket_endpoint(websocket: WebSocket, path: str = ""):
|
||||
"""
|
||||
通用 WebSocket 端点,捕获所有 WebSocket 连接请求。
|
||||
使用 {path:path} 通配符匹配任意路径,避免 "Unsupported upgrade request" 警告。
|
||||
"""
|
||||
# 尝试连接,检查连接数限制
|
||||
connected = await websocket_manager.connect(websocket)
|
||||
if not connected:
|
||||
# 连接被拒绝,直接关闭
|
||||
try:
|
||||
await websocket.close(code=1008, reason="Connection limit reached")
|
||||
except:
|
||||
pass
|
||||
return
|
||||
|
||||
full_path = f"/{path}" if path else "/"
|
||||
client_info = {
|
||||
"client": f"{websocket.client.host}:{websocket.client.port}",
|
||||
"path": full_path,
|
||||
"query_params": dict(websocket.query_params),
|
||||
"headers": {k: v for k, v in websocket.headers.items() if k.lower() not in ['cookie', 'authorization']},
|
||||
}
|
||||
log_config.info(f"WebSocket 连接请求: {client_info}")
|
||||
|
||||
try:
|
||||
# 保持连接并接收消息,3秒超时后自动关闭(缩短超时时间)
|
||||
while True:
|
||||
try:
|
||||
message = await asyncio.wait_for(websocket.receive_text(), timeout=3.0)
|
||||
log_config.info(f"WebSocket 收到消息 [{full_path}]: {message}")
|
||||
await websocket.send_json({
|
||||
"status": "received",
|
||||
"path": full_path,
|
||||
"message": message
|
||||
})
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
log_config.info(f"WebSocket 客户端断开连接: {client_info['client']} - {full_path}")
|
||||
except Exception as e:
|
||||
log_config.warning(f"WebSocket 异常 [{full_path}]: {e}")
|
||||
finally:
|
||||
# 确保断开连接并清理资源
|
||||
try:
|
||||
await websocket.close()
|
||||
except:
|
||||
pass
|
||||
# 从连接管理器中移除
|
||||
websocket_manager.disconnect(websocket)
|
||||
log_config.info(f"WebSocket 连接已关闭: {client_info['client']} - {full_path}")
|
||||
# 记录当前活跃连接数
|
||||
log_config.debug(f"当前活跃 WebSocket 连接数: {websocket_manager.get_active_connections_count()}")
|
||||
|
||||
# 为根路径添加单独的 WebSocket 路由
|
||||
@app.websocket("/")
|
||||
async def websocket_root(websocket: WebSocket):
|
||||
"""
|
||||
根路径 WebSocket 端点,捕获对根路径的 WebSocket 连接请求。
|
||||
"""
|
||||
# 尝试连接,检查连接数限制
|
||||
connected = await websocket_manager.connect(websocket)
|
||||
if not connected:
|
||||
# 连接被拒绝,直接关闭
|
||||
try:
|
||||
await websocket.close(code=1008, reason="Connection limit reached")
|
||||
except:
|
||||
pass
|
||||
return
|
||||
|
||||
client_info = {
|
||||
"client": f"{websocket.client.host}:{websocket.client.port}",
|
||||
"path": "/",
|
||||
"query_params": dict(websocket.query_params),
|
||||
"headers": {k: v for k, v in websocket.headers.items() if k.lower() not in ['cookie', 'authorization']},
|
||||
}
|
||||
log_config.info(f"WebSocket 根路径连接请求: {client_info}")
|
||||
|
||||
try:
|
||||
# 3秒超时后自动关闭(缩短超时时间)
|
||||
await asyncio.wait_for(websocket.receive_text(), timeout=3.0)
|
||||
except (asyncio.TimeoutError, WebSocketDisconnect):
|
||||
pass
|
||||
except Exception as e:
|
||||
log_config.warning(f"WebSocket 根路径异常: {e}")
|
||||
finally:
|
||||
# 确保断开连接并清理资源
|
||||
try:
|
||||
await websocket.close()
|
||||
except:
|
||||
pass
|
||||
# 从连接管理器中移除
|
||||
websocket_manager.disconnect(websocket)
|
||||
log_config.info(f"WebSocket 根路径连接已关闭: {client_info['client']}")
|
||||
# 记录当前活跃连接数
|
||||
log_config.debug(f"当前活跃 WebSocket 连接数: {websocket_manager.get_active_connections_count()}")
|
||||
|
||||
# 包含子路由
|
||||
app.include_router(io_rt, prefix="/api", tags=[])
|
||||
app.include_router(do_rt, prefix="/do", tags=[])
|
||||
|
||||
|
||||
# 根路由
|
||||
@app.get("/")
|
||||
async def read_root():
|
||||
return {
|
||||
"message": "Welcome to MyAPI",
|
||||
"version": "1.0.0",
|
||||
"status": "running"
|
||||
}
|
||||
|
||||
# 注册Tortoise ORM
|
||||
|
||||
register_tortoise(
|
||||
app = app,
|
||||
config=TORTOISE_ORM_CONFIG,
|
||||
# modules={"models": ["project_code.models"]},
|
||||
# generate_schemas=True, # 生产环境不要开,若数据库为空则自动生成对应表单
|
||||
# add_exception_handlers=True, # 生产环境不要开,会泄露调试信息
|
||||
)
|
||||
|
||||
# 初始化定时任务管理器
|
||||
from apps.data_opt.utils.scheduler import initialize_scheduler, get_scheduler_status, scheduler_manager
|
||||
|
||||
# 定期检查数据库连接状态
|
||||
async def check_db_connections():
|
||||
"""定期检查数据库连接状态"""
|
||||
try:
|
||||
from globalobjects.db_manager import get_db_managers
|
||||
db_managers = get_db_managers()
|
||||
for db_name, manager in db_managers.items():
|
||||
# 检查连接健康状态
|
||||
is_healthy = await manager.check_connection_health()
|
||||
if not is_healthy:
|
||||
log_config.warning(f"数据库连接 {db_name} 不健康,尝试刷新连接")
|
||||
await manager.refresh_connection()
|
||||
# 获取连接池状态
|
||||
pool_status = await manager.get_connection_pool_status()
|
||||
log_config.debug(f"连接池状态 - {db_name}: {pool_status}")
|
||||
log_config.debug("数据库连接检查完成")
|
||||
except Exception as e:
|
||||
log_config.error(f"数据库连接检查异常: {e}")
|
||||
|
||||
# 应用启动事件
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""应用启动事件"""
|
||||
# 初始化定时任务管理器
|
||||
await initialize_scheduler()
|
||||
|
||||
# 设置定期检查数据库连接的任务
|
||||
import asyncio
|
||||
async def schedule_db_checks():
|
||||
"""定期执行数据库连接检查"""
|
||||
while True:
|
||||
await check_db_connections()
|
||||
# 每300秒(5分钟)检查一次
|
||||
await asyncio.sleep(300)
|
||||
|
||||
# 启动数据库连接检查任务
|
||||
asyncio.create_task(schedule_db_checks())
|
||||
log_config.info("数据库连接检查任务已启动")
|
||||
# 注册数据库
|
||||
register_database(app)
|
||||
|
||||
# 注册启动事件
|
||||
app.on_event("startup")(startup_event)
|
||||
|
||||
# 启动说明:
|
||||
# 使用命令: uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
|
||||
Reference in New Issue
Block a user