import base64
import hashlib
import hmac
import secrets
import struct
import time
from typing import Iterator

from django.core.cache import cache

from allauth.core import context
from allauth.mfa import app_settings
from allauth.mfa.models import Authenticator
from allauth.mfa.utils import decrypt, encrypt


SECRET_SESSION_KEY = "mfa.totp.secret"  # nosec


def generate_totp_secret(length: int = 20) -> str:
    random_bytes = secrets.token_bytes(length)
    return base64.b32encode(random_bytes).decode("utf-8")


def get_totp_secret(regenerate: bool = False) -> str:
    secret = None
    if not regenerate:
        secret = context.request.session.get(SECRET_SESSION_KEY)
    if not secret:
        secret = context.request.session[SECRET_SESSION_KEY] = generate_totp_secret()
    return secret


def yield_hotp_counters_from_time() -> Iterator[int]:
    current_time = int(time.time())  # Get the current Unix timestamp
    counter = current_time // app_settings.TOTP_PERIOD
    for i in range(-app_settings.TOTP_TOLERANCE, app_settings.TOTP_TOLERANCE + 1):
        yield counter + i


def hotp_value(secret: str, counter: int) -> int:
    # Convert the counter to a byte array using big-endian encoding
    counter_bytes = struct.pack(">Q", counter)
    secret_enc = base64.b32decode(secret.encode("ascii"), casefold=True)
    # Calculate the HMAC-SHA1 hash using the secret and counter
    hmac_result = hmac.new(secret_enc, counter_bytes, hashlib.sha1).digest()
    # Get the last 4 bits of the HMAC result to determine the offset
    offset = hmac_result[-1] & 0x0F
    # Extract an 31-bit slice from the HMAC result starting at the offset + 1 bit
    truncated_hash = bytearray(hmac_result[offset : offset + 4])
    truncated_hash[0] = truncated_hash[0] & 0x7F
    # Convert the truncated hash to an integer value
    value = struct.unpack(">I", truncated_hash)[0]
    # Apply modulo to get a value within the specified number of digits
    value %= 10**app_settings.TOTP_DIGITS
    return value


def format_hotp_value(value: int) -> str:
    return f"{value:0{app_settings.TOTP_DIGITS}}"


def _is_insecure_bypass(code: str) -> bool:
    return bool(code and app_settings.TOTP_INSECURE_BYPASS_CODE == code)


def validate_totp_code(secret: str, code: str) -> bool:
    if _is_insecure_bypass(code):
        return True
    counters = yield_hotp_counters_from_time()
    for counter in counters:
        value = hotp_value(secret, counter)
        if code == format_hotp_value(value):
            return True
    return False


class TOTP:
    def __init__(self, instance: Authenticator) -> None:
        self.instance = instance

    @classmethod
    def activate(cls, user, secret: str) -> "TOTP":
        instance = Authenticator(
            user=user, type=Authenticator.Type.TOTP, data={"secret": encrypt(secret)}
        )
        instance.save()
        return cls(instance)

    def validate_code(self, code: str) -> bool:
        if _is_insecure_bypass(code):
            return True
        if self._is_code_used(code):
            return False

        secret = decrypt(self.instance.data["secret"])
        valid = validate_totp_code(secret, code)
        if valid:
            self._mark_code_used(code)
        return valid

    def _get_used_cache_key(self, code: str) -> str:
        return f"allauth.mfa.totp.used?user={self.instance.user_id}&code={code}"

    def _is_code_used(self, code: str) -> bool:
        return cache.get(self._get_used_cache_key(code)) == "y"

    def _mark_code_used(self, code: str) -> None:
        cache.set(self._get_used_cache_key(code), "y", timeout=app_settings.TOTP_PERIOD)
