promptlooper/backend/engine/cache.py

146 lines
4.1 KiB
Python

"""Response cache layer for PromptLooper.
Caches LLM responses by a SHA-256 hash of the full configuration
(prompt + model + params + input_data) to avoid redundant API calls.
"""
import hashlib
import json
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import func, text
from sqlalchemy.orm import Session
from models import ResponseCache
@dataclass
class CachedResponse:
"""A cached LLM response retrieved from the database."""
config_hash: str
response: str
model: str
tokens_in: int | None
tokens_out: int | None
latency_ms: int | None
created_at: datetime
@dataclass
class CacheStats:
"""Cache statistics."""
total_entries: int
hit_rate: float
storage_size_bytes: int
def compute_config_hash(
prompt: str,
model: str,
params: dict[str, Any],
input_data: Any = None,
) -> str:
"""Compute a deterministic SHA-256 hash for a given configuration.
The hash covers the full config so that any parameter change produces
a different key. Dict keys are sorted for determinism.
"""
payload = {
"prompt": prompt,
"model": model,
"params": params,
"input_data": input_data,
}
canonical = json.dumps(payload, sort_keys=True, ensure_ascii=True, default=str)
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
class ResponseCacheLayer:
"""Database-backed response cache.
Works with both SQLite and PostgreSQL — the caller provides a
SQLAlchemy session.
"""
def __init__(self) -> None:
self._hits: int = 0
self._misses: int = 0
def get(self, db: Session, config_hash: str) -> CachedResponse | None:
"""Look up a cached response by config hash.
Returns None on cache miss.
"""
row = db.get(ResponseCache, config_hash)
if row is None:
self._misses += 1
return None
self._hits += 1
return CachedResponse(
config_hash=row.config_hash,
response=row.response,
model=row.model,
tokens_in=row.tokens_in,
tokens_out=row.tokens_out,
latency_ms=row.latency_ms,
created_at=row.created_at,
)
def put(
self,
db: Session,
config_hash: str,
response: str,
model: str,
tokens_in: int | None = None,
tokens_out: int | None = None,
latency_ms: int | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Store a response in the cache.
If the config_hash already exists, the entry is updated (upsert).
"""
existing = db.get(ResponseCache, config_hash)
if existing is not None:
existing.response = response
existing.model = model
existing.tokens_in = tokens_in
existing.tokens_out = tokens_out
existing.latency_ms = latency_ms
else:
entry = ResponseCache(
config_hash=config_hash,
response=response,
model=model,
tokens_in=tokens_in,
tokens_out=tokens_out,
latency_ms=latency_ms,
)
db.add(entry)
db.commit()
def cache_stats(self, db: Session) -> CacheStats:
"""Return cache statistics: hit rate, total entries, storage size."""
total: int = db.query(func.count(ResponseCache.config_hash)).scalar() or 0
total_lookups = self._hits + self._misses
hit_rate = self._hits / total_lookups if total_lookups > 0 else 0.0
# Approximate storage: sum of response text lengths.
# For SQLite, length() returns character count; for Postgres, octet_length
# would be more accurate, but length() works everywhere.
size: int = (
db.query(func.sum(func.length(ResponseCache.response))).scalar() or 0
)
return CacheStats(
total_entries=total,
hit_rate=hit_rate,
storage_size_bytes=size,
)