from __future__ import annotations

import threading
import time

from src.cache import TTLCache


def test_cache_insert_and_get() -> None:
    cache = TTLCache[str](ttl_seconds=30, max_entries=10)
    cache.set("key", "value")

    assert cache.get("key") == "value"


def test_cache_expiration() -> None:
    cache = TTLCache[str](ttl_seconds=1, max_entries=10)
    cache.set("expired", "value")

    time.sleep(1.05)
    assert cache.get("expired") is None


def test_cache_fifo_eviction() -> None:
    cache = TTLCache[str](ttl_seconds=30, max_entries=2)
    cache.set("a", "1")
    cache.set("b", "2")
    cache.set("c", "3")

    assert cache.get("a") is None
    assert cache.get("b") == "2"
    assert cache.get("c") == "3"


def test_cache_missing_key_returns_none() -> None:
    cache = TTLCache[str](ttl_seconds=30, max_entries=2)
    assert cache.get("missing") is None


def test_cache_thread_safety() -> None:
    cache = TTLCache[int](ttl_seconds=30, max_entries=500)

    def writer(start: int) -> None:
        for i in range(start, start + 100):
            cache.set(f"k{i}", i)

    threads = [threading.Thread(target=writer, args=(n * 100,)) for n in range(5)]
    for t in threads:
        t.start()
    for t in threads:
        t.join()

    assert cache.get("k0") == 0
    assert cache.get("k199") == 199
    assert cache.get("k499") == 499


def test_ttlcache_hit_count() -> None:
    cache = TTLCache[str](ttl_seconds=30, max_entries=10)
    cache.set("key", "value")

    assert cache.get("key") == "value"
    assert cache.get("key") == "value"
    assert cache.get("key") == "value"

    snap = cache.snapshot()
    assert len(snap) == 1
    assert snap[0]["key"] == "key"
    assert snap[0]["hit_count"] == 3

    cache.set("key", "new-value")
    snap_after_reset = cache.snapshot()
    assert len(snap_after_reset) == 1
    assert snap_after_reset[0]["hit_count"] == 0


def test_ttlcache_snapshot() -> None:
    cache = TTLCache[int](ttl_seconds=30, max_entries=10)
    cache.set("a", 1)
    cache.set("b", 2)

    snapshot = cache.snapshot()
    keys = {item["key"] for item in snapshot}

    assert keys == {"a", "b"}
    for item in snapshot:
        assert "created_at" in item
        assert "expires_at" in item
        assert "hit_count" in item
        assert "value" not in item


def test_ttlcache_snapshot_excludes_expired() -> None:
    cache = TTLCache[str](ttl_seconds=1, max_entries=10)
    cache.set("expired", "value")

    time.sleep(1.05)
    assert cache.snapshot() == []
