from __future__ import annotations """ Unified RAG Backend for Heritage Custodian Data Multi-source retrieval-augmented generation system that combines: - Qdrant vector search (semantic similarity) - Oxigraph SPARQL (knowledge graph queries) - TypeDB (relationship traversal) - PostGIS (geospatial queries) - Valkey (semantic caching) Architecture: User Query → Query Analysis ↓ ┌─────┴─────┐ │ Router │ └─────┬─────┘ ┌─────┬─────┼─────┬─────┐ ↓ ↓ ↓ ↓ ↓ Qdrant SPARQL TypeDB PostGIS Cache │ │ │ │ │ └─────┴─────┴─────┴─────┘ ↓ ┌─────┴─────┐ │ Merger │ └─────┬─────┘ ↓ DSPy Generator ↓ Visualization Selector ↓ Response (JSON/Streaming) Features: - Intelligent query routing to appropriate data sources - Score fusion for multi-source results - Semantic caching via Valkey API - Streaming responses for long-running queries - DSPy assertions for output validation Endpoints: - POST /api/rag/query - Main RAG query endpoint - POST /api/rag/sparql - Generate SPARQL with RAG context - POST /api/rag/typedb/search - Direct TypeDB search - GET /api/rag/health - Health check for all services - GET /api/rag/stats - Retriever statistics """ import asyncio import hashlib import json import logging import os from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum from typing import Any, AsyncIterator, TYPE_CHECKING import httpx from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # Type hints for optional imports (only used during type checking) if TYPE_CHECKING: from glam_extractor.api.hybrid_retriever import HybridRetriever from glam_extractor.api.typedb_retriever import TypeDBRetriever from glam_extractor.api.visualization import VisualizationSelector # Import retrievers (with graceful fallbacks) RETRIEVERS_AVAILABLE = False create_hybrid_retriever: Any = None HeritageCustodianRetriever: Any = None create_typedb_retriever: Any = None select_visualization: Any = None VisualizationSelector: Any = None # type: ignore[no-redef] generate_sparql: Any = None configure_dspy: Any = None try: import sys sys.path.insert(0, str(os.path.join(os.path.dirname(__file__), "..", "..", "src"))) from glam_extractor.api.hybrid_retriever import HybridRetriever as _HybridRetriever, create_hybrid_retriever as _create_hybrid_retriever from glam_extractor.api.qdrant_retriever import HeritageCustodianRetriever as _HeritageCustodianRetriever from glam_extractor.api.typedb_retriever import TypeDBRetriever as _TypeDBRetriever, create_typedb_retriever as _create_typedb_retriever from glam_extractor.api.visualization import select_visualization as _select_visualization, VisualizationSelector as _VisualizationSelector # Assign to module-level variables create_hybrid_retriever = _create_hybrid_retriever HeritageCustodianRetriever = _HeritageCustodianRetriever create_typedb_retriever = _create_typedb_retriever select_visualization = _select_visualization VisualizationSelector = _VisualizationSelector RETRIEVERS_AVAILABLE = True except ImportError as e: logger.warning(f"Core retrievers not available: {e}") # DSPy is optional - don't block retrievers if it's missing try: from glam_extractor.api.dspy_sparql import generate_sparql as _generate_sparql, configure_dspy as _configure_dspy generate_sparql = _generate_sparql configure_dspy = _configure_dspy except ImportError as e: logger.warning(f"DSPy SPARQL not available: {e}") # Cost tracker is optional - gracefully degrades if unavailable COST_TRACKER_AVAILABLE = False get_tracker = None reset_tracker = None try: from cost_tracker import get_tracker as _get_tracker, reset_tracker as _reset_tracker get_tracker = _get_tracker reset_tracker = _reset_tracker COST_TRACKER_AVAILABLE = True logger.info("Cost tracker module loaded successfully") except ImportError as e: logger.info(f"Cost tracker not available (optional): {e}") # Configuration class Settings: """Application settings from environment variables.""" # API Configuration api_title: str = "Heritage RAG API" api_version: str = "1.0.0" debug: bool = os.getenv("DEBUG", "false").lower() == "true" # Valkey Cache valkey_api_url: str = os.getenv("VALKEY_API_URL", "https://bronhouder.nl/api/cache") cache_ttl: int = int(os.getenv("CACHE_TTL", "900")) # 15 minutes # Qdrant Vector DB # Production: Use URL-based client via bronhouder.nl/qdrant reverse proxy qdrant_host: str = os.getenv("QDRANT_HOST", "localhost") qdrant_port: int = int(os.getenv("QDRANT_PORT", "6333")) qdrant_use_production: bool = os.getenv("QDRANT_USE_PRODUCTION", "true").lower() == "true" qdrant_production_url: str = os.getenv("QDRANT_PRODUCTION_URL", "https://bronhouder.nl/qdrant") # Multi-Embedding Support # Enable to use named vectors with multiple embedding models (OpenAI 1536, MiniLM 384, BGE 768) use_multi_embedding: bool = os.getenv("USE_MULTI_EMBEDDING", "true").lower() == "true" preferred_embedding_model: str | None = os.getenv("PREFERRED_EMBEDDING_MODEL", None) # e.g., "minilm_384" or "openai_1536" # Oxigraph SPARQL # Production: Use bronhouder.nl/sparql reverse proxy sparql_endpoint: str = os.getenv("SPARQL_ENDPOINT", "https://bronhouder.nl/sparql") # TypeDB # Note: TypeDB not exposed via reverse proxy - always use localhost typedb_host: str = os.getenv("TYPEDB_HOST", "localhost") typedb_port: int = int(os.getenv("TYPEDB_PORT", "1729")) typedb_database: str = os.getenv("TYPEDB_DATABASE", "heritage_custodians") typedb_use_production: bool = os.getenv("TYPEDB_USE_PRODUCTION", "false").lower() == "true" # Default off # PostGIS/Geo API # Production: Use bronhouder.nl/api/geo reverse proxy postgis_url: str = os.getenv("POSTGIS_URL", "https://bronhouder.nl/api/geo") # LLM Configuration anthropic_api_key: str = os.getenv("ANTHROPIC_API_KEY", "") openai_api_key: str = os.getenv("OPENAI_API_KEY", "") default_model: str = os.getenv("DEFAULT_MODEL", "claude-opus-4-5-20251101") # Retrieval weights vector_weight: float = float(os.getenv("VECTOR_WEIGHT", "0.5")) graph_weight: float = float(os.getenv("GRAPH_WEIGHT", "0.3")) typedb_weight: float = float(os.getenv("TYPEDB_WEIGHT", "0.2")) settings = Settings() # Enums and Models class QueryIntent(str, Enum): """Detected query intent for routing.""" GEOGRAPHIC = "geographic" # Location-based queries STATISTICAL = "statistical" # Counts, aggregations RELATIONAL = "relational" # Relationships between entities TEMPORAL = "temporal" # Historical, timeline queries SEARCH = "search" # General text search DETAIL = "detail" # Specific entity lookup class DataSource(str, Enum): """Available data sources.""" QDRANT = "qdrant" SPARQL = "sparql" TYPEDB = "typedb" POSTGIS = "postgis" CACHE = "cache" @dataclass class RetrievalResult: """Result from a single retriever.""" source: DataSource items: list[dict[str, Any]] score: float = 0.0 query_time_ms: float = 0.0 metadata: dict[str, Any] = field(default_factory=dict) class QueryRequest(BaseModel): """RAG query request.""" question: str = Field(..., description="Natural language question") language: str = Field(default="nl", description="Language code (nl or en)") context: list[dict[str, Any]] = Field(default=[], description="Conversation history") sources: list[DataSource] = Field( default=[DataSource.QDRANT, DataSource.SPARQL], description="Data sources to query", ) k: int = Field(default=10, description="Number of results per source") include_visualization: bool = Field(default=True, description="Include visualization config") embedding_model: str | None = Field( default=None, description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available." ) stream: bool = Field(default=False, description="Stream response") embedding_model: str | None = Field( default=None, description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available." ) class QueryResponse(BaseModel): """RAG query response.""" question: str sparql: str | None = None results: list[dict[str, Any]] visualization: dict[str, Any] | None = None sources_used: list[DataSource] cache_hit: bool = False query_time_ms: float result_count: int class SPARQLRequest(BaseModel): """SPARQL generation request.""" question: str language: str = "nl" context: list[dict[str, Any]] = [] use_rag: bool = True class SPARQLResponse(BaseModel): """SPARQL generation response.""" sparql: str explanation: str rag_used: bool retrieved_passages: list[str] = [] class TypeDBSearchRequest(BaseModel): """TypeDB search request.""" query: str = Field(..., description="Search query (name, type, or location)") search_type: str = Field( default="semantic", description="Search type: semantic, name, type, or location" ) k: int = Field(default=10, ge=1, le=100, description="Number of results") class TypeDBSearchResponse(BaseModel): """TypeDB search response.""" query: str search_type: str results: list[dict[str, Any]] result_count: int query_time_ms: float class PersonSearchRequest(BaseModel): """Person/staff search request.""" query: str = Field(..., description="Search query for person/staff (e.g., 'Wie werkt er in het Nationaal Archief?')") k: int = Field(default=10, ge=1, le=100, description="Number of results to return") filter_custodian: str | None = Field(default=None, description="Filter by custodian slug (e.g., 'nationaal-archief')") only_heritage_relevant: bool = Field(default=False, description="Only return heritage-relevant staff") embedding_model: str | None = Field( default=None, description="Embedding model to use (e.g., 'minilm_384', 'openai_1536'). If None, auto-selects best available." ) class PersonSearchResponse(BaseModel): """Person/staff search response.""" query: str results: list[dict[str, Any]] result_count: int query_time_ms: float collection_stats: dict[str, Any] | None = None embedding_model_used: str | None = None class DSPyQueryRequest(BaseModel): """DSPy RAG query request with conversation support.""" question: str = Field(..., description="Natural language question") language: str = Field(default="nl", description="Language code (nl or en)") context: list[dict[str, Any]] = Field( default=[], description="Conversation history as list of {question, answer} dicts" ) include_visualization: bool = Field(default=True, description="Include visualization config") embedding_model: str | None = Field( default=None, description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available." ) class DSPyQueryResponse(BaseModel): """DSPy RAG query response.""" question: str resolved_question: str | None = None answer: str sources_used: list[str] = [] visualization: dict[str, Any] | None = None retrieved_results: list[dict[str, Any]] | None = None # Raw retrieved data for frontend visualization query_type: str | None = None # "person" or "institution" - helps frontend choose visualization query_time_ms: float = 0.0 conversation_turn: int = 0 embedding_model_used: str | None = None # Which embedding model was used for the search # Cost tracking fields (from cost_tracker module) timing_ms: float | None = None # Total pipeline timing from cost tracker cost_usd: float | None = None # Estimated LLM cost in USD timing_breakdown: dict[str, float] | None = None # Per-stage timing breakdown # Cache Client class ValkeyClient: """Client for Valkey semantic cache API.""" def __init__(self, base_url: str = settings.valkey_api_url): self.base_url = base_url.rstrip("/") self._client: httpx.AsyncClient | None = None @property async def client(self) -> httpx.AsyncClient: """Get or create async HTTP client.""" if self._client is None or self._client.is_closed: self._client = httpx.AsyncClient(timeout=30.0) return self._client def _cache_key(self, question: str, sources: list[DataSource]) -> str: """Generate cache key from question and sources.""" sources_str = ",".join(sorted(s.value for s in sources)) key_str = f"{question.lower().strip()}:{sources_str}" return hashlib.sha256(key_str.encode()).hexdigest()[:32] async def get(self, question: str, sources: list[DataSource]) -> dict[str, Any] | None: """Get cached response.""" try: key = self._cache_key(question, sources) client = await self.client response = await client.get(f"{self.base_url}/get/{key}") if response.status_code == 200: data = response.json() if data.get("value"): logger.info(f"Cache hit for question: {question[:50]}...") return json.loads(data["value"]) # type: ignore[no-any-return] return None except Exception as e: logger.warning(f"Cache get failed: {e}") return None async def set( self, question: str, sources: list[DataSource], response: dict[str, Any], ttl: int = settings.cache_ttl, ) -> bool: """Cache response.""" try: key = self._cache_key(question, sources) client = await self.client await client.post( f"{self.base_url}/set", json={ "key": key, "value": json.dumps(response), "ttl": ttl, }, ) logger.debug(f"Cached response for: {question[:50]}...") return True except Exception as e: logger.warning(f"Cache set failed: {e}") return False async def close(self) -> None: """Close HTTP client.""" if self._client: await self._client.aclose() self._client = None # Query Router class QueryRouter: """Routes queries to appropriate data sources based on intent.""" def __init__(self) -> None: self.intent_keywords = { QueryIntent.GEOGRAPHIC: [ "map", "kaart", "where", "waar", "location", "locatie", "city", "stad", "country", "land", "region", "gebied", "coordinates", "coördinaten", "near", "nearby", "in de buurt", ], QueryIntent.STATISTICAL: [ "how many", "hoeveel", "count", "aantal", "total", "totaal", "average", "gemiddeld", "distribution", "verdeling", "percentage", "statistics", "statistiek", "most", "meest", ], QueryIntent.RELATIONAL: [ "related", "gerelateerd", "connected", "verbonden", "relationship", "relatie", "network", "netwerk", "parent", "child", "merged", "fusie", "member of", ], QueryIntent.TEMPORAL: [ "history", "geschiedenis", "timeline", "tijdlijn", "when", "wanneer", "founded", "opgericht", "closed", "gesloten", "over time", "evolution", "change", "verandering", ], QueryIntent.DETAIL: [ "details", "information", "informatie", "about", "over", "specific", "specifiek", "what is", "wat is", ], } self.source_routing = { QueryIntent.GEOGRAPHIC: [DataSource.POSTGIS, DataSource.QDRANT, DataSource.SPARQL], QueryIntent.STATISTICAL: [DataSource.SPARQL, DataSource.QDRANT], QueryIntent.RELATIONAL: [DataSource.TYPEDB, DataSource.SPARQL], QueryIntent.TEMPORAL: [DataSource.TYPEDB, DataSource.SPARQL], QueryIntent.SEARCH: [DataSource.QDRANT, DataSource.SPARQL], QueryIntent.DETAIL: [DataSource.SPARQL, DataSource.QDRANT], } def detect_intent(self, question: str) -> QueryIntent: """Detect query intent from question text.""" question_lower = question.lower() intent_scores = {intent: 0 for intent in QueryIntent} for intent, keywords in self.intent_keywords.items(): for keyword in keywords: if keyword in question_lower: intent_scores[intent] += 1 max_intent = max(intent_scores, key=intent_scores.get) # type: ignore if intent_scores[max_intent] == 0: return QueryIntent.SEARCH return max_intent def get_sources( self, question: str, requested_sources: list[DataSource] | None = None, ) -> tuple[QueryIntent, list[DataSource]]: """Get optimal sources for a query. Args: question: User's question requested_sources: Explicitly requested sources (overrides routing) Returns: Tuple of (detected_intent, list_of_sources) """ intent = self.detect_intent(question) if requested_sources: return intent, requested_sources return intent, self.source_routing.get(intent, [DataSource.QDRANT]) # Multi-Source Retriever class MultiSourceRetriever: """Orchestrates retrieval across multiple data sources.""" def __init__(self) -> None: self.cache = ValkeyClient() self.router = QueryRouter() # Initialize retrievers lazily self._qdrant: HybridRetriever | None = None self._typedb: TypeDBRetriever | None = None self._sparql_client: httpx.AsyncClient | None = None self._postgis_client: httpx.AsyncClient | None = None @property def qdrant(self) -> HybridRetriever | None: """Lazy-load Qdrant hybrid retriever with multi-embedding support.""" if self._qdrant is None and RETRIEVERS_AVAILABLE: try: self._qdrant = create_hybrid_retriever( use_production=settings.qdrant_use_production, use_multi_embedding=settings.use_multi_embedding, preferred_embedding_model=settings.preferred_embedding_model, ) except Exception as e: logger.warning(f"Failed to initialize Qdrant: {e}") return self._qdrant @property def typedb(self) -> TypeDBRetriever | None: """Lazy-load TypeDB retriever.""" if self._typedb is None and RETRIEVERS_AVAILABLE: try: self._typedb = create_typedb_retriever( use_production=settings.typedb_use_production # Use TypeDB-specific setting ) except Exception as e: logger.warning(f"Failed to initialize TypeDB: {e}") return self._typedb async def _get_sparql_client(self) -> httpx.AsyncClient: """Get SPARQL HTTP client.""" if self._sparql_client is None or self._sparql_client.is_closed: self._sparql_client = httpx.AsyncClient(timeout=30.0) return self._sparql_client async def _get_postgis_client(self) -> httpx.AsyncClient: """Get PostGIS HTTP client.""" if self._postgis_client is None or self._postgis_client.is_closed: self._postgis_client = httpx.AsyncClient(timeout=30.0) return self._postgis_client async def retrieve_from_qdrant( self, query: str, k: int = 10, embedding_model: str | None = None, ) -> RetrievalResult: """Retrieve from Qdrant vector + SPARQL hybrid search. Args: query: Search query k: Number of results to return embedding_model: Optional embedding model to use (e.g., 'minilm_384', 'openai_1536') """ start = asyncio.get_event_loop().time() items = [] if self.qdrant: try: results = self.qdrant.search(query, k=k, using=embedding_model) items = [r.to_dict() for r in results] except Exception as e: logger.error(f"Qdrant retrieval failed: {e}") elapsed = (asyncio.get_event_loop().time() - start) * 1000 return RetrievalResult( source=DataSource.QDRANT, items=items, score=max((r.get("scores", {}).get("combined", 0) for r in items), default=0), query_time_ms=elapsed, ) async def retrieve_from_sparql( self, query: str, k: int = 10, ) -> RetrievalResult: """Retrieve from SPARQL endpoint.""" start = asyncio.get_event_loop().time() # Use DSPy to generate SPARQL items = [] try: if RETRIEVERS_AVAILABLE: sparql_result = generate_sparql(query, language="nl", use_rag=False) sparql_query = sparql_result.get("sparql", "") if sparql_query: client = await self._get_sparql_client() response = await client.post( settings.sparql_endpoint, data={"query": sparql_query}, headers={"Accept": "application/sparql-results+json"}, ) if response.status_code == 200: data = response.json() bindings = data.get("results", {}).get("bindings", []) items = [ {k: v.get("value") for k, v in b.items()} for b in bindings[:k] ] except Exception as e: logger.error(f"SPARQL retrieval failed: {e}") elapsed = (asyncio.get_event_loop().time() - start) * 1000 return RetrievalResult( source=DataSource.SPARQL, items=items, score=1.0 if items else 0.0, query_time_ms=elapsed, ) async def retrieve_from_typedb( self, query: str, k: int = 10, ) -> RetrievalResult: """Retrieve from TypeDB knowledge graph.""" start = asyncio.get_event_loop().time() items = [] if self.typedb: try: results = self.typedb.semantic_search(query, k=k) items = [r.to_dict() for r in results] except Exception as e: logger.error(f"TypeDB retrieval failed: {e}") elapsed = (asyncio.get_event_loop().time() - start) * 1000 return RetrievalResult( source=DataSource.TYPEDB, items=items, score=max((r.get("relevance_score", 0) for r in items), default=0), query_time_ms=elapsed, ) async def retrieve_from_postgis( self, query: str, k: int = 10, ) -> RetrievalResult: """Retrieve from PostGIS geospatial database.""" start = asyncio.get_event_loop().time() # Extract location from query for geospatial search # This is a simplified implementation items = [] try: client = await self._get_postgis_client() # Try to detect city name for bbox search query_lower = query.lower() # Simple city detection cities = { "amsterdam": {"lat": 52.3676, "lon": 4.9041}, "rotterdam": {"lat": 51.9244, "lon": 4.4777}, "den haag": {"lat": 52.0705, "lon": 4.3007}, "utrecht": {"lat": 52.0907, "lon": 5.1214}, } for city, coords in cities.items(): if city in query_lower: # Query PostGIS for nearby institutions response = await client.get( f"{settings.postgis_url}/api/institutions/nearby", params={ "lat": coords["lat"], "lon": coords["lon"], "radius_km": 10, "limit": k, }, ) if response.status_code == 200: items = response.json() break except Exception as e: logger.error(f"PostGIS retrieval failed: {e}") elapsed = (asyncio.get_event_loop().time() - start) * 1000 return RetrievalResult( source=DataSource.POSTGIS, items=items, score=1.0 if items else 0.0, query_time_ms=elapsed, ) async def retrieve( self, question: str, sources: list[DataSource], k: int = 10, embedding_model: str | None = None, ) -> list[RetrievalResult]: """Retrieve from multiple sources concurrently. Args: question: User's question sources: Data sources to query k: Number of results per source embedding_model: Optional embedding model for Qdrant (e.g., 'minilm_384', 'openai_1536') Returns: List of RetrievalResult from each source """ tasks = [] for source in sources: if source == DataSource.QDRANT: tasks.append(self.retrieve_from_qdrant(question, k, embedding_model)) elif source == DataSource.SPARQL: tasks.append(self.retrieve_from_sparql(question, k)) elif source == DataSource.TYPEDB: tasks.append(self.retrieve_from_typedb(question, k)) elif source == DataSource.POSTGIS: tasks.append(self.retrieve_from_postgis(question, k)) results = await asyncio.gather(*tasks, return_exceptions=True) # Filter out exceptions valid_results = [] for r in results: if isinstance(r, RetrievalResult): valid_results.append(r) elif isinstance(r, Exception): logger.error(f"Retrieval task failed: {r}") return valid_results def merge_results( self, results: list[RetrievalResult], max_results: int = 20, ) -> list[dict[str, Any]]: """Merge and deduplicate results from multiple sources. Uses reciprocal rank fusion for score combination. """ # Track items by GHCID for deduplication merged: dict[str, dict[str, Any]] = {} for result in results: for rank, item in enumerate(result.items): ghcid = item.get("ghcid", item.get("id", f"unknown_{rank}")) if ghcid not in merged: merged[ghcid] = item.copy() merged[ghcid]["_sources"] = [] merged[ghcid]["_rrf_score"] = 0.0 # Reciprocal Rank Fusion rrf_score = 1.0 / (60 + rank) # k=60 is standard # Weight by source source_weights = { DataSource.QDRANT: settings.vector_weight, DataSource.SPARQL: settings.graph_weight, DataSource.TYPEDB: settings.typedb_weight, DataSource.POSTGIS: 0.3, } weight = source_weights.get(result.source, 0.5) merged[ghcid]["_rrf_score"] += rrf_score * weight merged[ghcid]["_sources"].append(result.source.value) # Sort by RRF score sorted_items = sorted( merged.values(), key=lambda x: x.get("_rrf_score", 0), reverse=True, ) return sorted_items[:max_results] async def close(self) -> None: """Clean up resources.""" await self.cache.close() if self._sparql_client: await self._sparql_client.aclose() if self._postgis_client: await self._postgis_client.aclose() if self._qdrant: self._qdrant.close() if self._typedb: self._typedb.close() def search_persons( self, query: str, k: int = 10, filter_custodian: str | None = None, only_heritage_relevant: bool = False, using: str | None = None, ) -> list[Any]: """Search for persons/staff in the heritage_persons collection. Delegates to HybridRetriever.search_persons() if available. Args: query: Search query k: Number of results filter_custodian: Optional custodian slug to filter by only_heritage_relevant: Only return heritage-relevant staff using: Optional embedding model to use (e.g., 'minilm_384', 'openai_1536') Returns: List of RetrievedPerson objects """ if self.qdrant: try: return self.qdrant.search_persons( # type: ignore[no-any-return] query=query, k=k, filter_custodian=filter_custodian, only_heritage_relevant=only_heritage_relevant, using=using, ) except Exception as e: logger.error(f"Person search failed: {e}") return [] def get_stats(self) -> dict[str, Any]: """Get statistics from all retrievers. Returns combined stats from Qdrant (including persons collection) and TypeDB. """ stats = {} if self.qdrant: try: qdrant_stats = self.qdrant.get_stats() stats.update(qdrant_stats) except Exception as e: logger.warning(f"Failed to get Qdrant stats: {e}") if self.typedb: try: typedb_stats = self.typedb.get_stats() stats["typedb"] = typedb_stats except Exception as e: logger.warning(f"Failed to get TypeDB stats: {e}") return stats # Global instances retriever: MultiSourceRetriever | None = None viz_selector: VisualizationSelector | None = None @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncIterator[None]: """Application lifespan manager.""" global retriever, viz_selector # Startup logger.info("Starting Heritage RAG API...") retriever = MultiSourceRetriever() if RETRIEVERS_AVAILABLE: # Check for any available LLM API key (Anthropic preferred, OpenAI fallback) has_llm_key = bool(settings.anthropic_api_key or settings.openai_api_key) # VisualizationSelector requires DSPy - make it optional try: viz_selector = VisualizationSelector(use_dspy=has_llm_key) except RuntimeError as e: logger.warning(f"VisualizationSelector not available: {e}") viz_selector = None # Configure DSPy if API key available if configure_dspy and settings.anthropic_api_key: try: configure_dspy( provider="anthropic", model=settings.default_model, api_key=settings.anthropic_api_key, ) except Exception as e: logger.warning(f"Failed to configure DSPy with Anthropic: {e}") elif configure_dspy and settings.openai_api_key: try: configure_dspy( provider="openai", model="gpt-4o-mini", api_key=settings.openai_api_key, ) except Exception as e: logger.warning(f"Failed to configure DSPy with OpenAI: {e}") logger.info("Heritage RAG API started") yield # Shutdown logger.info("Shutting down Heritage RAG API...") if retriever: await retriever.close() logger.info("Heritage RAG API stopped") # Create FastAPI app app = FastAPI( title=settings.api_title, version=settings.api_version, lifespan=lifespan, ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # API Endpoints @app.get("/api/rag/health") async def health_check() -> dict[str, Any]: """Health check for all services.""" health: dict[str, Any] = { "status": "ok", "timestamp": datetime.now(timezone.utc).isoformat(), "services": {}, } # Check Qdrant if retriever and retriever.qdrant: try: stats = retriever.qdrant.get_stats() health["services"]["qdrant"] = { "status": "ok", "vectors": stats.get("qdrant", {}).get("vectors_count", 0), } except Exception as e: health["services"]["qdrant"] = {"status": "error", "error": str(e)} # Check SPARQL try: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(f"{settings.sparql_endpoint.replace('/query', '')}") health["services"]["sparql"] = { "status": "ok" if response.status_code < 500 else "error" } except Exception as e: health["services"]["sparql"] = {"status": "error", "error": str(e)} # Check TypeDB if retriever and retriever.typedb: try: stats = retriever.typedb.get_stats() health["services"]["typedb"] = { "status": "ok", "entities": stats.get("entities", {}), } except Exception as e: health["services"]["typedb"] = {"status": "error", "error": str(e)} # Overall status services = health["services"] errors = sum(1 for s in services.values() if isinstance(s, dict) and s.get("status") == "error") health["status"] = "ok" if errors == 0 else "degraded" if errors < 3 else "error" return health @app.get("/api/rag/stats") async def get_stats() -> dict[str, Any]: """Get retriever statistics.""" stats: dict[str, Any] = { "timestamp": datetime.now(timezone.utc).isoformat(), "retrievers": {}, } if retriever: if retriever.qdrant: stats["retrievers"]["qdrant"] = retriever.qdrant.get_stats() if retriever.typedb: stats["retrievers"]["typedb"] = retriever.typedb.get_stats() return stats @app.get("/api/rag/stats/costs") async def get_cost_stats() -> dict[str, Any]: """Get cost tracking session statistics. Returns cumulative statistics for the current session including: - Total LLM calls and token usage - Total retrieval operations and latencies - Estimated costs by model - Pipeline timing statistics Returns: Dict with cost tracker statistics or unavailable message """ if not COST_TRACKER_AVAILABLE or not get_tracker: return { "available": False, "message": "Cost tracker module not available", } tracker = get_tracker() return { "available": True, "timestamp": datetime.now(timezone.utc).isoformat(), "session": tracker.get_session_summary(), } @app.post("/api/rag/stats/costs/reset") async def reset_cost_stats() -> dict[str, Any]: """Reset cost tracking statistics. Clears all accumulated statistics and starts a fresh session. Useful for per-conversation or per-session cost tracking. Returns: Confirmation message """ if not COST_TRACKER_AVAILABLE or not reset_tracker: return { "available": False, "message": "Cost tracker module not available", } reset_tracker() return { "available": True, "message": "Cost tracking statistics reset", "timestamp": datetime.now(timezone.utc).isoformat(), } @app.get("/api/rag/embedding/models") async def get_embedding_models() -> dict[str, Any]: """List available embedding models for the Qdrant collections. Returns information about which embedding models are available in each collection's named vectors, helping clients choose the right model for their use case. Returns: Dict with available models per collection, current settings, and recommendations """ result: dict[str, Any] = { "timestamp": datetime.now(timezone.utc).isoformat(), "multi_embedding_enabled": settings.use_multi_embedding, "preferred_model": settings.preferred_embedding_model, "collections": {}, "models": { "openai_1536": { "description": "OpenAI text-embedding-3-small (1536 dimensions)", "quality": "high", "cost": "paid API", "recommended_for": "production, high-quality semantic search", }, "minilm_384": { "description": "sentence-transformers/all-MiniLM-L6-v2 (384 dimensions)", "quality": "good", "cost": "free (local)", "recommended_for": "development, cost-sensitive deployments", }, "bge_768": { "description": "BAAI/bge-small-en-v1.5 (768 dimensions)", "quality": "very good", "cost": "free (local)", "recommended_for": "balanced quality/cost, multilingual support", }, }, } if retriever and retriever.qdrant: qdrant = retriever.qdrant # Check if multi-embedding is enabled and get available models if hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding: if hasattr(qdrant, 'multi_retriever') and qdrant.multi_retriever: multi = qdrant.multi_retriever # Get available models for institutions collection try: inst_models = multi.get_available_models("heritage_custodians") selected = multi.select_model("heritage_custodians") result["collections"]["heritage_custodians"] = { "available_models": [m.value for m in inst_models], "uses_named_vectors": multi.uses_named_vectors("heritage_custodians"), "recommended": selected.value if selected else None, } except Exception as e: result["collections"]["heritage_custodians"] = {"error": str(e)} # Get available models for persons collection try: person_models = multi.get_available_models("heritage_persons") selected = multi.select_model("heritage_persons") result["collections"]["heritage_persons"] = { "available_models": [m.value for m in person_models], "uses_named_vectors": multi.uses_named_vectors("heritage_persons"), "recommended": selected.value if selected else None, } except Exception as e: result["collections"]["heritage_persons"] = {"error": str(e)} else: # Single embedding mode - detect dimension stats = qdrant.get_stats() result["single_embedding_mode"] = True result["note"] = "Collections use single embedding vectors. Enable USE_MULTI_EMBEDDING=true to use named vectors." return result class EmbeddingCompareRequest(BaseModel): """Request for comparing embedding models.""" query: str = Field(..., description="Query to search with") collection: str = Field(default="heritage_persons", description="Collection to search") k: int = Field(default=5, ge=1, le=20, description="Number of results per model") @app.post("/api/rag/embedding/compare") async def compare_embedding_models(request: EmbeddingCompareRequest) -> dict[str, Any]: """Compare search results across different embedding models. Performs the same search query using each available embedding model, allowing A/B testing of embedding quality. This endpoint is useful for: - Evaluating which embedding model works best for your queries - Understanding differences in semantic similarity between models - Making informed decisions about which model to use in production Returns: Dict with results from each embedding model, including scores and overlap analysis """ import time start_time = time.time() if not retriever or not retriever.qdrant: raise HTTPException(status_code=503, detail="Qdrant retriever not available") qdrant = retriever.qdrant # Check if multi-embedding is available if not (hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding): raise HTTPException( status_code=400, detail="Multi-embedding mode not enabled. Set USE_MULTI_EMBEDDING=true to use this endpoint." ) if not (hasattr(qdrant, 'multi_retriever') and qdrant.multi_retriever): raise HTTPException(status_code=503, detail="Multi-embedding retriever not initialized") multi = qdrant.multi_retriever try: # Use the compare_models method from MultiEmbeddingRetriever comparison = multi.compare_models( query=request.query, collection=request.collection, k=request.k, ) elapsed_ms = (time.time() - start_time) * 1000 return { "query": request.query, "collection": request.collection, "k": request.k, "query_time_ms": round(elapsed_ms, 2), "comparison": comparison, } except Exception as e: logger.exception(f"Embedding comparison failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/rag/query", response_model=QueryResponse) async def query_rag(request: QueryRequest) -> QueryResponse: """Main RAG query endpoint. Orchestrates retrieval from multiple sources, merges results, and optionally generates visualization configuration. """ if not retriever: raise HTTPException(status_code=503, detail="Retriever not initialized") start_time = asyncio.get_event_loop().time() # Check cache first cached = await retriever.cache.get(request.question, request.sources) if cached: cached["cache_hit"] = True return QueryResponse(**cached) # Route query to appropriate sources intent, sources = retriever.router.get_sources(request.question, request.sources) logger.info(f"Query intent: {intent}, sources: {sources}") # Retrieve from all sources results = await retriever.retrieve( request.question, sources, request.k, embedding_model=request.embedding_model, ) # Merge results merged_items = retriever.merge_results(results, max_results=request.k * 2) # Generate visualization config if requested visualization = None if request.include_visualization and viz_selector and merged_items: # Extract schema from first result schema_fields = list(merged_items[0].keys()) if merged_items else [] schema_str = ", ".join(f for f in schema_fields if not f.startswith("_")) visualization = viz_selector.select( request.question, schema_str, len(merged_items), ) elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000 response_data = { "question": request.question, "sparql": None, # Could be populated from SPARQL result "results": merged_items, "visualization": visualization, "sources_used": [s for s in sources], "cache_hit": False, "query_time_ms": round(elapsed_ms, 2), "result_count": len(merged_items), } # Cache the response await retriever.cache.set(request.question, request.sources, response_data) return QueryResponse(**response_data) # type: ignore[arg-type] @app.post("/api/rag/sparql", response_model=SPARQLResponse) async def generate_sparql_endpoint(request: SPARQLRequest) -> SPARQLResponse: """Generate SPARQL query from natural language. Uses DSPy with optional RAG enhancement for context. """ if not RETRIEVERS_AVAILABLE: raise HTTPException(status_code=503, detail="SPARQL generator not available") try: result = generate_sparql( request.question, language=request.language, context=request.context, use_rag=request.use_rag, ) return SPARQLResponse( sparql=result["sparql"], explanation=result.get("explanation", ""), rag_used=result.get("rag_used", False), retrieved_passages=result.get("retrieved_passages", []), ) except Exception as e: logger.exception("SPARQL generation failed") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/rag/visualize") async def get_visualization_config( question: str = Query(..., description="User's question"), schema: str = Query(..., description="Comma-separated field names"), result_count: int = Query(default=0, description="Number of results"), ) -> dict[str, Any]: """Get visualization configuration for a query.""" if not viz_selector: raise HTTPException(status_code=503, detail="Visualization selector not available") config = viz_selector.select(question, schema, result_count) return config # type: ignore[no-any-return] @app.post("/api/rag/typedb/search", response_model=TypeDBSearchResponse) async def typedb_search(request: TypeDBSearchRequest) -> TypeDBSearchResponse: """Direct TypeDB search endpoint. Search heritage custodians in TypeDB using various strategies: - semantic: Natural language search (combines type + location patterns) - name: Search by institution name - type: Search by institution type (museum, archive, library, gallery) - location: Search by city/location name Examples: - {"query": "museums in Amsterdam", "search_type": "semantic"} - {"query": "Rijksmuseum", "search_type": "name"} - {"query": "archive", "search_type": "type"} - {"query": "Rotterdam", "search_type": "location"} """ import time start_time = time.time() # Check if TypeDB retriever is available if not retriever or not retriever.typedb: raise HTTPException( status_code=503, detail="TypeDB retriever not available. Ensure TypeDB is running." ) try: typedb_retriever = retriever.typedb # Route to appropriate search method if request.search_type == "name": results = typedb_retriever.search_by_name(request.query, k=request.k) elif request.search_type == "type": results = typedb_retriever.search_by_type(request.query, k=request.k) elif request.search_type == "location": results = typedb_retriever.search_by_location(city=request.query, k=request.k) else: # semantic (default) results = typedb_retriever.semantic_search(request.query, k=request.k) # Convert results to dicts result_dicts = [] seen_names = set() # Deduplicate by name for r in results: # Handle both dict and object results if hasattr(r, 'to_dict'): item = r.to_dict() elif isinstance(r, dict): item = r else: item = {"name": str(r)} # Deduplicate by name name = item.get("name") or item.get("observed_name", "") if name and name not in seen_names: seen_names.add(name) result_dicts.append(item) elapsed_ms = (time.time() - start_time) * 1000 return TypeDBSearchResponse( query=request.query, search_type=request.search_type, results=result_dicts, result_count=len(result_dicts), query_time_ms=round(elapsed_ms, 2), ) except Exception as e: logger.exception(f"TypeDB search failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/rag/persons/search", response_model=PersonSearchResponse) async def person_search(request: PersonSearchRequest) -> PersonSearchResponse: """Search for persons/staff in heritage institutions. Search the heritage_persons Qdrant collection for staff members at heritage custodian institutions. Examples: - {"query": "Wie werkt er in het Nationaal Archief?"} - {"query": "archivist at Rijksmuseum", "k": 20} - {"query": "conservator", "filter_custodian": "rijksmuseum"} - {"query": "digital preservation", "only_heritage_relevant": true} The search uses semantic vector similarity to find relevant staff members based on their name, role, headline, and custodian affiliation. """ import time start_time = time.time() # Check if retriever is available if not retriever: raise HTTPException( status_code=503, detail="Hybrid retriever not available. Ensure Qdrant is running." ) try: # Use the hybrid retriever's person search results = retriever.search_persons( query=request.query, k=request.k, filter_custodian=request.filter_custodian, only_heritage_relevant=request.only_heritage_relevant, using=request.embedding_model, # Pass embedding model ) # Determine which embedding model was actually used embedding_model_used = None qdrant = retriever.qdrant if qdrant and hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding: if request.embedding_model: embedding_model_used = request.embedding_model elif hasattr(qdrant, '_selected_multi_model') and qdrant._selected_multi_model: embedding_model_used = qdrant._selected_multi_model.value # Convert results to dicts using to_dict() method if available result_dicts = [] for r in results: if hasattr(r, 'to_dict'): item = r.to_dict() elif hasattr(r, '__dict__'): item = { "name": getattr(r, 'name', 'Unknown'), "headline": getattr(r, 'headline', None), "custodian_name": getattr(r, 'custodian_name', None), "custodian_slug": getattr(r, 'custodian_slug', None), "linkedin_url": getattr(r, 'linkedin_url', None), "heritage_relevant": getattr(r, 'heritage_relevant', None), "heritage_type": getattr(r, 'heritage_type', None), "location": getattr(r, 'location', None), "score": getattr(r, 'combined_score', getattr(r, 'vector_score', None)), } elif isinstance(r, dict): item = r else: item = {"name": str(r)} result_dicts.append(item) elapsed_ms = (time.time() - start_time) * 1000 # Get collection stats stats = None try: stats = retriever.get_stats() # Only include person collection stats if available if stats and 'persons' in stats: stats = {'persons': stats['persons']} except Exception: pass return PersonSearchResponse( query=request.query, results=result_dicts, result_count=len(result_dicts), query_time_ms=round(elapsed_ms, 2), collection_stats=stats, embedding_model_used=embedding_model_used, ) except Exception as e: logger.exception(f"Person search failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/rag/dspy/query", response_model=DSPyQueryResponse) async def dspy_query(request: DSPyQueryRequest) -> DSPyQueryResponse: """DSPy RAG query endpoint with multi-turn conversation support. Uses the HeritageRAGPipeline for conversation-aware question answering. Follow-up questions like "Welke daarvan behoren archieven?" will be resolved using previous conversation context. Args: request: Query request with question, language, and conversation context Returns: DSPyQueryResponse with answer, resolved question, and optional visualization """ import time start_time = time.time() try: # Import DSPy pipeline and History import dspy from dspy import History from dspy_heritage_rag import HeritageRAGPipeline # Ensure DSPy has an LM configured # Check if LM is already configured by testing if we can get the settings try: current_lm = dspy.settings.lm if current_lm is None: raise ValueError("No LM configured") except (AttributeError, ValueError): # No LM configured yet - try to configure one api_key = settings.anthropic_api_key or os.getenv("ANTHROPIC_API_KEY", "") if api_key: lm = dspy.LM("anthropic/claude-sonnet-4-20250514", api_key=api_key) dspy.configure(lm=lm) logger.info("Configured DSPy with Anthropic Claude") else: # Try OpenAI as fallback openai_key = os.getenv("OPENAI_API_KEY", "") if openai_key: lm = dspy.LM("openai/gpt-4o-mini", api_key=openai_key) dspy.configure(lm=lm) logger.info("Configured DSPy with OpenAI GPT-4o-mini") else: raise ValueError( "No LLM API key found. Set ANTHROPIC_API_KEY or OPENAI_API_KEY environment variable." ) # Convert context to DSPy History format # Context comes as [{question: "...", answer: "..."}, ...] # History expects messages in the same format: [{question: "...", answer: "..."}, ...] # (NOT role/content format - that was a bug!) history_messages = [] for turn in request.context: # Only include turns that have both question AND answer if turn.get("question") and turn.get("answer"): history_messages.append({ "question": turn["question"], "answer": turn["answer"] }) history = History(messages=history_messages) if history_messages else None # Initialize pipeline with retriever for actual data retrieval # Pass the qdrant retriever (HybridRetriever) for person/institution searches qdrant_retriever = retriever.qdrant if retriever else None pipeline = HeritageRAGPipeline(retriever=qdrant_retriever) # Execute query with conversation history # Retry logic for transient API errors (e.g., Anthropic "Overloaded" errors) max_retries = 3 last_error: Exception | None = None result = None for attempt in range(max_retries): try: result = pipeline.forward( embedding_model=request.embedding_model, question=request.question, language=request.language, history=history, include_viz=request.include_visualization, ) break # Success, exit retry loop except Exception as e: last_error = e error_str = str(e).lower() # Check for retryable errors (API overload, rate limits, temporary failures) is_retryable = any(keyword in error_str for keyword in [ "overloaded", "rate_limit", "rate limit", "too many requests", "529", "503", "502", "504", # HTTP status codes "temporarily unavailable", "service unavailable", "connection reset", "connection refused", "timeout" ]) if is_retryable and attempt < max_retries - 1: wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s logger.warning( f"Transient API error (attempt {attempt + 1}/{max_retries}): {e}. " f"Retrying in {wait_time}s..." ) time.sleep(wait_time) continue else: # Non-retryable error or max retries reached raise # If we get here without a result (all retries exhausted), raise the last error if result is None: if last_error: raise last_error raise HTTPException(status_code=500, detail="Pipeline execution failed with no result") elapsed_ms = (time.time() - start_time) * 1000 # Extract visualization if present visualization = None if request.include_visualization and hasattr(result, "visualization"): viz = result.visualization if viz: visualization = { "type": getattr(viz, "viz_type", "table"), "sparql_query": getattr(result, "sparql", None), } # Extract retrieved results for frontend visualization (tables, graphs) retrieved_results = getattr(result, "retrieved_results", None) query_type = getattr(result, "query_type", None) return DSPyQueryResponse( question=request.question, resolved_question=getattr(result, "resolved_question", None), answer=getattr(result, "answer", "Geen antwoord gevonden."), sources_used=getattr(result, "sources_used", []), visualization=visualization, retrieved_results=retrieved_results, # Raw data for frontend visualization query_type=query_type, # "person" or "institution" query_time_ms=round(elapsed_ms, 2), conversation_turn=len(request.context), embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model), # Cost tracking fields timing_ms=getattr(result, "timing_ms", None), cost_usd=getattr(result, "cost_usd", None), timing_breakdown=getattr(result, "timing_breakdown", None), ) except ImportError as e: logger.warning(f"DSPy pipeline not available: {e}") # Fallback to simple response return DSPyQueryResponse( question=request.question, answer="DSPy pipeline is niet beschikbaar. Probeer de standaard /api/rag/query endpoint.", query_time_ms=0, conversation_turn=len(request.context), embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model), ) except Exception as e: logger.exception("DSPy query failed") raise HTTPException(status_code=500, detail=str(e)) async def stream_query_response( request: QueryRequest, ) -> AsyncIterator[str]: """Stream query response for long-running queries.""" if not retriever: yield json.dumps({"error": "Retriever not initialized"}) return start_time = asyncio.get_event_loop().time() # Route query intent, sources = retriever.router.get_sources(request.question, request.sources) yield json.dumps({ "type": "status", "message": f"Routing query to {len(sources)} sources...", "intent": intent.value, }) + "\n" # Retrieve from sources and stream progress results = [] for source in sources: yield json.dumps({ "type": "status", "message": f"Querying {source.value}...", }) + "\n" source_results = await retriever.retrieve( request.question, [source], request.k, embedding_model=request.embedding_model, ) results.extend(source_results) yield json.dumps({ "type": "partial", "source": source.value, "count": len(source_results[0].items) if source_results else 0, }) + "\n" # Merge and finalize merged = retriever.merge_results(results) elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000 yield json.dumps({ "type": "complete", "results": merged, "query_time_ms": round(elapsed_ms, 2), "result_count": len(merged), }) + "\n" @app.post("/api/rag/query/stream") async def query_rag_stream(request: QueryRequest) -> StreamingResponse: """Streaming version of RAG query endpoint.""" return StreamingResponse( stream_query_response(request), media_type="application/x-ndjson", ) # Main entry point if __name__ == "__main__": import uvicorn uvicorn.run( "main:app", host="0.0.0.0", port=8003, reload=settings.debug, log_level="info", )