"""
API Key Management System for SIM Analytics API
Handles API key generation, validation, and management
"""

import os
import secrets
import hashlib
import psycopg2
import psycopg2.extras
from datetime import datetime, timedelta
from enum import Enum
from typing import Optional

class APIKeyType(Enum):
    ADMIN = "admin"
    USER = "user"

def get_db_connection():
    """Get PostgreSQL database connection"""
    return psycopg2.connect(os.environ['DATABASE_URL'])

def init_api_tables():
    """Initialize API key tables"""
    conn = get_db_connection()
    cursor = conn.cursor()
    
    # API Keys table
    cursor.execute("""
        CREATE TABLE IF NOT EXISTS api_keys (
            id SERIAL PRIMARY KEY,
            key_name VARCHAR(255) NOT NULL,
            api_key_hash VARCHAR(255) NOT NULL UNIQUE,
            api_key_prefix VARCHAR(10) NOT NULL,
            key_type VARCHAR(20) NOT NULL CHECK (key_type IN ('admin', 'user')),
            created_by VARCHAR(255) NOT NULL,
            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
            last_used TIMESTAMP,
            is_active BOOLEAN DEFAULT TRUE,
            usage_count INTEGER DEFAULT 0,
            rate_limit_per_hour INTEGER DEFAULT 1000,
            description TEXT
        )
    """)
    
    # API Usage Logs table
    cursor.execute("""
        CREATE TABLE IF NOT EXISTS api_usage_logs (
            id SERIAL PRIMARY KEY,
            api_key_id INTEGER REFERENCES api_keys(id),
            endpoint VARCHAR(255) NOT NULL,
            method VARCHAR(10) NOT NULL,
            ip_address INET,
            user_agent TEXT,
            request_data TEXT,
            response_status INTEGER,
            response_time_ms INTEGER,
            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
        )
    """)
    
    conn.commit()
    cursor.close()
    conn.close()

def generate_api_key(key_type: APIKeyType):
    """Generate a new API key"""
    # Generate a secure random key
    key_data = secrets.token_urlsafe(32)
    
    # Create prefix based on key type
    prefix = "ska_admin_" if key_type == APIKeyType.ADMIN else "ska_user_"
    api_key = f"{prefix}{key_data}"
    
    # Hash the key for storage
    key_hash = hashlib.sha256(api_key.encode()).hexdigest()
    
    return api_key, key_hash, prefix

def create_api_key(key_name: str, key_type: APIKeyType, created_by: str, description: Optional[str] = None, rate_limit: int = 1000):
    """Create a new API key"""
    try:
        api_key, key_hash, prefix = generate_api_key(key_type)
        
        conn = get_db_connection()
        cursor = conn.cursor()
        
        cursor.execute("""
            INSERT INTO api_keys (key_name, api_key_hash, api_key_prefix, key_type, created_by, description, rate_limit_per_hour)
            VALUES (%s, %s, %s, %s, %s, %s, %s)
            RETURNING id
        """, (key_name, key_hash, prefix, key_type.value, created_by, description, rate_limit))
        
        result = cursor.fetchone()
        key_id = result[0] if result else None
        conn.commit()
        cursor.close()
        conn.close()
        
        return {
            'success': True,
            'api_key': api_key,
            'key_id': key_id,
            'prefix': prefix
        }
    
    except Exception as e:
        return {
            'success': False,
            'error': str(e)
        }

def validate_api_key(api_key: str):
    """Validate and return API key information"""
    try:
        key_hash = hashlib.sha256(api_key.encode()).hexdigest()
        
        conn = get_db_connection()
        cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
        
        cursor.execute("""
            SELECT id, key_name, key_type, created_by, is_active, usage_count, rate_limit_per_hour, last_used
            FROM api_keys 
            WHERE api_key_hash = %s AND is_active = TRUE
        """, (key_hash,))
        
        key_info = cursor.fetchone()
        
        if key_info:
            # Update last used timestamp and usage count
            cursor.execute("""
                UPDATE api_keys 
                SET last_used = CURRENT_TIMESTAMP, usage_count = usage_count + 1
                WHERE id = %s
            """, (key_info['id'],))
            conn.commit()
        
        cursor.close()
        conn.close()
        
        return key_info
    
    except Exception as e:
        return None

def get_all_api_keys():
    """Get all API keys (without actual key values)"""
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
    
    cursor.execute("""
        SELECT id, key_name, api_key_prefix, key_type, created_by, created_at, 
               last_used, is_active, usage_count, rate_limit_per_hour, description
        FROM api_keys 
        ORDER BY created_at DESC
    """)
    
    keys = cursor.fetchall()
    cursor.close()
    conn.close()
    
    return keys

def toggle_api_key_status(key_id: int, is_active: bool):
    """Enable or disable an API key"""
    conn = get_db_connection()
    cursor = conn.cursor()
    
    cursor.execute("""
        UPDATE api_keys 
        SET is_active = %s
        WHERE id = %s
    """, (is_active, key_id))
    
    conn.commit()
    cursor.close()
    conn.close()
    
    return True

def delete_api_key(key_id: int):
    """Delete an API key"""
    conn = get_db_connection()
    cursor = conn.cursor()
    
    # Delete usage logs first
    cursor.execute("DELETE FROM api_usage_logs WHERE api_key_id = %s", (key_id,))
    
    # Delete the API key
    cursor.execute("DELETE FROM api_keys WHERE id = %s", (key_id,))
    
    conn.commit()
    cursor.close()
    conn.close()
    
    return True

def log_api_usage(api_key_id: int, endpoint: str, method: str, ip_address: Optional[str] = None, 
                 user_agent: Optional[str] = None, request_data: Optional[str] = None, 
                 response_status: Optional[int] = None, response_time_ms: Optional[int] = None):
    """Log API usage"""
    try:
        conn = get_db_connection()
        cursor = conn.cursor()
        
        cursor.execute("""
            INSERT INTO api_usage_logs 
            (api_key_id, endpoint, method, ip_address, user_agent, request_data, response_status, response_time_ms)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
        """, (api_key_id, endpoint, method, ip_address, user_agent, request_data, response_status, response_time_ms))
        
        conn.commit()
        cursor.close()
        conn.close()
        
    except Exception as e:
        # Log error but don't fail the API request
        print(f"Error logging API usage: {e}")

def check_rate_limit(api_key_id: int, rate_limit: int):
    """Check if API key has exceeded rate limit"""
    try:
        conn = get_db_connection()
        cursor = conn.cursor()
        
        # Check usage in the last hour
        cursor.execute("""
            SELECT COUNT(*) 
            FROM api_usage_logs 
            WHERE api_key_id = %s 
            AND created_at > NOW() - INTERVAL '1 hour'
        """, (api_key_id,))
        
        result = cursor.fetchone()
        usage_count = result[0] if result else 0
        cursor.close()
        conn.close()
        
        return usage_count >= rate_limit
    
    except Exception as e:
        # If we can't check, allow the request
        return False

def get_api_usage_stats(api_key_id: Optional[int] = None, days: int = 30):
    """Get API usage statistics"""
    conn = get_db_connection()
    cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
    
    base_query = """
        SELECT 
            DATE(created_at) as date,
            COUNT(*) as requests,
            COUNT(DISTINCT endpoint) as unique_endpoints,
            AVG(response_time_ms) as avg_response_time
        FROM api_usage_logs 
        WHERE created_at > NOW() - INTERVAL '%s days'
    """
    
    if api_key_id:
        base_query += " AND api_key_id = %s"
        cursor.execute(base_query + " GROUP BY DATE(created_at) ORDER BY date DESC", (days, api_key_id))
    else:
        cursor.execute(base_query + " GROUP BY DATE(created_at) ORDER BY date DESC", (days,))
    
    stats = cursor.fetchall()
    cursor.close()
    conn.close()
    
    return stats

# Initialize tables when module is imported
init_api_tables()