146 lines
4.1 KiB
Python
146 lines
4.1 KiB
Python
"""Response cache layer for PromptLooper.
|
|
|
|
Caches LLM responses by a SHA-256 hash of the full configuration
|
|
(prompt + model + params + input_data) to avoid redundant API calls.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from sqlalchemy import func, text
|
|
from sqlalchemy.orm import Session
|
|
|
|
from models import ResponseCache
|
|
|
|
|
|
@dataclass
|
|
class CachedResponse:
|
|
"""A cached LLM response retrieved from the database."""
|
|
|
|
config_hash: str
|
|
response: str
|
|
model: str
|
|
tokens_in: int | None
|
|
tokens_out: int | None
|
|
latency_ms: int | None
|
|
created_at: datetime
|
|
|
|
|
|
@dataclass
|
|
class CacheStats:
|
|
"""Cache statistics."""
|
|
|
|
total_entries: int
|
|
hit_rate: float
|
|
storage_size_bytes: int
|
|
|
|
|
|
def compute_config_hash(
|
|
prompt: str,
|
|
model: str,
|
|
params: dict[str, Any],
|
|
input_data: Any = None,
|
|
) -> str:
|
|
"""Compute a deterministic SHA-256 hash for a given configuration.
|
|
|
|
The hash covers the full config so that any parameter change produces
|
|
a different key. Dict keys are sorted for determinism.
|
|
"""
|
|
payload = {
|
|
"prompt": prompt,
|
|
"model": model,
|
|
"params": params,
|
|
"input_data": input_data,
|
|
}
|
|
canonical = json.dumps(payload, sort_keys=True, ensure_ascii=True, default=str)
|
|
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
|
|
|
|
|
|
class ResponseCacheLayer:
|
|
"""Database-backed response cache.
|
|
|
|
Works with both SQLite and PostgreSQL — the caller provides a
|
|
SQLAlchemy session.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._hits: int = 0
|
|
self._misses: int = 0
|
|
|
|
def get(self, db: Session, config_hash: str) -> CachedResponse | None:
|
|
"""Look up a cached response by config hash.
|
|
|
|
Returns None on cache miss.
|
|
"""
|
|
row = db.get(ResponseCache, config_hash)
|
|
if row is None:
|
|
self._misses += 1
|
|
return None
|
|
|
|
self._hits += 1
|
|
return CachedResponse(
|
|
config_hash=row.config_hash,
|
|
response=row.response,
|
|
model=row.model,
|
|
tokens_in=row.tokens_in,
|
|
tokens_out=row.tokens_out,
|
|
latency_ms=row.latency_ms,
|
|
created_at=row.created_at,
|
|
)
|
|
|
|
def put(
|
|
self,
|
|
db: Session,
|
|
config_hash: str,
|
|
response: str,
|
|
model: str,
|
|
tokens_in: int | None = None,
|
|
tokens_out: int | None = None,
|
|
latency_ms: int | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> None:
|
|
"""Store a response in the cache.
|
|
|
|
If the config_hash already exists, the entry is updated (upsert).
|
|
"""
|
|
existing = db.get(ResponseCache, config_hash)
|
|
if existing is not None:
|
|
existing.response = response
|
|
existing.model = model
|
|
existing.tokens_in = tokens_in
|
|
existing.tokens_out = tokens_out
|
|
existing.latency_ms = latency_ms
|
|
else:
|
|
entry = ResponseCache(
|
|
config_hash=config_hash,
|
|
response=response,
|
|
model=model,
|
|
tokens_in=tokens_in,
|
|
tokens_out=tokens_out,
|
|
latency_ms=latency_ms,
|
|
)
|
|
db.add(entry)
|
|
db.commit()
|
|
|
|
def cache_stats(self, db: Session) -> CacheStats:
|
|
"""Return cache statistics: hit rate, total entries, storage size."""
|
|
total: int = db.query(func.count(ResponseCache.config_hash)).scalar() or 0
|
|
|
|
total_lookups = self._hits + self._misses
|
|
hit_rate = self._hits / total_lookups if total_lookups > 0 else 0.0
|
|
|
|
# Approximate storage: sum of response text lengths.
|
|
# For SQLite, length() returns character count; for Postgres, octet_length
|
|
# would be more accurate, but length() works everywhere.
|
|
size: int = (
|
|
db.query(func.sum(func.length(ResponseCache.response))).scalar() or 0
|
|
)
|
|
|
|
return CacheStats(
|
|
total_entries=total,
|
|
hit_rate=hit_rate,
|
|
storage_size_bytes=size,
|
|
)
|