"""
迁移基础类和工具
提供迁移系统的核心抽象类和异常定义
"""
import asyncio
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, List
from datetime import datetime
from sqlalchemy import text
from loguru import logger
import json

class MigrationError(Exception):
    """迁移异常类"""
    
    def __init__(self, message: str, migration_version: str = None, original_error: Exception = None):
        self.message = message
        self.migration_version = migration_version
        self.original_error = original_error
        super().__init__(self.message)
    
    def __str__(self):
        if self.migration_version:
            return f"Migration {self.migration_version}: {self.message}"
        return self.message

class Migration(ABC):
    """迁移基础抽象类
    
    所有具体的迁移都应该继承这个类并实现必要的方法
    """
    
    def __init__(self, version: str, description: str, dependencies: List[str] = None):
        """
        初始化迁移
        
        Args:
            version: 迁移版本号，建议使用语义化版本号如 "1.0.1"
            description: 迁移描述，简要说明这个迁移做了什么
            dependencies: 依赖的迁移版本列表，确保迁移顺序正确
        """
        self.version = version
        self.description = description
        self.dependencies = dependencies or []
        self.executed_at: Optional[datetime] = None
        self.execution_time: Optional[float] = None
        
        # 验证版本号格式
        if not version or not isinstance(version, str):
            raise MigrationError("迁移版本号不能为空且必须是字符串")
        
        # 验证描述
        if not description or not isinstance(description, str):
            raise MigrationError("迁移描述不能为空且必须是字符串")
    
    @abstractmethod
    async def up(self, db) -> bool:
        """
        执行迁移 - 必须实现
        
        Args:
            db: 数据库管理器实例
            
        Returns:
            bool: 迁移是否成功执行
            
        Raises:
            MigrationError: 迁移执行失败时抛出
        """
        pass
    
    @abstractmethod
    async def down(self, db) -> bool:
        """
        回滚迁移 - 必须实现
        
        Args:
            db: 数据库管理器实例
            
        Returns:
            bool: 回滚是否成功执行
            
        Raises:
            MigrationError: 回滚执行失败时抛出
        """
        pass
    
    async def validate_before_up(self, db) -> bool:
        """
        迁移前验证 - 可选重写
        
        在执行迁移前进行必要的验证，如检查表是否存在、字段是否已存在等
        
        Args:
            db: 数据库管理器实例
            
        Returns:
            bool: 验证是否通过
        """
        return True
    
    async def validate_after_up(self, db) -> bool:
        """
        迁移后验证 - 可选重写
        
        在执行迁移后进行验证，确保迁移结果符合预期
        
        Args:
            db: 数据库管理器实例
            
        Returns:
            bool: 验证是否通过
        """
        return True
    
    async def get_rollback_sql(self, db) -> Optional[str]:
        """
        获取回滚SQL - 可选重写
        
        返回用于回滚的SQL语句，用于记录和紧急回滚
        
        Args:
            db: 数据库管理器实例
            
        Returns:
            Optional[str]: 回滚SQL语句，如果不需要则返回None
        """
        return None
    
    def get_checksum(self) -> str:
        """
        获取迁移校验和
        
        用于验证迁移文件是否被修改，确保迁移的一致性
        
        Returns:
            str: 迁移内容的校验和
        """
        import hashlib
        content = f"{self.version}:{self.description}:{str(self.dependencies)}"
        return hashlib.md5(content.encode()).hexdigest()
    
    def __str__(self):
        return f"Migration {self.version}: {self.description}"
    
    def __repr__(self):
        return f"<Migration(version='{self.version}', description='{self.description}')>"
    
    def __eq__(self, other):
        if not isinstance(other, Migration):
            return False
        return self.version == other.version
    
    def __hash__(self):
        return hash(self.version)

def version_to_tuple(version: str) -> tuple:
    """
    将版本号转换为可比较的元组
    
    Args:
        version: 版本号字符串，如 "1.0.1"
        
    Returns:
        tuple: 可比较的版本元组，如 (1, 0, 1)
    """
    try:
        # 尝试解析语义化版本号
        parts = version.split('.')
        return tuple(int(part) for part in parts)
    except (ValueError, AttributeError):
        # 如果不是标准版本号格式，按字符串排序
        return (version,)

def compare_versions(version1: str, version2: str) -> int:
    """
    比较两个版本号
    
    Args:
        version1: 第一个版本号
        version2: 第二个版本号
        
    Returns:
        int: -1 如果 version1 < version2, 0 如果相等, 1 如果 version1 > version2
    """
    tuple1 = version_to_tuple(version1)
    tuple2 = version_to_tuple(version2)
    
    if tuple1 < tuple2:
        return -1
    elif tuple1 > tuple2:
        return 1
    else:
        return 0
