""" Valkey/Redis Semantic Cache Backend Service Provides a shared cache layer for RAG query responses across all users. Uses vector similarity search for semantic matching. Architecture: - Two-tier caching: Client (IndexedDB) -> Server (Valkey) - Embeddings stored as binary vectors for efficient similarity search - TTL-based expiration with LRU eviction - Optional: Use Redis Stack's vector search (RediSearch) for native similarity Endpoints: - POST /cache/lookup - Find semantically similar cached queries - POST /cache/store - Store a query/response pair - DELETE /cache/clear - Clear all cache entries - GET /cache/stats - Get cache statistics - GET /health - Health check @author TextPast / NDE @version 1.0.0 """ import os import json import time import hashlib import struct from typing import Optional, List, Dict, Any from contextlib import asynccontextmanager import numpy as np from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import redis.asyncio as redis # ============================================================================= # Configuration # ============================================================================= VALKEY_HOST = os.getenv("VALKEY_HOST", "localhost") VALKEY_PORT = int(os.getenv("VALKEY_PORT", "6379")) VALKEY_PASSWORD = os.getenv("VALKEY_PASSWORD", None) VALKEY_DB = int(os.getenv("VALKEY_DB", "0")) # Cache settings CACHE_PREFIX = "glam:semantic_cache:" CACHE_TTL_SECONDS = int(os.getenv("CACHE_TTL_SECONDS", "86400")) # 24 hours MAX_CACHE_ENTRIES = int(os.getenv("MAX_CACHE_ENTRIES", "10000")) SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.92")) # Embedding dimension (for validation) EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", "1536")) # OpenAI ada-002 default # ============================================================================= # Models # ============================================================================= class CachedResponse(BaseModel): """The RAG response to cache""" answer: str sparql_query: Optional[str] = None typeql_query: Optional[str] = None visualization_type: Optional[str] = None visualization_data: Optional[Any] = None sources: List[Any] = Field(default_factory=list) confidence: float = 0.0 context: Optional[Dict[str, Any]] = None class CacheLookupRequest(BaseModel): """Request to look up a query in cache""" query: str embedding: Optional[List[float]] = None language: str = "nl" similarity_threshold: Optional[float] = None class CacheStoreRequest(BaseModel): """Request to store a query/response in cache""" query: str embedding: Optional[List[float]] = None response: CachedResponse language: str = "nl" model: str = "unknown" class CacheLookupResponse(BaseModel): """Response from cache lookup""" found: bool similarity: float = 0.0 method: str = "none" # 'semantic', 'fuzzy', 'exact', 'none' lookup_time_ms: float = 0.0 entry: Optional[Dict[str, Any]] = None class CacheStats(BaseModel): """Cache statistics""" total_entries: int total_hits: int total_misses: int hit_rate: float storage_used_bytes: int oldest_entry: Optional[int] = None newest_entry: Optional[int] = None # ============================================================================= # Utility Functions # ============================================================================= def normalize_query(query: str) -> str: """Normalize query text for comparison""" import re normalized = query.lower().strip() normalized = re.sub(r'[^\w\s]', ' ', normalized) normalized = re.sub(r'\s+', ' ', normalized) return normalized.strip() def generate_cache_key(query: str) -> str: """Generate a unique cache key from normalized query""" normalized = normalize_query(query) hash_val = hashlib.sha256(normalized.encode()).hexdigest()[:16] return f"{CACHE_PREFIX}query:{hash_val}" def embedding_to_bytes(embedding: List[float]) -> bytes: """Convert embedding list to compact binary format""" return struct.pack(f'{len(embedding)}f', *embedding) def bytes_to_embedding(data: bytes) -> List[float]: """Convert binary format back to embedding list""" count = len(data) // 4 # 4 bytes per float return list(struct.unpack(f'{count}f', data)) def cosine_similarity(a: List[float], b: List[float]) -> float: """Compute cosine similarity between two vectors""" if len(a) != len(b) or len(a) == 0: return 0.0 a_np = np.array(a) b_np = np.array(b) dot_product = np.dot(a_np, b_np) norm_a = np.linalg.norm(a_np) norm_b = np.linalg.norm(b_np) if norm_a == 0 or norm_b == 0: return 0.0 return float(dot_product / (norm_a * norm_b)) def jaccard_similarity(a: str, b: str) -> float: """Compute Jaccard similarity between two strings (word-level)""" set_a = set(normalize_query(a).split()) set_b = set(normalize_query(b).split()) if not set_a or not set_b: return 0.0 intersection = len(set_a & set_b) union = len(set_a | set_b) return intersection / union if union > 0 else 0.0 # ============================================================================= # Redis/Valkey Client # ============================================================================= class ValkeyClient: """Async Valkey/Redis client wrapper""" def __init__(self): self.client: Optional[redis.Redis] = None self.stats = { "hits": 0, "misses": 0, } async def connect(self): """Initialize connection to Valkey""" self.client = redis.Redis( host=VALKEY_HOST, port=VALKEY_PORT, password=VALKEY_PASSWORD, db=VALKEY_DB, decode_responses=False, # We handle encoding ourselves ) # Test connection await self.client.ping() print(f"[ValkeyCache] Connected to {VALKEY_HOST}:{VALKEY_PORT}") async def disconnect(self): """Close connection""" if self.client: await self.client.close() print("[ValkeyCache] Disconnected") async def lookup( self, query: str, embedding: Optional[List[float]] = None, similarity_threshold: float = SIMILARITY_THRESHOLD, ) -> CacheLookupResponse: """Look up a query in the cache""" start_time = time.time() if not self.client: raise HTTPException(status_code=503, detail="Cache not connected") normalized = normalize_query(query) # 1. Check for exact match first (fastest) exact_key = generate_cache_key(query) exact_match = await self.client.get(exact_key) if exact_match: entry = json.loads(exact_match.decode('utf-8')) # Update access time entry['last_accessed'] = int(time.time() * 1000) entry['hit_count'] = entry.get('hit_count', 0) + 1 await self.client.setex(exact_key, CACHE_TTL_SECONDS, json.dumps(entry)) self.stats["hits"] += 1 lookup_time = (time.time() - start_time) * 1000 return CacheLookupResponse( found=True, similarity=1.0, method="exact", lookup_time_ms=lookup_time, entry=entry, ) # 2. Semantic similarity search best_match = None best_similarity = 0.0 match_method = "none" # Get all cache keys all_keys = await self.client.keys(f"{CACHE_PREFIX}query:*") for key in all_keys: entry_data = await self.client.get(key) if not entry_data: continue entry = json.loads(entry_data.decode('utf-8')) # Semantic similarity (if embeddings available) if embedding and entry.get('embedding'): stored_embedding = entry['embedding'] similarity = cosine_similarity(embedding, stored_embedding) if similarity > best_similarity and similarity >= similarity_threshold: best_similarity = similarity best_match = entry match_method = "semantic" # Fuzzy text fallback if not best_match: text_similarity = jaccard_similarity(normalized, entry.get('query_normalized', '')) if text_similarity > best_similarity and text_similarity >= 0.85: best_similarity = text_similarity best_match = entry match_method = "fuzzy" lookup_time = (time.time() - start_time) * 1000 if best_match: # Update stats best_match['last_accessed'] = int(time.time() * 1000) best_match['hit_count'] = best_match.get('hit_count', 0) + 1 match_key = generate_cache_key(best_match['query']) await self.client.setex(match_key, CACHE_TTL_SECONDS, json.dumps(best_match)) self.stats["hits"] += 1 # Don't send embedding back to client (too large) return_entry = {k: v for k, v in best_match.items() if k != 'embedding'} return CacheLookupResponse( found=True, similarity=best_similarity, method=match_method, lookup_time_ms=lookup_time, entry=return_entry, ) self.stats["misses"] += 1 return CacheLookupResponse( found=False, similarity=best_similarity, method="none", lookup_time_ms=lookup_time, ) async def store( self, query: str, embedding: Optional[List[float]], response: CachedResponse, language: str = "nl", model: str = "unknown", ) -> str: """Store a query/response pair in the cache""" if not self.client: raise HTTPException(status_code=503, detail="Cache not connected") cache_key = generate_cache_key(query) timestamp = int(time.time() * 1000) entry = { "id": cache_key, "query": query, "query_normalized": normalize_query(query), "embedding": embedding, "response": response.model_dump(), "timestamp": timestamp, "last_accessed": timestamp, "hit_count": 0, "language": language, "model": model, } await self.client.setex( cache_key, CACHE_TTL_SECONDS, json.dumps(entry), ) # Enforce max entries (simple LRU) await self._enforce_max_entries() print(f"[ValkeyCache] Stored: {query[:50]}...") return cache_key async def _enforce_max_entries(self): """Remove oldest entries if over limit""" all_keys = await self.client.keys(f"{CACHE_PREFIX}query:*") if len(all_keys) <= MAX_CACHE_ENTRIES: return # Get all entries with timestamps entries = [] for key in all_keys: entry_data = await self.client.get(key) if entry_data: entry = json.loads(entry_data.decode('utf-8')) entries.append({ "key": key, "last_accessed": entry.get("last_accessed", 0), "hit_count": entry.get("hit_count", 0), }) # Sort by LRU score (recent access + hit count) entries.sort(key=lambda x: x["last_accessed"] + x["hit_count"] * 1000) # Remove oldest to_remove = len(entries) - MAX_CACHE_ENTRIES for entry in entries[:to_remove]: await self.client.delete(entry["key"]) print(f"[ValkeyCache] Evicted {to_remove} entries") async def clear(self): """Clear all cache entries""" if not self.client: raise HTTPException(status_code=503, detail="Cache not connected") all_keys = await self.client.keys(f"{CACHE_PREFIX}*") if all_keys: await self.client.delete(*all_keys) self.stats = {"hits": 0, "misses": 0} print("[ValkeyCache] Cache cleared") async def get_stats(self) -> CacheStats: """Get cache statistics""" if not self.client: raise HTTPException(status_code=503, detail="Cache not connected") all_keys = await self.client.keys(f"{CACHE_PREFIX}query:*") total_size = 0 oldest = None newest = None for key in all_keys: entry_data = await self.client.get(key) if entry_data: total_size += len(entry_data) entry = json.loads(entry_data.decode('utf-8')) timestamp = entry.get("timestamp", 0) if oldest is None or timestamp < oldest: oldest = timestamp if newest is None or timestamp > newest: newest = timestamp total = self.stats["hits"] + self.stats["misses"] hit_rate = self.stats["hits"] / total if total > 0 else 0.0 return CacheStats( total_entries=len(all_keys), total_hits=self.stats["hits"], total_misses=self.stats["misses"], hit_rate=hit_rate, storage_used_bytes=total_size, oldest_entry=oldest, newest_entry=newest, ) # ============================================================================= # FastAPI Application # ============================================================================= valkey_client = ValkeyClient() @asynccontextmanager async def lifespan(app: FastAPI): """Startup and shutdown events""" # Startup try: await valkey_client.connect() except Exception as e: print(f"[ValkeyCache] WARNING: Could not connect to Valkey: {e}") print("[ValkeyCache] Service will run without cache persistence") yield # Shutdown await valkey_client.disconnect() app = FastAPI( title="GLAM Semantic Cache API", description="Shared semantic cache backend for RAG query responses", version="1.0.0", lifespan=lifespan, ) # CORS configuration app.add_middleware( CORSMiddleware, allow_origins=[ "http://localhost:5173", "http://localhost:5174", "https://bronhouder.nl", "https://www.bronhouder.nl", ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============================================================================= # API Endpoints # ============================================================================= @app.get("/health") async def health_check(): """Health check endpoint""" connected = valkey_client.client is not None try: if connected: await valkey_client.client.ping() except Exception: connected = False return { "status": "healthy" if connected else "degraded", "valkey_connected": connected, "config": { "host": VALKEY_HOST, "port": VALKEY_PORT, "ttl_seconds": CACHE_TTL_SECONDS, "max_entries": MAX_CACHE_ENTRIES, "similarity_threshold": SIMILARITY_THRESHOLD, } } @app.post("/cache/lookup", response_model=CacheLookupResponse) async def cache_lookup(request: CacheLookupRequest): """ Look up a query in the shared cache. Returns the most similar cached response if above the similarity threshold. Supports both semantic (embedding) and fuzzy (text) matching. """ threshold = request.similarity_threshold or SIMILARITY_THRESHOLD return await valkey_client.lookup( query=request.query, embedding=request.embedding, similarity_threshold=threshold, ) @app.post("/cache/store") async def cache_store(request: CacheStoreRequest): """ Store a query/response pair in the shared cache. The entry will be available to all users for semantic matching. """ cache_key = await valkey_client.store( query=request.query, embedding=request.embedding, response=request.response, language=request.language, model=request.model, ) return { "success": True, "cache_key": cache_key, "ttl_seconds": CACHE_TTL_SECONDS, } @app.delete("/cache/clear") async def cache_clear( confirm: bool = Query(False, description="Must be true to clear cache") ): """ Clear all entries from the shared cache. Requires confirmation parameter to prevent accidental clearing. """ if not confirm: raise HTTPException( status_code=400, detail="Must pass confirm=true to clear cache" ) await valkey_client.clear() return {"success": True, "message": "Cache cleared"} @app.get("/cache/stats", response_model=CacheStats) async def cache_stats(): """ Get statistics about the shared cache. Returns entry counts, hit rates, and storage usage. """ return await valkey_client.get_stats() @app.get("/cache/entries") async def cache_entries( limit: int = Query(100, ge=1, le=1000), offset: int = Query(0, ge=0), ): """ List cached entries (for debugging/admin). Returns entries without embeddings to reduce payload size. """ if not valkey_client.client: raise HTTPException(status_code=503, detail="Cache not connected") all_keys = await valkey_client.client.keys(f"{CACHE_PREFIX}query:*") all_keys = sorted(all_keys)[offset:offset + limit] entries = [] for key in all_keys: entry_data = await valkey_client.client.get(key) if entry_data: entry = json.loads(entry_data.decode('utf-8')) # Remove embedding from response entry.pop('embedding', None) entries.append(entry) return { "total": len(all_keys), "offset": offset, "limit": limit, "entries": entries, } # ============================================================================= # Main # ============================================================================= if __name__ == "__main__": import uvicorn uvicorn.run( "main:app", host="0.0.0.0", port=8090, reload=True, )