"""
JWT 认证工具模块
提供 JWT token 生成、验证、刷新等功能
以及密码加密和验证功能
"""

from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi import HTTPException, status
from loguru import logger

from config import settings


# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


class JWTManager:
    """JWT 管理器"""
    
    def __init__(self):
        self.secret_key = settings.JWT_SECRET_KEY
        self.algorithm = settings.JWT_ALGORITHM
        self.access_token_expire_hours = settings.JWT_EXPIRE_HOURS
        self.refresh_token_expire_days = 7  # 刷新token有效期7天
    
    def create_access_token(self, data: Dict[str, Any]) -> str:
        """
        创建访问token
        
        Args:
            data: 要编码到token中的数据
            
        Returns:
            JWT access token字符串
        """
        to_encode = data.copy()
        expire = datetime.now(timezone.utc) + timedelta(hours=self.access_token_expire_hours)
        to_encode.update({
            "exp": expire,
            "type": "access"
        })
        
        try:
            encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
            logger.info(f"创建访问token成功，用户: {data.get('sub')}")
            return encoded_jwt
        except Exception as e:
            logger.error(f"创建访问token失败: {e}")
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="Token创建失败"
            )
    
    def create_refresh_token(self, data: Dict[str, Any]) -> str:
        """
        创建刷新token
        
        Args:
            data: 要编码到token中的数据
            
        Returns:
            JWT refresh token字符串
        """
        to_encode = data.copy()
        expire = datetime.now(timezone.utc) + timedelta(days=self.refresh_token_expire_days)
        to_encode.update({
            "exp": expire,
            "type": "refresh"
        })
        
        try:
            encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
            logger.info(f"创建刷新token成功，用户: {data.get('sub')}")
            return encoded_jwt
        except Exception as e:
            logger.error(f"创建刷新token失败: {e}")
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="刷新Token创建失败"
            )
    
    def verify_token(self, token: str, token_type: str = "access") -> Dict[str, Any]:
        """
        验证token
        
        Args:
            token: JWT token字符串
            token_type: token类型 ("access" 或 "refresh")
            
        Returns:
            解码后的token数据
            
        Raises:
            HTTPException: token无效时抛出异常
        """
        try:
            payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
            
            # 检查token类型
            if payload.get("type") != token_type:
                raise HTTPException(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    detail=f"无效的token类型，期望: {token_type}"
                )
            
            # 检查过期时间
            exp = payload.get("exp")
            if exp is None:
                raise HTTPException(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    detail="Token缺少过期时间"
                )
            
            if datetime.now(timezone.utc) > datetime.fromtimestamp(exp, timezone.utc):
                raise HTTPException(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    detail="Token已过期"
                )
            
            return payload
            
        except JWTError as e:
            logger.warning(f"Token验证失败: {e}")
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="无效的token"
            )
        except HTTPException:
            raise
        except Exception as e:
            logger.error(f"Token验证异常: {e}")
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="Token验证失败"
            )
    
    def refresh_access_token(self, refresh_token: str) -> str:
        """
        使用刷新token生成新的访问token
        
        Args:
            refresh_token: 刷新token字符串
            
        Returns:
            新的访问token
        """
        # 验证刷新token
        payload = self.verify_token(refresh_token, "refresh")
        
        # 创建新的访问token数据
        new_data = {
            "sub": payload.get("sub"),
            "user_id": payload.get("user_id"),
            "role": payload.get("role"),
            "phone": payload.get("phone")
        }
        
        return self.create_access_token(new_data)


class PasswordManager:
    """密码管理器"""
    
    @staticmethod
    def hash_password(password: str) -> str:
        """
        加密密码
        
        Args:
            password: 明文密码
            
        Returns:
            加密后的密码哈希
        """
        try:
            hashed = pwd_context.hash(password)
            logger.info("密码加密成功")
            return hashed
        except Exception as e:
            logger.error(f"密码加密失败: {e}")
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="密码加密失败"
            )
    
    @staticmethod
    def verify_password(plain_password: str, hashed_password: str) -> bool:
        """
        验证密码
        
        Args:
            plain_password: 明文密码
            hashed_password: 加密后的密码哈希
            
        Returns:
            密码是否匹配
        """
        try:
            result = pwd_context.verify(plain_password, hashed_password)
            logger.info(f"密码验证结果: {'成功' if result else '失败'}")
            return result
        except Exception as e:
            logger.error(f"密码验证异常: {e}")
            return False
    
    @staticmethod
    def is_hashed_password(password: str) -> bool:
        """
        检查密码是否已经是哈希格式
        
        Args:
            password: 密码字符串
            
        Returns:
            是否为哈希密码
        """
        # bcrypt哈希通常以$2b$开头，长度为60字符
        return password.startswith("$2b$") and len(password) == 60


# 创建全局实例
jwt_manager = JWTManager()
password_manager = PasswordManager()
