"""Utility function for gradio/external.py, designed for internal use."""

from __future__ import annotations

import base64
import math
import re
import warnings

import httpx
import yaml
from huggingface_hub import HfApi, ImageClassificationOutputElement, InferenceClient

from gradio import components
from gradio.exceptions import Error, TooManyRequestsError


def get_model_info(model_name, hf_token=None):
    hf_api = HfApi(token=hf_token)
    print(f"Fetching model from: https://huggingface.co/{model_name}")

    model_info = hf_api.model_info(model_name)
    pipeline = model_info.pipeline_tag
    tags = model_info.tags
    return pipeline, tags


##################
# Helper functions for processing tabular data
##################


def get_tabular_examples(model_name: str) -> dict[str, list[float]]:
    readme = httpx.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
    if readme.status_code != 200:
        warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
        example_data = {}
    else:
        yaml_regex = re.search(
            "(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
        )
        if yaml_regex is None:
            example_data = {}
        else:
            example_yaml = next(
                yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]])
            )
            example_data = example_yaml.get("widget", {}).get("structuredData", {})
    if not example_data:
        raise ValueError(
            f"No example data found in README.md of {model_name} - Cannot build gradio demo. "
            "See the README.md here: https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md "
            "for a reference on how to provide example data to your model."
        )
    # replace nan with string NaN for inference Endpoints
    for data in example_data.values():
        for i, val in enumerate(data):
            if isinstance(val, float) and math.isnan(val):
                data[i] = "NaN"
    return example_data


def cols_to_rows(
    example_data: dict[str, list[float | str] | None],
) -> tuple[list[str], list[list[float]]]:
    headers = list(example_data.keys())
    n_rows = max(len(example_data[header] or []) for header in headers)
    data = []
    for row_index in range(n_rows):
        row_data = []
        for header in headers:
            col = example_data[header] or []
            if row_index >= len(col):
                row_data.append("NaN")
            else:
                row_data.append(col[row_index])
        data.append(row_data)
    return headers, data


def rows_to_cols(incoming_data: dict) -> dict[str, dict[str, dict[str, list[str]]]]:
    data_column_wise = {}
    for i, header in enumerate(incoming_data["headers"]):
        data_column_wise[header] = [str(row[i]) for row in incoming_data["data"]]
    return {"inputs": {"data": data_column_wise}}


##################
# Helper functions for processing other kinds of data
##################


def postprocess_label(scores: list[ImageClassificationOutputElement]) -> dict:
    return {c.label: c.score for c in scores}


def postprocess_mask_tokens(scores: list[dict[str, str | float]]) -> dict:
    return {c["token_str"]: c["score"] for c in scores}


def postprocess_question_answering(answer: dict) -> tuple[str, dict]:
    return answer["answer"], {answer["answer"]: answer["score"]}


def postprocess_visual_question_answering(scores: list[dict[str, str | float]]) -> dict:
    return {c["answer"]: c["score"] for c in scores}


def zero_shot_classification_wrapper(client: InferenceClient):
    def zero_shot_classification_inner(input: str, labels: str, multi_label: bool):
        return client.zero_shot_classification(
            input, labels.split(","), multi_label=multi_label
        )

    return zero_shot_classification_inner


def sentence_similarity_wrapper(client: InferenceClient):
    def sentence_similarity_inner(input: str, sentences: str):
        return client.sentence_similarity(input, sentences.split("\n"))

    return sentence_similarity_inner


def text_generation_wrapper(client: InferenceClient):
    def text_generation_inner(input: str):
        return input + client.text_generation(input)

    return text_generation_inner


def conversational_wrapper(client: InferenceClient):
    def chat_fn(message, history):
        if not history:
            history = []
        history.append({"role": "user", "content": message})
        try:
            out = ""
            for chunk in client.chat_completion(messages=history, stream=True):
                out += chunk.choices[0].delta.content or ""
                yield out
        except Exception as e:
            handle_hf_error(e)

    return chat_fn


def encode_to_base64(r: httpx.Response) -> str:
    # Handles the different ways HF API returns the prediction
    base64_repr = base64.b64encode(r.content).decode("utf-8")
    data_prefix = ";base64,"
    # Case 1: base64 representation already includes data prefix
    if data_prefix in base64_repr:
        return base64_repr
    else:
        content_type = r.headers.get("content-type")
        # Case 2: the data prefix is a key in the response
        if content_type == "application/json":
            try:
                data = r.json()[0]
                content_type = data["content-type"]
                base64_repr = data["blob"]
            except KeyError as ke:
                raise ValueError(
                    "Cannot determine content type returned by external API."
                ) from ke
        # Case 3: the data prefix is included in the response headers
        else:
            pass
        new_base64 = f"data:{content_type};base64,{base64_repr}"
        return new_base64


def format_ner_list(input_string: str, ner_groups: list[dict[str, str | int]]):
    if len(ner_groups) == 0:
        return [(input_string, None)]

    output = []
    end = 0
    prev_end = 0

    for group in ner_groups:
        entity, start, end = group["entity_group"], group["start"], group["end"]
        output.append((input_string[prev_end:start], None))
        output.append((input_string[start:end], entity))
        prev_end = end

    output.append((input_string[end:], None))
    return output


def token_classification_wrapper(client: InferenceClient):
    def token_classification_inner(input: str):
        ner_list = client.token_classification(input)
        return format_ner_list(input, ner_list)  # type: ignore

    return token_classification_inner


def object_detection_wrapper(client: InferenceClient):
    def object_detection_inner(input: str):
        annotations = client.object_detection(input)
        formatted_annotations = [
            (
                (
                    a["box"]["xmin"],
                    a["box"]["ymin"],
                    a["box"]["xmax"],
                    a["box"]["ymax"],
                ),
                a["label"],
            )
            for a in annotations
        ]
        return (input, formatted_annotations)

    return object_detection_inner


def chatbot_preprocess(text, state):
    if not state:
        return text, [], []
    return (
        text,
        state["conversation"]["generated_responses"],
        state["conversation"]["past_user_inputs"],
    )


def chatbot_postprocess(response):
    chatbot_history = list(
        zip(
            response["conversation"]["past_user_inputs"],
            response["conversation"]["generated_responses"],
            strict=False,
        )
    )
    return chatbot_history, response


def tabular_wrapper(client: InferenceClient, pipeline: str):
    # This wrapper is needed to handle an issue in the InfereneClient where the model name is not
    # automatically loaded when using the tabular_classification and tabular_regression methods.
    # See: https://github.com/huggingface/huggingface_hub/issues/2015
    def tabular_inner(data):
        if pipeline not in ("tabular_classification", "tabular_regression"):
            raise TypeError(f"pipeline type {pipeline!r} not supported")
        assert client.model  # noqa: S101
        if pipeline == "tabular_classification":
            return client.tabular_classification(data, model=client.model)
        else:
            return client.tabular_regression(data, model=client.model)

    return tabular_inner


##################
# Helper function for cleaning up an Interface loaded from HF Spaces
##################


def streamline_spaces_interface(config: dict) -> dict:
    """Streamlines the interface config dictionary to remove unnecessary keys."""
    config["inputs"] = [
        components.get_component_instance(component)
        for component in config["input_components"]
    ]
    config["outputs"] = [
        components.get_component_instance(component)
        for component in config["output_components"]
    ]
    parameters = {
        "article",
        "description",
        "flagging_options",
        "inputs",
        "outputs",
        "title",
    }
    config = {k: config[k] for k in parameters}
    return config


def handle_hf_error(e: Exception):
    if "429" in str(e):
        raise TooManyRequestsError() from e
    elif "401" in str(e) or "You must provide an api_key" in str(e):
        raise Error("Unauthorized, please make sure you are signed in.") from e
    else:
        raise Error(str(e)) from e
