mirror of
https://github.com/rnvm9wjdtj-bot/myaps_api.git
synced 2026-06-02 05:54:40 +00:00
fix: 移除不正确的gitignore
This commit is contained in:
+2
-2
@@ -96,8 +96,8 @@ storage/
|
||||
storage/**/
|
||||
|
||||
|
||||
migrations/
|
||||
migrations/**/
|
||||
# 只忽略根目录下的 migrations(通常是 Alembic/Tortoise 自动生成的迁移文件)
|
||||
/migrations/
|
||||
|
||||
offline_packages/
|
||||
offline_packages/**/
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""MDS数据库迁移模块"""
|
||||
from .version_manager import SchemaVersionManager
|
||||
from .model_diff import ModelDiffer
|
||||
|
||||
__all__ = ["SchemaVersionManager", "ModelDiffer"]
|
||||
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
数据库迁移 API 路由
|
||||
提供差异检测、执行迁移、权限校验等接口
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Query, Body, Request
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .model_diff import ModelDiffer
|
||||
from apps.data_opt.mds.staging_cleaner import STAGING_MODEL_MAPPING, ensure_config_initialized
|
||||
from apps.io_api.utils.common import standard_response
|
||||
from core.database import get_db_connection_safely
|
||||
from core.settings import THIS_DB_NAME
|
||||
from globalobjects import logger as log_config
|
||||
|
||||
logger = log_config.get_logger(__name__)
|
||||
|
||||
migrate_router = APIRouter(prefix="/migrate", tags=["数据库迁移"])
|
||||
|
||||
|
||||
class ExecuteMigrateRequest(BaseModel):
|
||||
"""执行迁移请求"""
|
||||
tables: List[str] = []
|
||||
force: bool = False
|
||||
|
||||
|
||||
async def execute_migration_sql(sql_statement: str, db_name: str) -> bool:
|
||||
"""
|
||||
执行单条 ALTER TABLE 语句
|
||||
|
||||
Args:
|
||||
sql_statement: SQL 语句
|
||||
db_name: 数据库名
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
try:
|
||||
conn = await get_db_connection_safely(db_name)
|
||||
await conn.execute_query(sql_statement)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"执行迁移失败: {sql_statement}, 错误: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
@migrate_router.get("/check-permission", summary="校验数据库迁移权限")
|
||||
async def check_migration_permission(
|
||||
request: Request,
|
||||
db_name: str = Query(None, description="数据库名称,默认当前账套")
|
||||
):
|
||||
"""
|
||||
校验当前用户是否有数据库迁移权限
|
||||
|
||||
通过尝试执行一个无害的查询来验证数据库连接和权限
|
||||
|
||||
Returns:
|
||||
{"success": 1, "data": {"has_permission": true/false, "message": "..."}}
|
||||
"""
|
||||
try:
|
||||
target_db = db_name or THIS_DB_NAME
|
||||
conn = await get_db_connection_safely(target_db)
|
||||
|
||||
result = await conn.execute_query("SELECT current_user, current_database()")
|
||||
|
||||
if result and result[1]:
|
||||
row = result[1][0]
|
||||
current_user = row.get('current_user', 'unknown')
|
||||
current_db = row.get('current_database', 'unknown')
|
||||
|
||||
test_query = """
|
||||
SELECT has_schema_privilege(current_user, 'public', 'CREATE') as can_create,
|
||||
has_schema_privilege(current_user, 'public', 'USAGE') as can_usage
|
||||
"""
|
||||
priv_result = await conn.execute_query(test_query)
|
||||
|
||||
if priv_result and priv_result[1]:
|
||||
priv_row = priv_result[1][0]
|
||||
can_create = priv_row.get('can_create', False)
|
||||
can_usage = priv_row.get('can_usage', False)
|
||||
|
||||
if can_create and can_usage:
|
||||
return standard_response(
|
||||
success=1,
|
||||
message="权限校验通过",
|
||||
data={
|
||||
"has_permission": True,
|
||||
"user": current_user,
|
||||
"database": current_db,
|
||||
"can_create": can_create,
|
||||
"can_usage": can_usage
|
||||
}
|
||||
)
|
||||
else:
|
||||
return standard_response(
|
||||
success=1,
|
||||
message="权限不足:需要 CREATE 和 USAGE 权限",
|
||||
data={
|
||||
"has_permission": False,
|
||||
"user": current_user,
|
||||
"database": current_db,
|
||||
"can_create": can_create,
|
||||
"can_usage": can_usage
|
||||
}
|
||||
)
|
||||
|
||||
return standard_response(
|
||||
success=1,
|
||||
message="无法获取权限信息",
|
||||
data={"has_permission": False}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"权限校验失败: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return standard_response(
|
||||
success=1,
|
||||
message=f"权限校验异常: {str(e)}",
|
||||
data={"has_permission": False, "error": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@migrate_router.get("/diff", summary="获取模型与数据库差异")
|
||||
async def get_migration_diff(
|
||||
request: Request,
|
||||
db_name: str = Query(None, description="数据库名称,默认当前账套")
|
||||
):
|
||||
"""
|
||||
检测 ORM 模型与数据库表结构的差异
|
||||
|
||||
Returns:
|
||||
差异列表、统计信息
|
||||
"""
|
||||
try:
|
||||
ensure_config_initialized()
|
||||
|
||||
target_db = db_name or THIS_DB_NAME
|
||||
differ = ModelDiffer(target_db)
|
||||
|
||||
result = await differ.diff_all(STAGING_MODEL_MAPPING)
|
||||
|
||||
return standard_response(
|
||||
success=1,
|
||||
message=f"检测完成,发现 {result['total_fields']} 个待迁移字段",
|
||||
data=result
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"差异检测失败: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return standard_response(success=0, message=str(e))
|
||||
|
||||
|
||||
@migrate_router.post("/execute", summary="执行数据库迁移")
|
||||
async def execute_migration(
|
||||
request: Request,
|
||||
data: ExecuteMigrateRequest = Body(...),
|
||||
db_name: str = Query(None, description="数据库名称,默认当前账套")
|
||||
):
|
||||
"""
|
||||
执行数据库迁移
|
||||
|
||||
Args:
|
||||
tables: 要迁移的表列表,空数组表示全部
|
||||
force: 是否强制执行(跳过幂等检查)
|
||||
|
||||
Returns:
|
||||
迁移结果
|
||||
"""
|
||||
try:
|
||||
ensure_config_initialized()
|
||||
|
||||
target_db = db_name or THIS_DB_NAME
|
||||
differ = ModelDiffer(target_db)
|
||||
|
||||
diff_result = await differ.diff_all(STAGING_MODEL_MAPPING)
|
||||
all_differences = diff_result.get("differences", [])
|
||||
|
||||
if data.tables:
|
||||
all_differences = [
|
||||
d for d in all_differences
|
||||
if d.get("table_key") in data.tables
|
||||
]
|
||||
|
||||
if not all_differences:
|
||||
return standard_response(
|
||||
success=1,
|
||||
message="无需迁移,模型与数据库一致",
|
||||
data={
|
||||
"version": None,
|
||||
"applied_count": 0,
|
||||
"failed_count": 0,
|
||||
"skipped_count": 0,
|
||||
"changes": []
|
||||
}
|
||||
)
|
||||
|
||||
version = datetime.now().strftime("V%Y%m%d%H%M%S")
|
||||
|
||||
applied_count = 0
|
||||
failed_count = 0
|
||||
changes = []
|
||||
|
||||
for diff in all_differences:
|
||||
sql = diff.get("sql")
|
||||
table = diff.get("table")
|
||||
field = diff.get("field")
|
||||
db_field = diff.get("db_field")
|
||||
|
||||
success = await execute_migration_sql(sql, target_db)
|
||||
|
||||
change_record = {
|
||||
"table": table,
|
||||
"field": field,
|
||||
"db_field": db_field,
|
||||
"sql": sql,
|
||||
"success": success,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
changes.append(change_record)
|
||||
|
||||
if success:
|
||||
applied_count += 1
|
||||
logger.info(f"迁移成功: {table}.{db_field}")
|
||||
else:
|
||||
failed_count += 1
|
||||
logger.error(f"迁移失败: {table}.{db_field}")
|
||||
|
||||
return standard_response(
|
||||
success=1 if failed_count == 0 else 0,
|
||||
message=f"迁移完成,版本 {version},成功 {applied_count} 个,失败 {failed_count} 个",
|
||||
data={
|
||||
"version": version,
|
||||
"applied_count": applied_count,
|
||||
"failed_count": failed_count,
|
||||
"skipped_count": 0,
|
||||
"changes": changes
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"执行迁移失败: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return standard_response(success=0, message=str(e))
|
||||
|
||||
|
||||
@migrate_router.get("/versions", summary="获取迁移历史")
|
||||
async def get_migration_versions(
|
||||
request: Request,
|
||||
limit: int = Query(20, description="返回数量限制")
|
||||
):
|
||||
"""
|
||||
获取迁移历史记录(简化版本,返回最近执行的迁移信息)
|
||||
|
||||
Note: 完整版本需要 version_manager.py 配合数据库存储
|
||||
"""
|
||||
try:
|
||||
return standard_response(
|
||||
success=1,
|
||||
message="查询成功",
|
||||
data={
|
||||
"versions": [],
|
||||
"note": "完整版本记录需要启用 version_manager"
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"查询迁移历史失败: {str(e)}")
|
||||
return standard_response(success=0, message=str(e))
|
||||
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
ORM 模型与数据库表结构差异检测模块
|
||||
对比 Tortoise ORM 模型定义与实际数据库表结构,生成 ALTER TABLE 语句
|
||||
"""
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from tortoise import Tortoise
|
||||
from tortoise.fields import (
|
||||
CharField, IntField, FloatField, DecimalField, BooleanField,
|
||||
DatetimeField, DateField, TextField, JSONField
|
||||
)
|
||||
|
||||
from core.database import get_db_connection_safely
|
||||
from core.settings import THIS_DB_NAME
|
||||
from globalobjects import logger as log_config
|
||||
|
||||
logger = log_config.get_logger(__name__)
|
||||
|
||||
|
||||
class AlterStmt:
|
||||
"""ALTER TABLE 语句封装"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table_name: str,
|
||||
field_name: str,
|
||||
db_field_name: str,
|
||||
sql_type: str,
|
||||
sql_statement: str,
|
||||
is_nullable: bool = True,
|
||||
default_value: Any = None
|
||||
):
|
||||
self.table_name = table_name
|
||||
self.field_name = field_name
|
||||
self.db_field_name = db_field_name
|
||||
self.sql_type = sql_type
|
||||
self.sql_statement = sql_statement
|
||||
self.is_nullable = is_nullable
|
||||
self.default_value = default_value
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"table": self.table_name,
|
||||
"field": self.field_name,
|
||||
"db_field": self.db_field_name,
|
||||
"sql_type": self.sql_type,
|
||||
"sql": self.sql_statement,
|
||||
"nullable": self.is_nullable,
|
||||
"default": self.default_value
|
||||
}
|
||||
|
||||
|
||||
class ModelDiffer:
|
||||
"""
|
||||
ORM 模型与数据库表结构差异检测器
|
||||
|
||||
检测 Tortoise ORM 模型定义中新增的字段,生成对应的 ALTER TABLE ADD COLUMN 语句
|
||||
"""
|
||||
|
||||
FIELD_TYPE_MAPPING = {
|
||||
'CharField': 'VARCHAR',
|
||||
'IntField': 'INTEGER',
|
||||
'FloatField': 'DOUBLE PRECISION',
|
||||
'DecimalField': 'DECIMAL',
|
||||
'BooleanField': 'BOOLEAN',
|
||||
'DatetimeField': 'TIMESTAMP',
|
||||
'DateField': 'DATE',
|
||||
'TextField': 'TEXT',
|
||||
'JSONField': 'JSONB',
|
||||
}
|
||||
|
||||
def __init__(self, db_name: str = None):
|
||||
"""
|
||||
初始化差异检测器
|
||||
|
||||
Args:
|
||||
db_name: 数据库名称,默认使用 THIS_DB_NAME
|
||||
"""
|
||||
self.db_name = db_name or THIS_DB_NAME
|
||||
|
||||
async def get_db_columns(self, table_name: str) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
查询数据库表的字段信息
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
字段信息字典 {字段名: {type, nullable, default}}
|
||||
"""
|
||||
conn = await get_db_connection_safely(self.db_name)
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_default,
|
||||
character_maximum_length,
|
||||
numeric_precision,
|
||||
numeric_scale
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = $1 AND table_schema = 'public'
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await conn.execute_query(query, (table_name,))
|
||||
|
||||
columns = {}
|
||||
for row in (result[1] or []):
|
||||
col_name = row['column_name']
|
||||
columns[col_name] = {
|
||||
'type': row['data_type'],
|
||||
'nullable': row['is_nullable'] == 'YES',
|
||||
'default': row['column_default'],
|
||||
'max_length': row['character_maximum_length'],
|
||||
'precision': row['numeric_precision'],
|
||||
'scale': row['numeric_scale']
|
||||
}
|
||||
|
||||
return columns
|
||||
|
||||
def get_model_fields(self, model_class) -> Dict[str, Any]:
|
||||
"""
|
||||
获取 ORM 模型的字段定义
|
||||
|
||||
Args:
|
||||
model_class: Tortoise ORM 模型类
|
||||
|
||||
Returns:
|
||||
字段映射字典 {Python字段名: field对象}
|
||||
"""
|
||||
return dict(model_class._meta.fields_map)
|
||||
|
||||
def _map_field_type_to_sql(self, field) -> Tuple[str, str]:
|
||||
"""
|
||||
将 Tortoise ORM 字段类型映射为 SQL 类型
|
||||
|
||||
Args:
|
||||
field: Tortoise 字段对象
|
||||
|
||||
Returns:
|
||||
(SQL类型字符串, 完整类型定义字符串)
|
||||
"""
|
||||
field_type_name = type(field).__name__
|
||||
|
||||
if field_type_name == 'CharField':
|
||||
max_length = getattr(field, 'max_length', 255)
|
||||
return 'VARCHAR', f'VARCHAR({max_length})'
|
||||
|
||||
elif field_type_name == 'DecimalField':
|
||||
precision = getattr(field, 'max_digits', 10)
|
||||
scale = getattr(field, 'decimal_places', 2)
|
||||
return 'DECIMAL', f'DECIMAL({precision}, {scale})'
|
||||
|
||||
elif field_type_name in self.FIELD_TYPE_MAPPING:
|
||||
return self.FIELD_TYPE_MAPPING[field_type_name], self.FIELD_TYPE_MAPPING[field_type_name]
|
||||
|
||||
else:
|
||||
return 'VARCHAR', 'VARCHAR(255)'
|
||||
|
||||
def _generate_alter_sql(
|
||||
self,
|
||||
table_name: str,
|
||||
field_name: str,
|
||||
field
|
||||
) -> AlterStmt:
|
||||
"""
|
||||
生成 ALTER TABLE ADD COLUMN 语句
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
field_name: Python 字段名
|
||||
field: Tortoise 字段对象
|
||||
|
||||
Returns:
|
||||
AlterStmt 对象
|
||||
"""
|
||||
db_field_name = getattr(field, 'source_field', None) or field_name
|
||||
_, sql_type_def = self._map_field_type_to_sql(field)
|
||||
|
||||
is_nullable = getattr(field, 'null', True)
|
||||
default_value = getattr(field, 'default', None)
|
||||
|
||||
has_default = default_value is not None and str(default_value) != 'PydanticUndefined'
|
||||
|
||||
parts = [f'ALTER TABLE "{table_name}" ADD COLUMN "{db_field_name}" {sql_type_def}']
|
||||
|
||||
if not is_nullable:
|
||||
parts.append('NOT NULL')
|
||||
|
||||
if has_default:
|
||||
if isinstance(default_value, str):
|
||||
parts.append(f"DEFAULT '{default_value}'")
|
||||
elif isinstance(default_value, bool):
|
||||
parts.append(f"DEFAULT {'TRUE' if default_value else 'FALSE'}")
|
||||
else:
|
||||
parts.append(f'DEFAULT {default_value}')
|
||||
|
||||
sql_statement = ' '.join(parts)
|
||||
|
||||
return AlterStmt(
|
||||
table_name=table_name,
|
||||
field_name=field_name,
|
||||
db_field_name=db_field_name,
|
||||
sql_type=sql_type_def,
|
||||
sql_statement=sql_statement,
|
||||
is_nullable=is_nullable,
|
||||
default_value=default_value if has_default else None
|
||||
)
|
||||
|
||||
async def diff(self, model_class) -> List[AlterStmt]:
|
||||
"""
|
||||
对比单个模型与数据库表的差异
|
||||
|
||||
Args:
|
||||
model_class: Tortoise ORM 模型类
|
||||
|
||||
Returns:
|
||||
差异列表(需要新增的字段)
|
||||
"""
|
||||
table_name = model_class._meta.db_table
|
||||
|
||||
db_columns = await self.get_db_columns(table_name)
|
||||
model_fields = self.get_model_fields(model_class)
|
||||
|
||||
differences = []
|
||||
|
||||
for field_name, field in model_fields.items():
|
||||
if field_name.startswith('_') and field_name not in (
|
||||
'_staging_id', '_source_system', '_source_id', '_status',
|
||||
'_error_msg', '_transform_rules', '_retry_count',
|
||||
'_createtime', '_updatetime', '_synced_id', '_synced_time'
|
||||
):
|
||||
continue
|
||||
|
||||
db_field_name = getattr(field, 'source_field', None) or field_name
|
||||
|
||||
if db_field_name not in db_columns:
|
||||
alter_stmt = self._generate_alter_sql(table_name, field_name, field)
|
||||
differences.append(alter_stmt)
|
||||
logger.debug(f"检测到新字段: {table_name}.{db_field_name}")
|
||||
|
||||
return differences
|
||||
|
||||
async def diff_all(self, model_mapping: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
批量检测所有表的差异
|
||||
|
||||
Args:
|
||||
model_mapping: 模型映射字典 {表键: 模型类}
|
||||
|
||||
Returns:
|
||||
汇总结果 {
|
||||
"differences": [...],
|
||||
"total_tables": int,
|
||||
"total_fields": int
|
||||
}
|
||||
"""
|
||||
all_differences = []
|
||||
tables_with_diff = set()
|
||||
|
||||
for table_key, model_class in model_mapping.items():
|
||||
try:
|
||||
diffs = await self.diff(model_class)
|
||||
if diffs:
|
||||
tables_with_diff.add(table_key)
|
||||
for diff in diffs:
|
||||
all_differences.append({
|
||||
"table_key": table_key,
|
||||
**diff.to_dict()
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"检测差异失败 [{table_key}]: {str(e)}")
|
||||
|
||||
return {
|
||||
"differences": all_differences,
|
||||
"total_tables": len(tables_with_diff),
|
||||
"total_fields": len(all_differences)
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
SchemaVersionManager - 迁移版本记录管理
|
||||
负责记录每次迁移的版本和变更内容
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from tortoise import Tortoise
|
||||
|
||||
class SchemaVersionManager:
|
||||
"""版本管理器"""
|
||||
|
||||
VERSION_TABLE = "t_schema_version"
|
||||
|
||||
async def ensure_table_exists(self, db_name: str):
|
||||
"""确保版本表存在"""
|
||||
conn = Tortoise.get_connection(db_name)
|
||||
await conn.execute_query('''
|
||||
CREATE TABLE IF NOT EXISTS t_schema_version (
|
||||
id SERIAL PRIMARY KEY,
|
||||
version VARCHAR(16) UNIQUE NOT NULL,
|
||||
applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
|
||||
description TEXT,
|
||||
sql_scripts TEXT,
|
||||
status VARCHAR(16) DEFAULT 'applied'
|
||||
)
|
||||
''')
|
||||
|
||||
async def get_applied_versions(self, db_name: str) -> List[str]:
|
||||
"""获取已应用的版本列表"""
|
||||
conn = Tortoise.get_connection(db_name)
|
||||
result = await conn.execute_query(
|
||||
"SELECT version FROM t_schema_version WHERE status='applied' ORDER BY version"
|
||||
)
|
||||
return [row[0] for row in result[1]] if result[1] else []
|
||||
|
||||
async def get_latest_version(self, db_name: str) -> Optional[str]:
|
||||
"""获取最新版本号"""
|
||||
conn = Tortoise.get_connection(db_name)
|
||||
result = await conn.execute_query(
|
||||
"SELECT version FROM t_schema_version WHERE status='applied' ORDER BY version DESC LIMIT 1"
|
||||
)
|
||||
return result[1][0][0] if result[1] else None
|
||||
|
||||
async def generate_next_version(self, db_name: str) -> str:
|
||||
"""生成下一个版本号"""
|
||||
latest = await self.get_latest_version(db_name)
|
||||
if latest:
|
||||
version_num = int(latest.replace('V', '')) + 1
|
||||
else:
|
||||
version_num = 1
|
||||
return f"V{version_num:03d}"
|
||||
|
||||
async def record_version(self, db_name: str, version: str, description: str, sql_scripts: str):
|
||||
"""记录新版本"""
|
||||
conn = Tortoise.get_connection(db_name)
|
||||
await conn.execute_query(
|
||||
"""INSERT INTO t_schema_version (version, description, sql_scripts, status)
|
||||
VALUES ($1, $2, $3, 'applied')""",
|
||||
(version, description, sql_scripts)
|
||||
)
|
||||
|
||||
async def version_exists(self, db_name: str, version: str) -> bool:
|
||||
"""检查版本是否已存在"""
|
||||
conn = Tortoise.get_connection(db_name)
|
||||
result = await conn.execute_query(
|
||||
"SELECT COUNT(*) FROM t_schema_version WHERE version = $1",
|
||||
(version,)
|
||||
)
|
||||
return result[1][0][0] > 0
|
||||
Reference in New Issue
Block a user