1056 lines
38 KiB
Python
1056 lines
38 KiB
Python
"""
|
|
Heritage RAG Semantic Cache
|
|
|
|
Hybrid semantic caching system for the Heritage RAG pipeline with:
|
|
- Vector-based semantic similarity matching (RedisVL)
|
|
- Atomic sub-query decomposition for higher hit rates
|
|
- Cross-encoder validation to prevent false positives
|
|
- Intent-aware TTL policies
|
|
- Cache warmup with heritage FAQs
|
|
|
|
Based on research findings:
|
|
- Hybrid architecture achieves 65% hit rate vs 5-15% for full queries
|
|
- Cross-encoder validation reduces false positives from 99% to 3.8%
|
|
- Atomic decomposition enables 40-70% cache hit rates
|
|
|
|
Usage:
|
|
cache = HeritageSemanticCache()
|
|
|
|
# Check cache before expensive RAG
|
|
cached = await cache.get(question, language)
|
|
if cached:
|
|
return cached
|
|
|
|
# Execute RAG pipeline
|
|
result = await rag_pipeline(question)
|
|
|
|
# Store result
|
|
await cache.set(question, result, intent="statistical")
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import re
|
|
import time
|
|
from dataclasses import asdict
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Import cache config - support both package and direct imports
|
|
try:
|
|
# Try relative import first (when used as part of a package)
|
|
from .cache_config import (
|
|
CACHE_BYPASS_PATTERNS,
|
|
FAQ_CATEGORIES,
|
|
CacheEntry,
|
|
CacheSettings,
|
|
CacheStats,
|
|
DistanceMetric,
|
|
get_cache_settings,
|
|
get_ttl_for_intent,
|
|
)
|
|
except ImportError:
|
|
# Fall back to absolute import (when module is on sys.path directly)
|
|
from cache_config import (
|
|
CACHE_BYPASS_PATTERNS,
|
|
FAQ_CATEGORIES,
|
|
CacheEntry,
|
|
CacheSettings,
|
|
CacheStats,
|
|
DistanceMetric,
|
|
get_cache_settings,
|
|
get_ttl_for_intent,
|
|
)
|
|
|
|
# Try to import Redis/RedisVL (graceful fallback)
|
|
try:
|
|
import redis.asyncio as redis
|
|
from redisvl.index import AsyncSearchIndex
|
|
from redisvl.query import VectorQuery
|
|
from redisvl.schema import IndexSchema
|
|
REDIS_AVAILABLE = True
|
|
except ImportError:
|
|
logger.warning("Redis/RedisVL not available, using in-memory cache fallback")
|
|
REDIS_AVAILABLE = False
|
|
redis = None # type: ignore
|
|
AsyncSearchIndex = None # type: ignore
|
|
VectorQuery = None # type: ignore
|
|
|
|
# Try to import sentence transformers for embeddings
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
EMBEDDINGS_AVAILABLE = True
|
|
except ImportError:
|
|
logger.warning("SentenceTransformers not available")
|
|
EMBEDDINGS_AVAILABLE = False
|
|
SentenceTransformer = None # type: ignore
|
|
|
|
# Try to import cross-encoder for validation
|
|
try:
|
|
from sentence_transformers import CrossEncoder
|
|
CROSS_ENCODER_AVAILABLE = True
|
|
except ImportError:
|
|
logger.warning("CrossEncoder not available, validation disabled")
|
|
CROSS_ENCODER_AVAILABLE = False
|
|
CrossEncoder = None # type: ignore
|
|
|
|
|
|
class HeritageSemanticCache:
|
|
"""Hybrid semantic cache for Heritage RAG responses.
|
|
|
|
Features:
|
|
- Semantic similarity matching using vector embeddings
|
|
- Filterable by intent, language, institution type, location
|
|
- Cross-encoder validation to prevent false positives
|
|
- Intent-aware TTL policies
|
|
- Atomic sub-query caching for higher hit rates
|
|
|
|
Architecture (from research):
|
|
- Layer 1: Vector similarity (RedisVL) - fuzzy matching
|
|
- Layer 2: Metadata filtering - intent/language alignment
|
|
- Layer 3: Cross-encoder validation - semantic equivalence check
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
settings: CacheSettings | None = None,
|
|
embedding_model: SentenceTransformer | None = None,
|
|
validator: CrossEncoder | None = None,
|
|
):
|
|
"""Initialize semantic cache.
|
|
|
|
Args:
|
|
settings: Cache configuration (uses defaults if None)
|
|
embedding_model: Pre-loaded embedding model (auto-loads if None)
|
|
validator: Pre-loaded cross-encoder (auto-loads if None)
|
|
"""
|
|
self.settings = settings or get_cache_settings()
|
|
self.stats = CacheStats()
|
|
|
|
# Redis connection
|
|
self._redis: redis.Redis | None = None
|
|
self._index: AsyncSearchIndex | None = None
|
|
|
|
# Embedding model
|
|
self._embedding_model = embedding_model
|
|
self._embedding_dim = self.settings.cache_embedding_dim
|
|
|
|
# Cross-encoder validator
|
|
self._validator = validator
|
|
|
|
# In-memory fallback cache
|
|
self._memory_cache: dict[str, CacheEntry] = {}
|
|
|
|
# Compiled bypass patterns
|
|
self._bypass_patterns = [
|
|
re.compile(p, re.IGNORECASE) for p in CACHE_BYPASS_PATTERNS
|
|
]
|
|
|
|
# Initialization flag
|
|
self._initialized = False
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize cache connections and models.
|
|
|
|
Call this before using the cache. Safe to call multiple times.
|
|
"""
|
|
if self._initialized:
|
|
return
|
|
|
|
# Load embedding model
|
|
if self._embedding_model is None and EMBEDDINGS_AVAILABLE:
|
|
try:
|
|
# Use a fast, high-quality model for cache embeddings
|
|
# redis/langcache-embed-v1 is optimized for semantic caching
|
|
model_name = "sentence-transformers/all-MiniLM-L6-v2" # Fallback
|
|
if "redis" in self.settings.cache_embedding_model:
|
|
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
|
|
|
self._embedding_model = SentenceTransformer(model_name)
|
|
self._embedding_dim = self._embedding_model.get_sentence_embedding_dimension()
|
|
logger.info(f"Loaded embedding model: {model_name} (dim={self._embedding_dim})")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load embedding model: {e}")
|
|
|
|
# Load cross-encoder validator
|
|
if (
|
|
self._validator is None
|
|
and CROSS_ENCODER_AVAILABLE
|
|
and self.settings.validation_enabled
|
|
):
|
|
try:
|
|
# Cross-encoder for semantic equivalence validation
|
|
self._validator = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
|
logger.info("Loaded cross-encoder validator")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load cross-encoder: {e}")
|
|
|
|
# Connect to Redis
|
|
if REDIS_AVAILABLE and self.settings.cache_enabled:
|
|
try:
|
|
self._redis = await redis.from_url(
|
|
self.settings.redis_url,
|
|
password=self.settings.redis_password,
|
|
db=self.settings.redis_db,
|
|
decode_responses=False, # Binary for vectors
|
|
)
|
|
await self._redis.ping()
|
|
logger.info(f"Connected to Redis: {self.settings.redis_url}")
|
|
|
|
# Create search index
|
|
await self._create_index()
|
|
except Exception as e:
|
|
logger.warning(f"Redis connection failed: {e}, using memory cache")
|
|
self._redis = None
|
|
|
|
self._initialized = True
|
|
|
|
async def _create_index(self) -> None:
|
|
"""Create search index for semantic cache.
|
|
|
|
Uses raw FT.CREATE command for Valkey-Search compatibility.
|
|
Valkey-Search doesn't support some RedisVL-specific parameters like EPSILON.
|
|
"""
|
|
if not REDIS_AVAILABLE or self._redis is None:
|
|
return
|
|
|
|
index_name = self.settings.cache_index_name
|
|
prefix = self.settings.cache_prefix
|
|
|
|
try:
|
|
# Check if index already exists
|
|
try:
|
|
await self._redis.execute_command("FT.INFO", index_name)
|
|
logger.info(f"Cache index already exists: {index_name}")
|
|
self._index = True # Mark as available
|
|
return
|
|
except Exception:
|
|
pass # Index doesn't exist, create it
|
|
|
|
# Create index using raw FT.CREATE command (Valkey-Search compatible)
|
|
# Format: FT.CREATE idx ON HASH PREFIX 1 prefix: SCHEMA field1 type1 ...
|
|
# Note: Valkey-Search doesn't support TEXT fields - use TAG for exact match
|
|
# The "query" field is stored in hash but not indexed (retrieved via RETURN)
|
|
distance_metric = self.settings.distance_metric.value.upper()
|
|
|
|
await self._redis.execute_command(
|
|
"FT.CREATE", index_name,
|
|
"ON", "HASH",
|
|
"PREFIX", "1", prefix,
|
|
"SCHEMA",
|
|
# Vector field for semantic search (HNSW algorithm)
|
|
"embedding", "VECTOR", "HNSW", "6",
|
|
"TYPE", "FLOAT32",
|
|
"DIM", str(self._embedding_dim),
|
|
"DISTANCE_METRIC", distance_metric,
|
|
# TAG fields (exact match filtering)
|
|
"query_hash", "TAG",
|
|
"intent", "TAG",
|
|
"language", "TAG",
|
|
"institution_type", "TAG",
|
|
"country_code", "TAG",
|
|
# Numeric fields
|
|
"created_at", "NUMERIC",
|
|
"ttl_seconds", "NUMERIC",
|
|
"confidence", "NUMERIC",
|
|
)
|
|
|
|
self._index = True # Mark as available
|
|
logger.info(f"Cache index created: {index_name} (dim={self._embedding_dim}, metric={distance_metric})")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create index: {e}")
|
|
self._index = None
|
|
|
|
def _should_bypass_cache(self, query: str) -> bool:
|
|
"""Check if query should bypass cache.
|
|
|
|
Bypass for:
|
|
- Very short queries (likely incomplete)
|
|
- Very long queries (too specific)
|
|
- Temporal/dynamic patterns
|
|
- User-specific queries
|
|
"""
|
|
# Length checks
|
|
if len(query) < self.settings.min_query_length:
|
|
return True
|
|
if len(query) > self.settings.max_query_length:
|
|
return True
|
|
|
|
# Pattern checks
|
|
for pattern in self._bypass_patterns:
|
|
if pattern.search(query):
|
|
return True
|
|
|
|
return False
|
|
|
|
def _compute_query_hash(self, query: str, language: str) -> str:
|
|
"""Compute deterministic hash for exact match lookup."""
|
|
normalized = query.lower().strip()
|
|
key = f"{language}:{normalized}"
|
|
return hashlib.sha256(key.encode()).hexdigest()[:16]
|
|
|
|
def _parse_ft_search_results(self, raw_results: list) -> list[dict[str, Any]]:
|
|
"""Parse raw FT.SEARCH results into list of dicts.
|
|
|
|
FT.SEARCH returns: [total_count, doc_id1, [field1, value1, ...], doc_id2, ...]
|
|
|
|
Args:
|
|
raw_results: Raw response from FT.SEARCH command
|
|
|
|
Returns:
|
|
List of dicts with field names as keys
|
|
"""
|
|
if not raw_results or len(raw_results) < 2:
|
|
return []
|
|
|
|
results = []
|
|
total_count = raw_results[0]
|
|
|
|
# Iterate through results (skip total_count at index 0)
|
|
# Each result is: doc_id, [field1, value1, field2, value2, ...]
|
|
i = 1
|
|
while i < len(raw_results):
|
|
doc_id = raw_results[i]
|
|
i += 1
|
|
|
|
if i >= len(raw_results):
|
|
break
|
|
|
|
fields_list = raw_results[i]
|
|
i += 1
|
|
|
|
# Parse fields list into dict
|
|
doc = {"_id": doc_id}
|
|
if isinstance(fields_list, (list, tuple)):
|
|
for j in range(0, len(fields_list), 2):
|
|
if j + 1 < len(fields_list):
|
|
field_name = fields_list[j]
|
|
field_value = fields_list[j + 1]
|
|
|
|
# Decode bytes to string if needed
|
|
if isinstance(field_name, bytes):
|
|
field_name = field_name.decode()
|
|
if isinstance(field_value, bytes):
|
|
field_value = field_value.decode()
|
|
|
|
doc[field_name] = field_value
|
|
|
|
results.append(doc)
|
|
|
|
return results
|
|
|
|
async def _embed_query(self, query: str) -> list[float] | None:
|
|
"""Generate embedding for query."""
|
|
if self._embedding_model is None:
|
|
return None
|
|
|
|
try:
|
|
# Run embedding in thread pool (CPU-bound)
|
|
loop = asyncio.get_event_loop()
|
|
embedding = await loop.run_in_executor(
|
|
None,
|
|
lambda: self._embedding_model.encode(query, convert_to_numpy=True)
|
|
)
|
|
return embedding.tolist()
|
|
except Exception as e:
|
|
logger.warning(f"Embedding failed: {e}")
|
|
return None
|
|
|
|
async def _validate_match(
|
|
self,
|
|
query: str,
|
|
cached_query: str,
|
|
threshold: float = 0.0,
|
|
) -> bool:
|
|
"""Validate semantic equivalence using cross-encoder.
|
|
|
|
This is the critical layer that prevents false positives.
|
|
Cross-encoders are more accurate than bi-encoders for this task.
|
|
|
|
The ms-marco-MiniLM model outputs logit scores:
|
|
- Positive scores indicate relevance/equivalence
|
|
- Negative scores indicate non-relevance
|
|
- Threshold of 0.0 is a good starting point
|
|
|
|
Args:
|
|
query: User's input query
|
|
cached_query: Matched query from cache
|
|
threshold: Minimum logit score for equivalence (default: 0.0)
|
|
|
|
Returns:
|
|
True if queries are semantically equivalent
|
|
"""
|
|
if self._validator is None:
|
|
# No validator, skip validation
|
|
return True
|
|
|
|
try:
|
|
# Run cross-encoder in thread pool
|
|
loop = asyncio.get_event_loop()
|
|
score = await loop.run_in_executor(
|
|
None,
|
|
lambda: self._validator.predict([[query, cached_query]])[0]
|
|
)
|
|
|
|
# ms-marco model outputs logits: positive = relevant, negative = not relevant
|
|
is_equivalent = score > threshold
|
|
|
|
if self.settings.log_cache_hits:
|
|
logger.debug(
|
|
f"Validation: '{query[:50]}...' vs '{cached_query[:50]}...' "
|
|
f"score={score:.3f} pass={is_equivalent}"
|
|
)
|
|
|
|
return is_equivalent
|
|
except Exception as e:
|
|
logger.warning(f"Validation failed: {e}")
|
|
return True # Fail open if validation errors
|
|
|
|
async def get(
|
|
self,
|
|
query: str,
|
|
language: str = "nl",
|
|
intent: str | None = None,
|
|
filters: dict[str, Any] | None = None,
|
|
) -> dict[str, Any] | None:
|
|
"""Retrieve cached response for query.
|
|
|
|
Performs three-layer lookup:
|
|
1. Exact hash match (fastest)
|
|
2. Semantic vector search (fuzzy)
|
|
3. Cross-encoder validation (accuracy)
|
|
|
|
Args:
|
|
query: User's natural language question
|
|
language: Query language (nl, en)
|
|
intent: Optional intent filter
|
|
filters: Additional metadata filters
|
|
|
|
Returns:
|
|
Cached response dict or None if no valid match
|
|
"""
|
|
start_time = time.perf_counter()
|
|
self.stats.total_queries += 1
|
|
|
|
if not self.settings.cache_enabled:
|
|
return None
|
|
|
|
# Check bypass patterns
|
|
if self._should_bypass_cache(query):
|
|
logger.debug(f"Cache bypass for query: {query[:50]}...")
|
|
return None
|
|
|
|
# Ensure initialized
|
|
await self.initialize()
|
|
|
|
# Try exact match first (hash lookup)
|
|
query_hash = self._compute_query_hash(query, language)
|
|
|
|
# Memory cache fallback
|
|
if self._redis is None:
|
|
entry = self._memory_cache.get(query_hash)
|
|
if entry:
|
|
self.stats.cache_hits += 1
|
|
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
|
self._update_hit_latency(elapsed_ms)
|
|
return entry.response
|
|
self.stats.cache_misses += 1
|
|
return None
|
|
|
|
# Redis semantic search
|
|
if self._index is None:
|
|
return None
|
|
|
|
# Generate embedding
|
|
embedding = await self._embed_query(query)
|
|
if embedding is None:
|
|
return None
|
|
|
|
try:
|
|
import numpy as np
|
|
|
|
# Convert embedding to bytes for KNN query
|
|
if isinstance(embedding, (list, tuple)):
|
|
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
|
|
elif isinstance(embedding, np.ndarray):
|
|
embedding_bytes = embedding.astype(np.float32).tobytes()
|
|
else:
|
|
embedding_bytes = embedding
|
|
|
|
# Build filter string for FT.SEARCH
|
|
# Format: @field:{value} for TAG fields
|
|
filter_parts = [f"@language:{{{language}}}"]
|
|
if intent:
|
|
filter_parts.append(f"@intent:{{{intent}}}")
|
|
if filters:
|
|
for key, value in filters.items():
|
|
if value:
|
|
filter_parts.append(f"@{key}:{{{value}}}")
|
|
|
|
# Combine filters with base KNN query
|
|
# FT.SEARCH format: "(@filter1 @filter2)=>[KNN N @field $vec AS score]"
|
|
filter_str = " ".join(filter_parts)
|
|
knn_query = f"({filter_str})=>[KNN 3 @embedding $vec AS __vec_score]"
|
|
|
|
# Execute raw FT.SEARCH command (Valkey-Search compatible)
|
|
# FT.SEARCH idx query PARAMS N name value ... RETURN N field ... DIALECT 2
|
|
raw_results = await self._redis.execute_command(
|
|
"FT.SEARCH", self.settings.cache_index_name,
|
|
knn_query,
|
|
"PARAMS", "2", "vec", embedding_bytes,
|
|
"RETURN", "6", "query", "query_hash", "response", "intent", "confidence", "__vec_score",
|
|
"DIALECT", "2"
|
|
)
|
|
|
|
# Parse FT.SEARCH results
|
|
# Format: [total_count, doc_id1, [field1, value1, ...], doc_id2, [field1, value1, ...], ...]
|
|
results = self._parse_ft_search_results(raw_results)
|
|
|
|
if not results:
|
|
self.stats.cache_misses += 1
|
|
return None
|
|
|
|
# Check distance threshold
|
|
best_result = results[0]
|
|
distance = float(best_result.get("__vec_score", 1.0))
|
|
|
|
if distance > self.settings.distance_threshold:
|
|
self.stats.cache_misses += 1
|
|
if self.settings.log_cache_misses:
|
|
logger.debug(
|
|
f"Cache miss: distance {distance:.4f} > "
|
|
f"threshold {self.settings.distance_threshold}"
|
|
)
|
|
return None
|
|
|
|
# Cross-encoder validation (critical for false positive prevention)
|
|
cached_query = best_result.get("query", "")
|
|
|
|
if self.settings.validation_enabled:
|
|
is_valid = await self._validate_match(query, cached_query)
|
|
|
|
if is_valid:
|
|
self.stats.validation_passes += 1
|
|
else:
|
|
self.stats.validation_failures += 1
|
|
self.stats.cache_misses += 1
|
|
logger.debug(f"Cache rejected by validation: {query[:50]}...")
|
|
return None
|
|
|
|
# Success - return cached response
|
|
self.stats.cache_hits += 1
|
|
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
|
self._update_hit_latency(elapsed_ms)
|
|
|
|
if self.settings.log_cache_hits:
|
|
logger.info(
|
|
f"Cache hit: '{query[:40]}...' matched '{cached_query[:40]}...' "
|
|
f"(dist={distance:.4f}, latency={elapsed_ms:.1f}ms)"
|
|
)
|
|
|
|
# Parse stored response
|
|
response_str = best_result.get("response", "{}")
|
|
if isinstance(response_str, bytes):
|
|
response_str = response_str.decode()
|
|
return json.loads(response_str)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Cache get failed: {e}")
|
|
self.stats.cache_misses += 1
|
|
return None
|
|
|
|
async def set(
|
|
self,
|
|
query: str,
|
|
response: dict[str, Any],
|
|
intent: str = "exploration",
|
|
language: str = "nl",
|
|
sources: list[str] | None = None,
|
|
confidence: float = 0.8,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> bool:
|
|
"""Store response in cache.
|
|
|
|
Args:
|
|
query: User's natural language question
|
|
response: RAG response to cache
|
|
intent: Query intent for TTL selection
|
|
language: Query language
|
|
sources: Data sources used
|
|
confidence: Response confidence score
|
|
metadata: Additional filterable metadata
|
|
|
|
Returns:
|
|
True if successfully cached
|
|
"""
|
|
if not self.settings.cache_enabled:
|
|
return False
|
|
|
|
# Check bypass patterns
|
|
if self._should_bypass_cache(query):
|
|
return False
|
|
|
|
await self.initialize()
|
|
|
|
# Compute hash and embedding
|
|
query_hash = self._compute_query_hash(query, language)
|
|
embedding = await self._embed_query(query)
|
|
|
|
# Determine TTL
|
|
ttl = get_ttl_for_intent(intent, self.settings)
|
|
|
|
# Create cache entry
|
|
entry = CacheEntry(
|
|
query=query,
|
|
query_hash=query_hash,
|
|
response=response,
|
|
intent=intent,
|
|
language=language,
|
|
sources=sources or [],
|
|
institution_type=metadata.get("institution_type") if metadata else None,
|
|
country_code=metadata.get("country_code") if metadata else None,
|
|
region_code=metadata.get("region_code") if metadata else None,
|
|
created_at=datetime.now(timezone.utc).isoformat(),
|
|
ttl_seconds=ttl,
|
|
confidence=confidence,
|
|
)
|
|
|
|
# Memory cache fallback
|
|
if self._redis is None:
|
|
self._memory_cache[query_hash] = entry
|
|
return True
|
|
|
|
# Store in Redis
|
|
try:
|
|
key = f"{self.settings.cache_prefix}{query_hash}"
|
|
|
|
# Prepare document
|
|
doc = {
|
|
"query": query,
|
|
"query_hash": query_hash,
|
|
"response": json.dumps(response),
|
|
"intent": intent,
|
|
"language": language,
|
|
"institution_type": entry.institution_type or "",
|
|
"country_code": entry.country_code or "",
|
|
"created_at": time.time(),
|
|
"ttl_seconds": ttl,
|
|
"confidence": confidence,
|
|
}
|
|
|
|
# Convert embedding to bytes for Redis storage
|
|
if embedding is not None:
|
|
import numpy as np
|
|
if isinstance(embedding, np.ndarray):
|
|
doc["embedding"] = embedding.astype(np.float32).tobytes()
|
|
elif isinstance(embedding, (list, tuple)):
|
|
doc["embedding"] = np.array(embedding, dtype=np.float32).tobytes()
|
|
else:
|
|
doc["embedding"] = embedding
|
|
|
|
# Store with TTL
|
|
await self._redis.hset(key, mapping=doc)
|
|
await self._redis.expire(key, ttl)
|
|
|
|
logger.debug(f"Cached: {query[:50]}... (ttl={ttl}s, intent={intent})")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Cache set failed: {e}")
|
|
return False
|
|
|
|
# =========================================================================
|
|
# Synchronous Wrappers (for use in sync contexts like DSPy modules)
|
|
# =========================================================================
|
|
|
|
def get_sync(
|
|
self,
|
|
query: str,
|
|
language: str = "nl",
|
|
intent: str | None = None,
|
|
filters: dict[str, Any] | None = None,
|
|
) -> dict[str, Any] | None:
|
|
"""Synchronous wrapper for get().
|
|
|
|
Safe to call from sync code - handles event loop detection.
|
|
Falls back to memory cache only if async execution fails.
|
|
|
|
Args:
|
|
query: User's natural language question
|
|
language: Query language (nl, en)
|
|
intent: Optional intent filter
|
|
filters: Additional metadata filters
|
|
|
|
Returns:
|
|
Cached response dict or None if no valid match
|
|
"""
|
|
# Fast path: check memory cache first (no async needed)
|
|
if not self.settings.cache_enabled:
|
|
return None
|
|
|
|
if self._should_bypass_cache(query):
|
|
return None
|
|
|
|
query_hash = self._compute_query_hash(query, language)
|
|
|
|
# Memory cache lookup (sync)
|
|
if self._redis is None or not self._initialized:
|
|
entry = self._memory_cache.get(query_hash)
|
|
if entry:
|
|
self.stats.cache_hits += 1
|
|
return entry.response
|
|
self.stats.cache_misses += 1
|
|
return None
|
|
|
|
# Try to run async get in event loop
|
|
try:
|
|
# Check if we're already in an async context
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
# We're in an async context - can't use asyncio.run()
|
|
# Fall back to memory cache only
|
|
logger.debug("get_sync called from async context, using memory cache only")
|
|
entry = self._memory_cache.get(query_hash)
|
|
if entry:
|
|
self.stats.cache_hits += 1
|
|
return entry.response
|
|
self.stats.cache_misses += 1
|
|
return None
|
|
except RuntimeError:
|
|
# No running loop - safe to use asyncio.run()
|
|
return asyncio.run(self.get(query, language, intent, filters))
|
|
except Exception as e:
|
|
logger.warning(f"get_sync failed, falling back to memory cache: {e}")
|
|
entry = self._memory_cache.get(query_hash)
|
|
if entry:
|
|
return entry.response
|
|
return None
|
|
|
|
def set_sync(
|
|
self,
|
|
query: str,
|
|
response: dict[str, Any],
|
|
intent: str = "exploration",
|
|
language: str = "nl",
|
|
sources: list[str] | None = None,
|
|
confidence: float = 0.8,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> bool:
|
|
"""Synchronous wrapper for set().
|
|
|
|
Safe to call from sync code - handles event loop detection.
|
|
Falls back to memory cache if async execution fails.
|
|
|
|
Args:
|
|
query: User's natural language question
|
|
response: RAG response to cache
|
|
intent: Query intent for TTL selection
|
|
language: Query language
|
|
sources: Data sources used
|
|
confidence: Response confidence score
|
|
metadata: Additional filterable metadata
|
|
|
|
Returns:
|
|
True if successfully cached
|
|
"""
|
|
if not self.settings.cache_enabled:
|
|
return False
|
|
|
|
if self._should_bypass_cache(query):
|
|
return False
|
|
|
|
query_hash = self._compute_query_hash(query, language)
|
|
|
|
# Memory cache fallback (sync)
|
|
if self._redis is None or not self._initialized:
|
|
ttl = get_ttl_for_intent(intent, self.settings)
|
|
entry = CacheEntry(
|
|
query=query,
|
|
query_hash=query_hash,
|
|
response=response,
|
|
intent=intent,
|
|
language=language,
|
|
sources=sources or [],
|
|
institution_type=metadata.get("institution_type") if metadata else None,
|
|
country_code=metadata.get("country_code") if metadata else None,
|
|
region_code=metadata.get("region_code") if metadata else None,
|
|
created_at=datetime.now(timezone.utc).isoformat(),
|
|
ttl_seconds=ttl,
|
|
confidence=confidence,
|
|
)
|
|
self._memory_cache[query_hash] = entry
|
|
return True
|
|
|
|
# Try to run async set in event loop
|
|
try:
|
|
# Check if we're already in an async context
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
# We're in an async context - can't use asyncio.run()
|
|
# Fall back to memory cache only
|
|
logger.debug("set_sync called from async context, using memory cache only")
|
|
ttl = get_ttl_for_intent(intent, self.settings)
|
|
entry = CacheEntry(
|
|
query=query,
|
|
query_hash=query_hash,
|
|
response=response,
|
|
intent=intent,
|
|
language=language,
|
|
sources=sources or [],
|
|
institution_type=metadata.get("institution_type") if metadata else None,
|
|
country_code=metadata.get("country_code") if metadata else None,
|
|
region_code=metadata.get("region_code") if metadata else None,
|
|
created_at=datetime.now(timezone.utc).isoformat(),
|
|
ttl_seconds=ttl,
|
|
confidence=confidence,
|
|
)
|
|
self._memory_cache[query_hash] = entry
|
|
return True
|
|
except RuntimeError:
|
|
# No running loop - safe to use asyncio.run()
|
|
return asyncio.run(self.set(query, response, intent, language, sources, confidence, metadata))
|
|
except Exception as e:
|
|
logger.warning(f"set_sync failed, falling back to memory cache: {e}")
|
|
ttl = get_ttl_for_intent(intent, self.settings)
|
|
entry = CacheEntry(
|
|
query=query,
|
|
query_hash=query_hash,
|
|
response=response,
|
|
intent=intent,
|
|
language=language,
|
|
sources=sources or [],
|
|
institution_type=metadata.get("institution_type") if metadata else None,
|
|
country_code=metadata.get("country_code") if metadata else None,
|
|
region_code=metadata.get("region_code") if metadata else None,
|
|
created_at=datetime.now(timezone.utc).isoformat(),
|
|
ttl_seconds=ttl,
|
|
confidence=confidence,
|
|
)
|
|
self._memory_cache[query_hash] = entry
|
|
return True
|
|
|
|
async def warmup(self, faqs: dict[str, list[str]] | None = None) -> int:
|
|
"""Pre-populate cache with common heritage FAQs.
|
|
|
|
Implements the "Best Candidate Principle" from research:
|
|
Having high-quality candidates available is more effective
|
|
than optimizing selection on inadequate candidates.
|
|
|
|
Args:
|
|
faqs: Dict of intent -> list of FAQ questions
|
|
Uses FAQ_CATEGORIES if None
|
|
|
|
Returns:
|
|
Number of FAQs successfully cached
|
|
"""
|
|
if faqs is None:
|
|
faqs = FAQ_CATEGORIES
|
|
|
|
await self.initialize()
|
|
|
|
cached_count = 0
|
|
|
|
for intent, questions in faqs.items():
|
|
for question in questions:
|
|
# Detect language
|
|
language = "en" if any(
|
|
w in question.lower()
|
|
for w in ["what", "where", "when", "how", "which", "show", "find"]
|
|
) else "nl"
|
|
|
|
# Create placeholder response for warmup
|
|
# These will be replaced on first actual query
|
|
warmup_response = {
|
|
"answer": f"[Warmup placeholder for: {question}]",
|
|
"warmup": True,
|
|
"intent": intent,
|
|
}
|
|
|
|
success = await self.set(
|
|
query=question,
|
|
response=warmup_response,
|
|
intent=intent,
|
|
language=language,
|
|
confidence=0.0, # Mark as warmup
|
|
)
|
|
|
|
if success:
|
|
cached_count += 1
|
|
|
|
logger.info(f"Cache warmup complete: {cached_count} FAQs cached")
|
|
return cached_count
|
|
|
|
async def invalidate(
|
|
self,
|
|
pattern: str | None = None,
|
|
intent: str | None = None,
|
|
older_than_hours: int | None = None,
|
|
) -> int:
|
|
"""Invalidate cache entries.
|
|
|
|
Args:
|
|
pattern: Query text pattern to match
|
|
intent: Invalidate entries with this intent
|
|
older_than_hours: Invalidate entries older than X hours
|
|
|
|
Returns:
|
|
Number of entries invalidated
|
|
"""
|
|
if self._redis is None:
|
|
# Memory cache invalidation
|
|
if pattern:
|
|
to_delete = [
|
|
k for k, v in self._memory_cache.items()
|
|
if pattern.lower() in v.query.lower()
|
|
]
|
|
elif intent:
|
|
to_delete = [
|
|
k for k, v in self._memory_cache.items()
|
|
if v.intent == intent
|
|
]
|
|
else:
|
|
to_delete = list(self._memory_cache.keys())
|
|
|
|
for k in to_delete:
|
|
del self._memory_cache[k]
|
|
return len(to_delete)
|
|
|
|
# Redis invalidation
|
|
try:
|
|
keys = []
|
|
async for key in self._redis.scan_iter(f"{self.settings.cache_prefix}*"):
|
|
keys.append(key)
|
|
|
|
if keys:
|
|
deleted = await self._redis.delete(*keys)
|
|
logger.info(f"Invalidated {deleted} cache entries")
|
|
return deleted
|
|
return 0
|
|
except Exception as e:
|
|
logger.error(f"Cache invalidation failed: {e}")
|
|
return 0
|
|
|
|
def _update_hit_latency(self, latency_ms: float) -> None:
|
|
"""Update average hit latency (exponential moving average)."""
|
|
alpha = 0.1 # Smoothing factor
|
|
if self.stats.avg_hit_latency_ms == 0:
|
|
self.stats.avg_hit_latency_ms = latency_ms
|
|
else:
|
|
self.stats.avg_hit_latency_ms = (
|
|
alpha * latency_ms + (1 - alpha) * self.stats.avg_hit_latency_ms
|
|
)
|
|
|
|
def get_stats(self) -> dict[str, Any]:
|
|
"""Get cache performance statistics."""
|
|
self.stats.update_hit_rate()
|
|
self.stats.update_false_positive_rate()
|
|
|
|
return {
|
|
"total_queries": self.stats.total_queries,
|
|
"cache_hits": self.stats.cache_hits,
|
|
"cache_misses": self.stats.cache_misses,
|
|
"hit_rate": round(self.stats.hit_rate * 100, 2),
|
|
"validation_passes": self.stats.validation_passes,
|
|
"validation_failures": self.stats.validation_failures,
|
|
"false_positive_rate": round(self.stats.false_positive_rate * 100, 2),
|
|
"avg_hit_latency_ms": round(self.stats.avg_hit_latency_ms, 2),
|
|
"backend": "redis" if self._redis else "memory",
|
|
"embedding_model": self.settings.cache_embedding_model,
|
|
"distance_threshold": self.settings.distance_threshold,
|
|
"validation_enabled": self._validator is not None,
|
|
}
|
|
|
|
async def close(self) -> None:
|
|
"""Close cache connections."""
|
|
if self._redis:
|
|
await self._redis.close()
|
|
self._redis = None
|
|
|
|
async def clear(self) -> int:
|
|
"""Clear all cache entries.
|
|
|
|
Returns:
|
|
Number of entries cleared
|
|
"""
|
|
if self._redis is None:
|
|
# Memory cache clear
|
|
count = len(self._memory_cache)
|
|
self._memory_cache.clear()
|
|
return count
|
|
|
|
# Redis clear
|
|
try:
|
|
keys = []
|
|
async for key in self._redis.scan_iter(f"{self.settings.cache_prefix}*"):
|
|
keys.append(key)
|
|
|
|
if keys:
|
|
deleted = await self._redis.delete(*keys)
|
|
logger.info(f"Cleared {deleted} cache entries")
|
|
return deleted
|
|
return 0
|
|
except Exception as e:
|
|
logger.error(f"Cache clear failed: {e}")
|
|
return 0
|
|
|
|
|
|
# Global cache instance
|
|
_cache_instance: HeritageSemanticCache | None = None
|
|
|
|
|
|
def get_cache() -> HeritageSemanticCache:
|
|
"""Get or create global cache instance (synchronous).
|
|
|
|
Creates a new HeritageSemanticCache instance if not already created.
|
|
Note: This returns the cache instance, but initialization (loading models,
|
|
connecting to Redis) happens lazily on first get/set operation.
|
|
"""
|
|
global _cache_instance
|
|
|
|
if _cache_instance is None:
|
|
_cache_instance = HeritageSemanticCache()
|
|
|
|
return _cache_instance
|
|
|
|
|
|
async def get_cache_async() -> HeritageSemanticCache:
|
|
"""Get or create global cache instance (asynchronous with initialization).
|
|
|
|
Creates and initializes the cache instance, loading models and connecting
|
|
to Redis if available.
|
|
"""
|
|
global _cache_instance
|
|
|
|
if _cache_instance is None:
|
|
_cache_instance = HeritageSemanticCache()
|
|
|
|
await _cache_instance.initialize()
|
|
return _cache_instance
|
|
|
|
|
|
async def cache_rag_response(
|
|
query: str,
|
|
response: dict[str, Any],
|
|
intent: str = "exploration",
|
|
language: str = "nl",
|
|
) -> bool:
|
|
"""Convenience function to cache a RAG response."""
|
|
cache = await get_cache_async()
|
|
return await cache.set(query, response, intent, language)
|
|
|
|
|
|
async def get_cached_response(
|
|
query: str,
|
|
language: str = "nl",
|
|
intent: str | None = None,
|
|
) -> dict[str, Any] | None:
|
|
"""Convenience function to retrieve cached response."""
|
|
cache = await get_cache_async()
|
|
return await cache.get(query, language, intent)
|