from __future__ import annotations

import asyncio
from datetime import date, datetime, timedelta
import time
from typing import Any

import httpx

from src.cache import TTLCache
from src.config import (
    HTTP_TIMEOUT_SECONDS,
    MONTH_NAMES_FR,
    NORMALS_CACHE_MAX_ENTRIES,
    NORMALS_CACHE_TTL_SECONDS,
    NORMALS_CHUNK_YEARS,
    NORMALS_REFERENCE_END,
    NORMALS_REFERENCE_PERIOD,
    NORMALS_REFERENCE_START,
    OPEN_METEO_ARCHIVE_URL,
    TIMEZONE,
)
from src.tracking_service import get_current_search_id, get_tracker


class NormalsValidationError(Exception):
    pass


class NormalsUpstreamError(Exception):
    pass


class NormalsService:
    def __init__(self, client: httpx.AsyncClient | None = None) -> None:
        self.client = client or httpx.AsyncClient(timeout=HTTP_TIMEOUT_SECONDS)
        self.cache: TTLCache[dict[str, Any]] = TTLCache(
            ttl_seconds=NORMALS_CACHE_TTL_SECONDS,
            max_entries=NORMALS_CACHE_MAX_ENTRIES,
        )

    async def get_normals(
        self, latitude: float, longitude: float, start: str, end: str
    ) -> dict[str, Any] | None:
        self._validate_coordinates(latitude, longitude)
        start_date, end_date = self._validate_dates(start, end)

        cache_key = f"normals:{latitude:.2f}:{longitude:.2f}"
        cached = self.cache.get(cache_key)
        if cached is None:
            cached = await self._fetch_reference_normals(latitude, longitude)
            self.cache.set(cache_key, cached)

        elevation = cached.get("elevation")
        days = cached.get("days") or {}

        daily_normals: list[dict[str, Any]] = []
        cursor = start_date
        while cursor <= end_date:
            month_day = cursor.strftime("%m-%d")
            day_values = days.get(month_day)
            if day_values:
                daily_normals.append(
                    {
                        "month_day": month_day,
                        "temp_avg": day_values.get("temp_avg"),
                        "temp_max": day_values.get("temp_max"),
                        "temp_min": day_values.get("temp_min"),
                        "precipitation": day_values.get("precipitation"),
                    }
                )
            cursor += timedelta(days=1)

        if not daily_normals:
            return None

        temp_avg_values = [
            item["temp_avg"] for item in daily_normals if item.get("temp_avg") is not None
        ]
        temp_max_values = [
            item["temp_max"] for item in daily_normals if item.get("temp_max") is not None
        ]
        temp_min_values = [
            item["temp_min"] for item in daily_normals if item.get("temp_min") is not None
        ]
        precipitation_values = [
            item["precipitation"]
            for item in daily_normals
            if item.get("precipitation") is not None
        ]

        period_normals = {
            "temp_avg": self._average(temp_avg_values),
            "temp_max_avg": self._average(temp_max_values),
            "temp_min_avg": self._average(temp_min_values),
            "precipitation_daily_avg": self._average(precipitation_values),
            "precipitation_total": round(sum(precipitation_values), 1)
            if precipitation_values
            else None,
        }

        return {
            "elevation": elevation,
            "reference_period": NORMALS_REFERENCE_PERIOD,
            "period_normals": period_normals,
            "daily_normals": daily_normals,
            "month_normals": self._compute_month_normals(days, start),
        }

    async def get_annual_normals(
        self, latitude: float, longitude: float
    ) -> dict[str, Any] | None:
        self._validate_coordinates(latitude, longitude)

        cache_key = f"normals:{latitude:.2f}:{longitude:.2f}"
        cached = self.cache.get(cache_key)
        if cached is None:
            cached = await self._fetch_reference_normals(latitude, longitude)
            self.cache.set(cache_key, cached)

        elevation = cached.get("elevation")
        days = cached.get("days") or {}

        months = []
        for month in range(1, 13):
            month_data = self._compute_month_normals(days, f"2020-{month:02d}-01")
            if month_data:
                months.append(month_data)

        if not months:
            return None

        temp_values = [m["temp_avg"] for m in months if m.get("temp_avg") is not None]
        precip_values = [
            m["precipitation_total"]
            for m in months
            if m.get("precipitation_total") is not None
        ]

        return {
            "elevation": elevation,
            "reference_period": NORMALS_REFERENCE_PERIOD,
            "annual_avg_temp": self._average(temp_values),
            "annual_precipitation": round(sum(precip_values), 1)
            if precip_values
            else None,
            "months": months,
        }

    @staticmethod
    def _compute_month_normals(
        all_days: dict[str, dict[str, float | None]], start_date: str
    ) -> dict[str, Any] | None:
        month = int(start_date[5:7])

        month_days = [
            day_values
            for month_day, day_values in all_days.items()
            if int(month_day.split("-")[0]) == month
        ]

        if not month_days:
            return None

        temp_avg_values = [
            day["temp_avg"] for day in month_days if day.get("temp_avg") is not None
        ]
        temp_max_values = [
            day["temp_max"] for day in month_days if day.get("temp_max") is not None
        ]
        temp_min_values = [
            day["temp_min"] for day in month_days if day.get("temp_min") is not None
        ]
        precipitation_values = [
            day["precipitation"]
            for day in month_days
            if day.get("precipitation") is not None
        ]

        return {
            "month": month,
            "month_name": MONTH_NAMES_FR[month],
            "temp_avg": NormalsService._average(temp_avg_values),
            "temp_max_avg": NormalsService._average(temp_max_values),
            "temp_min_avg": NormalsService._average(temp_min_values),
            "precipitation_total": round(sum(precipitation_values), 1)
            if precipitation_values
            else None,
        }

    async def _fetch_reference_normals(
        self, latitude: float, longitude: float
    ) -> dict[str, Any]:
        chunk_tasks = [
            self._fetch_chunk(latitude, longitude, start, end)
            for start, end in self._build_reference_chunks()
        ]

        results = await asyncio.gather(*chunk_tasks, return_exceptions=True)

        valid_payloads: list[dict[str, Any]] = []
        for result in results:
            if isinstance(result, Exception):
                continue
            valid_payloads.append(result)

        if not valid_payloads:
            raise NormalsUpstreamError(
                "Le service de normales climatiques est temporairement indisponible."
            )

        days_accumulator: dict[str, dict[str, list[float]]] = {}
        elevation: float | None = None

        for payload in valid_payloads:
            if elevation is None:
                maybe_elevation = payload.get("elevation")
                if isinstance(maybe_elevation, (int, float)):
                    elevation = round(float(maybe_elevation), 1)

            daily = payload.get("daily") or {}
            times = daily.get("time") or []
            temp_avg = daily.get("temperature_2m_mean") or []
            temp_max = daily.get("temperature_2m_max") or []
            temp_min = daily.get("temperature_2m_min") or []
            precipitations = daily.get("precipitation_sum") or []

            rows = min(
                len(times),
                len(temp_avg),
                len(temp_max),
                len(temp_min),
                len(precipitations),
            )

            for idx in range(rows):
                date_str = times[idx]
                try:
                    month_day = datetime.strptime(date_str, "%Y-%m-%d").strftime("%m-%d")
                except ValueError:
                    continue

                bucket = days_accumulator.setdefault(
                    month_day,
                    {
                        "temp_avg": [],
                        "temp_max": [],
                        "temp_min": [],
                        "precipitation": [],
                    },
                )

                self._append_if_numeric(bucket["temp_avg"], temp_avg[idx])
                self._append_if_numeric(bucket["temp_max"], temp_max[idx])
                self._append_if_numeric(bucket["temp_min"], temp_min[idx])
                self._append_if_numeric(bucket["precipitation"], precipitations[idx])

        days: dict[str, dict[str, float | None]] = {}
        for month_day, metric_values in days_accumulator.items():
            days[month_day] = {
                "temp_avg": self._average(metric_values["temp_avg"]),
                "temp_max": self._average(metric_values["temp_max"]),
                "temp_min": self._average(metric_values["temp_min"]),
                "precipitation": self._average(metric_values["precipitation"]),
            }

        return {"elevation": elevation, "days": days}

    async def _fetch_chunk(
        self, latitude: float, longitude: float, start_date: str, end_date: str
    ) -> dict[str, Any]:
        params = {
            "latitude": latitude,
            "longitude": longitude,
            "start_date": start_date,
            "end_date": end_date,
            "daily": (
                "temperature_2m_mean,temperature_2m_max,"
                "temperature_2m_min,precipitation_sum"
            ),
            "timezone": TIMEZONE,
        }

        tracker = get_tracker()
        search_id = get_current_search_id()
        api_start = time.monotonic()
        try:
            response = await self.client.get(OPEN_METEO_ARCHIVE_URL, params=params)
            response.raise_for_status()
            payload = response.json()
            api_duration = int((time.monotonic() - api_start) * 1000)
            if tracker:
                tracker.log_api_call(
                    search_id=search_id,
                    service="normals",
                    provider="open-meteo",
                    endpoint=OPEN_METEO_ARCHIVE_URL,
                    params_summary=(
                        f"lat={latitude:.2f},lon={longitude:.2f},{start_date}/{end_date}"
                    ),
                    cache_key=None,
                    cache_status="miss",
                    status_code=response.status_code,
                    duration_ms=api_duration,
                    success=True,
                    error_message=None,
                )
        except (httpx.TimeoutException, httpx.HTTPError, ValueError) as exc:
            api_duration = int((time.monotonic() - api_start) * 1000)
            if tracker:
                status_code = getattr(getattr(exc, "response", None), "status_code", None)
                tracker.log_api_call(
                    search_id=search_id,
                    service="normals",
                    provider="open-meteo",
                    endpoint=OPEN_METEO_ARCHIVE_URL,
                    params_summary=(
                        f"lat={latitude:.2f},lon={longitude:.2f},{start_date}/{end_date}"
                    ),
                    cache_key=None,
                    cache_status="miss",
                    status_code=status_code,
                    duration_ms=api_duration,
                    success=False,
                    error_message=str(exc),
                )
            raise NormalsUpstreamError from exc

        return payload

    @staticmethod
    def _build_reference_chunks() -> list[tuple[str, str]]:
        chunks: list[tuple[str, str]] = []
        chunk_start = NORMALS_REFERENCE_START

        while chunk_start <= NORMALS_REFERENCE_END:
            candidate_end_year = chunk_start.year + NORMALS_CHUNK_YEARS - 1
            chunk_end = date(candidate_end_year, 12, 31)
            if chunk_end > NORMALS_REFERENCE_END:
                chunk_end = NORMALS_REFERENCE_END

            chunks.append((chunk_start.isoformat(), chunk_end.isoformat()))
            chunk_start = chunk_end + timedelta(days=1)

        return chunks

    @staticmethod
    def _validate_coordinates(latitude: float, longitude: float) -> None:
        if latitude < -90 or latitude > 90:
            raise NormalsValidationError(
                "La latitude doit être comprise entre -90 et 90."
            )
        if longitude < -180 or longitude > 180:
            raise NormalsValidationError(
                "La longitude doit être comprise entre -180 et 180."
            )

    @staticmethod
    def _validate_dates(start: str, end: str) -> tuple[date, date]:
        try:
            start_date = datetime.strptime(start, "%Y-%m-%d").date()
            end_date = datetime.strptime(end, "%Y-%m-%d").date()
        except ValueError as exc:
            raise NormalsValidationError(
                "Le format de date doit être YYYY-MM-DD."
            ) from exc

        if end_date < start_date:
            raise NormalsValidationError(
                "La date de fin doit être postérieure ou égale à la date de début."
            )

        return start_date, end_date

    @staticmethod
    def _append_if_numeric(target: list[float], value: Any) -> None:
        if isinstance(value, (int, float)):
            target.append(float(value))

    @staticmethod
    def _average(values: list[float]) -> float | None:
        if not values:
            return None
        return round(sum(values) / len(values), 1)
