""" Unified RAG Backend for Heritage Custodian Data Multi-source retrieval-augmented generation system that combines: - Qdrant vector search (semantic similarity) - Oxigraph SPARQL (knowledge graph queries) - TypeDB (relationship traversal) - PostGIS (geospatial queries) - Valkey (semantic caching) Architecture: User Query → Query Analysis ↓ ┌─────┴─────┐ │ Router │ └─────┬─────┘ ┌─────┬─────┼─────┬─────┐ ↓ ↓ ↓ ↓ ↓ Qdrant SPARQL TypeDB PostGIS Cache │ │ │ │ │ └─────┴─────┴─────┴─────┘ ↓ ┌─────┴─────┐ │ Merger │ └─────┬─────┘ ↓ DSPy Generator ↓ Visualization Selector ↓ Response (JSON/Streaming) Features: - Intelligent query routing to appropriate data sources - Score fusion for multi-source results - Semantic caching via Valkey API - Streaming responses for long-running queries - DSPy assertions for output validation Endpoints: - POST /api/rag/query - Main RAG query endpoint - POST /api/rag/sparql - Generate SPARQL with RAG context - GET /api/rag/health - Health check for all services - GET /api/rag/stats - Retriever statistics """ import asyncio import hashlib import json import logging import os from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum from typing import Any, AsyncIterator import httpx from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # Import retrievers (with graceful fallbacks) try: import sys sys.path.insert(0, str(os.path.join(os.path.dirname(__file__), "..", "..", "src"))) from glam_extractor.api.hybrid_retriever import HybridRetriever, create_hybrid_retriever from glam_extractor.api.qdrant_retriever import HeritageCustodianRetriever from glam_extractor.api.typedb_retriever import TypeDBRetriever, create_typedb_retriever from glam_extractor.api.visualization import select_visualization, VisualizationSelector from glam_extractor.api.dspy_sparql import generate_sparql, configure_dspy RETRIEVERS_AVAILABLE = True except ImportError as e: logger.warning(f"Some retrievers not available: {e}") RETRIEVERS_AVAILABLE = False # Configuration class Settings: """Application settings from environment variables.""" # API Configuration api_title: str = "Heritage RAG API" api_version: str = "1.0.0" debug: bool = os.getenv("DEBUG", "false").lower() == "true" # Valkey Cache valkey_api_url: str = os.getenv("VALKEY_API_URL", "https://bronhouder.nl/api/cache") cache_ttl: int = int(os.getenv("CACHE_TTL", "900")) # 15 minutes # Qdrant Vector DB qdrant_host: str = os.getenv("QDRANT_HOST", "localhost") qdrant_port: int = int(os.getenv("QDRANT_PORT", "6333")) qdrant_use_production: bool = os.getenv("QDRANT_USE_PRODUCTION", "false").lower() == "true" # Oxigraph SPARQL sparql_endpoint: str = os.getenv("SPARQL_ENDPOINT", "http://localhost:7878/query") # TypeDB typedb_host: str = os.getenv("TYPEDB_HOST", "localhost") typedb_port: int = int(os.getenv("TYPEDB_PORT", "1729")) typedb_database: str = os.getenv("TYPEDB_DATABASE", "heritage_custodians") # PostGIS postgis_url: str = os.getenv("POSTGIS_URL", "http://localhost:8001") # LLM Configuration anthropic_api_key: str = os.getenv("ANTHROPIC_API_KEY", "") openai_api_key: str = os.getenv("OPENAI_API_KEY", "") default_model: str = os.getenv("DEFAULT_MODEL", "claude-opus-4-5-20251101") # Retrieval weights vector_weight: float = float(os.getenv("VECTOR_WEIGHT", "0.5")) graph_weight: float = float(os.getenv("GRAPH_WEIGHT", "0.3")) typedb_weight: float = float(os.getenv("TYPEDB_WEIGHT", "0.2")) settings = Settings() # Enums and Models class QueryIntent(str, Enum): """Detected query intent for routing.""" GEOGRAPHIC = "geographic" # Location-based queries STATISTICAL = "statistical" # Counts, aggregations RELATIONAL = "relational" # Relationships between entities TEMPORAL = "temporal" # Historical, timeline queries SEARCH = "search" # General text search DETAIL = "detail" # Specific entity lookup class DataSource(str, Enum): """Available data sources.""" QDRANT = "qdrant" SPARQL = "sparql" TYPEDB = "typedb" POSTGIS = "postgis" CACHE = "cache" @dataclass class RetrievalResult: """Result from a single retriever.""" source: DataSource items: list[dict[str, Any]] score: float = 0.0 query_time_ms: float = 0.0 metadata: dict[str, Any] = field(default_factory=dict) class QueryRequest(BaseModel): """RAG query request.""" question: str = Field(..., description="Natural language question") language: str = Field(default="nl", description="Language code (nl or en)") context: list[dict[str, Any]] = Field(default=[], description="Conversation history") sources: list[DataSource] = Field( default=[DataSource.QDRANT, DataSource.SPARQL], description="Data sources to query", ) k: int = Field(default=10, description="Number of results per source") include_visualization: bool = Field(default=True, description="Include visualization config") stream: bool = Field(default=False, description="Stream response") class QueryResponse(BaseModel): """RAG query response.""" question: str sparql: str | None = None results: list[dict[str, Any]] visualization: dict[str, Any] | None = None sources_used: list[DataSource] cache_hit: bool = False query_time_ms: float result_count: int class SPARQLRequest(BaseModel): """SPARQL generation request.""" question: str language: str = "nl" context: list[dict[str, Any]] = [] use_rag: bool = True class SPARQLResponse(BaseModel): """SPARQL generation response.""" sparql: str explanation: str rag_used: bool retrieved_passages: list[str] = [] # Cache Client class ValkeyClient: """Client for Valkey semantic cache API.""" def __init__(self, base_url: str = settings.valkey_api_url): self.base_url = base_url.rstrip("/") self._client: httpx.AsyncClient | None = None @property async def client(self) -> httpx.AsyncClient: """Get or create async HTTP client.""" if self._client is None or self._client.is_closed: self._client = httpx.AsyncClient(timeout=30.0) return self._client def _cache_key(self, question: str, sources: list[DataSource]) -> str: """Generate cache key from question and sources.""" sources_str = ",".join(sorted(s.value for s in sources)) key_str = f"{question.lower().strip()}:{sources_str}" return hashlib.sha256(key_str.encode()).hexdigest()[:32] async def get(self, question: str, sources: list[DataSource]) -> dict[str, Any] | None: """Get cached response.""" try: key = self._cache_key(question, sources) client = await self.client response = await client.get(f"{self.base_url}/get/{key}") if response.status_code == 200: data = response.json() if data.get("value"): logger.info(f"Cache hit for question: {question[:50]}...") return json.loads(data["value"]) return None except Exception as e: logger.warning(f"Cache get failed: {e}") return None async def set( self, question: str, sources: list[DataSource], response: dict[str, Any], ttl: int = settings.cache_ttl, ) -> bool: """Cache response.""" try: key = self._cache_key(question, sources) client = await self.client await client.post( f"{self.base_url}/set", json={ "key": key, "value": json.dumps(response), "ttl": ttl, }, ) logger.debug(f"Cached response for: {question[:50]}...") return True except Exception as e: logger.warning(f"Cache set failed: {e}") return False async def close(self): """Close HTTP client.""" if self._client: await self._client.aclose() self._client = None # Query Router class QueryRouter: """Routes queries to appropriate data sources based on intent.""" def __init__(self): self.intent_keywords = { QueryIntent.GEOGRAPHIC: [ "map", "kaart", "where", "waar", "location", "locatie", "city", "stad", "country", "land", "region", "gebied", "coordinates", "coördinaten", "near", "nearby", "in de buurt", ], QueryIntent.STATISTICAL: [ "how many", "hoeveel", "count", "aantal", "total", "totaal", "average", "gemiddeld", "distribution", "verdeling", "percentage", "statistics", "statistiek", "most", "meest", ], QueryIntent.RELATIONAL: [ "related", "gerelateerd", "connected", "verbonden", "relationship", "relatie", "network", "netwerk", "parent", "child", "merged", "fusie", "member of", ], QueryIntent.TEMPORAL: [ "history", "geschiedenis", "timeline", "tijdlijn", "when", "wanneer", "founded", "opgericht", "closed", "gesloten", "over time", "evolution", "change", "verandering", ], QueryIntent.DETAIL: [ "details", "information", "informatie", "about", "over", "specific", "specifiek", "what is", "wat is", ], } self.source_routing = { QueryIntent.GEOGRAPHIC: [DataSource.POSTGIS, DataSource.QDRANT, DataSource.SPARQL], QueryIntent.STATISTICAL: [DataSource.SPARQL, DataSource.QDRANT], QueryIntent.RELATIONAL: [DataSource.TYPEDB, DataSource.SPARQL], QueryIntent.TEMPORAL: [DataSource.TYPEDB, DataSource.SPARQL], QueryIntent.SEARCH: [DataSource.QDRANT, DataSource.SPARQL], QueryIntent.DETAIL: [DataSource.SPARQL, DataSource.QDRANT], } def detect_intent(self, question: str) -> QueryIntent: """Detect query intent from question text.""" question_lower = question.lower() intent_scores = {intent: 0 for intent in QueryIntent} for intent, keywords in self.intent_keywords.items(): for keyword in keywords: if keyword in question_lower: intent_scores[intent] += 1 max_intent = max(intent_scores, key=intent_scores.get) # type: ignore if intent_scores[max_intent] == 0: return QueryIntent.SEARCH return max_intent def get_sources( self, question: str, requested_sources: list[DataSource] | None = None, ) -> tuple[QueryIntent, list[DataSource]]: """Get optimal sources for a query. Args: question: User's question requested_sources: Explicitly requested sources (overrides routing) Returns: Tuple of (detected_intent, list_of_sources) """ intent = self.detect_intent(question) if requested_sources: return intent, requested_sources return intent, self.source_routing.get(intent, [DataSource.QDRANT]) # Multi-Source Retriever class MultiSourceRetriever: """Orchestrates retrieval across multiple data sources.""" def __init__(self): self.cache = ValkeyClient() self.router = QueryRouter() # Initialize retrievers lazily self._qdrant: HybridRetriever | None = None self._typedb: TypeDBRetriever | None = None self._sparql_client: httpx.AsyncClient | None = None self._postgis_client: httpx.AsyncClient | None = None @property def qdrant(self) -> HybridRetriever | None: """Lazy-load Qdrant hybrid retriever.""" if self._qdrant is None and RETRIEVERS_AVAILABLE: try: self._qdrant = create_hybrid_retriever( use_production=settings.qdrant_use_production ) except Exception as e: logger.warning(f"Failed to initialize Qdrant: {e}") return self._qdrant @property def typedb(self) -> TypeDBRetriever | None: """Lazy-load TypeDB retriever.""" if self._typedb is None and RETRIEVERS_AVAILABLE: try: self._typedb = create_typedb_retriever( use_production=settings.qdrant_use_production ) except Exception as e: logger.warning(f"Failed to initialize TypeDB: {e}") return self._typedb async def _get_sparql_client(self) -> httpx.AsyncClient: """Get SPARQL HTTP client.""" if self._sparql_client is None or self._sparql_client.is_closed: self._sparql_client = httpx.AsyncClient(timeout=30.0) return self._sparql_client async def _get_postgis_client(self) -> httpx.AsyncClient: """Get PostGIS HTTP client.""" if self._postgis_client is None or self._postgis_client.is_closed: self._postgis_client = httpx.AsyncClient(timeout=30.0) return self._postgis_client async def retrieve_from_qdrant( self, query: str, k: int = 10, ) -> RetrievalResult: """Retrieve from Qdrant vector + SPARQL hybrid search.""" start = asyncio.get_event_loop().time() items = [] if self.qdrant: try: results = self.qdrant.search(query, k=k) items = [r.to_dict() for r in results] except Exception as e: logger.error(f"Qdrant retrieval failed: {e}") elapsed = (asyncio.get_event_loop().time() - start) * 1000 return RetrievalResult( source=DataSource.QDRANT, items=items, score=max((r.get("scores", {}).get("combined", 0) for r in items), default=0), query_time_ms=elapsed, ) async def retrieve_from_sparql( self, query: str, k: int = 10, ) -> RetrievalResult: """Retrieve from SPARQL endpoint.""" start = asyncio.get_event_loop().time() # Use DSPy to generate SPARQL items = [] try: if RETRIEVERS_AVAILABLE: sparql_result = generate_sparql(query, language="nl", use_rag=False) sparql_query = sparql_result.get("sparql", "") if sparql_query: client = await self._get_sparql_client() response = await client.post( settings.sparql_endpoint, data={"query": sparql_query}, headers={"Accept": "application/sparql-results+json"}, ) if response.status_code == 200: data = response.json() bindings = data.get("results", {}).get("bindings", []) items = [ {k: v.get("value") for k, v in b.items()} for b in bindings[:k] ] except Exception as e: logger.error(f"SPARQL retrieval failed: {e}") elapsed = (asyncio.get_event_loop().time() - start) * 1000 return RetrievalResult( source=DataSource.SPARQL, items=items, score=1.0 if items else 0.0, query_time_ms=elapsed, ) async def retrieve_from_typedb( self, query: str, k: int = 10, ) -> RetrievalResult: """Retrieve from TypeDB knowledge graph.""" start = asyncio.get_event_loop().time() items = [] if self.typedb: try: results = self.typedb.semantic_search(query, k=k) items = [r.to_dict() for r in results] except Exception as e: logger.error(f"TypeDB retrieval failed: {e}") elapsed = (asyncio.get_event_loop().time() - start) * 1000 return RetrievalResult( source=DataSource.TYPEDB, items=items, score=max((r.get("relevance_score", 0) for r in items), default=0), query_time_ms=elapsed, ) async def retrieve_from_postgis( self, query: str, k: int = 10, ) -> RetrievalResult: """Retrieve from PostGIS geospatial database.""" start = asyncio.get_event_loop().time() # Extract location from query for geospatial search # This is a simplified implementation items = [] try: client = await self._get_postgis_client() # Try to detect city name for bbox search query_lower = query.lower() # Simple city detection cities = { "amsterdam": {"lat": 52.3676, "lon": 4.9041}, "rotterdam": {"lat": 51.9244, "lon": 4.4777}, "den haag": {"lat": 52.0705, "lon": 4.3007}, "utrecht": {"lat": 52.0907, "lon": 5.1214}, } for city, coords in cities.items(): if city in query_lower: # Query PostGIS for nearby institutions response = await client.get( f"{settings.postgis_url}/api/institutions/nearby", params={ "lat": coords["lat"], "lon": coords["lon"], "radius_km": 10, "limit": k, }, ) if response.status_code == 200: items = response.json() break except Exception as e: logger.error(f"PostGIS retrieval failed: {e}") elapsed = (asyncio.get_event_loop().time() - start) * 1000 return RetrievalResult( source=DataSource.POSTGIS, items=items, score=1.0 if items else 0.0, query_time_ms=elapsed, ) async def retrieve( self, question: str, sources: list[DataSource], k: int = 10, ) -> list[RetrievalResult]: """Retrieve from multiple sources concurrently. Args: question: User's question sources: Data sources to query k: Number of results per source Returns: List of RetrievalResult from each source """ tasks = [] for source in sources: if source == DataSource.QDRANT: tasks.append(self.retrieve_from_qdrant(question, k)) elif source == DataSource.SPARQL: tasks.append(self.retrieve_from_sparql(question, k)) elif source == DataSource.TYPEDB: tasks.append(self.retrieve_from_typedb(question, k)) elif source == DataSource.POSTGIS: tasks.append(self.retrieve_from_postgis(question, k)) results = await asyncio.gather(*tasks, return_exceptions=True) # Filter out exceptions valid_results = [] for r in results: if isinstance(r, RetrievalResult): valid_results.append(r) elif isinstance(r, Exception): logger.error(f"Retrieval task failed: {r}") return valid_results def merge_results( self, results: list[RetrievalResult], max_results: int = 20, ) -> list[dict[str, Any]]: """Merge and deduplicate results from multiple sources. Uses reciprocal rank fusion for score combination. """ # Track items by GHCID for deduplication merged: dict[str, dict[str, Any]] = {} for result in results: for rank, item in enumerate(result.items): ghcid = item.get("ghcid", item.get("id", f"unknown_{rank}")) if ghcid not in merged: merged[ghcid] = item.copy() merged[ghcid]["_sources"] = [] merged[ghcid]["_rrf_score"] = 0.0 # Reciprocal Rank Fusion rrf_score = 1.0 / (60 + rank) # k=60 is standard # Weight by source source_weights = { DataSource.QDRANT: settings.vector_weight, DataSource.SPARQL: settings.graph_weight, DataSource.TYPEDB: settings.typedb_weight, DataSource.POSTGIS: 0.3, } weight = source_weights.get(result.source, 0.5) merged[ghcid]["_rrf_score"] += rrf_score * weight merged[ghcid]["_sources"].append(result.source.value) # Sort by RRF score sorted_items = sorted( merged.values(), key=lambda x: x.get("_rrf_score", 0), reverse=True, ) return sorted_items[:max_results] async def close(self): """Clean up resources.""" await self.cache.close() if self._sparql_client: await self._sparql_client.aclose() if self._postgis_client: await self._postgis_client.aclose() if self._qdrant: self._qdrant.close() if self._typedb: self._typedb.close() # Global instances retriever: MultiSourceRetriever | None = None viz_selector: VisualizationSelector | None = None @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan manager.""" global retriever, viz_selector # Startup logger.info("Starting Heritage RAG API...") retriever = MultiSourceRetriever() if RETRIEVERS_AVAILABLE: viz_selector = VisualizationSelector(use_dspy=bool(settings.anthropic_api_key)) # Configure DSPy if API key available if settings.anthropic_api_key: try: configure_dspy( provider="anthropic", model=settings.default_model, api_key=settings.anthropic_api_key, ) except Exception as e: logger.warning(f"Failed to configure DSPy: {e}") logger.info("Heritage RAG API started") yield # Shutdown logger.info("Shutting down Heritage RAG API...") if retriever: await retriever.close() logger.info("Heritage RAG API stopped") # Create FastAPI app app = FastAPI( title=settings.api_title, version=settings.api_version, lifespan=lifespan, ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # API Endpoints @app.get("/api/rag/health") async def health_check(): """Health check for all services.""" health = { "status": "ok", "timestamp": datetime.now(timezone.utc).isoformat(), "services": {}, } # Check Qdrant if retriever and retriever.qdrant: try: stats = retriever.qdrant.get_stats() health["services"]["qdrant"] = { "status": "ok", "vectors": stats.get("qdrant", {}).get("vectors_count", 0), } except Exception as e: health["services"]["qdrant"] = {"status": "error", "error": str(e)} # Check SPARQL try: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(f"{settings.sparql_endpoint.replace('/query', '')}") health["services"]["sparql"] = { "status": "ok" if response.status_code < 500 else "error" } except Exception as e: health["services"]["sparql"] = {"status": "error", "error": str(e)} # Check TypeDB if retriever and retriever.typedb: try: stats = retriever.typedb.get_stats() health["services"]["typedb"] = { "status": "ok", "entities": stats.get("entities", {}), } except Exception as e: health["services"]["typedb"] = {"status": "error", "error": str(e)} # Overall status errors = sum(1 for s in health["services"].values() if s.get("status") == "error") health["status"] = "ok" if errors == 0 else "degraded" if errors < 3 else "error" return health @app.get("/api/rag/stats") async def get_stats(): """Get retriever statistics.""" stats = { "timestamp": datetime.now(timezone.utc).isoformat(), "retrievers": {}, } if retriever: if retriever.qdrant: stats["retrievers"]["qdrant"] = retriever.qdrant.get_stats() if retriever.typedb: stats["retrievers"]["typedb"] = retriever.typedb.get_stats() return stats @app.post("/api/rag/query", response_model=QueryResponse) async def query_rag(request: QueryRequest): """Main RAG query endpoint. Orchestrates retrieval from multiple sources, merges results, and optionally generates visualization configuration. """ if not retriever: raise HTTPException(status_code=503, detail="Retriever not initialized") start_time = asyncio.get_event_loop().time() # Check cache first cached = await retriever.cache.get(request.question, request.sources) if cached: cached["cache_hit"] = True return QueryResponse(**cached) # Route query to appropriate sources intent, sources = retriever.router.get_sources(request.question, request.sources) logger.info(f"Query intent: {intent}, sources: {sources}") # Retrieve from all sources results = await retriever.retrieve(request.question, sources, request.k) # Merge results merged_items = retriever.merge_results(results, max_results=request.k * 2) # Generate visualization config if requested visualization = None if request.include_visualization and viz_selector and merged_items: # Extract schema from first result schema_fields = list(merged_items[0].keys()) if merged_items else [] schema_str = ", ".join(f for f in schema_fields if not f.startswith("_")) visualization = viz_selector.select( request.question, schema_str, len(merged_items), ) elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000 response_data = { "question": request.question, "sparql": None, # Could be populated from SPARQL result "results": merged_items, "visualization": visualization, "sources_used": [s for s in sources], "cache_hit": False, "query_time_ms": round(elapsed_ms, 2), "result_count": len(merged_items), } # Cache the response await retriever.cache.set(request.question, request.sources, response_data) return QueryResponse(**response_data) @app.post("/api/rag/sparql", response_model=SPARQLResponse) async def generate_sparql_endpoint(request: SPARQLRequest): """Generate SPARQL query from natural language. Uses DSPy with optional RAG enhancement for context. """ if not RETRIEVERS_AVAILABLE: raise HTTPException(status_code=503, detail="SPARQL generator not available") try: result = generate_sparql( request.question, language=request.language, context=request.context, use_rag=request.use_rag, ) return SPARQLResponse( sparql=result["sparql"], explanation=result.get("explanation", ""), rag_used=result.get("rag_used", False), retrieved_passages=result.get("retrieved_passages", []), ) except Exception as e: logger.exception("SPARQL generation failed") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/rag/visualize") async def get_visualization_config( question: str = Query(..., description="User's question"), schema: str = Query(..., description="Comma-separated field names"), result_count: int = Query(default=0, description="Number of results"), ): """Get visualization configuration for a query.""" if not viz_selector: raise HTTPException(status_code=503, detail="Visualization selector not available") config = viz_selector.select(question, schema, result_count) return config async def stream_query_response( request: QueryRequest, ) -> AsyncIterator[str]: """Stream query response for long-running queries.""" if not retriever: yield json.dumps({"error": "Retriever not initialized"}) return start_time = asyncio.get_event_loop().time() # Route query intent, sources = retriever.router.get_sources(request.question, request.sources) yield json.dumps({ "type": "status", "message": f"Routing query to {len(sources)} sources...", "intent": intent.value, }) + "\n" # Retrieve from sources and stream progress results = [] for source in sources: yield json.dumps({ "type": "status", "message": f"Querying {source.value}...", }) + "\n" source_results = await retriever.retrieve(request.question, [source], request.k) results.extend(source_results) yield json.dumps({ "type": "partial", "source": source.value, "count": len(source_results[0].items) if source_results else 0, }) + "\n" # Merge and finalize merged = retriever.merge_results(results) elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000 yield json.dumps({ "type": "complete", "results": merged, "query_time_ms": round(elapsed_ms, 2), "result_count": len(merged), }) + "\n" @app.post("/api/rag/query/stream") async def query_rag_stream(request: QueryRequest): """Streaming version of RAG query endpoint.""" return StreamingResponse( stream_query_response(request), media_type="application/x-ndjson", ) # Main entry point if __name__ == "__main__": import uvicorn uvicorn.run( "main:app", host="0.0.0.0", port=8002, reload=settings.debug, log_level="info", )