172 lines
5.7 KiB
Python
172 lines
5.7 KiB
Python
"""Tests for the response cache layer."""
|
|
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import Session
|
|
|
|
from engine.cache import (
|
|
CachedResponse,
|
|
CacheStats,
|
|
ResponseCacheLayer,
|
|
compute_config_hash,
|
|
)
|
|
from models import Base
|
|
|
|
|
|
def _engine():
|
|
engine = create_engine("sqlite:///:memory:")
|
|
Base.metadata.create_all(engine)
|
|
return engine
|
|
|
|
|
|
def _session(engine):
|
|
return Session(engine)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# compute_config_hash tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestComputeConfigHash:
|
|
def test_deterministic(self):
|
|
h1 = compute_config_hash("hello", "gpt-4", {"temperature": 0.7})
|
|
h2 = compute_config_hash("hello", "gpt-4", {"temperature": 0.7})
|
|
assert h1 == h2
|
|
|
|
def test_different_prompt_different_hash(self):
|
|
h1 = compute_config_hash("hello", "gpt-4", {"temperature": 0.7})
|
|
h2 = compute_config_hash("world", "gpt-4", {"temperature": 0.7})
|
|
assert h1 != h2
|
|
|
|
def test_different_model_different_hash(self):
|
|
h1 = compute_config_hash("hello", "gpt-4", {"temperature": 0.7})
|
|
h2 = compute_config_hash("hello", "gpt-3.5", {"temperature": 0.7})
|
|
assert h1 != h2
|
|
|
|
def test_different_params_different_hash(self):
|
|
h1 = compute_config_hash("hello", "gpt-4", {"temperature": 0.7})
|
|
h2 = compute_config_hash("hello", "gpt-4", {"temperature": 0.9})
|
|
assert h1 != h2
|
|
|
|
def test_different_input_data_different_hash(self):
|
|
h1 = compute_config_hash("hello", "gpt-4", {}, input_data="data1")
|
|
h2 = compute_config_hash("hello", "gpt-4", {}, input_data="data2")
|
|
assert h1 != h2
|
|
|
|
def test_param_order_irrelevant(self):
|
|
h1 = compute_config_hash("p", "m", {"a": 1, "b": 2})
|
|
h2 = compute_config_hash("p", "m", {"b": 2, "a": 1})
|
|
assert h1 == h2
|
|
|
|
def test_returns_hex_string_64_chars(self):
|
|
h = compute_config_hash("test", "model", {})
|
|
assert len(h) == 64
|
|
assert all(c in "0123456789abcdef" for c in h)
|
|
|
|
def test_none_input_data_is_default(self):
|
|
h1 = compute_config_hash("p", "m", {})
|
|
h2 = compute_config_hash("p", "m", {}, input_data=None)
|
|
assert h1 == h2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ResponseCacheLayer tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestResponseCacheLayer:
|
|
def test_get_miss_returns_none(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
cache = ResponseCacheLayer()
|
|
result = cache.get(db, "nonexistent_hash")
|
|
assert result is None
|
|
|
|
def test_put_and_get(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
cache = ResponseCacheLayer()
|
|
config_hash = compute_config_hash("hello", "gpt-4", {"temp": 0.5})
|
|
|
|
cache.put(
|
|
db,
|
|
config_hash=config_hash,
|
|
response="Hello world!",
|
|
model="gpt-4",
|
|
tokens_in=10,
|
|
tokens_out=5,
|
|
latency_ms=150,
|
|
)
|
|
|
|
result = cache.get(db, config_hash)
|
|
assert result is not None
|
|
assert isinstance(result, CachedResponse)
|
|
assert result.response == "Hello world!"
|
|
assert result.model == "gpt-4"
|
|
assert result.tokens_in == 10
|
|
assert result.tokens_out == 5
|
|
assert result.latency_ms == 150
|
|
assert result.config_hash == config_hash
|
|
|
|
def test_put_upsert(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
cache = ResponseCacheLayer()
|
|
h = "a" * 64
|
|
|
|
cache.put(db, h, response="first", model="m1")
|
|
cache.put(db, h, response="second", model="m2")
|
|
|
|
result = cache.get(db, h)
|
|
assert result is not None
|
|
assert result.response == "second"
|
|
assert result.model == "m2"
|
|
|
|
def test_hit_rate_tracking(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
cache = ResponseCacheLayer()
|
|
h = compute_config_hash("p", "m", {})
|
|
|
|
cache.put(db, h, response="r", model="m")
|
|
|
|
# 1 hit, 1 miss
|
|
cache.get(db, h)
|
|
cache.get(db, "missing")
|
|
|
|
stats = cache.cache_stats(db)
|
|
assert stats.hit_rate == 0.5
|
|
assert stats.total_entries == 1
|
|
|
|
def test_cache_stats_empty(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
cache = ResponseCacheLayer()
|
|
stats = cache.cache_stats(db)
|
|
assert isinstance(stats, CacheStats)
|
|
assert stats.total_entries == 0
|
|
assert stats.hit_rate == 0.0
|
|
assert stats.storage_size_bytes == 0
|
|
|
|
def test_cache_stats_storage_size(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
cache = ResponseCacheLayer()
|
|
cache.put(db, "a" * 64, response="hello", model="m")
|
|
cache.put(db, "b" * 64, response="world!", model="m")
|
|
|
|
stats = cache.cache_stats(db)
|
|
assert stats.total_entries == 2
|
|
# "hello" = 5 chars, "world!" = 6 chars
|
|
assert stats.storage_size_bytes == 11
|
|
|
|
def test_multiple_entries(self):
|
|
engine = _engine()
|
|
with _session(engine) as db:
|
|
cache = ResponseCacheLayer()
|
|
for i in range(5):
|
|
h = compute_config_hash(f"prompt_{i}", "model", {})
|
|
cache.put(db, h, response=f"resp_{i}", model="model")
|
|
|
|
stats = cache.cache_stats(db)
|
|
assert stats.total_entries == 5
|