glam/backend/rag/main.py
kempersc 6c19ef8661 feat(rag): add Rule 46 epistemic provenance tracking
Track full lineage of RAG responses: WHERE data comes from, WHEN it was
retrieved, HOW it was processed (SPARQL/vector/LLM).

Backend changes:
- Add provenance.py with EpistemicProvenance, DataTier, SourceAttribution
- Integrate provenance into MultiSourceRetriever.merge_results()
- Return epistemic_provenance in DSPyQueryResponse

Frontend changes:
- Pass EpistemicProvenance through useMultiDatabaseRAG hook
- Display provenance in ConversationPage (for cache transparency)

Schema fixes:
- Fix truncated example in has_observation.yaml slot definition

References:
- Pavlyshyn's Context Graphs and Data Traces paper
- LinkML ProvenanceBlock schema pattern
2026-01-10 18:42:43 +01:00

5095 lines
219 KiB
Python

from __future__ import annotations
"""
Unified RAG Backend for Heritage Custodian Data
Multi-source retrieval-augmented generation system that combines:
- Qdrant vector search (semantic similarity)
- Oxigraph SPARQL (knowledge graph queries)
- TypeDB (relationship traversal)
- PostGIS (geospatial queries)
- Valkey (semantic caching)
Architecture:
User Query → Query Analysis
┌─────┴─────┐
│ Router │
└─────┬─────┘
┌─────┬─────┼─────┬─────┐
↓ ↓ ↓ ↓ ↓
Qdrant SPARQL TypeDB PostGIS Cache
│ │ │ │ │
└─────┴─────┴─────┴─────┘
┌─────┴─────┐
│ Merger │
└─────┬─────┘
DSPy Generator
Visualization Selector
Response (JSON/Streaming)
Features:
- Intelligent query routing to appropriate data sources
- Score fusion for multi-source results
- Semantic caching via Valkey API
- Streaming responses for long-running queries
- DSPy assertions for output validation
Endpoints:
- POST /api/rag/query - Main RAG query endpoint
- POST /api/rag/sparql - Generate SPARQL with RAG context
- POST /api/rag/typedb/search - Direct TypeDB search
- GET /api/rag/health - Health check for all services
- GET /api/rag/stats - Retriever statistics
"""
import asyncio
import hashlib
import json
import logging
import os
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Any, AsyncIterator, TYPE_CHECKING
import httpx
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
# Rule 46: Epistemic Provenance Tracking
from .provenance import (
EpistemicProvenance,
EpistemicDataSource,
DataTier,
RetrievalSource,
SourceAttribution,
infer_data_tier,
build_derivation_chain,
aggregate_data_tier,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Type hints for optional imports (only used during type checking)
if TYPE_CHECKING:
from glam_extractor.api.hybrid_retriever import HybridRetriever
from glam_extractor.api.typedb_retriever import TypeDBRetriever
from glam_extractor.api.visualization import VisualizationSelector
# Import retrievers (with graceful fallbacks)
RETRIEVERS_AVAILABLE = False
create_hybrid_retriever: Any = None
HeritageCustodianRetriever: Any = None
create_typedb_retriever: Any = None
select_visualization: Any = None
VisualizationSelector: Any = None # type: ignore[no-redef]
generate_sparql: Any = None
configure_dspy: Any = None
get_province_code: Any = None # Province name to ISO 3166-2 code converter
try:
import sys
sys.path.insert(0, str(os.path.join(os.path.dirname(__file__), "..", "..", "src")))
from glam_extractor.api.hybrid_retriever import (
HybridRetriever as _HybridRetriever,
create_hybrid_retriever as _create_hybrid_retriever,
get_province_code as _get_province_code,
PERSON_JSONLD_CONTEXT,
)
from glam_extractor.api.qdrant_retriever import HeritageCustodianRetriever as _HeritageCustodianRetriever
from glam_extractor.api.typedb_retriever import TypeDBRetriever as _TypeDBRetriever, create_typedb_retriever as _create_typedb_retriever
from glam_extractor.api.visualization import select_visualization as _select_visualization, VisualizationSelector as _VisualizationSelector
# Assign to module-level variables
create_hybrid_retriever = _create_hybrid_retriever
HeritageCustodianRetriever = _HeritageCustodianRetriever
create_typedb_retriever = _create_typedb_retriever
select_visualization = _select_visualization
VisualizationSelector = _VisualizationSelector
get_province_code = _get_province_code
RETRIEVERS_AVAILABLE = True
except ImportError as e:
logger.warning(f"Core retrievers not available: {e}")
# Provide a fallback get_province_code that returns None
def get_province_code(province_name: str | None) -> str | None:
"""Fallback when hybrid_retriever is not available."""
return None
# DSPy is optional - don't block retrievers if it's missing
try:
from glam_extractor.api.dspy_sparql import generate_sparql as _generate_sparql, configure_dspy as _configure_dspy
generate_sparql = _generate_sparql
configure_dspy = _configure_dspy
except ImportError as e:
logger.warning(f"DSPy SPARQL not available: {e}")
# Atomic query decomposition for geographic/type filtering and sub-task caching
decompose_query: Any = None
AtomicCacheManager: Any = None
DECOMPOSER_AVAILABLE = False
ATOMIC_CACHE_AVAILABLE = False
try:
from atomic_decomposer import (
decompose_query as _decompose_query,
AtomicCacheManager as _AtomicCacheManager,
)
decompose_query = _decompose_query
AtomicCacheManager = _AtomicCacheManager
DECOMPOSER_AVAILABLE = True
ATOMIC_CACHE_AVAILABLE = True
logger.info("Query decomposer and AtomicCacheManager loaded successfully")
except ImportError as e:
logger.info(f"Query decomposer not available: {e}")
# Cost tracker is optional - gracefully degrades if unavailable
COST_TRACKER_AVAILABLE = False
get_tracker = None
reset_tracker = None
try:
from cost_tracker import get_tracker as _get_tracker, reset_tracker as _reset_tracker
get_tracker = _get_tracker
reset_tracker = _reset_tracker
COST_TRACKER_AVAILABLE = True
logger.info("Cost tracker module loaded successfully")
except ImportError as e:
logger.info(f"Cost tracker not available (optional): {e}")
# Session manager for multi-turn conversation state
SESSION_MANAGER_AVAILABLE = False
get_session_manager = None
shutdown_session_manager = None
ConversationState = None
try:
from session_manager import (
get_session_manager as _get_session_manager,
shutdown_session_manager as _shutdown_session_manager,
ConversationState as _ConversationState,
)
get_session_manager = _get_session_manager
shutdown_session_manager = _shutdown_session_manager
ConversationState = _ConversationState
SESSION_MANAGER_AVAILABLE = True
logger.info("Session manager module loaded successfully")
except ImportError as e:
logger.info(f"Session manager not available (optional): {e}")
# Template-based SPARQL pipeline (deterministic, validated queries)
# This provides 65% precision vs 10% for LLM-only SPARQL generation
TEMPLATE_SPARQL_AVAILABLE = False
TemplateSPARQLPipeline: Any = None
get_template_pipeline: Any = None
_template_pipeline_instance: Any = None # Singleton for reuse
try:
from template_sparql import (
TemplateSPARQLPipeline as _TemplateSPARQLPipeline,
get_template_pipeline as _get_template_pipeline,
)
TemplateSPARQLPipeline = _TemplateSPARQLPipeline
get_template_pipeline = _get_template_pipeline
TEMPLATE_SPARQL_AVAILABLE = True
logger.info("Template SPARQL pipeline loaded successfully")
except ImportError as e:
logger.info(f"Template SPARQL pipeline not available (optional): {e}")
# Prometheus metrics for monitoring template hit rate, latency, etc.
METRICS_AVAILABLE = False
record_query = None
record_template_matching = None
record_template_tier = None
record_atomic_cache = None
record_atomic_subtask_cached = None
record_connection_pool = None
record_embedding_warmup = None
record_template_embedding_warmup = None
set_warmup_status = None
set_active_sessions = None
create_metrics_endpoint = None
get_template_hit_rate = None
get_all_performance_stats = None
try:
from metrics import (
record_query as _record_query,
record_template_matching as _record_template_matching,
record_template_tier as _record_template_tier,
record_atomic_cache as _record_atomic_cache,
record_atomic_subtask_cached as _record_atomic_subtask_cached,
record_connection_pool as _record_connection_pool,
record_embedding_warmup as _record_embedding_warmup,
record_template_embedding_warmup as _record_template_embedding_warmup,
set_warmup_status as _set_warmup_status,
set_active_sessions as _set_active_sessions,
create_metrics_endpoint as _create_metrics_endpoint,
get_template_hit_rate as _get_template_hit_rate,
get_all_performance_stats as _get_all_performance_stats,
PROMETHEUS_AVAILABLE,
)
record_query = _record_query
record_template_matching = _record_template_matching
record_template_tier = _record_template_tier
record_atomic_cache = _record_atomic_cache
record_atomic_subtask_cached = _record_atomic_subtask_cached
record_connection_pool = _record_connection_pool
record_embedding_warmup = _record_embedding_warmup
record_template_embedding_warmup = _record_template_embedding_warmup
set_warmup_status = _set_warmup_status
set_active_sessions = _set_active_sessions
create_metrics_endpoint = _create_metrics_endpoint
get_template_hit_rate = _get_template_hit_rate
get_all_performance_stats = _get_all_performance_stats
METRICS_AVAILABLE = PROMETHEUS_AVAILABLE
logger.info(f"Metrics module loaded (prometheus={PROMETHEUS_AVAILABLE})")
except ImportError as e:
logger.info(f"Metrics module not available (optional): {e}")
# Province detection for geographic filtering
DUTCH_PROVINCES = {
"noord-holland", "noordholland", "north holland", "north-holland",
"zuid-holland", "zuidholland", "south holland", "south-holland",
"utrecht", "gelderland", "noord-brabant", "noordbrabant", "brabant",
"north brabant", "limburg", "overijssel", "friesland", "fryslân",
"fryslan", "groningen", "drenthe", "flevoland", "zeeland",
}
def infer_location_level(location: str) -> str:
"""Infer whether location is city, province, or region.
Returns:
'province' if location is a Dutch province
'region' if location is a sub-provincial region
'city' otherwise
"""
location_lower = location.lower().strip()
if location_lower in DUTCH_PROVINCES:
return "province"
# Sub-provincial regions
regions = {"randstad", "veluwe", "achterhoek", "twente", "de betuwe", "betuwe"}
if location_lower in regions:
return "region"
return "city"
def extract_geographic_filters(question: str) -> dict[str, list[str] | None]:
"""Extract geographic filters from a question using query decomposition.
Returns:
dict with keys: region_codes, cities, institution_types
"""
filters: dict[str, list[str] | None] = {
"region_codes": None,
"cities": None,
"institution_types": None,
}
if not DECOMPOSER_AVAILABLE or not decompose_query:
return filters
# Check for explicit city markers BEFORE decomposition
# This overrides province disambiguation when user explicitly says "de stad"
question_lower = question.lower()
explicit_city_markers = [
"de stad ", "in de stad", "stad van", "gemeente ",
"the city of", "city of ", "in the city"
]
force_city = any(marker in question_lower for marker in explicit_city_markers)
try:
decomposed = decompose_query(question)
# Extract location and determine if it's a province or city
if decomposed.location:
location = decomposed.location
# If user explicitly said "de stad", treat as city even if it's a province name
if force_city:
filters["cities"] = [location]
logger.info(f"City filter (explicit): {location}")
else:
level = infer_location_level(location)
if level == "province":
# Convert province name to ISO 3166-2 code for Qdrant filtering
# e.g., "Noord-Holland" → "NH"
province_code = get_province_code(location)
if province_code:
filters["region_codes"] = [province_code]
logger.info(f"Province filter: {location}{province_code}")
elif level == "city":
filters["cities"] = [location]
logger.info(f"City filter: {location}")
# Extract institution type
if decomposed.institution_type:
# Map common types to enum values
type_mapping = {
"archive": "ARCHIVE",
"archief": "ARCHIVE",
"archieven": "ARCHIVE",
"museum": "MUSEUM",
"musea": "MUSEUM",
"museums": "MUSEUM",
"library": "LIBRARY",
"bibliotheek": "LIBRARY",
"bibliotheken": "LIBRARY",
"gallery": "GALLERY",
"galerie": "GALLERY",
}
inst_type = decomposed.institution_type.lower()
mapped_type = type_mapping.get(inst_type, inst_type.upper())
filters["institution_types"] = [mapped_type]
logger.info(f"Institution type filter: {mapped_type}")
except Exception as e:
logger.warning(f"Failed to extract geographic filters: {e}")
return filters
# Configuration
class Settings:
"""Application settings from environment variables."""
# API Configuration
api_title: str = "Heritage RAG API"
api_version: str = "1.0.0"
debug: bool = os.getenv("DEBUG", "false").lower() == "true"
# Valkey Cache
valkey_api_url: str = os.getenv("VALKEY_API_URL", "https://bronhouder.nl/api/cache")
cache_ttl: int = int(os.getenv("CACHE_TTL", "900")) # 15 minutes
# Qdrant Vector DB
# Production: Use URL-based client via bronhouder.nl/qdrant reverse proxy
qdrant_host: str = os.getenv("QDRANT_HOST", "localhost")
qdrant_port: int = int(os.getenv("QDRANT_PORT", "6333"))
qdrant_use_production: bool = os.getenv("QDRANT_USE_PRODUCTION", "true").lower() == "true"
qdrant_production_url: str = os.getenv("QDRANT_PRODUCTION_URL", "https://bronhouder.nl/qdrant")
# Multi-Embedding Support
# Enable to use named vectors with multiple embedding models (OpenAI 1536, MiniLM 384, BGE 768)
use_multi_embedding: bool = os.getenv("USE_MULTI_EMBEDDING", "true").lower() == "true"
preferred_embedding_model: str | None = os.getenv("PREFERRED_EMBEDDING_MODEL", None) # e.g., "minilm_384" or "openai_1536"
# Oxigraph SPARQL
# Production: Use bronhouder.nl/sparql reverse proxy
sparql_endpoint: str = os.getenv("SPARQL_ENDPOINT", "https://bronhouder.nl/sparql")
# TypeDB
# Note: TypeDB not exposed via reverse proxy - always use localhost
typedb_host: str = os.getenv("TYPEDB_HOST", "localhost")
typedb_port: int = int(os.getenv("TYPEDB_PORT", "1729"))
typedb_database: str = os.getenv("TYPEDB_DATABASE", "heritage_custodians")
typedb_use_production: bool = os.getenv("TYPEDB_USE_PRODUCTION", "false").lower() == "true" # Default off
# PostGIS/Geo API
# Production: Use bronhouder.nl/api/geo reverse proxy
postgis_url: str = os.getenv("POSTGIS_URL", "https://bronhouder.nl/api/geo")
# NOTE: DuckLake removed from RAG - it's for offline analytics only, not real-time retrieval
# RAG uses only Qdrant (vectors) and Oxigraph (SPARQL) for retrieval
# LLM Configuration
anthropic_api_key: str = os.getenv("ANTHROPIC_API_KEY", "")
openai_api_key: str = os.getenv("OPENAI_API_KEY", "")
huggingface_api_key: str = os.getenv("HUGGINGFACE_API_KEY", "")
groq_api_key: str = os.getenv("GROQ_API_KEY", "")
zai_api_token: str = os.getenv("ZAI_API_TOKEN", "")
default_model: str = os.getenv("DEFAULT_MODEL", "claude-opus-4-5-20251101")
# LLM Provider: "anthropic", "openai", "huggingface", "zai" (FREE), or "groq" (FREE)
llm_provider: str = os.getenv("LLM_PROVIDER", "anthropic")
# LLM Model: Specific model to use. Defaults depend on provider.
# For Z.AI: "glm-4.5-flash" (fast, recommended) or "glm-4.6" (reasoning, slow)
llm_model: str = os.getenv("LLM_MODEL", "glm-4.5-flash")
# Fast LM Provider for routing/extraction: "openai" (fast ~1-2s) or "zai" (FREE but slow ~13s)
# Default to openai for speed. Set to "zai" to save costs (free but adds ~12s latency)
fast_lm_provider: str = os.getenv("FAST_LM_PROVIDER", "openai")
# Retrieval weights
vector_weight: float = float(os.getenv("VECTOR_WEIGHT", "0.5"))
graph_weight: float = float(os.getenv("GRAPH_WEIGHT", "0.3"))
typedb_weight: float = float(os.getenv("TYPEDB_WEIGHT", "0.2"))
settings = Settings()
# Enums and Models
class QueryIntent(str, Enum):
"""Detected query intent for routing."""
GEOGRAPHIC = "geographic" # Location-based queries
STATISTICAL = "statistical" # Counts, aggregations
RELATIONAL = "relational" # Relationships between entities
TEMPORAL = "temporal" # Historical, timeline queries
SEARCH = "search" # General text search
DETAIL = "detail" # Specific entity lookup
class DataSource(str, Enum):
"""Available data sources for RAG retrieval.
NOTE: DuckLake removed - it's for offline analytics only, not real-time RAG retrieval.
RAG uses Qdrant (vectors) and Oxigraph (SPARQL) as primary backends.
"""
QDRANT = "qdrant"
SPARQL = "sparql"
TYPEDB = "typedb"
POSTGIS = "postgis"
CACHE = "cache"
# DUCKLAKE removed - use DuckLake separately for offline analytics/dashboards
@dataclass
class RetrievalResult:
"""Result from a single retriever."""
source: DataSource
items: list[dict[str, Any]]
score: float = 0.0
query_time_ms: float = 0.0
metadata: dict[str, Any] = field(default_factory=dict)
class QueryRequest(BaseModel):
"""RAG query request."""
question: str = Field(..., description="Natural language question")
language: str = Field(default="nl", description="Language code (nl or en)")
context: list[dict[str, Any]] = Field(default=[], description="Conversation history")
sources: list[DataSource] | None = Field(
default=None,
description="Data sources to query. If None, auto-routes based on query intent.",
)
k: int = Field(default=10, description="Number of results per source")
include_visualization: bool = Field(default=True, description="Include visualization config")
embedding_model: str | None = Field(
default=None,
description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available."
)
stream: bool = Field(default=False, description="Stream response")
class QueryResponse(BaseModel):
"""RAG query response."""
question: str
sparql: str | None = None
results: list[dict[str, Any]]
visualization: dict[str, Any] | None = None
sources_used: list[DataSource]
cache_hit: bool = False
query_time_ms: float
result_count: int
class SPARQLRequest(BaseModel):
"""SPARQL generation request."""
question: str
language: str = "nl"
context: list[dict[str, Any]] = []
use_rag: bool = True
class SPARQLResponse(BaseModel):
"""SPARQL generation response."""
sparql: str
explanation: str
rag_used: bool
retrieved_passages: list[str] = []
class SPARQLExecuteRequest(BaseModel):
"""Execute a SPARQL query directly against the knowledge graph."""
sparql_query: str = Field(..., description="SPARQL query to execute")
timeout: float = Field(default=30.0, ge=1.0, le=120.0, description="Query timeout in seconds")
class SPARQLExecuteResponse(BaseModel):
"""Response from direct SPARQL execution."""
results: list[dict[str, Any]] = Field(default=[], description="Query results as list of dicts")
result_count: int = Field(default=0, description="Number of results")
query_time_ms: float = Field(default=0.0, description="Query execution time in milliseconds")
error: str | None = Field(default=None, description="Error message if query failed")
class SPARQLRerunRequest(BaseModel):
"""Re-run RAG pipeline with modified SPARQL results injected into context."""
sparql_query: str = Field(..., description="Modified SPARQL query to execute")
original_question: str = Field(..., description="Original user question")
conversation_history: list[dict[str, Any]] = Field(
default=[],
description="Previous conversation turns"
)
language: str = Field(default="nl", description="Language code (nl or en)")
llm_provider: str | None = Field(default=None, description="LLM provider to use")
llm_model: str | None = Field(default=None, description="Specific LLM model")
class SPARQLRerunResponse(BaseModel):
"""Response from re-running RAG with modified SPARQL context."""
results: list[dict[str, Any]] = Field(default=[], description="SPARQL query results")
answer: str = Field(default="", description="Re-generated answer based on modified SPARQL results")
sparql_result_count: int = Field(default=0, description="Number of SPARQL results")
query_time_ms: float = Field(default=0.0, description="Total processing time")
class TypeDBSearchRequest(BaseModel):
"""TypeDB search request."""
query: str = Field(..., description="Search query (name, type, or location)")
search_type: str = Field(
default="semantic",
description="Search type: semantic, name, type, or location"
)
k: int = Field(default=10, ge=1, le=100, description="Number of results")
class TypeDBSearchResponse(BaseModel):
"""TypeDB search response."""
query: str
search_type: str
results: list[dict[str, Any]]
result_count: int
query_time_ms: float
class PersonSearchRequest(BaseModel):
"""Person/staff search request."""
query: str = Field(..., description="Search query for person/staff (e.g., 'Wie werkt er in het Nationaal Archief?')")
k: int = Field(default=10, ge=1, le=100, description="Number of results to return")
filter_custodian: str | None = Field(default=None, description="Filter by custodian slug (e.g., 'nationaal-archief')")
only_heritage_relevant: bool = Field(default=False, description="Only return heritage-relevant staff")
embedding_model: str | None = Field(
default=None,
description="Embedding model to use (e.g., 'minilm_384', 'openai_1536'). If None, auto-selects best available."
)
class PersonSearchResponse(BaseModel):
"""Person/staff search response with JSON-LD linked data."""
context: dict[str, Any] | None = Field(
default=None,
alias="@context",
description="JSON-LD context for linked data semantic interoperability"
)
query: str
results: list[dict[str, Any]]
result_count: int
query_time_ms: float
collection_stats: dict[str, Any] | None = None
embedding_model_used: str | None = None
model_config = {"populate_by_name": True}
class DSPyQueryRequest(BaseModel):
"""DSPy RAG query request with conversation support."""
question: str = Field(..., description="Natural language question")
language: str = Field(default="nl", description="Language code (nl or en)")
context: list[dict[str, Any]] = Field(
default=[],
description="Conversation history as list of {question, answer} dicts"
)
session_id: str | None = Field(
default=None,
description="Session ID for multi-turn conversations. If provided, session state is used to resolve follow-up questions like 'En in Enschede?'. If None, a new session is created and returned."
)
include_visualization: bool = Field(default=True, description="Include visualization config")
embedding_model: str | None = Field(
default=None,
description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available."
)
llm_provider: str | None = Field(
default=None,
description="LLM provider to use for this request: 'zai', 'anthropic', 'huggingface', or 'openai'. If None, uses server default (LLM_PROVIDER env)."
)
llm_model: str | None = Field(
default=None,
description="Specific LLM model to use (e.g., 'glm-4.6', 'claude-sonnet-4-5-20250929', 'gpt-4o'). If None, uses provider default."
)
skip_cache: bool = Field(
default=False,
description="Bypass cache lookup and force fresh LLM query. Useful for debugging."
)
class LLMResponseMetadata(BaseModel):
"""LLM response provenance metadata (aligned with LinkML LLMResponse schema).
Captures GLM 4.7 Interleaved Thinking chain-of-thought reasoning and
full API response metadata for audit trails and debugging.
See: schemas/20251121/linkml/modules/classes/LLMResponse.yaml
"""
# Core response content
content: str | None = None # The final LLM response text
reasoning_content: str | None = None # GLM 4.7 Interleaved Thinking chain-of-thought
# Model identification
model: str | None = None # Model identifier (e.g., 'glm-4.7', 'claude-3-opus')
provider: str | None = None # Provider enum: zai, anthropic, openai, huggingface, groq
# Request tracking
request_id: str | None = None # Provider-assigned request ID
created: str | None = None # ISO 8601 timestamp of response generation
# Token usage (for cost estimation and monitoring)
prompt_tokens: int | None = None # Tokens in input prompt
completion_tokens: int | None = None # Tokens in response (content + reasoning)
total_tokens: int | None = None # Total tokens used
cached_tokens: int | None = None # Tokens served from provider cache
# Response metadata
finish_reason: str | None = None # stop, length, tool_calls, content_filter
latency_ms: int | None = None # Response latency in milliseconds
# GLM 4.7 Thinking Mode configuration
thinking_mode: str | None = None # enabled, disabled, interleaved, preserved
clear_thinking: bool | None = None # False = Preserved Thinking enabled
class DSPyQueryResponse(BaseModel):
"""DSPy RAG query response."""
question: str
resolved_question: str | None = None
answer: str
sources_used: list[str] = []
visualization: dict[str, Any] | None = None
retrieved_results: list[dict[str, Any]] | None = None # Raw retrieved data for frontend visualization
query_type: str | None = None # "person" or "institution" - helps frontend choose visualization
query_time_ms: float = 0.0
conversation_turn: int = 0
embedding_model_used: str | None = None # Which embedding model was used for the search
llm_provider_used: str | None = None # Which LLM provider handled this request (zai, anthropic, huggingface, openai)
llm_model_used: str | None = None # Which specific LLM model was used (e.g., 'glm-4.6', 'claude-sonnet-4-5-20250929')
# Cost tracking fields (from cost_tracker module)
timing_ms: float | None = None # Total pipeline timing from cost tracker
cost_usd: float | None = None # Estimated LLM cost in USD
timing_breakdown: dict[str, float] | None = None # Per-stage timing breakdown
# Cache tracking
cache_hit: bool = False # Whether response was served from cache
# LLM response provenance (GLM 4.7 Thinking Mode support)
llm_response: LLMResponseMetadata | None = None # Full LLM response metadata including reasoning_content
# Session management for multi-turn conversations
session_id: str | None = None # Session ID for continuing conversation (returned even if not provided in request)
# Template SPARQL tracking (for monitoring template hit rate vs LLM fallback)
template_used: bool = False # Whether template-based SPARQL was used (vs LLM generation)
template_id: str | None = None # Which template was used (e.g., "institution_by_city", "person_by_name")
# Factual query mode - skip LLM generation for count/list queries
factual_result: bool = False # True if this is a direct SPARQL result (no LLM prose generation)
sparql_query: str | None = None # The SPARQL query that was executed (for transparency)
# Rule 46: Epistemic Provenance Tracking
epistemic_provenance: dict[str, Any] | None = None # Full provenance chain for transparency
def extract_llm_response_metadata(
lm: Any,
provider: str | None = None,
latency_ms: int | None = None,
) -> LLMResponseMetadata | None:
"""Extract LLM response metadata from DSPy LM history.
DSPy stores the raw API response in lm.history[-1]["response"], which includes:
- choices[0].message.content (final response text)
- choices[0].message.reasoning_content (GLM 4.7 Interleaved Thinking)
- usage.prompt_tokens, completion_tokens, total_tokens
- model, created, id, finish_reason
This enables capturing GLM 4.7's chain-of-thought reasoning for provenance.
Args:
lm: DSPy LM instance with history attribute
provider: LLM provider name (zai, anthropic, openai, etc.)
latency_ms: Response latency in milliseconds
Returns:
LLMResponseMetadata or None if history is empty
"""
try:
# Check if LM has history
if not hasattr(lm, "history") or not lm.history:
logger.debug("No LM history available for metadata extraction")
return None
# Get the last history entry (most recent LLM call)
last_entry = lm.history[-1]
response = last_entry.get("response")
if response is None:
logger.debug("No response in LM history entry")
return None
# Extract content and reasoning_content from the response
content = None
reasoning_content = None
finish_reason = None
if hasattr(response, "choices") and response.choices:
choice = response.choices[0]
if hasattr(choice, "message"):
message = choice.message
content = getattr(message, "content", None)
# GLM 4.7 Interleaved Thinking - check for reasoning_content
reasoning_content = getattr(message, "reasoning_content", None)
elif isinstance(choice, dict):
content = choice.get("text") or choice.get("message", {}).get("content")
reasoning_content = choice.get("message", {}).get("reasoning_content")
# Extract finish_reason
finish_reason = getattr(choice, "finish_reason", None)
if finish_reason is None and isinstance(choice, dict):
finish_reason = choice.get("finish_reason")
# Extract usage statistics - handle both dict and object types
# (DSPy/OpenAI SDK may return CompletionUsage objects instead of dicts)
usage = last_entry.get("usage")
prompt_tokens = None
completion_tokens = None
total_tokens = None
cached_tokens = None
if usage is not None:
if hasattr(usage, "prompt_tokens"):
# It's an object (e.g., CompletionUsage from OpenAI SDK)
prompt_tokens = getattr(usage, "prompt_tokens", None)
completion_tokens = getattr(usage, "completion_tokens", None)
total_tokens = getattr(usage, "total_tokens", None)
prompt_details = getattr(usage, "prompt_tokens_details", None)
if prompt_details is not None:
cached_tokens = getattr(prompt_details, "cached_tokens", None)
elif isinstance(usage, dict):
# It's a plain dict
prompt_tokens = usage.get("prompt_tokens")
completion_tokens = usage.get("completion_tokens")
total_tokens = usage.get("total_tokens")
prompt_details = usage.get("prompt_tokens_details")
if isinstance(prompt_details, dict):
cached_tokens = prompt_details.get("cached_tokens")
# Extract model info
model = last_entry.get("response_model") or last_entry.get("model")
request_id = getattr(response, "id", None)
created = getattr(response, "created", None)
# Convert unix timestamp to ISO 8601 if needed
created_str = None
if created:
if isinstance(created, (int, float)):
import datetime
created_str = datetime.datetime.fromtimestamp(created, tz=datetime.timezone.utc).isoformat()
else:
created_str = str(created)
# Determine thinking mode (GLM 4.7 specific)
thinking_mode = None
if reasoning_content:
# If we got reasoning_content, the model used interleaved thinking
thinking_mode = "interleaved"
metadata = LLMResponseMetadata(
content=content,
reasoning_content=reasoning_content,
model=model,
provider=provider,
request_id=request_id,
created=created_str,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
cached_tokens=cached_tokens,
finish_reason=finish_reason,
latency_ms=latency_ms,
thinking_mode=thinking_mode,
)
if reasoning_content:
logger.info(
f"Captured GLM 4.7 reasoning_content ({len(reasoning_content)} chars) "
f"from {provider}/{model}"
)
return metadata
except Exception as e:
logger.warning(f"Failed to extract LLM response metadata: {e}")
return None
# Cache Client
class ValkeyClient:
"""Client for Valkey semantic cache API."""
def __init__(self, base_url: str = settings.valkey_api_url):
self.base_url = base_url.rstrip("/")
self._client: httpx.AsyncClient | None = None
@property
async def client(self) -> httpx.AsyncClient:
"""Get or create async HTTP client."""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(timeout=30.0)
return self._client
def _cache_key(self, question: str, sources: list[DataSource] | None) -> str:
"""Generate cache key from question and sources."""
if sources:
sources_str = ",".join(sorted(s.value for s in sources))
else:
sources_str = "auto"
key_str = f"{question.lower().strip()}:{sources_str}"
return hashlib.sha256(key_str.encode()).hexdigest()[:32]
async def get(self, question: str, sources: list[DataSource] | None) -> dict[str, Any] | None:
"""Get cached response using semantic cache lookup."""
try:
client = await self.client
response = await client.post(
f"{self.base_url}/cache/lookup",
json={
"query": question,
# Higher threshold (0.97) to avoid false cache hits on semantically
# similar but geographically different queries
"similarity_threshold": 0.97,
},
)
if response.status_code == 200:
data = response.json()
if data.get("found") and data.get("entry"):
logger.info(f"Cache hit for question: {question[:50]}... (similarity: {data.get('similarity', 0):.3f})")
return data["entry"].get("response") # type: ignore[no-any-return]
return None
except Exception as e:
logger.warning(f"Cache get failed: {e}")
return None
async def set(
self,
question: str,
sources: list[DataSource] | None,
response: dict[str, Any],
ttl: int = settings.cache_ttl,
) -> bool:
"""Cache response using semantic cache store."""
try:
client = await self.client
# Build CachedResponse schema
cached_response = {
"answer": response.get("answer", ""),
"sparql_query": response.get("sparql_query"),
"typeql_query": response.get("typeql_query"),
"visualization_type": response.get("visualization_type"),
"visualization_data": response.get("visualization_data"),
"sources": response.get("sources", []),
"confidence": response.get("confidence", 0.0),
"context": response.get("context"),
}
await client.post(
f"{self.base_url}/cache/store",
json={
"query": question,
"response": cached_response,
"language": response.get("language", "nl"),
"model": response.get("llm_model", "unknown"),
},
)
logger.debug(f"Cached response for: {question[:50]}...")
return True
except Exception as e:
logger.warning(f"Cache set failed: {e}")
return False
def _dspy_cache_key(
self,
question: str,
language: str,
llm_provider: str | None,
embedding_model: str | None,
context_hash: str | None = None,
) -> str:
"""Generate cache key for DSPy query responses.
Cache key components:
- Question text (normalized)
- Language code
- LLM provider (different providers give different answers)
- Embedding model (affects retrieval results)
- Context hash (for multi-turn conversations)
"""
components = [
question.lower().strip(),
language,
llm_provider or "default",
embedding_model or "auto",
context_hash or "no_context",
]
key_str = ":".join(components)
return f"dspy:{hashlib.sha256(key_str.encode()).hexdigest()[:32]}"
async def get_dspy(
self,
question: str,
language: str,
llm_provider: str | None,
embedding_model: str | None,
context: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Get cached DSPy response using semantic cache lookup.
Cache hits are filtered by LLM provider to ensure responses from different
providers (e.g., anthropic vs huggingface) are cached separately.
"""
try:
client = await self.client
response = await client.post(
f"{self.base_url}/cache/lookup",
json={
"query": question,
"language": language,
# Higher threshold (0.97) to avoid false cache hits on semantically
# similar but geographically different queries like
# "archieven in Groningen" vs "archieven in de stad Groningen"
"similarity_threshold": 0.97,
},
)
if response.status_code == 200:
data = response.json()
if data.get("found") and data.get("entry"):
cached_response = data["entry"].get("response")
# Verify the cached response matches the requested LLM provider
# The model field in cache contains the provider (e.g., "anthropic", "huggingface")
cached_model = data["entry"].get("model")
requested_provider = llm_provider or settings.llm_provider
if cached_model and cached_model != requested_provider:
logger.info(
f"DSPy cache miss (provider mismatch): cached={cached_model}, requested={requested_provider}"
)
return None
similarity = data.get("similarity", 0)
method = data.get("method", "unknown")
logger.info(f"DSPy cache hit for question: {question[:50]}... (similarity: {similarity:.3f}, method: {method}, provider: {cached_model})")
return cached_response # type: ignore[no-any-return]
return None
except Exception as e:
logger.warning(f"DSPy cache get failed: {e}")
return None
async def set_dspy(
self,
question: str,
language: str,
llm_provider: str | None,
embedding_model: str | None,
response: dict[str, Any],
context: list[dict[str, Any]] | None = None,
ttl: int = settings.cache_ttl,
) -> bool:
"""Cache DSPy response using semantic cache store.
Maps DSPyQueryResponse fields to CachedResponse schema:
- sources_used -> sources
- visualization -> visualization_type + visualization_data
- Additional context from query_type, resolved_question, etc.
"""
try:
client = await self.client
# Extract visualization components if present
visualization = response.get("visualization")
viz_type = None
viz_data = None
if visualization:
viz_type = visualization.get("type")
viz_data = visualization.get("data")
# Build CachedResponse schema matching the Valkey API
# Maps DSPyQueryResponse fields to CachedResponse expected fields
#
# IMPORTANT: Include llm_response metadata (GLM 4.7 reasoning_content) in cache
# so that cached responses also return the chain-of-thought reasoning.
llm_response_data = None
if response.get("llm_response"):
llm_resp = response["llm_response"]
# Handle both dict and LLMResponseMetadata object
if hasattr(llm_resp, "model_dump"):
llm_response_data = llm_resp.model_dump()
elif isinstance(llm_resp, dict):
llm_response_data = llm_resp
cached_response = {
"answer": response.get("answer", ""),
"sparql_query": None, # DSPy doesn't generate SPARQL
"typeql_query": None, # DSPy doesn't generate TypeQL
"visualization_type": viz_type,
"visualization_data": viz_data,
"sources": response.get("sources_used", []), # DSPy uses sources_used
"confidence": 0.95, # DSPy responses are generally high confidence
"context": {
"query_type": response.get("query_type"),
"resolved_question": response.get("resolved_question"),
"retrieved_results": response.get("retrieved_results"),
"embedding_model": response.get("embedding_model_used"),
"llm_model": response.get("llm_model_used"),
"original_context": context,
"llm_response": llm_response_data, # GLM 4.7 reasoning_content
},
}
result = await client.post(
f"{self.base_url}/cache/store",
json={
"query": question,
"response": cached_response,
"language": language,
"model": llm_provider or "unknown",
},
)
# Check if store was successful
if result.status_code == 200:
logger.info(f"✓ Cached DSPy response for: {question[:50]}...")
return True
else:
logger.warning(f"Cache store returned {result.status_code}: {result.text[:200]}")
return False
except Exception as e:
logger.warning(f"DSPy cache set failed: {e}")
return False
async def close(self) -> None:
"""Close HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
# Query Router
class QueryRouter:
"""Routes queries to appropriate data sources based on intent."""
def __init__(self) -> None:
self.intent_keywords = {
QueryIntent.GEOGRAPHIC: [
"map", "kaart", "where", "waar", "location", "locatie",
"city", "stad", "country", "land", "region", "gebied",
"coordinates", "coördinaten", "near", "nearby", "in de buurt",
],
QueryIntent.STATISTICAL: [
"how many", "hoeveel", "count", "aantal", "total", "totaal",
"average", "gemiddeld", "distribution", "verdeling",
"percentage", "statistics", "statistiek", "most", "meest",
],
QueryIntent.RELATIONAL: [
"related", "gerelateerd", "connected", "verbonden",
"relationship", "relatie", "network", "netwerk",
"parent", "child", "merged", "fusie", "member of",
],
QueryIntent.TEMPORAL: [
"history", "geschiedenis", "timeline", "tijdlijn",
"when", "wanneer", "founded", "opgericht", "closed", "gesloten",
"over time", "evolution", "change", "verandering",
],
QueryIntent.DETAIL: [
"details", "information", "informatie", "about", "over",
"specific", "specifiek", "what is", "wat is",
],
}
# NOTE: DuckLake removed from RAG - it's for offline analytics only
# Statistical queries now use SPARQL aggregations (COUNT, SUM, AVG, GROUP BY)
self.source_routing = {
QueryIntent.GEOGRAPHIC: [DataSource.POSTGIS, DataSource.QDRANT, DataSource.SPARQL],
QueryIntent.STATISTICAL: [DataSource.SPARQL, DataSource.QDRANT], # SPARQL aggregations
QueryIntent.RELATIONAL: [DataSource.TYPEDB, DataSource.SPARQL],
QueryIntent.TEMPORAL: [DataSource.TYPEDB, DataSource.SPARQL],
QueryIntent.SEARCH: [DataSource.QDRANT, DataSource.SPARQL],
QueryIntent.DETAIL: [DataSource.SPARQL, DataSource.QDRANT],
}
def detect_intent(self, question: str) -> QueryIntent:
"""Detect query intent from question text."""
import re
question_lower = question.lower()
intent_scores = {intent: 0 for intent in QueryIntent}
for intent, keywords in self.intent_keywords.items():
for keyword in keywords:
# Use word boundary matching to avoid partial matches
# e.g., "land" should not match "netherlands"
pattern = r'\b' + re.escape(keyword) + r'\b'
if re.search(pattern, question_lower):
intent_scores[intent] += 1
max_intent = max(intent_scores, key=intent_scores.get) # type: ignore
if intent_scores[max_intent] == 0:
return QueryIntent.SEARCH
return max_intent
def get_sources(
self,
question: str,
requested_sources: list[DataSource] | None = None,
) -> tuple[QueryIntent, list[DataSource]]:
"""Get optimal sources for a query.
Args:
question: User's question
requested_sources: Explicitly requested sources (overrides routing)
Returns:
Tuple of (detected_intent, list_of_sources)
"""
intent = self.detect_intent(question)
if requested_sources:
return intent, requested_sources
return intent, self.source_routing.get(intent, [DataSource.QDRANT])
# Multi-Source Retriever
class MultiSourceRetriever:
"""Orchestrates retrieval across multiple data sources."""
def __init__(self) -> None:
self.cache = ValkeyClient()
self.router = QueryRouter()
# Initialize retrievers lazily
self._qdrant: HybridRetriever | None = None
self._typedb: TypeDBRetriever | None = None
self._sparql_client: httpx.AsyncClient | None = None
self._postgis_client: httpx.AsyncClient | None = None
# NOTE: DuckLake client removed - DuckLake is for offline analytics only
@property
def qdrant(self) -> HybridRetriever | None:
"""Lazy-load Qdrant hybrid retriever with multi-embedding support."""
if self._qdrant is None and RETRIEVERS_AVAILABLE:
try:
self._qdrant = create_hybrid_retriever(
use_production=settings.qdrant_use_production,
use_multi_embedding=settings.use_multi_embedding,
preferred_embedding_model=settings.preferred_embedding_model,
)
except Exception as e:
logger.warning(f"Failed to initialize Qdrant: {e}")
return self._qdrant
@property
def typedb(self) -> TypeDBRetriever | None:
"""Lazy-load TypeDB retriever."""
if self._typedb is None and RETRIEVERS_AVAILABLE:
try:
self._typedb = create_typedb_retriever(
use_production=settings.typedb_use_production # Use TypeDB-specific setting
)
except Exception as e:
logger.warning(f"Failed to initialize TypeDB: {e}")
return self._typedb
async def _get_sparql_client(self) -> httpx.AsyncClient:
"""Get SPARQL HTTP client with connection pooling.
Connection pooling improves performance by reusing TCP connections
instead of creating new ones for each request.
"""
if self._sparql_client is None or self._sparql_client.is_closed:
self._sparql_client = httpx.AsyncClient(
timeout=30.0,
limits=httpx.Limits(
max_connections=20, # Max total connections
max_keepalive_connections=10, # Keep-alive connections in pool
keepalive_expiry=30.0, # Seconds to keep idle connections
),
)
# Record connection pool metrics
if record_connection_pool:
record_connection_pool(client="sparql", pool_size=20, available=20)
return self._sparql_client
async def _get_postgis_client(self) -> httpx.AsyncClient:
"""Get PostGIS HTTP client with connection pooling."""
if self._postgis_client is None or self._postgis_client.is_closed:
self._postgis_client = httpx.AsyncClient(
timeout=30.0,
limits=httpx.Limits(
max_connections=10,
max_keepalive_connections=5,
keepalive_expiry=30.0,
),
)
# Record connection pool metrics
if record_connection_pool:
record_connection_pool(client="postgis", pool_size=10, available=10)
return self._postgis_client
# NOTE: _get_ducklake_client removed - DuckLake is for offline analytics only, not RAG retrieval
async def retrieve_from_qdrant(
self,
query: str,
k: int = 10,
embedding_model: str | None = None,
region_codes: list[str] | None = None,
cities: list[str] | None = None,
institution_types: list[str] | None = None,
) -> RetrievalResult:
"""Retrieve from Qdrant vector + SPARQL hybrid search.
Args:
query: Search query
k: Number of results to return
embedding_model: Optional embedding model to use (e.g., 'minilm_384', 'openai_1536')
region_codes: Filter by province/region codes (e.g., ['NH', 'ZH'])
cities: Filter by city names (e.g., ['Amsterdam', 'Rotterdam'])
institution_types: Filter by institution types (e.g., ['ARCHIVE', 'MUSEUM'])
"""
start = asyncio.get_event_loop().time()
items = []
if self.qdrant:
try:
results = self.qdrant.search(
query,
k=k,
using=embedding_model,
region_codes=region_codes,
cities=cities,
institution_types=institution_types,
)
items = [r.to_dict() for r in results]
except Exception as e:
logger.error(f"Qdrant retrieval failed: {e}")
elapsed = (asyncio.get_event_loop().time() - start) * 1000
return RetrievalResult(
source=DataSource.QDRANT,
items=items,
score=max((r.get("scores", {}).get("combined", 0) for r in items), default=0),
query_time_ms=elapsed,
)
async def retrieve_from_sparql(
self,
query: str,
k: int = 10,
) -> RetrievalResult:
"""Retrieve from SPARQL endpoint.
Uses TEMPLATE-FIRST approach:
1. Try template-based SPARQL generation (deterministic, validated)
2. Fall back to LLM-based generation only if no template matches
Template approach provides 65% precision vs 10% for LLM-only.
"""
global _template_pipeline_instance
start = asyncio.get_event_loop().time()
items = []
sparql_query = ""
template_used = False
try:
# ===================================================================
# STEP 1: Try TEMPLATE-BASED SPARQL generation (preferred)
# ===================================================================
if TEMPLATE_SPARQL_AVAILABLE and get_template_pipeline:
try:
# Get or create singleton pipeline instance
if _template_pipeline_instance is None:
_template_pipeline_instance = get_template_pipeline()
logger.info("[SPARQL] Template pipeline initialized for MultiSourceRetriever")
# Run template matching in thread pool (DSPy is synchronous)
template_result = await asyncio.to_thread(
_template_pipeline_instance,
question=query,
conversation_state=None, # No conversation state in simple retriever
language="nl"
)
if template_result.matched and template_result.sparql:
sparql_query = template_result.sparql
template_used = True
logger.info(f"[SPARQL] Template match: '{template_result.template_id}' "
f"(confidence={template_result.confidence:.2f})")
else:
logger.info(f"[SPARQL] No template match: {template_result.reasoning}")
except Exception as e:
logger.warning(f"[SPARQL] Template pipeline failed: {e}")
# ===================================================================
# STEP 2: Fall back to LLM-BASED SPARQL generation
# ===================================================================
if not template_used and RETRIEVERS_AVAILABLE and generate_sparql:
logger.info("[SPARQL] Falling back to LLM-based SPARQL generation")
sparql_result = generate_sparql(query, language="nl", use_rag=False)
sparql_query = sparql_result.get("sparql", "")
# ===================================================================
# STEP 3: Execute the SPARQL query
# ===================================================================
if sparql_query:
logger.debug(f"[SPARQL] Executing query:\n{sparql_query[:500]}...")
client = await self._get_sparql_client()
response = await client.post(
settings.sparql_endpoint,
data={"query": sparql_query},
headers={"Accept": "application/sparql-results+json"},
)
if response.status_code == 200:
data = response.json()
bindings = data.get("results", {}).get("bindings", [])
items = [
{key: val.get("value") for key, val in b.items()}
for b in bindings[:k]
]
logger.info(f"[SPARQL] Query returned {len(items)} results "
f"(template={template_used})")
else:
logger.warning(f"[SPARQL] Query failed with status {response.status_code}: "
f"{response.text[:200]}")
except Exception as e:
logger.error(f"SPARQL retrieval failed: {e}")
elapsed = (asyncio.get_event_loop().time() - start) * 1000
return RetrievalResult(
source=DataSource.SPARQL,
items=items,
score=1.0 if items else 0.0,
query_time_ms=elapsed,
)
async def retrieve_from_typedb(
self,
query: str,
k: int = 10,
) -> RetrievalResult:
"""Retrieve from TypeDB knowledge graph."""
start = asyncio.get_event_loop().time()
items = []
if self.typedb:
try:
results = self.typedb.semantic_search(query, k=k)
items = [r.to_dict() for r in results]
except Exception as e:
logger.error(f"TypeDB retrieval failed: {e}")
elapsed = (asyncio.get_event_loop().time() - start) * 1000
return RetrievalResult(
source=DataSource.TYPEDB,
items=items,
score=max((r.get("relevance_score", 0) for r in items), default=0),
query_time_ms=elapsed,
)
async def retrieve_from_postgis(
self,
query: str,
k: int = 10,
) -> RetrievalResult:
"""Retrieve from PostGIS geospatial database."""
start = asyncio.get_event_loop().time()
# Extract location from query for geospatial search
# This is a simplified implementation
items = []
try:
client = await self._get_postgis_client()
# Try to detect city name for bbox search
query_lower = query.lower()
# Simple city detection
cities = {
"amsterdam": {"lat": 52.3676, "lon": 4.9041},
"rotterdam": {"lat": 51.9244, "lon": 4.4777},
"den haag": {"lat": 52.0705, "lon": 4.3007},
"utrecht": {"lat": 52.0907, "lon": 5.1214},
}
for city, coords in cities.items():
if city in query_lower:
# Query PostGIS for nearby institutions
response = await client.get(
f"{settings.postgis_url}/api/institutions/nearby",
params={
"lat": coords["lat"],
"lon": coords["lon"],
"radius_km": 10,
"limit": k,
},
)
if response.status_code == 200:
items = response.json()
break
except Exception as e:
logger.error(f"PostGIS retrieval failed: {e}")
elapsed = (asyncio.get_event_loop().time() - start) * 1000
return RetrievalResult(
source=DataSource.POSTGIS,
items=items,
score=1.0 if items else 0.0,
query_time_ms=elapsed,
)
# NOTE: retrieve_from_ducklake removed - DuckLake is for offline analytics only, not RAG retrieval
# Statistical queries now use SPARQL aggregations (COUNT, SUM, AVG, GROUP BY) on Oxigraph
async def retrieve(
self,
question: str,
sources: list[DataSource],
k: int = 10,
embedding_model: str | None = None,
region_codes: list[str] | None = None,
cities: list[str] | None = None,
institution_types: list[str] | None = None,
) -> list[RetrievalResult]:
"""Retrieve from multiple sources concurrently.
Args:
question: User's question
sources: Data sources to query
k: Number of results per source
embedding_model: Optional embedding model for Qdrant (e.g., 'minilm_384', 'openai_1536')
region_codes: Filter by province/region codes (e.g., ['NH', 'ZH']) - Qdrant only
cities: Filter by city names (e.g., ['Amsterdam']) - Qdrant only
institution_types: Filter by institution types (e.g., ['ARCHIVE']) - Qdrant only
Returns:
List of RetrievalResult from each source
"""
tasks = []
for source in sources:
if source == DataSource.QDRANT:
tasks.append(self.retrieve_from_qdrant(
question,
k,
embedding_model,
region_codes=region_codes,
cities=cities,
institution_types=institution_types,
))
elif source == DataSource.SPARQL:
tasks.append(self.retrieve_from_sparql(question, k))
elif source == DataSource.TYPEDB:
tasks.append(self.retrieve_from_typedb(question, k))
elif source == DataSource.POSTGIS:
tasks.append(self.retrieve_from_postgis(question, k))
# NOTE: DuckLake case removed - DuckLake is for offline analytics only
results = await asyncio.gather(*tasks, return_exceptions=True)
# Filter out exceptions
valid_results = []
for r in results:
if isinstance(r, RetrievalResult):
valid_results.append(r)
elif isinstance(r, Exception):
logger.error(f"Retrieval task failed: {r}")
return valid_results
def merge_results(
self,
results: list[RetrievalResult],
max_results: int = 20,
template_used: bool = False,
template_id: str | None = None,
) -> tuple[list[dict[str, Any]], EpistemicProvenance]:
"""Merge and deduplicate results from multiple sources.
Uses reciprocal rank fusion for score combination.
Returns merged items AND epistemic provenance tracking.
Rule 46: Epistemic Provenance Tracking
"""
from datetime import datetime, timezone
# Track items by GHCID for deduplication
merged: dict[str, dict[str, Any]] = {}
# Initialize provenance tracking
tier_counts: dict[DataTier, int] = {}
sources_queried = [r.source.value for r in results]
total_retrieved = sum(len(r.items) for r in results)
for result in results:
# Map DataSource to RetrievalSource
source_map = {
DataSource.QDRANT: RetrievalSource.QDRANT,
DataSource.SPARQL: RetrievalSource.SPARQL,
DataSource.TYPEDB: RetrievalSource.TYPEDB,
DataSource.POSTGIS: RetrievalSource.POSTGIS,
DataSource.CACHE: RetrievalSource.CACHE,
}
retrieval_source = source_map.get(result.source, RetrievalSource.LLM_SYNTHESIS)
for rank, item in enumerate(result.items):
ghcid = item.get("ghcid", item.get("id", f"unknown_{rank}"))
if ghcid not in merged:
merged[ghcid] = item.copy()
merged[ghcid]["_sources"] = []
merged[ghcid]["_rrf_score"] = 0.0
merged[ghcid]["_data_tier"] = None
# Infer data tier for this item
item_tier = infer_data_tier(item, retrieval_source)
tier_counts[item_tier] = tier_counts.get(item_tier, 0) + 1
# Track best (lowest) tier for each item
if merged[ghcid]["_data_tier"] is None:
merged[ghcid]["_data_tier"] = item_tier.value
else:
merged[ghcid]["_data_tier"] = min(merged[ghcid]["_data_tier"], item_tier.value)
# Reciprocal Rank Fusion
rrf_score = 1.0 / (60 + rank) # k=60 is standard
# Weight by source
source_weights = {
DataSource.QDRANT: settings.vector_weight,
DataSource.SPARQL: settings.graph_weight,
DataSource.TYPEDB: settings.typedb_weight,
DataSource.POSTGIS: 0.3,
}
weight = source_weights.get(result.source, 0.5)
merged[ghcid]["_rrf_score"] += rrf_score * weight
merged[ghcid]["_sources"].append(result.source.value)
# Sort by RRF score
sorted_items = sorted(
merged.values(),
key=lambda x: x.get("_rrf_score", 0),
reverse=True,
)
final_items = sorted_items[:max_results]
# Build epistemic provenance
provenance = EpistemicProvenance(
dataSource=EpistemicDataSource.RAG_PIPELINE,
dataTier=aggregate_data_tier(tier_counts),
sourceTimestamp=datetime.now(timezone.utc).isoformat(),
derivationChain=build_derivation_chain(
sources_used=sources_queried,
template_used=template_used,
template_id=template_id,
),
revalidationPolicy="weekly",
sourcesQueried=sources_queried,
totalRetrieved=total_retrieved,
totalAfterFusion=len(final_items),
dataTierBreakdown={
f"tier_{tier.value}": count
for tier, count in tier_counts.items()
},
templateUsed=template_used,
templateId=template_id,
)
return final_items, provenance
async def close(self) -> None:
"""Clean up resources."""
await self.cache.close()
if self._sparql_client:
await self._sparql_client.aclose()
if self._postgis_client:
await self._postgis_client.aclose()
if self._qdrant:
self._qdrant.close()
if self._typedb:
self._typedb.close()
def search_persons(
self,
query: str,
k: int = 10,
filter_custodian: str | None = None,
only_heritage_relevant: bool = False,
using: str | None = None,
) -> list[Any]:
"""Search for persons/staff in the heritage_persons collection.
Delegates to HybridRetriever.search_persons() if available.
Args:
query: Search query
k: Number of results
filter_custodian: Optional custodian slug to filter by
only_heritage_relevant: Only return heritage-relevant staff
using: Optional embedding model to use (e.g., 'minilm_384', 'openai_1536')
Returns:
List of RetrievedPerson objects
"""
if self.qdrant:
try:
return self.qdrant.search_persons( # type: ignore[no-any-return]
query=query,
k=k,
filter_custodian=filter_custodian,
only_heritage_relevant=only_heritage_relevant,
using=using,
)
except Exception as e:
logger.error(f"Person search failed: {e}")
return []
def get_stats(self) -> dict[str, Any]:
"""Get statistics from all retrievers.
Returns combined stats from Qdrant (including persons collection) and TypeDB.
"""
stats = {}
if self.qdrant:
try:
qdrant_stats = self.qdrant.get_stats()
stats.update(qdrant_stats)
except Exception as e:
logger.warning(f"Failed to get Qdrant stats: {e}")
if self.typedb:
try:
typedb_stats = self.typedb.get_stats()
stats["typedb"] = typedb_stats
except Exception as e:
logger.warning(f"Failed to get TypeDB stats: {e}")
return stats
# Global instances
retriever: MultiSourceRetriever | None = None
viz_selector: VisualizationSelector | None = None
dspy_pipeline: Any = None # HeritageRAGPipeline instance (loaded with optimized model)
atomic_cache_manager: Any = None # AtomicCacheManager for sub-task caching (40-70% hit rate vs 5-15% for full queries)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Application lifespan manager."""
global retriever, viz_selector, dspy_pipeline, atomic_cache_manager
# Startup
logger.info("Starting Heritage RAG API...")
retriever = MultiSourceRetriever()
if RETRIEVERS_AVAILABLE:
# Check for any available LLM API key (Anthropic preferred, OpenAI fallback)
has_llm_key = bool(settings.anthropic_api_key or settings.openai_api_key)
# VisualizationSelector requires DSPy - make it optional
try:
viz_selector = VisualizationSelector(use_dspy=has_llm_key)
except RuntimeError as e:
logger.warning(f"VisualizationSelector not available: {e}")
viz_selector = None
# Configure DSPy based on LLM_PROVIDER setting
# Respect user's provider preference, with fallback chain
import dspy
llm_provider = settings.llm_provider.lower()
logger.info(f"LLM_PROVIDER configured as: {llm_provider}")
dspy_configured = False
# Try Z.AI GLM if configured as provider (FREE!)
if llm_provider == "zai" and settings.zai_api_token:
try:
# Z.AI uses OpenAI-compatible API format
# Use LLM_MODEL from settings (default: glm-4.5-flash for speed)
zai_model = settings.llm_model if settings.llm_model.startswith("glm-") else "glm-4.5-flash"
lm = dspy.LM(
f"openai/{zai_model}",
api_key=settings.zai_api_token,
api_base="https://api.z.ai/api/coding/paas/v4",
)
dspy.configure(lm=lm)
logger.info(f"Configured DSPy with Z.AI {zai_model} (FREE)")
dspy_configured = True
except Exception as e:
logger.warning(f"Failed to configure DSPy with Z.AI: {e}")
# Try HuggingFace if configured as provider
if not dspy_configured and llm_provider == "huggingface" and settings.huggingface_api_key:
try:
lm = dspy.LM("huggingface/utter-project/EuroLLM-9B-Instruct", api_key=settings.huggingface_api_key)
dspy.configure(lm=lm)
logger.info("Configured DSPy with HuggingFace EuroLLM-9B-Instruct")
dspy_configured = True
except Exception as e:
logger.warning(f"Failed to configure DSPy with HuggingFace: {e}")
# Try Anthropic if not yet configured (either as primary or fallback)
if not dspy_configured and (llm_provider == "anthropic" or (llm_provider == "huggingface" and settings.anthropic_api_key)):
if settings.anthropic_api_key and configure_dspy:
try:
configure_dspy(
provider="anthropic",
model=settings.default_model,
api_key=settings.anthropic_api_key,
)
dspy_configured = True
except Exception as e:
logger.warning(f"Failed to configure DSPy with Anthropic: {e}")
# Try OpenAI as final fallback
if not dspy_configured and settings.openai_api_key and configure_dspy:
try:
configure_dspy(
provider="openai",
model="gpt-4o-mini",
api_key=settings.openai_api_key,
)
dspy_configured = True
except Exception as e:
logger.warning(f"Failed to configure DSPy with OpenAI: {e}")
if not dspy_configured:
logger.warning("No LLM provider configured - DSPy queries will fail")
# Initialize optimized HeritageRAGPipeline (if DSPy is configured)
if dspy_configured:
try:
from dspy_heritage_rag import HeritageRAGPipeline
from pathlib import Path
# Create pipeline with Qdrant retriever
qdrant_retriever = retriever.qdrant if retriever else None
dspy_pipeline = HeritageRAGPipeline(retriever=qdrant_retriever)
# Load optimized model (BootstrapFewShot: 14.3% quality improvement)
# Note: load() may fail if new modules were added that aren't in the saved state
optimized_model_path = Path(__file__).parent / "optimized_models" / "heritage_rag_bootstrap_latest.json"
if optimized_model_path.exists():
try:
dspy_pipeline.load(str(optimized_model_path))
logger.info(f"Loaded optimized DSPy pipeline from {optimized_model_path}")
except Exception as load_err:
# Pipeline still works, just without optimized demos for new modules
logger.warning(f"Could not load optimized model (new modules may need re-optimization): {load_err}")
logger.info("Pipeline initialized without optimized demos - will work but may be less accurate")
else:
logger.warning(f"Optimized model not found at {optimized_model_path}, using unoptimized pipeline")
except Exception as e:
logger.warning(f"Failed to initialize DSPy pipeline: {e}")
dspy_pipeline = None
# === HOT LOADING: Warmup embedding model to avoid cold-start latency ===
# The sentence-transformers model takes 3-15 seconds to load on first use.
# By loading it eagerly at startup, we eliminate this delay for users.
if retriever.qdrant:
logger.info("Warming up embedding model (this takes 3-15 seconds on first startup)...")
warmup_start = time.perf_counter()
try:
# Trigger model load with a dummy embedding request
_ = retriever.qdrant._get_embedding("archief warmup query")
warmup_duration = time.perf_counter() - warmup_start
logger.info(f"✅ Embedding model warmed up in {warmup_duration:.2f}s - ready for fast queries!")
# Record warmup metrics
if record_embedding_warmup:
record_embedding_warmup(
model="sentence-transformers/all-MiniLM-L6-v2",
duration_seconds=warmup_duration,
success=True,
)
except Exception as e:
warmup_duration = time.perf_counter() - warmup_start
logger.warning(f"Failed to warm up embedding model: {e}")
if record_embedding_warmup:
record_embedding_warmup(
model="sentence-transformers/all-MiniLM-L6-v2",
duration_seconds=warmup_duration,
success=False,
)
# === TEMPLATE EMBEDDING WARMUP: Pre-compute embeddings for template patterns ===
# The TemplateEmbeddingMatcher computes embeddings on first query (~2-5 seconds).
# By pre-computing at startup, we eliminate this delay for users.
template_warmup_start = time.perf_counter()
template_count = 0
try:
from template_sparql import get_template_embedding_matcher, TemplateClassifier
logger.info("Pre-computing template pattern embeddings...")
classifier = TemplateClassifier()
templates = classifier._load_templates()
if templates:
template_count = len(templates)
matcher = get_template_embedding_matcher()
if matcher._ensure_embeddings_computed(templates):
template_warmup_duration = time.perf_counter() - template_warmup_start
logger.info(f"✅ Template embeddings pre-computed ({template_count} templates) in {template_warmup_duration:.2f}s")
# Record template warmup metrics
if record_template_embedding_warmup:
record_template_embedding_warmup(
duration_seconds=template_warmup_duration,
template_count=template_count,
success=True,
)
else:
logger.warning("Template embedding computation skipped (model not available)")
if set_warmup_status:
set_warmup_status("template_embeddings", False)
else:
logger.warning("No templates found for embedding warmup")
if set_warmup_status:
set_warmup_status("template_embeddings", False)
except Exception as e:
template_warmup_duration = time.perf_counter() - template_warmup_start
logger.warning(f"Failed to pre-compute template embeddings: {e}")
if record_template_embedding_warmup:
record_template_embedding_warmup(
duration_seconds=template_warmup_duration,
template_count=template_count,
success=False,
)
# === ATOMIC CACHE MANAGER: Sub-task caching for higher hit rates ===
# Research shows 40-70% cache hit rates with atomic decomposition vs 5-15% for full queries.
# Initialize AtomicCacheManager with retriever's semantic cache for persistence.
if ATOMIC_CACHE_AVAILABLE and AtomicCacheManager:
try:
semantic_cache = retriever.cache if retriever else None
atomic_cache_manager = AtomicCacheManager(semantic_cache=semantic_cache)
logger.info("✅ AtomicCacheManager initialized for sub-task caching")
except Exception as e:
logger.warning(f"Failed to initialize AtomicCacheManager: {e}")
# === ONTOLOGY CACHE WARMUP: Pre-load KG values to avoid cold-start latency ===
# The OntologyLoader queries the Knowledge Graph for valid slot values (cities, regions, types).
# These queries can take 1-3 seconds each on first access.
# By pre-loading at startup, we eliminate this delay for users.
ontology_warmup_start = time.perf_counter()
try:
from template_sparql import get_ontology_loader
logger.info("Warming up ontology cache (pre-loading KG values)...")
ontology = get_ontology_loader()
ontology.load() # Triggers KG queries for institution_types, subregions, cities, etc.
ontology_warmup_duration = time.perf_counter() - ontology_warmup_start
cache_stats = ontology.get_kg_cache_stats()
logger.info(
f"✅ Ontology cache warmed up in {ontology_warmup_duration:.2f}s "
f"({cache_stats['cache_size']} KG queries cached, TTL={cache_stats['ttl_seconds']}s)"
)
except Exception as e:
ontology_warmup_duration = time.perf_counter() - ontology_warmup_start
logger.warning(f"Failed to warm up ontology cache: {e}")
logger.info("Heritage RAG API started")
yield
# Shutdown
logger.info("Shutting down Heritage RAG API...")
if retriever:
await retriever.close()
logger.info("Heritage RAG API stopped")
# Create FastAPI app
app = FastAPI(
title=settings.api_title,
version=settings.api_version,
lifespan=lifespan,
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Prometheus metrics endpoint
if METRICS_AVAILABLE and create_metrics_endpoint:
app.include_router(create_metrics_endpoint(), prefix="/api/rag")
# API Endpoints
@app.get("/api/rag/health")
async def health_check() -> dict[str, Any]:
"""Health check for all services."""
health: dict[str, Any] = {
"status": "ok",
"timestamp": datetime.now(timezone.utc).isoformat(),
"services": {},
}
# Check Qdrant
if retriever and retriever.qdrant:
try:
stats = retriever.qdrant.get_stats()
health["services"]["qdrant"] = {
"status": "ok",
"vectors": stats.get("qdrant", {}).get("vectors_count", 0),
}
except Exception as e:
health["services"]["qdrant"] = {"status": "error", "error": str(e)}
# Check SPARQL
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{settings.sparql_endpoint.replace('/query', '')}")
health["services"]["sparql"] = {
"status": "ok" if response.status_code < 500 else "error"
}
except Exception as e:
health["services"]["sparql"] = {"status": "error", "error": str(e)}
# Check TypeDB
if retriever and retriever.typedb:
try:
stats = retriever.typedb.get_stats()
health["services"]["typedb"] = {
"status": "ok",
"entities": stats.get("entities", {}),
}
except Exception as e:
health["services"]["typedb"] = {"status": "error", "error": str(e)}
# Overall status
services = health["services"]
errors = sum(1 for s in services.values() if isinstance(s, dict) and s.get("status") == "error")
health["status"] = "ok" if errors == 0 else "degraded" if errors < 3 else "error"
return health
@app.get("/api/rag/stats")
async def get_stats() -> dict[str, Any]:
"""Get retriever statistics."""
stats: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"retrievers": {},
}
if retriever:
if retriever.qdrant:
stats["retrievers"]["qdrant"] = retriever.qdrant.get_stats()
if retriever.typedb:
stats["retrievers"]["typedb"] = retriever.typedb.get_stats()
return stats
@app.get("/api/rag/stats/costs")
async def get_cost_stats() -> dict[str, Any]:
"""Get cost tracking session statistics.
Returns cumulative statistics for the current session including:
- Total LLM calls and token usage
- Total retrieval operations and latencies
- Estimated costs by model
- Pipeline timing statistics
Returns:
Dict with cost tracker statistics or unavailable message
"""
if not COST_TRACKER_AVAILABLE or not get_tracker:
return {
"available": False,
"message": "Cost tracker module not available",
}
tracker = get_tracker()
return {
"available": True,
"timestamp": datetime.now(timezone.utc).isoformat(),
"session": tracker.get_session_summary(),
}
@app.post("/api/rag/stats/costs/reset")
async def reset_cost_stats() -> dict[str, Any]:
"""Reset cost tracking statistics.
Clears all accumulated statistics and starts a fresh session.
Useful for per-conversation or per-session cost tracking.
Returns:
Confirmation message
"""
if not COST_TRACKER_AVAILABLE or not reset_tracker:
return {
"available": False,
"message": "Cost tracker module not available",
}
reset_tracker()
return {
"available": True,
"message": "Cost tracking statistics reset",
"timestamp": datetime.now(timezone.utc).isoformat(),
}
@app.get("/api/rag/stats/templates")
async def get_template_stats() -> dict[str, Any]:
"""Get template SPARQL usage statistics.
Returns metrics about template-based SPARQL query generation,
including hit rate and breakdown by template ID.
This is useful for:
- Monitoring template coverage (what % of queries use templates)
- Identifying which templates are most used
- Tuning template slot extraction parameters
Returns:
Dict with template hit rate, breakdown by template_id, and timestamp
"""
if not METRICS_AVAILABLE:
return {
"available": False,
"message": "Metrics module not available",
}
# Import the metrics functions
try:
from metrics import get_template_hit_rate, get_template_breakdown
except ImportError:
return {
"available": False,
"message": "Metrics module import failed",
}
return {
"available": True,
"hit_rate": get_template_hit_rate(),
"breakdown": get_template_breakdown(),
"timestamp": datetime.now(timezone.utc).isoformat(),
}
@app.get("/api/rag/embedding/models")
async def get_embedding_models() -> dict[str, Any]:
"""List available embedding models for the Qdrant collections.
Returns information about which embedding models are available in each
collection's named vectors, helping clients choose the right model for
their use case.
Returns:
Dict with available models per collection, current settings, and recommendations
"""
result: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"multi_embedding_enabled": settings.use_multi_embedding,
"preferred_model": settings.preferred_embedding_model,
"collections": {},
"models": {
"openai_1536": {
"description": "OpenAI text-embedding-3-small (1536 dimensions)",
"quality": "high",
"cost": "paid API",
"recommended_for": "production, high-quality semantic search",
},
"minilm_384": {
"description": "sentence-transformers/all-MiniLM-L6-v2 (384 dimensions)",
"quality": "good",
"cost": "free (local)",
"recommended_for": "development, cost-sensitive deployments",
},
"bge_768": {
"description": "BAAI/bge-small-en-v1.5 (768 dimensions)",
"quality": "very good",
"cost": "free (local)",
"recommended_for": "balanced quality/cost, multilingual support",
},
},
}
if retriever and retriever.qdrant:
qdrant = retriever.qdrant
# Check if multi-embedding is enabled and get available models
if hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding:
if hasattr(qdrant, 'multi_retriever') and qdrant.multi_retriever:
multi = qdrant.multi_retriever
# Get available models for institutions collection
try:
inst_models = multi.get_available_models("heritage_custodians")
selected = multi.select_model("heritage_custodians")
result["collections"]["heritage_custodians"] = {
"available_models": [m.value for m in inst_models],
"uses_named_vectors": multi.uses_named_vectors("heritage_custodians"),
"recommended": selected.value if selected else None,
}
except Exception as e:
result["collections"]["heritage_custodians"] = {"error": str(e)}
# Get available models for persons collection
try:
person_models = multi.get_available_models("heritage_persons")
selected = multi.select_model("heritage_persons")
result["collections"]["heritage_persons"] = {
"available_models": [m.value for m in person_models],
"uses_named_vectors": multi.uses_named_vectors("heritage_persons"),
"recommended": selected.value if selected else None,
}
except Exception as e:
result["collections"]["heritage_persons"] = {"error": str(e)}
else:
# Single embedding mode - detect dimension
stats = qdrant.get_stats()
result["single_embedding_mode"] = True
result["note"] = "Collections use single embedding vectors. Enable USE_MULTI_EMBEDDING=true to use named vectors."
return result
class EmbeddingCompareRequest(BaseModel):
"""Request for comparing embedding models."""
query: str = Field(..., description="Query to search with")
collection: str = Field(default="heritage_persons", description="Collection to search")
k: int = Field(default=5, ge=1, le=20, description="Number of results per model")
@app.post("/api/rag/embedding/compare")
async def compare_embedding_models(request: EmbeddingCompareRequest) -> dict[str, Any]:
"""Compare search results across different embedding models.
Performs the same search query using each available embedding model,
allowing A/B testing of embedding quality.
This endpoint is useful for:
- Evaluating which embedding model works best for your queries
- Understanding differences in semantic similarity between models
- Making informed decisions about which model to use in production
Returns:
Dict with results from each embedding model, including scores and overlap analysis
"""
import time
start_time = time.time()
if not retriever or not retriever.qdrant:
raise HTTPException(status_code=503, detail="Qdrant retriever not available")
qdrant = retriever.qdrant
# Check if multi-embedding is available
if not (hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding):
raise HTTPException(
status_code=400,
detail="Multi-embedding mode not enabled. Set USE_MULTI_EMBEDDING=true to use this endpoint."
)
if not (hasattr(qdrant, 'multi_retriever') and qdrant.multi_retriever):
raise HTTPException(status_code=503, detail="Multi-embedding retriever not initialized")
multi = qdrant.multi_retriever
try:
# Use the compare_models method from MultiEmbeddingRetriever
comparison = multi.compare_models(
query=request.query,
collection=request.collection,
k=request.k,
)
elapsed_ms = (time.time() - start_time) * 1000
return {
"query": request.query,
"collection": request.collection,
"k": request.k,
"query_time_ms": round(elapsed_ms, 2),
"comparison": comparison,
}
except Exception as e:
logger.exception(f"Embedding comparison failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/rag/query", response_model=QueryResponse)
async def query_rag(request: QueryRequest) -> QueryResponse:
"""Main RAG query endpoint.
Orchestrates retrieval from multiple sources, merges results,
and optionally generates visualization configuration.
"""
if not retriever:
raise HTTPException(status_code=503, detail="Retriever not initialized")
start_time = asyncio.get_event_loop().time()
# Check cache first
cached = await retriever.cache.get(request.question, request.sources)
if cached:
# Transform cached data to QueryResponse schema
# Cache stores: answer, sparql_query, sources, confidence, context
# QueryResponse needs: question, sparql, results, sources_used, query_time_ms, result_count
try:
# Get sources from cached data (may be strings or DataSource enums)
cached_sources = cached.get("sources", [])
sources_used = []
for s in cached_sources:
if isinstance(s, str):
try:
sources_used.append(DataSource(s))
except ValueError:
# Skip invalid source values
pass
elif isinstance(s, DataSource):
sources_used.append(s)
# Get results from context if available
results = cached.get("results", [])
if not results and cached.get("context"):
results = cached["context"].get("retrieved_results", []) or []
return QueryResponse(
question=request.question,
sparql=cached.get("sparql_query") or cached.get("sparql"),
results=results,
visualization=cached.get("visualization"),
sources_used=sources_used or [DataSource.QDRANT], # Default if none
cache_hit=True,
query_time_ms=cached.get("query_time_ms", 0.0),
result_count=cached.get("result_count", len(results)),
)
except Exception as e:
logger.warning(f"Failed to transform cached response: {e}, skipping cache")
# Fall through to normal query processing
# Route query to appropriate sources
intent, sources = retriever.router.get_sources(request.question, request.sources)
logger.info(f"Query intent: {intent}, sources: {sources}")
# Extract geographic filters from question (province, city, institution type)
geo_filters = extract_geographic_filters(request.question)
if any(geo_filters.values()):
logger.info(f"Geographic filters extracted: {geo_filters}")
# Retrieve from all sources
results = await retriever.retrieve(
request.question,
sources,
request.k,
embedding_model=request.embedding_model,
region_codes=geo_filters["region_codes"],
cities=geo_filters["cities"],
institution_types=geo_filters["institution_types"],
)
# Merge results with provenance tracking
merged_items, retrieval_provenance = retriever.merge_results(results, max_results=request.k * 2)
# Generate visualization config if requested
visualization = None
if request.include_visualization and viz_selector and merged_items:
# Extract schema from first result
schema_fields = list(merged_items[0].keys()) if merged_items else []
schema_str = ", ".join(f for f in schema_fields if not f.startswith("_"))
visualization = viz_selector.select(
request.question,
schema_str,
len(merged_items),
)
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
response_data = {
"question": request.question,
"sparql": None, # Could be populated from SPARQL result
"results": merged_items,
"visualization": visualization,
"sources_used": [s for s in sources],
"cache_hit": False,
"query_time_ms": round(elapsed_ms, 2),
"result_count": len(merged_items),
}
# Cache the response
await retriever.cache.set(request.question, request.sources, response_data)
return QueryResponse(**response_data) # type: ignore[arg-type]
@app.post("/api/rag/sparql", response_model=SPARQLResponse)
async def generate_sparql_endpoint(request: SPARQLRequest) -> SPARQLResponse:
"""Generate SPARQL query from natural language.
Uses TEMPLATE-FIRST approach:
1. Try template-based SPARQL generation (deterministic, validated)
2. Fall back to LLM-based generation only if no template matches
Template approach provides 65% precision vs 10% for LLM-only (Formica et al. 2023).
"""
global _template_pipeline_instance
template_used = False
sparql_query = ""
explanation = ""
try:
# ===================================================================
# STEP 1: Try TEMPLATE-BASED SPARQL generation (preferred)
# ===================================================================
if TEMPLATE_SPARQL_AVAILABLE and get_template_pipeline:
try:
# Get or create singleton pipeline instance
if _template_pipeline_instance is None:
_template_pipeline_instance = get_template_pipeline()
logger.info("[SPARQL] Template pipeline initialized for /api/rag/sparql endpoint")
# Run template matching in thread pool (DSPy is synchronous)
template_result = await asyncio.to_thread(
_template_pipeline_instance,
question=request.question,
conversation_state=None,
language=request.language
)
if template_result.matched and template_result.sparql:
sparql_query = template_result.sparql
template_used = True
explanation = (
f"Template '{template_result.template_id}' matched with "
f"confidence {template_result.confidence:.2f}. "
f"Slots: {template_result.slots}. "
f"{template_result.reasoning}"
)
logger.info(f"[SPARQL] Template match: '{template_result.template_id}' "
f"(confidence={template_result.confidence:.2f})")
else:
logger.info(f"[SPARQL] No template match: {template_result.reasoning}")
except Exception as e:
logger.warning(f"[SPARQL] Template pipeline failed: {e}")
# ===================================================================
# STEP 2: Fall back to LLM-BASED SPARQL generation
# ===================================================================
if not template_used:
if not RETRIEVERS_AVAILABLE:
raise HTTPException(status_code=503, detail="SPARQL generator not available")
logger.info("[SPARQL] Falling back to LLM-based SPARQL generation")
result = generate_sparql(
request.question,
language=request.language,
context=request.context,
use_rag=request.use_rag,
)
sparql_query = result["sparql"]
explanation = result.get("explanation", "")
return SPARQLResponse(
sparql=sparql_query,
explanation=explanation,
rag_used=not template_used, # RAG only used if LLM fallback
retrieved_passages=[],
)
except HTTPException:
raise
except Exception as e:
logger.exception("SPARQL generation failed")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/rag/sparql/execute", response_model=SPARQLExecuteResponse)
async def execute_sparql_query(request: SPARQLExecuteRequest) -> SPARQLExecuteResponse:
"""Execute a SPARQL query directly against the knowledge graph.
This endpoint allows users to run modified SPARQL queries and see the results
without regenerating an answer. Useful for exploration and debugging.
"""
import time
start_time = time.time()
try:
# Use pooled SPARQL client from retriever for better connection reuse
if retriever:
client = await retriever._get_sparql_client()
response = await client.post(
settings.sparql_endpoint,
data={"query": request.sparql_query},
headers={"Accept": "application/sparql-results+json"},
timeout=request.timeout, # Override timeout per-request if specified
)
else:
# Fallback to creating a new client if retriever not initialized
async with httpx.AsyncClient(timeout=request.timeout) as client:
response = await client.post(
settings.sparql_endpoint,
data={"query": request.sparql_query},
headers={"Accept": "application/sparql-results+json"},
)
if response.status_code != 200:
return SPARQLExecuteResponse(
results=[],
result_count=0,
query_time_ms=(time.time() - start_time) * 1000,
error=f"SPARQL endpoint returned {response.status_code}: {response.text[:500]}",
)
data = response.json()
bindings = data.get("results", {}).get("bindings", [])
# Convert bindings to simple dicts
results = [
{k: v.get("value") for k, v in binding.items()}
for binding in bindings
]
return SPARQLExecuteResponse(
results=results,
result_count=len(results),
query_time_ms=(time.time() - start_time) * 1000,
)
except httpx.TimeoutException:
return SPARQLExecuteResponse(
results=[],
result_count=0,
query_time_ms=(time.time() - start_time) * 1000,
error=f"Query timed out after {request.timeout}s",
)
except Exception as e:
logger.exception("SPARQL execution failed")
return SPARQLExecuteResponse(
results=[],
result_count=0,
query_time_ms=(time.time() - start_time) * 1000,
error=str(e),
)
@app.post("/api/rag/sparql/rerun", response_model=SPARQLRerunResponse)
async def rerun_rag_with_sparql(request: SPARQLRerunRequest) -> SPARQLRerunResponse:
"""Re-run the RAG pipeline with modified SPARQL results injected into context.
This endpoint allows users to:
1. Execute a modified SPARQL query
2. Inject those results into the DSPy RAG context
3. Generate a new answer based on the modified knowledge graph results
This affects the entire conversation through DSPy by providing new factual
context that the LLM uses to generate its response.
"""
import time
import dspy
start_time = time.time()
# Step 1: Execute the modified SPARQL query using pooled client
sparql_results: list[dict[str, Any]] = []
try:
# Use pooled SPARQL client from retriever for better connection reuse
if retriever:
client = await retriever._get_sparql_client()
response = await client.post(
settings.sparql_endpoint,
data={"query": request.sparql_query},
headers={"Accept": "application/sparql-results+json"},
)
else:
# Fallback to creating a new client if retriever not initialized
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
settings.sparql_endpoint,
data={"query": request.sparql_query},
headers={"Accept": "application/sparql-results+json"},
)
if response.status_code == 200:
data = response.json()
bindings = data.get("results", {}).get("bindings", [])
sparql_results = [
{k: v.get("value") for k, v in binding.items()}
for binding in bindings[:50] # Limit to 50 results for context
]
logger.info(f"SPARQL rerun: got {len(sparql_results)} results")
else:
logger.warning(f"SPARQL rerun: endpoint returned {response.status_code}")
except Exception as e:
logger.exception(f"SPARQL rerun: execution failed: {e}")
# Step 2: Format SPARQL results as context for DSPy
sparql_context = ""
if sparql_results:
sparql_context = "\n[KENNISGRAAF RESULTATEN (aangepaste SPARQL query)]:\n"
for i, result in enumerate(sparql_results[:20], 1):
entry = " | ".join(f"{k}: {v}" for k, v in result.items() if v)
sparql_context += f" {i}. {entry}\n"
# Step 3: Run DSPy answer generation with injected SPARQL context
answer = ""
try:
from dspy_heritage_rag import HeritageRAGPipeline
# Get LLM configuration
lm = None
provider = request.llm_provider or "zai"
model = request.llm_model
if provider == "zai" and settings.zai_api_token:
model = model or "glm-4.5-flash"
lm = dspy.LM(
f"openai/{model}",
api_key=settings.zai_api_token,
api_base="https://api.z.ai/api/coding/paas/v4",
)
elif provider == "groq" and settings.groq_api_key:
model = model or "llama-3.1-8b-instant"
lm = dspy.LM(f"groq/{model}", api_key=settings.groq_api_key)
elif provider == "openai" and settings.openai_api_key:
model = model or "gpt-4o-mini"
lm = dspy.LM(f"openai/{model}", api_key=settings.openai_api_key)
elif provider == "anthropic" and settings.anthropic_api_key:
model = model or "claude-sonnet-4-20250514"
lm = dspy.LM(f"anthropic/{model}", api_key=settings.anthropic_api_key)
if lm:
with dspy.settings.context(lm=lm):
# Create a simple answer generator that uses the SPARQL context
generate_answer = dspy.ChainOfThought(
"question, sparql_context, language -> answer"
)
result = generate_answer(
question=request.original_question,
sparql_context=sparql_context,
language=request.language,
)
answer = result.answer
else:
answer = f"LLM niet beschikbaar. SPARQL resultaten: {len(sparql_results)} gevonden."
except Exception as e:
logger.exception(f"SPARQL rerun: answer generation failed: {e}")
answer = f"Fout bij het genereren van antwoord: {str(e)}"
return SPARQLRerunResponse(
results=sparql_results,
answer=answer,
sparql_result_count=len(sparql_results),
query_time_ms=(time.time() - start_time) * 1000,
)
@app.post("/api/rag/visualize")
async def get_visualization_config(
question: str = Query(..., description="User's question"),
schema: str = Query(..., description="Comma-separated field names"),
result_count: int = Query(default=0, description="Number of results"),
) -> dict[str, Any]:
"""Get visualization configuration for a query."""
if not viz_selector:
raise HTTPException(status_code=503, detail="Visualization selector not available")
config = viz_selector.select(question, schema, result_count)
return config # type: ignore[no-any-return]
@app.post("/api/rag/typedb/search", response_model=TypeDBSearchResponse)
async def typedb_search(request: TypeDBSearchRequest) -> TypeDBSearchResponse:
"""Direct TypeDB search endpoint.
Search heritage custodians in TypeDB using various strategies:
- semantic: Natural language search (combines type + location patterns)
- name: Search by institution name
- type: Search by institution type (museum, archive, library, gallery)
- location: Search by city/location name
Examples:
- {"query": "museums in Amsterdam", "search_type": "semantic"}
- {"query": "Rijksmuseum", "search_type": "name"}
- {"query": "archive", "search_type": "type"}
- {"query": "Rotterdam", "search_type": "location"}
"""
import time
start_time = time.time()
# Check if TypeDB retriever is available
if not retriever or not retriever.typedb:
raise HTTPException(
status_code=503,
detail="TypeDB retriever not available. Ensure TypeDB is running."
)
try:
typedb_retriever = retriever.typedb
# Route to appropriate search method
if request.search_type == "name":
results = typedb_retriever.search_by_name(request.query, k=request.k)
elif request.search_type == "type":
results = typedb_retriever.search_by_type(request.query, k=request.k)
elif request.search_type == "location":
results = typedb_retriever.search_by_location(city=request.query, k=request.k)
else: # semantic (default)
results = typedb_retriever.semantic_search(request.query, k=request.k)
# Convert results to dicts
result_dicts = []
seen_names = set() # Deduplicate by name
for r in results:
# Handle both dict and object results
if hasattr(r, 'to_dict'):
item = r.to_dict()
elif isinstance(r, dict):
item = r
else:
item = {"name": str(r)}
# Deduplicate by name
name = item.get("name") or item.get("observed_name", "")
if name and name not in seen_names:
seen_names.add(name)
result_dicts.append(item)
elapsed_ms = (time.time() - start_time) * 1000
return TypeDBSearchResponse(
query=request.query,
search_type=request.search_type,
results=result_dicts,
result_count=len(result_dicts),
query_time_ms=round(elapsed_ms, 2),
)
except Exception as e:
logger.exception(f"TypeDB search failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/rag/persons/search", response_model=PersonSearchResponse)
async def person_search(request: PersonSearchRequest) -> PersonSearchResponse:
"""Search for persons/staff in heritage institutions.
Search the heritage_persons Qdrant collection for staff members
at heritage custodian institutions.
Examples:
- {"query": "Wie werkt er in het Nationaal Archief?"}
- {"query": "archivist at Rijksmuseum", "k": 20}
- {"query": "conservator", "filter_custodian": "rijksmuseum"}
- {"query": "digital preservation", "only_heritage_relevant": true}
The search uses semantic vector similarity to find relevant staff members
based on their name, role, headline, and custodian affiliation.
"""
import time
start_time = time.time()
# Check if retriever is available
if not retriever:
raise HTTPException(
status_code=503,
detail="Hybrid retriever not available. Ensure Qdrant is running."
)
try:
# Use the hybrid retriever's person search
results = retriever.search_persons(
query=request.query,
k=request.k,
filter_custodian=request.filter_custodian,
only_heritage_relevant=request.only_heritage_relevant,
using=request.embedding_model, # Pass embedding model
)
# Determine which embedding model was actually used
embedding_model_used = None
qdrant = retriever.qdrant
if qdrant and hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding:
if request.embedding_model:
embedding_model_used = request.embedding_model
elif hasattr(qdrant, '_selected_multi_model') and qdrant._selected_multi_model:
embedding_model_used = qdrant._selected_multi_model.value
# Convert results to dicts using to_dict() method if available
result_dicts = []
for r in results:
if hasattr(r, 'to_dict'):
item = r.to_dict()
elif hasattr(r, '__dict__'):
item = {
"name": getattr(r, 'name', 'Unknown'),
"headline": getattr(r, 'headline', None),
"custodian_name": getattr(r, 'custodian_name', None),
"custodian_slug": getattr(r, 'custodian_slug', None),
"linkedin_url": getattr(r, 'linkedin_url', None),
"heritage_relevant": getattr(r, 'heritage_relevant', None),
"heritage_type": getattr(r, 'heritage_type', None),
"location": getattr(r, 'location', None),
"score": getattr(r, 'combined_score', getattr(r, 'vector_score', None)),
}
elif isinstance(r, dict):
item = r
else:
item = {"name": str(r)}
result_dicts.append(item)
elapsed_ms = (time.time() - start_time) * 1000
# Get collection stats
stats = None
try:
stats = retriever.get_stats()
# Only include person collection stats if available
if stats and 'persons' in stats:
stats = {'persons': stats['persons']}
except Exception:
pass
return PersonSearchResponse(
context=PERSON_JSONLD_CONTEXT, # JSON-LD context for linked data
query=request.query,
results=result_dicts,
result_count=len(result_dicts),
query_time_ms=round(elapsed_ms, 2),
collection_stats=stats,
embedding_model_used=embedding_model_used,
)
except Exception as e:
logger.exception(f"Person search failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
def _extract_subtask_result(task: Any, pipeline_result: Any, response: Any) -> Any:
"""Extract the relevant result portion for an atomic sub-task.
Maps sub-task types to corresponding data from the pipeline result.
This enables caching individual components for reuse in similar queries.
Args:
task: AtomicSubTask with task_type and parameters
pipeline_result: Raw result from HeritageRAGPipeline
response: DSPyQueryResponse object
Returns:
Cacheable result for this sub-task, or None if not extractable
"""
from atomic_decomposer import SubTaskType
task_type = task.task_type
# Intent classification - cache the detected intent
if task_type == SubTaskType.INTENT_CLASSIFICATION:
return {
"intent": task.parameters.get("intent"),
"query_type": getattr(pipeline_result, "query_type", None),
}
# Type filter - cache institution type filtering results
if task_type == SubTaskType.TYPE_FILTER:
inst_type = task.parameters.get("institution_type")
retrieved = getattr(pipeline_result, "retrieved_results", None)
if retrieved and isinstance(retrieved, list):
# Filter results to just this institution type
filtered = [r for r in retrieved if r.get("institution_type") == inst_type]
return {
"institution_type": inst_type,
"count": len(filtered),
"sample_ids": [r.get("id") for r in filtered[:10]], # Cache IDs not full records
}
# Location filter - cache geographic filtering results
if task_type == SubTaskType.LOCATION_FILTER:
location = task.parameters.get("location")
retrieved = getattr(pipeline_result, "retrieved_results", None)
if retrieved and isinstance(retrieved, list):
# Count results in this location
location_lower = location.lower() if location else ""
in_location = [
r for r in retrieved
if location_lower in str(r.get("city", "")).lower() or
location_lower in str(r.get("region", "")).lower()
]
return {
"location": location,
"level": task.parameters.get("level"),
"count": len(in_location),
}
# Aggregation - cache aggregate statistics
if task_type == SubTaskType.AGGREGATION:
agg_type = task.parameters.get("aggregation")
retrieved = getattr(pipeline_result, "retrieved_results", None)
if agg_type == "count" and retrieved:
return {
"aggregation": "count",
"value": len(retrieved) if isinstance(retrieved, list) else 0,
}
# Identifier filter - cache identifier lookup results
if task_type == SubTaskType.IDENTIFIER_FILTER:
id_type = task.parameters.get("identifier_type")
retrieved = getattr(pipeline_result, "retrieved_results", None)
if retrieved and isinstance(retrieved, list):
# Count entities with this identifier type
has_id = [
r for r in retrieved
if r.get(f"has_{id_type}") or r.get(id_type)
]
return {
"identifier_type": id_type,
"has_identifier_count": len(has_id),
}
# Default: don't cache if we can't extract meaningful sub-result
return None
@app.post("/api/rag/dspy/query", response_model=DSPyQueryResponse)
async def dspy_query(request: DSPyQueryRequest) -> DSPyQueryResponse:
"""DSPy RAG query endpoint with multi-turn conversation support.
Uses the HeritageRAGPipeline for conversation-aware question answering.
Follow-up questions like "Welke daarvan behoren archieven?" will be
resolved using previous conversation context.
Args:
request: Query request with question, language, and conversation context
Returns:
DSPyQueryResponse with answer, resolved question, and optional visualization
"""
import time
start_time = time.time()
# Session management for multi-turn conversations
# Get or create session state that enables follow-up question resolution
session_id = request.session_id
conversation_state = None
session_mgr = None
if SESSION_MANAGER_AVAILABLE and get_session_manager:
try:
session_mgr = await get_session_manager()
session_id, conversation_state = await session_mgr.get_or_create(request.session_id)
logger.debug(f"Session {session_id}: {len(conversation_state.turns)} previous turns")
except Exception as e:
logger.warning(f"Session manager error (continuing without session): {e}")
# Generate a new session_id even if session manager failed
import uuid
session_id = str(uuid.uuid4())
else:
# No session manager - generate session_id for tracking purposes
import uuid
session_id = request.session_id or str(uuid.uuid4())
# Resolve the provider BEFORE cache lookup to ensure consistent cache keys
# This is critical: cache GET and SET must use the same provider value
resolved_provider = (request.llm_provider or settings.llm_provider).lower()
# Check cache first (before expensive LLM configuration) unless skip_cache is True
if retriever and not request.skip_cache:
cached = await retriever.cache.get_dspy(
question=request.question,
language=request.language,
llm_provider=resolved_provider, # Use resolved provider, not request.llm_provider
embedding_model=request.embedding_model,
context=request.context if request.context else None,
)
if cached:
elapsed_ms = (time.time() - start_time) * 1000
logger.info(f"DSPy cache hit - returning cached response in {elapsed_ms:.2f}ms")
# Transform CachedResponse format back to DSPyQueryResponse format
# CachedResponse has: sources, visualization_type, visualization_data, context
# DSPyQueryResponse needs: sources_used, visualization, query_type, etc.
cached_context = cached.get("context") or {}
visualization = None
if cached.get("visualization_type") or cached.get("visualization_data"):
visualization = {
"type": cached.get("visualization_type"),
"data": cached.get("visualization_data"),
}
# Restore llm_response metadata (GLM 4.7 reasoning_content) from cache
llm_response_cached = cached_context.get("llm_response")
llm_response_obj = None
if llm_response_cached:
try:
llm_response_obj = LLMResponseMetadata(**llm_response_cached)
except Exception:
# Fall back to dict if LLMResponseMetadata fails
llm_response_obj = llm_response_cached # type: ignore[assignment]
# Rule 46: Build provenance for cache hit responses
cached_sources = cached.get("sources", [])
cached_template_used = cached_context.get("template_used", False)
cached_template_id = cached_context.get("template_id")
cached_llm_provider = cached_context.get("llm_provider")
cached_llm_model = cached_context.get("llm_model")
# Infer data tier - prioritize cached provenance if present
cached_provenance = cached_context.get("epistemic_provenance")
if cached_provenance:
# Use the cached provenance, but mark it as coming from cache
cache_provenance = cached_provenance.copy()
if "CACHE" not in cache_provenance.get("derivationChain", []):
cache_provenance.setdefault("derivationChain", []).insert(0, "CACHE:hit")
else:
# Build fresh provenance for older cache entries
cache_tier = DataTier.TIER_3_CROWD_SOURCED.value
if cached_template_used:
cache_tier = DataTier.TIER_1_AUTHORITATIVE.value
elif any(s.lower() in ["sparql", "typedb"] for s in cached_sources):
cache_tier = DataTier.TIER_1_AUTHORITATIVE.value
cache_provenance = EpistemicProvenance(
dataSource=EpistemicDataSource.CACHE_AGGREGATION,
dataTier=cache_tier,
derivationChain=["CACHE:hit"] + build_derivation_chain(
sources_used=cached_sources,
template_used=cached_template_used,
template_id=cached_template_id,
llm_provider=cached_llm_provider,
),
sourcesQueried=cached_sources,
templateUsed=cached_template_used,
templateId=cached_template_id,
llmProvider=cached_llm_provider,
llmModel=cached_llm_model,
).model_dump()
response_data = {
"question": request.question,
"answer": cached.get("answer", ""),
"sources_used": cached_sources,
"visualization": visualization,
"resolved_question": cached_context.get("resolved_question"),
"retrieved_results": cached_context.get("retrieved_results"),
"query_type": cached_context.get("query_type"),
"embedding_model_used": cached_context.get("embedding_model"),
"llm_model_used": cached_llm_model,
"query_time_ms": round(elapsed_ms, 2),
"cache_hit": True,
"llm_response": llm_response_obj, # GLM 4.7 reasoning_content from cache
# Session management - return session_id for follow-up queries
"session_id": session_id,
# Template tracking from cache
"template_used": cached_template_used,
"template_id": cached_template_id,
# Rule 46: Epistemic provenance for transparency
"epistemic_provenance": cache_provenance,
}
# Record cache hit metrics
if METRICS_AVAILABLE and record_query:
try:
record_query(
endpoint="dspy_query",
template_used=cached_context.get("template_used", False),
template_id=cached_context.get("template_id"),
cache_hit=True,
status="success",
duration_seconds=elapsed_ms / 1000,
intent=cached_context.get("query_type"),
)
except Exception as e:
logger.warning(f"Failed to record cache hit metrics: {e}")
return DSPyQueryResponse(**response_data)
# === ATOMIC SUB-TASK CACHING ===
# Full query cache miss - try atomic decomposition for partial cache hits
# Research shows 40-70% cache hit rates with atomic decomposition
decomposed_query = None
cached_subtasks: dict[str, Any] = {}
if ATOMIC_CACHE_AVAILABLE and atomic_cache_manager:
try:
decomposed_query, cached_subtasks = await atomic_cache_manager.process_query(
query=request.question,
language=request.language,
)
# Record atomic cache metrics
subtask_hits = decomposed_query.partial_cache_hits if decomposed_query else 0
subtask_misses = len(decomposed_query.sub_tasks) - subtask_hits if decomposed_query else 0
if decomposed_query.fully_cached:
# All sub-tasks are cached - can potentially skip LLM
logger.info(f"Atomic cache: fully cached ({len(decomposed_query.sub_tasks)} sub-tasks)")
if record_atomic_cache:
record_atomic_cache(
query_hit=True,
subtask_hits=subtask_hits,
subtask_misses=0,
fully_assembled=True,
)
elif decomposed_query.partial_cache_hits > 0:
# Partial cache hit - some sub-tasks cached
logger.info(
f"Atomic cache: partial hit ({decomposed_query.partial_cache_hits}/"
f"{len(decomposed_query.sub_tasks)} sub-tasks cached)"
)
if record_atomic_cache:
record_atomic_cache(
query_hit=False,
subtask_hits=subtask_hits,
subtask_misses=subtask_misses,
fully_assembled=False,
)
else:
logger.debug(f"Atomic cache: miss (0/{len(decomposed_query.sub_tasks)} sub-tasks)")
if record_atomic_cache:
record_atomic_cache(
query_hit=False,
subtask_hits=0,
subtask_misses=subtask_misses,
fully_assembled=False,
)
except Exception as e:
logger.warning(f"Atomic decomposition failed: {e}")
# ==========================================================================
# FACTUAL QUERY FAST PATH: Skip LLM for count/list queries
# ==========================================================================
# For factual queries (counts, lists, comparisons), the SPARQL results ARE
# the answer. No need for expensive LLM prose generation - just return the
# table directly. This can reduce latency from ~15s to ~2s.
# ==========================================================================
try:
from template_sparql import get_template_pipeline
template_pipeline = get_template_pipeline()
# Try template matching (this handles follow-up resolution internally)
# Note: conversation_state already contains history from request.context
# Run in thread pool to avoid blocking the event loop (DSPy is synchronous)
template_result = await asyncio.to_thread(
template_pipeline,
question=request.question,
language=request.language,
conversation_state=conversation_state,
)
# Check if this is a factual query that can skip LLM (template-driven, not hardcoded)
# Fast path rule: If "prose" is NOT in response_modes, LLM generation is skipped
if template_result.matched and not template_result.requires_llm():
# Log database routing decision
databases_used = template_result.databases if hasattr(template_result, 'databases') else ["oxigraph", "qdrant"]
qdrant_skipped = "qdrant" not in databases_used
logger.info(
f"[FAST-PATH] Template '{template_result.template_id}' uses response_modes={template_result.response_modes}, "
f"databases={databases_used} - skipping LLM generation{', Qdrant skipped' if qdrant_skipped else ''} "
f"(confidence={template_result.confidence:.2f})"
)
# Execute SPARQL directly
sparql_query = template_result.sparql
sparql_results: list[dict[str, Any]] = []
sparql_error: str | None = None
try:
if retriever:
client = await retriever._get_sparql_client()
response = await client.post(
settings.sparql_endpoint,
data={"query": sparql_query},
headers={"Accept": "application/sparql-results+json"},
timeout=30.0,
)
if response.status_code == 200:
data = response.json()
bindings = data.get("results", {}).get("bindings", [])
raw_results = [
{k: v.get("value") for k, v in binding.items()}
for binding in bindings
]
# Check if this is a COUNT query (raw_results has 'count' key)
# COUNT queries return [{"count": "10"}] - don't transform these
is_count_query = raw_results and "count" in raw_results[0]
if is_count_query:
# For COUNT queries, preserve raw results with count value
# Convert count string to int for template rendering
sparql_results = []
for row in raw_results:
count_val = row.get("count", "0")
try:
count_int = int(count_val)
except (ValueError, TypeError):
count_int = 0
sparql_results.append({
"count": count_int,
"metadata": {
"institution_type": template_result.slots.get("institution_type"),
},
"scores": {"combined": 1.0},
})
logger.debug(f"[FACTUAL-QUERY] COUNT query result: {sparql_results[0].get('count') if sparql_results else 0}")
# Execute companion query if available to get entity results for map/list
# This fetches the actual institution records that were counted
companion_query = getattr(template_result, 'companion_query', None)
if companion_query:
try:
companion_response = await client.post(
settings.sparql_endpoint,
data={"query": companion_query},
headers={"Accept": "application/sparql-results+json"},
timeout=30.0,
)
if companion_response.status_code == 200:
companion_data = companion_response.json()
companion_bindings = companion_data.get("results", {}).get("bindings", [])
companion_raw = [
{k: v.get("value") for k, v in binding.items()}
for binding in companion_bindings
]
# Transform companion results to frontend format
companion_results = []
for row in companion_raw:
lat = None
lon = None
if row.get("lat"):
try:
lat = float(row["lat"])
except (ValueError, TypeError):
pass
if row.get("lon"):
try:
lon = float(row["lon"])
except (ValueError, TypeError):
pass
companion_results.append({
"name": row.get("name"),
"institution_uri": row.get("institution"),
"metadata": {
"latitude": lat,
"longitude": lon,
"city": row.get("city") or template_result.slots.get("city"),
"institution_type": template_result.slots.get("institution_type"),
},
"scores": {"combined": 1.0},
})
# Store companion results - these will be used for map/list display
# while sparql_results contains the count for the answer text
if companion_results:
logger.info(f"[COMPANION-QUERY] Fetched {len(companion_results)} entities for display, {sum(1 for r in companion_results if r['metadata'].get('latitude'))} with coordinates")
# Replace sparql_results with companion results for display
# but preserve the count value for answer rendering
count_value = sparql_results[0].get("count", 0) if sparql_results else 0
sparql_results = companion_results
# Add count to first result so it's available for ui_template
if sparql_results:
sparql_results[0]["count"] = count_value
else:
logger.warning(f"[COMPANION-QUERY] Failed with status {companion_response.status_code}")
except Exception as ce:
logger.warning(f"[COMPANION-QUERY] Execution failed: {ce}")
else:
# Transform SPARQL results to match frontend expected format
# Frontend expects: {name, website, metadata: {latitude, longitude, city, ...}}
# SPARQL returns: {name, website, lat, lon, city, ...}
sparql_results = []
for row in raw_results:
# Parse lat/lon to float if present
lat = None
lon = None
if row.get("lat"):
try:
lat = float(row["lat"])
except (ValueError, TypeError):
pass
if row.get("lon"):
try:
lon = float(row["lon"])
except (ValueError, TypeError):
pass
transformed = {
"name": row.get("name"),
"website": row.get("website"),
"metadata": {
"latitude": lat,
"longitude": lon,
"city": row.get("city") or template_result.slots.get("city"),
"country": row.get("country") or template_result.slots.get("country"),
"region": row.get("region") or template_result.slots.get("region"),
"institution_type": row.get("type") or template_result.slots.get("institution_type"),
},
"scores": {"combined": 1.0}, # SPARQL results are exact matches
}
sparql_results.append(transformed)
logger.debug(f"[FACTUAL-QUERY] Transformed {len(sparql_results)} results, {sum(1 for r in sparql_results if r['metadata']['latitude'])} with coordinates")
else:
sparql_error = f"SPARQL returned {response.status_code}"
else:
sparql_error = "Retriever not available"
except Exception as e:
sparql_error = str(e)
logger.warning(f"[FACTUAL-QUERY] SPARQL execution failed: {e}")
elapsed_ms = (time.time() - start_time) * 1000
# Generate answer using ui_template if available, otherwise fallback
if sparql_error:
answer = f"Er is een fout opgetreden bij het uitvoeren van de query: {sparql_error}"
elif not sparql_results:
answer = "Geen resultaten gevonden."
elif template_result.ui_template:
# Use template-defined UI template (template-driven answer formatting)
lang = request.language if request.language in template_result.ui_template else "nl"
ui_tmpl = template_result.ui_template.get(lang, template_result.ui_template.get("nl", ""))
# Build context for Jinja2 template rendering with human-readable labels
# The slots have resolved codes (M, NL-NH) but ui_template expects labels (musea, Noord-Holland)
template_context = {
"result_count": len(sparql_results),
"count": sparql_results[0].get("count", len(sparql_results)) if sparql_results else 0,
**template_result.slots # Include resolved slot values (codes)
}
# Add human-readable labels for common slot types
# Labels loaded from schema/reference files per Rule 41 (no hardcoding)
try:
from schema_labels import get_label_resolver
label_resolver = get_label_resolver()
INSTITUTION_TYPE_LABELS_NL = label_resolver.get_all_institution_type_labels("nl")
INSTITUTION_TYPE_LABELS_EN = label_resolver.get_all_institution_type_labels("en")
SUBREGION_LABELS = label_resolver.get_all_subregion_labels("nl")
except ImportError:
# Fallback if schema_labels module not available (shouldn't happen in prod)
logger.warning("schema_labels module not available, using inline fallback")
INSTITUTION_TYPE_LABELS_NL = {
"M": "musea", "L": "bibliotheken", "A": "archieven", "G": "galerijen",
"O": "overheidsinstellingen", "R": "onderzoekscentra", "C": "bedrijfsarchieven",
"U": "instellingen", "B": "botanische tuinen en dierentuinen",
"E": "onderwijsinstellingen", "S": "heemkundige kringen", "F": "monumenten",
"I": "immaterieel erfgoedgroepen", "X": "gecombineerde instellingen",
"P": "privéverzamelingen", "H": "religieuze erfgoedsites",
"D": "digitale platforms", "N": "erfgoedorganisaties", "T": "culinair erfgoed"
}
INSTITUTION_TYPE_LABELS_EN = {
"M": "museums", "L": "libraries", "A": "archives", "G": "galleries",
"O": "official institutions", "R": "research centers", "C": "corporate archives",
"U": "institutions", "B": "botanical gardens and zoos",
"E": "education providers", "S": "heritage societies", "F": "features",
"I": "intangible heritage groups", "X": "mixed institutions",
"P": "personal collections", "H": "holy sites",
"D": "digital platforms", "N": "heritage NGOs", "T": "taste/smell heritage"
}
SUBREGION_LABELS = {
"NL-DR": "Drenthe", "NL-FR": "Friesland", "NL-GE": "Gelderland",
"NL-GR": "Groningen", "NL-LI": "Limburg", "NL-NB": "Noord-Brabant",
"NL-NH": "Noord-Holland", "NL-OV": "Overijssel", "NL-UT": "Utrecht",
"NL-ZE": "Zeeland", "NL-ZH": "Zuid-Holland", "NL-FL": "Flevoland"
}
# Add institution_type_nl and institution_type_en labels
if "institution_type" in template_result.slots:
type_code = template_result.slots["institution_type"]
template_context["institution_type_nl"] = INSTITUTION_TYPE_LABELS_NL.get(type_code, type_code)
template_context["institution_type_en"] = INSTITUTION_TYPE_LABELS_EN.get(type_code, type_code)
# Add human-readable location label
if "location" in template_result.slots:
loc_code = template_result.slots["location"]
# Check if it's a subregion code
if loc_code in SUBREGION_LABELS:
template_context["location"] = SUBREGION_LABELS[loc_code]
# Otherwise keep the original (might already be a city name)
# Simple Jinja2-style replacement (avoids importing Jinja2)
answer = ui_tmpl
for key, value in template_context.items():
answer = answer.replace("{{ " + key + " }}", str(value))
answer = answer.replace("{{" + key + "}}", str(value))
elif "count" in template_result.response_modes:
# Count query - format as count
count_value = sparql_results[0].get("count", len(sparql_results))
answer = f"Aantal: {count_value}"
else:
# List/table query - just indicate result count
answer = f"Gevonden: {len(sparql_results)} resultaten. Zie de tabel hieronder."
# Determine visualization type from response_modes
viz_types = []
if "table" in template_result.response_modes:
viz_types.append("table")
if "chart" in template_result.response_modes:
viz_types.append("chart")
if "map" in template_result.response_modes:
viz_types.append("map")
# Build response with factual_result=True
factual_response = DSPyQueryResponse(
question=request.question,
resolved_question=getattr(template_result, "resolved_question", None),
answer=answer,
sources_used=["SPARQL Knowledge Graph"],
visualization={
"types": viz_types,
"primary_type": viz_types[0] if viz_types else "table",
"sparql_query": sparql_query,
"response_modes": template_result.response_modes,
"databases_used": databases_used, # For transparency/debugging
},
retrieved_results=sparql_results,
query_type="factual",
query_time_ms=round(elapsed_ms, 2),
conversation_turn=len(request.context),
cache_hit=False,
session_id=session_id,
template_used=True,
template_id=template_result.template_id,
factual_result=True,
sparql_query=sparql_query,
)
# Update session with this turn
if session_mgr and session_id:
try:
await session_mgr.add_turn_to_session(
session_id=session_id,
question=request.question,
answer=answer,
resolved_question=getattr(template_result, "resolved_question", None),
template_id=template_result.template_id,
slots=template_result.slots or {},
)
except Exception as e:
logger.warning(f"Failed to update session: {e}")
# Record metrics
if METRICS_AVAILABLE and record_query:
try:
record_query(
endpoint="dspy_query",
template_used=True,
template_id=template_result.template_id,
cache_hit=False,
status="success",
duration_seconds=elapsed_ms / 1000,
intent="factual",
)
except Exception as e:
logger.warning(f"Failed to record metrics: {e}")
# Cache the response
if retriever:
await retriever.cache.set_dspy(
question=request.question,
language=request.language,
llm_provider="none", # No LLM used
embedding_model=request.embedding_model,
response=factual_response.model_dump(),
context=request.context if request.context else None,
)
logger.info(f"[FACTUAL-QUERY] Returned {len(sparql_results)} results in {elapsed_ms:.2f}ms (LLM skipped)")
return factual_response
except ImportError as e:
logger.debug(f"Template SPARQL not available for factual query detection: {e}")
except Exception as e:
logger.warning(f"Factual query detection failed (continuing with full pipeline): {e}")
# ==========================================================================
# FULL RAG PIPELINE: For non-factual queries or when factual detection fails
# ==========================================================================
try:
# Import DSPy pipeline and History
import dspy
from dspy import History
from dspy_heritage_rag import HeritageRAGPipeline
# Configure DSPy LM per-request based on request.llm_provider (or server default)
# This allows frontend to switch LLM providers dynamically
#
# IMPORTANT: We use dspy.settings.context() instead of dspy.configure() because
# configure() can only be called from the same async task that initially configured DSPy.
# context() provides thread-local overrides that work correctly in async request handlers.
requested_provider = resolved_provider # Already resolved above
llm_provider_used: str | None = None
llm_model_used: str | None = None
lm = None
logger.info(f"LLM provider requested: {requested_provider} (request.llm_provider={request.llm_provider}, server default={settings.llm_provider})")
# Provider configuration priority: requested provider first, then fallback chain
providers_to_try = [requested_provider]
# Add fallback chain (but not duplicates)
for fallback in ["zai", "groq", "anthropic", "openai"]:
if fallback not in providers_to_try:
providers_to_try.append(fallback)
for provider in providers_to_try:
if lm is not None:
break
# Default models per provider (used if request.llm_model is not specified)
# Use LLM_MODEL from settings when it matches the provider prefix
default_models = {
"zai": settings.llm_model if settings.llm_model.startswith("glm-") else "glm-4.5-flash",
"groq": "llama-3.1-8b-instant",
"anthropic": settings.llm_model if settings.llm_model.startswith("claude-") else "claude-sonnet-4-20250514",
"openai": "gpt-4o-mini",
# Llama 3.1 8B: Good balance of speed/quality, available on HF serverless inference
# Alternatives: Qwen/QwQ-32B (better reasoning), mistralai/Mistral-7B-Instruct-v0.2
"huggingface": settings.llm_model if "/" in settings.llm_model else "meta-llama/Llama-3.1-8B-Instruct",
}
# HuggingFace models use org/model format (e.g., meta-llama/Llama-3.1-8B-Instruct)
# Groq models use simple names (e.g., llama-3.1-8b-instant)
model_prefixes = {
"glm-": "zai",
"llama-3.1-": "groq",
"llama-3.3-": "groq",
"claude-": "anthropic",
"gpt-": "openai",
# HuggingFace organization prefixes
"mistralai/": "huggingface",
"google/": "huggingface",
"Qwen/": "huggingface",
"deepseek-ai/": "huggingface",
"meta-llama/": "huggingface",
"utter-project/": "huggingface",
"microsoft/": "huggingface",
"tiiuae/": "huggingface",
}
# Determine which model to use: requested model (if valid for this provider) or default
requested_model = request.llm_model
model_to_use = default_models.get(provider, "")
# Check if requested model matches this provider
if requested_model:
for prefix, model_provider in model_prefixes.items():
if requested_model.startswith(prefix) and model_provider == provider:
model_to_use = requested_model
break
if provider == "zai" and settings.zai_api_token:
try:
lm = dspy.LM(
f"openai/{model_to_use}",
api_key=settings.zai_api_token,
api_base="https://api.z.ai/api/coding/paas/v4",
)
llm_provider_used = "zai"
llm_model_used = model_to_use
logger.info(f"Using Z.AI {model_to_use} (FREE) for this request")
except Exception as e:
logger.warning(f"Failed to create Z.AI LM: {e}")
elif provider == "groq" and settings.groq_api_key:
try:
lm = dspy.LM(f"groq/{model_to_use}", api_key=settings.groq_api_key)
llm_provider_used = "groq"
llm_model_used = model_to_use
logger.info(f"Using Groq {model_to_use} (FREE) for this request")
except Exception as e:
logger.warning(f"Failed to create Groq LM: {e}")
elif provider == "huggingface" and settings.huggingface_api_key:
try:
lm = dspy.LM(f"huggingface/{model_to_use}", api_key=settings.huggingface_api_key)
llm_provider_used = "huggingface"
llm_model_used = model_to_use
logger.info(f"Using HuggingFace {model_to_use} for this request")
except Exception as e:
logger.warning(f"Failed to create HuggingFace LM: {e}")
elif provider == "anthropic" and settings.anthropic_api_key:
try:
lm = dspy.LM(f"anthropic/{model_to_use}", api_key=settings.anthropic_api_key)
llm_provider_used = "anthropic"
llm_model_used = model_to_use
logger.info(f"Using Anthropic {model_to_use} for this request")
except Exception as e:
logger.warning(f"Failed to create Anthropic LM: {e}")
elif provider == "openai" and settings.openai_api_key:
try:
lm = dspy.LM(f"openai/{model_to_use}", api_key=settings.openai_api_key)
llm_provider_used = "openai"
llm_model_used = model_to_use
logger.info(f"Using OpenAI {model_to_use} for this request")
except Exception as e:
logger.warning(f"Failed to create OpenAI LM: {e}")
# No LM could be configured
if lm is None:
raise ValueError(
f"No LLM could be configured. Requested provider: {requested_provider}. "
"Ensure the appropriate API key is set: ZAI_API_TOKEN, GROQ_API_KEY, ANTHROPIC_API_KEY, HUGGINGFACE_API_KEY, or OPENAI_API_KEY."
)
logger.info(f"LLM provider for this request: {llm_provider_used}")
# =================================================================
# PERFORMANCE OPTIMIZATION: Create fast LM for routing/extraction
# Use a fast, cheap model (glm-4.5-flash FREE, gpt-4o-mini $0.15/1M)
# for routing, entity extraction, and SPARQL generation.
# The quality_lm (lm) is used only for final answer generation.
# This can reduce total latency by 2-3x (from ~20s to ~7s).
# =================================================================
fast_lm = None
# Try to create fast_lm based on FAST_LM_PROVIDER setting
# Options: "openai" (fast ~1-2s, $0.15/1M) or "zai" (FREE but slow ~13s)
# Default: openai for speed. Override with FAST_LM_PROVIDER=zai to save costs.
if settings.fast_lm_provider == "openai" and settings.openai_api_key:
try:
fast_lm = dspy.LM("openai/gpt-4o-mini", api_key=settings.openai_api_key)
logger.info("Using OpenAI GPT-4o-mini as fast_lm for routing/extraction (~1-2s)")
except Exception as e:
logger.warning(f"Failed to create fast OpenAI LM: {e}")
if fast_lm is None and settings.fast_lm_provider == "zai" and settings.zai_api_token:
try:
fast_lm = dspy.LM(
"openai/glm-4.5-flash",
api_key=settings.zai_api_token,
api_base="https://api.z.ai/api/coding/paas/v4",
)
logger.info("Using Z.AI GLM-4.5-flash (FREE) as fast_lm for routing/extraction (~13s)")
except Exception as e:
logger.warning(f"Failed to create fast Z.AI LM: {e}")
# Fallback: try the other provider if preferred one failed
if fast_lm is None and settings.openai_api_key:
try:
fast_lm = dspy.LM("openai/gpt-4o-mini", api_key=settings.openai_api_key)
logger.info("Fallback: Using OpenAI GPT-4o-mini as fast_lm")
except Exception as e:
logger.warning(f"Fallback failed - no fast_lm available: {e}")
if fast_lm is None:
logger.info("No fast_lm available - all stages will use quality_lm (slower but works)")
# Convert context to DSPy History format
# Context comes as [{question: "...", answer: "..."}, ...]
# History expects messages in the same format: [{question: "...", answer: "..."}, ...]
# (NOT role/content format - that was a bug!)
history_messages = []
for turn in request.context:
# Only include turns that have both question AND answer
if turn.get("question") and turn.get("answer"):
history_messages.append({
"question": turn["question"],
"answer": turn["answer"]
})
history = History(messages=history_messages) if history_messages else None
# Use global optimized pipeline (loaded with BootstrapFewShot weights: +14.3% quality)
# Falls back to creating a new pipeline if global not available
if dspy_pipeline is not None:
pipeline = dspy_pipeline
logger.debug("Using global optimized DSPy pipeline")
else:
# Fallback: create pipeline without optimized weights
qdrant_retriever = retriever.qdrant if retriever else None
pipeline = HeritageRAGPipeline(
retriever=qdrant_retriever,
fast_lm=fast_lm,
quality_lm=lm,
)
logger.debug("Using fallback (unoptimized) DSPy pipeline")
# Execute query with conversation history
# Retry logic for transient API errors (e.g., Anthropic "Overloaded" errors)
#
# IMPORTANT: We use dspy.settings.context(lm=lm) to set the LLM for this request.
# This provides thread-local overrides that work correctly in async request handlers,
# unlike dspy.configure() which can only be called from the main async task.
max_retries = 3
last_error: Exception | None = None
result = None
# Helper function to run pipeline synchronously (for asyncio.to_thread)
def run_pipeline_sync():
"""Run DSPy pipeline in sync context with retry logic."""
nonlocal last_error, result
with dspy.settings.context(lm=lm):
for attempt in range(max_retries):
try:
# Use pipeline() instead of pipeline.forward() per DSPy 3.0 best practices
return pipeline(
embedding_model=request.embedding_model,
question=request.question,
language=request.language,
history=history,
include_viz=request.include_visualization,
conversation_state=conversation_state, # Pass session state for template SPARQL
)
except Exception as e:
last_error = e
error_str = str(e).lower()
# Check for retryable errors (API overload, rate limits, temporary failures)
is_retryable = any(keyword in error_str for keyword in [
"overloaded", "rate_limit", "rate limit", "too many requests",
"529", "503", "502", "504", # HTTP status codes
"temporarily unavailable", "service unavailable",
"connection reset", "connection refused", "timeout"
])
if is_retryable and attempt < max_retries - 1:
wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
logger.warning(
f"Transient API error (attempt {attempt + 1}/{max_retries}): {e}. "
f"Retrying in {wait_time}s..."
)
time.sleep(wait_time) # OK to block in thread pool
continue
else:
# Non-retryable error or max retries reached
raise
return None
# Run DSPy pipeline in thread pool to avoid blocking the event loop
result = await asyncio.to_thread(run_pipeline_sync)
# If we get here without a result (all retries exhausted), raise the last error
if result is None:
if last_error:
raise last_error
raise HTTPException(status_code=500, detail="Pipeline execution failed with no result")
elapsed_ms = (time.time() - start_time) * 1000
# Extract retrieved results for frontend visualization (tables, graphs)
retrieved_results = getattr(result, "retrieved_results", None)
query_type = getattr(result, "query_type", None)
# Extract visualization if present
visualization = None
if request.include_visualization and hasattr(result, "visualization"):
viz = result.visualization
if viz:
# Now showing SPARQL for all query types including person queries
# Person queries use HeritagePersonSPARQLGenerator (schema:Person predicates)
# Institution queries use HeritageSPARQLGenerator (crm:E39_Actor predicates)
sparql_to_show = getattr(result, "sparql", None)
visualization = {
"type": getattr(viz, "viz_type", "table"),
"sparql_query": sparql_to_show,
}
# Extract LLM response metadata from DSPy history (GLM 4.7 reasoning_content support)
llm_response_metadata = extract_llm_response_metadata(
lm=lm,
provider=llm_provider_used,
latency_ms=int(elapsed_ms),
)
# Extract template SPARQL info from result
template_used = getattr(result, "template_used", False)
template_id = getattr(result, "template_id", None)
# Rule 46: Build epistemic provenance for transparency
# This tracks WHERE, WHEN, and HOW the response data originated
sources_used_list = getattr(result, "sources_used", [])
# Infer data tier from sources - SPARQL/TypeDB are authoritative, Qdrant may include scraped data
inferred_tier = DataTier.TIER_3_CROWD_SOURCED.value # Default
if template_used:
# Template-based SPARQL uses curated Oxigraph data
inferred_tier = DataTier.TIER_1_AUTHORITATIVE.value
elif any(s.lower() in ["sparql", "typedb"] for s in sources_used_list):
inferred_tier = DataTier.TIER_1_AUTHORITATIVE.value
elif any(s.lower() == "qdrant" for s in sources_used_list):
inferred_tier = DataTier.TIER_3_CROWD_SOURCED.value
# Build provenance object
response_provenance = EpistemicProvenance(
dataSource=EpistemicDataSource.RAG_PIPELINE,
dataTier=inferred_tier,
derivationChain=build_derivation_chain(
sources_used=sources_used_list,
template_used=template_used,
template_id=template_id,
llm_provider=llm_provider_used,
),
sourcesQueried=sources_used_list,
totalRetrieved=len(retrieved_results) if retrieved_results else 0,
totalAfterFusion=len(retrieved_results) if retrieved_results else 0,
templateUsed=template_used,
templateId=template_id,
llmProvider=llm_provider_used,
llmModel=llm_model_used,
)
# Build response object
response = DSPyQueryResponse(
question=request.question,
resolved_question=getattr(result, "resolved_question", None),
answer=getattr(result, "answer", "Geen antwoord gevonden."),
sources_used=sources_used_list,
visualization=visualization,
retrieved_results=retrieved_results, # Raw data for frontend visualization
query_type=query_type, # "person" or "institution"
query_time_ms=round(elapsed_ms, 2),
conversation_turn=len(request.context),
embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model),
# Cost tracking fields
timing_ms=getattr(result, "timing_ms", None),
cost_usd=getattr(result, "cost_usd", None),
timing_breakdown=getattr(result, "timing_breakdown", None),
# LLM provider tracking
llm_provider_used=llm_provider_used,
llm_model_used=llm_model_used,
cache_hit=False,
# LLM response provenance (GLM 4.7 Thinking Mode chain-of-thought)
llm_response=llm_response_metadata,
# Session management - return session_id for follow-up queries
session_id=session_id,
# Template SPARQL tracking
template_used=template_used,
template_id=template_id,
# Rule 46: Epistemic provenance for transparency
epistemic_provenance=response_provenance.model_dump(),
)
# Update session with this turn for multi-turn conversation support
if session_mgr and session_id:
try:
await session_mgr.add_turn_to_session(
session_id=session_id,
question=request.question,
answer=response.answer,
resolved_question=response.resolved_question,
template_id=template_id,
slots=getattr(result, "slots", {}), # Extracted slots for follow-up inheritance
)
logger.debug(f"Session {session_id} updated with new turn")
except Exception as e:
logger.warning(f"Failed to update session {session_id}: {e}")
# Record Prometheus metrics for monitoring
if METRICS_AVAILABLE and record_query:
try:
record_query(
endpoint="dspy_query",
template_used=template_used,
template_id=template_id,
cache_hit=False,
status="success",
duration_seconds=elapsed_ms / 1000,
intent=query_type,
)
except Exception as e:
logger.warning(f"Failed to record metrics: {e}")
# Cache the successful response for future requests
if retriever:
await retriever.cache.set_dspy(
question=request.question,
language=request.language,
llm_provider=llm_provider_used, # Use actual provider, not requested
embedding_model=request.embedding_model,
response=response.model_dump(),
context=request.context if request.context else None,
)
# === CACHE ATOMIC SUB-TASKS FOR FUTURE QUERIES ===
# Cache individual sub-tasks for higher hit rates on similar queries
# E.g., "musea in Amsterdam" sub-task can be reused for
# "Hoeveel musea in Amsterdam hebben een website?"
if ATOMIC_CACHE_AVAILABLE and atomic_cache_manager and decomposed_query:
try:
subtasks_cached = 0
for task in decomposed_query.sub_tasks:
if not task.cache_hit:
# Extract relevant result for this sub-task type
subtask_result = _extract_subtask_result(task, result, response)
if subtask_result is not None:
await atomic_cache_manager.cache_subtask_result(
task=task,
result=subtask_result,
language=request.language,
ttl=3600, # 1 hour TTL
)
subtasks_cached += 1
# Record subtasks cached metric
if subtasks_cached > 0 and record_atomic_subtask_cached:
record_atomic_subtask_cached(subtasks_cached)
# Log atomic cache stats periodically
stats = atomic_cache_manager.get_stats()
if stats["queries_decomposed"] % 10 == 0:
logger.info(
f"Atomic cache stats: {stats['subtask_hit_rate']}% hit rate, "
f"{stats['queries_decomposed']} queries, "
f"{stats['full_query_reassemblies']} fully cached"
)
except Exception as e:
logger.warning(f"Failed to cache atomic sub-tasks: {e}")
return response
except ImportError as e:
logger.warning(f"DSPy pipeline not available: {e}")
# Fallback to simple response
return DSPyQueryResponse(
question=request.question,
answer="DSPy pipeline is niet beschikbaar. Probeer de standaard /api/rag/query endpoint.",
query_time_ms=0,
conversation_turn=len(request.context),
embedding_model_used=request.embedding_model,
session_id=session_id, # Still return session_id even on error
)
except Exception as e:
logger.exception("DSPy query failed")
raise HTTPException(status_code=500, detail=str(e))
async def stream_dspy_query_response(
request: DSPyQueryRequest,
) -> AsyncIterator[str]:
"""Stream DSPy query response with progress updates for long-running queries.
Yields NDJSON lines with status updates at each pipeline stage:
- {"type": "status", "stage": "cache", "message": "🔍 Cache controleren..."}
- {"type": "status", "stage": "config", "message": "⚙️ LLM configureren..."}
- {"type": "status", "stage": "routing", "message": "🧭 Vraag analyseren..."}
- {"type": "status", "stage": "retrieval", "message": "📊 Database doorzoeken..."}
- {"type": "status", "stage": "generation", "message": "💡 Antwoord genereren..."}
- {"type": "complete", "data": {...DSPyQueryResponse...}}
"""
import time
start_time = time.time()
# Session management for multi-turn conversations
# Get or create session state that enables follow-up question resolution
session_id = request.session_id
conversation_state = None
session_mgr = None
if SESSION_MANAGER_AVAILABLE and get_session_manager:
try:
session_mgr = await get_session_manager()
session_id, conversation_state = await session_mgr.get_or_create(request.session_id)
logger.debug(f"Stream session {session_id}: {len(conversation_state.turns)} previous turns")
except Exception as e:
logger.warning(f"Session manager error (continuing without session): {e}")
import uuid
session_id = str(uuid.uuid4())
else:
import uuid
session_id = request.session_id or str(uuid.uuid4())
def emit_status(stage: str, message: str) -> str:
"""Helper to emit status JSON line."""
return json.dumps({
"type": "status",
"stage": stage,
"message": message,
"elapsed_ms": round((time.time() - start_time) * 1000, 2),
}) + "\n"
def emit_error(error: str, details: str | None = None) -> str:
"""Helper to emit error JSON line."""
return json.dumps({
"type": "error",
"error": error,
"details": details,
"elapsed_ms": round((time.time() - start_time) * 1000, 2),
}) + "\n"
def extract_user_friendly_error(exception: Exception) -> tuple[str, str | None]:
"""Extract a user-friendly error message from various exception types.
Returns:
tuple: (user_message, technical_details)
"""
error_str = str(exception)
error_lower = error_str.lower()
# HuggingFace / LiteLLM specific errors
if "huggingface" in error_lower or "hf" in error_lower:
if "model_not_supported" in error_lower or "not a chat model" in error_lower:
# Extract model name if present
import re
model_match = re.search(r"model['\"]?\s*[:=]\s*['\"]?([^'\"}\s,]+)", error_str)
model_name = model_match.group(1) if model_match else "geselecteerde model"
return (
f"Het model '{model_name}' wordt niet ondersteund door HuggingFace. Kies een ander model.",
error_str
)
if "rate limit" in error_lower or "too many requests" in error_lower:
return (
"HuggingFace API limiet bereikt. Probeer het over een minuut opnieuw.",
error_str
)
if "unauthorized" in error_lower or "invalid api key" in error_lower:
return (
"HuggingFace API sleutel ongeldig. Neem contact op met de beheerder.",
error_str
)
if "model is loading" in error_lower or "loading" in error_lower and "model" in error_lower:
return (
"Het HuggingFace model wordt geladen. Probeer het over 30 seconden opnieuw.",
error_str
)
# Anthropic errors
if "anthropic" in error_lower:
if "rate limit" in error_lower or "overloaded" in error_lower:
return (
"Anthropic API is overbelast. Probeer het over een minuut opnieuw.",
error_str
)
if "invalid api key" in error_lower or "unauthorized" in error_lower:
return (
"Anthropic API sleutel ongeldig. Neem contact op met de beheerder.",
error_str
)
# OpenAI errors
if "openai" in error_lower:
if "rate limit" in error_lower:
return (
"OpenAI API limiet bereikt. Probeer het over een minuut opnieuw.",
error_str
)
if "invalid api key" in error_lower:
return (
"OpenAI API sleutel ongeldig. Neem contact op met de beheerder.",
error_str
)
# Z.AI errors
if "z.ai" in error_lower or "zai" in error_lower:
if "rate limit" in error_lower or "quota" in error_lower:
return (
"Z.AI API limiet bereikt. Probeer het over een minuut opnieuw.",
error_str
)
# Generic network/connection errors
if "connection" in error_lower or "timeout" in error_lower:
return (
"Verbindingsfout met de AI service. Controleer uw internetverbinding en probeer het opnieuw.",
error_str
)
if "503" in error_str or "service unavailable" in error_lower:
return (
"De AI service is tijdelijk niet beschikbaar. Probeer het over een minuut opnieuw.",
error_str
)
# Qdrant/retrieval errors
if "qdrant" in error_lower:
return (
"Fout bij het doorzoeken van de database. Probeer het later opnieuw.",
error_str
)
# Default: return the raw error but in a nicer format
return (
f"Er is een fout opgetreden: {error_str[:200]}{'...' if len(error_str) > 200 else ''}",
error_str if len(error_str) > 200 else None
)
# Resolve the provider BEFORE cache lookup to ensure consistent cache keys
# This is critical: cache GET and SET must use the same provider value
resolved_provider = (request.llm_provider or settings.llm_provider).lower()
# Stage 1: Check cache
yield emit_status("cache", "🔍 Cache controleren...")
if retriever:
cached = await retriever.cache.get_dspy(
question=request.question,
language=request.language,
llm_provider=resolved_provider, # Use resolved provider, not request.llm_provider
embedding_model=request.embedding_model,
context=request.context if request.context else None,
)
if cached:
elapsed_ms = (time.time() - start_time) * 1000
logger.info(f"DSPy cache hit - returning cached response in {elapsed_ms:.2f}ms")
# Transform CachedResponse format back to DSPyQueryResponse format
cached_context = cached.get("context") or {}
visualization = None
if cached.get("visualization_type") or cached.get("visualization_data"):
visualization = {
"type": cached.get("visualization_type"),
"data": cached.get("visualization_data"),
}
# Rule 46: Build provenance for streaming cache hit responses
stream_cached_sources = cached.get("sources", [])
stream_cached_template_used = cached_context.get("template_used", False)
stream_cached_template_id = cached_context.get("template_id")
stream_cached_llm_provider = cached_context.get("llm_provider")
stream_cached_llm_model = cached_context.get("llm_model")
# Infer data tier - prioritize cached provenance if present
stream_cached_prov = cached_context.get("epistemic_provenance")
if stream_cached_prov:
# Use the cached provenance, but mark it as coming from cache
stream_cache_provenance = stream_cached_prov.copy()
if "CACHE" not in stream_cache_provenance.get("derivationChain", []):
stream_cache_provenance.setdefault("derivationChain", []).insert(0, "CACHE:hit")
else:
# Build fresh provenance for older cache entries
stream_cache_tier = DataTier.TIER_3_CROWD_SOURCED.value
if stream_cached_template_used:
stream_cache_tier = DataTier.TIER_1_AUTHORITATIVE.value
elif any(s.lower() in ["sparql", "typedb"] for s in stream_cached_sources):
stream_cache_tier = DataTier.TIER_1_AUTHORITATIVE.value
stream_cache_provenance = EpistemicProvenance(
dataSource=EpistemicDataSource.CACHE_AGGREGATION,
dataTier=stream_cache_tier,
derivationChain=["CACHE:hit"] + build_derivation_chain(
sources_used=stream_cached_sources,
template_used=stream_cached_template_used,
template_id=stream_cached_template_id,
llm_provider=stream_cached_llm_provider,
),
sourcesQueried=stream_cached_sources,
templateUsed=stream_cached_template_used,
templateId=stream_cached_template_id,
llmProvider=stream_cached_llm_provider,
llmModel=stream_cached_llm_model,
).model_dump()
response_data = {
"question": request.question,
"answer": cached.get("answer", ""),
"sources_used": stream_cached_sources,
"visualization": visualization,
"resolved_question": cached_context.get("resolved_question"),
"retrieved_results": cached_context.get("retrieved_results"),
"query_type": cached_context.get("query_type"),
"embedding_model_used": cached_context.get("embedding_model"),
"llm_model_used": stream_cached_llm_model,
"query_time_ms": round(elapsed_ms, 2),
"cache_hit": True,
# Session management
"session_id": session_id,
# Template tracking from cache
"template_used": stream_cached_template_used,
"template_id": stream_cached_template_id,
# Rule 46: Epistemic provenance for transparency
"epistemic_provenance": stream_cache_provenance,
}
# Record cache hit metrics for streaming endpoint
if METRICS_AVAILABLE and record_query:
try:
record_query(
endpoint="dspy_query_stream",
template_used=cached_context.get("template_used", False),
template_id=cached_context.get("template_id"),
cache_hit=True,
status="success",
duration_seconds=elapsed_ms / 1000,
intent=cached_context.get("query_type"),
)
except Exception as e:
logger.warning(f"Failed to record streaming cache hit metrics: {e}")
yield emit_status("cache", "✅ Antwoord gevonden in cache!")
yield json.dumps({"type": "complete", "data": response_data}) + "\n"
return
try:
# Stage 2: Configure LLM
yield emit_status("config", "⚙️ LLM configureren...")
import dspy
from dspy import History
from dspy_heritage_rag import HeritageRAGPipeline
requested_provider = resolved_provider # Already resolved above
llm_provider_used: str | None = None
llm_model_used: str | None = None
lm = None
providers_to_try = [requested_provider]
for fallback in ["zai", "groq", "anthropic", "openai"]:
if fallback not in providers_to_try:
providers_to_try.append(fallback)
for provider in providers_to_try:
if lm is not None:
break
# Default models per provider (used if request.llm_model is not specified)
# Use LLM_MODEL from settings when it matches the provider prefix
default_models = {
"zai": settings.llm_model if settings.llm_model.startswith("glm-") else "glm-4.5-flash",
"groq": "llama-3.1-8b-instant",
"anthropic": settings.llm_model if settings.llm_model.startswith("claude-") else "claude-sonnet-4-20250514",
"openai": "gpt-4o-mini",
# Llama 3.1 8B: Good balance of speed/quality, available on HF serverless inference
# Alternatives: Qwen/QwQ-32B (better reasoning), mistralai/Mistral-7B-Instruct-v0.2
"huggingface": settings.llm_model if "/" in settings.llm_model else "meta-llama/Llama-3.1-8B-Instruct",
}
# HuggingFace models use org/model format (e.g., meta-llama/Llama-3.1-8B-Instruct)
# Groq models use simple names (e.g., llama-3.1-8b-instant)
model_prefixes = {
"glm-": "zai",
"llama-3.1-": "groq",
"llama-3.3-": "groq",
"claude-": "anthropic",
"gpt-": "openai",
# HuggingFace organization prefixes
"mistralai/": "huggingface",
"google/": "huggingface",
"Qwen/": "huggingface",
"deepseek-ai/": "huggingface",
"meta-llama/": "huggingface",
"utter-project/": "huggingface",
"microsoft/": "huggingface",
"tiiuae/": "huggingface",
}
# Determine which model to use: requested model (if valid for this provider) or default
requested_model = request.llm_model
model_to_use = default_models.get(provider, "")
# Check if requested model matches this provider
if requested_model:
for prefix, model_provider in model_prefixes.items():
if requested_model.startswith(prefix) and model_provider == provider:
model_to_use = requested_model
break
if provider == "zai" and settings.zai_api_token:
try:
lm = dspy.LM(
f"openai/{model_to_use}",
api_key=settings.zai_api_token,
api_base="https://api.z.ai/api/coding/paas/v4",
)
llm_provider_used = "zai"
llm_model_used = model_to_use
except Exception as e:
logger.warning(f"Failed to create Z.AI LM: {e}")
elif provider == "groq" and settings.groq_api_key:
try:
lm = dspy.LM(f"groq/{model_to_use}", api_key=settings.groq_api_key)
llm_provider_used = "groq"
llm_model_used = model_to_use
logger.info(f"Using Groq {model_to_use} (FREE) for streaming request")
except Exception as e:
logger.warning(f"Failed to create Groq LM: {e}")
elif provider == "huggingface" and settings.huggingface_api_key:
try:
lm = dspy.LM(f"huggingface/{model_to_use}", api_key=settings.huggingface_api_key)
llm_provider_used = "huggingface"
llm_model_used = model_to_use
except Exception as e:
logger.warning(f"Failed to create HuggingFace LM: {e}")
elif provider == "anthropic" and settings.anthropic_api_key:
try:
lm = dspy.LM(f"anthropic/{model_to_use}", api_key=settings.anthropic_api_key)
llm_provider_used = "anthropic"
llm_model_used = model_to_use
except Exception as e:
logger.warning(f"Failed to create Anthropic LM: {e}")
elif provider == "openai" and settings.openai_api_key:
try:
lm = dspy.LM(f"openai/{model_to_use}", api_key=settings.openai_api_key)
llm_provider_used = "openai"
llm_model_used = model_to_use
except Exception as e:
logger.warning(f"Failed to create OpenAI LM: {e}")
if lm is None:
yield emit_error(f"Geen LLM beschikbaar. Controleer API keys.")
return
yield emit_status("config", f"✅ LLM geconfigureerd ({llm_provider_used})")
# Stage 3: Prepare conversation history
yield emit_status("routing", "🧭 Vraag analyseren...")
history_messages = []
for turn in request.context:
if turn.get("question") and turn.get("answer"):
history_messages.append({
"question": turn["question"],
"answer": turn["answer"]
})
history = History(messages=history_messages) if history_messages else None
# Use global optimized pipeline (loaded with BootstrapFewShot weights: +14.3% quality)
if dspy_pipeline is not None:
pipeline = dspy_pipeline
logger.debug("Using global optimized DSPy pipeline (streaming)")
else:
# Fallback: create pipeline without optimized weights
qdrant_retriever = retriever.qdrant if retriever else None
pipeline = HeritageRAGPipeline(retriever=qdrant_retriever)
logger.debug("Using fallback (unoptimized) DSPy pipeline (streaming)")
# Stage 4: Execute pipeline with STREAMING answer generation
yield emit_status("retrieval", "📊 Database doorzoeken...")
result = None
# Check if pipeline supports streaming
if hasattr(pipeline, 'forward_streaming'):
# Use streaming mode - tokens arrive as they're generated
try:
with dspy.settings.context(lm=lm):
async for event in pipeline.forward_streaming(
embedding_model=request.embedding_model,
question=request.question,
language=request.language,
history=history,
include_viz=request.include_visualization,
conversation_state=conversation_state, # Pass session state for template SPARQL
):
event_type = event.get("type")
if event_type == "cache_hit":
# Cache hit - return immediately
result = event["prediction"]
yield emit_status("complete", "✅ Klaar! (cache)")
break
elif event_type == "retrieval_complete":
# Retrieval done, now generating answer
yield emit_status("generation", "💡 Antwoord genereren...")
elif event_type == "token":
# Stream token to frontend
yield json.dumps({"type": "token", "content": event["content"]}) + "\n"
elif event_type == "status":
# Status message from pipeline
yield emit_status("generation", event.get("message", "..."))
elif event_type == "answer_complete":
# Final prediction ready
result = event["prediction"]
except Exception as e:
logger.exception(f"Streaming pipeline execution failed: {e}")
user_msg, details = extract_user_friendly_error(e)
yield emit_error(user_msg, details)
return
else:
# Fallback: Non-streaming mode (original behavior)
max_retries = 3
last_error: Exception | None = None
with dspy.settings.context(lm=lm):
for attempt in range(max_retries):
try:
if attempt > 0:
yield emit_status("retrieval", f"🔄 Opnieuw proberen ({attempt + 1}/{max_retries})...")
result = pipeline(
embedding_model=request.embedding_model,
question=request.question,
language=request.language,
history=history,
include_viz=request.include_visualization,
conversation_state=conversation_state,
)
break
except Exception as e:
last_error = e
error_str = str(e).lower()
is_retryable = any(keyword in error_str for keyword in [
"overloaded", "rate_limit", "rate limit", "too many requests",
"529", "503", "502", "504",
"temporarily unavailable", "service unavailable",
"connection reset", "connection refused", "timeout"
])
if is_retryable and attempt < max_retries - 1:
wait_time = 2 ** attempt
logger.warning(f"Transient API error (attempt {attempt + 1}/{max_retries}): {e}")
yield emit_status("retrieval", f"⏳ API overbelast, wachten {wait_time}s...")
await asyncio.sleep(wait_time)
continue
else:
logger.exception(f"Pipeline execution failed after {attempt + 1} attempts")
user_msg, details = extract_user_friendly_error(e)
yield emit_error(user_msg, details)
return
if result is None:
if last_error:
user_msg, details = extract_user_friendly_error(last_error)
yield emit_error(user_msg, details)
return
yield emit_error("Pipeline uitvoering mislukt zonder resultaat")
return
# Stage 5: Generate response (only for non-streaming fallback)
yield emit_status("generation", "💡 Antwoord genereren...")
elapsed_ms = (time.time() - start_time) * 1000
# Extract query_type first - needed for SPARQL visibility decision
query_type = getattr(result, "query_type", None)
visualization = None
if request.include_visualization and hasattr(result, "visualization"):
viz = result.visualization
if viz:
# Now showing SPARQL for all query types including person queries
# Person queries use HeritagePersonSPARQLGenerator (schema:Person predicates)
# Institution queries use HeritageSPARQLGenerator (crm:E39_Actor predicates)
sparql_to_show = getattr(result, "sparql", None)
# viz can be either an object (with .viz_type attr) or a dict (with "type" key)
# Handle both cases for compatibility with streaming and non-streaming modes
if isinstance(viz, dict):
viz_type = viz.get("type", "table")
else:
viz_type = getattr(viz, "viz_type", "table")
visualization = {
"type": viz_type,
"sparql_query": sparql_to_show,
}
logger.info(f"[DEBUG] Built visualization: type={viz_type}, sparql_len={len(sparql_to_show) if sparql_to_show else 0}")
retrieved_results = getattr(result, "retrieved_results", None)
# Extract LLM response metadata from DSPy history (GLM 4.7 reasoning_content support)
llm_response_metadata = extract_llm_response_metadata(
lm=lm,
provider=llm_provider_used,
latency_ms=int(elapsed_ms),
)
# Rule 46: Build epistemic provenance for streaming endpoint
stream_sources_used = getattr(result, "sources_used", [])
stream_template_used = getattr(result, "template_used", False)
stream_template_id = getattr(result, "template_id", None)
# Infer data tier from sources
stream_tier = DataTier.TIER_3_CROWD_SOURCED.value
if stream_template_used:
stream_tier = DataTier.TIER_1_AUTHORITATIVE.value
elif any(s.lower() in ["sparql", "typedb"] for s in stream_sources_used):
stream_tier = DataTier.TIER_1_AUTHORITATIVE.value
stream_provenance = EpistemicProvenance(
dataSource=EpistemicDataSource.RAG_PIPELINE,
dataTier=stream_tier,
derivationChain=build_derivation_chain(
sources_used=stream_sources_used,
template_used=stream_template_used,
template_id=stream_template_id,
llm_provider=llm_provider_used,
),
sourcesQueried=stream_sources_used,
totalRetrieved=len(retrieved_results) if retrieved_results else 0,
totalAfterFusion=len(retrieved_results) if retrieved_results else 0,
templateUsed=stream_template_used,
templateId=stream_template_id,
llmProvider=llm_provider_used,
llmModel=llm_model_used,
)
response = DSPyQueryResponse(
question=request.question,
resolved_question=getattr(result, "resolved_question", None),
answer=getattr(result, "answer", "Geen antwoord gevonden."),
sources_used=stream_sources_used,
visualization=visualization,
retrieved_results=retrieved_results,
query_type=query_type,
query_time_ms=round(elapsed_ms, 2),
conversation_turn=len(request.context),
embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model),
timing_ms=getattr(result, "timing_ms", None),
cost_usd=getattr(result, "cost_usd", None),
timing_breakdown=getattr(result, "timing_breakdown", None),
llm_provider_used=llm_provider_used,
llm_model_used=llm_model_used,
cache_hit=False,
# LLM response provenance (GLM 4.7 Thinking Mode chain-of-thought)
llm_response=llm_response_metadata,
# Session management fields for multi-turn conversations
session_id=session_id,
template_used=stream_template_used,
template_id=stream_template_id,
# Rule 46: Epistemic provenance for transparency
epistemic_provenance=stream_provenance.model_dump(),
)
# Update session with this turn (before caching)
if session_mgr and session_id and conversation_state is not None:
try:
await session_mgr.add_turn_to_session(
session_id=session_id,
question=request.question,
answer=response.answer,
resolved_question=response.resolved_question,
template_id=getattr(result, "template_id", None),
slots=getattr(result, "slots", {}),
)
logger.debug(f"Updated session {session_id} with new turn")
except Exception as e:
logger.warning(f"Failed to update session {session_id}: {e}")
# Record Prometheus metrics for monitoring
if METRICS_AVAILABLE and record_query:
try:
record_query(
endpoint="dspy_query_stream",
template_used=getattr(result, "template_used", False),
template_id=getattr(result, "template_id", None),
cache_hit=False,
status="success",
duration_seconds=elapsed_ms / 1000,
intent=query_type,
)
except Exception as e:
logger.warning(f"Failed to record streaming metrics: {e}")
# Cache the response
if retriever:
await retriever.cache.set_dspy(
question=request.question,
language=request.language,
llm_provider=llm_provider_used,
embedding_model=request.embedding_model,
response=response.model_dump(),
context=request.context if request.context else None,
)
yield emit_status("complete", "✅ Klaar!")
yield json.dumps({"type": "complete", "data": response.model_dump()}) + "\n"
except ImportError as e:
logger.warning(f"DSPy pipeline not available: {e}")
yield emit_error("DSPy pipeline is niet beschikbaar.")
except Exception as e:
logger.exception("DSPy streaming query failed")
user_msg, details = extract_user_friendly_error(e)
yield emit_error(user_msg, details)
@app.post("/api/rag/dspy/query/stream")
async def dspy_query_stream(request: DSPyQueryRequest) -> StreamingResponse:
"""Streaming version of DSPy RAG query endpoint.
Returns NDJSON stream with status updates at each pipeline stage,
allowing the frontend to show progress during long-running queries.
Status stages:
- cache: Checking for cached response
- config: Configuring LLM provider
- routing: Analyzing query intent
- retrieval: Searching databases (Qdrant, SPARQL, etc.)
- generation: Generating answer with LLM
- complete: Final response ready
"""
return StreamingResponse(
stream_dspy_query_response(request),
media_type="application/x-ndjson",
)
async def stream_query_response(
request: QueryRequest,
) -> AsyncIterator[str]:
"""Stream query response for long-running queries."""
if not retriever:
yield json.dumps({"error": "Retriever not initialized"})
return
start_time = asyncio.get_event_loop().time()
# Route query
intent, sources = retriever.router.get_sources(request.question, request.sources)
# Extract geographic filters from question (province, city, institution type)
geo_filters = extract_geographic_filters(request.question)
yield json.dumps({
"type": "status",
"message": f"Routing query to {len(sources)} sources...",
"intent": intent.value,
"geo_filters": {k: v for k, v in geo_filters.items() if v},
}) + "\n"
# Retrieve from sources and stream progress
results = []
for source in sources:
yield json.dumps({
"type": "status",
"message": f"Querying {source.value}...",
}) + "\n"
source_results = await retriever.retrieve(
request.question,
[source],
request.k,
embedding_model=request.embedding_model,
region_codes=geo_filters["region_codes"],
cities=geo_filters["cities"],
institution_types=geo_filters["institution_types"],
)
results.extend(source_results)
yield json.dumps({
"type": "partial",
"source": source.value,
"count": len(source_results[0].items) if source_results else 0,
}) + "\n"
# Merge and finalize with provenance
merged, stream_provenance = retriever.merge_results(results)
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
yield json.dumps({
"type": "complete",
"results": merged,
"query_time_ms": round(elapsed_ms, 2),
"result_count": len(merged),
"epistemic_provenance": stream_provenance.model_dump() if stream_provenance else None,
}) + "\n"
@app.post("/api/rag/query/stream")
async def query_rag_stream(request: QueryRequest) -> StreamingResponse:
"""Streaming version of RAG query endpoint."""
return StreamingResponse(
stream_query_response(request),
media_type="application/x-ndjson",
)
# =============================================================================
# SEMANTIC CACHE ENDPOINTS (Qdrant-backed)
# =============================================================================
#
# High-performance semantic cache using Qdrant's HNSW vector index.
# Replaces slow client-side cosine similarity with server-side ANN search.
#
# Performance target:
# - Cache lookup: <20ms (vs 500-2000ms with client-side scan)
# - Cache store: <50ms
#
# Architecture:
# Frontend → /api/cache/lookup → Qdrant ANN search → cached response
# /api/cache/store → embed + upsert to Qdrant
# =============================================================================
# Lazy-loaded Qdrant client for cache
_cache_qdrant_client: Any = None
_cache_embedding_model: Any = None
CACHE_COLLECTION_NAME = "query_cache"
CACHE_EMBEDDING_DIM = 384 # all-MiniLM-L6-v2
def get_cache_qdrant_client() -> Any:
"""Get or create Qdrant client for cache collection.
Always uses localhost:6333 since cache is co-located with the RAG backend.
This avoids reverse proxy overhead and ensures direct local connection.
"""
global _cache_qdrant_client
if _cache_qdrant_client is not None:
return _cache_qdrant_client
try:
from qdrant_client import QdrantClient
# Cache always uses localhost - co-located with RAG backend
# Uses settings.qdrant_host/port which default to localhost:6333
_cache_qdrant_client = QdrantClient(
host=settings.qdrant_host,
port=settings.qdrant_port,
timeout=30,
)
logger.info(f"Qdrant cache client: {settings.qdrant_host}:{settings.qdrant_port}")
return _cache_qdrant_client
except ImportError:
logger.error("qdrant-client not installed")
return None
except Exception as e:
logger.error(f"Failed to create Qdrant cache client: {e}")
return None
def get_cache_embedding_model() -> Any:
"""Get or create embedding model for cache (MiniLM-L6-v2, 384-dim)."""
global _cache_embedding_model
if _cache_embedding_model is not None:
return _cache_embedding_model
try:
from sentence_transformers import SentenceTransformer
_cache_embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
logger.info("Loaded cache embedding model: all-MiniLM-L6-v2")
return _cache_embedding_model
except ImportError:
logger.error("sentence-transformers not installed")
return None
except Exception as e:
logger.error(f"Failed to load cache embedding model: {e}")
return None
def ensure_cache_collection_exists() -> bool:
"""Ensure the query_cache collection exists in Qdrant."""
client = get_cache_qdrant_client()
if client is None:
return False
try:
from qdrant_client.models import Distance, VectorParams
# Check if collection exists
collections = client.get_collections().collections
if any(c.name == CACHE_COLLECTION_NAME for c in collections):
return True
# Create collection with HNSW index
client.create_collection(
collection_name=CACHE_COLLECTION_NAME,
vectors_config=VectorParams(
size=CACHE_EMBEDDING_DIM,
distance=Distance.COSINE,
),
)
logger.info(f"Created Qdrant collection: {CACHE_COLLECTION_NAME}")
return True
except Exception as e:
logger.error(f"Failed to ensure cache collection: {e}")
return False
# Request/Response Models for Cache API
class CacheLookupRequest(BaseModel):
"""Cache lookup request."""
query: str = Field(..., description="Query text to look up")
embedding: list[float] | None = Field(default=None, description="Pre-computed embedding (optional)")
similarity_threshold: float = Field(default=0.92, description="Minimum similarity for match")
language: str = Field(default="nl", description="Language filter")
class CacheLookupResponse(BaseModel):
"""Cache lookup response."""
found: bool
entry: dict[str, Any] | None = None
similarity: float = 0.0
method: str = "none"
lookup_time_ms: float = 0.0
class CacheStoreRequest(BaseModel):
"""Cache store request."""
query: str = Field(..., description="Query text")
embedding: list[float] | None = Field(default=None, description="Pre-computed embedding (optional)")
response: dict[str, Any] = Field(..., description="Response to cache")
language: str = Field(default="nl", description="Language")
model: str = Field(default="unknown", description="LLM model used")
ttl_seconds: int = Field(default=86400, description="Time-to-live in seconds")
class CacheStoreResponse(BaseModel):
"""Cache store response."""
success: bool
id: str | None = None
message: str = ""
class CacheStatsResponse(BaseModel):
"""Cache statistics response."""
total_entries: int = 0
collection_name: str = CACHE_COLLECTION_NAME
embedding_dim: int = CACHE_EMBEDDING_DIM
backend: str = "qdrant"
status: str = "ok"
@app.post("/api/cache/lookup", response_model=CacheLookupResponse)
async def cache_lookup(request: CacheLookupRequest) -> CacheLookupResponse:
"""Look up a query in the semantic cache using Qdrant ANN search.
This endpoint performs sub-millisecond vector similarity search using
Qdrant's HNSW index, replacing slow client-side cosine similarity scans.
"""
import time
start_time = time.perf_counter()
# Ensure collection exists
if not ensure_cache_collection_exists():
return CacheLookupResponse(
found=False,
similarity=0.0,
method="error",
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
)
client = get_cache_qdrant_client()
if client is None:
return CacheLookupResponse(
found=False,
similarity=0.0,
method="error",
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
)
# Get or generate embedding
embedding = request.embedding
if embedding is None:
model = get_cache_embedding_model()
if model is None:
return CacheLookupResponse(
found=False,
similarity=0.0,
method="error",
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
)
embedding = model.encode(request.query).tolist()
try:
from qdrant_client.models import Filter, FieldCondition, MatchValue
# Build filter for language
search_filter = Filter(
must=[
FieldCondition(
key="language",
match=MatchValue(value=request.language),
)
]
)
# Perform ANN search using query_points (qdrant-client >= 1.7)
results = client.query_points(
collection_name=CACHE_COLLECTION_NAME,
query=embedding,
query_filter=search_filter,
limit=1,
score_threshold=request.similarity_threshold,
).points
elapsed_ms = (time.perf_counter() - start_time) * 1000
if not results:
return CacheLookupResponse(
found=False,
similarity=0.0,
method="semantic",
lookup_time_ms=elapsed_ms,
)
# Extract best match
best = results[0]
payload = best.payload or {}
return CacheLookupResponse(
found=True,
entry={
"id": str(best.id),
"query": payload.get("query", ""),
"query_normalized": payload.get("query_normalized", ""),
"response": payload.get("response", {}),
"timestamp": payload.get("timestamp", 0),
"hit_count": payload.get("hit_count", 0),
"last_accessed": payload.get("last_accessed", 0),
"language": payload.get("language", "nl"),
"model": payload.get("model", "unknown"),
},
similarity=best.score,
method="semantic",
lookup_time_ms=elapsed_ms,
)
except Exception as e:
logger.error(f"Cache lookup error: {e}")
return CacheLookupResponse(
found=False,
similarity=0.0,
method="error",
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
)
@app.post("/api/cache/store", response_model=CacheStoreResponse)
async def cache_store(request: CacheStoreRequest) -> CacheStoreResponse:
"""Store a query/response pair in the semantic cache.
Generates embedding if not provided and upserts to Qdrant.
"""
import time
import uuid
# Ensure collection exists
if not ensure_cache_collection_exists():
return CacheStoreResponse(
success=False,
message="Failed to ensure cache collection exists",
)
client = get_cache_qdrant_client()
if client is None:
return CacheStoreResponse(
success=False,
message="Qdrant client not available",
)
# Get or generate embedding
embedding = request.embedding
if embedding is None:
model = get_cache_embedding_model()
if model is None:
return CacheStoreResponse(
success=False,
message="Embedding model not available",
)
embedding = model.encode(request.query).tolist()
try:
from qdrant_client.models import PointStruct
# Generate unique ID
point_id = str(uuid.uuid4())
timestamp = int(time.time() * 1000)
# Normalize query for exact matching
query_normalized = request.query.lower().strip()
# Create point
point = PointStruct(
id=point_id,
vector=embedding,
payload={
"query": request.query,
"query_normalized": query_normalized,
"response": request.response,
"language": request.language,
"model": request.model,
"timestamp": timestamp,
"hit_count": 0,
"last_accessed": timestamp,
"ttl_seconds": request.ttl_seconds,
},
)
# Upsert to Qdrant
client.upsert(
collection_name=CACHE_COLLECTION_NAME,
points=[point],
)
logger.debug(f"Cached query: {request.query[:50]}...")
return CacheStoreResponse(
success=True,
id=point_id,
message="Stored successfully",
)
except Exception as e:
logger.error(f"Cache store error: {e}")
return CacheStoreResponse(
success=False,
message=str(e),
)
@app.get("/api/cache/stats", response_model=CacheStatsResponse)
async def cache_stats() -> CacheStatsResponse:
"""Get cache statistics."""
client = get_cache_qdrant_client()
if client is None:
return CacheStatsResponse(
status="error",
total_entries=0,
)
try:
# Check if collection exists
collections = client.get_collections().collections
if not any(c.name == CACHE_COLLECTION_NAME for c in collections):
return CacheStatsResponse(
status="no_collection",
total_entries=0,
)
# Get collection info
info = client.get_collection(CACHE_COLLECTION_NAME)
return CacheStatsResponse(
total_entries=info.points_count,
collection_name=CACHE_COLLECTION_NAME,
embedding_dim=CACHE_EMBEDDING_DIM,
backend="qdrant",
status="ok",
)
except Exception as e:
logger.error(f"Cache stats error: {e}")
return CacheStatsResponse(
status=f"error: {e}",
total_entries=0,
)
@app.delete("/api/cache/clear")
async def cache_clear() -> dict[str, Any]:
"""Clear all cache entries (both Qdrant semantic cache and Valkey/Redis cache)."""
results = {
"qdrant": {"success": False, "deleted": 0, "message": ""},
"valkey": {"success": False, "deleted": 0, "message": ""},
}
# 1. Clear Qdrant semantic cache
client = get_cache_qdrant_client()
if client is None:
results["qdrant"]["message"] = "Qdrant client not available"
else:
try:
collections = client.get_collections().collections
if not any(c.name == CACHE_COLLECTION_NAME for c in collections):
results["qdrant"] = {"success": True, "deleted": 0, "message": "Collection does not exist"}
else:
info = client.get_collection(CACHE_COLLECTION_NAME)
count = info.points_count
client.delete_collection(CACHE_COLLECTION_NAME)
ensure_cache_collection_exists()
results["qdrant"] = {"success": True, "deleted": count, "message": f"Cleared {count} entries"}
except Exception as e:
logger.error(f"Qdrant cache clear error: {e}")
results["qdrant"]["message"] = str(e)
# 2. Clear Valkey/Redis cache using redis-cli FLUSHALL
import subprocess
try:
# Try redis-cli first, then valkey-cli
for cli in ["redis-cli", "valkey-cli"]:
try:
result = subprocess.run(
[cli, "FLUSHALL"],
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0 and "OK" in result.stdout:
results["valkey"] = {"success": True, "deleted": -1, "message": f"Flushed via {cli}"}
break
except FileNotFoundError:
continue
if not results["valkey"]["success"]:
results["valkey"]["message"] = "Neither redis-cli nor valkey-cli available"
except subprocess.TimeoutExpired:
results["valkey"]["message"] = "Cache flush timed out"
except Exception as e:
logger.error(f"Valkey cache clear error: {e}")
results["valkey"]["message"] = str(e)
overall_success = results["qdrant"]["success"] or results["valkey"]["success"]
total_deleted = max(0, results["qdrant"]["deleted"]) + max(0, results["valkey"]["deleted"])
return {
"success": overall_success,
"message": f"Qdrant: {results['qdrant']['message']}, Valkey: {results['valkey']['message']}",
"deleted": total_deleted,
"details": results,
}
# Main entry point
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8003,
reload=settings.debug,
log_level="info",
)