"""Pydantic data models and other dataclasses. This is the only file that uses Optional[]
typing syntax instead of | None syntax to work with pydantic"""

from __future__ import annotations

import pathlib
import secrets
import shutil
from abc import ABC, abstractmethod
from collections.abc import Iterator
from enum import Enum, auto
from typing import (
    Annotated,
    Any,
    Literal,
    NewType,
    Optional,
    Union,
)

from fastapi import Request
from gradio_client.data_classes import ParameterInfo
from gradio_client.documentation import document
from gradio_client.utils import is_file_obj_with_meta, traverse
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    GetCoreSchemaHandler,
    GetJsonSchemaHandler,
    RootModel,
    ValidationError,
    ValidationInfo,
    model_validator,
)
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema
from typing_extensions import NotRequired, TypedDict

from gradio.i18n import I18nData

try:
    from pydantic import JsonValue
except ImportError:
    JsonValue = Any

DeveloperPath = NewType("DeveloperPath", str)
UserProvidedPath = NewType("UserProvidedPath", str)


class CancelBody(BaseModel):
    session_hash: str
    fn_index: int
    event_id: str


class SimplePredictBody(BaseModel):
    data: list[Any]
    session_hash: Optional[str] = None


class _StarletteRequestPydanticAnnotation:
    @classmethod
    def __get_pydantic_core_schema__(
        cls,
        _source_type: Any,
        _handler: GetCoreSchemaHandler,
    ) -> core_schema.CoreSchema:
        def validate_request(value: Any) -> Request:
            if isinstance(value, Request):
                return value
            raise ValueError("Input must be a Starlette Request object")

        return core_schema.no_info_plain_validator_function(validate_request)

    @classmethod
    def __get_pydantic_json_schema__(
        cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
    ) -> JsonSchemaValue:
        return {"type": "object", "title": "StarletteRequest"}


PydanticStarletteRequest = Annotated[Request, _StarletteRequestPydanticAnnotation]


class PredictBody(BaseModel):
    session_hash: Optional[str] = None
    event_id: Optional[str] = None
    data: list[Any]
    event_data: Optional[Any] = None
    fn_index: Optional[int] = None
    trigger_id: Optional[int] = None
    simple_format: bool = False
    batched: Optional[bool] = (
        False  # Whether the data is a batch of samples (i.e. called from the queue if batch=True) or a single sample (i.e. called from the UI)
    )

    @classmethod
    def __get_pydantic_json_schema__(cls, core_schema, handler):
        return {
            "title": "PredictBody",
            "type": "object",
            "properties": {
                "session_hash": {"type": "string"},
                "event_id": {"type": "string"},
                "data": {"type": "array", "items": {"type": "object"}},
                "event_data": {"type": "object"},
                "fn_index": {"type": "integer"},
                "trigger_id": {"type": "integer"},
                "simple_format": {"type": "boolean"},
                "batched": {"type": "boolean"},
            },
            "required": ["data"],
        }


class PredictBodyInternal(PredictBody):
    "Separate class to avoid exposing PydanticStarletteRequest in the API validation"

    request: Optional[PydanticStarletteRequest] = (
        None  # dictionary of request headers, query parameters, url, etc. (used to to pass in request for queuing)
    )


class ResetBody(BaseModel):
    event_id: str


class ComponentServerJSONBody(BaseModel):
    session_hash: str
    component_id: int
    fn_name: str
    data: Any


class DataWithFiles(BaseModel):
    data: Any
    files: list[tuple[str, bytes]]


class ComponentServerBlobBody(BaseModel):
    session_hash: str
    component_id: int
    fn_name: str
    data: DataWithFiles


class InterfaceTypes(Enum):
    STANDARD = auto()
    INPUT_ONLY = auto()
    OUTPUT_ONLY = auto()
    UNIFIED = auto()


class GradioBaseModel(ABC):
    def copy_to_dir(self, dir: str | pathlib.Path) -> GradioDataModel:
        if not isinstance(self, (BaseModel, RootModel)):
            raise TypeError("must be used in a Pydantic model")
        dir = pathlib.Path(dir)

        # TODO: Making sure path is unique should be done in caller
        def unique_copy(obj: dict):
            data = FileData(**obj)
            return data._copy_to_dir(
                str(pathlib.Path(dir / secrets.token_hex(10)))
            ).model_dump()

        return self.__class__.from_json(
            x=traverse(
                self.model_dump(),
                unique_copy,
                FileData.is_file_data,
            )
        )

    @classmethod
    @abstractmethod
    def from_json(cls, x) -> GradioDataModel:
        pass


class JsonData(RootModel):
    """JSON data returned from a component that should not be modified further."""

    root: JsonValue


class GradioModel(GradioBaseModel, BaseModel):
    @classmethod
    def from_json(cls, x) -> GradioModel:
        return cls(**x)


class GradioRootModel(GradioBaseModel, RootModel):
    @classmethod
    def from_json(cls, x) -> GradioRootModel:
        return cls(root=x)


GradioDataModel = Union[GradioModel, GradioRootModel]


class FileDataDict(TypedDict):
    path: str  # server filepath
    url: NotRequired[Optional[str]]  # normalised server url
    size: NotRequired[Optional[int]]  # size in bytes
    orig_name: NotRequired[Optional[str]]  # original filename
    mime_type: NotRequired[Optional[str]]
    is_stream: bool
    meta: NotRequired[dict]


class FileDataMeta(TypedDict):
    _type: Literal["gradio.FileData"]


@document()
class FileData(GradioModel):
    """
    The FileData class is a subclass of the GradioModel class that represents a file object within a Gradio interface. It is used to store file data and metadata when a file is uploaded.

    Attributes:
        path: The server file path where the file is stored.
        url: The normalized server URL pointing to the file.
        size: The size of the file in bytes.
        orig_name: The original filename before upload.
        mime_type: The MIME type of the file.
        is_stream: Indicates whether the file is a stream.
        meta: Additional metadata used internally (should not be changed).
    """

    path: str  # server filepath
    url: Optional[str] = None  # normalised server url
    size: Optional[int] = None  # size in bytes
    orig_name: Optional[str] = None  # original filename
    mime_type: Optional[str] = None
    is_stream: bool = False
    meta: FileDataMeta = Field(default_factory=lambda: {"_type": "gradio.FileData"})

    @model_validator(mode="before")
    @classmethod
    def validate_model(cls, v, info: ValidationInfo):
        if (
            info.context
            and info.context.get("validate_meta")
            and not is_file_obj_with_meta(v)
        ):
            raise ValueError(
                "The 'meta' field must be explicitly provided in the input data and be equal to {'_type': 'gradio.FileData'}."
            )
        return v

    @property
    def is_none(self) -> bool:
        """
        Checks if the FileData object is empty, i.e., all attributes are None.

        Returns:
            bool: True if all attributes (except 'is_stream' and 'meta') are None, False otherwise.
        """
        return all(
            f is None
            for f in [
                self.path,
                self.url,
                self.size,
                self.orig_name,
                self.mime_type,
            ]
        )

    @classmethod
    def from_path(cls, path: str) -> FileData:
        """
        Creates a FileData object from a given file path.

        Args:
            path: The file path.

        Returns:
            FileData: An instance of FileData representing the file at the specified path.
        """
        return cls(path=path)

    def _copy_to_dir(self, dir: str) -> FileData:
        """
        Copies the file to a specified directory and returns a new FileData object representing the copied file.

        Args:
            dir: The destination directory.

        Returns:
            FileData: A new FileData object representing the copied file.

        Raises:
            ValueError: If the source file path is not set.
        """
        pathlib.Path(dir).mkdir(exist_ok=True)
        new_obj = dict(self)

        if not self.path:
            raise ValueError("Source file path is not set")
        new_name = shutil.copy(self.path, dir)
        new_obj["path"] = new_name
        return self.__class__(**new_obj)

    @classmethod
    def is_file_data(cls, obj: Any) -> bool:
        """
        Checks if an object is a valid FileData instance.

        Args:
            obj: The object to check.

        Returns:
            bool: True if the object is a valid FileData instance, False otherwise.
        """
        if isinstance(obj, dict):
            try:
                return not FileData(**obj).is_none
            except (TypeError, ValidationError):
                return False
        return False


class ListFiles(GradioRootModel):
    root: list[FileData]

    def __getitem__(self, index):
        return self.root[index]

    def __iter__(self) -> Iterator[FileData]:  # type: ignore[override]
        return iter(self.root)


class _StaticFiles:
    """
    Class to hold all static files for an app
    """

    all_paths = []

    def __init__(self, paths: list[str | pathlib.Path]) -> None:
        self.paths = paths
        self.all_paths = [pathlib.Path(p).resolve() for p in paths]

    @classmethod
    def clear(cls):
        cls.all_paths = []


class BodyCSS(TypedDict):
    body_background_fill: str
    body_text_color: str
    body_background_fill_dark: str
    body_text_color_dark: str


class Layout(TypedDict):
    id: int
    children: NotRequired[list[int | Layout]]


class Page(TypedDict):
    components: list[int]
    dependencies: list[int]
    layout: Layout


class BlocksConfigDict(TypedDict):
    version: str
    deep_link_state: NotRequired[Literal["valid", "invalid", "none"]]
    mode: str
    app_id: int
    dev_mode: bool
    analytics_enabled: bool
    components: list[dict[str, Any]]
    css: str | None
    connect_heartbeat: bool
    js: str | Literal[True] | None
    head: str | None
    title: str | I18nData
    space_id: str | None
    enable_queue: bool
    show_error: bool
    show_api: bool
    is_colab: bool
    max_file_size: int | None
    stylesheets: list[str]
    theme: str | None
    protocol: Literal["ws", "sse", "sse_v1", "sse_v2", "sse_v2.1", "sse_v3"]
    body_css: BodyCSS
    fill_height: bool
    fill_width: bool
    theme_hash: str
    layout: NotRequired[Layout]
    dependencies: NotRequired[list[dict[str, Any]]]
    root: NotRequired[str | None]
    username: NotRequired[str | None]
    api_prefix: str
    pwa: NotRequired[bool]
    page: dict[str, Page]
    pages: list[tuple[str, str]]
    current_page: NotRequired[str]
    i18n_translations: NotRequired[dict[str, dict[str, str]] | None]
    mcp_server: NotRequired[bool]


class MediaStreamChunk(TypedDict):
    data: bytes
    duration: float
    extension: str
    id: NotRequired[str]


class ImageData(GradioModel):
    path: Optional[str] = Field(default=None, description="Path to a local file")
    url: Optional[str] = Field(
        default=None, description="Publicly available url or base64 encoded image"
    )
    size: Optional[int] = Field(default=None, description="Size of image in bytes")
    orig_name: Optional[str] = Field(default=None, description="Original filename")
    mime_type: Optional[str] = Field(default=None, description="mime type of image")
    is_stream: bool = Field(default=False, description="Can always be set to False")
    meta: dict = {"_type": "gradio.FileData"}

    model_config = ConfigDict(
        json_schema_extra={
            "description": "For input, either path or url must be provided. For output, path is always provided."
        }
    )


class Base64ImageData(GradioModel):
    url: str = Field(description="base64 encoded image")


class APIReturnInfo(TypedDict):
    label: str
    type: dict[str, Any]
    python_type: dict[str, str]
    component: str


class APIEndpointInfo(TypedDict):
    description: NotRequired[str]
    parameters: list[ParameterInfo]
    returns: list[APIReturnInfo]
    show_api: bool


class APIInfo(TypedDict):
    named_endpoints: dict[str, APIEndpointInfo]
    unnamed_endpoints: dict[str, APIEndpointInfo]
