import pandas as pd
from datetime import datetime, timedelta
import json
import os
from contextlib import contextmanager
from sqlalchemy import create_engine, text
import psycopg2

class DatabaseManager:
    def __init__(self):
        self.database_url = os.getenv('DATABASE_URL')
        if not self.database_url:
            raise ValueError("DATABASE_URL environment variable not set")
        
        # Create engine with proper connection pooling and SSL configuration
        self.engine = create_engine(
            self.database_url,
            pool_size=5,
            max_overflow=10,
            pool_timeout=30,
            pool_recycle=300,
            pool_pre_ping=True,
            connect_args={"sslmode": "prefer", "connect_timeout": 10}
        )
        self.init_database()
    
    @contextmanager
    def get_connection(self):
        """Context manager for database connections with retry logic"""
        max_retries = 3
        retry_count = 0
        
        while retry_count < max_retries:
            try:
                conn = self.engine.connect()
                yield conn
                conn.close()
                return
            except Exception as e:
                retry_count += 1
                if retry_count >= max_retries:
                    raise e
                # Wait before retrying
                import time
                time.sleep(1)
    
    def init_database(self):
        """Initialize database tables"""
        with self.get_connection() as conn:
            # Clients table
            conn.execute(text("""
                CREATE TABLE IF NOT EXISTS clients (
                    client_id VARCHAR(255) PRIMARY KEY,
                    client_name VARCHAR(255) NOT NULL,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                )
            """))
            
            # SIM numbers table with MSISDN
            conn.execute(text("""
                CREATE TABLE IF NOT EXISTS sim_numbers (
                    id SERIAL PRIMARY KEY,
                    client_id VARCHAR(255) NOT NULL,
                    sim_number VARCHAR(255) NOT NULL UNIQUE,
                    msisdn VARCHAR(255) NOT NULL UNIQUE,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    FOREIGN KEY (client_id) REFERENCES clients (client_id) ON DELETE CASCADE
                )
            """))
            
            # Usage records table with MSISDN and SIM number
            conn.execute(text("""
                CREATE TABLE IF NOT EXISTS usage_records (
                    id SERIAL PRIMARY KEY,
                    client_id VARCHAR(255) NOT NULL,
                    msisdn VARCHAR(255) NOT NULL,
                    sim_number VARCHAR(255) NOT NULL,
                    total_usage_gb DECIMAL(10,3) NOT NULL,
                    date_from DATE NOT NULL,
                    date_to DATE NOT NULL,
                    fup_reached BOOLEAN NOT NULL,
                    daily_breakdown TEXT,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    FOREIGN KEY (client_id) REFERENCES clients (client_id),
                    FOREIGN KEY (msisdn, sim_number) REFERENCES sim_numbers (msisdn, sim_number)
                )
            """))
            
            # Daily usage breakdown table
            conn.execute(text("""
                CREATE TABLE IF NOT EXISTS daily_usage (
                    id SERIAL PRIMARY KEY,
                    usage_record_id INTEGER NOT NULL,
                    date DATE NOT NULL,
                    usage_gb DECIMAL(10,3) NOT NULL,
                    FOREIGN KEY (usage_record_id) REFERENCES usage_records (id)
                )
            """))
            
            # Protocol usage table
            conn.execute(text("""
                CREATE TABLE IF NOT EXISTS protocol_usage (
                    id SERIAL PRIMARY KEY,
                    usage_record_id INTEGER NOT NULL,
                    protocol_name VARCHAR(255) NOT NULL,
                    category VARCHAR(255) NOT NULL,
                    usage_gb DECIMAL(10,3) NOT NULL,
                    percentage DECIMAL(5,2) NOT NULL,
                    FOREIGN KEY (usage_record_id) REFERENCES usage_records (id)
                )
            """))
            
            conn.commit()
    
    def add_client(self, client_id, client_name):
        """Add a new client"""
        with self.get_connection() as conn:
            conn.execute(text("""
                INSERT INTO clients (client_id, client_name)
                VALUES (:client_id, :client_name)
                ON CONFLICT (client_id) DO UPDATE SET
                    client_name = EXCLUDED.client_name
            """), {"client_id": client_id, "client_name": client_name})
            conn.commit()
    
    def add_sim_number(self, client_id, sim_number, msisdn):
        """Add a SIM number with MSISDN to a client"""
        with self.get_connection() as conn:
            conn.execute(text("""
                INSERT INTO sim_numbers (client_id, sim_number, msisdn) 
                VALUES (:client_id, :sim_number, :msisdn)
                ON CONFLICT (sim_number) DO NOTHING
            """), {"client_id": client_id, "sim_number": sim_number, "msisdn": msisdn})
            conn.commit()
    
    def get_client_sim_numbers(self, client_id):
        """Get all SIM numbers with MSISDN for a client"""
        with self.get_connection() as conn:
            return pd.read_sql_query("""
                SELECT id, sim_number, msisdn, created_at 
                FROM sim_numbers 
                WHERE client_id = %(client_id)s 
                ORDER BY created_at
            """, conn, params={"client_id": client_id})
    
    def delete_sim_number(self, sim_id):
        """Delete a SIM number by ID"""
        with self.get_connection() as conn:
            conn.execute(text("""
                DELETE FROM sim_numbers WHERE id = %(sim_id)s
            """), {"sim_id": sim_id})
            conn.commit()
    
    def get_all_sim_numbers(self):
        """Get all SIM numbers with MSISDN and client information"""
        with self.get_connection() as conn:
            return pd.read_sql_query("""
                SELECT s.id, s.sim_number, s.msisdn, s.client_id, c.client_name, s.created_at
                FROM sim_numbers s
                JOIN clients c ON s.client_id = c.client_id
                ORDER BY c.client_name, s.sim_number
            """, conn)
    
    def add_usage_record(self, client_id, msisdn, sim_number, total_usage_gb, date_from, date_to, fup_reached, daily_breakdown):
        """Add a usage record with MSISDN, SIM number and daily breakdown"""
        with self.get_connection() as conn:
            # Insert usage record
            result = conn.execute(text("""
                INSERT INTO usage_records 
                (client_id, msisdn, sim_number, total_usage_gb, date_from, date_to, fup_reached, daily_breakdown)
                VALUES (:client_id, :msisdn, :sim_number, :total_usage_gb, :date_from, :date_to, :fup_reached, :daily_breakdown)
                RETURNING id
            """), {
                "client_id": client_id,
                "msisdn": msisdn,
                "sim_number": sim_number,
                "total_usage_gb": total_usage_gb, 
                "date_from": date_from, 
                "date_to": date_to, 
                "fup_reached": fup_reached, 
                "daily_breakdown": json.dumps(daily_breakdown)
            })
            
            row = result.fetchone()
            usage_record_id = row[0] if row else None
            
            # Insert daily breakdown
            for date_str, usage_gb in daily_breakdown.items():
                conn.execute(text("""
                    INSERT INTO daily_usage (usage_record_id, date, usage_gb)
                    VALUES (:usage_record_id, :date, :usage_gb)
                """), {"usage_record_id": usage_record_id, "date": date_str, "usage_gb": usage_gb})
            
            conn.commit()
            return usage_record_id
    
    def add_protocol_usage(self, usage_record_id, protocol_name, category, usage_gb, percentage):
        """Add protocol usage data"""
        with self.get_connection() as conn:
            conn.execute(text("""
                INSERT INTO protocol_usage 
                (usage_record_id, protocol_name, category, usage_gb, percentage)
                VALUES (:usage_record_id, :protocol_name, :category, :usage_gb, :percentage)
            """), {
                "usage_record_id": int(usage_record_id),
                "protocol_name": str(protocol_name),
                "category": str(category),
                "usage_gb": float(usage_gb),
                "percentage": float(percentage)
            })
            conn.commit()
    
    def get_all_clients(self):
        """Get all clients"""
        with self.get_connection() as conn:
            return pd.read_sql_query("SELECT * FROM clients ORDER BY client_name", conn)
    
    def get_client_usage_records(self, client_id, date_from=None, date_to=None):
        """Get usage records for a specific client"""
        with self.get_connection() as conn:
            query = """
                SELECT ur.*, c.client_name 
                FROM usage_records ur
                JOIN clients c ON ur.client_id = c.client_id
                WHERE ur.client_id = %(client_id)s
            """
            params = {"client_id": client_id}
            
            if date_from:
                query += " AND ur.date_from >= %(date_from)s"
                params["date_from"] = date_from
            
            if date_to:
                query += " AND ur.date_to <= %(date_to)s"
                params["date_to"] = date_to
            
            query += " ORDER BY ur.date_from DESC"
            
            return pd.read_sql_query(query, conn, params=params)
    
    def get_daily_breakdown(self, usage_record_id):
        """Get daily breakdown for a usage record"""
        with self.get_connection() as conn:
            return pd.read_sql_query("""
                SELECT date, usage_gb 
                FROM daily_usage 
                WHERE usage_record_id = %(usage_record_id)s 
                ORDER BY date
            """, conn, params={"usage_record_id": usage_record_id})
    
    def get_protocol_usage(self, usage_record_id):
        """Get protocol usage for a usage record"""
        with self.get_connection() as conn:
            return pd.read_sql_query("""
                SELECT protocol_name, category, usage_gb, percentage 
                FROM protocol_usage 
                WHERE usage_record_id = %(usage_record_id)s
                ORDER BY usage_gb DESC
            """, conn, params={"usage_record_id": usage_record_id})
    
    def get_total_clients(self):
        """Get total number of clients"""
        with self.get_connection() as conn:
            result = conn.execute(text("SELECT COUNT(*) FROM clients"))
            row = result.fetchone()
            return row[0] if row else 0
    
    def get_total_usage_gb(self):
        """Get total usage across all records"""
        with self.get_connection() as conn:
            result = conn.execute(text("SELECT COALESCE(SUM(total_usage_gb), 0) FROM usage_records"))
            row = result.fetchone()
            return float(row[0]) if row else 0.0
    
    def get_active_connections(self):
        """Get number of active connections (records in last 30 days)"""
        with self.get_connection() as conn:
            thirty_days_ago = (datetime.now() - timedelta(days=30)).date()
            result = conn.execute(text("""
                SELECT COUNT(DISTINCT client_id) 
                FROM usage_records 
                WHERE date_to >= :thirty_days_ago
            """), {"thirty_days_ago": thirty_days_ago})
            row = result.fetchone()
            return row[0] if row else 0
    
    def get_fup_exceeded_count(self):
        """Get number of records with FUP exceeded"""
        with self.get_connection() as conn:
            result = conn.execute(text("SELECT COUNT(*) FROM usage_records WHERE fup_reached = true"))
            row = result.fetchone()
            return row[0] if row else 0
    
    def get_recent_usage_records(self, limit=10):
        """Get recent usage records"""
        with self.get_connection() as conn:
            return pd.read_sql_query("""
                SELECT ur.client_id, c.client_name, ur.total_usage_gb, 
                       ur.date_from, ur.date_to, ur.fup_reached, ur.created_at
                FROM usage_records ur
                JOIN clients c ON ur.client_id = c.client_id
                ORDER BY ur.created_at DESC
                LIMIT %(limit)s
            """, conn, params={"limit": limit})
    
    def get_daily_usage_trend(self, days=30):
        """Get daily usage trend for dashboard"""
        with self.get_connection() as conn:
            start_date = (datetime.now() - timedelta(days=days)).date()
            return pd.read_sql_query("""
                SELECT date, SUM(usage_gb) as total_usage_gb
                FROM daily_usage du
                JOIN usage_records ur ON du.usage_record_id = ur.id
                WHERE date >= %(start_date)s
                GROUP BY date
                ORDER BY date
            """, conn, params={"start_date": start_date})
    
    def get_fup_status_distribution(self):
        """Get FUP status distribution"""
        with self.get_connection() as conn:
            return pd.read_sql_query("""
                SELECT 
                    CASE WHEN fup_reached = true THEN 'Exceeded' ELSE 'Within Limit' END as fup_status,
                    COUNT(*) as count
                FROM usage_records
                GROUP BY fup_reached
            """, conn)
    
    def get_client_usage_summary(self, client_id):
        """Get total usage summary for a specific client"""
        with self.get_connection() as conn:
            return pd.read_sql_query("""
                SELECT 
                    c.client_id,
                    c.client_name,
                    c.sim_number,
                    COUNT(ur.id) as total_records,
                    SUM(ur.total_usage_gb) as total_usage_gb,
                    MIN(ur.date_from) as first_usage_date,
                    MAX(ur.date_to) as last_usage_date,
                    SUM(CASE WHEN ur.fup_reached = true THEN 1 ELSE 0 END) as fup_exceeded_count
                FROM clients c
                LEFT JOIN usage_records ur ON c.client_id = ur.client_id
                WHERE c.client_id = %(client_id)s
                GROUP BY c.client_id, c.client_name, c.sim_number
            """, conn, params={"client_id": client_id})
    
    def get_all_protocol_usage(self):
        """Get all protocol usage data"""
        with self.get_connection() as conn:
            return pd.read_sql_query("""
                SELECT pu.*, ur.client_id, c.client_name
                FROM protocol_usage pu
                JOIN usage_records ur ON pu.usage_record_id = ur.id
                JOIN clients c ON ur.client_id = c.client_id
                ORDER BY pu.usage_gb DESC
            """, conn)
