"""This module contains the EndpointV3Compatibility class, which is used to connect to Gradio apps running 3.x.x versions of Gradio."""

from __future__ import annotations

import json
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

import httpx
import huggingface_hub
import websockets
from packaging import version

from gradio_client import serializing, utils
from gradio_client.exceptions import SerializationSetupError
from gradio_client.utils import (
    Communicator,
)

if TYPE_CHECKING:
    from gradio_client import Client


class EndpointV3Compatibility:
    """Endpoint class for connecting to v3 endpoints. Backwards compatibility."""

    def __init__(self, client: Client, fn_index: int, dependency: dict, *_args):
        self.client: Client = client
        self.fn_index = fn_index
        self.dependency = dependency
        api_name = dependency.get("api_name")
        self.api_name: str | Literal[False] | None = (
            "/" + api_name if isinstance(api_name, str) else api_name
        )
        self.use_ws = self._use_websocket(self.dependency)
        self.protocol = "ws" if self.use_ws else "http"
        self.input_component_types = []
        self.output_component_types = []
        self.root_url = client.src + "/" if not client.src.endswith("/") else client.src
        try:
            # Only a real API endpoint if backend_fn is True (so not just a frontend function), serializers are valid,
            # and api_name is not False (meaning that the developer has explicitly disabled the API endpoint)
            self.serializers, self.deserializers = self._setup_serializers()
            self.is_valid = self.dependency["backend_fn"] and self.api_name is not False
        except SerializationSetupError:
            self.is_valid = False
        self.backend_fn = dependency.get("backend_fn")
        self.show_api = True

    def __repr__(self):
        return f"Endpoint src: {self.client.src}, api_name: {self.api_name}, fn_index: {self.fn_index}"

    def __str__(self):
        return self.__repr__()

    def make_end_to_end_fn(self, helper: Communicator | None = None):
        _predict = self.make_predict(helper)

        def _inner(*data):
            if not self.is_valid:
                raise utils.InvalidAPIEndpointError()
            data = self.insert_state(*data)
            data = self.serialize(*data)
            predictions = _predict(*data)
            predictions = self.process_predictions(*predictions)
            # Append final output only if not already present
            # for consistency between generators and not generators
            if helper:
                with helper.lock:
                    if not helper.job.outputs:
                        helper.job.outputs.append(predictions)
            return predictions

        return _inner

    def make_cancel(self, helper: Communicator | None = None):  # noqa: ARG002 (needed so that both endpoints classes have the same api)
        return None

    def make_predict(self, helper: Communicator | None = None):
        def _predict(*data) -> tuple:
            data = json.dumps(
                {
                    "data": data,
                    "fn_index": self.fn_index,
                    "session_hash": self.client.session_hash,
                }
            )
            hash_data = json.dumps(
                {
                    "fn_index": self.fn_index,
                    "session_hash": self.client.session_hash,
                }
            )
            if self.use_ws:
                result = utils.synchronize_async(self._ws_fn, data, hash_data, helper)
                if "error" in result:
                    raise ValueError(result["error"])
            else:
                response = httpx.post(
                    self.client.api_url,
                    headers=self.client.headers,
                    json=data,
                    verify=self.client.ssl_verify,
                    **self.client.httpx_kwargs,
                )
                result = json.loads(response.content.decode("utf-8"))
            try:
                output = result["data"]
            except KeyError as ke:
                is_public_space = (
                    self.client.space_id
                    and not huggingface_hub.space_info(self.client.space_id).private
                )
                if "error" in result and "429" in result["error"] and is_public_space:
                    raise utils.TooManyRequestsError(
                        f"Too many requests to the API, please try again later. To avoid being rate-limited, "
                        f"please duplicate the Space using Client.duplicate({self.client.space_id}) "
                        f"and pass in your Hugging Face token."
                    ) from None
                elif "error" in result:
                    raise ValueError(result["error"]) from None
                raise KeyError(
                    f"Could not find 'data' key in response. Response received: {result}"
                ) from ke
            return tuple(output)

        return _predict

    def _predict_resolve(self, *data) -> Any:
        """Needed for gradio.load(), which has a slightly different signature for serializing/deserializing"""
        outputs = self.make_predict()(*data)
        if len(self.dependency["outputs"]) == 1:
            return outputs[0]
        return outputs

    def _upload(
        self, file_paths: list[str | list[str]]
    ) -> list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]]:
        if not file_paths:
            return []
        # Put all the filepaths in one file
        # but then keep track of which index in the
        # original list they came from so we can recreate
        # the original structure
        files = []
        indices = []
        for i, fs in enumerate(file_paths):
            if not isinstance(fs, list):
                fs = [fs]
            for f in fs:
                files.append(("files", (Path(f).name, open(f, "rb"))))  # noqa: SIM115
                indices.append(i)
        r = httpx.post(
            self.client.upload_url,
            headers=self.client.headers,
            files=files,
            verify=self.client.ssl_verify,
            **self.client.httpx_kwargs,
        )
        if r.status_code != 200:
            uploaded = file_paths
        else:
            uploaded = []
            result = r.json()
            for i, fs in enumerate(file_paths):
                if isinstance(fs, list):
                    output = [o for ix, o in enumerate(result) if indices[ix] == i]
                    res = [
                        {
                            "is_file": True,
                            "name": o,
                            "orig_name": Path(f).name,
                            "data": None,
                        }
                        for f, o in zip(fs, output, strict=False)
                    ]
                else:
                    o = next(o for ix, o in enumerate(result) if indices[ix] == i)
                    res = {
                        "is_file": True,
                        "name": o,
                        "orig_name": Path(fs).name,
                        "data": None,
                    }
                uploaded.append(res)
        return uploaded

    def _add_uploaded_files_to_data(
        self,
        files: list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]],
        data: list[Any],
    ) -> None:
        """Helper function to modify the input data with the uploaded files."""
        file_counter = 0
        for i, t in enumerate(self.input_component_types):
            if t in ["file", "uploadbutton"]:
                data[i] = files[file_counter]
                file_counter += 1

    def insert_state(self, *data) -> tuple:
        data = list(data)
        for i, input_component_type in enumerate(self.input_component_types):
            if input_component_type == utils.STATE_COMPONENT:
                data.insert(i, None)
        return tuple(data)

    def remove_skipped_components(self, *data) -> tuple:
        data = [
            d
            for d, oct in zip(data, self.output_component_types, strict=False)
            if oct not in utils.SKIP_COMPONENTS
        ]
        return tuple(data)

    def reduce_singleton_output(self, *data) -> Any:
        if (
            len(
                [
                    oct
                    for oct in self.output_component_types
                    if oct not in utils.SKIP_COMPONENTS
                ]
            )
            == 1
        ):
            return data[0]
        else:
            return data

    def serialize(self, *data) -> tuple:
        if len(data) != len(self.serializers):
            raise ValueError(
                f"Expected {len(self.serializers)} arguments, got {len(data)}"
            )

        files = [
            f
            for f, t in zip(data, self.input_component_types, strict=False)
            if t in ["file", "uploadbutton"]
        ]
        uploaded_files = self._upload(files)
        data = list(data)
        self._add_uploaded_files_to_data(uploaded_files, data)
        o = tuple(
            [s.serialize(d) for s, d in zip(self.serializers, data, strict=False)]
        )
        return o

    def deserialize(self, *data) -> tuple:
        if len(data) != len(self.deserializers):
            raise ValueError(
                f"Expected {len(self.deserializers)} outputs, got {len(data)}"
            )
        outputs = tuple(
            [
                s.deserialize(
                    d,
                    save_dir=self.client.output_dir,
                    hf_token=self.client.hf_token,
                    root_url=self.root_url,
                )
                for s, d in zip(self.deserializers, data, strict=False)
            ]
        )
        return outputs

    def process_predictions(self, *predictions):
        if self.client.download_files:
            predictions = self.deserialize(*predictions)
        predictions = self.remove_skipped_components(*predictions)
        predictions = self.reduce_singleton_output(*predictions)
        return predictions

    def _setup_serializers(
        self,
    ) -> tuple[list[serializing.Serializable], list[serializing.Serializable]]:
        inputs = self.dependency["inputs"]
        serializers = []

        for i in inputs:
            for component in self.client.config["components"]:
                if component["id"] == i:
                    component_name = component["type"]
                    self.input_component_types.append(component_name)
                    if component.get("serializer"):
                        serializer_name = component["serializer"]
                        if serializer_name not in serializing.SERIALIZER_MAPPING:
                            raise SerializationSetupError(
                                f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
                            )
                        serializer = serializing.SERIALIZER_MAPPING[serializer_name]
                    elif component_name in serializing.COMPONENT_MAPPING:
                        serializer = serializing.COMPONENT_MAPPING[component_name]
                    else:
                        raise SerializationSetupError(
                            f"Unknown component: {component_name}, you may need to update your gradio_client version."
                        )
                    serializers.append(serializer())  # type: ignore

        outputs = self.dependency["outputs"]
        deserializers = []
        for i in outputs:
            for component in self.client.config["components"]:
                if component["id"] == i:
                    component_name = component["type"]
                    self.output_component_types.append(component_name)
                    if component.get("serializer"):
                        serializer_name = component["serializer"]
                        if serializer_name not in serializing.SERIALIZER_MAPPING:
                            raise SerializationSetupError(
                                f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
                            )
                        deserializer = serializing.SERIALIZER_MAPPING[serializer_name]
                    elif component_name in utils.SKIP_COMPONENTS:
                        deserializer = serializing.SimpleSerializable
                    elif component_name in serializing.COMPONENT_MAPPING:
                        deserializer = serializing.COMPONENT_MAPPING[component_name]
                    else:
                        raise SerializationSetupError(
                            f"Unknown component: {component_name}, you may need to update your gradio_client version."
                        )
                    deserializers.append(deserializer())  # type: ignore

        return serializers, deserializers

    def _use_websocket(self, dependency: dict) -> bool:
        queue_enabled = self.client.config.get("enable_queue", False)
        queue_uses_websocket = version.parse(
            self.client.config.get("version", "2.0")
        ) >= version.Version("3.2")
        dependency_uses_queue = dependency.get("queue", False) is not False
        return queue_enabled and queue_uses_websocket and dependency_uses_queue

    async def _ws_fn(self, data, hash_data, helper: Communicator):
        async with websockets.connect(  # type: ignore
            self.client.ws_url,
            open_timeout=10,
            extra_headers=self.client.headers,
            max_size=1024 * 1024 * 1024,
        ) as websocket:
            return await utils.get_pred_from_ws(websocket, data, hash_data, helper)
