glam/backend/rag/metrics.py
kempersc 99dc608826 Refactor RAG to template-based SPARQL generation
Major architectural changes based on Formica et al. (2023) research:
- Add TemplateClassifier for deterministic SPARQL template matching
- Add SlotExtractor with synonym resolution for slot values
- Add TemplateInstantiator using Jinja2 for query rendering
- Refactor dspy_heritage_rag.py to use template system
- Update main.py with streamlined pipeline
- Fix semantic_router.py ordering issues
- Add comprehensive metrics tracking

Template-based approach achieves 65% precision vs 10% LLM-only
per Formica et al. research on SPARQL generation.
2026-01-07 22:04:43 +01:00

719 lines
24 KiB
Python

"""
Prometheus Metrics for Heritage RAG API
Exposes metrics for monitoring template-based SPARQL generation,
session management, caching, and overall API performance.
Metrics exposed:
- rag_queries_total: Total queries by type (template/llm), status, endpoint
- rag_template_hits_total: Template SPARQL hits by template_id
- rag_template_tier_total: Template matching by tier (pattern/embedding/rag/llm)
- rag_query_duration_seconds: Query latency histogram
- rag_session_active: Active sessions gauge
- rag_cache_hits_total: Cache hit/miss counter
- rag_atomic_cache_total: Atomic sub-task cache hits/misses
- rag_atomic_subtasks_total: Sub-task cache operations
- rag_connection_pool_size: Connection pool utilization gauge
- rag_embedding_warmup_seconds: Embedding model warmup time
Usage:
from backend.rag.metrics import (
record_query, record_atomic_cache, record_template_tier,
create_metrics_endpoint, PROMETHEUS_AVAILABLE
)
# Record a query
record_query(
endpoint="dspy_query",
template_used=True,
template_id="count_by_province",
cache_hit=False,
status="success",
duration_seconds=1.5
)
# Record template tier
record_template_tier(tier="pattern", template_id="list_by_city")
# Record atomic cache stats
record_atomic_cache(
query_hit=False,
subtask_hits=3,
subtask_misses=1,
fully_assembled=False
)
"""
from __future__ import annotations
import logging
from functools import lru_cache
from typing import Any
logger = logging.getLogger(__name__)
# ============================================================================
# Prometheus Client Import (Lazy/Optional)
# ============================================================================
PROMETHEUS_AVAILABLE = False
_prometheus_client = None
try:
import prometheus_client as _prometheus_client
PROMETHEUS_AVAILABLE = True
logger.info("Prometheus metrics enabled")
except ImportError:
logger.warning("prometheus_client not installed - metrics disabled")
# ============================================================================
# Metric Initialization
# ============================================================================
def _init_metrics():
"""Initialize Prometheus metrics. Called once at module load."""
if not PROMETHEUS_AVAILABLE or _prometheus_client is None:
return {}
pc = _prometheus_client
return {
# =================================================================
# Query-level Metrics
# =================================================================
"query_counter": pc.Counter(
"rag_queries_total",
"Total RAG queries processed",
labelnames=["endpoint", "method", "status"],
),
"template_hit_counter": pc.Counter(
"rag_template_hits_total",
"Template SPARQL hits by template ID",
labelnames=["template_id", "intent"],
),
"cache_counter": pc.Counter(
"rag_cache_total",
"Cache hits and misses",
labelnames=["result"],
),
"query_duration": pc.Histogram(
"rag_query_duration_seconds",
"Query processing time in seconds",
labelnames=["endpoint", "method"],
buckets=(0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0),
),
# =================================================================
# Template Matching Tier Metrics (NEW)
# =================================================================
"template_tier_counter": pc.Counter(
"rag_template_tier_total",
"Template matching attempts by tier",
labelnames=["tier", "matched"], # tier: pattern, embedding, rag, llm
),
"template_matching_duration": pc.Histogram(
"rag_template_matching_seconds",
"Time to match query to template",
labelnames=["tier", "matched"],
buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0),
),
# =================================================================
# Atomic Sub-task Cache Metrics (NEW)
# =================================================================
"atomic_query_counter": pc.Counter(
"rag_atomic_queries_total",
"Atomic decomposition query attempts",
labelnames=["result"], # full_hit, partial_hit, miss
),
"atomic_subtask_counter": pc.Counter(
"rag_atomic_subtasks_total",
"Atomic sub-task cache operations",
labelnames=["operation"], # hit, miss, cached
),
"atomic_reassembly_counter": pc.Counter(
"rag_atomic_reassemblies_total",
"Full query reassemblies from cached sub-tasks",
),
"atomic_subtask_hit_rate": pc.Gauge(
"rag_atomic_subtask_hit_rate",
"Current atomic sub-task cache hit rate (0-1)",
),
# =================================================================
# Connection Pool Metrics (NEW)
# =================================================================
"connection_pool_size": pc.Gauge(
"rag_connection_pool_size",
"Current connection pool size by client type",
labelnames=["client"], # sparql, postgis (ducklake removed from RAG)
),
"connection_pool_available": pc.Gauge(
"rag_connection_pool_available",
"Available connections in pool by client type",
labelnames=["client"],
),
# =================================================================
# Warmup/Initialization Metrics (NEW)
# =================================================================
"embedding_warmup_duration": pc.Gauge(
"rag_embedding_warmup_seconds",
"Time taken to warm up embedding model",
labelnames=["model"],
),
"template_embedding_warmup_duration": pc.Gauge(
"rag_template_embedding_warmup_seconds",
"Time taken to pre-compute template embeddings",
),
"warmup_status": pc.Gauge(
"rag_warmup_complete",
"Whether warmup is complete (1) or not (0)",
labelnames=["component"], # embedding_model, template_embeddings
),
# =================================================================
# Session Metrics
# =================================================================
"active_sessions_gauge": pc.Gauge(
"rag_sessions_active",
"Number of active conversation sessions",
),
}
# Initialize metrics at module load
_metrics = _init_metrics()
# ============================================================================
# Helper Functions
# ============================================================================
def record_query(
endpoint: str,
template_used: bool,
template_id: str | None,
cache_hit: bool,
status: str,
duration_seconds: float,
intent: str | None = None,
) -> None:
"""Record metrics for a completed query.
Args:
endpoint: API endpoint name (e.g., "dspy_query", "dspy_query_stream")
template_used: Whether template SPARQL was used vs LLM generation
template_id: Template ID if template was used
cache_hit: Whether response was served from cache
status: Query status ("success", "error", "timeout")
duration_seconds: Total query duration in seconds
intent: Query intent classification if available
"""
if not PROMETHEUS_AVAILABLE or not _metrics:
return
method = "template" if template_used else "llm"
# Increment query counter
_metrics["query_counter"].labels(
endpoint=endpoint,
method=method,
status=status,
).inc()
# Record template hit if applicable
if template_used and template_id:
_metrics["template_hit_counter"].labels(
template_id=template_id,
intent=intent or "unknown",
).inc()
# Record cache status
_metrics["cache_counter"].labels(result="hit" if cache_hit else "miss").inc()
# Record duration
_metrics["query_duration"].labels(
endpoint=endpoint,
method=method,
).observe(duration_seconds)
def record_template_matching(matched: bool, duration_seconds: float, tier: str = "unknown") -> None:
"""Record template matching attempt metrics.
Args:
matched: Whether a template was successfully matched
duration_seconds: Time taken to attempt template matching
tier: Which matching tier was used (pattern, embedding, llm)
"""
if not PROMETHEUS_AVAILABLE or not _metrics:
return
matched_str = "true" if matched else "false"
_metrics["template_matching_duration"].labels(
tier=tier,
matched=matched_str,
).observe(duration_seconds)
_metrics["template_tier_counter"].labels(
tier=tier,
matched=matched_str,
).inc()
def set_active_sessions(count: int) -> None:
"""Update the active sessions gauge.
Args:
count: Current number of active sessions
"""
if not PROMETHEUS_AVAILABLE or not _metrics:
return
_metrics["active_sessions_gauge"].set(count)
def increment_active_sessions() -> None:
"""Increment active sessions by 1."""
if not PROMETHEUS_AVAILABLE or not _metrics:
return
_metrics["active_sessions_gauge"].inc()
def decrement_active_sessions() -> None:
"""Decrement active sessions by 1."""
if not PROMETHEUS_AVAILABLE or not _metrics:
return
_metrics["active_sessions_gauge"].dec()
# ============================================================================
# Template Tier Metrics (NEW)
# ============================================================================
def record_template_tier(
tier: str,
matched: bool,
template_id: str | None = None,
duration_seconds: float | None = None,
) -> None:
"""Record which template matching tier was used.
Args:
tier: Matching tier - "pattern", "embedding", "rag", or "llm"
matched: Whether the tier successfully matched
template_id: Template ID if matched
duration_seconds: Optional time taken for this tier
"""
if not PROMETHEUS_AVAILABLE or not _metrics:
return
matched_str = "true" if matched else "false"
_metrics["template_tier_counter"].labels(
tier=tier,
matched=matched_str,
).inc()
if duration_seconds is not None:
_metrics["template_matching_duration"].labels(
tier=tier,
matched=matched_str,
).observe(duration_seconds)
# ============================================================================
# Atomic Sub-task Cache Metrics (NEW)
# ============================================================================
def record_atomic_cache(
query_hit: bool,
subtask_hits: int = 0,
subtask_misses: int = 0,
fully_assembled: bool = False,
) -> None:
"""Record atomic sub-task cache metrics.
Args:
query_hit: Whether full query was reassembled from cache
subtask_hits: Number of sub-task cache hits
subtask_misses: Number of sub-task cache misses
fully_assembled: Whether all sub-tasks were cached (full reassembly)
"""
if not PROMETHEUS_AVAILABLE or not _metrics:
return
# Record query-level result
if fully_assembled:
result = "full_hit"
elif subtask_hits > 0:
result = "partial_hit"
else:
result = "miss"
_metrics["atomic_query_counter"].labels(result=result).inc()
# Record sub-task level stats
if subtask_hits > 0:
_metrics["atomic_subtask_counter"].labels(operation="hit").inc(subtask_hits)
if subtask_misses > 0:
_metrics["atomic_subtask_counter"].labels(operation="miss").inc(subtask_misses)
# Record full reassembly
if fully_assembled:
_metrics["atomic_reassembly_counter"].inc()
# Update hit rate gauge
total = subtask_hits + subtask_misses
if total > 0:
hit_rate = subtask_hits / total
_metrics["atomic_subtask_hit_rate"].set(hit_rate)
def record_atomic_subtask_cached(count: int = 1) -> None:
"""Record that sub-tasks were cached for future use.
Args:
count: Number of sub-tasks cached
"""
if not PROMETHEUS_AVAILABLE or not _metrics:
return
_metrics["atomic_subtask_counter"].labels(operation="cached").inc(count)
# ============================================================================
# Connection Pool Metrics (NEW)
# ============================================================================
def record_connection_pool(
client: str,
pool_size: int,
available: int | None = None,
) -> None:
"""Record connection pool utilization.
Args:
client: Client type - "sparql", "postgis" (ducklake removed from RAG)
pool_size: Current total pool size
available: Number of available connections (if known)
"""
if not PROMETHEUS_AVAILABLE or not _metrics:
return
_metrics["connection_pool_size"].labels(client=client).set(pool_size)
if available is not None:
_metrics["connection_pool_available"].labels(client=client).set(available)
# ============================================================================
# Warmup Metrics (NEW)
# ============================================================================
def record_embedding_warmup(
model: str,
duration_seconds: float,
success: bool = True,
) -> None:
"""Record embedding model warmup time.
Args:
model: Model name/identifier
duration_seconds: Time taken to warm up
success: Whether warmup completed successfully
"""
if not PROMETHEUS_AVAILABLE or not _metrics:
return
_metrics["embedding_warmup_duration"].labels(model=model).set(duration_seconds)
_metrics["warmup_status"].labels(component="embedding_model").set(1 if success else 0)
def record_template_embedding_warmup(
duration_seconds: float,
template_count: int = 0,
success: bool = True,
) -> None:
"""Record template embedding pre-computation time.
Args:
duration_seconds: Time taken to compute template embeddings
template_count: Number of templates processed
success: Whether warmup completed successfully
"""
if not PROMETHEUS_AVAILABLE or not _metrics:
return
_metrics["template_embedding_warmup_duration"].set(duration_seconds)
_metrics["warmup_status"].labels(component="template_embeddings").set(1 if success else 0)
def set_warmup_status(component: str, complete: bool) -> None:
"""Set warmup status for a component.
Args:
component: Component name - "embedding_model", "template_embeddings"
complete: Whether warmup is complete
"""
if not PROMETHEUS_AVAILABLE or not _metrics:
return
_metrics["warmup_status"].labels(component=component).set(1 if complete else 0)
# ============================================================================
# Metrics Endpoint
# ============================================================================
@lru_cache(maxsize=1)
def _get_metrics_bytes() -> tuple[bytes, str]:
"""Generate Prometheus metrics response (cached).
Returns:
Tuple of (metrics_bytes, content_type)
"""
if not PROMETHEUS_AVAILABLE or _prometheus_client is None:
return b"# Prometheus metrics not available\n", "text/plain"
return (
_prometheus_client.generate_latest(_prometheus_client.REGISTRY),
_prometheus_client.CONTENT_TYPE_LATEST,
)
def get_metrics_response() -> tuple[bytes, str]:
"""Generate Prometheus metrics response.
Clears cache to ensure fresh metrics on each call.
Returns:
Tuple of (metrics_bytes, content_type)
"""
_get_metrics_bytes.cache_clear()
return _get_metrics_bytes()
def create_metrics_endpoint():
"""Create a FastAPI router for the /metrics endpoint.
Usage:
from backend.rag.metrics import create_metrics_endpoint
app.include_router(create_metrics_endpoint())
Returns:
FastAPI APIRouter with /metrics endpoint
"""
from fastapi import APIRouter
from fastapi.responses import Response
router = APIRouter(tags=["monitoring"])
@router.get("/metrics")
async def metrics():
"""Prometheus metrics endpoint for scraping."""
body, content_type = get_metrics_response()
return Response(content=body, media_type=content_type)
return router
# ============================================================================
# Metric Summary Helpers (for logging/debugging)
# ============================================================================
def get_template_hit_rate() -> dict[str, Any]:
"""Calculate template hit rate from current metrics.
Returns:
Dict with hit rate statistics
"""
if not PROMETHEUS_AVAILABLE or not _metrics:
return {"available": False}
query_counter = _metrics["query_counter"]
# Get current counter values
total_template = 0
total_llm = 0
# Iterate through query_counter samples
for metric in query_counter.collect():
for sample in metric.samples:
if sample.name == "rag_queries_total":
labels = sample.labels
if labels.get("method") == "template":
total_template += sample.value
elif labels.get("method") == "llm":
total_llm += sample.value
total = total_template + total_llm
hit_rate = total_template / total if total > 0 else 0.0
return {
"available": True,
"total_queries": int(total),
"template_queries": int(total_template),
"llm_queries": int(total_llm),
"template_hit_rate": round(hit_rate, 4),
"template_hit_rate_percent": round(hit_rate * 100, 2),
}
def get_template_breakdown() -> dict[str, int]:
"""Get breakdown of template usage by template_id.
Returns:
Dict mapping template_id to hit count
"""
if not PROMETHEUS_AVAILABLE or not _metrics:
return {}
template_counter = _metrics["template_hit_counter"]
breakdown: dict[str, int] = {}
for metric in template_counter.collect():
for sample in metric.samples:
if sample.name == "rag_template_hits_total":
template_id = sample.labels.get("template_id", "unknown")
breakdown[template_id] = int(sample.value)
return breakdown
def get_template_tier_stats() -> dict[str, Any]:
"""Get template matching tier statistics.
Returns:
Dict with tier breakdown and hit rates
"""
if not PROMETHEUS_AVAILABLE or not _metrics:
return {"available": False}
tier_counter = _metrics["template_tier_counter"]
stats: dict[str, dict[str, int]] = {
"pattern": {"matched": 0, "unmatched": 0},
"embedding": {"matched": 0, "unmatched": 0},
"rag": {"matched": 0, "unmatched": 0}, # Tier 2.5: RAG-enhanced matching
"llm": {"matched": 0, "unmatched": 0},
}
for metric in tier_counter.collect():
for sample in metric.samples:
if sample.name == "rag_template_tier_total":
tier = sample.labels.get("tier", "unknown")
matched = sample.labels.get("matched") == "true"
if tier in stats:
key = "matched" if matched else "unmatched"
stats[tier][key] = int(sample.value)
# Calculate totals and rates
result = {"available": True, "tiers": {}}
for tier, counts in stats.items():
total = counts["matched"] + counts["unmatched"]
hit_rate = counts["matched"] / total if total > 0 else 0.0
result["tiers"][tier] = {
"matched": counts["matched"],
"unmatched": counts["unmatched"],
"total": total,
"hit_rate": round(hit_rate, 4),
}
# Overall stats
total_matched = sum(s["matched"] for s in stats.values())
total_attempts = sum(s["matched"] + s["unmatched"] for s in stats.values())
result["total_matched"] = total_matched
result["total_attempts"] = total_attempts
result["overall_hit_rate"] = round(total_matched / total_attempts, 4) if total_attempts > 0 else 0.0
return result
def get_atomic_cache_stats() -> dict[str, Any]:
"""Get atomic sub-task cache statistics.
Returns:
Dict with cache hit rates and operation counts
"""
if not PROMETHEUS_AVAILABLE or not _metrics:
return {"available": False}
# Get query-level stats
query_counter = _metrics["atomic_query_counter"]
query_stats = {"full_hit": 0, "partial_hit": 0, "miss": 0}
for metric in query_counter.collect():
for sample in metric.samples:
if sample.name == "rag_atomic_queries_total":
result = sample.labels.get("result", "unknown")
if result in query_stats:
query_stats[result] = int(sample.value)
# Get sub-task level stats
subtask_counter = _metrics["atomic_subtask_counter"]
subtask_stats = {"hit": 0, "miss": 0, "cached": 0}
for metric in subtask_counter.collect():
for sample in metric.samples:
if sample.name == "rag_atomic_subtasks_total":
operation = sample.labels.get("operation", "unknown")
if operation in subtask_stats:
subtask_stats[operation] = int(sample.value)
# Get reassembly count
reassembly_counter = _metrics["atomic_reassembly_counter"]
reassemblies = 0
for metric in reassembly_counter.collect():
for sample in metric.samples:
if sample.name == "rag_atomic_reassemblies_total":
reassemblies = int(sample.value)
# Calculate rates
total_queries = sum(query_stats.values())
total_subtasks = subtask_stats["hit"] + subtask_stats["miss"]
subtask_hit_rate = subtask_stats["hit"] / total_subtasks if total_subtasks > 0 else 0.0
query_full_hit_rate = query_stats["full_hit"] / total_queries if total_queries > 0 else 0.0
query_any_hit_rate = (query_stats["full_hit"] + query_stats["partial_hit"]) / total_queries if total_queries > 0 else 0.0
return {
"available": True,
"queries": {
"full_hits": query_stats["full_hit"],
"partial_hits": query_stats["partial_hit"],
"misses": query_stats["miss"],
"total": total_queries,
"full_hit_rate": round(query_full_hit_rate, 4),
"any_hit_rate": round(query_any_hit_rate, 4),
},
"subtasks": {
"hits": subtask_stats["hit"],
"misses": subtask_stats["miss"],
"cached": subtask_stats["cached"],
"total": total_subtasks,
"hit_rate": round(subtask_hit_rate, 4),
"hit_rate_percent": round(subtask_hit_rate * 100, 2),
},
"reassemblies": reassemblies,
}
def get_all_performance_stats() -> dict[str, Any]:
"""Get comprehensive performance statistics.
Returns:
Dict with all performance metrics for monitoring dashboards
"""
return {
"template_hit_rate": get_template_hit_rate(),
"template_breakdown": get_template_breakdown(),
"template_tiers": get_template_tier_stats(),
"atomic_cache": get_atomic_cache_stats(),
}