# Taken from https://gist.github.com/kevinastone/a6a62db57577b3f24e8a6865ed311463
# Context: https://github.com/encode/starlette/pull/1090
from __future__ import annotations

import os
import re
import stat
from typing import NamedTuple
from urllib.parse import quote

import aiofiles
from aiofiles.os import stat as aio_stat
from starlette.datastructures import Headers
from starlette.exceptions import HTTPException
from starlette.responses import Response, guess_type  # type: ignore
from starlette.staticfiles import StaticFiles
from starlette.types import Receive, Scope, Send

RANGE_REGEX = re.compile(r"^bytes=(?P<start>\d+)-(?P<end>\d*)$")


class ClosedRange(NamedTuple):
    start: int
    end: int

    def __len__(self) -> int:
        return self.end - self.start + 1

    def __bool__(self) -> bool:
        return len(self) > 0


class OpenRange(NamedTuple):
    start: int
    end: int | None = None

    def clamp(self, start: int, end: int) -> ClosedRange:
        begin = max(self.start, start)
        end = min(x for x in (self.end, end) if x)

        begin = min(begin, end)
        end = max(begin, end)

        return ClosedRange(begin, end)


class RangedFileResponse(Response):
    chunk_size = 4096

    def __init__(
        self,
        path: str | os.PathLike,
        range: OpenRange,
        headers: dict[str, str] | None = None,
        media_type: str | None = None,
        filename: str | None = None,
        stat_result: os.stat_result | None = None,
        method: str | None = None,
    ) -> None:
        if aiofiles is None:
            raise ModuleNotFoundError(
                "'aiofiles' must be installed to use FileResponse"
            )
        self.path = path
        self.range = range
        self.filename = filename
        self.background = None
        self.send_header_only = method is not None and method.upper() == "HEAD"
        if media_type is None:
            media_type = guess_type(filename or path)[0] or "text/plain"
        self.media_type = media_type
        self.init_headers(headers or {})
        if self.filename is not None:
            content_disposition_filename = quote(self.filename)
            if content_disposition_filename != self.filename:
                content_disposition = (
                    f"attachment; filename*=utf-8''{content_disposition_filename}"
                )
            else:
                content_disposition = f'attachment; filename="{self.filename}"'
            self.headers.setdefault("content-disposition", content_disposition)
        self.stat_result = stat_result

    def set_range_headers(self, range: ClosedRange) -> None:
        if not self.stat_result:
            raise ValueError("No stat result to set range headers with")
        total_length = self.stat_result.st_size
        content_length = len(range)
        self.headers["content-range"] = (
            f"bytes {range.start}-{range.end}/{total_length}"
        )
        self.headers["content-length"] = str(content_length)
        pass

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:  # noqa: ARG002
        if self.stat_result is None:
            try:
                stat_result = await aio_stat(self.path)
                self.stat_result = stat_result
            except FileNotFoundError as fnfe:
                raise RuntimeError(
                    f"File at path {self.path} does not exist."
                ) from fnfe
            else:
                mode = stat_result.st_mode
                if not stat.S_ISREG(mode):
                    raise RuntimeError(f"File at path {self.path} is not a file.")

        byte_range = self.range.clamp(0, self.stat_result.st_size)
        self.set_range_headers(byte_range)

        async with aiofiles.open(self.path, mode="rb") as file:
            await file.seek(byte_range.start)
            await send(
                {
                    "type": "http.response.start",
                    "status": 206,
                    "headers": self.raw_headers,
                }
            )
            if self.send_header_only:
                await send(
                    {"type": "http.response.body", "body": b"", "more_body": False}
                )
            else:
                remaining_bytes = len(byte_range)

                if not byte_range:
                    await send(
                        {"type": "http.response.body", "body": b"", "more_body": False}
                    )
                    return

                while remaining_bytes > 0:
                    chunk_size = min(self.chunk_size, remaining_bytes)
                    chunk = await file.read(chunk_size)
                    remaining_bytes -= len(chunk)
                    await send(
                        {
                            "type": "http.response.body",
                            "body": chunk,
                            "more_body": remaining_bytes > 0,
                        }
                    )


class RangedStaticFiles(StaticFiles):
    def file_response(
        self,
        full_path: str | os.PathLike,
        stat_result: os.stat_result,
        scope: Scope,
        status_code: int = 200,
    ) -> Response:
        request_headers = Headers(scope=scope)

        if request_headers.get("range"):
            response = self.ranged_file_response(
                full_path, stat_result=stat_result, scope=scope
            )
        else:
            response = super().file_response(
                full_path, stat_result=stat_result, scope=scope, status_code=status_code
            )
        response.headers["accept-ranges"] = "bytes"
        return response

    def ranged_file_response(
        self,
        full_path: str | os.PathLike,
        stat_result: os.stat_result,
        scope: Scope,
    ) -> Response:
        method = scope["method"]
        request_headers = Headers(scope=scope)

        range_header = request_headers["range"]

        match = RANGE_REGEX.search(range_header)
        if not match:
            raise HTTPException(400)

        start, end = match.group("start"), match.group("end")

        range = OpenRange(int(start), int(end) if end else None)

        return RangedFileResponse(
            full_path, range, stat_result=stat_result, method=method
        )
