import asyncio
import socket
import weakref
from typing import Any, Dict, Final, List, Optional, Tuple, Type, Union

from .abc import AbstractResolver, ResolveResult

__all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")


try:
    import aiodns

    aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo")
except ImportError:  # pragma: no cover
    aiodns = None  # type: ignore[assignment]
    aiodns_default = False


_NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
_NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
_AI_ADDRCONFIG = socket.AI_ADDRCONFIG
if hasattr(socket, "AI_MASK"):
    _AI_ADDRCONFIG &= socket.AI_MASK


class ThreadedResolver(AbstractResolver):
    """Threaded resolver.

    Uses an Executor for synchronous getaddrinfo() calls.
    concurrent.futures.ThreadPoolExecutor is used by default.
    """

    def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
        self._loop = loop or asyncio.get_running_loop()

    async def resolve(
        self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
    ) -> List[ResolveResult]:
        infos = await self._loop.getaddrinfo(
            host,
            port,
            type=socket.SOCK_STREAM,
            family=family,
            flags=_AI_ADDRCONFIG,
        )

        hosts: List[ResolveResult] = []
        for family, _, proto, _, address in infos:
            if family == socket.AF_INET6:
                if len(address) < 3:
                    # IPv6 is not supported by Python build,
                    # or IPv6 is not enabled in the host
                    continue
                if address[3]:
                    # This is essential for link-local IPv6 addresses.
                    # LL IPv6 is a VERY rare case. Strictly speaking, we should use
                    # getnameinfo() unconditionally, but performance makes sense.
                    resolved_host, _port = await self._loop.getnameinfo(
                        address, _NAME_SOCKET_FLAGS
                    )
                    port = int(_port)
                else:
                    resolved_host, port = address[:2]
            else:  # IPv4
                assert family == socket.AF_INET
                resolved_host, port = address  # type: ignore[misc]
            hosts.append(
                ResolveResult(
                    hostname=host,
                    host=resolved_host,
                    port=port,
                    family=family,
                    proto=proto,
                    flags=_NUMERIC_SOCKET_FLAGS,
                )
            )

        return hosts

    async def close(self) -> None:
        pass


class AsyncResolver(AbstractResolver):
    """Use the `aiodns` package to make asynchronous DNS lookups"""

    def __init__(
        self,
        loop: Optional[asyncio.AbstractEventLoop] = None,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        if aiodns is None:
            raise RuntimeError("Resolver requires aiodns library")

        self._loop = loop or asyncio.get_running_loop()
        self._manager: Optional[_DNSResolverManager] = None
        # If custom args are provided, create a dedicated resolver instance
        # This means each AsyncResolver with custom args gets its own
        # aiodns.DNSResolver instance
        if args or kwargs:
            self._resolver = aiodns.DNSResolver(*args, **kwargs)
            return
        # Use the shared resolver from the manager for default arguments
        self._manager = _DNSResolverManager()
        self._resolver = self._manager.get_resolver(self, self._loop)

        if not hasattr(self._resolver, "gethostbyname"):
            # aiodns 1.1 is not available, fallback to DNSResolver.query
            self.resolve = self._resolve_with_query  # type: ignore

    async def resolve(
        self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
    ) -> List[ResolveResult]:
        try:
            resp = await self._resolver.getaddrinfo(
                host,
                port=port,
                type=socket.SOCK_STREAM,
                family=family,
                flags=_AI_ADDRCONFIG,
            )
        except aiodns.error.DNSError as exc:
            msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
            raise OSError(None, msg) from exc
        hosts: List[ResolveResult] = []
        for node in resp.nodes:
            address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr
            family = node.family
            if family == socket.AF_INET6:
                if len(address) > 3 and address[3]:
                    # This is essential for link-local IPv6 addresses.
                    # LL IPv6 is a VERY rare case. Strictly speaking, we should use
                    # getnameinfo() unconditionally, but performance makes sense.
                    result = await self._resolver.getnameinfo(
                        (address[0].decode("ascii"), *address[1:]),
                        _NAME_SOCKET_FLAGS,
                    )
                    resolved_host = result.node
                else:
                    resolved_host = address[0].decode("ascii")
                    port = address[1]
            else:  # IPv4
                assert family == socket.AF_INET
                resolved_host = address[0].decode("ascii")
                port = address[1]
            hosts.append(
                ResolveResult(
                    hostname=host,
                    host=resolved_host,
                    port=port,
                    family=family,
                    proto=0,
                    flags=_NUMERIC_SOCKET_FLAGS,
                )
            )

        if not hosts:
            raise OSError(None, "DNS lookup failed")

        return hosts

    async def _resolve_with_query(
        self, host: str, port: int = 0, family: int = socket.AF_INET
    ) -> List[Dict[str, Any]]:
        qtype: Final = "AAAA" if family == socket.AF_INET6 else "A"

        try:
            resp = await self._resolver.query(host, qtype)
        except aiodns.error.DNSError as exc:
            msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
            raise OSError(None, msg) from exc

        hosts = []
        for rr in resp:
            hosts.append(
                {
                    "hostname": host,
                    "host": rr.host,
                    "port": port,
                    "family": family,
                    "proto": 0,
                    "flags": socket.AI_NUMERICHOST,
                }
            )

        if not hosts:
            raise OSError(None, "DNS lookup failed")

        return hosts

    async def close(self) -> None:
        if self._manager:
            # Release the resolver from the manager if using the shared resolver
            self._manager.release_resolver(self, self._loop)
            self._manager = None  # Clear reference to manager
            self._resolver = None  # type: ignore[assignment] # Clear reference to resolver
            return
        # Otherwise cancel our dedicated resolver
        if self._resolver is not None:
            self._resolver.cancel()
        self._resolver = None  # type: ignore[assignment] # Clear reference


class _DNSResolverManager:
    """Manager for aiodns.DNSResolver objects.

    This class manages shared aiodns.DNSResolver instances
    with no custom arguments across different event loops.
    """

    _instance: Optional["_DNSResolverManager"] = None

    def __new__(cls) -> "_DNSResolverManager":
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._init()
        return cls._instance

    def _init(self) -> None:
        # Use WeakKeyDictionary to allow event loops to be garbage collected
        self._loop_data: weakref.WeakKeyDictionary[
            asyncio.AbstractEventLoop,
            tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]],
        ] = weakref.WeakKeyDictionary()

    def get_resolver(
        self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop
    ) -> "aiodns.DNSResolver":
        """Get or create the shared aiodns.DNSResolver instance for a specific event loop.

        Args:
            client: The AsyncResolver instance requesting the resolver.
                   This is required to track resolver usage.
            loop: The event loop to use for the resolver.
        """
        # Create a new resolver and client set for this loop if it doesn't exist
        if loop not in self._loop_data:
            resolver = aiodns.DNSResolver(loop=loop)
            client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet()
            self._loop_data[loop] = (resolver, client_set)
        else:
            # Get the existing resolver and client set
            resolver, client_set = self._loop_data[loop]

        # Register this client with the loop
        client_set.add(client)
        return resolver

    def release_resolver(
        self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop
    ) -> None:
        """Release the resolver for an AsyncResolver client when it's closed.

        Args:
            client: The AsyncResolver instance to release.
            loop: The event loop the resolver was using.
        """
        # Remove client from its loop's tracking
        if loop not in self._loop_data:
            return
        resolver, client_set = self._loop_data[loop]
        client_set.discard(client)
        # If no more clients for this loop, cancel and remove its resolver
        if not client_set:
            if resolver is not None:
                resolver.cancel()
            del self._loop_data[loop]


_DefaultType = Type[Union[AsyncResolver, ThreadedResolver]]
DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver
