import pandas as pd
import numpy as np
import random
import io
from typing import Dict, List, Tuple

class ProtocolAnalyzer:
    def __init__(self, database_manager):
        self.db = database_manager
        
        # Define protocol categories and their typical usage patterns
        self.protocol_categories = {
            'Web Browsing': {
                'protocols': ['HTTP', 'HTTPS', 'DNS'],
                'base_percentage': 35,
                'variation': 10
            },
            'Video Streaming': {
                'protocols': ['YouTube', 'Netflix', 'Prime Video', 'Video Streaming'],
                'base_percentage': 25,
                'variation': 8
            },
            'Social Media': {
                'protocols': ['Facebook', 'Instagram', 'Twitter', 'TikTok', 'WhatsApp'],
                'base_percentage': 15,
                'variation': 5
            },
            'Gaming': {
                'protocols': ['Online Gaming', 'Steam', 'Game Updates'],
                'base_percentage': 8,
                'variation': 4
            },
            'File Transfer': {
                'protocols': ['FTP', 'BitTorrent', 'Cloud Sync', 'Email Attachments'],
                'base_percentage': 6,
                'variation': 3
            },
            'Audio Streaming': {
                'protocols': ['Spotify', 'Apple Music', 'Audio Streaming'],
                'base_percentage': 4,
                'variation': 2
            },
            'Adult Content': {
                'protocols': ['Adult Websites', 'Adult Streaming'],
                'base_percentage': 3,
                'variation': 2
            },
            'System Updates': {
                'protocols': ['OS Updates', 'App Updates', 'Security Updates'],
                'base_percentage': 2,
                'variation': 1
            },
            'Other': {
                'protocols': ['VPN', 'Remote Desktop', 'IoT Traffic', 'Unknown'],
                'base_percentage': 2,
                'variation': 1
            }
        }
    
    def generate_protocol_usage(self, usage_record_id: int, total_usage_gb: float):
        """
        Generate realistic protocol usage breakdown for a usage record
        """
        protocol_breakdown = self._generate_protocol_distribution(total_usage_gb)
        
        # Store protocol usage in database
        for protocol_name, data in protocol_breakdown.items():
            category = self._get_protocol_category(protocol_name)
            self.db.add_protocol_usage(
                usage_record_id,
                protocol_name,
                category,
                data['usage_gb'],
                data['percentage']
            )
    
    def _generate_protocol_distribution(self, total_usage_gb: float) -> Dict:
        """
        Generate realistic protocol distribution
        """
        protocol_breakdown = {}
        remaining_percentage = 100.0
        
        # Generate percentages for each category
        category_percentages = {}
        for category, config in self.protocol_categories.items():
            if remaining_percentage <= 0:
                break
            
            base = config['base_percentage']
            variation = config['variation']
            
            # Add randomness within bounds
            percentage = max(0, min(remaining_percentage, 
                                  base + random.uniform(-variation, variation)))
            category_percentages[category] = percentage
            remaining_percentage -= percentage
        
        # Normalize to ensure total is 100%
        total_generated = sum(category_percentages.values())
        if total_generated > 0:
            normalization_factor = 100.0 / total_generated
            category_percentages = {k: v * normalization_factor for k, v in category_percentages.items()}
        
        # Generate individual protocols within categories
        for category, category_percentage in category_percentages.items():
            if category_percentage <= 0:
                continue
            
            protocols = self.protocol_categories[category]['protocols']
            
            # Distribute category percentage among protocols
            protocol_percentages = self._distribute_category_percentage(protocols, category_percentage)
            
            for protocol, percentage in protocol_percentages.items():
                usage_gb = (percentage / 100.0) * total_usage_gb
                protocol_breakdown[protocol] = {
                    'usage_gb': usage_gb,
                    'percentage': percentage,
                    'category': category
                }
        
        # Ensure total usage matches exactly by adjusting the largest protocol
        total_generated_gb = sum(p['usage_gb'] for p in protocol_breakdown.values())
        if total_generated_gb != total_usage_gb and protocol_breakdown:
            # Find the protocol with the highest usage and adjust it
            largest_protocol = max(protocol_breakdown.keys(), key=lambda k: protocol_breakdown[k]['usage_gb'])
            adjustment = total_usage_gb - total_generated_gb
            protocol_breakdown[largest_protocol]['usage_gb'] += adjustment
            # Recalculate percentage for the adjusted protocol
            protocol_breakdown[largest_protocol]['percentage'] = (protocol_breakdown[largest_protocol]['usage_gb'] / total_usage_gb) * 100
        
        return protocol_breakdown
    
    def _distribute_category_percentage(self, protocols: List[str], category_percentage: float) -> Dict[str, float]:
        """
        Distribute category percentage among individual protocols
        """
        if not protocols:
            return {}
        
        # Generate random weights for protocols
        weights = [random.uniform(0.5, 2.0) for _ in protocols]
        total_weight = sum(weights)
        
        # Distribute percentage based on weights
        protocol_percentages = {}
        for i, protocol in enumerate(protocols):
            if len(protocols) == 1:
                protocol_percentages[protocol] = category_percentage
            else:
                # Randomly select subset of protocols (not all categories use all protocols)
                if random.random() < 0.7:  # 70% chance to include each protocol
                    percentage = (weights[i] / total_weight) * category_percentage
                    if percentage > 0.01:  # Only include if > 0.01%
                        protocol_percentages[protocol] = percentage
        
        return protocol_percentages
    
    def _get_protocol_category(self, protocol_name: str) -> str:
        """
        Get category for a protocol
        """
        for category, config in self.protocol_categories.items():
            if protocol_name in config['protocols']:
                return category
        return 'Other'
    
    def get_protocol_distribution(self):
        """
        Get overall protocol distribution across all usage records
        """
        with self.db.get_connection() as conn:
            return pd.read_sql_query("""
                SELECT protocol_name, SUM(usage_gb) as usage_gb, AVG(percentage) as avg_percentage
                FROM protocol_usage
                GROUP BY protocol_name
                ORDER BY usage_gb DESC
            """, conn)
    
    def get_category_usage(self):
        """
        Get usage by protocol category
        """
        with self.db.get_connection() as conn:
            return pd.read_sql_query("""
                SELECT category, SUM(usage_gb) as usage_gb, COUNT(*) as record_count
                FROM protocol_usage
                GROUP BY category
                ORDER BY usage_gb DESC
            """, conn)
    
    def get_detailed_protocol_usage(self):
        """
        Get detailed protocol usage with client information
        """
        return self.db.get_all_protocol_usage()
    
    def get_client_protocol_usage(self, client_id: str):
        """
        Get protocol usage for specific client
        """
        with self.db.get_connection() as conn:
            return pd.read_sql_query("""
                SELECT pu.protocol_name, pu.category, SUM(pu.usage_gb) as total_usage_gb,
                       AVG(pu.percentage) as avg_percentage
                FROM protocol_usage pu
                JOIN usage_records ur ON pu.usage_record_id = ur.id
                WHERE ur.client_id = %(client_id)s
                GROUP BY pu.protocol_name, pu.category
                ORDER BY total_usage_gb DESC
            """, conn, params={'client_id': client_id})
    
    def get_client_protocol_usage_detailed(self, client_id: str, date_from=None, date_to=None):
        """
        Get detailed protocol usage for specific client with date filtering
        """
        with self.db.get_connection() as conn:
            query = """
                SELECT pu.protocol_name, pu.category, SUM(pu.usage_gb) as total_usage_gb,
                       AVG(pu.percentage) as avg_percentage, COUNT(*) as record_count
                FROM protocol_usage pu
                JOIN usage_records ur ON pu.usage_record_id = ur.id
                WHERE ur.client_id = %(client_id)s
            """
            params = {'client_id': client_id}
            
            if date_from:
                query += " AND ur.date_from >= %(date_from)s"
                params['date_from'] = date_from
            
            if date_to:
                query += " AND ur.date_to <= %(date_to)s"
                params['date_to'] = date_to
            
            query += """
                GROUP BY pu.protocol_name, pu.category
                ORDER BY total_usage_gb DESC
            """
            
            return pd.read_sql_query(query, conn, params=params)
    
    def analyze_protocol_trends(self, days=30):
        """
        Analyze protocol usage trends over time
        """
        with self.db.get_connection() as conn:
            return pd.read_sql_query("""
                SELECT pu.protocol_name, pu.category, ur.date_from, SUM(pu.usage_gb) as usage_gb
                FROM protocol_usage pu
                JOIN usage_records ur ON pu.usage_record_id = ur.id
                WHERE ur.date_from >= date('now', '-{} days')
                GROUP BY pu.protocol_name, pu.category, ur.date_from
                ORDER BY ur.date_from, usage_gb DESC
            """.format(days), conn)
    
    def get_security_insights(self):
        """
        Get security-related insights from protocol usage
        """
        with self.db.get_connection() as conn:
            # Get adult content usage
            adult_content = pd.read_sql_query("""
                SELECT ur.client_id, c.client_name, SUM(pu.usage_gb) as adult_usage_gb
                FROM protocol_usage pu
                JOIN usage_records ur ON pu.usage_record_id = ur.id
                JOIN clients c ON ur.client_id = c.client_id
                WHERE pu.category = 'Adult Content'
                GROUP BY ur.client_id, c.client_name
                ORDER BY adult_usage_gb DESC
            """, conn)
            
            # Get VPN usage
            vpn_usage = pd.read_sql_query("""
                SELECT ur.client_id, c.client_name, SUM(pu.usage_gb) as vpn_usage_gb
                FROM protocol_usage pu
                JOIN usage_records ur ON pu.usage_record_id = ur.id
                JOIN clients c ON ur.client_id = c.client_id
                WHERE pu.protocol_name = 'VPN'
                GROUP BY ur.client_id, c.client_name
                ORDER BY vpn_usage_gb DESC
            """, conn)
            
            # Get high bandwidth users
            high_bandwidth = pd.read_sql_query("""
                SELECT ur.client_id, c.client_name, SUM(ur.total_usage_gb) as total_usage
                FROM usage_records ur
                JOIN clients c ON ur.client_id = c.client_id
                GROUP BY ur.client_id, c.client_name
                HAVING total_usage > 50
                ORDER BY total_usage DESC
            """, conn)
            
            return {
                'adult_content_users': adult_content,
                'vpn_users': vpn_usage,
                'high_bandwidth_users': high_bandwidth
            }
    
    def export_protocol_analysis(self):
        """
        Export comprehensive protocol analysis to CSV
        """
        # Get all protocol data
        protocol_data = self.get_detailed_protocol_usage()
        
        if protocol_data.empty:
            return ""
        
        # Create comprehensive analysis
        analysis_data = []
        
        # Overall statistics
        total_usage = protocol_data['usage_gb'].sum()
        
        for _, row in protocol_data.iterrows():
            analysis_data.append({
                'Client ID': row['client_id'],
                'Client Name': row['client_name'],
                'Protocol': row['protocol_name'],
                'Category': row['category'],
                'Usage (GB)': round(row['usage_gb'], 3),
                'Percentage': round(row['percentage'], 2),
                'Record ID': row['usage_record_id']
            })
        
        # Convert to DataFrame and then CSV
        df = pd.DataFrame(analysis_data)
        
        # Create CSV buffer
        csv_buffer = io.StringIO()
        df.to_csv(csv_buffer, index=False)
        csv_buffer.seek(0)
        
        return csv_buffer.getvalue()
    
    def get_protocol_recommendations(self, client_id: str):
        """
        Get recommendations based on protocol usage patterns
        """
        client_protocols = self.get_client_protocol_usage(client_id)
        
        if client_protocols.empty:
            return []
        
        recommendations = []
        
        # Check for high video streaming usage
        video_usage = client_protocols[client_protocols['category'] == 'Video Streaming']['total_usage_gb'].sum()
        if video_usage > 10:  # More than 10GB video streaming
            recommendations.append({
                'type': 'optimization',
                'message': f'High video streaming usage detected ({video_usage:.1f}GB). Consider video quality optimization.'
            })
        
        # Check for adult content
        adult_usage = client_protocols[client_protocols['category'] == 'Adult Content']['total_usage_gb'].sum()
        if adult_usage > 0:
            recommendations.append({
                'type': 'security',
                'message': f'Adult content usage detected ({adult_usage:.1f}GB). Consider content filtering policies.'
            })
        
        # Check for gaming
        gaming_usage = client_protocols[client_protocols['category'] == 'Gaming']['total_usage_gb'].sum()
        if gaming_usage > 5:
            recommendations.append({
                'type': 'performance',
                'message': f'Significant gaming usage ({gaming_usage:.1f}GB). Ensure low-latency connectivity.'
            })
        
        return recommendations
    
    def export_client_protocol_analysis(self, client_id: str, date_from=None, date_to=None):
        """
        Export client-specific protocol analysis to CSV
        """
        # Get client protocol data
        protocol_data = self.get_client_protocol_usage_detailed(client_id, date_from, date_to)
        
        if protocol_data.empty:
            return ""
        
        # Create comprehensive analysis
        analysis_data = []
        
        for _, row in protocol_data.iterrows():
            analysis_data.append({
                'Client ID': client_id,
                'Protocol': row['protocol_name'],
                'Category': row['category'],
                'Total Usage (GB)': round(float(row['total_usage_gb']), 3),
                'Average Percentage': round(float(row['avg_percentage']), 2),
                'Record Count': int(row['record_count']),
                'Date From': str(date_from),
                'Date To': str(date_to)
            })
        
        # Convert to DataFrame and then CSV
        df = pd.DataFrame(analysis_data)
        
        # Create CSV buffer
        csv_buffer = io.StringIO()
        df.to_csv(csv_buffer, index=False)
        csv_buffer.seek(0)
        
        return csv_buffer.getvalue()
