glam/backend/valkey/main.py
2025-12-10 23:51:51 +01:00

607 lines
19 KiB
Python

"""
Valkey/Redis Semantic Cache Backend Service
Provides a shared cache layer for RAG query responses across all users.
Uses vector similarity search for semantic matching.
Architecture:
- Two-tier caching: Client (IndexedDB) -> Server (Valkey)
- Embeddings stored as binary vectors for efficient similarity search
- TTL-based expiration with LRU eviction
- Optional: Use Redis Stack's vector search (RediSearch) for native similarity
Endpoints:
- POST /cache/lookup - Find semantically similar cached queries
- POST /cache/store - Store a query/response pair
- DELETE /cache/clear - Clear all cache entries
- GET /cache/stats - Get cache statistics
- GET /health - Health check
@author TextPast / NDE
@version 1.0.0
"""
import os
import json
import time
import hashlib
import struct
from typing import Optional, List, Dict, Any
from contextlib import asynccontextmanager
import numpy as np
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import redis.asyncio as redis
# =============================================================================
# Configuration
# =============================================================================
VALKEY_HOST = os.getenv("VALKEY_HOST", "localhost")
VALKEY_PORT = int(os.getenv("VALKEY_PORT", "6379"))
VALKEY_PASSWORD = os.getenv("VALKEY_PASSWORD", None)
VALKEY_DB = int(os.getenv("VALKEY_DB", "0"))
# Cache settings
CACHE_PREFIX = "glam:semantic_cache:"
CACHE_TTL_SECONDS = int(os.getenv("CACHE_TTL_SECONDS", "86400")) # 24 hours
MAX_CACHE_ENTRIES = int(os.getenv("MAX_CACHE_ENTRIES", "10000"))
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.92"))
# Embedding dimension (for validation)
EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", "1536")) # OpenAI ada-002 default
# =============================================================================
# Models
# =============================================================================
class CachedResponse(BaseModel):
"""The RAG response to cache"""
answer: str
sparql_query: Optional[str] = None
typeql_query: Optional[str] = None
visualization_type: Optional[str] = None
visualization_data: Optional[Any] = None
sources: List[Any] = Field(default_factory=list)
confidence: float = 0.0
context: Optional[Dict[str, Any]] = None
class CacheLookupRequest(BaseModel):
"""Request to look up a query in cache"""
query: str
embedding: Optional[List[float]] = None
language: str = "nl"
similarity_threshold: Optional[float] = None
class CacheStoreRequest(BaseModel):
"""Request to store a query/response in cache"""
query: str
embedding: Optional[List[float]] = None
response: CachedResponse
language: str = "nl"
model: str = "unknown"
class CacheLookupResponse(BaseModel):
"""Response from cache lookup"""
found: bool
similarity: float = 0.0
method: str = "none" # 'semantic', 'fuzzy', 'exact', 'none'
lookup_time_ms: float = 0.0
entry: Optional[Dict[str, Any]] = None
class CacheStats(BaseModel):
"""Cache statistics"""
total_entries: int
total_hits: int
total_misses: int
hit_rate: float
storage_used_bytes: int
oldest_entry: Optional[int] = None
newest_entry: Optional[int] = None
# =============================================================================
# Utility Functions
# =============================================================================
def normalize_query(query: str) -> str:
"""Normalize query text for comparison"""
import re
normalized = query.lower().strip()
normalized = re.sub(r'[^\w\s]', ' ', normalized)
normalized = re.sub(r'\s+', ' ', normalized)
return normalized.strip()
def generate_cache_key(query: str) -> str:
"""Generate a unique cache key from normalized query"""
normalized = normalize_query(query)
hash_val = hashlib.sha256(normalized.encode()).hexdigest()[:16]
return f"{CACHE_PREFIX}query:{hash_val}"
def embedding_to_bytes(embedding: List[float]) -> bytes:
"""Convert embedding list to compact binary format"""
return struct.pack(f'{len(embedding)}f', *embedding)
def bytes_to_embedding(data: bytes) -> List[float]:
"""Convert binary format back to embedding list"""
count = len(data) // 4 # 4 bytes per float
return list(struct.unpack(f'{count}f', data))
def cosine_similarity(a: List[float], b: List[float]) -> float:
"""Compute cosine similarity between two vectors"""
if len(a) != len(b) or len(a) == 0:
return 0.0
a_np = np.array(a)
b_np = np.array(b)
dot_product = np.dot(a_np, b_np)
norm_a = np.linalg.norm(a_np)
norm_b = np.linalg.norm(b_np)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(dot_product / (norm_a * norm_b))
def jaccard_similarity(a: str, b: str) -> float:
"""Compute Jaccard similarity between two strings (word-level)"""
set_a = set(normalize_query(a).split())
set_b = set(normalize_query(b).split())
if not set_a or not set_b:
return 0.0
intersection = len(set_a & set_b)
union = len(set_a | set_b)
return intersection / union if union > 0 else 0.0
# =============================================================================
# Redis/Valkey Client
# =============================================================================
class ValkeyClient:
"""Async Valkey/Redis client wrapper"""
def __init__(self):
self.client: Optional[redis.Redis] = None
self.stats = {
"hits": 0,
"misses": 0,
}
async def connect(self):
"""Initialize connection to Valkey"""
self.client = redis.Redis(
host=VALKEY_HOST,
port=VALKEY_PORT,
password=VALKEY_PASSWORD,
db=VALKEY_DB,
decode_responses=False, # We handle encoding ourselves
)
# Test connection
await self.client.ping()
print(f"[ValkeyCache] Connected to {VALKEY_HOST}:{VALKEY_PORT}")
async def disconnect(self):
"""Close connection"""
if self.client:
await self.client.close()
print("[ValkeyCache] Disconnected")
async def lookup(
self,
query: str,
embedding: Optional[List[float]] = None,
similarity_threshold: float = SIMILARITY_THRESHOLD,
) -> CacheLookupResponse:
"""Look up a query in the cache"""
start_time = time.time()
if not self.client:
raise HTTPException(status_code=503, detail="Cache not connected")
normalized = normalize_query(query)
# 1. Check for exact match first (fastest)
exact_key = generate_cache_key(query)
exact_match = await self.client.get(exact_key)
if exact_match:
entry = json.loads(exact_match.decode('utf-8'))
# Update access time
entry['last_accessed'] = int(time.time() * 1000)
entry['hit_count'] = entry.get('hit_count', 0) + 1
await self.client.setex(exact_key, CACHE_TTL_SECONDS, json.dumps(entry))
self.stats["hits"] += 1
lookup_time = (time.time() - start_time) * 1000
return CacheLookupResponse(
found=True,
similarity=1.0,
method="exact",
lookup_time_ms=lookup_time,
entry=entry,
)
# 2. Semantic similarity search
best_match = None
best_similarity = 0.0
match_method = "none"
# Get all cache keys
all_keys = await self.client.keys(f"{CACHE_PREFIX}query:*")
for key in all_keys:
entry_data = await self.client.get(key)
if not entry_data:
continue
entry = json.loads(entry_data.decode('utf-8'))
# Semantic similarity (if embeddings available)
if embedding and entry.get('embedding'):
stored_embedding = entry['embedding']
similarity = cosine_similarity(embedding, stored_embedding)
if similarity > best_similarity and similarity >= similarity_threshold:
best_similarity = similarity
best_match = entry
match_method = "semantic"
# Fuzzy text fallback
if not best_match:
text_similarity = jaccard_similarity(normalized, entry.get('query_normalized', ''))
if text_similarity > best_similarity and text_similarity >= 0.85:
best_similarity = text_similarity
best_match = entry
match_method = "fuzzy"
lookup_time = (time.time() - start_time) * 1000
if best_match:
# Update stats
best_match['last_accessed'] = int(time.time() * 1000)
best_match['hit_count'] = best_match.get('hit_count', 0) + 1
match_key = generate_cache_key(best_match['query'])
await self.client.setex(match_key, CACHE_TTL_SECONDS, json.dumps(best_match))
self.stats["hits"] += 1
# Don't send embedding back to client (too large)
return_entry = {k: v for k, v in best_match.items() if k != 'embedding'}
return CacheLookupResponse(
found=True,
similarity=best_similarity,
method=match_method,
lookup_time_ms=lookup_time,
entry=return_entry,
)
self.stats["misses"] += 1
return CacheLookupResponse(
found=False,
similarity=best_similarity,
method="none",
lookup_time_ms=lookup_time,
)
async def store(
self,
query: str,
embedding: Optional[List[float]],
response: CachedResponse,
language: str = "nl",
model: str = "unknown",
) -> str:
"""Store a query/response pair in the cache"""
if not self.client:
raise HTTPException(status_code=503, detail="Cache not connected")
cache_key = generate_cache_key(query)
timestamp = int(time.time() * 1000)
entry = {
"id": cache_key,
"query": query,
"query_normalized": normalize_query(query),
"embedding": embedding,
"response": response.model_dump(),
"timestamp": timestamp,
"last_accessed": timestamp,
"hit_count": 0,
"language": language,
"model": model,
}
await self.client.setex(
cache_key,
CACHE_TTL_SECONDS,
json.dumps(entry),
)
# Enforce max entries (simple LRU)
await self._enforce_max_entries()
print(f"[ValkeyCache] Stored: {query[:50]}...")
return cache_key
async def _enforce_max_entries(self):
"""Remove oldest entries if over limit"""
all_keys = await self.client.keys(f"{CACHE_PREFIX}query:*")
if len(all_keys) <= MAX_CACHE_ENTRIES:
return
# Get all entries with timestamps
entries = []
for key in all_keys:
entry_data = await self.client.get(key)
if entry_data:
entry = json.loads(entry_data.decode('utf-8'))
entries.append({
"key": key,
"last_accessed": entry.get("last_accessed", 0),
"hit_count": entry.get("hit_count", 0),
})
# Sort by LRU score (recent access + hit count)
entries.sort(key=lambda x: x["last_accessed"] + x["hit_count"] * 1000)
# Remove oldest
to_remove = len(entries) - MAX_CACHE_ENTRIES
for entry in entries[:to_remove]:
await self.client.delete(entry["key"])
print(f"[ValkeyCache] Evicted {to_remove} entries")
async def clear(self):
"""Clear all cache entries"""
if not self.client:
raise HTTPException(status_code=503, detail="Cache not connected")
all_keys = await self.client.keys(f"{CACHE_PREFIX}*")
if all_keys:
await self.client.delete(*all_keys)
self.stats = {"hits": 0, "misses": 0}
print("[ValkeyCache] Cache cleared")
async def get_stats(self) -> CacheStats:
"""Get cache statistics"""
if not self.client:
raise HTTPException(status_code=503, detail="Cache not connected")
all_keys = await self.client.keys(f"{CACHE_PREFIX}query:*")
total_size = 0
oldest = None
newest = None
for key in all_keys:
entry_data = await self.client.get(key)
if entry_data:
total_size += len(entry_data)
entry = json.loads(entry_data.decode('utf-8'))
timestamp = entry.get("timestamp", 0)
if oldest is None or timestamp < oldest:
oldest = timestamp
if newest is None or timestamp > newest:
newest = timestamp
total = self.stats["hits"] + self.stats["misses"]
hit_rate = self.stats["hits"] / total if total > 0 else 0.0
return CacheStats(
total_entries=len(all_keys),
total_hits=self.stats["hits"],
total_misses=self.stats["misses"],
hit_rate=hit_rate,
storage_used_bytes=total_size,
oldest_entry=oldest,
newest_entry=newest,
)
# =============================================================================
# FastAPI Application
# =============================================================================
valkey_client = ValkeyClient()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup and shutdown events"""
# Startup
try:
await valkey_client.connect()
except Exception as e:
print(f"[ValkeyCache] WARNING: Could not connect to Valkey: {e}")
print("[ValkeyCache] Service will run without cache persistence")
yield
# Shutdown
await valkey_client.disconnect()
app = FastAPI(
title="GLAM Semantic Cache API",
description="Shared semantic cache backend for RAG query responses",
version="1.0.0",
lifespan=lifespan,
)
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:5173",
"http://localhost:5174",
"https://bronhouder.nl",
"https://www.bronhouder.nl",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# =============================================================================
# API Endpoints
# =============================================================================
@app.get("/health")
async def health_check():
"""Health check endpoint"""
connected = valkey_client.client is not None
try:
if connected:
await valkey_client.client.ping()
except Exception:
connected = False
return {
"status": "healthy" if connected else "degraded",
"valkey_connected": connected,
"config": {
"host": VALKEY_HOST,
"port": VALKEY_PORT,
"ttl_seconds": CACHE_TTL_SECONDS,
"max_entries": MAX_CACHE_ENTRIES,
"similarity_threshold": SIMILARITY_THRESHOLD,
}
}
@app.post("/cache/lookup", response_model=CacheLookupResponse)
async def cache_lookup(request: CacheLookupRequest):
"""
Look up a query in the shared cache.
Returns the most similar cached response if above the similarity threshold.
Supports both semantic (embedding) and fuzzy (text) matching.
"""
threshold = request.similarity_threshold or SIMILARITY_THRESHOLD
return await valkey_client.lookup(
query=request.query,
embedding=request.embedding,
similarity_threshold=threshold,
)
@app.post("/cache/store")
async def cache_store(request: CacheStoreRequest):
"""
Store a query/response pair in the shared cache.
The entry will be available to all users for semantic matching.
"""
cache_key = await valkey_client.store(
query=request.query,
embedding=request.embedding,
response=request.response,
language=request.language,
model=request.model,
)
return {
"success": True,
"cache_key": cache_key,
"ttl_seconds": CACHE_TTL_SECONDS,
}
@app.delete("/cache/clear")
async def cache_clear(
confirm: bool = Query(False, description="Must be true to clear cache")
):
"""
Clear all entries from the shared cache.
Requires confirmation parameter to prevent accidental clearing.
"""
if not confirm:
raise HTTPException(
status_code=400,
detail="Must pass confirm=true to clear cache"
)
await valkey_client.clear()
return {"success": True, "message": "Cache cleared"}
@app.get("/cache/stats", response_model=CacheStats)
async def cache_stats():
"""
Get statistics about the shared cache.
Returns entry counts, hit rates, and storage usage.
"""
return await valkey_client.get_stats()
@app.get("/cache/entries")
async def cache_entries(
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
):
"""
List cached entries (for debugging/admin).
Returns entries without embeddings to reduce payload size.
"""
if not valkey_client.client:
raise HTTPException(status_code=503, detail="Cache not connected")
all_keys = await valkey_client.client.keys(f"{CACHE_PREFIX}query:*")
all_keys = sorted(all_keys)[offset:offset + limit]
entries = []
for key in all_keys:
entry_data = await valkey_client.client.get(key)
if entry_data:
entry = json.loads(entry_data.decode('utf-8'))
# Remove embedding from response
entry.pop('embedding', None)
entries.append(entry)
return {
"total": len(all_keys),
"offset": offset,
"limit": limit,
"entries": entries,
}
# =============================================================================
# Main
# =============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8090,
reload=True,
)