from __future__ import annotations

import contextvars
import logging
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from uuid import uuid4

logger = logging.getLogger(__name__)

_current_search_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
    "current_search_id", default=None
)
_tracker: TrackingService | None = None


def set_current_search_id(search_id: str | None) -> None:
    _current_search_id.set(search_id)


def get_current_search_id() -> str | None:
    return _current_search_id.get()


def set_tracker(tracker: TrackingService | None) -> None:
    global _tracker
    _tracker = tracker


def get_tracker() -> TrackingService | None:
    return _tracker


class TrackingService:
    def __init__(self, db_path: str | Path) -> None:
        self.db_path = Path(db_path)
        self.db_path.parent.mkdir(parents=True, exist_ok=True)
        self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
        self.conn.row_factory = sqlite3.Row
        self._init_db()

    @staticmethod
    def _utc_now_iso() -> str:
        return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")

    def _init_db(self) -> None:
        with self.conn:
            self.conn.execute("PRAGMA journal_mode=WAL")
            self.conn.execute("PRAGMA synchronous=NORMAL")
            self.conn.execute(
                """
                CREATE TABLE IF NOT EXISTS search_logs (
                    id              TEXT PRIMARY KEY,
                    created_at      TEXT NOT NULL,
                    environment     TEXT NOT NULL,
                    source          TEXT NOT NULL,
                    search_type     TEXT NOT NULL,
                    commune_1       TEXT,
                    commune_1_slug  TEXT,
                    commune_2       TEXT,
                    commune_2_slug  TEXT,
                    latitude        REAL,
                    longitude       REAL,
                    start_date      TEXT,
                    end_date        TEXT,
                    status          TEXT NOT NULL DEFAULT 'pending',
                    total_api_calls INTEGER NOT NULL DEFAULT 0,
                    duration_ms     INTEGER,
                    error_message   TEXT
                )
                """
            )
            self.conn.execute(
                """
                CREATE TABLE IF NOT EXISTS api_call_logs (
                    id              TEXT PRIMARY KEY,
                    search_id       TEXT,
                    created_at      TEXT NOT NULL,
                    service         TEXT NOT NULL DEFAULT 'weather',
                    provider        TEXT NOT NULL,
                    endpoint        TEXT NOT NULL,
                    params_summary  TEXT,
                    cache_key       TEXT,
                    cache_status    TEXT NOT NULL,
                    status_code     INTEGER,
                    duration_ms     INTEGER NOT NULL,
                    success         INTEGER NOT NULL DEFAULT 1,
                    error_message   TEXT,
                    FOREIGN KEY (search_id) REFERENCES search_logs(id)
                )
                """
            )
            try:
                self.conn.execute(
                    "ALTER TABLE api_call_logs ADD COLUMN service TEXT NOT NULL DEFAULT 'weather'"
                )
            except sqlite3.OperationalError:
                # Column already exists on migrated databases.
                pass
            self._migrate_search_id_nullable()
            self.conn.execute(
                "CREATE INDEX IF NOT EXISTS idx_search_created ON search_logs(created_at)"
            )
            self.conn.execute(
                "CREATE INDEX IF NOT EXISTS idx_search_status ON search_logs(status)"
            )
            self.conn.execute(
                "CREATE INDEX IF NOT EXISTS idx_search_env ON search_logs(environment)"
            )
            self.conn.execute(
                "CREATE INDEX IF NOT EXISTS idx_search_commune ON search_logs(commune_1_slug)"
            )
            self.conn.execute(
                "CREATE INDEX IF NOT EXISTS idx_api_search ON api_call_logs(search_id)"
            )
            self.conn.execute(
                "CREATE INDEX IF NOT EXISTS idx_api_created ON api_call_logs(created_at)"
            )
            self.conn.execute(
                "CREATE INDEX IF NOT EXISTS idx_api_cache ON api_call_logs(cache_status)"
            )
            self.conn.execute(
                "CREATE INDEX IF NOT EXISTS idx_api_service ON api_call_logs(service)"
            )

    def _migrate_search_id_nullable(self) -> None:
        pragma = self.conn.execute("PRAGMA table_info(api_call_logs)").fetchall()
        for col in pragma:
            if col[1] == "search_id" and col[3]:
                self.conn.executescript(
                    """
                    CREATE TABLE api_call_logs_v23 (
                        id TEXT PRIMARY KEY,
                        search_id TEXT,
                        created_at TEXT NOT NULL,
                        service TEXT NOT NULL DEFAULT 'weather',
                        provider TEXT NOT NULL,
                        endpoint TEXT NOT NULL,
                        params_summary TEXT,
                        cache_key TEXT,
                        cache_status TEXT NOT NULL,
                        status_code INTEGER,
                        duration_ms INTEGER NOT NULL,
                        success INTEGER NOT NULL DEFAULT 1,
                        error_message TEXT,
                        FOREIGN KEY (search_id) REFERENCES search_logs(id)
                    );
                    INSERT INTO api_call_logs_v23 (
                        id, search_id, created_at, service, provider, endpoint,
                        params_summary, cache_key, cache_status, status_code,
                        duration_ms, success, error_message
                    )
                    SELECT
                        id, search_id, created_at, service, provider, endpoint,
                        params_summary, cache_key, cache_status, status_code,
                        duration_ms, success, error_message
                    FROM api_call_logs;
                    DROP TABLE api_call_logs;
                    ALTER TABLE api_call_logs_v23 RENAME TO api_call_logs;
                    CREATE INDEX IF NOT EXISTS idx_api_search ON api_call_logs(search_id);
                    CREATE INDEX IF NOT EXISTS idx_api_service ON api_call_logs(service);
                    """
                )
                break

    def create_search(
        self,
        *,
        search_id: str,
        environment: str,
        source: str,
        search_type: str,
        commune_1: str | None = None,
        commune_1_slug: str | None = None,
        commune_2: str | None = None,
        commune_2_slug: str | None = None,
        latitude: float | None = None,
        longitude: float | None = None,
        start_date: str | None = None,
        end_date: str | None = None,
    ) -> None:
        try:
            with self.conn:
                self.conn.execute(
                    """
                    INSERT INTO search_logs (
                        id, created_at, environment, source, search_type,
                        commune_1, commune_1_slug, commune_2, commune_2_slug,
                        latitude, longitude, start_date, end_date, status
                    ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'pending')
                    """,
                    (
                        search_id,
                        self._utc_now_iso(),
                        environment,
                        source,
                        search_type,
                        commune_1,
                        commune_1_slug,
                        commune_2,
                        commune_2_slug,
                        latitude,
                        longitude,
                        start_date,
                        end_date,
                    ),
                )
        except Exception as exc:
            logger.warning("tracking create_search failed: %s", exc)

    def _count_api_calls(self, search_id: str) -> int:
        row = self.conn.execute(
            (
                "SELECT COUNT(*) AS c FROM api_call_logs "
                "WHERE search_id = ? AND cache_status = 'miss'"
            ),
            (search_id,),
        ).fetchone()
        return int(row["c"]) if row else 0

    def complete_search(
        self,
        *,
        search_id: str,
        status: str,
        duration_ms: int,
        total_api_calls: int | None = None,
        error_message: str | None = None,
    ) -> None:
        try:
            api_calls = total_api_calls
            if api_calls is None:
                api_calls = self._count_api_calls(search_id)
            with self.conn:
                self.conn.execute(
                    """
                    UPDATE search_logs
                    SET status = ?, total_api_calls = ?, duration_ms = ?, error_message = ?
                    WHERE id = ?
                    """,
                    (status, api_calls, duration_ms, error_message, search_id),
                )
        except Exception as exc:
            logger.warning("tracking complete_search failed: %s", exc)

    def log_api_call(
        self,
        *,
        search_id: str | None,
        service: str,
        provider: str,
        endpoint: str,
        params_summary: str | None,
        cache_key: str | None,
        cache_status: str,
        status_code: int | None,
        duration_ms: int,
        success: bool,
        error_message: str | None,
    ) -> None:
        try:
            with self.conn:
                self.conn.execute(
                    """
                    INSERT INTO api_call_logs (
                        id, search_id, created_at, service, provider, endpoint,
                        params_summary, cache_key, cache_status, status_code,
                        duration_ms, success, error_message
                    ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                    """,
                    (
                        str(uuid4()),
                        search_id,
                        self._utc_now_iso(),
                        service,
                        provider,
                        endpoint,
                        params_summary,
                        cache_key,
                        cache_status,
                        status_code,
                        duration_ms,
                        1 if success else 0,
                        error_message,
                    ),
                )
        except Exception as exc:
            logger.warning("tracking log_api_call failed: %s", exc)

    def list_searches(
        self,
        *,
        limit: int = 50,
        offset: int = 0,
        status: str | None = None,
        environment: str | None = None,
        commune: str | None = None,
        sort: str = "date",
        order: str = "desc",
    ) -> dict[str, Any]:
        limit = max(1, min(limit, 200))
        offset = max(0, offset)

        sort_mapping = {
            "date": "created_at",
            "duration": "duration_ms",
            "api_calls": "total_api_calls",
        }
        sort_column = sort_mapping.get(sort, "created_at")
        order_sql = "ASC" if order.lower() == "asc" else "DESC"

        where_parts: list[str] = []
        params: list[Any] = []

        if status in {"success", "error", "pending"}:
            where_parts.append("status = ?")
            params.append(status)

        if environment:
            where_parts.append("environment = ?")
            params.append(environment)

        if commune:
            where_parts.append("LOWER(COALESCE(commune_1_slug, '')) LIKE ?")
            params.append(f"%{commune.lower()}%")

        where_sql = ""
        if where_parts:
            where_sql = "WHERE " + " AND ".join(where_parts)

        total_row = self.conn.execute(
            f"SELECT COUNT(*) AS c FROM search_logs {where_sql}", tuple(params)
        ).fetchone()
        total = int(total_row["c"]) if total_row else 0

        rows = self.conn.execute(
            f"""
            SELECT
                id, created_at, environment, source, search_type,
                commune_1, commune_1_slug, commune_2, commune_2_slug,
                start_date, end_date, status, total_api_calls, duration_ms
            FROM search_logs
            {where_sql}
            ORDER BY {sort_column} {order_sql}, created_at DESC
            LIMIT ? OFFSET ?
            """,
            tuple([*params, limit, offset]),
        ).fetchall()

        searches = [dict(row) for row in rows]
        return {
            "searches": searches,
            "total": total,
            "limit": limit,
            "offset": offset,
        }

    def get_search_detail(self, search_id: str) -> dict[str, Any] | None:
        search_row = self.conn.execute(
            "SELECT * FROM search_logs WHERE id = ?", (search_id,)
        ).fetchone()
        if search_row is None:
            return None

        api_rows = self.conn.execute(
            """
            SELECT
                id, created_at, service, provider, endpoint, params_summary,
                cache_key, cache_status, status_code, duration_ms,
                success, error_message
            FROM api_call_logs
            WHERE search_id = ?
            ORDER BY created_at ASC
            """,
            (search_id,),
        ).fetchall()

        api_calls: list[dict[str, Any]] = []
        for row in api_rows:
            item = dict(row)
            item["success"] = bool(item["success"])
            api_calls.append(item)

        return {"search": dict(search_row), "api_calls": api_calls}

    def get_dashboard(self, date_value: str | None = None) -> dict[str, Any]:
        if date_value is None:
            date_value = datetime.now(timezone.utc).date().isoformat()

        searches_today_row = self.conn.execute(
            "SELECT COUNT(*) AS c FROM search_logs WHERE substr(created_at, 1, 10) = ?",
            (date_value,),
        ).fetchone()
        searches_today = int(searches_today_row["c"]) if searches_today_row else 0

        api_calls_today_row = self.conn.execute(
            """
            SELECT COALESCE(service, 'weather') AS svc, COUNT(*) AS c
            FROM api_call_logs
            WHERE substr(created_at, 1, 10) = ? AND cache_status = 'miss'
            GROUP BY svc
            """,
            (date_value,),
        ).fetchall()
        api_calls_by_service = {"weather": 0, "normals": 0, "communes": 0}
        for row in api_calls_today_row:
            service = row["svc"]
            count = int(row["c"] or 0)
            if service in api_calls_by_service:
                api_calls_by_service[service] = count
            else:
                api_calls_by_service[service] = count
        api_calls_today = sum(api_calls_by_service.values())

        api_calls_breakdown_row = self.conn.execute(
            """
            SELECT
                COUNT(*) AS total,
                SUM(CASE WHEN search_id IS NOT NULL THEN 1 ELSE 0 END) AS attributed,
                SUM(CASE WHEN search_id IS NULL THEN 1 ELSE 0 END) AS standalone
            FROM api_call_logs
            WHERE substr(created_at, 1, 10) = ? AND cache_status = 'miss'
            """,
            (date_value,),
        ).fetchone()
        attributed_api_calls = (
            int(api_calls_breakdown_row["attributed"] or 0) if api_calls_breakdown_row else 0
        )
        standalone_api_calls = (
            int(api_calls_breakdown_row["standalone"] or 0) if api_calls_breakdown_row else 0
        )

        cache_row = self.conn.execute(
            """
            SELECT
                SUM(CASE WHEN total_api_calls = 0 AND status = 'success' THEN 1 ELSE 0 END) AS zero_calls,
                COUNT(*) AS total
            FROM search_logs
            WHERE substr(created_at, 1, 10) = ? AND status != 'pending'
            """,
            (date_value,),
        ).fetchone()
        zero_calls = int(cache_row["zero_calls"] or 0) if cache_row else 0
        total_cache = int(cache_row["total"] or 0) if cache_row else 0
        cache_hit_ratio = round(zero_calls / total_cache, 2) if total_cache else 0.0

        error_row = self.conn.execute(
            """
            SELECT COUNT(*) AS c
            FROM search_logs
            WHERE substr(created_at, 1, 10) = ? AND status = 'error'
            """,
            (date_value,),
        ).fetchone()
        errors_today = int(error_row["c"]) if error_row else 0
        error_rate = round(errors_today / searches_today, 2) if searches_today else 0.0

        top_rows = self.conn.execute(
            """
            SELECT commune_1_slug AS slug, commune_1 AS commune, COUNT(*) AS count
            FROM search_logs
            WHERE substr(created_at, 1, 10) = ? AND commune_1_slug IS NOT NULL
            GROUP BY commune_1_slug, commune_1
            ORDER BY count DESC, slug ASC
            LIMIT 10
            """,
            (date_value,),
        ).fetchall()

        expensive_rows = self.conn.execute(
            """
            SELECT id, commune_1, start_date, end_date, total_api_calls, duration_ms
            FROM search_logs
            WHERE substr(created_at, 1, 10) = ?
            ORDER BY duration_ms DESC, created_at DESC
            LIMIT 10
            """,
            (date_value,),
        ).fetchall()

        return {
            "date": date_value,
            "searches_today": searches_today,
            "api_calls_today": {
                "total": api_calls_today,
                "by_service": api_calls_by_service,
                "attributed": attributed_api_calls,
                "standalone": standalone_api_calls,
            },
            "cache_hit_ratio": cache_hit_ratio,
            "error_rate": error_rate,
            "top_communes": [dict(row) for row in top_rows],
            "most_expensive_searches": [dict(row) for row in expensive_rows],
        }

    def close(self) -> None:
        self.conn.close()
