fix: 移除不正确的gitignore

This commit is contained in:
2026-05-26 15:37:10 +08:00
parent c0b0707c24
commit 50d6af9a9a
5 changed files with 626 additions and 2 deletions
+2 -2
View File
@@ -96,8 +96,8 @@ storage/
storage/**/
migrations/
migrations/**/
# 只忽略根目录下的 migrations(通常是 Alembic/Tortoise 自动生成的迁移文件)
/migrations/
offline_packages/
offline_packages/**/
+5
View File
@@ -0,0 +1,5 @@
"""MDS数据库迁移模块"""
from .version_manager import SchemaVersionManager
from .model_diff import ModelDiffer
__all__ = ["SchemaVersionManager", "ModelDiffer"]
@@ -0,0 +1,271 @@
"""
数据库迁移 API 路由
提供差异检测、执行迁移、权限校验等接口
"""
from typing import List, Optional
from datetime import datetime
from fastapi import APIRouter, Query, Body, Request
from pydantic import BaseModel
from .model_diff import ModelDiffer
from apps.data_opt.mds.staging_cleaner import STAGING_MODEL_MAPPING, ensure_config_initialized
from apps.io_api.utils.common import standard_response
from core.database import get_db_connection_safely
from core.settings import THIS_DB_NAME
from globalobjects import logger as log_config
logger = log_config.get_logger(__name__)
migrate_router = APIRouter(prefix="/migrate", tags=["数据库迁移"])
class ExecuteMigrateRequest(BaseModel):
"""执行迁移请求"""
tables: List[str] = []
force: bool = False
async def execute_migration_sql(sql_statement: str, db_name: str) -> bool:
"""
执行单条 ALTER TABLE 语句
Args:
sql_statement: SQL 语句
db_name: 数据库名
Returns:
是否成功
"""
try:
conn = await get_db_connection_safely(db_name)
await conn.execute_query(sql_statement)
return True
except Exception as e:
logger.error(f"执行迁移失败: {sql_statement}, 错误: {str(e)}")
return False
@migrate_router.get("/check-permission", summary="校验数据库迁移权限")
async def check_migration_permission(
request: Request,
db_name: str = Query(None, description="数据库名称,默认当前账套")
):
"""
校验当前用户是否有数据库迁移权限
通过尝试执行一个无害的查询来验证数据库连接和权限
Returns:
{"success": 1, "data": {"has_permission": true/false, "message": "..."}}
"""
try:
target_db = db_name or THIS_DB_NAME
conn = await get_db_connection_safely(target_db)
result = await conn.execute_query("SELECT current_user, current_database()")
if result and result[1]:
row = result[1][0]
current_user = row.get('current_user', 'unknown')
current_db = row.get('current_database', 'unknown')
test_query = """
SELECT has_schema_privilege(current_user, 'public', 'CREATE') as can_create,
has_schema_privilege(current_user, 'public', 'USAGE') as can_usage
"""
priv_result = await conn.execute_query(test_query)
if priv_result and priv_result[1]:
priv_row = priv_result[1][0]
can_create = priv_row.get('can_create', False)
can_usage = priv_row.get('can_usage', False)
if can_create and can_usage:
return standard_response(
success=1,
message="权限校验通过",
data={
"has_permission": True,
"user": current_user,
"database": current_db,
"can_create": can_create,
"can_usage": can_usage
}
)
else:
return standard_response(
success=1,
message="权限不足:需要 CREATE 和 USAGE 权限",
data={
"has_permission": False,
"user": current_user,
"database": current_db,
"can_create": can_create,
"can_usage": can_usage
}
)
return standard_response(
success=1,
message="无法获取权限信息",
data={"has_permission": False}
)
except Exception as e:
import traceback
logger.error(f"权限校验失败: {str(e)}")
logger.error(traceback.format_exc())
return standard_response(
success=1,
message=f"权限校验异常: {str(e)}",
data={"has_permission": False, "error": str(e)}
)
@migrate_router.get("/diff", summary="获取模型与数据库差异")
async def get_migration_diff(
request: Request,
db_name: str = Query(None, description="数据库名称,默认当前账套")
):
"""
检测 ORM 模型与数据库表结构的差异
Returns:
差异列表、统计信息
"""
try:
ensure_config_initialized()
target_db = db_name or THIS_DB_NAME
differ = ModelDiffer(target_db)
result = await differ.diff_all(STAGING_MODEL_MAPPING)
return standard_response(
success=1,
message=f"检测完成,发现 {result['total_fields']} 个待迁移字段",
data=result
)
except Exception as e:
import traceback
logger.error(f"差异检测失败: {str(e)}")
logger.error(traceback.format_exc())
return standard_response(success=0, message=str(e))
@migrate_router.post("/execute", summary="执行数据库迁移")
async def execute_migration(
request: Request,
data: ExecuteMigrateRequest = Body(...),
db_name: str = Query(None, description="数据库名称,默认当前账套")
):
"""
执行数据库迁移
Args:
tables: 要迁移的表列表,空数组表示全部
force: 是否强制执行(跳过幂等检查)
Returns:
迁移结果
"""
try:
ensure_config_initialized()
target_db = db_name or THIS_DB_NAME
differ = ModelDiffer(target_db)
diff_result = await differ.diff_all(STAGING_MODEL_MAPPING)
all_differences = diff_result.get("differences", [])
if data.tables:
all_differences = [
d for d in all_differences
if d.get("table_key") in data.tables
]
if not all_differences:
return standard_response(
success=1,
message="无需迁移,模型与数据库一致",
data={
"version": None,
"applied_count": 0,
"failed_count": 0,
"skipped_count": 0,
"changes": []
}
)
version = datetime.now().strftime("V%Y%m%d%H%M%S")
applied_count = 0
failed_count = 0
changes = []
for diff in all_differences:
sql = diff.get("sql")
table = diff.get("table")
field = diff.get("field")
db_field = diff.get("db_field")
success = await execute_migration_sql(sql, target_db)
change_record = {
"table": table,
"field": field,
"db_field": db_field,
"sql": sql,
"success": success,
"timestamp": datetime.now().isoformat()
}
changes.append(change_record)
if success:
applied_count += 1
logger.info(f"迁移成功: {table}.{db_field}")
else:
failed_count += 1
logger.error(f"迁移失败: {table}.{db_field}")
return standard_response(
success=1 if failed_count == 0 else 0,
message=f"迁移完成,版本 {version},成功 {applied_count} 个,失败 {failed_count}",
data={
"version": version,
"applied_count": applied_count,
"failed_count": failed_count,
"skipped_count": 0,
"changes": changes
}
)
except Exception as e:
import traceback
logger.error(f"执行迁移失败: {str(e)}")
logger.error(traceback.format_exc())
return standard_response(success=0, message=str(e))
@migrate_router.get("/versions", summary="获取迁移历史")
async def get_migration_versions(
request: Request,
limit: int = Query(20, description="返回数量限制")
):
"""
获取迁移历史记录(简化版本,返回最近执行的迁移信息)
Note: 完整版本需要 version_manager.py 配合数据库存储
"""
try:
return standard_response(
success=1,
message="查询成功",
data={
"versions": [],
"note": "完整版本记录需要启用 version_manager"
}
)
except Exception as e:
logger.error(f"查询迁移历史失败: {str(e)}")
return standard_response(success=0, message=str(e))
+279
View File
@@ -0,0 +1,279 @@
"""
ORM 模型与数据库表结构差异检测模块
对比 Tortoise ORM 模型定义与实际数据库表结构,生成 ALTER TABLE 语句
"""
from typing import List, Dict, Any, Optional, Tuple
from tortoise import Tortoise
from tortoise.fields import (
CharField, IntField, FloatField, DecimalField, BooleanField,
DatetimeField, DateField, TextField, JSONField
)
from core.database import get_db_connection_safely
from core.settings import THIS_DB_NAME
from globalobjects import logger as log_config
logger = log_config.get_logger(__name__)
class AlterStmt:
"""ALTER TABLE 语句封装"""
def __init__(
self,
table_name: str,
field_name: str,
db_field_name: str,
sql_type: str,
sql_statement: str,
is_nullable: bool = True,
default_value: Any = None
):
self.table_name = table_name
self.field_name = field_name
self.db_field_name = db_field_name
self.sql_type = sql_type
self.sql_statement = sql_statement
self.is_nullable = is_nullable
self.default_value = default_value
def to_dict(self) -> Dict[str, Any]:
return {
"table": self.table_name,
"field": self.field_name,
"db_field": self.db_field_name,
"sql_type": self.sql_type,
"sql": self.sql_statement,
"nullable": self.is_nullable,
"default": self.default_value
}
class ModelDiffer:
"""
ORM 模型与数据库表结构差异检测器
检测 Tortoise ORM 模型定义中新增的字段,生成对应的 ALTER TABLE ADD COLUMN 语句
"""
FIELD_TYPE_MAPPING = {
'CharField': 'VARCHAR',
'IntField': 'INTEGER',
'FloatField': 'DOUBLE PRECISION',
'DecimalField': 'DECIMAL',
'BooleanField': 'BOOLEAN',
'DatetimeField': 'TIMESTAMP',
'DateField': 'DATE',
'TextField': 'TEXT',
'JSONField': 'JSONB',
}
def __init__(self, db_name: str = None):
"""
初始化差异检测器
Args:
db_name: 数据库名称,默认使用 THIS_DB_NAME
"""
self.db_name = db_name or THIS_DB_NAME
async def get_db_columns(self, table_name: str) -> Dict[str, Dict[str, Any]]:
"""
查询数据库表的字段信息
Args:
table_name: 表名
Returns:
字段信息字典 {字段名: {type, nullable, default}}
"""
conn = await get_db_connection_safely(self.db_name)
query = """
SELECT
column_name,
data_type,
is_nullable,
column_default,
character_maximum_length,
numeric_precision,
numeric_scale
FROM information_schema.columns
WHERE table_name = $1 AND table_schema = 'public'
ORDER BY ordinal_position
"""
result = await conn.execute_query(query, (table_name,))
columns = {}
for row in (result[1] or []):
col_name = row['column_name']
columns[col_name] = {
'type': row['data_type'],
'nullable': row['is_nullable'] == 'YES',
'default': row['column_default'],
'max_length': row['character_maximum_length'],
'precision': row['numeric_precision'],
'scale': row['numeric_scale']
}
return columns
def get_model_fields(self, model_class) -> Dict[str, Any]:
"""
获取 ORM 模型的字段定义
Args:
model_class: Tortoise ORM 模型类
Returns:
字段映射字典 {Python字段名: field对象}
"""
return dict(model_class._meta.fields_map)
def _map_field_type_to_sql(self, field) -> Tuple[str, str]:
"""
将 Tortoise ORM 字段类型映射为 SQL 类型
Args:
field: Tortoise 字段对象
Returns:
(SQL类型字符串, 完整类型定义字符串)
"""
field_type_name = type(field).__name__
if field_type_name == 'CharField':
max_length = getattr(field, 'max_length', 255)
return 'VARCHAR', f'VARCHAR({max_length})'
elif field_type_name == 'DecimalField':
precision = getattr(field, 'max_digits', 10)
scale = getattr(field, 'decimal_places', 2)
return 'DECIMAL', f'DECIMAL({precision}, {scale})'
elif field_type_name in self.FIELD_TYPE_MAPPING:
return self.FIELD_TYPE_MAPPING[field_type_name], self.FIELD_TYPE_MAPPING[field_type_name]
else:
return 'VARCHAR', 'VARCHAR(255)'
def _generate_alter_sql(
self,
table_name: str,
field_name: str,
field
) -> AlterStmt:
"""
生成 ALTER TABLE ADD COLUMN 语句
Args:
table_name: 表名
field_name: Python 字段名
field: Tortoise 字段对象
Returns:
AlterStmt 对象
"""
db_field_name = getattr(field, 'source_field', None) or field_name
_, sql_type_def = self._map_field_type_to_sql(field)
is_nullable = getattr(field, 'null', True)
default_value = getattr(field, 'default', None)
has_default = default_value is not None and str(default_value) != 'PydanticUndefined'
parts = [f'ALTER TABLE "{table_name}" ADD COLUMN "{db_field_name}" {sql_type_def}']
if not is_nullable:
parts.append('NOT NULL')
if has_default:
if isinstance(default_value, str):
parts.append(f"DEFAULT '{default_value}'")
elif isinstance(default_value, bool):
parts.append(f"DEFAULT {'TRUE' if default_value else 'FALSE'}")
else:
parts.append(f'DEFAULT {default_value}')
sql_statement = ' '.join(parts)
return AlterStmt(
table_name=table_name,
field_name=field_name,
db_field_name=db_field_name,
sql_type=sql_type_def,
sql_statement=sql_statement,
is_nullable=is_nullable,
default_value=default_value if has_default else None
)
async def diff(self, model_class) -> List[AlterStmt]:
"""
对比单个模型与数据库表的差异
Args:
model_class: Tortoise ORM 模型类
Returns:
差异列表(需要新增的字段)
"""
table_name = model_class._meta.db_table
db_columns = await self.get_db_columns(table_name)
model_fields = self.get_model_fields(model_class)
differences = []
for field_name, field in model_fields.items():
if field_name.startswith('_') and field_name not in (
'_staging_id', '_source_system', '_source_id', '_status',
'_error_msg', '_transform_rules', '_retry_count',
'_createtime', '_updatetime', '_synced_id', '_synced_time'
):
continue
db_field_name = getattr(field, 'source_field', None) or field_name
if db_field_name not in db_columns:
alter_stmt = self._generate_alter_sql(table_name, field_name, field)
differences.append(alter_stmt)
logger.debug(f"检测到新字段: {table_name}.{db_field_name}")
return differences
async def diff_all(self, model_mapping: Dict[str, Any]) -> Dict[str, Any]:
"""
批量检测所有表的差异
Args:
model_mapping: 模型映射字典 {表键: 模型类}
Returns:
汇总结果 {
"differences": [...],
"total_tables": int,
"total_fields": int
}
"""
all_differences = []
tables_with_diff = set()
for table_key, model_class in model_mapping.items():
try:
diffs = await self.diff(model_class)
if diffs:
tables_with_diff.add(table_key)
for diff in diffs:
all_differences.append({
"table_key": table_key,
**diff.to_dict()
})
except Exception as e:
logger.error(f"检测差异失败 [{table_key}]: {str(e)}")
return {
"differences": all_differences,
"total_tables": len(tables_with_diff),
"total_fields": len(all_differences)
}
@@ -0,0 +1,69 @@
"""
SchemaVersionManager - 迁移版本记录管理
负责记录每次迁移的版本和变更内容
"""
from typing import List, Optional, Dict, Any
from datetime import datetime
from tortoise import Tortoise
class SchemaVersionManager:
"""版本管理器"""
VERSION_TABLE = "t_schema_version"
async def ensure_table_exists(self, db_name: str):
"""确保版本表存在"""
conn = Tortoise.get_connection(db_name)
await conn.execute_query('''
CREATE TABLE IF NOT EXISTS t_schema_version (
id SERIAL PRIMARY KEY,
version VARCHAR(16) UNIQUE NOT NULL,
applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
description TEXT,
sql_scripts TEXT,
status VARCHAR(16) DEFAULT 'applied'
)
''')
async def get_applied_versions(self, db_name: str) -> List[str]:
"""获取已应用的版本列表"""
conn = Tortoise.get_connection(db_name)
result = await conn.execute_query(
"SELECT version FROM t_schema_version WHERE status='applied' ORDER BY version"
)
return [row[0] for row in result[1]] if result[1] else []
async def get_latest_version(self, db_name: str) -> Optional[str]:
"""获取最新版本号"""
conn = Tortoise.get_connection(db_name)
result = await conn.execute_query(
"SELECT version FROM t_schema_version WHERE status='applied' ORDER BY version DESC LIMIT 1"
)
return result[1][0][0] if result[1] else None
async def generate_next_version(self, db_name: str) -> str:
"""生成下一个版本号"""
latest = await self.get_latest_version(db_name)
if latest:
version_num = int(latest.replace('V', '')) + 1
else:
version_num = 1
return f"V{version_num:03d}"
async def record_version(self, db_name: str, version: str, description: str, sql_scripts: str):
"""记录新版本"""
conn = Tortoise.get_connection(db_name)
await conn.execute_query(
"""INSERT INTO t_schema_version (version, description, sql_scripts, status)
VALUES ($1, $2, $3, 'applied')""",
(version, description, sql_scripts)
)
async def version_exists(self, db_name: str, version: str) -> bool:
"""检查版本是否已存在"""
conn = Tortoise.get_connection(db_name)
result = await conn.execute_query(
"SELECT COUNT(*) FROM t_schema_version WHERE version = $1",
(version,)
)
return result[1][0][0] > 0