
import socket
import struct
import hashlib
import secrets
from django.conf import settings
from django.utils import timezone
from .models import ClientConnection
from customers.models import Customer
import logging

logger = logging.getLogger(__name__)

class RADIUSServer:
    """Simple RADIUS server implementation for PPPoE authentication"""
    
    # RADIUS packet types
    ACCESS_REQUEST = 1
    ACCESS_ACCEPT = 2
    ACCESS_REJECT = 3
    ACCOUNTING_REQUEST = 4
    ACCOUNTING_RESPONSE = 5
    
    # RADIUS attributes
    USER_NAME = 1
    USER_PASSWORD = 2
    NAS_IP_ADDRESS = 4
    SERVICE_TYPE = 6
    FRAMED_PROTOCOL = 7
    FRAMED_IP_ADDRESS = 8
    CALLING_STATION_ID = 31
    NAS_IDENTIFIER = 32
    
    def __init__(self, host='0.0.0.0', port=1812, secret=None):
        self.host = host
        self.port = port
        self.secret = secret or getattr(settings, 'RADIUS_SECRET', 'radiussecret').encode()
        self.socket = None
        
    def start_server(self):
        """Start the RADIUS server"""
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.socket.bind((self.host, self.port))
        logger.info(f"RADIUS server started on {self.host}:{self.port}")
        
        while True:
            try:
                data, addr = self.socket.recvfrom(4096)
                self.handle_request(data, addr)
            except Exception as e:
                logger.error(f"RADIUS server error: {e}")
    
    def handle_request(self, data, addr):
        """Handle incoming RADIUS request"""
        if len(data) < 20:
            return
            
        # Parse RADIUS header
        code = data[0]
        identifier = data[1]
        length = struct.unpack('!H', data[2:4])[0]
        authenticator = data[4:20]
        
        if code == self.ACCESS_REQUEST:
            response = self.handle_access_request(data, identifier, authenticator)
            if response:
                self.socket.sendto(response, addr)
        elif code == self.ACCOUNTING_REQUEST:
            response = self.handle_accounting_request(data, identifier, authenticator)
            if response:
                self.socket.sendto(response, addr)
    
    def handle_access_request(self, data, identifier, authenticator):
        """Handle authentication request"""
        attributes = self.parse_attributes(data[20:])
        
        username = attributes.get(self.USER_NAME, b'').decode('utf-8')
        password = attributes.get(self.USER_PASSWORD, b'')
        calling_station = attributes.get(self.CALLING_STATION_ID, b'').decode('utf-8')
        
        # Authenticate user
        if self.authenticate_user(username, password, calling_station):
            return self.create_access_accept(identifier, authenticator, username)
        else:
            return self.create_access_reject(identifier, authenticator)
    
    def authenticate_user(self, username, password, mac_address=None):
        """Authenticate PPPoE user against database"""
        try:
            connection = ClientConnection.objects.get(
                username=username, 
                is_active=True
            )
            
            # Verify password (you should use proper password hashing)
            if connection.password == password.decode('utf-8'):
                # Update connection info
                connection.is_online = True
                connection.mac_address = mac_address or connection.mac_address
                connection.last_seen = timezone.now()
                connection.save()
                
                # Log successful authentication
                from .models import ConnectionLog
                ConnectionLog.objects.create(
                    connection=connection,
                    action='radius_auth_success',
                    details=f'RADIUS authentication successful from MAC: {mac_address}'
                )
                
                logger.info(f"RADIUS authentication successful for {username}")
                return True
                
        except ClientConnection.DoesNotExist:
            logger.warning(f"RADIUS authentication failed - user not found: {username}")
        except Exception as e:
            logger.error(f"RADIUS authentication error for {username}: {e}")
            
        return False
    
    def parse_attributes(self, data):
        """Parse RADIUS attributes"""
        attributes = {}
        pos = 0
        
        while pos < len(data):
            if pos + 2 > len(data):
                break
                
            attr_type = data[pos]
            attr_length = data[pos + 1]
            
            if pos + attr_length > len(data):
                break
                
            attr_value = data[pos + 2:pos + attr_length]
            attributes[attr_type] = attr_value
            pos += attr_length
            
        return attributes
    
    def create_access_accept(self, identifier, request_auth, username):
        """Create ACCESS-ACCEPT response"""
        # Get user's assigned IP if any
        try:
            connection = ClientConnection.objects.get(username=username)
            ip_address = connection.ip_address
        except:
            ip_address = None
            
        response = struct.pack('!BBH16s', self.ACCESS_ACCEPT, identifier, 20, b'\x00' * 16)
        
        # Add Framed-IP-Address attribute if available
        if ip_address:
            ip_bytes = socket.inet_aton(str(ip_address))
            attr = struct.pack('!BB', self.FRAMED_IP_ADDRESS, 6) + ip_bytes
            response += attr
            
        # Calculate response authenticator
        response_auth = hashlib.md5(response[:4] + request_auth + response[20:] + self.secret).digest()
        response = response[:4] + response_auth + response[20:]
        
        return response
    
    def create_access_reject(self, identifier, request_auth):
        """Create ACCESS-REJECT response"""
        response = struct.pack('!BBH16s', self.ACCESS_REJECT, identifier, 20, b'\x00' * 16)
        response_auth = hashlib.md5(response[:4] + request_auth + response[20:] + self.secret).digest()
        response = response[:4] + response_auth + response[20:]
        return response
    
    def handle_accounting_request(self, data, identifier, authenticator):
        """Handle accounting request"""
        attributes = self.parse_attributes(data[20:])
        username = attributes.get(self.USER_NAME, b'').decode('utf-8')
        
        # Create accounting response
        response = struct.pack('!BBH16s', self.ACCOUNTING_RESPONSE, identifier, 20, authenticator)
        response_auth = hashlib.md5(response[:4] + authenticator + response[20:] + self.secret).digest()
        response = response[:4] + response_auth + response[20:]
        
        # Log accounting info
        logger.info(f"RADIUS accounting for user: {username}")
        
        return response

class RADIUSClient:
    """RADIUS client for connecting to external RADIUS servers"""
    
    def __init__(self, server_host, server_port=1812, secret=None):
        self.server_host = server_host
        self.server_port = server_port
        self.secret = secret or getattr(settings, 'RADIUS_SECRET', 'radiussecret').encode()
    
    def authenticate(self, username, password):
        """Send authentication request to external RADIUS server"""
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            sock.settimeout(5)
            
            # Create RADIUS packet
            identifier = secrets.randbits(8)
            authenticator = secrets.token_bytes(16)
            
            # Build attributes
            attrs = b''
            # Username attribute
            username_bytes = username.encode('utf-8')
            attrs += struct.pack('!BB', 1, len(username_bytes) + 2) + username_bytes
            
            # Password attribute (should be encrypted)
            password_bytes = password.encode('utf-8')
            attrs += struct.pack('!BB', 2, len(password_bytes) + 2) + password_bytes
            
            # Create packet
            packet = struct.pack('!BBH16s', 1, identifier, 20 + len(attrs), authenticator) + attrs
            
            # Send request
            sock.sendto(packet, (self.server_host, self.server_port))
            
            # Receive response
            response, _ = sock.recvfrom(4096)
            
            if len(response) >= 20:
                response_code = response[0]
                return response_code == 2  # ACCESS_ACCEPT
                
        except Exception as e:
            logger.error(f"RADIUS client error: {e}")
            
        return False
