607 lines
19 KiB
Python
607 lines
19 KiB
Python
"""
|
|
Valkey/Redis Semantic Cache Backend Service
|
|
|
|
Provides a shared cache layer for RAG query responses across all users.
|
|
Uses vector similarity search for semantic matching.
|
|
|
|
Architecture:
|
|
- Two-tier caching: Client (IndexedDB) -> Server (Valkey)
|
|
- Embeddings stored as binary vectors for efficient similarity search
|
|
- TTL-based expiration with LRU eviction
|
|
- Optional: Use Redis Stack's vector search (RediSearch) for native similarity
|
|
|
|
Endpoints:
|
|
- POST /cache/lookup - Find semantically similar cached queries
|
|
- POST /cache/store - Store a query/response pair
|
|
- DELETE /cache/clear - Clear all cache entries
|
|
- GET /cache/stats - Get cache statistics
|
|
- GET /health - Health check
|
|
|
|
@author TextPast / NDE
|
|
@version 1.0.0
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import time
|
|
import hashlib
|
|
import struct
|
|
from typing import Optional, List, Dict, Any
|
|
from contextlib import asynccontextmanager
|
|
|
|
import numpy as np
|
|
from fastapi import FastAPI, HTTPException, Query
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field
|
|
import redis.asyncio as redis
|
|
|
|
# =============================================================================
|
|
# Configuration
|
|
# =============================================================================
|
|
|
|
VALKEY_HOST = os.getenv("VALKEY_HOST", "localhost")
|
|
VALKEY_PORT = int(os.getenv("VALKEY_PORT", "6379"))
|
|
VALKEY_PASSWORD = os.getenv("VALKEY_PASSWORD", None)
|
|
VALKEY_DB = int(os.getenv("VALKEY_DB", "0"))
|
|
|
|
# Cache settings
|
|
CACHE_PREFIX = "glam:semantic_cache:"
|
|
CACHE_TTL_SECONDS = int(os.getenv("CACHE_TTL_SECONDS", "86400")) # 24 hours
|
|
MAX_CACHE_ENTRIES = int(os.getenv("MAX_CACHE_ENTRIES", "10000"))
|
|
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", "0.92"))
|
|
|
|
# Embedding dimension (for validation)
|
|
EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", "1536")) # OpenAI ada-002 default
|
|
|
|
# =============================================================================
|
|
# Models
|
|
# =============================================================================
|
|
|
|
class CachedResponse(BaseModel):
|
|
"""The RAG response to cache"""
|
|
answer: str
|
|
sparql_query: Optional[str] = None
|
|
typeql_query: Optional[str] = None
|
|
visualization_type: Optional[str] = None
|
|
visualization_data: Optional[Any] = None
|
|
sources: List[Any] = Field(default_factory=list)
|
|
confidence: float = 0.0
|
|
context: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class CacheLookupRequest(BaseModel):
|
|
"""Request to look up a query in cache"""
|
|
query: str
|
|
embedding: Optional[List[float]] = None
|
|
language: str = "nl"
|
|
similarity_threshold: Optional[float] = None
|
|
|
|
|
|
class CacheStoreRequest(BaseModel):
|
|
"""Request to store a query/response in cache"""
|
|
query: str
|
|
embedding: Optional[List[float]] = None
|
|
response: CachedResponse
|
|
language: str = "nl"
|
|
model: str = "unknown"
|
|
|
|
|
|
class CacheLookupResponse(BaseModel):
|
|
"""Response from cache lookup"""
|
|
found: bool
|
|
similarity: float = 0.0
|
|
method: str = "none" # 'semantic', 'fuzzy', 'exact', 'none'
|
|
lookup_time_ms: float = 0.0
|
|
entry: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class CacheStats(BaseModel):
|
|
"""Cache statistics"""
|
|
total_entries: int
|
|
total_hits: int
|
|
total_misses: int
|
|
hit_rate: float
|
|
storage_used_bytes: int
|
|
oldest_entry: Optional[int] = None
|
|
newest_entry: Optional[int] = None
|
|
|
|
|
|
# =============================================================================
|
|
# Utility Functions
|
|
# =============================================================================
|
|
|
|
def normalize_query(query: str) -> str:
|
|
"""Normalize query text for comparison"""
|
|
import re
|
|
normalized = query.lower().strip()
|
|
normalized = re.sub(r'[^\w\s]', ' ', normalized)
|
|
normalized = re.sub(r'\s+', ' ', normalized)
|
|
return normalized.strip()
|
|
|
|
|
|
def generate_cache_key(query: str) -> str:
|
|
"""Generate a unique cache key from normalized query"""
|
|
normalized = normalize_query(query)
|
|
hash_val = hashlib.sha256(normalized.encode()).hexdigest()[:16]
|
|
return f"{CACHE_PREFIX}query:{hash_val}"
|
|
|
|
|
|
def embedding_to_bytes(embedding: List[float]) -> bytes:
|
|
"""Convert embedding list to compact binary format"""
|
|
return struct.pack(f'{len(embedding)}f', *embedding)
|
|
|
|
|
|
def bytes_to_embedding(data: bytes) -> List[float]:
|
|
"""Convert binary format back to embedding list"""
|
|
count = len(data) // 4 # 4 bytes per float
|
|
return list(struct.unpack(f'{count}f', data))
|
|
|
|
|
|
def cosine_similarity(a: List[float], b: List[float]) -> float:
|
|
"""Compute cosine similarity between two vectors"""
|
|
if len(a) != len(b) or len(a) == 0:
|
|
return 0.0
|
|
|
|
a_np = np.array(a)
|
|
b_np = np.array(b)
|
|
|
|
dot_product = np.dot(a_np, b_np)
|
|
norm_a = np.linalg.norm(a_np)
|
|
norm_b = np.linalg.norm(b_np)
|
|
|
|
if norm_a == 0 or norm_b == 0:
|
|
return 0.0
|
|
|
|
return float(dot_product / (norm_a * norm_b))
|
|
|
|
|
|
def jaccard_similarity(a: str, b: str) -> float:
|
|
"""Compute Jaccard similarity between two strings (word-level)"""
|
|
set_a = set(normalize_query(a).split())
|
|
set_b = set(normalize_query(b).split())
|
|
|
|
if not set_a or not set_b:
|
|
return 0.0
|
|
|
|
intersection = len(set_a & set_b)
|
|
union = len(set_a | set_b)
|
|
|
|
return intersection / union if union > 0 else 0.0
|
|
|
|
|
|
# =============================================================================
|
|
# Redis/Valkey Client
|
|
# =============================================================================
|
|
|
|
class ValkeyClient:
|
|
"""Async Valkey/Redis client wrapper"""
|
|
|
|
def __init__(self):
|
|
self.client: Optional[redis.Redis] = None
|
|
self.stats = {
|
|
"hits": 0,
|
|
"misses": 0,
|
|
}
|
|
|
|
async def connect(self):
|
|
"""Initialize connection to Valkey"""
|
|
self.client = redis.Redis(
|
|
host=VALKEY_HOST,
|
|
port=VALKEY_PORT,
|
|
password=VALKEY_PASSWORD,
|
|
db=VALKEY_DB,
|
|
decode_responses=False, # We handle encoding ourselves
|
|
)
|
|
# Test connection
|
|
await self.client.ping()
|
|
print(f"[ValkeyCache] Connected to {VALKEY_HOST}:{VALKEY_PORT}")
|
|
|
|
async def disconnect(self):
|
|
"""Close connection"""
|
|
if self.client:
|
|
await self.client.close()
|
|
print("[ValkeyCache] Disconnected")
|
|
|
|
async def lookup(
|
|
self,
|
|
query: str,
|
|
embedding: Optional[List[float]] = None,
|
|
similarity_threshold: float = SIMILARITY_THRESHOLD,
|
|
) -> CacheLookupResponse:
|
|
"""Look up a query in the cache"""
|
|
start_time = time.time()
|
|
|
|
if not self.client:
|
|
raise HTTPException(status_code=503, detail="Cache not connected")
|
|
|
|
normalized = normalize_query(query)
|
|
|
|
# 1. Check for exact match first (fastest)
|
|
exact_key = generate_cache_key(query)
|
|
exact_match = await self.client.get(exact_key)
|
|
|
|
if exact_match:
|
|
entry = json.loads(exact_match.decode('utf-8'))
|
|
# Update access time
|
|
entry['last_accessed'] = int(time.time() * 1000)
|
|
entry['hit_count'] = entry.get('hit_count', 0) + 1
|
|
await self.client.setex(exact_key, CACHE_TTL_SECONDS, json.dumps(entry))
|
|
|
|
self.stats["hits"] += 1
|
|
lookup_time = (time.time() - start_time) * 1000
|
|
|
|
return CacheLookupResponse(
|
|
found=True,
|
|
similarity=1.0,
|
|
method="exact",
|
|
lookup_time_ms=lookup_time,
|
|
entry=entry,
|
|
)
|
|
|
|
# 2. Semantic similarity search
|
|
best_match = None
|
|
best_similarity = 0.0
|
|
match_method = "none"
|
|
|
|
# Get all cache keys
|
|
all_keys = await self.client.keys(f"{CACHE_PREFIX}query:*")
|
|
|
|
for key in all_keys:
|
|
entry_data = await self.client.get(key)
|
|
if not entry_data:
|
|
continue
|
|
|
|
entry = json.loads(entry_data.decode('utf-8'))
|
|
|
|
# Semantic similarity (if embeddings available)
|
|
if embedding and entry.get('embedding'):
|
|
stored_embedding = entry['embedding']
|
|
similarity = cosine_similarity(embedding, stored_embedding)
|
|
|
|
if similarity > best_similarity and similarity >= similarity_threshold:
|
|
best_similarity = similarity
|
|
best_match = entry
|
|
match_method = "semantic"
|
|
|
|
# Fuzzy text fallback
|
|
if not best_match:
|
|
text_similarity = jaccard_similarity(normalized, entry.get('query_normalized', ''))
|
|
if text_similarity > best_similarity and text_similarity >= 0.85:
|
|
best_similarity = text_similarity
|
|
best_match = entry
|
|
match_method = "fuzzy"
|
|
|
|
lookup_time = (time.time() - start_time) * 1000
|
|
|
|
if best_match:
|
|
# Update stats
|
|
best_match['last_accessed'] = int(time.time() * 1000)
|
|
best_match['hit_count'] = best_match.get('hit_count', 0) + 1
|
|
match_key = generate_cache_key(best_match['query'])
|
|
await self.client.setex(match_key, CACHE_TTL_SECONDS, json.dumps(best_match))
|
|
|
|
self.stats["hits"] += 1
|
|
|
|
# Don't send embedding back to client (too large)
|
|
return_entry = {k: v for k, v in best_match.items() if k != 'embedding'}
|
|
|
|
return CacheLookupResponse(
|
|
found=True,
|
|
similarity=best_similarity,
|
|
method=match_method,
|
|
lookup_time_ms=lookup_time,
|
|
entry=return_entry,
|
|
)
|
|
|
|
self.stats["misses"] += 1
|
|
|
|
return CacheLookupResponse(
|
|
found=False,
|
|
similarity=best_similarity,
|
|
method="none",
|
|
lookup_time_ms=lookup_time,
|
|
)
|
|
|
|
async def store(
|
|
self,
|
|
query: str,
|
|
embedding: Optional[List[float]],
|
|
response: CachedResponse,
|
|
language: str = "nl",
|
|
model: str = "unknown",
|
|
) -> str:
|
|
"""Store a query/response pair in the cache"""
|
|
if not self.client:
|
|
raise HTTPException(status_code=503, detail="Cache not connected")
|
|
|
|
cache_key = generate_cache_key(query)
|
|
timestamp = int(time.time() * 1000)
|
|
|
|
entry = {
|
|
"id": cache_key,
|
|
"query": query,
|
|
"query_normalized": normalize_query(query),
|
|
"embedding": embedding,
|
|
"response": response.model_dump(),
|
|
"timestamp": timestamp,
|
|
"last_accessed": timestamp,
|
|
"hit_count": 0,
|
|
"language": language,
|
|
"model": model,
|
|
}
|
|
|
|
await self.client.setex(
|
|
cache_key,
|
|
CACHE_TTL_SECONDS,
|
|
json.dumps(entry),
|
|
)
|
|
|
|
# Enforce max entries (simple LRU)
|
|
await self._enforce_max_entries()
|
|
|
|
print(f"[ValkeyCache] Stored: {query[:50]}...")
|
|
return cache_key
|
|
|
|
async def _enforce_max_entries(self):
|
|
"""Remove oldest entries if over limit"""
|
|
all_keys = await self.client.keys(f"{CACHE_PREFIX}query:*")
|
|
|
|
if len(all_keys) <= MAX_CACHE_ENTRIES:
|
|
return
|
|
|
|
# Get all entries with timestamps
|
|
entries = []
|
|
for key in all_keys:
|
|
entry_data = await self.client.get(key)
|
|
if entry_data:
|
|
entry = json.loads(entry_data.decode('utf-8'))
|
|
entries.append({
|
|
"key": key,
|
|
"last_accessed": entry.get("last_accessed", 0),
|
|
"hit_count": entry.get("hit_count", 0),
|
|
})
|
|
|
|
# Sort by LRU score (recent access + hit count)
|
|
entries.sort(key=lambda x: x["last_accessed"] + x["hit_count"] * 1000)
|
|
|
|
# Remove oldest
|
|
to_remove = len(entries) - MAX_CACHE_ENTRIES
|
|
for entry in entries[:to_remove]:
|
|
await self.client.delete(entry["key"])
|
|
|
|
print(f"[ValkeyCache] Evicted {to_remove} entries")
|
|
|
|
async def clear(self):
|
|
"""Clear all cache entries"""
|
|
if not self.client:
|
|
raise HTTPException(status_code=503, detail="Cache not connected")
|
|
|
|
all_keys = await self.client.keys(f"{CACHE_PREFIX}*")
|
|
if all_keys:
|
|
await self.client.delete(*all_keys)
|
|
|
|
self.stats = {"hits": 0, "misses": 0}
|
|
print("[ValkeyCache] Cache cleared")
|
|
|
|
async def get_stats(self) -> CacheStats:
|
|
"""Get cache statistics"""
|
|
if not self.client:
|
|
raise HTTPException(status_code=503, detail="Cache not connected")
|
|
|
|
all_keys = await self.client.keys(f"{CACHE_PREFIX}query:*")
|
|
|
|
total_size = 0
|
|
oldest = None
|
|
newest = None
|
|
|
|
for key in all_keys:
|
|
entry_data = await self.client.get(key)
|
|
if entry_data:
|
|
total_size += len(entry_data)
|
|
entry = json.loads(entry_data.decode('utf-8'))
|
|
timestamp = entry.get("timestamp", 0)
|
|
|
|
if oldest is None or timestamp < oldest:
|
|
oldest = timestamp
|
|
if newest is None or timestamp > newest:
|
|
newest = timestamp
|
|
|
|
total = self.stats["hits"] + self.stats["misses"]
|
|
hit_rate = self.stats["hits"] / total if total > 0 else 0.0
|
|
|
|
return CacheStats(
|
|
total_entries=len(all_keys),
|
|
total_hits=self.stats["hits"],
|
|
total_misses=self.stats["misses"],
|
|
hit_rate=hit_rate,
|
|
storage_used_bytes=total_size,
|
|
oldest_entry=oldest,
|
|
newest_entry=newest,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# FastAPI Application
|
|
# =============================================================================
|
|
|
|
valkey_client = ValkeyClient()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Startup and shutdown events"""
|
|
# Startup
|
|
try:
|
|
await valkey_client.connect()
|
|
except Exception as e:
|
|
print(f"[ValkeyCache] WARNING: Could not connect to Valkey: {e}")
|
|
print("[ValkeyCache] Service will run without cache persistence")
|
|
|
|
yield
|
|
|
|
# Shutdown
|
|
await valkey_client.disconnect()
|
|
|
|
|
|
app = FastAPI(
|
|
title="GLAM Semantic Cache API",
|
|
description="Shared semantic cache backend for RAG query responses",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS configuration
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=[
|
|
"http://localhost:5173",
|
|
"http://localhost:5174",
|
|
"https://bronhouder.nl",
|
|
"https://www.bronhouder.nl",
|
|
],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# API Endpoints
|
|
# =============================================================================
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint"""
|
|
connected = valkey_client.client is not None
|
|
try:
|
|
if connected:
|
|
await valkey_client.client.ping()
|
|
except Exception:
|
|
connected = False
|
|
|
|
return {
|
|
"status": "healthy" if connected else "degraded",
|
|
"valkey_connected": connected,
|
|
"config": {
|
|
"host": VALKEY_HOST,
|
|
"port": VALKEY_PORT,
|
|
"ttl_seconds": CACHE_TTL_SECONDS,
|
|
"max_entries": MAX_CACHE_ENTRIES,
|
|
"similarity_threshold": SIMILARITY_THRESHOLD,
|
|
}
|
|
}
|
|
|
|
|
|
@app.post("/cache/lookup", response_model=CacheLookupResponse)
|
|
async def cache_lookup(request: CacheLookupRequest):
|
|
"""
|
|
Look up a query in the shared cache.
|
|
|
|
Returns the most similar cached response if above the similarity threshold.
|
|
Supports both semantic (embedding) and fuzzy (text) matching.
|
|
"""
|
|
threshold = request.similarity_threshold or SIMILARITY_THRESHOLD
|
|
|
|
return await valkey_client.lookup(
|
|
query=request.query,
|
|
embedding=request.embedding,
|
|
similarity_threshold=threshold,
|
|
)
|
|
|
|
|
|
@app.post("/cache/store")
|
|
async def cache_store(request: CacheStoreRequest):
|
|
"""
|
|
Store a query/response pair in the shared cache.
|
|
|
|
The entry will be available to all users for semantic matching.
|
|
"""
|
|
cache_key = await valkey_client.store(
|
|
query=request.query,
|
|
embedding=request.embedding,
|
|
response=request.response,
|
|
language=request.language,
|
|
model=request.model,
|
|
)
|
|
|
|
return {
|
|
"success": True,
|
|
"cache_key": cache_key,
|
|
"ttl_seconds": CACHE_TTL_SECONDS,
|
|
}
|
|
|
|
|
|
@app.delete("/cache/clear")
|
|
async def cache_clear(
|
|
confirm: bool = Query(False, description="Must be true to clear cache")
|
|
):
|
|
"""
|
|
Clear all entries from the shared cache.
|
|
|
|
Requires confirmation parameter to prevent accidental clearing.
|
|
"""
|
|
if not confirm:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Must pass confirm=true to clear cache"
|
|
)
|
|
|
|
await valkey_client.clear()
|
|
return {"success": True, "message": "Cache cleared"}
|
|
|
|
|
|
@app.get("/cache/stats", response_model=CacheStats)
|
|
async def cache_stats():
|
|
"""
|
|
Get statistics about the shared cache.
|
|
|
|
Returns entry counts, hit rates, and storage usage.
|
|
"""
|
|
return await valkey_client.get_stats()
|
|
|
|
|
|
@app.get("/cache/entries")
|
|
async def cache_entries(
|
|
limit: int = Query(100, ge=1, le=1000),
|
|
offset: int = Query(0, ge=0),
|
|
):
|
|
"""
|
|
List cached entries (for debugging/admin).
|
|
|
|
Returns entries without embeddings to reduce payload size.
|
|
"""
|
|
if not valkey_client.client:
|
|
raise HTTPException(status_code=503, detail="Cache not connected")
|
|
|
|
all_keys = await valkey_client.client.keys(f"{CACHE_PREFIX}query:*")
|
|
all_keys = sorted(all_keys)[offset:offset + limit]
|
|
|
|
entries = []
|
|
for key in all_keys:
|
|
entry_data = await valkey_client.client.get(key)
|
|
if entry_data:
|
|
entry = json.loads(entry_data.decode('utf-8'))
|
|
# Remove embedding from response
|
|
entry.pop('embedding', None)
|
|
entries.append(entry)
|
|
|
|
return {
|
|
"total": len(all_keys),
|
|
"offset": offset,
|
|
"limit": limit,
|
|
"entries": entries,
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Main
|
|
# =============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(
|
|
"main:app",
|
|
host="0.0.0.0",
|
|
port=8090,
|
|
reload=True,
|
|
)
|