"""
认证依赖注入模块
提供JWT认证、权限验证等依赖注入函数
"""

from typing import Optional, List
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy import select
from loguru import logger

from auth import jwt_manager
from database import get_database, DatabaseManager, users_table
from models import UserResponse


# HTTP Bearer token 安全方案
security = HTTPBearer()


async def get_current_user(
    credentials: HTTPAuthorizationCredentials = Depends(security),
    db: DatabaseManager = Depends(get_database)
) -> UserResponse:
    """
    获取当前认证用户
    
    Args:
        credentials: HTTP Bearer token凭据
        db: 数据库管理器
        
    Returns:
        当前用户信息
        
    Raises:
        HTTPException: 认证失败时抛出异常
    """
    try:
        # 验证token
        payload = jwt_manager.verify_token(credentials.credentials, "access")
        user_id = payload.get("user_id")

        if user_id is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Token中缺少用户ID"
            )

        # 从数据库获取用户信息
        query = users_table.select().where(users_table.c.id == user_id)
        user = await db.fetch_one(query)

        if not user:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="用户不存在"
            )

        # 返回用户信息
        return UserResponse(
            id=user["id"],
            phone=user["phone"],
            name=user["name"],
            role=user["role"],
            institutions=user["institutions"] or [],
            created_at=user["created_at"],
            updated_at=user["updated_at"]
        )
        
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"获取当前用户失败: {e}")
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="认证失败"
        )


async def get_current_active_user(
    current_user: UserResponse = Depends(get_current_user)
) -> UserResponse:
    """
    获取当前活跃用户（可扩展为检查用户状态）
    
    Args:
        current_user: 当前用户
        
    Returns:
        当前活跃用户信息
    """
    # 这里可以添加用户状态检查逻辑
    # 例如检查用户是否被禁用、是否需要重新验证等
    return current_user


def require_roles(allowed_roles: List[str]):
    """
    角色权限验证装饰器工厂
    
    Args:
        allowed_roles: 允许的角色列表
        
    Returns:
        依赖注入函数
    """
    async def role_checker(
        current_user: UserResponse = Depends(get_current_active_user)
    ) -> UserResponse:
        """
        检查用户角色权限
        
        Args:
            current_user: 当前用户
            
        Returns:
            当前用户信息
            
        Raises:
            HTTPException: 权限不足时抛出异常
        """
        if current_user.role not in allowed_roles:
            logger.warning(f"用户 {current_user.name} (角色: {current_user.role}) 尝试访问需要角色 {allowed_roles} 的资源")
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail=f"权限不足，需要角色: {', '.join(allowed_roles)}"
            )
        
        return current_user
    
    return role_checker


# 常用的角色权限依赖
require_admin = require_roles(["admin"])
require_user_or_admin = require_roles(["user", "admin"])


async def get_optional_current_user(
    credentials: Optional[HTTPAuthorizationCredentials] = Depends(HTTPBearer(auto_error=False)),
    db: DatabaseManager = Depends(get_database)
) -> Optional[UserResponse]:
    """
    获取可选的当前用户（不强制要求认证）
    
    Args:
        credentials: 可选的HTTP Bearer token凭据
        db: 数据库管理器
        
    Returns:
        当前用户信息或None
    """
    if not credentials:
        return None
    
    try:
        # 验证token
        payload = jwt_manager.verify_token(credentials.credentials, "access")
        user_id = payload.get("user_id")
        
        if user_id is None:
            return None
        
        # 从数据库获取用户信息
        query = users_table.select().where(users_table.c.id == user_id)
        user = await db.fetch_one(query)
        
        if not user:
            return None
        
        # 返回用户信息
        return UserResponse(
            id=user["id"],
            phone=user["phone"],
            name=user["name"],
            role=user["role"],
            institutions=user["institutions"] or [],
            created_at=user["created_at"],
            updated_at=user["updated_at"]
        )
        
    except Exception as e:
        logger.warning(f"可选用户认证失败: {e}")
        return None


class TokenBlacklist:
    """Token黑名单管理（简单内存实现，生产环境建议使用Redis）"""
    
    def __init__(self):
        self._blacklisted_tokens = set()
    
    def add_token(self, token: str):
        """将token添加到黑名单"""
        self._blacklisted_tokens.add(token)
        logger.info("Token已添加到黑名单")
    
    def is_blacklisted(self, token: str) -> bool:
        """检查token是否在黑名单中"""
        return token in self._blacklisted_tokens
    
    def clear_expired_tokens(self):
        """清理过期的token（这里简化处理，实际应该根据token过期时间清理）"""
        # 在实际应用中，应该解析token获取过期时间，然后清理过期的token
        pass


# 创建全局token黑名单实例
token_blacklist = TokenBlacklist()


async def verify_token_not_blacklisted(
    credentials: HTTPAuthorizationCredentials = Depends(security)
) -> HTTPAuthorizationCredentials:
    """
    验证token不在黑名单中
    
    Args:
        credentials: HTTP Bearer token凭据
        
    Returns:
        验证通过的凭据
        
    Raises:
        HTTPException: token在黑名单中时抛出异常
    """
    if token_blacklist.is_blacklisted(credentials.credentials):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Token已失效"
        )
    
    return credentials
