from __future__ import annotations """ Unified RAG Backend for Heritage Custodian Data Multi-source retrieval-augmented generation system that combines: - Qdrant vector search (semantic similarity) - Oxigraph SPARQL (knowledge graph queries) - TypeDB (relationship traversal) - PostGIS (geospatial queries) - Valkey (semantic caching) Architecture: User Query → Query Analysis ↓ ┌─────┴─────┐ │ Router │ └─────┬─────┘ ┌─────┬─────┼─────┬─────┐ ↓ ↓ ↓ ↓ ↓ Qdrant SPARQL TypeDB PostGIS Cache │ │ │ │ │ └─────┴─────┴─────┴─────┘ ↓ ┌─────┴─────┐ │ Merger │ └─────┬─────┘ ↓ DSPy Generator ↓ Visualization Selector ↓ Response (JSON/Streaming) Features: - Intelligent query routing to appropriate data sources - Score fusion for multi-source results - Semantic caching via Valkey API - Streaming responses for long-running queries - DSPy assertions for output validation Endpoints: - POST /api/rag/query - Main RAG query endpoint - POST /api/rag/sparql - Generate SPARQL with RAG context - POST /api/rag/typedb/search - Direct TypeDB search - GET /api/rag/health - Health check for all services - GET /api/rag/stats - Retriever statistics """ import asyncio import hashlib import json import logging import os import time from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum from typing import Any, AsyncIterator, TYPE_CHECKING import httpx from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field # Rule 46: Epistemic Provenance Tracking from provenance import ( EpistemicProvenance, EpistemicDataSource, DataTier, RetrievalSource, SourceAttribution, infer_data_tier, build_derivation_chain, aggregate_data_tier, ) # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # Type hints for optional imports (only used during type checking) if TYPE_CHECKING: from glam_extractor.api.hybrid_retriever import HybridRetriever from glam_extractor.api.typedb_retriever import TypeDBRetriever from glam_extractor.api.visualization import VisualizationSelector # Import retrievers (with graceful fallbacks) RETRIEVERS_AVAILABLE = False create_hybrid_retriever: Any = None HeritageCustodianRetriever: Any = None create_typedb_retriever: Any = None select_visualization: Any = None VisualizationSelector: Any = None # type: ignore[no-redef] generate_sparql: Any = None configure_dspy: Any = None get_province_code: Any = None # Province name to ISO 3166-2 code converter try: import sys sys.path.insert(0, str(os.path.join(os.path.dirname(__file__), "..", "..", "src"))) from glam_extractor.api.hybrid_retriever import ( HybridRetriever as _HybridRetriever, create_hybrid_retriever as _create_hybrid_retriever, get_province_code as _get_province_code, PERSON_JSONLD_CONTEXT, ) from glam_extractor.api.qdrant_retriever import HeritageCustodianRetriever as _HeritageCustodianRetriever from glam_extractor.api.typedb_retriever import TypeDBRetriever as _TypeDBRetriever, create_typedb_retriever as _create_typedb_retriever from glam_extractor.api.visualization import select_visualization as _select_visualization, VisualizationSelector as _VisualizationSelector # Assign to module-level variables create_hybrid_retriever = _create_hybrid_retriever HeritageCustodianRetriever = _HeritageCustodianRetriever create_typedb_retriever = _create_typedb_retriever select_visualization = _select_visualization VisualizationSelector = _VisualizationSelector get_province_code = _get_province_code RETRIEVERS_AVAILABLE = True except ImportError as e: logger.warning(f"Core retrievers not available: {e}") # Provide a fallback get_province_code that returns None def get_province_code(province_name: str | None) -> str | None: """Fallback when hybrid_retriever is not available.""" return None # DSPy is optional - don't block retrievers if it's missing try: from glam_extractor.api.dspy_sparql import generate_sparql as _generate_sparql, configure_dspy as _configure_dspy generate_sparql = _generate_sparql configure_dspy = _configure_dspy except ImportError as e: logger.warning(f"DSPy SPARQL not available: {e}") # Atomic query decomposition for geographic/type filtering and sub-task caching decompose_query: Any = None AtomicCacheManager: Any = None DECOMPOSER_AVAILABLE = False ATOMIC_CACHE_AVAILABLE = False try: from atomic_decomposer import ( decompose_query as _decompose_query, AtomicCacheManager as _AtomicCacheManager, ) decompose_query = _decompose_query AtomicCacheManager = _AtomicCacheManager DECOMPOSER_AVAILABLE = True ATOMIC_CACHE_AVAILABLE = True logger.info("Query decomposer and AtomicCacheManager loaded successfully") except ImportError as e: logger.info(f"Query decomposer not available: {e}") # Cost tracker is optional - gracefully degrades if unavailable COST_TRACKER_AVAILABLE = False get_tracker = None reset_tracker = None try: from cost_tracker import get_tracker as _get_tracker, reset_tracker as _reset_tracker get_tracker = _get_tracker reset_tracker = _reset_tracker COST_TRACKER_AVAILABLE = True logger.info("Cost tracker module loaded successfully") except ImportError as e: logger.info(f"Cost tracker not available (optional): {e}") # Session manager for multi-turn conversation state SESSION_MANAGER_AVAILABLE = False get_session_manager = None shutdown_session_manager = None ConversationState = None try: from session_manager import ( get_session_manager as _get_session_manager, shutdown_session_manager as _shutdown_session_manager, ConversationState as _ConversationState, ) get_session_manager = _get_session_manager shutdown_session_manager = _shutdown_session_manager ConversationState = _ConversationState SESSION_MANAGER_AVAILABLE = True logger.info("Session manager module loaded successfully") except ImportError as e: logger.info(f"Session manager not available (optional): {e}") # Template-based SPARQL pipeline (deterministic, validated queries) # This provides 65% precision vs 10% for LLM-only SPARQL generation TEMPLATE_SPARQL_AVAILABLE = False TemplateSPARQLPipeline: Any = None get_template_pipeline: Any = None _template_pipeline_instance: Any = None # Singleton for reuse try: from template_sparql import ( TemplateSPARQLPipeline as _TemplateSPARQLPipeline, get_template_pipeline as _get_template_pipeline, ) TemplateSPARQLPipeline = _TemplateSPARQLPipeline get_template_pipeline = _get_template_pipeline TEMPLATE_SPARQL_AVAILABLE = True logger.info("Template SPARQL pipeline loaded successfully") except ImportError as e: logger.info(f"Template SPARQL pipeline not available (optional): {e}") # Prometheus metrics for monitoring template hit rate, latency, etc. METRICS_AVAILABLE = False record_query = None record_template_matching = None record_template_tier = None record_atomic_cache = None record_atomic_subtask_cached = None record_connection_pool = None record_embedding_warmup = None record_template_embedding_warmup = None set_warmup_status = None set_active_sessions = None create_metrics_endpoint = None get_template_hit_rate = None get_all_performance_stats = None try: from metrics import ( record_query as _record_query, record_template_matching as _record_template_matching, record_template_tier as _record_template_tier, record_atomic_cache as _record_atomic_cache, record_atomic_subtask_cached as _record_atomic_subtask_cached, record_connection_pool as _record_connection_pool, record_embedding_warmup as _record_embedding_warmup, record_template_embedding_warmup as _record_template_embedding_warmup, set_warmup_status as _set_warmup_status, set_active_sessions as _set_active_sessions, create_metrics_endpoint as _create_metrics_endpoint, get_template_hit_rate as _get_template_hit_rate, get_all_performance_stats as _get_all_performance_stats, PROMETHEUS_AVAILABLE, ) record_query = _record_query record_template_matching = _record_template_matching record_template_tier = _record_template_tier record_atomic_cache = _record_atomic_cache record_atomic_subtask_cached = _record_atomic_subtask_cached record_connection_pool = _record_connection_pool record_embedding_warmup = _record_embedding_warmup record_template_embedding_warmup = _record_template_embedding_warmup set_warmup_status = _set_warmup_status set_active_sessions = _set_active_sessions create_metrics_endpoint = _create_metrics_endpoint get_template_hit_rate = _get_template_hit_rate get_all_performance_stats = _get_all_performance_stats METRICS_AVAILABLE = PROMETHEUS_AVAILABLE logger.info(f"Metrics module loaded (prometheus={PROMETHEUS_AVAILABLE})") except ImportError as e: logger.info(f"Metrics module not available (optional): {e}") # Province detection for geographic filtering DUTCH_PROVINCES = { "noord-holland", "noordholland", "north holland", "north-holland", "zuid-holland", "zuidholland", "south holland", "south-holland", "utrecht", "gelderland", "noord-brabant", "noordbrabant", "brabant", "north brabant", "limburg", "overijssel", "friesland", "fryslân", "fryslan", "groningen", "drenthe", "flevoland", "zeeland", } def infer_location_level(location: str) -> str: """Infer whether location is city, province, or region. Returns: 'province' if location is a Dutch province 'region' if location is a sub-provincial region 'city' otherwise """ location_lower = location.lower().strip() if location_lower in DUTCH_PROVINCES: return "province" # Sub-provincial regions regions = {"randstad", "veluwe", "achterhoek", "twente", "de betuwe", "betuwe"} if location_lower in regions: return "region" return "city" def extract_geographic_filters(question: str) -> dict[str, list[str] | None]: """Extract geographic filters from a question using query decomposition. Returns: dict with keys: region_codes, cities, institution_types """ filters: dict[str, list[str] | None] = { "region_codes": None, "cities": None, "institution_types": None, } if not DECOMPOSER_AVAILABLE or not decompose_query: return filters # Check for explicit city markers BEFORE decomposition # This overrides province disambiguation when user explicitly says "de stad" question_lower = question.lower() explicit_city_markers = [ "de stad ", "in de stad", "stad van", "gemeente ", "the city of", "city of ", "in the city" ] force_city = any(marker in question_lower for marker in explicit_city_markers) try: decomposed = decompose_query(question) # Extract location and determine if it's a province or city if decomposed.location: location = decomposed.location # If user explicitly said "de stad", treat as city even if it's a province name if force_city: filters["cities"] = [location] logger.info(f"City filter (explicit): {location}") else: level = infer_location_level(location) if level == "province": # Convert province name to ISO 3166-2 code for Qdrant filtering # e.g., "Noord-Holland" → "NH" province_code = get_province_code(location) if province_code: filters["region_codes"] = [province_code] logger.info(f"Province filter: {location} → {province_code}") elif level == "city": filters["cities"] = [location] logger.info(f"City filter: {location}") # Extract institution type if decomposed.institution_type: # Map common types to enum values type_mapping = { "archive": "ARCHIVE", "archief": "ARCHIVE", "archieven": "ARCHIVE", "museum": "MUSEUM", "musea": "MUSEUM", "museums": "MUSEUM", "library": "LIBRARY", "bibliotheek": "LIBRARY", "bibliotheken": "LIBRARY", "gallery": "GALLERY", "galerie": "GALLERY", } inst_type = decomposed.institution_type.lower() mapped_type = type_mapping.get(inst_type, inst_type.upper()) filters["institution_types"] = [mapped_type] logger.info(f"Institution type filter: {mapped_type}") except Exception as e: logger.warning(f"Failed to extract geographic filters: {e}") return filters # Configuration class Settings: """Application settings from environment variables.""" # API Configuration api_title: str = "Heritage RAG API" api_version: str = "1.0.0" debug: bool = os.getenv("DEBUG", "false").lower() == "true" # Valkey Cache valkey_api_url: str = os.getenv("VALKEY_API_URL", "https://bronhouder.nl/api/cache") cache_ttl: int = int(os.getenv("CACHE_TTL", "900")) # 15 minutes # Qdrant Vector DB # Production: Use URL-based client via bronhouder.nl/qdrant reverse proxy qdrant_host: str = os.getenv("QDRANT_HOST", "localhost") qdrant_port: int = int(os.getenv("QDRANT_PORT", "6333")) qdrant_use_production: bool = os.getenv("QDRANT_USE_PRODUCTION", "true").lower() == "true" qdrant_production_url: str = os.getenv("QDRANT_PRODUCTION_URL", "https://bronhouder.nl/qdrant") # Multi-Embedding Support # Enable to use named vectors with multiple embedding models (OpenAI 1536, MiniLM 384, BGE 768) use_multi_embedding: bool = os.getenv("USE_MULTI_EMBEDDING", "true").lower() == "true" preferred_embedding_model: str | None = os.getenv("PREFERRED_EMBEDDING_MODEL", None) # e.g., "minilm_384" or "openai_1536" # Oxigraph SPARQL # Production: Use bronhouder.nl/sparql reverse proxy sparql_endpoint: str = os.getenv("SPARQL_ENDPOINT", "https://bronhouder.nl/sparql") # TypeDB # Note: TypeDB not exposed via reverse proxy - always use localhost typedb_host: str = os.getenv("TYPEDB_HOST", "localhost") typedb_port: int = int(os.getenv("TYPEDB_PORT", "1729")) typedb_database: str = os.getenv("TYPEDB_DATABASE", "heritage_custodians") typedb_use_production: bool = os.getenv("TYPEDB_USE_PRODUCTION", "false").lower() == "true" # Default off # PostGIS/Geo API # Production: Use bronhouder.nl/api/geo reverse proxy postgis_url: str = os.getenv("POSTGIS_URL", "https://bronhouder.nl/api/geo") # NOTE: DuckLake removed from RAG - it's for offline analytics only, not real-time retrieval # RAG uses only Qdrant (vectors) and Oxigraph (SPARQL) for retrieval # LLM Configuration anthropic_api_key: str = os.getenv("ANTHROPIC_API_KEY", "") or os.getenv("CLAUDE_API_KEY", "") openai_api_key: str = os.getenv("OPENAI_API_KEY", "") huggingface_api_key: str = os.getenv("HUGGINGFACE_API_KEY", "") groq_api_key: str = os.getenv("GROQ_API_KEY", "") zai_api_token: str = os.getenv("ZAI_API_TOKEN", "") default_model: str = os.getenv("DEFAULT_MODEL", "claude-opus-4-5-20251101") # LLM Provider: "anthropic", "openai", "huggingface", "zai" (FREE), or "groq" (FREE) llm_provider: str = os.getenv("LLM_PROVIDER", "anthropic") # LLM Model: Specific model to use. Defaults depend on provider. # For Z.AI: "glm-4.5-flash" (fast, recommended) or "glm-4.6" (reasoning, slow) llm_model: str = os.getenv("LLM_MODEL", "glm-4.5-flash") # Fast LM Provider for routing/extraction: "openai" (fast ~1-2s) or "zai" (FREE but slow ~13s) # Default to openai for speed. Set to "zai" to save costs (free but adds ~12s latency) fast_lm_provider: str = os.getenv("FAST_LM_PROVIDER", "openai") # Retrieval weights vector_weight: float = float(os.getenv("VECTOR_WEIGHT", "0.5")) graph_weight: float = float(os.getenv("GRAPH_WEIGHT", "0.3")) typedb_weight: float = float(os.getenv("TYPEDB_WEIGHT", "0.2")) settings = Settings() # Enums and Models class QueryIntent(str, Enum): """Detected query intent for routing.""" GEOGRAPHIC = "geographic" # Location-based queries STATISTICAL = "statistical" # Counts, aggregations RELATIONAL = "relational" # Relationships between entities TEMPORAL = "temporal" # Historical, timeline queries SEARCH = "search" # General text search DETAIL = "detail" # Specific entity lookup class DataSource(str, Enum): """Available data sources for RAG retrieval. NOTE: DuckLake removed - it's for offline analytics only, not real-time RAG retrieval. RAG uses Qdrant (vectors) and Oxigraph (SPARQL) as primary backends. """ QDRANT = "qdrant" SPARQL = "sparql" TYPEDB = "typedb" POSTGIS = "postgis" CACHE = "cache" # DUCKLAKE removed - use DuckLake separately for offline analytics/dashboards @dataclass class RetrievalResult: """Result from a single retriever.""" source: DataSource items: list[dict[str, Any]] score: float = 0.0 query_time_ms: float = 0.0 metadata: dict[str, Any] = field(default_factory=dict) class QueryRequest(BaseModel): """RAG query request.""" question: str = Field(..., description="Natural language question") language: str = Field(default="nl", description="Language code (nl or en)") context: list[dict[str, Any]] = Field(default=[], description="Conversation history") sources: list[DataSource] | None = Field( default=None, description="Data sources to query. If None, auto-routes based on query intent.", ) k: int = Field(default=10, description="Number of results per source") include_visualization: bool = Field(default=True, description="Include visualization config") embedding_model: str | None = Field( default=None, description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available." ) stream: bool = Field(default=False, description="Stream response") class QueryResponse(BaseModel): """RAG query response.""" question: str sparql: str | None = None results: list[dict[str, Any]] visualization: dict[str, Any] | None = None sources_used: list[DataSource] cache_hit: bool = False query_time_ms: float result_count: int class SPARQLRequest(BaseModel): """SPARQL generation request.""" question: str language: str = "nl" context: list[dict[str, Any]] = [] use_rag: bool = True class SPARQLResponse(BaseModel): """SPARQL generation response.""" sparql: str explanation: str rag_used: bool retrieved_passages: list[str] = [] class SPARQLExecuteRequest(BaseModel): """Execute a SPARQL query directly against the knowledge graph.""" sparql_query: str = Field(..., description="SPARQL query to execute") timeout: float = Field(default=30.0, ge=1.0, le=120.0, description="Query timeout in seconds") class SPARQLExecuteResponse(BaseModel): """Response from direct SPARQL execution.""" results: list[dict[str, Any]] = Field(default=[], description="Query results as list of dicts") result_count: int = Field(default=0, description="Number of results") query_time_ms: float = Field(default=0.0, description="Query execution time in milliseconds") error: str | None = Field(default=None, description="Error message if query failed") class SPARQLRerunRequest(BaseModel): """Re-run RAG pipeline with modified SPARQL results injected into context.""" sparql_query: str = Field(..., description="Modified SPARQL query to execute") original_question: str = Field(..., description="Original user question") conversation_history: list[dict[str, Any]] = Field( default=[], description="Previous conversation turns" ) language: str = Field(default="nl", description="Language code (nl or en)") llm_provider: str | None = Field(default=None, description="LLM provider to use") llm_model: str | None = Field(default=None, description="Specific LLM model") class SPARQLRerunResponse(BaseModel): """Response from re-running RAG with modified SPARQL context.""" results: list[dict[str, Any]] = Field(default=[], description="SPARQL query results") answer: str = Field(default="", description="Re-generated answer based on modified SPARQL results") sparql_result_count: int = Field(default=0, description="Number of SPARQL results") query_time_ms: float = Field(default=0.0, description="Total processing time") class TypeDBSearchRequest(BaseModel): """TypeDB search request.""" query: str = Field(..., description="Search query (name, type, or location)") search_type: str = Field( default="semantic", description="Search type: semantic, name, type, or location" ) k: int = Field(default=10, ge=1, le=100, description="Number of results") class TypeDBSearchResponse(BaseModel): """TypeDB search response.""" query: str search_type: str results: list[dict[str, Any]] result_count: int query_time_ms: float class PersonSearchRequest(BaseModel): """Person/staff search request.""" query: str = Field(..., description="Search query for person/staff (e.g., 'Wie werkt er in het Nationaal Archief?')") k: int = Field(default=10, ge=1, le=100, description="Number of results to return") filter_custodian: str | None = Field(default=None, description="Filter by custodian slug (e.g., 'nationaal-archief')") only_heritage_relevant: bool = Field(default=False, description="Only return heritage-relevant staff") only_wcms: bool = Field(default=False, description="Only return WCMS-registered profiles (heritage sector users)") embedding_model: str | None = Field( default=None, description="Embedding model to use (e.g., 'minilm_384', 'openai_1536'). If None, auto-selects best available." ) class PersonSearchResponse(BaseModel): """Person/staff search response with JSON-LD linked data.""" context: dict[str, Any] | None = Field( default=None, alias="@context", description="JSON-LD context for linked data semantic interoperability" ) query: str results: list[dict[str, Any]] result_count: int query_time_ms: float collection_stats: dict[str, Any] | None = None embedding_model_used: str | None = None model_config = {"populate_by_name": True} class DSPyQueryRequest(BaseModel): """DSPy RAG query request with conversation support.""" question: str = Field(..., description="Natural language question") language: str = Field(default="nl", description="Language code (nl or en)") context: list[dict[str, Any]] = Field( default=[], description="Conversation history as list of {question, answer} dicts" ) session_id: str | None = Field( default=None, description="Session ID for multi-turn conversations. If provided, session state is used to resolve follow-up questions like 'En in Enschede?'. If None, a new session is created and returned." ) include_visualization: bool = Field(default=True, description="Include visualization config") embedding_model: str | None = Field( default=None, description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available." ) llm_provider: str | None = Field( default=None, description="LLM provider to use for this request: 'zai', 'anthropic', 'huggingface', or 'openai'. If None, uses server default (LLM_PROVIDER env)." ) llm_model: str | None = Field( default=None, description="Specific LLM model to use (e.g., 'glm-4.6', 'claude-sonnet-4-5-20250929', 'gpt-4o'). If None, uses provider default." ) skip_cache: bool = Field( default=False, description="Bypass cache lookup and force fresh LLM query. Useful for debugging." ) class LLMResponseMetadata(BaseModel): """LLM response provenance metadata (aligned with LinkML LLMResponse schema). Captures GLM 4.7 Interleaved Thinking chain-of-thought reasoning and full API response metadata for audit trails and debugging. See: schemas/20251121/linkml/modules/classes/LLMResponse.yaml """ # Core response content content: str | None = None # The final LLM response text reasoning_content: str | None = None # GLM 4.7 Interleaved Thinking chain-of-thought # Model identification model: str | None = None # Model identifier (e.g., 'glm-4.7', 'claude-3-opus') provider: str | None = None # Provider enum: zai, anthropic, openai, huggingface, groq # Request tracking request_id: str | None = None # Provider-assigned request ID created: str | None = None # ISO 8601 timestamp of response generation # Token usage (for cost estimation and monitoring) prompt_tokens: int | None = None # Tokens in input prompt completion_tokens: int | None = None # Tokens in response (content + reasoning) total_tokens: int | None = None # Total tokens used cached_tokens: int | None = None # Tokens served from provider cache # Response metadata finish_reason: str | None = None # stop, length, tool_calls, content_filter latency_ms: int | None = None # Response latency in milliseconds # GLM 4.7 Thinking Mode configuration thinking_mode: str | None = None # enabled, disabled, interleaved, preserved clear_thinking: bool | None = None # False = Preserved Thinking enabled class DSPyQueryResponse(BaseModel): """DSPy RAG query response.""" question: str resolved_question: str | None = None answer: str sources_used: list[str] = [] visualization: dict[str, Any] | None = None retrieved_results: list[dict[str, Any]] | None = None # Raw retrieved data for frontend visualization query_type: str | None = None # "person" or "institution" - helps frontend choose visualization query_time_ms: float = 0.0 conversation_turn: int = 0 embedding_model_used: str | None = None # Which embedding model was used for the search llm_provider_used: str | None = None # Which LLM provider handled this request (zai, anthropic, huggingface, openai) llm_model_used: str | None = None # Which specific LLM model was used (e.g., 'glm-4.6', 'claude-sonnet-4-5-20250929') # Cost tracking fields (from cost_tracker module) timing_ms: float | None = None # Total pipeline timing from cost tracker cost_usd: float | None = None # Estimated LLM cost in USD timing_breakdown: dict[str, float] | None = None # Per-stage timing breakdown # Cache tracking cache_hit: bool = False # Whether response was served from cache # LLM response provenance (GLM 4.7 Thinking Mode support) llm_response: LLMResponseMetadata | None = None # Full LLM response metadata including reasoning_content # Session management for multi-turn conversations session_id: str | None = None # Session ID for continuing conversation (returned even if not provided in request) # Template SPARQL tracking (for monitoring template hit rate vs LLM fallback) template_used: bool = False # Whether template-based SPARQL was used (vs LLM generation) template_id: str | None = None # Which template was used (e.g., "institution_by_city", "person_by_name") # Factual query mode - skip LLM generation for count/list queries factual_result: bool = False # True if this is a direct SPARQL result (no LLM prose generation) sparql_query: str | None = None # The SPARQL query that was executed (for transparency) # Rule 46: Epistemic Provenance Tracking epistemic_provenance: dict[str, Any] | None = None # Full provenance chain for transparency def extract_llm_response_metadata( lm: Any, provider: str | None = None, latency_ms: int | None = None, ) -> LLMResponseMetadata | None: """Extract LLM response metadata from DSPy LM history. DSPy stores the raw API response in lm.history[-1]["response"], which includes: - choices[0].message.content (final response text) - choices[0].message.reasoning_content (GLM 4.7 Interleaved Thinking) - usage.prompt_tokens, completion_tokens, total_tokens - model, created, id, finish_reason This enables capturing GLM 4.7's chain-of-thought reasoning for provenance. Args: lm: DSPy LM instance with history attribute provider: LLM provider name (zai, anthropic, openai, etc.) latency_ms: Response latency in milliseconds Returns: LLMResponseMetadata or None if history is empty """ try: # Check if LM has history if not hasattr(lm, "history") or not lm.history: logger.debug("No LM history available for metadata extraction") return None # Get the last history entry (most recent LLM call) last_entry = lm.history[-1] response = last_entry.get("response") if response is None: logger.debug("No response in LM history entry") return None # Extract content and reasoning_content from the response content = None reasoning_content = None finish_reason = None if hasattr(response, "choices") and response.choices: choice = response.choices[0] if hasattr(choice, "message"): message = choice.message content = getattr(message, "content", None) # GLM 4.7 Interleaved Thinking - check for reasoning_content reasoning_content = getattr(message, "reasoning_content", None) elif isinstance(choice, dict): content = choice.get("text") or choice.get("message", {}).get("content") reasoning_content = choice.get("message", {}).get("reasoning_content") # Extract finish_reason finish_reason = getattr(choice, "finish_reason", None) if finish_reason is None and isinstance(choice, dict): finish_reason = choice.get("finish_reason") # Extract usage statistics - handle both dict and object types # (DSPy/OpenAI SDK may return CompletionUsage objects instead of dicts) usage = last_entry.get("usage") prompt_tokens = None completion_tokens = None total_tokens = None cached_tokens = None if usage is not None: if hasattr(usage, "prompt_tokens"): # It's an object (e.g., CompletionUsage from OpenAI SDK) prompt_tokens = getattr(usage, "prompt_tokens", None) completion_tokens = getattr(usage, "completion_tokens", None) total_tokens = getattr(usage, "total_tokens", None) prompt_details = getattr(usage, "prompt_tokens_details", None) if prompt_details is not None: cached_tokens = getattr(prompt_details, "cached_tokens", None) elif isinstance(usage, dict): # It's a plain dict prompt_tokens = usage.get("prompt_tokens") completion_tokens = usage.get("completion_tokens") total_tokens = usage.get("total_tokens") prompt_details = usage.get("prompt_tokens_details") if isinstance(prompt_details, dict): cached_tokens = prompt_details.get("cached_tokens") # Extract model info model = last_entry.get("response_model") or last_entry.get("model") request_id = getattr(response, "id", None) created = getattr(response, "created", None) # Convert unix timestamp to ISO 8601 if needed created_str = None if created: if isinstance(created, (int, float)): import datetime created_str = datetime.datetime.fromtimestamp(created, tz=datetime.timezone.utc).isoformat() else: created_str = str(created) # Determine thinking mode (GLM 4.7 specific) thinking_mode = None if reasoning_content: # If we got reasoning_content, the model used interleaved thinking thinking_mode = "interleaved" metadata = LLMResponseMetadata( content=content, reasoning_content=reasoning_content, model=model, provider=provider, request_id=request_id, created=created_str, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, cached_tokens=cached_tokens, finish_reason=finish_reason, latency_ms=latency_ms, thinking_mode=thinking_mode, ) if reasoning_content: logger.info( f"Captured GLM 4.7 reasoning_content ({len(reasoning_content)} chars) " f"from {provider}/{model}" ) return metadata except Exception as e: logger.warning(f"Failed to extract LLM response metadata: {e}") return None # Cache Client class ValkeyClient: """Client for Valkey semantic cache API.""" def __init__(self, base_url: str = settings.valkey_api_url): self.base_url = base_url.rstrip("/") self._client: httpx.AsyncClient | None = None @property async def client(self) -> httpx.AsyncClient: """Get or create async HTTP client.""" if self._client is None or self._client.is_closed: self._client = httpx.AsyncClient(timeout=30.0) return self._client def _cache_key(self, question: str, sources: list[DataSource] | None) -> str: """Generate cache key from question and sources.""" if sources: sources_str = ",".join(sorted(s.value for s in sources)) else: sources_str = "auto" key_str = f"{question.lower().strip()}:{sources_str}" return hashlib.sha256(key_str.encode()).hexdigest()[:32] async def get(self, question: str, sources: list[DataSource] | None) -> dict[str, Any] | None: """Get cached response using semantic cache lookup.""" try: client = await self.client response = await client.post( f"{self.base_url}/cache/lookup", json={ "query": question, # Higher threshold (0.97) to avoid false cache hits on semantically # similar but geographically different queries "similarity_threshold": 0.97, }, ) if response.status_code == 200: data = response.json() if data.get("found") and data.get("entry"): logger.info(f"Cache hit for question: {question[:50]}... (similarity: {data.get('similarity', 0):.3f})") return data["entry"].get("response") # type: ignore[no-any-return] return None except Exception as e: logger.warning(f"Cache get failed: {e}") return None async def set( self, question: str, sources: list[DataSource] | None, response: dict[str, Any], ttl: int = settings.cache_ttl, ) -> bool: """Cache response using semantic cache store.""" try: client = await self.client # Build CachedResponse schema cached_response = { "answer": response.get("answer", ""), "sparql_query": response.get("sparql_query"), "typeql_query": response.get("typeql_query"), "visualization_type": response.get("visualization_type"), "visualization_data": response.get("visualization_data"), "sources": response.get("sources", []), "confidence": response.get("confidence", 0.0), "context": response.get("context"), } await client.post( f"{self.base_url}/cache/store", json={ "query": question, "response": cached_response, "language": response.get("language", "nl"), "model": response.get("llm_model", "unknown"), }, ) logger.debug(f"Cached response for: {question[:50]}...") return True except Exception as e: logger.warning(f"Cache set failed: {e}") return False def _dspy_cache_key( self, question: str, language: str, llm_provider: str | None, embedding_model: str | None, context_hash: str | None = None, ) -> str: """Generate cache key for DSPy query responses. Cache key components: - Question text (normalized) - Language code - LLM provider (different providers give different answers) - Embedding model (affects retrieval results) - Context hash (for multi-turn conversations) """ components = [ question.lower().strip(), language, llm_provider or "default", embedding_model or "auto", context_hash or "no_context", ] key_str = ":".join(components) return f"dspy:{hashlib.sha256(key_str.encode()).hexdigest()[:32]}" async def get_dspy( self, question: str, language: str, llm_provider: str | None, embedding_model: str | None, context: list[dict[str, Any]] | None = None, ) -> dict[str, Any] | None: """Get cached DSPy response using semantic cache lookup. Cache hits are filtered by LLM provider to ensure responses from different providers (e.g., anthropic vs huggingface) are cached separately. """ try: client = await self.client response = await client.post( f"{self.base_url}/cache/lookup", json={ "query": question, "language": language, # Higher threshold (0.97) to avoid false cache hits on semantically # similar but geographically different queries like # "archieven in Groningen" vs "archieven in de stad Groningen" "similarity_threshold": 0.97, }, ) if response.status_code == 200: data = response.json() if data.get("found") and data.get("entry"): cached_response = data["entry"].get("response") # Verify the cached response matches the requested LLM provider # The model field in cache contains the provider (e.g., "anthropic", "huggingface") cached_model = data["entry"].get("model") requested_provider = llm_provider or settings.llm_provider if cached_model and cached_model != requested_provider: logger.info( f"DSPy cache miss (provider mismatch): cached={cached_model}, requested={requested_provider}" ) return None similarity = data.get("similarity", 0) method = data.get("method", "unknown") logger.info(f"DSPy cache hit for question: {question[:50]}... (similarity: {similarity:.3f}, method: {method}, provider: {cached_model})") return cached_response # type: ignore[no-any-return] return None except Exception as e: logger.warning(f"DSPy cache get failed: {e}") return None async def set_dspy( self, question: str, language: str, llm_provider: str | None, embedding_model: str | None, response: dict[str, Any], context: list[dict[str, Any]] | None = None, ttl: int = settings.cache_ttl, ) -> bool: """Cache DSPy response using semantic cache store. Maps DSPyQueryResponse fields to CachedResponse schema: - sources_used -> sources - visualization -> visualization_type + visualization_data - Additional context from query_type, resolved_question, etc. """ try: client = await self.client # Extract visualization components if present visualization = response.get("visualization") viz_type = None viz_data = None if visualization: viz_type = visualization.get("type") viz_data = visualization.get("data") # Build CachedResponse schema matching the Valkey API # Maps DSPyQueryResponse fields to CachedResponse expected fields # # IMPORTANT: Include llm_response metadata (GLM 4.7 reasoning_content) in cache # so that cached responses also return the chain-of-thought reasoning. llm_response_data = None if response.get("llm_response"): llm_resp = response["llm_response"] # Handle both dict and LLMResponseMetadata object if hasattr(llm_resp, "model_dump"): llm_response_data = llm_resp.model_dump() elif isinstance(llm_resp, dict): llm_response_data = llm_resp cached_response = { "answer": response.get("answer", ""), "sparql_query": None, # DSPy doesn't generate SPARQL "typeql_query": None, # DSPy doesn't generate TypeQL "visualization_type": viz_type, "visualization_data": viz_data, "sources": response.get("sources_used", []), # DSPy uses sources_used "confidence": 0.95, # DSPy responses are generally high confidence "context": { "query_type": response.get("query_type"), "resolved_question": response.get("resolved_question"), "retrieved_results": response.get("retrieved_results"), "embedding_model": response.get("embedding_model_used"), "llm_model": response.get("llm_model_used"), "original_context": context, "llm_response": llm_response_data, # GLM 4.7 reasoning_content }, } result = await client.post( f"{self.base_url}/cache/store", json={ "query": question, "response": cached_response, "language": language, "model": llm_provider or "unknown", }, ) # Check if store was successful if result.status_code == 200: logger.info(f"✓ Cached DSPy response for: {question[:50]}...") return True else: logger.warning(f"Cache store returned {result.status_code}: {result.text[:200]}") return False except Exception as e: logger.warning(f"DSPy cache set failed: {e}") return False async def close(self) -> None: """Close HTTP client.""" if self._client: await self._client.aclose() self._client = None # Query Router class QueryRouter: """Routes queries to appropriate data sources based on intent.""" def __init__(self) -> None: self.intent_keywords = { QueryIntent.GEOGRAPHIC: [ "map", "kaart", "where", "waar", "location", "locatie", "city", "stad", "country", "land", "region", "gebied", "coordinates", "coördinaten", "near", "nearby", "in de buurt", ], QueryIntent.STATISTICAL: [ "how many", "hoeveel", "count", "aantal", "total", "totaal", "average", "gemiddeld", "distribution", "verdeling", "percentage", "statistics", "statistiek", "most", "meest", ], QueryIntent.RELATIONAL: [ "related", "gerelateerd", "connected", "verbonden", "relationship", "relatie", "network", "netwerk", "parent", "child", "merged", "fusie", "member of", ], QueryIntent.TEMPORAL: [ "history", "geschiedenis", "timeline", "tijdlijn", "when", "wanneer", "founded", "opgericht", "closed", "gesloten", "over time", "evolution", "change", "verandering", ], QueryIntent.DETAIL: [ "details", "information", "informatie", "about", "over", "specific", "specifiek", "what is", "wat is", ], } # NOTE: DuckLake removed from RAG - it's for offline analytics only # Statistical queries now use SPARQL aggregations (COUNT, SUM, AVG, GROUP BY) self.source_routing = { QueryIntent.GEOGRAPHIC: [DataSource.POSTGIS, DataSource.QDRANT, DataSource.SPARQL], QueryIntent.STATISTICAL: [DataSource.SPARQL, DataSource.QDRANT], # SPARQL aggregations QueryIntent.RELATIONAL: [DataSource.TYPEDB, DataSource.SPARQL], QueryIntent.TEMPORAL: [DataSource.TYPEDB, DataSource.SPARQL], QueryIntent.SEARCH: [DataSource.QDRANT, DataSource.SPARQL], QueryIntent.DETAIL: [DataSource.SPARQL, DataSource.QDRANT], } def detect_intent(self, question: str) -> QueryIntent: """Detect query intent from question text.""" import re question_lower = question.lower() intent_scores = {intent: 0 for intent in QueryIntent} for intent, keywords in self.intent_keywords.items(): for keyword in keywords: # Use word boundary matching to avoid partial matches # e.g., "land" should not match "netherlands" pattern = r'\b' + re.escape(keyword) + r'\b' if re.search(pattern, question_lower): intent_scores[intent] += 1 max_intent = max(intent_scores, key=intent_scores.get) # type: ignore if intent_scores[max_intent] == 0: return QueryIntent.SEARCH return max_intent def get_sources( self, question: str, requested_sources: list[DataSource] | None = None, ) -> tuple[QueryIntent, list[DataSource]]: """Get optimal sources for a query. Args: question: User's question requested_sources: Explicitly requested sources (overrides routing) Returns: Tuple of (detected_intent, list_of_sources) """ intent = self.detect_intent(question) if requested_sources: return intent, requested_sources return intent, self.source_routing.get(intent, [DataSource.QDRANT]) # Multi-Source Retriever class MultiSourceRetriever: """Orchestrates retrieval across multiple data sources.""" def __init__(self) -> None: self.cache = ValkeyClient() self.router = QueryRouter() # Initialize retrievers lazily self._qdrant: HybridRetriever | None = None self._typedb: TypeDBRetriever | None = None self._sparql_client: httpx.AsyncClient | None = None self._postgis_client: httpx.AsyncClient | None = None # NOTE: DuckLake client removed - DuckLake is for offline analytics only @property def qdrant(self) -> HybridRetriever | None: """Lazy-load Qdrant hybrid retriever with multi-embedding support.""" if self._qdrant is None and RETRIEVERS_AVAILABLE: try: self._qdrant = create_hybrid_retriever( use_production=settings.qdrant_use_production, use_multi_embedding=settings.use_multi_embedding, preferred_embedding_model=settings.preferred_embedding_model, ) except Exception as e: logger.warning(f"Failed to initialize Qdrant: {e}") return self._qdrant @property def typedb(self) -> TypeDBRetriever | None: """Lazy-load TypeDB retriever.""" if self._typedb is None and RETRIEVERS_AVAILABLE: try: self._typedb = create_typedb_retriever( use_production=settings.typedb_use_production # Use TypeDB-specific setting ) except Exception as e: logger.warning(f"Failed to initialize TypeDB: {e}") return self._typedb async def _get_sparql_client(self) -> httpx.AsyncClient: """Get SPARQL HTTP client with connection pooling. Connection pooling improves performance by reusing TCP connections instead of creating new ones for each request. """ if self._sparql_client is None or self._sparql_client.is_closed: self._sparql_client = httpx.AsyncClient( timeout=30.0, limits=httpx.Limits( max_connections=20, # Max total connections max_keepalive_connections=10, # Keep-alive connections in pool keepalive_expiry=30.0, # Seconds to keep idle connections ), ) # Record connection pool metrics if record_connection_pool: record_connection_pool(client="sparql", pool_size=20, available=20) return self._sparql_client async def _get_postgis_client(self) -> httpx.AsyncClient: """Get PostGIS HTTP client with connection pooling.""" if self._postgis_client is None or self._postgis_client.is_closed: self._postgis_client = httpx.AsyncClient( timeout=30.0, limits=httpx.Limits( max_connections=10, max_keepalive_connections=5, keepalive_expiry=30.0, ), ) # Record connection pool metrics if record_connection_pool: record_connection_pool(client="postgis", pool_size=10, available=10) return self._postgis_client # NOTE: _get_ducklake_client removed - DuckLake is for offline analytics only, not RAG retrieval async def retrieve_from_qdrant( self, query: str, k: int = 10, embedding_model: str | None = None, region_codes: list[str] | None = None, cities: list[str] | None = None, institution_types: list[str] | None = None, ) -> RetrievalResult: """Retrieve from Qdrant vector + SPARQL hybrid search. Args: query: Search query k: Number of results to return embedding_model: Optional embedding model to use (e.g., 'minilm_384', 'openai_1536') region_codes: Filter by province/region codes (e.g., ['NH', 'ZH']) cities: Filter by city names (e.g., ['Amsterdam', 'Rotterdam']) institution_types: Filter by institution types (e.g., ['ARCHIVE', 'MUSEUM']) """ start = asyncio.get_event_loop().time() items = [] if self.qdrant: try: results = self.qdrant.search( query, k=k, using=embedding_model, region_codes=region_codes, cities=cities, institution_types=institution_types, ) items = [r.to_dict() for r in results] except Exception as e: logger.error(f"Qdrant retrieval failed: {e}") elapsed = (asyncio.get_event_loop().time() - start) * 1000 return RetrievalResult( source=DataSource.QDRANT, items=items, score=max((r.get("scores", {}).get("combined", 0) for r in items), default=0), query_time_ms=elapsed, ) async def retrieve_from_sparql( self, query: str, k: int = 10, ) -> RetrievalResult: """Retrieve from SPARQL endpoint. Uses TEMPLATE-FIRST approach: 1. Try template-based SPARQL generation (deterministic, validated) 2. Fall back to LLM-based generation only if no template matches Template approach provides 65% precision vs 10% for LLM-only. """ global _template_pipeline_instance start = asyncio.get_event_loop().time() items = [] sparql_query = "" template_used = False try: # =================================================================== # STEP 1: Try TEMPLATE-BASED SPARQL generation (preferred) # =================================================================== if TEMPLATE_SPARQL_AVAILABLE and get_template_pipeline: try: # Get or create singleton pipeline instance if _template_pipeline_instance is None: _template_pipeline_instance = get_template_pipeline() logger.info("[SPARQL] Template pipeline initialized for MultiSourceRetriever") # Run template matching in thread pool (DSPy is synchronous) template_result = await asyncio.to_thread( _template_pipeline_instance, question=query, conversation_state=None, # No conversation state in simple retriever language="nl" ) if template_result.matched and template_result.sparql: sparql_query = template_result.sparql template_used = True logger.info(f"[SPARQL] Template match: '{template_result.template_id}' " f"(confidence={template_result.confidence:.2f})") else: logger.info(f"[SPARQL] No template match: {template_result.reasoning}") except Exception as e: logger.warning(f"[SPARQL] Template pipeline failed: {e}") # =================================================================== # STEP 2: Fall back to LLM-BASED SPARQL generation # =================================================================== if not template_used and RETRIEVERS_AVAILABLE and generate_sparql: logger.info("[SPARQL] Falling back to LLM-based SPARQL generation") sparql_result = generate_sparql(query, language="nl", use_rag=False) sparql_query = sparql_result.get("sparql", "") # =================================================================== # STEP 3: Execute the SPARQL query # =================================================================== if sparql_query: logger.debug(f"[SPARQL] Executing query:\n{sparql_query[:500]}...") client = await self._get_sparql_client() response = await client.post( settings.sparql_endpoint, data={"query": sparql_query}, headers={"Accept": "application/sparql-results+json"}, ) if response.status_code == 200: data = response.json() bindings = data.get("results", {}).get("bindings", []) items = [ {key: val.get("value") for key, val in b.items()} for b in bindings[:k] ] logger.info(f"[SPARQL] Query returned {len(items)} results " f"(template={template_used})") else: logger.warning(f"[SPARQL] Query failed with status {response.status_code}: " f"{response.text[:200]}") except Exception as e: logger.error(f"SPARQL retrieval failed: {e}") elapsed = (asyncio.get_event_loop().time() - start) * 1000 return RetrievalResult( source=DataSource.SPARQL, items=items, score=1.0 if items else 0.0, query_time_ms=elapsed, ) async def retrieve_from_typedb( self, query: str, k: int = 10, ) -> RetrievalResult: """Retrieve from TypeDB knowledge graph.""" start = asyncio.get_event_loop().time() items = [] if self.typedb: try: results = self.typedb.semantic_search(query, k=k) items = [r.to_dict() for r in results] except Exception as e: logger.error(f"TypeDB retrieval failed: {e}") elapsed = (asyncio.get_event_loop().time() - start) * 1000 return RetrievalResult( source=DataSource.TYPEDB, items=items, score=max((r.get("relevance_score", 0) for r in items), default=0), query_time_ms=elapsed, ) async def retrieve_from_postgis( self, query: str, k: int = 10, ) -> RetrievalResult: """Retrieve from PostGIS geospatial database.""" start = asyncio.get_event_loop().time() # Extract location from query for geospatial search # This is a simplified implementation items = [] try: client = await self._get_postgis_client() # Try to detect city name for bbox search query_lower = query.lower() # Simple city detection cities = { "amsterdam": {"lat": 52.3676, "lon": 4.9041}, "rotterdam": {"lat": 51.9244, "lon": 4.4777}, "den haag": {"lat": 52.0705, "lon": 4.3007}, "utrecht": {"lat": 52.0907, "lon": 5.1214}, } for city, coords in cities.items(): if city in query_lower: # Query PostGIS for nearby institutions response = await client.get( f"{settings.postgis_url}/api/institutions/nearby", params={ "lat": coords["lat"], "lon": coords["lon"], "radius_km": 10, "limit": k, }, ) if response.status_code == 200: items = response.json() break except Exception as e: logger.error(f"PostGIS retrieval failed: {e}") elapsed = (asyncio.get_event_loop().time() - start) * 1000 return RetrievalResult( source=DataSource.POSTGIS, items=items, score=1.0 if items else 0.0, query_time_ms=elapsed, ) # NOTE: retrieve_from_ducklake removed - DuckLake is for offline analytics only, not RAG retrieval # Statistical queries now use SPARQL aggregations (COUNT, SUM, AVG, GROUP BY) on Oxigraph async def retrieve( self, question: str, sources: list[DataSource], k: int = 10, embedding_model: str | None = None, region_codes: list[str] | None = None, cities: list[str] | None = None, institution_types: list[str] | None = None, ) -> list[RetrievalResult]: """Retrieve from multiple sources concurrently. Args: question: User's question sources: Data sources to query k: Number of results per source embedding_model: Optional embedding model for Qdrant (e.g., 'minilm_384', 'openai_1536') region_codes: Filter by province/region codes (e.g., ['NH', 'ZH']) - Qdrant only cities: Filter by city names (e.g., ['Amsterdam']) - Qdrant only institution_types: Filter by institution types (e.g., ['ARCHIVE']) - Qdrant only Returns: List of RetrievalResult from each source """ tasks = [] for source in sources: if source == DataSource.QDRANT: tasks.append(self.retrieve_from_qdrant( question, k, embedding_model, region_codes=region_codes, cities=cities, institution_types=institution_types, )) elif source == DataSource.SPARQL: tasks.append(self.retrieve_from_sparql(question, k)) elif source == DataSource.TYPEDB: tasks.append(self.retrieve_from_typedb(question, k)) elif source == DataSource.POSTGIS: tasks.append(self.retrieve_from_postgis(question, k)) # NOTE: DuckLake case removed - DuckLake is for offline analytics only results = await asyncio.gather(*tasks, return_exceptions=True) # Filter out exceptions valid_results = [] for r in results: if isinstance(r, RetrievalResult): valid_results.append(r) elif isinstance(r, Exception): logger.error(f"Retrieval task failed: {r}") return valid_results def merge_results( self, results: list[RetrievalResult], max_results: int = 20, template_used: bool = False, template_id: str | None = None, ) -> tuple[list[dict[str, Any]], EpistemicProvenance]: """Merge and deduplicate results from multiple sources. Uses reciprocal rank fusion for score combination. Returns merged items AND epistemic provenance tracking. Rule 46: Epistemic Provenance Tracking """ from datetime import datetime, timezone # Track items by GHCID for deduplication merged: dict[str, dict[str, Any]] = {} # Initialize provenance tracking tier_counts: dict[DataTier, int] = {} sources_queried = [r.source.value for r in results] total_retrieved = sum(len(r.items) for r in results) for result in results: # Map DataSource to RetrievalSource source_map = { DataSource.QDRANT: RetrievalSource.QDRANT, DataSource.SPARQL: RetrievalSource.SPARQL, DataSource.TYPEDB: RetrievalSource.TYPEDB, DataSource.POSTGIS: RetrievalSource.POSTGIS, DataSource.CACHE: RetrievalSource.CACHE, } retrieval_source = source_map.get(result.source, RetrievalSource.LLM_SYNTHESIS) for rank, item in enumerate(result.items): ghcid = item.get("ghcid", item.get("id", f"unknown_{rank}")) if ghcid not in merged: merged[ghcid] = item.copy() merged[ghcid]["_sources"] = [] merged[ghcid]["_rrf_score"] = 0.0 merged[ghcid]["_data_tier"] = None # Infer data tier for this item item_tier = infer_data_tier(item, retrieval_source) tier_counts[item_tier] = tier_counts.get(item_tier, 0) + 1 # Track best (lowest) tier for each item if merged[ghcid]["_data_tier"] is None: merged[ghcid]["_data_tier"] = item_tier.value else: merged[ghcid]["_data_tier"] = min(merged[ghcid]["_data_tier"], item_tier.value) # Reciprocal Rank Fusion rrf_score = 1.0 / (60 + rank) # k=60 is standard # Weight by source source_weights = { DataSource.QDRANT: settings.vector_weight, DataSource.SPARQL: settings.graph_weight, DataSource.TYPEDB: settings.typedb_weight, DataSource.POSTGIS: 0.3, } weight = source_weights.get(result.source, 0.5) merged[ghcid]["_rrf_score"] += rrf_score * weight merged[ghcid]["_sources"].append(result.source.value) # Sort by RRF score sorted_items = sorted( merged.values(), key=lambda x: x.get("_rrf_score", 0), reverse=True, ) final_items = sorted_items[:max_results] # Build epistemic provenance provenance = EpistemicProvenance( dataSource=EpistemicDataSource.RAG_PIPELINE, dataTier=aggregate_data_tier(tier_counts), sourceTimestamp=datetime.now(timezone.utc).isoformat(), derivationChain=build_derivation_chain( sources_used=sources_queried, template_used=template_used, template_id=template_id, ), revalidationPolicy="weekly", sourcesQueried=sources_queried, totalRetrieved=total_retrieved, totalAfterFusion=len(final_items), dataTierBreakdown={ f"tier_{tier.value}": count for tier, count in tier_counts.items() }, templateUsed=template_used, templateId=template_id, ) return final_items, provenance async def close(self) -> None: """Clean up resources.""" await self.cache.close() if self._sparql_client: await self._sparql_client.aclose() if self._postgis_client: await self._postgis_client.aclose() if self._qdrant: self._qdrant.close() if self._typedb: self._typedb.close() def search_persons( self, query: str, k: int = 10, filter_custodian: str | None = None, only_heritage_relevant: bool = False, only_wcms: bool = False, using: str | None = None, ) -> list[Any]: """Search for persons/staff in the heritage_persons collection. Delegates to HybridRetriever.search_persons() if available. Args: query: Search query k: Number of results filter_custodian: Optional custodian slug to filter by only_heritage_relevant: Only return heritage-relevant staff only_wcms: Only return WCMS-registered profiles using: Optional embedding model to use (e.g., 'minilm_384', 'openai_1536') Returns: List of RetrievedPerson objects """ if self.qdrant: try: return self.qdrant.search_persons( # type: ignore[no-any-return] query=query, k=k, filter_custodian=filter_custodian, only_heritage_relevant=only_heritage_relevant, only_wcms=only_wcms, using=using, ) except Exception as e: logger.error(f"Person search failed: {e}") return [] def get_stats(self) -> dict[str, Any]: """Get statistics from all retrievers. Returns combined stats from Qdrant (including persons collection) and TypeDB. """ stats = {} if self.qdrant: try: qdrant_stats = self.qdrant.get_stats() stats.update(qdrant_stats) except Exception as e: logger.warning(f"Failed to get Qdrant stats: {e}") if self.typedb: try: typedb_stats = self.typedb.get_stats() stats["typedb"] = typedb_stats except Exception as e: logger.warning(f"Failed to get TypeDB stats: {e}") return stats # Global instances retriever: MultiSourceRetriever | None = None viz_selector: VisualizationSelector | None = None dspy_pipeline: Any = None # HeritageRAGPipeline instance (loaded with optimized model) atomic_cache_manager: Any = None # AtomicCacheManager for sub-task caching (40-70% hit rate vs 5-15% for full queries) @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncIterator[None]: """Application lifespan manager.""" global retriever, viz_selector, dspy_pipeline, atomic_cache_manager # Startup logger.info("Starting Heritage RAG API...") retriever = MultiSourceRetriever() if RETRIEVERS_AVAILABLE: # Check for any available LLM API key (Anthropic preferred, OpenAI fallback) has_llm_key = bool(settings.anthropic_api_key or settings.openai_api_key) # VisualizationSelector requires DSPy - make it optional try: viz_selector = VisualizationSelector(use_dspy=has_llm_key) except RuntimeError as e: logger.warning(f"VisualizationSelector not available: {e}") viz_selector = None # Configure DSPy based on LLM_PROVIDER setting # Respect user's provider preference, with fallback chain import dspy llm_provider = settings.llm_provider.lower() logger.info(f"LLM_PROVIDER configured as: {llm_provider}") dspy_configured = False # Try Z.AI GLM if configured as provider (FREE!) if llm_provider == "zai" and settings.zai_api_token: try: # Z.AI uses OpenAI-compatible API format # Use LLM_MODEL from settings (default: glm-4.5-flash for speed) zai_model = settings.llm_model if settings.llm_model.startswith("glm-") else "glm-4.5-flash" lm = dspy.LM( f"openai/{zai_model}", api_key=settings.zai_api_token, api_base="https://api.z.ai/api/coding/paas/v4", ) dspy.configure(lm=lm) logger.info(f"Configured DSPy with Z.AI {zai_model} (FREE)") dspy_configured = True except Exception as e: logger.warning(f"Failed to configure DSPy with Z.AI: {e}") # Try HuggingFace if configured as provider if not dspy_configured and llm_provider == "huggingface" and settings.huggingface_api_key: try: lm = dspy.LM("huggingface/utter-project/EuroLLM-9B-Instruct", api_key=settings.huggingface_api_key) dspy.configure(lm=lm) logger.info("Configured DSPy with HuggingFace EuroLLM-9B-Instruct") dspy_configured = True except Exception as e: logger.warning(f"Failed to configure DSPy with HuggingFace: {e}") # Try Anthropic if not yet configured (either as primary or fallback) if not dspy_configured and (llm_provider == "anthropic" or (llm_provider == "huggingface" and settings.anthropic_api_key)): if settings.anthropic_api_key and configure_dspy: try: configure_dspy( provider="anthropic", model=settings.default_model, api_key=settings.anthropic_api_key, ) dspy_configured = True except Exception as e: logger.warning(f"Failed to configure DSPy with Anthropic: {e}") # Try OpenAI as final fallback if not dspy_configured and settings.openai_api_key and configure_dspy: try: configure_dspy( provider="openai", model="gpt-4o-mini", api_key=settings.openai_api_key, ) dspy_configured = True except Exception as e: logger.warning(f"Failed to configure DSPy with OpenAI: {e}") if not dspy_configured: logger.warning("No LLM provider configured - DSPy queries will fail") # Initialize optimized HeritageRAGPipeline (if DSPy is configured) if dspy_configured: try: from dspy_heritage_rag import HeritageRAGPipeline from pathlib import Path # Create pipeline with Qdrant retriever qdrant_retriever = retriever.qdrant if retriever else None dspy_pipeline = HeritageRAGPipeline(retriever=qdrant_retriever) # Load optimized model (BootstrapFewShot: 14.3% quality improvement) # Note: load() may fail if new modules were added that aren't in the saved state optimized_model_path = Path(__file__).parent / "optimized_models" / "heritage_rag_bootstrap_latest.json" if optimized_model_path.exists(): try: dspy_pipeline.load(str(optimized_model_path)) logger.info(f"Loaded optimized DSPy pipeline from {optimized_model_path}") except Exception as load_err: # Pipeline still works, just without optimized demos for new modules logger.warning(f"Could not load optimized model (new modules may need re-optimization): {load_err}") logger.info("Pipeline initialized without optimized demos - will work but may be less accurate") else: logger.warning(f"Optimized model not found at {optimized_model_path}, using unoptimized pipeline") except Exception as e: logger.warning(f"Failed to initialize DSPy pipeline: {e}") dspy_pipeline = None # === HOT LOADING: Warmup embedding model to avoid cold-start latency === # The sentence-transformers model takes 3-15 seconds to load on first use. # By loading it eagerly at startup, we eliminate this delay for users. if retriever.qdrant: logger.info("Warming up embedding model (this takes 3-15 seconds on first startup)...") warmup_start = time.perf_counter() try: # Trigger model load with a dummy embedding request _ = retriever.qdrant._get_embedding("archief warmup query") warmup_duration = time.perf_counter() - warmup_start logger.info(f"✅ Embedding model warmed up in {warmup_duration:.2f}s - ready for fast queries!") # Record warmup metrics if record_embedding_warmup: record_embedding_warmup( model="sentence-transformers/all-MiniLM-L6-v2", duration_seconds=warmup_duration, success=True, ) except Exception as e: warmup_duration = time.perf_counter() - warmup_start logger.warning(f"Failed to warm up embedding model: {e}") if record_embedding_warmup: record_embedding_warmup( model="sentence-transformers/all-MiniLM-L6-v2", duration_seconds=warmup_duration, success=False, ) # === TEMPLATE EMBEDDING WARMUP: Pre-compute embeddings for template patterns === # The TemplateEmbeddingMatcher computes embeddings on first query (~2-5 seconds). # By pre-computing at startup, we eliminate this delay for users. template_warmup_start = time.perf_counter() template_count = 0 try: from template_sparql import get_template_embedding_matcher, TemplateClassifier logger.info("Pre-computing template pattern embeddings...") classifier = TemplateClassifier() templates = classifier._load_templates() if templates: template_count = len(templates) matcher = get_template_embedding_matcher() if matcher._ensure_embeddings_computed(templates): template_warmup_duration = time.perf_counter() - template_warmup_start logger.info(f"✅ Template embeddings pre-computed ({template_count} templates) in {template_warmup_duration:.2f}s") # Record template warmup metrics if record_template_embedding_warmup: record_template_embedding_warmup( duration_seconds=template_warmup_duration, template_count=template_count, success=True, ) else: logger.warning("Template embedding computation skipped (model not available)") if set_warmup_status: set_warmup_status("template_embeddings", False) else: logger.warning("No templates found for embedding warmup") if set_warmup_status: set_warmup_status("template_embeddings", False) except Exception as e: template_warmup_duration = time.perf_counter() - template_warmup_start logger.warning(f"Failed to pre-compute template embeddings: {e}") if record_template_embedding_warmup: record_template_embedding_warmup( duration_seconds=template_warmup_duration, template_count=template_count, success=False, ) # === ATOMIC CACHE MANAGER: Sub-task caching for higher hit rates === # Research shows 40-70% cache hit rates with atomic decomposition vs 5-15% for full queries. # Initialize AtomicCacheManager with retriever's semantic cache for persistence. if ATOMIC_CACHE_AVAILABLE and AtomicCacheManager: try: semantic_cache = retriever.cache if retriever else None atomic_cache_manager = AtomicCacheManager(semantic_cache=semantic_cache) logger.info("✅ AtomicCacheManager initialized for sub-task caching") except Exception as e: logger.warning(f"Failed to initialize AtomicCacheManager: {e}") # === ONTOLOGY CACHE WARMUP: Pre-load KG values to avoid cold-start latency === # The OntologyLoader queries the Knowledge Graph for valid slot values (cities, regions, types). # These queries can take 1-3 seconds each on first access. # By pre-loading at startup, we eliminate this delay for users. ontology_warmup_start = time.perf_counter() try: from template_sparql import get_ontology_loader logger.info("Warming up ontology cache (pre-loading KG values)...") ontology = get_ontology_loader() ontology.load() # Triggers KG queries for institution_types, subregions, cities, etc. ontology_warmup_duration = time.perf_counter() - ontology_warmup_start cache_stats = ontology.get_kg_cache_stats() logger.info( f"✅ Ontology cache warmed up in {ontology_warmup_duration:.2f}s " f"({cache_stats['cache_size']} KG queries cached, TTL={cache_stats['ttl_seconds']}s)" ) except Exception as e: ontology_warmup_duration = time.perf_counter() - ontology_warmup_start logger.warning(f"Failed to warm up ontology cache: {e}") logger.info("Heritage RAG API started") yield # Shutdown logger.info("Shutting down Heritage RAG API...") if retriever: await retriever.close() logger.info("Heritage RAG API stopped") # Create FastAPI app app = FastAPI( title=settings.api_title, version=settings.api_version, lifespan=lifespan, ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Prometheus metrics endpoint if METRICS_AVAILABLE and create_metrics_endpoint: app.include_router(create_metrics_endpoint(), prefix="/api/rag") # API Endpoints @app.get("/api/rag/health") async def health_check() -> dict[str, Any]: """Health check for all services.""" health: dict[str, Any] = { "status": "ok", "timestamp": datetime.now(timezone.utc).isoformat(), "services": {}, } # Check Qdrant if retriever and retriever.qdrant: try: stats = retriever.qdrant.get_stats() health["services"]["qdrant"] = { "status": "ok", "vectors": stats.get("qdrant", {}).get("vectors_count", 0), } except Exception as e: health["services"]["qdrant"] = {"status": "error", "error": str(e)} # Check SPARQL try: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(f"{settings.sparql_endpoint.replace('/query', '')}") health["services"]["sparql"] = { "status": "ok" if response.status_code < 500 else "error" } except Exception as e: health["services"]["sparql"] = {"status": "error", "error": str(e)} # Check TypeDB if retriever and retriever.typedb: try: stats = retriever.typedb.get_stats() health["services"]["typedb"] = { "status": "ok", "entities": stats.get("entities", {}), } except Exception as e: health["services"]["typedb"] = {"status": "error", "error": str(e)} # Overall status services = health["services"] errors = sum(1 for s in services.values() if isinstance(s, dict) and s.get("status") == "error") health["status"] = "ok" if errors == 0 else "degraded" if errors < 3 else "error" return health @app.get("/api/rag/stats") async def get_stats() -> dict[str, Any]: """Get retriever statistics.""" stats: dict[str, Any] = { "timestamp": datetime.now(timezone.utc).isoformat(), "retrievers": {}, } if retriever: if retriever.qdrant: stats["retrievers"]["qdrant"] = retriever.qdrant.get_stats() if retriever.typedb: stats["retrievers"]["typedb"] = retriever.typedb.get_stats() return stats @app.get("/api/rag/stats/costs") async def get_cost_stats() -> dict[str, Any]: """Get cost tracking session statistics. Returns cumulative statistics for the current session including: - Total LLM calls and token usage - Total retrieval operations and latencies - Estimated costs by model - Pipeline timing statistics Returns: Dict with cost tracker statistics or unavailable message """ if not COST_TRACKER_AVAILABLE or not get_tracker: return { "available": False, "message": "Cost tracker module not available", } tracker = get_tracker() return { "available": True, "timestamp": datetime.now(timezone.utc).isoformat(), "session": tracker.get_session_summary(), } @app.post("/api/rag/stats/costs/reset") async def reset_cost_stats() -> dict[str, Any]: """Reset cost tracking statistics. Clears all accumulated statistics and starts a fresh session. Useful for per-conversation or per-session cost tracking. Returns: Confirmation message """ if not COST_TRACKER_AVAILABLE or not reset_tracker: return { "available": False, "message": "Cost tracker module not available", } reset_tracker() return { "available": True, "message": "Cost tracking statistics reset", "timestamp": datetime.now(timezone.utc).isoformat(), } @app.get("/api/rag/stats/templates") async def get_template_stats() -> dict[str, Any]: """Get template SPARQL usage statistics. Returns metrics about template-based SPARQL query generation, including hit rate and breakdown by template ID. This is useful for: - Monitoring template coverage (what % of queries use templates) - Identifying which templates are most used - Tuning template slot extraction parameters Returns: Dict with template hit rate, breakdown by template_id, and timestamp """ if not METRICS_AVAILABLE: return { "available": False, "message": "Metrics module not available", } # Import the metrics functions try: from metrics import get_template_hit_rate, get_template_breakdown except ImportError: return { "available": False, "message": "Metrics module import failed", } return { "available": True, "hit_rate": get_template_hit_rate(), "breakdown": get_template_breakdown(), "timestamp": datetime.now(timezone.utc).isoformat(), } @app.get("/api/rag/embedding/models") async def get_embedding_models() -> dict[str, Any]: """List available embedding models for the Qdrant collections. Returns information about which embedding models are available in each collection's named vectors, helping clients choose the right model for their use case. Returns: Dict with available models per collection, current settings, and recommendations """ result: dict[str, Any] = { "timestamp": datetime.now(timezone.utc).isoformat(), "multi_embedding_enabled": settings.use_multi_embedding, "preferred_model": settings.preferred_embedding_model, "collections": {}, "models": { "openai_1536": { "description": "OpenAI text-embedding-3-small (1536 dimensions)", "quality": "high", "cost": "paid API", "recommended_for": "production, high-quality semantic search", }, "minilm_384": { "description": "sentence-transformers/all-MiniLM-L6-v2 (384 dimensions)", "quality": "good", "cost": "free (local)", "recommended_for": "development, cost-sensitive deployments", }, "bge_768": { "description": "BAAI/bge-small-en-v1.5 (768 dimensions)", "quality": "very good", "cost": "free (local)", "recommended_for": "balanced quality/cost, multilingual support", }, }, } if retriever and retriever.qdrant: qdrant = retriever.qdrant # Check if multi-embedding is enabled and get available models if hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding: if hasattr(qdrant, 'multi_retriever') and qdrant.multi_retriever: multi = qdrant.multi_retriever # Get available models for institutions collection try: inst_models = multi.get_available_models("heritage_custodians") selected = multi.select_model("heritage_custodians") result["collections"]["heritage_custodians"] = { "available_models": [m.value for m in inst_models], "uses_named_vectors": multi.uses_named_vectors("heritage_custodians"), "recommended": selected.value if selected else None, } except Exception as e: result["collections"]["heritage_custodians"] = {"error": str(e)} # Get available models for persons collection try: person_models = multi.get_available_models("heritage_persons") selected = multi.select_model("heritage_persons") result["collections"]["heritage_persons"] = { "available_models": [m.value for m in person_models], "uses_named_vectors": multi.uses_named_vectors("heritage_persons"), "recommended": selected.value if selected else None, } except Exception as e: result["collections"]["heritage_persons"] = {"error": str(e)} else: # Single embedding mode - detect dimension stats = qdrant.get_stats() result["single_embedding_mode"] = True result["note"] = "Collections use single embedding vectors. Enable USE_MULTI_EMBEDDING=true to use named vectors." return result class EmbeddingCompareRequest(BaseModel): """Request for comparing embedding models.""" query: str = Field(..., description="Query to search with") collection: str = Field(default="heritage_persons", description="Collection to search") k: int = Field(default=5, ge=1, le=20, description="Number of results per model") @app.post("/api/rag/embedding/compare") async def compare_embedding_models(request: EmbeddingCompareRequest) -> dict[str, Any]: """Compare search results across different embedding models. Performs the same search query using each available embedding model, allowing A/B testing of embedding quality. This endpoint is useful for: - Evaluating which embedding model works best for your queries - Understanding differences in semantic similarity between models - Making informed decisions about which model to use in production Returns: Dict with results from each embedding model, including scores and overlap analysis """ import time start_time = time.time() if not retriever or not retriever.qdrant: raise HTTPException(status_code=503, detail="Qdrant retriever not available") qdrant = retriever.qdrant # Check if multi-embedding is available if not (hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding): raise HTTPException( status_code=400, detail="Multi-embedding mode not enabled. Set USE_MULTI_EMBEDDING=true to use this endpoint." ) if not (hasattr(qdrant, 'multi_retriever') and qdrant.multi_retriever): raise HTTPException(status_code=503, detail="Multi-embedding retriever not initialized") multi = qdrant.multi_retriever try: # Use the compare_models method from MultiEmbeddingRetriever comparison = multi.compare_models( query=request.query, collection=request.collection, k=request.k, ) elapsed_ms = (time.time() - start_time) * 1000 return { "query": request.query, "collection": request.collection, "k": request.k, "query_time_ms": round(elapsed_ms, 2), "comparison": comparison, } except Exception as e: logger.exception(f"Embedding comparison failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/rag/query", response_model=QueryResponse) async def query_rag(request: QueryRequest) -> QueryResponse: """Main RAG query endpoint. Orchestrates retrieval from multiple sources, merges results, and optionally generates visualization configuration. """ if not retriever: raise HTTPException(status_code=503, detail="Retriever not initialized") start_time = asyncio.get_event_loop().time() # Check cache first cached = await retriever.cache.get(request.question, request.sources) if cached: # Transform cached data to QueryResponse schema # Cache stores: answer, sparql_query, sources, confidence, context # QueryResponse needs: question, sparql, results, sources_used, query_time_ms, result_count try: # Get sources from cached data (may be strings or DataSource enums) cached_sources = cached.get("sources", []) sources_used = [] for s in cached_sources: if isinstance(s, str): try: sources_used.append(DataSource(s)) except ValueError: # Skip invalid source values pass elif isinstance(s, DataSource): sources_used.append(s) # Get results from context if available results = cached.get("results", []) if not results and cached.get("context"): results = cached["context"].get("retrieved_results", []) or [] return QueryResponse( question=request.question, sparql=cached.get("sparql_query") or cached.get("sparql"), results=results, visualization=cached.get("visualization"), sources_used=sources_used or [DataSource.QDRANT], # Default if none cache_hit=True, query_time_ms=cached.get("query_time_ms", 0.0), result_count=cached.get("result_count", len(results)), ) except Exception as e: logger.warning(f"Failed to transform cached response: {e}, skipping cache") # Fall through to normal query processing # Route query to appropriate sources intent, sources = retriever.router.get_sources(request.question, request.sources) logger.info(f"Query intent: {intent}, sources: {sources}") # Extract geographic filters from question (province, city, institution type) geo_filters = extract_geographic_filters(request.question) if any(geo_filters.values()): logger.info(f"Geographic filters extracted: {geo_filters}") # Retrieve from all sources results = await retriever.retrieve( request.question, sources, request.k, embedding_model=request.embedding_model, region_codes=geo_filters["region_codes"], cities=geo_filters["cities"], institution_types=geo_filters["institution_types"], ) # Merge results with provenance tracking merged_items, retrieval_provenance = retriever.merge_results(results, max_results=request.k * 2) # Generate visualization config if requested visualization = None if request.include_visualization and viz_selector and merged_items: # Extract schema from first result schema_fields = list(merged_items[0].keys()) if merged_items else [] schema_str = ", ".join(f for f in schema_fields if not f.startswith("_")) visualization = viz_selector.select( request.question, schema_str, len(merged_items), ) elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000 response_data = { "question": request.question, "sparql": None, # Could be populated from SPARQL result "results": merged_items, "visualization": visualization, "sources_used": [s for s in sources], "cache_hit": False, "query_time_ms": round(elapsed_ms, 2), "result_count": len(merged_items), } # Cache the response await retriever.cache.set(request.question, request.sources, response_data) return QueryResponse(**response_data) # type: ignore[arg-type] @app.post("/api/rag/sparql", response_model=SPARQLResponse) async def generate_sparql_endpoint(request: SPARQLRequest) -> SPARQLResponse: """Generate SPARQL query from natural language. Uses TEMPLATE-FIRST approach: 1. Try template-based SPARQL generation (deterministic, validated) 2. Fall back to LLM-based generation only if no template matches Template approach provides 65% precision vs 10% for LLM-only (Formica et al. 2023). """ global _template_pipeline_instance template_used = False sparql_query = "" explanation = "" try: # =================================================================== # STEP 1: Try TEMPLATE-BASED SPARQL generation (preferred) # =================================================================== if TEMPLATE_SPARQL_AVAILABLE and get_template_pipeline: try: # Get or create singleton pipeline instance if _template_pipeline_instance is None: _template_pipeline_instance = get_template_pipeline() logger.info("[SPARQL] Template pipeline initialized for /api/rag/sparql endpoint") # Run template matching in thread pool (DSPy is synchronous) template_result = await asyncio.to_thread( _template_pipeline_instance, question=request.question, conversation_state=None, language=request.language ) if template_result.matched and template_result.sparql: sparql_query = template_result.sparql template_used = True explanation = ( f"Template '{template_result.template_id}' matched with " f"confidence {template_result.confidence:.2f}. " f"Slots: {template_result.slots}. " f"{template_result.reasoning}" ) logger.info(f"[SPARQL] Template match: '{template_result.template_id}' " f"(confidence={template_result.confidence:.2f})") else: logger.info(f"[SPARQL] No template match: {template_result.reasoning}") except Exception as e: logger.warning(f"[SPARQL] Template pipeline failed: {e}") # =================================================================== # STEP 2: Fall back to LLM-BASED SPARQL generation # =================================================================== if not template_used: if not RETRIEVERS_AVAILABLE: raise HTTPException(status_code=503, detail="SPARQL generator not available") logger.info("[SPARQL] Falling back to LLM-based SPARQL generation") result = generate_sparql( request.question, language=request.language, context=request.context, use_rag=request.use_rag, ) sparql_query = result["sparql"] explanation = result.get("explanation", "") return SPARQLResponse( sparql=sparql_query, explanation=explanation, rag_used=not template_used, # RAG only used if LLM fallback retrieved_passages=[], ) except HTTPException: raise except Exception as e: logger.exception("SPARQL generation failed") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/rag/sparql/execute", response_model=SPARQLExecuteResponse) async def execute_sparql_query(request: SPARQLExecuteRequest) -> SPARQLExecuteResponse: """Execute a SPARQL query directly against the knowledge graph. This endpoint allows users to run modified SPARQL queries and see the results without regenerating an answer. Useful for exploration and debugging. """ import time start_time = time.time() try: # Use pooled SPARQL client from retriever for better connection reuse if retriever: client = await retriever._get_sparql_client() response = await client.post( settings.sparql_endpoint, data={"query": request.sparql_query}, headers={"Accept": "application/sparql-results+json"}, timeout=request.timeout, # Override timeout per-request if specified ) else: # Fallback to creating a new client if retriever not initialized async with httpx.AsyncClient(timeout=request.timeout) as client: response = await client.post( settings.sparql_endpoint, data={"query": request.sparql_query}, headers={"Accept": "application/sparql-results+json"}, ) if response.status_code != 200: return SPARQLExecuteResponse( results=[], result_count=0, query_time_ms=(time.time() - start_time) * 1000, error=f"SPARQL endpoint returned {response.status_code}: {response.text[:500]}", ) data = response.json() bindings = data.get("results", {}).get("bindings", []) # Convert bindings to simple dicts results = [ {k: v.get("value") for k, v in binding.items()} for binding in bindings ] return SPARQLExecuteResponse( results=results, result_count=len(results), query_time_ms=(time.time() - start_time) * 1000, ) except httpx.TimeoutException: return SPARQLExecuteResponse( results=[], result_count=0, query_time_ms=(time.time() - start_time) * 1000, error=f"Query timed out after {request.timeout}s", ) except Exception as e: logger.exception("SPARQL execution failed") return SPARQLExecuteResponse( results=[], result_count=0, query_time_ms=(time.time() - start_time) * 1000, error=str(e), ) @app.post("/api/rag/sparql/rerun", response_model=SPARQLRerunResponse) async def rerun_rag_with_sparql(request: SPARQLRerunRequest) -> SPARQLRerunResponse: """Re-run the RAG pipeline with modified SPARQL results injected into context. This endpoint allows users to: 1. Execute a modified SPARQL query 2. Inject those results into the DSPy RAG context 3. Generate a new answer based on the modified knowledge graph results This affects the entire conversation through DSPy by providing new factual context that the LLM uses to generate its response. """ import time import dspy start_time = time.time() # Step 1: Execute the modified SPARQL query using pooled client sparql_results: list[dict[str, Any]] = [] try: # Use pooled SPARQL client from retriever for better connection reuse if retriever: client = await retriever._get_sparql_client() response = await client.post( settings.sparql_endpoint, data={"query": request.sparql_query}, headers={"Accept": "application/sparql-results+json"}, ) else: # Fallback to creating a new client if retriever not initialized async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post( settings.sparql_endpoint, data={"query": request.sparql_query}, headers={"Accept": "application/sparql-results+json"}, ) if response.status_code == 200: data = response.json() bindings = data.get("results", {}).get("bindings", []) sparql_results = [ {k: v.get("value") for k, v in binding.items()} for binding in bindings[:50] # Limit to 50 results for context ] logger.info(f"SPARQL rerun: got {len(sparql_results)} results") else: logger.warning(f"SPARQL rerun: endpoint returned {response.status_code}") except Exception as e: logger.exception(f"SPARQL rerun: execution failed: {e}") # Step 2: Format SPARQL results as context for DSPy sparql_context = "" if sparql_results: sparql_context = "\n[KENNISGRAAF RESULTATEN (aangepaste SPARQL query)]:\n" for i, result in enumerate(sparql_results[:20], 1): entry = " | ".join(f"{k}: {v}" for k, v in result.items() if v) sparql_context += f" {i}. {entry}\n" # Step 3: Run DSPy answer generation with injected SPARQL context answer = "" try: from dspy_heritage_rag import HeritageRAGPipeline # Get LLM configuration lm = None provider = request.llm_provider or "zai" model = request.llm_model if provider == "zai" and settings.zai_api_token: model = model or "glm-4.5-flash" lm = dspy.LM( f"openai/{model}", api_key=settings.zai_api_token, api_base="https://api.z.ai/api/coding/paas/v4", ) elif provider == "groq" and settings.groq_api_key: model = model or "llama-3.1-8b-instant" lm = dspy.LM(f"groq/{model}", api_key=settings.groq_api_key) elif provider == "openai" and settings.openai_api_key: model = model or "gpt-4o-mini" lm = dspy.LM(f"openai/{model}", api_key=settings.openai_api_key) elif provider == "anthropic" and settings.anthropic_api_key: model = model or "claude-sonnet-4-20250514" lm = dspy.LM(f"anthropic/{model}", api_key=settings.anthropic_api_key) if lm: with dspy.settings.context(lm=lm): # Create a simple answer generator that uses the SPARQL context generate_answer = dspy.ChainOfThought( "question, sparql_context, language -> answer" ) result = generate_answer( question=request.original_question, sparql_context=sparql_context, language=request.language, ) answer = result.answer else: answer = f"LLM niet beschikbaar. SPARQL resultaten: {len(sparql_results)} gevonden." except Exception as e: logger.exception(f"SPARQL rerun: answer generation failed: {e}") answer = f"Fout bij het genereren van antwoord: {str(e)}" return SPARQLRerunResponse( results=sparql_results, answer=answer, sparql_result_count=len(sparql_results), query_time_ms=(time.time() - start_time) * 1000, ) @app.post("/api/rag/visualize") async def get_visualization_config( question: str = Query(..., description="User's question"), schema: str = Query(..., description="Comma-separated field names"), result_count: int = Query(default=0, description="Number of results"), ) -> dict[str, Any]: """Get visualization configuration for a query.""" if not viz_selector: raise HTTPException(status_code=503, detail="Visualization selector not available") config = viz_selector.select(question, schema, result_count) return config # type: ignore[no-any-return] @app.post("/api/rag/typedb/search", response_model=TypeDBSearchResponse) async def typedb_search(request: TypeDBSearchRequest) -> TypeDBSearchResponse: """Direct TypeDB search endpoint. Search heritage custodians in TypeDB using various strategies: - semantic: Natural language search (combines type + location patterns) - name: Search by institution name - type: Search by institution type (museum, archive, library, gallery) - location: Search by city/location name Examples: - {"query": "museums in Amsterdam", "search_type": "semantic"} - {"query": "Rijksmuseum", "search_type": "name"} - {"query": "archive", "search_type": "type"} - {"query": "Rotterdam", "search_type": "location"} """ import time start_time = time.time() # Check if TypeDB retriever is available if not retriever or not retriever.typedb: raise HTTPException( status_code=503, detail="TypeDB retriever not available. Ensure TypeDB is running." ) try: typedb_retriever = retriever.typedb # Route to appropriate search method if request.search_type == "name": results = typedb_retriever.search_by_name(request.query, k=request.k) elif request.search_type == "type": results = typedb_retriever.search_by_type(request.query, k=request.k) elif request.search_type == "location": results = typedb_retriever.search_by_location(city=request.query, k=request.k) else: # semantic (default) results = typedb_retriever.semantic_search(request.query, k=request.k) # Convert results to dicts result_dicts = [] seen_names = set() # Deduplicate by name for r in results: # Handle both dict and object results if hasattr(r, 'to_dict'): item = r.to_dict() elif isinstance(r, dict): item = r else: item = {"name": str(r)} # Deduplicate by name name = item.get("name") or item.get("observed_name", "") if name and name not in seen_names: seen_names.add(name) result_dicts.append(item) elapsed_ms = (time.time() - start_time) * 1000 return TypeDBSearchResponse( query=request.query, search_type=request.search_type, results=result_dicts, result_count=len(result_dicts), query_time_ms=round(elapsed_ms, 2), ) except Exception as e: logger.exception(f"TypeDB search failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/rag/persons/search", response_model=PersonSearchResponse) async def person_search(request: PersonSearchRequest) -> PersonSearchResponse: """Search for persons/staff in heritage institutions. Search the heritage_persons Qdrant collection for staff members at heritage custodian institutions. Examples: - {"query": "Wie werkt er in het Nationaal Archief?"} - {"query": "archivist at Rijksmuseum", "k": 20} - {"query": "conservator", "filter_custodian": "rijksmuseum"} - {"query": "digital preservation", "only_heritage_relevant": true} The search uses semantic vector similarity to find relevant staff members based on their name, role, headline, and custodian affiliation. """ import time start_time = time.time() # Check if retriever is available if not retriever: raise HTTPException( status_code=503, detail="Hybrid retriever not available. Ensure Qdrant is running." ) try: # Use the hybrid retriever's person search results = retriever.search_persons( query=request.query, k=request.k, filter_custodian=request.filter_custodian, only_heritage_relevant=request.only_heritage_relevant, only_wcms=request.only_wcms, using=request.embedding_model, # Pass embedding model ) # Determine which embedding model was actually used embedding_model_used = None qdrant = retriever.qdrant if qdrant and hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding: if request.embedding_model: embedding_model_used = request.embedding_model elif hasattr(qdrant, '_selected_multi_model') and qdrant._selected_multi_model: embedding_model_used = qdrant._selected_multi_model.value # Convert results to dicts using to_dict() method if available result_dicts = [] for r in results: if hasattr(r, 'to_dict'): item = r.to_dict() elif hasattr(r, '__dict__'): item = { "name": getattr(r, 'name', 'Unknown'), "headline": getattr(r, 'headline', None), "custodian_name": getattr(r, 'custodian_name', None), "custodian_slug": getattr(r, 'custodian_slug', None), "linkedin_url": getattr(r, 'linkedin_url', None), "heritage_relevant": getattr(r, 'heritage_relevant', None), "heritage_type": getattr(r, 'heritage_type', None), "location": getattr(r, 'location', None), "has_wcms": getattr(r, 'has_wcms', None), # WCMS fields "wcms_user_id": getattr(r, 'wcms_user_id', None), "wcms_abs_id": getattr(r, 'wcms_abs_id', None), "wcms_crm_id": getattr(r, 'wcms_crm_id', None), "wcms_username": getattr(r, 'wcms_username', None), "wcms_username_url": getattr(r, 'wcms_username_url', None), "wcms_status": getattr(r, 'wcms_status', None), "wcms_roles": getattr(r, 'wcms_roles', None), "wcms_registered_since": getattr(r, 'wcms_registered_since', None), "wcms_last_access": getattr(r, 'wcms_last_access', None), # Contact details "email": getattr(r, 'email', None), "email_domain": getattr(r, 'email_domain', None), "score": getattr(r, 'combined_score', getattr(r, 'vector_score', None)), } elif isinstance(r, dict): item = r else: item = {"name": str(r)} result_dicts.append(item) elapsed_ms = (time.time() - start_time) * 1000 # Get collection stats stats = None try: stats = retriever.get_stats() # Only include person collection stats if available if stats and 'persons' in stats: stats = {'persons': stats['persons']} except Exception: pass return PersonSearchResponse( context=PERSON_JSONLD_CONTEXT, # JSON-LD context for linked data query=request.query, results=result_dicts, result_count=len(result_dicts), query_time_ms=round(elapsed_ms, 2), collection_stats=stats, embedding_model_used=embedding_model_used, ) except Exception as e: logger.exception(f"Person search failed: {e}") raise HTTPException(status_code=500, detail=str(e)) def _extract_subtask_result(task: Any, pipeline_result: Any, response: Any) -> Any: """Extract the relevant result portion for an atomic sub-task. Maps sub-task types to corresponding data from the pipeline result. This enables caching individual components for reuse in similar queries. Args: task: AtomicSubTask with task_type and parameters pipeline_result: Raw result from HeritageRAGPipeline response: DSPyQueryResponse object Returns: Cacheable result for this sub-task, or None if not extractable """ from atomic_decomposer import SubTaskType task_type = task.task_type # Intent classification - cache the detected intent if task_type == SubTaskType.INTENT_CLASSIFICATION: return { "intent": task.parameters.get("intent"), "query_type": getattr(pipeline_result, "query_type", None), } # Type filter - cache institution type filtering results if task_type == SubTaskType.TYPE_FILTER: inst_type = task.parameters.get("institution_type") retrieved = getattr(pipeline_result, "retrieved_results", None) if retrieved and isinstance(retrieved, list): # Filter results to just this institution type filtered = [r for r in retrieved if r.get("institution_type") == inst_type] return { "institution_type": inst_type, "count": len(filtered), "sample_ids": [r.get("id") for r in filtered[:10]], # Cache IDs not full records } # Location filter - cache geographic filtering results if task_type == SubTaskType.LOCATION_FILTER: location = task.parameters.get("location") retrieved = getattr(pipeline_result, "retrieved_results", None) if retrieved and isinstance(retrieved, list): # Count results in this location location_lower = location.lower() if location else "" in_location = [ r for r in retrieved if location_lower in str(r.get("city", "")).lower() or location_lower in str(r.get("region", "")).lower() ] return { "location": location, "level": task.parameters.get("level"), "count": len(in_location), } # Aggregation - cache aggregate statistics if task_type == SubTaskType.AGGREGATION: agg_type = task.parameters.get("aggregation") retrieved = getattr(pipeline_result, "retrieved_results", None) if agg_type == "count" and retrieved: return { "aggregation": "count", "value": len(retrieved) if isinstance(retrieved, list) else 0, } # Identifier filter - cache identifier lookup results if task_type == SubTaskType.IDENTIFIER_FILTER: id_type = task.parameters.get("identifier_type") retrieved = getattr(pipeline_result, "retrieved_results", None) if retrieved and isinstance(retrieved, list): # Count entities with this identifier type has_id = [ r for r in retrieved if r.get(f"has_{id_type}") or r.get(id_type) ] return { "identifier_type": id_type, "has_identifier_count": len(has_id), } # Default: don't cache if we can't extract meaningful sub-result return None @app.post("/api/rag/dspy/query", response_model=DSPyQueryResponse) async def dspy_query(request: DSPyQueryRequest) -> DSPyQueryResponse: """DSPy RAG query endpoint with multi-turn conversation support. Uses the HeritageRAGPipeline for conversation-aware question answering. Follow-up questions like "Welke daarvan behoren archieven?" will be resolved using previous conversation context. Args: request: Query request with question, language, and conversation context Returns: DSPyQueryResponse with answer, resolved question, and optional visualization """ import time start_time = time.time() # Session management for multi-turn conversations # Get or create session state that enables follow-up question resolution session_id = request.session_id conversation_state = None session_mgr = None if SESSION_MANAGER_AVAILABLE and get_session_manager: try: session_mgr = await get_session_manager() session_id, conversation_state = await session_mgr.get_or_create(request.session_id) logger.debug(f"Session {session_id}: {len(conversation_state.turns)} previous turns") except Exception as e: logger.warning(f"Session manager error (continuing without session): {e}") # Generate a new session_id even if session manager failed import uuid session_id = str(uuid.uuid4()) else: # No session manager - generate session_id for tracking purposes import uuid session_id = request.session_id or str(uuid.uuid4()) # Resolve the provider BEFORE cache lookup to ensure consistent cache keys # This is critical: cache GET and SET must use the same provider value resolved_provider = (request.llm_provider or settings.llm_provider).lower() # Check cache first (before expensive LLM configuration) unless skip_cache is True if retriever and not request.skip_cache: cached = await retriever.cache.get_dspy( question=request.question, language=request.language, llm_provider=resolved_provider, # Use resolved provider, not request.llm_provider embedding_model=request.embedding_model, context=request.context if request.context else None, ) if cached: elapsed_ms = (time.time() - start_time) * 1000 logger.info(f"DSPy cache hit - returning cached response in {elapsed_ms:.2f}ms") # Transform CachedResponse format back to DSPyQueryResponse format # CachedResponse has: sources, visualization_type, visualization_data, context # DSPyQueryResponse needs: sources_used, visualization, query_type, etc. cached_context = cached.get("context") or {} visualization = None if cached.get("visualization_type") or cached.get("visualization_data"): visualization = { "type": cached.get("visualization_type"), "data": cached.get("visualization_data"), } # Restore llm_response metadata (GLM 4.7 reasoning_content) from cache llm_response_cached = cached_context.get("llm_response") llm_response_obj = None if llm_response_cached: try: llm_response_obj = LLMResponseMetadata(**llm_response_cached) except Exception: # Fall back to dict if LLMResponseMetadata fails llm_response_obj = llm_response_cached # type: ignore[assignment] # Rule 46: Build provenance for cache hit responses cached_sources = cached.get("sources", []) cached_template_used = cached_context.get("template_used", False) cached_template_id = cached_context.get("template_id") cached_llm_provider = cached_context.get("llm_provider") cached_llm_model = cached_context.get("llm_model") # Infer data tier - prioritize cached provenance if present cached_provenance = cached_context.get("epistemic_provenance") if cached_provenance: # Use the cached provenance, but mark it as coming from cache cache_provenance = cached_provenance.copy() if "CACHE" not in cache_provenance.get("derivationChain", []): cache_provenance.setdefault("derivationChain", []).insert(0, "CACHE:hit") else: # Build fresh provenance for older cache entries cache_tier = DataTier.TIER_3_CROWD_SOURCED.value if cached_template_used: cache_tier = DataTier.TIER_1_AUTHORITATIVE.value elif any(s.lower() in ["sparql", "typedb"] for s in cached_sources): cache_tier = DataTier.TIER_1_AUTHORITATIVE.value cache_provenance = EpistemicProvenance( dataSource=EpistemicDataSource.CACHE_AGGREGATION, dataTier=cache_tier, derivationChain=["CACHE:hit"] + build_derivation_chain( sources_used=cached_sources, template_used=cached_template_used, template_id=cached_template_id, llm_provider=cached_llm_provider, ), sourcesQueried=cached_sources, templateUsed=cached_template_used, templateId=cached_template_id, llmProvider=cached_llm_provider, llmModel=cached_llm_model, ).model_dump() response_data = { "question": request.question, "answer": cached.get("answer", ""), "sources_used": cached_sources, "visualization": visualization, "resolved_question": cached_context.get("resolved_question"), "retrieved_results": cached_context.get("retrieved_results"), "query_type": cached_context.get("query_type"), "embedding_model_used": cached_context.get("embedding_model"), "llm_model_used": cached_llm_model, "query_time_ms": round(elapsed_ms, 2), "cache_hit": True, "llm_response": llm_response_obj, # GLM 4.7 reasoning_content from cache # Session management - return session_id for follow-up queries "session_id": session_id, # Template tracking from cache "template_used": cached_template_used, "template_id": cached_template_id, # Rule 46: Epistemic provenance for transparency "epistemic_provenance": cache_provenance, } # Record cache hit metrics if METRICS_AVAILABLE and record_query: try: record_query( endpoint="dspy_query", template_used=cached_context.get("template_used", False), template_id=cached_context.get("template_id"), cache_hit=True, status="success", duration_seconds=elapsed_ms / 1000, intent=cached_context.get("query_type"), ) except Exception as e: logger.warning(f"Failed to record cache hit metrics: {e}") return DSPyQueryResponse(**response_data) # === ATOMIC SUB-TASK CACHING === # Full query cache miss - try atomic decomposition for partial cache hits # Research shows 40-70% cache hit rates with atomic decomposition decomposed_query = None cached_subtasks: dict[str, Any] = {} if ATOMIC_CACHE_AVAILABLE and atomic_cache_manager: try: decomposed_query, cached_subtasks = await atomic_cache_manager.process_query( query=request.question, language=request.language, ) # Record atomic cache metrics subtask_hits = decomposed_query.partial_cache_hits if decomposed_query else 0 subtask_misses = len(decomposed_query.sub_tasks) - subtask_hits if decomposed_query else 0 if decomposed_query.fully_cached: # All sub-tasks are cached - can potentially skip LLM logger.info(f"Atomic cache: fully cached ({len(decomposed_query.sub_tasks)} sub-tasks)") if record_atomic_cache: record_atomic_cache( query_hit=True, subtask_hits=subtask_hits, subtask_misses=0, fully_assembled=True, ) elif decomposed_query.partial_cache_hits > 0: # Partial cache hit - some sub-tasks cached logger.info( f"Atomic cache: partial hit ({decomposed_query.partial_cache_hits}/" f"{len(decomposed_query.sub_tasks)} sub-tasks cached)" ) if record_atomic_cache: record_atomic_cache( query_hit=False, subtask_hits=subtask_hits, subtask_misses=subtask_misses, fully_assembled=False, ) else: logger.debug(f"Atomic cache: miss (0/{len(decomposed_query.sub_tasks)} sub-tasks)") if record_atomic_cache: record_atomic_cache( query_hit=False, subtask_hits=0, subtask_misses=subtask_misses, fully_assembled=False, ) except Exception as e: logger.warning(f"Atomic decomposition failed: {e}") # ========================================================================== # FACTUAL QUERY FAST PATH: Skip LLM for count/list queries # ========================================================================== # For factual queries (counts, lists, comparisons), the SPARQL results ARE # the answer. No need for expensive LLM prose generation - just return the # table directly. This can reduce latency from ~15s to ~2s. # ========================================================================== try: from template_sparql import get_template_pipeline template_pipeline = get_template_pipeline() # Try template matching (this handles follow-up resolution internally) # Note: conversation_state already contains history from request.context # Run in thread pool to avoid blocking the event loop (DSPy is synchronous) template_result = await asyncio.to_thread( template_pipeline, question=request.question, language=request.language, conversation_state=conversation_state, ) # Check if this is a factual query that can skip LLM (template-driven, not hardcoded) # Fast path rule: If "prose" is NOT in response_modes, LLM generation is skipped if template_result.matched and not template_result.requires_llm(): # Log database routing decision databases_used = template_result.databases if hasattr(template_result, 'databases') else ["oxigraph", "qdrant"] qdrant_skipped = "qdrant" not in databases_used logger.info( f"[FAST-PATH] Template '{template_result.template_id}' uses response_modes={template_result.response_modes}, " f"databases={databases_used} - skipping LLM generation{', Qdrant skipped' if qdrant_skipped else ''} " f"(confidence={template_result.confidence:.2f})" ) # Execute SPARQL directly sparql_query = template_result.sparql sparql_results: list[dict[str, Any]] = [] sparql_error: str | None = None try: if retriever: client = await retriever._get_sparql_client() response = await client.post( settings.sparql_endpoint, data={"query": sparql_query}, headers={"Accept": "application/sparql-results+json"}, timeout=30.0, ) if response.status_code == 200: data = response.json() bindings = data.get("results", {}).get("bindings", []) raw_results = [ {k: v.get("value") for k, v in binding.items()} for binding in bindings ] # Check if this is a COUNT query (raw_results has 'count' key) # COUNT queries return [{"count": "10"}] - don't transform these is_count_query = raw_results and "count" in raw_results[0] if is_count_query: # For COUNT queries, preserve raw results with count value # Convert count string to int for template rendering sparql_results = [] for row in raw_results: count_val = row.get("count", "0") try: count_int = int(count_val) except (ValueError, TypeError): count_int = 0 sparql_results.append({ "count": count_int, "metadata": { "institution_type": template_result.slots.get("institution_type"), }, "scores": {"combined": 1.0}, }) logger.debug(f"[FACTUAL-QUERY] COUNT query result: {sparql_results[0].get('count') if sparql_results else 0}") # Execute companion query if available to get entity results for map/list # This fetches the actual institution records that were counted companion_query = getattr(template_result, 'companion_query', None) if companion_query: try: companion_response = await client.post( settings.sparql_endpoint, data={"query": companion_query}, headers={"Accept": "application/sparql-results+json"}, timeout=30.0, ) if companion_response.status_code == 200: companion_data = companion_response.json() companion_bindings = companion_data.get("results", {}).get("bindings", []) companion_raw = [ {k: v.get("value") for k, v in binding.items()} for binding in companion_bindings ] # Transform companion results to frontend format companion_results = [] for row in companion_raw: lat = None lon = None if row.get("lat"): try: lat = float(row["lat"]) except (ValueError, TypeError): pass if row.get("lon"): try: lon = float(row["lon"]) except (ValueError, TypeError): pass companion_results.append({ "name": row.get("name"), "institution_uri": row.get("institution"), "metadata": { "latitude": lat, "longitude": lon, "city": row.get("city") or template_result.slots.get("city"), "institution_type": template_result.slots.get("institution_type"), }, "scores": {"combined": 1.0}, }) # Store companion results - these will be used for map/list display # while sparql_results contains the count for the answer text if companion_results: logger.info(f"[COMPANION-QUERY] Fetched {len(companion_results)} entities for display, {sum(1 for r in companion_results if r['metadata'].get('latitude'))} with coordinates") # Replace sparql_results with companion results for display # but preserve the count value for answer rendering count_value = sparql_results[0].get("count", 0) if sparql_results else 0 sparql_results = companion_results # Add count to first result so it's available for ui_template if sparql_results: sparql_results[0]["count"] = count_value else: logger.warning(f"[COMPANION-QUERY] Failed with status {companion_response.status_code}") except Exception as ce: logger.warning(f"[COMPANION-QUERY] Execution failed: {ce}") else: # Transform SPARQL results to match frontend expected format # Frontend expects: {name, website, metadata: {latitude, longitude, city, ...}} # SPARQL returns: {name, website, lat, lon, city, ...} sparql_results = [] for row in raw_results: # Parse lat/lon to float if present lat = None lon = None if row.get("lat"): try: lat = float(row["lat"]) except (ValueError, TypeError): pass if row.get("lon"): try: lon = float(row["lon"]) except (ValueError, TypeError): pass transformed = { "name": row.get("name"), "website": row.get("website"), "metadata": { "latitude": lat, "longitude": lon, "city": row.get("city") or template_result.slots.get("city"), "country": row.get("country") or template_result.slots.get("country"), "region": row.get("region") or template_result.slots.get("region"), "institution_type": row.get("type") or template_result.slots.get("institution_type"), }, "scores": {"combined": 1.0}, # SPARQL results are exact matches } sparql_results.append(transformed) logger.debug(f"[FACTUAL-QUERY] Transformed {len(sparql_results)} results, {sum(1 for r in sparql_results if r['metadata']['latitude'])} with coordinates") else: sparql_error = f"SPARQL returned {response.status_code}" else: sparql_error = "Retriever not available" except Exception as e: sparql_error = str(e) logger.warning(f"[FACTUAL-QUERY] SPARQL execution failed: {e}") elapsed_ms = (time.time() - start_time) * 1000 # Generate answer using ui_template if available, otherwise fallback if sparql_error: answer = f"Er is een fout opgetreden bij het uitvoeren van de query: {sparql_error}" elif not sparql_results: answer = "Geen resultaten gevonden." elif template_result.ui_template: # Use template-defined UI template (template-driven answer formatting) lang = request.language if request.language in template_result.ui_template else "nl" ui_tmpl = template_result.ui_template.get(lang, template_result.ui_template.get("nl", "")) # Build context for Jinja2 template rendering with human-readable labels # The slots have resolved codes (M, NL-NH) but ui_template expects labels (musea, Noord-Holland) template_context = { "result_count": len(sparql_results), "count": sparql_results[0].get("count", len(sparql_results)) if sparql_results else 0, **template_result.slots # Include resolved slot values (codes) } # Add human-readable labels for common slot types # Labels loaded from schema/reference files per Rule 41 (no hardcoding) try: from schema_labels import get_label_resolver label_resolver = get_label_resolver() INSTITUTION_TYPE_LABELS_NL = label_resolver.get_all_institution_type_labels("nl") INSTITUTION_TYPE_LABELS_EN = label_resolver.get_all_institution_type_labels("en") SUBREGION_LABELS = label_resolver.get_all_subregion_labels("nl") except ImportError: # Fallback if schema_labels module not available (shouldn't happen in prod) logger.warning("schema_labels module not available, using inline fallback") INSTITUTION_TYPE_LABELS_NL = { "M": "musea", "L": "bibliotheken", "A": "archieven", "G": "galerijen", "O": "overheidsinstellingen", "R": "onderzoekscentra", "C": "bedrijfsarchieven", "U": "instellingen", "B": "botanische tuinen en dierentuinen", "E": "onderwijsinstellingen", "S": "heemkundige kringen", "F": "monumenten", "I": "immaterieel erfgoedgroepen", "X": "gecombineerde instellingen", "P": "privéverzamelingen", "H": "religieuze erfgoedsites", "D": "digitale platforms", "N": "erfgoedorganisaties", "T": "culinair erfgoed" } INSTITUTION_TYPE_LABELS_EN = { "M": "museums", "L": "libraries", "A": "archives", "G": "galleries", "O": "official institutions", "R": "research centers", "C": "corporate archives", "U": "institutions", "B": "botanical gardens and zoos", "E": "education providers", "S": "heritage societies", "F": "features", "I": "intangible heritage groups", "X": "mixed institutions", "P": "personal collections", "H": "holy sites", "D": "digital platforms", "N": "heritage NGOs", "T": "taste/smell heritage" } SUBREGION_LABELS = { "NL-DR": "Drenthe", "NL-FR": "Friesland", "NL-GE": "Gelderland", "NL-GR": "Groningen", "NL-LI": "Limburg", "NL-NB": "Noord-Brabant", "NL-NH": "Noord-Holland", "NL-OV": "Overijssel", "NL-UT": "Utrecht", "NL-ZE": "Zeeland", "NL-ZH": "Zuid-Holland", "NL-FL": "Flevoland" } # Add institution_type_nl and institution_type_en labels if "institution_type" in template_result.slots: type_code = template_result.slots["institution_type"] template_context["institution_type_nl"] = INSTITUTION_TYPE_LABELS_NL.get(type_code, type_code) template_context["institution_type_en"] = INSTITUTION_TYPE_LABELS_EN.get(type_code, type_code) # Add human-readable location label if "location" in template_result.slots: loc_code = template_result.slots["location"] # Check if it's a subregion code if loc_code in SUBREGION_LABELS: template_context["location"] = SUBREGION_LABELS[loc_code] # Otherwise keep the original (might already be a city name) # Simple Jinja2-style replacement (avoids importing Jinja2) answer = ui_tmpl for key, value in template_context.items(): answer = answer.replace("{{ " + key + " }}", str(value)) answer = answer.replace("{{" + key + "}}", str(value)) elif "count" in template_result.response_modes: # Count query - format as count count_value = sparql_results[0].get("count", len(sparql_results)) answer = f"Aantal: {count_value}" else: # List/table query - just indicate result count answer = f"Gevonden: {len(sparql_results)} resultaten. Zie de tabel hieronder." # Determine visualization type from response_modes viz_types = [] if "table" in template_result.response_modes: viz_types.append("table") if "chart" in template_result.response_modes: viz_types.append("chart") if "map" in template_result.response_modes: viz_types.append("map") # Build response with factual_result=True factual_response = DSPyQueryResponse( question=request.question, resolved_question=getattr(template_result, "resolved_question", None), answer=answer, sources_used=["SPARQL Knowledge Graph"], visualization={ "types": viz_types, "primary_type": viz_types[0] if viz_types else "table", "sparql_query": sparql_query, "response_modes": template_result.response_modes, "databases_used": databases_used, # For transparency/debugging }, retrieved_results=sparql_results, query_type="factual", query_time_ms=round(elapsed_ms, 2), conversation_turn=len(request.context), cache_hit=False, session_id=session_id, template_used=True, template_id=template_result.template_id, factual_result=True, sparql_query=sparql_query, ) # Update session with this turn if session_mgr and session_id: try: await session_mgr.add_turn_to_session( session_id=session_id, question=request.question, answer=answer, resolved_question=getattr(template_result, "resolved_question", None), template_id=template_result.template_id, slots=template_result.slots or {}, ) except Exception as e: logger.warning(f"Failed to update session: {e}") # Record metrics if METRICS_AVAILABLE and record_query: try: record_query( endpoint="dspy_query", template_used=True, template_id=template_result.template_id, cache_hit=False, status="success", duration_seconds=elapsed_ms / 1000, intent="factual", ) except Exception as e: logger.warning(f"Failed to record metrics: {e}") # Cache the response if retriever: await retriever.cache.set_dspy( question=request.question, language=request.language, llm_provider="none", # No LLM used embedding_model=request.embedding_model, response=factual_response.model_dump(), context=request.context if request.context else None, ) logger.info(f"[FACTUAL-QUERY] Returned {len(sparql_results)} results in {elapsed_ms:.2f}ms (LLM skipped)") return factual_response except ImportError as e: logger.debug(f"Template SPARQL not available for factual query detection: {e}") except Exception as e: logger.warning(f"Factual query detection failed (continuing with full pipeline): {e}") # ========================================================================== # FULL RAG PIPELINE: For non-factual queries or when factual detection fails # ========================================================================== try: # Import DSPy pipeline and History import dspy from dspy import History from dspy_heritage_rag import HeritageRAGPipeline # Configure DSPy LM per-request based on request.llm_provider (or server default) # This allows frontend to switch LLM providers dynamically # # IMPORTANT: We use dspy.settings.context() instead of dspy.configure() because # configure() can only be called from the same async task that initially configured DSPy. # context() provides thread-local overrides that work correctly in async request handlers. requested_provider = resolved_provider # Already resolved above llm_provider_used: str | None = None llm_model_used: str | None = None lm = None logger.info(f"LLM provider requested: {requested_provider} (request.llm_provider={request.llm_provider}, server default={settings.llm_provider})") # Check if requested provider has API key configured - fail early if not provider_api_keys = { "zai": settings.zai_api_token, "groq": settings.groq_api_key, "anthropic": settings.anthropic_api_key, "openai": settings.openai_api_key, "huggingface": settings.huggingface_api_key, } if requested_provider in provider_api_keys and not provider_api_keys[requested_provider]: raise ValueError( f"LLM provider '{requested_provider}' was requested but its API key is not configured. " f"Please set the appropriate environment variable (e.g., ANTHROPIC_API_KEY or CLAUDE_API_KEY for anthropic)." ) # Provider configuration priority: requested provider first, then fallback chain providers_to_try = [requested_provider] # Add fallback chain (but not duplicates) for fallback in ["zai", "groq", "anthropic", "openai"]: if fallback not in providers_to_try: providers_to_try.append(fallback) for provider in providers_to_try: if lm is not None: break # Default models per provider (used if request.llm_model is not specified) # Use LLM_MODEL from settings when it matches the provider prefix default_models = { "zai": settings.llm_model if settings.llm_model.startswith("glm-") else "glm-4.5-flash", "groq": "llama-3.1-8b-instant", "anthropic": settings.llm_model if settings.llm_model.startswith("claude-") else "claude-sonnet-4-20250514", "openai": "gpt-4o-mini", # Llama 3.1 8B: Good balance of speed/quality, available on HF serverless inference # Alternatives: Qwen/QwQ-32B (better reasoning), mistralai/Mistral-7B-Instruct-v0.2 "huggingface": settings.llm_model if "/" in settings.llm_model else "meta-llama/Llama-3.1-8B-Instruct", } # HuggingFace models use org/model format (e.g., meta-llama/Llama-3.1-8B-Instruct) # Groq models use simple names (e.g., llama-3.1-8b-instant) model_prefixes = { "glm-": "zai", "llama-3.1-": "groq", "llama-3.3-": "groq", "claude-": "anthropic", "gpt-": "openai", # HuggingFace organization prefixes "mistralai/": "huggingface", "google/": "huggingface", "Qwen/": "huggingface", "deepseek-ai/": "huggingface", "meta-llama/": "huggingface", "utter-project/": "huggingface", "microsoft/": "huggingface", "tiiuae/": "huggingface", } # Determine which model to use: requested model (if valid for this provider) or default requested_model = request.llm_model model_to_use = default_models.get(provider, "") # Check if requested model matches this provider if requested_model: for prefix, model_provider in model_prefixes.items(): if requested_model.startswith(prefix) and model_provider == provider: model_to_use = requested_model break if provider == "zai" and settings.zai_api_token: try: lm = dspy.LM( f"openai/{model_to_use}", api_key=settings.zai_api_token, api_base="https://api.z.ai/api/coding/paas/v4", ) llm_provider_used = "zai" llm_model_used = model_to_use logger.info(f"Using Z.AI {model_to_use} (FREE) for this request") except Exception as e: logger.warning(f"Failed to create Z.AI LM: {e}") elif provider == "groq" and settings.groq_api_key: try: lm = dspy.LM(f"groq/{model_to_use}", api_key=settings.groq_api_key) llm_provider_used = "groq" llm_model_used = model_to_use logger.info(f"Using Groq {model_to_use} (FREE) for this request") except Exception as e: logger.warning(f"Failed to create Groq LM: {e}") elif provider == "huggingface" and settings.huggingface_api_key: try: lm = dspy.LM(f"huggingface/{model_to_use}", api_key=settings.huggingface_api_key) llm_provider_used = "huggingface" llm_model_used = model_to_use logger.info(f"Using HuggingFace {model_to_use} for this request") except Exception as e: logger.warning(f"Failed to create HuggingFace LM: {e}") elif provider == "anthropic" and settings.anthropic_api_key: try: lm = dspy.LM(f"anthropic/{model_to_use}", api_key=settings.anthropic_api_key) llm_provider_used = "anthropic" llm_model_used = model_to_use logger.info(f"Using Anthropic {model_to_use} for this request") except Exception as e: logger.warning(f"Failed to create Anthropic LM: {e}") elif provider == "openai" and settings.openai_api_key: try: lm = dspy.LM(f"openai/{model_to_use}", api_key=settings.openai_api_key) llm_provider_used = "openai" llm_model_used = model_to_use logger.info(f"Using OpenAI {model_to_use} for this request") except Exception as e: logger.warning(f"Failed to create OpenAI LM: {e}") # No LM could be configured if lm is None: raise ValueError( f"No LLM could be configured. Requested provider: {requested_provider}. " "Ensure the appropriate API key is set: ZAI_API_TOKEN, GROQ_API_KEY, ANTHROPIC_API_KEY, HUGGINGFACE_API_KEY, or OPENAI_API_KEY." ) logger.info(f"LLM provider for this request: {llm_provider_used}") # ================================================================= # PERFORMANCE OPTIMIZATION: Create fast LM for routing/extraction # Use a fast, cheap model (glm-4.5-flash FREE, gpt-4o-mini $0.15/1M) # for routing, entity extraction, and SPARQL generation. # The quality_lm (lm) is used only for final answer generation. # This can reduce total latency by 2-3x (from ~20s to ~7s). # ================================================================= fast_lm = None # Try to create fast_lm based on FAST_LM_PROVIDER setting # Options: "openai" (fast ~1-2s, $0.15/1M) or "zai" (FREE but slow ~13s) # Default: openai for speed. Override with FAST_LM_PROVIDER=zai to save costs. if settings.fast_lm_provider == "openai" and settings.openai_api_key: try: fast_lm = dspy.LM("openai/gpt-4o-mini", api_key=settings.openai_api_key) logger.info("Using OpenAI GPT-4o-mini as fast_lm for routing/extraction (~1-2s)") except Exception as e: logger.warning(f"Failed to create fast OpenAI LM: {e}") if fast_lm is None and settings.fast_lm_provider == "zai" and settings.zai_api_token: try: fast_lm = dspy.LM( "openai/glm-4.5-flash", api_key=settings.zai_api_token, api_base="https://api.z.ai/api/coding/paas/v4", ) logger.info("Using Z.AI GLM-4.5-flash (FREE) as fast_lm for routing/extraction (~13s)") except Exception as e: logger.warning(f"Failed to create fast Z.AI LM: {e}") # Fallback: try the other provider if preferred one failed if fast_lm is None and settings.openai_api_key: try: fast_lm = dspy.LM("openai/gpt-4o-mini", api_key=settings.openai_api_key) logger.info("Fallback: Using OpenAI GPT-4o-mini as fast_lm") except Exception as e: logger.warning(f"Fallback failed - no fast_lm available: {e}") if fast_lm is None: logger.info("No fast_lm available - all stages will use quality_lm (slower but works)") # Convert context to DSPy History format # Context comes as [{question: "...", answer: "..."}, ...] # History expects messages in the same format: [{question: "...", answer: "..."}, ...] # (NOT role/content format - that was a bug!) history_messages = [] for turn in request.context: # Only include turns that have both question AND answer if turn.get("question") and turn.get("answer"): history_messages.append({ "question": turn["question"], "answer": turn["answer"] }) history = History(messages=history_messages) if history_messages else None # Use global optimized pipeline (loaded with BootstrapFewShot weights: +14.3% quality) # Falls back to creating a new pipeline if global not available if dspy_pipeline is not None: pipeline = dspy_pipeline logger.debug("Using global optimized DSPy pipeline") else: # Fallback: create pipeline without optimized weights qdrant_retriever = retriever.qdrant if retriever else None pipeline = HeritageRAGPipeline( retriever=qdrant_retriever, fast_lm=fast_lm, quality_lm=lm, ) logger.debug("Using fallback (unoptimized) DSPy pipeline") # Execute query with conversation history # Retry logic for transient API errors (e.g., Anthropic "Overloaded" errors) # # IMPORTANT: We use dspy.settings.context(lm=lm) to set the LLM for this request. # This provides thread-local overrides that work correctly in async request handlers, # unlike dspy.configure() which can only be called from the main async task. max_retries = 3 last_error: Exception | None = None result = None # Helper function to run pipeline synchronously (for asyncio.to_thread) def run_pipeline_sync(): """Run DSPy pipeline in sync context with retry logic.""" nonlocal last_error, result with dspy.settings.context(lm=lm): for attempt in range(max_retries): try: # Use pipeline() instead of pipeline.forward() per DSPy 3.0 best practices return pipeline( embedding_model=request.embedding_model, question=request.question, language=request.language, history=history, include_viz=request.include_visualization, conversation_state=conversation_state, # Pass session state for template SPARQL ) except Exception as e: last_error = e error_str = str(e).lower() # Check for retryable errors (API overload, rate limits, temporary failures) is_retryable = any(keyword in error_str for keyword in [ "overloaded", "rate_limit", "rate limit", "too many requests", "529", "503", "502", "504", # HTTP status codes "temporarily unavailable", "service unavailable", "connection reset", "connection refused", "timeout" ]) if is_retryable and attempt < max_retries - 1: wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s logger.warning( f"Transient API error (attempt {attempt + 1}/{max_retries}): {e}. " f"Retrying in {wait_time}s..." ) time.sleep(wait_time) # OK to block in thread pool continue else: # Non-retryable error or max retries reached raise return None # Run DSPy pipeline in thread pool to avoid blocking the event loop result = await asyncio.to_thread(run_pipeline_sync) # If we get here without a result (all retries exhausted), raise the last error if result is None: if last_error: raise last_error raise HTTPException(status_code=500, detail="Pipeline execution failed with no result") elapsed_ms = (time.time() - start_time) * 1000 # Extract retrieved results for frontend visualization (tables, graphs) retrieved_results = getattr(result, "retrieved_results", None) query_type = getattr(result, "query_type", None) # Extract visualization if present visualization = None if request.include_visualization and hasattr(result, "visualization"): viz = result.visualization if viz: # Now showing SPARQL for all query types including person queries # Person queries use HeritagePersonSPARQLGenerator (schema:Person predicates) # Institution queries use HeritageSPARQLGenerator (crm:E39_Actor predicates) sparql_to_show = getattr(result, "sparql", None) visualization = { "type": getattr(viz, "viz_type", "table"), "sparql_query": sparql_to_show, } # Extract LLM response metadata from DSPy history (GLM 4.7 reasoning_content support) llm_response_metadata = extract_llm_response_metadata( lm=lm, provider=llm_provider_used, latency_ms=int(elapsed_ms), ) # Extract template SPARQL info from result template_used = getattr(result, "template_used", False) template_id = getattr(result, "template_id", None) # Rule 46: Build epistemic provenance for transparency # This tracks WHERE, WHEN, and HOW the response data originated sources_used_list = getattr(result, "sources_used", []) # Infer data tier from sources - SPARQL/TypeDB are authoritative, Qdrant may include scraped data inferred_tier = DataTier.TIER_3_CROWD_SOURCED.value # Default if template_used: # Template-based SPARQL uses curated Oxigraph data inferred_tier = DataTier.TIER_1_AUTHORITATIVE.value elif any(s.lower() in ["sparql", "typedb"] for s in sources_used_list): inferred_tier = DataTier.TIER_1_AUTHORITATIVE.value elif any(s.lower() == "qdrant" for s in sources_used_list): inferred_tier = DataTier.TIER_3_CROWD_SOURCED.value # Build provenance object response_provenance = EpistemicProvenance( dataSource=EpistemicDataSource.RAG_PIPELINE, dataTier=inferred_tier, derivationChain=build_derivation_chain( sources_used=sources_used_list, template_used=template_used, template_id=template_id, llm_provider=llm_provider_used, ), sourcesQueried=sources_used_list, totalRetrieved=len(retrieved_results) if retrieved_results else 0, totalAfterFusion=len(retrieved_results) if retrieved_results else 0, templateUsed=template_used, templateId=template_id, llmProvider=llm_provider_used, llmModel=llm_model_used, ) # Build response object response = DSPyQueryResponse( question=request.question, resolved_question=getattr(result, "resolved_question", None), answer=getattr(result, "answer", "Geen antwoord gevonden."), sources_used=sources_used_list, visualization=visualization, retrieved_results=retrieved_results, # Raw data for frontend visualization query_type=query_type, # "person" or "institution" query_time_ms=round(elapsed_ms, 2), conversation_turn=len(request.context), embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model), # Cost tracking fields timing_ms=getattr(result, "timing_ms", None), cost_usd=getattr(result, "cost_usd", None), timing_breakdown=getattr(result, "timing_breakdown", None), # LLM provider tracking llm_provider_used=llm_provider_used, llm_model_used=llm_model_used, cache_hit=False, # LLM response provenance (GLM 4.7 Thinking Mode chain-of-thought) llm_response=llm_response_metadata, # Session management - return session_id for follow-up queries session_id=session_id, # Template SPARQL tracking template_used=template_used, template_id=template_id, # Rule 46: Epistemic provenance for transparency epistemic_provenance=response_provenance.model_dump(), ) # Update session with this turn for multi-turn conversation support if session_mgr and session_id: try: await session_mgr.add_turn_to_session( session_id=session_id, question=request.question, answer=response.answer, resolved_question=response.resolved_question, template_id=template_id, slots=getattr(result, "slots", {}), # Extracted slots for follow-up inheritance ) logger.debug(f"Session {session_id} updated with new turn") except Exception as e: logger.warning(f"Failed to update session {session_id}: {e}") # Record Prometheus metrics for monitoring if METRICS_AVAILABLE and record_query: try: record_query( endpoint="dspy_query", template_used=template_used, template_id=template_id, cache_hit=False, status="success", duration_seconds=elapsed_ms / 1000, intent=query_type, ) except Exception as e: logger.warning(f"Failed to record metrics: {e}") # Cache the successful response for future requests if retriever: await retriever.cache.set_dspy( question=request.question, language=request.language, llm_provider=llm_provider_used, # Use actual provider, not requested embedding_model=request.embedding_model, response=response.model_dump(), context=request.context if request.context else None, ) # === CACHE ATOMIC SUB-TASKS FOR FUTURE QUERIES === # Cache individual sub-tasks for higher hit rates on similar queries # E.g., "musea in Amsterdam" sub-task can be reused for # "Hoeveel musea in Amsterdam hebben een website?" if ATOMIC_CACHE_AVAILABLE and atomic_cache_manager and decomposed_query: try: subtasks_cached = 0 for task in decomposed_query.sub_tasks: if not task.cache_hit: # Extract relevant result for this sub-task type subtask_result = _extract_subtask_result(task, result, response) if subtask_result is not None: await atomic_cache_manager.cache_subtask_result( task=task, result=subtask_result, language=request.language, ttl=3600, # 1 hour TTL ) subtasks_cached += 1 # Record subtasks cached metric if subtasks_cached > 0 and record_atomic_subtask_cached: record_atomic_subtask_cached(subtasks_cached) # Log atomic cache stats periodically stats = atomic_cache_manager.get_stats() if stats["queries_decomposed"] % 10 == 0: logger.info( f"Atomic cache stats: {stats['subtask_hit_rate']}% hit rate, " f"{stats['queries_decomposed']} queries, " f"{stats['full_query_reassemblies']} fully cached" ) except Exception as e: logger.warning(f"Failed to cache atomic sub-tasks: {e}") return response except ImportError as e: logger.warning(f"DSPy pipeline not available: {e}") # Fallback to simple response return DSPyQueryResponse( question=request.question, answer="DSPy pipeline is niet beschikbaar. Probeer de standaard /api/rag/query endpoint.", query_time_ms=0, conversation_turn=len(request.context), embedding_model_used=request.embedding_model, session_id=session_id, # Still return session_id even on error ) except Exception as e: logger.exception("DSPy query failed") raise HTTPException(status_code=500, detail=str(e)) async def stream_dspy_query_response( request: DSPyQueryRequest, ) -> AsyncIterator[str]: """Stream DSPy query response with progress updates for long-running queries. Yields NDJSON lines with status updates at each pipeline stage: - {"type": "status", "stage": "cache", "message": "🔍 Cache controleren..."} - {"type": "status", "stage": "config", "message": "⚙️ LLM configureren..."} - {"type": "status", "stage": "routing", "message": "🧭 Vraag analyseren..."} - {"type": "status", "stage": "retrieval", "message": "📊 Database doorzoeken..."} - {"type": "status", "stage": "generation", "message": "💡 Antwoord genereren..."} - {"type": "complete", "data": {...DSPyQueryResponse...}} """ import time start_time = time.time() # Session management for multi-turn conversations # Get or create session state that enables follow-up question resolution session_id = request.session_id conversation_state = None session_mgr = None if SESSION_MANAGER_AVAILABLE and get_session_manager: try: session_mgr = await get_session_manager() session_id, conversation_state = await session_mgr.get_or_create(request.session_id) logger.debug(f"Stream session {session_id}: {len(conversation_state.turns)} previous turns") except Exception as e: logger.warning(f"Session manager error (continuing without session): {e}") import uuid session_id = str(uuid.uuid4()) else: import uuid session_id = request.session_id or str(uuid.uuid4()) def emit_status(stage: str, message: str) -> str: """Helper to emit status JSON line.""" return json.dumps({ "type": "status", "stage": stage, "message": message, "elapsed_ms": round((time.time() - start_time) * 1000, 2), }) + "\n" def emit_error(error: str, details: str | None = None) -> str: """Helper to emit error JSON line.""" return json.dumps({ "type": "error", "error": error, "details": details, "elapsed_ms": round((time.time() - start_time) * 1000, 2), }) + "\n" def extract_user_friendly_error(exception: Exception) -> tuple[str, str | None]: """Extract a user-friendly error message from various exception types. Returns: tuple: (user_message, technical_details) """ error_str = str(exception) error_lower = error_str.lower() # HuggingFace / LiteLLM specific errors if "huggingface" in error_lower or "hf" in error_lower: if "model_not_supported" in error_lower or "not a chat model" in error_lower: # Extract model name if present import re model_match = re.search(r"model['\"]?\s*[:=]\s*['\"]?([^'\"}\s,]+)", error_str) model_name = model_match.group(1) if model_match else "geselecteerde model" return ( f"Het model '{model_name}' wordt niet ondersteund door HuggingFace. Kies een ander model.", error_str ) if "rate limit" in error_lower or "too many requests" in error_lower: return ( "HuggingFace API limiet bereikt. Probeer het over een minuut opnieuw.", error_str ) if "unauthorized" in error_lower or "invalid api key" in error_lower: return ( "HuggingFace API sleutel ongeldig. Neem contact op met de beheerder.", error_str ) if "model is loading" in error_lower or "loading" in error_lower and "model" in error_lower: return ( "Het HuggingFace model wordt geladen. Probeer het over 30 seconden opnieuw.", error_str ) # Anthropic errors if "anthropic" in error_lower: if "rate limit" in error_lower or "overloaded" in error_lower: return ( "Anthropic API is overbelast. Probeer het over een minuut opnieuw.", error_str ) if "invalid api key" in error_lower or "unauthorized" in error_lower: return ( "Anthropic API sleutel ongeldig. Neem contact op met de beheerder.", error_str ) # OpenAI errors if "openai" in error_lower: if "rate limit" in error_lower: return ( "OpenAI API limiet bereikt. Probeer het over een minuut opnieuw.", error_str ) if "invalid api key" in error_lower: return ( "OpenAI API sleutel ongeldig. Neem contact op met de beheerder.", error_str ) # Z.AI errors if "z.ai" in error_lower or "zai" in error_lower: if "rate limit" in error_lower or "quota" in error_lower: return ( "Z.AI API limiet bereikt. Probeer het over een minuut opnieuw.", error_str ) # Generic network/connection errors if "connection" in error_lower or "timeout" in error_lower: return ( "Verbindingsfout met de AI service. Controleer uw internetverbinding en probeer het opnieuw.", error_str ) if "503" in error_str or "service unavailable" in error_lower: return ( "De AI service is tijdelijk niet beschikbaar. Probeer het over een minuut opnieuw.", error_str ) # Qdrant/retrieval errors if "qdrant" in error_lower: return ( "Fout bij het doorzoeken van de database. Probeer het later opnieuw.", error_str ) # Default: return the raw error but in a nicer format return ( f"Er is een fout opgetreden: {error_str[:200]}{'...' if len(error_str) > 200 else ''}", error_str if len(error_str) > 200 else None ) # Resolve the provider BEFORE cache lookup to ensure consistent cache keys # This is critical: cache GET and SET must use the same provider value resolved_provider = (request.llm_provider or settings.llm_provider).lower() # Stage 1: Check cache yield emit_status("cache", "🔍 Cache controleren...") if retriever: cached = await retriever.cache.get_dspy( question=request.question, language=request.language, llm_provider=resolved_provider, # Use resolved provider, not request.llm_provider embedding_model=request.embedding_model, context=request.context if request.context else None, ) if cached: elapsed_ms = (time.time() - start_time) * 1000 logger.info(f"DSPy cache hit - returning cached response in {elapsed_ms:.2f}ms") # Transform CachedResponse format back to DSPyQueryResponse format cached_context = cached.get("context") or {} visualization = None if cached.get("visualization_type") or cached.get("visualization_data"): visualization = { "type": cached.get("visualization_type"), "data": cached.get("visualization_data"), } # Rule 46: Build provenance for streaming cache hit responses stream_cached_sources = cached.get("sources", []) stream_cached_template_used = cached_context.get("template_used", False) stream_cached_template_id = cached_context.get("template_id") stream_cached_llm_provider = cached_context.get("llm_provider") stream_cached_llm_model = cached_context.get("llm_model") # Infer data tier - prioritize cached provenance if present stream_cached_prov = cached_context.get("epistemic_provenance") if stream_cached_prov: # Use the cached provenance, but mark it as coming from cache stream_cache_provenance = stream_cached_prov.copy() if "CACHE" not in stream_cache_provenance.get("derivationChain", []): stream_cache_provenance.setdefault("derivationChain", []).insert(0, "CACHE:hit") else: # Build fresh provenance for older cache entries stream_cache_tier = DataTier.TIER_3_CROWD_SOURCED.value if stream_cached_template_used: stream_cache_tier = DataTier.TIER_1_AUTHORITATIVE.value elif any(s.lower() in ["sparql", "typedb"] for s in stream_cached_sources): stream_cache_tier = DataTier.TIER_1_AUTHORITATIVE.value stream_cache_provenance = EpistemicProvenance( dataSource=EpistemicDataSource.CACHE_AGGREGATION, dataTier=stream_cache_tier, derivationChain=["CACHE:hit"] + build_derivation_chain( sources_used=stream_cached_sources, template_used=stream_cached_template_used, template_id=stream_cached_template_id, llm_provider=stream_cached_llm_provider, ), sourcesQueried=stream_cached_sources, templateUsed=stream_cached_template_used, templateId=stream_cached_template_id, llmProvider=stream_cached_llm_provider, llmModel=stream_cached_llm_model, ).model_dump() response_data = { "question": request.question, "answer": cached.get("answer", ""), "sources_used": stream_cached_sources, "visualization": visualization, "resolved_question": cached_context.get("resolved_question"), "retrieved_results": cached_context.get("retrieved_results"), "query_type": cached_context.get("query_type"), "embedding_model_used": cached_context.get("embedding_model"), "llm_model_used": stream_cached_llm_model, "query_time_ms": round(elapsed_ms, 2), "cache_hit": True, # Session management "session_id": session_id, # Template tracking from cache "template_used": stream_cached_template_used, "template_id": stream_cached_template_id, # Rule 46: Epistemic provenance for transparency "epistemic_provenance": stream_cache_provenance, } # Record cache hit metrics for streaming endpoint if METRICS_AVAILABLE and record_query: try: record_query( endpoint="dspy_query_stream", template_used=cached_context.get("template_used", False), template_id=cached_context.get("template_id"), cache_hit=True, status="success", duration_seconds=elapsed_ms / 1000, intent=cached_context.get("query_type"), ) except Exception as e: logger.warning(f"Failed to record streaming cache hit metrics: {e}") yield emit_status("cache", "✅ Antwoord gevonden in cache!") yield json.dumps({"type": "complete", "data": response_data}) + "\n" return try: # Stage 2: Configure LLM yield emit_status("config", "⚙️ LLM configureren...") import dspy from dspy import History from dspy_heritage_rag import HeritageRAGPipeline requested_provider = resolved_provider # Already resolved above llm_provider_used: str | None = None llm_model_used: str | None = None lm = None # Check if requested provider has API key configured - fail early if not provider_api_keys = { "zai": settings.zai_api_token, "groq": settings.groq_api_key, "anthropic": settings.anthropic_api_key, "openai": settings.openai_api_key, "huggingface": settings.huggingface_api_key, } if requested_provider in provider_api_keys and not provider_api_keys[requested_provider]: yield emit_error( f"LLM provider '{requested_provider}' was requested but its API key is not configured. " f"Please set the appropriate environment variable (e.g., ANTHROPIC_API_KEY or CLAUDE_API_KEY for anthropic)." ) return providers_to_try = [requested_provider] for fallback in ["zai", "groq", "anthropic", "openai"]: if fallback not in providers_to_try: providers_to_try.append(fallback) for provider in providers_to_try: if lm is not None: break # Default models per provider (used if request.llm_model is not specified) # Use LLM_MODEL from settings when it matches the provider prefix default_models = { "zai": settings.llm_model if settings.llm_model.startswith("glm-") else "glm-4.5-flash", "groq": "llama-3.1-8b-instant", "anthropic": settings.llm_model if settings.llm_model.startswith("claude-") else "claude-sonnet-4-20250514", "openai": "gpt-4o-mini", # Llama 3.1 8B: Good balance of speed/quality, available on HF serverless inference # Alternatives: Qwen/QwQ-32B (better reasoning), mistralai/Mistral-7B-Instruct-v0.2 "huggingface": settings.llm_model if "/" in settings.llm_model else "meta-llama/Llama-3.1-8B-Instruct", } # HuggingFace models use org/model format (e.g., meta-llama/Llama-3.1-8B-Instruct) # Groq models use simple names (e.g., llama-3.1-8b-instant) model_prefixes = { "glm-": "zai", "llama-3.1-": "groq", "llama-3.3-": "groq", "claude-": "anthropic", "gpt-": "openai", # HuggingFace organization prefixes "mistralai/": "huggingface", "google/": "huggingface", "Qwen/": "huggingface", "deepseek-ai/": "huggingface", "meta-llama/": "huggingface", "utter-project/": "huggingface", "microsoft/": "huggingface", "tiiuae/": "huggingface", } # Determine which model to use: requested model (if valid for this provider) or default requested_model = request.llm_model model_to_use = default_models.get(provider, "") # Check if requested model matches this provider if requested_model: for prefix, model_provider in model_prefixes.items(): if requested_model.startswith(prefix) and model_provider == provider: model_to_use = requested_model break if provider == "zai" and settings.zai_api_token: try: lm = dspy.LM( f"openai/{model_to_use}", api_key=settings.zai_api_token, api_base="https://api.z.ai/api/coding/paas/v4", ) llm_provider_used = "zai" llm_model_used = model_to_use except Exception as e: logger.warning(f"Failed to create Z.AI LM: {e}") elif provider == "groq" and settings.groq_api_key: try: lm = dspy.LM(f"groq/{model_to_use}", api_key=settings.groq_api_key) llm_provider_used = "groq" llm_model_used = model_to_use logger.info(f"Using Groq {model_to_use} (FREE) for streaming request") except Exception as e: logger.warning(f"Failed to create Groq LM: {e}") elif provider == "huggingface" and settings.huggingface_api_key: try: lm = dspy.LM(f"huggingface/{model_to_use}", api_key=settings.huggingface_api_key) llm_provider_used = "huggingface" llm_model_used = model_to_use except Exception as e: logger.warning(f"Failed to create HuggingFace LM: {e}") elif provider == "anthropic" and settings.anthropic_api_key: try: lm = dspy.LM(f"anthropic/{model_to_use}", api_key=settings.anthropic_api_key) llm_provider_used = "anthropic" llm_model_used = model_to_use except Exception as e: logger.warning(f"Failed to create Anthropic LM: {e}") elif provider == "openai" and settings.openai_api_key: try: lm = dspy.LM(f"openai/{model_to_use}", api_key=settings.openai_api_key) llm_provider_used = "openai" llm_model_used = model_to_use except Exception as e: logger.warning(f"Failed to create OpenAI LM: {e}") if lm is None: yield emit_error(f"Geen LLM beschikbaar. Controleer API keys.") return yield emit_status("config", f"✅ LLM geconfigureerd ({llm_provider_used})") # Stage 3: Prepare conversation history yield emit_status("routing", "🧭 Vraag analyseren...") history_messages = [] for turn in request.context: if turn.get("question") and turn.get("answer"): history_messages.append({ "question": turn["question"], "answer": turn["answer"] }) history = History(messages=history_messages) if history_messages else None # Use global optimized pipeline (loaded with BootstrapFewShot weights: +14.3% quality) if dspy_pipeline is not None: pipeline = dspy_pipeline logger.debug("Using global optimized DSPy pipeline (streaming)") else: # Fallback: create pipeline without optimized weights qdrant_retriever = retriever.qdrant if retriever else None pipeline = HeritageRAGPipeline(retriever=qdrant_retriever) logger.debug("Using fallback (unoptimized) DSPy pipeline (streaming)") # Stage 4: Execute pipeline with STREAMING answer generation yield emit_status("retrieval", "📊 Database doorzoeken...") result = None # Check if pipeline supports streaming if hasattr(pipeline, 'forward_streaming'): # Use streaming mode - tokens arrive as they're generated try: with dspy.settings.context(lm=lm): async for event in pipeline.forward_streaming( embedding_model=request.embedding_model, question=request.question, language=request.language, history=history, include_viz=request.include_visualization, conversation_state=conversation_state, # Pass session state for template SPARQL ): event_type = event.get("type") if event_type == "cache_hit": # Cache hit - return immediately result = event["prediction"] yield emit_status("complete", "✅ Klaar! (cache)") break elif event_type == "retrieval_complete": # Retrieval done, now generating answer yield emit_status("generation", "💡 Antwoord genereren...") elif event_type == "token": # Stream token to frontend yield json.dumps({"type": "token", "content": event["content"]}) + "\n" elif event_type == "status": # Status message from pipeline yield emit_status("generation", event.get("message", "...")) elif event_type == "answer_complete": # Final prediction ready result = event["prediction"] except Exception as e: logger.exception(f"Streaming pipeline execution failed: {e}") user_msg, details = extract_user_friendly_error(e) yield emit_error(user_msg, details) return else: # Fallback: Non-streaming mode (original behavior) max_retries = 3 last_error: Exception | None = None with dspy.settings.context(lm=lm): for attempt in range(max_retries): try: if attempt > 0: yield emit_status("retrieval", f"🔄 Opnieuw proberen ({attempt + 1}/{max_retries})...") result = pipeline( embedding_model=request.embedding_model, question=request.question, language=request.language, history=history, include_viz=request.include_visualization, conversation_state=conversation_state, ) break except Exception as e: last_error = e error_str = str(e).lower() is_retryable = any(keyword in error_str for keyword in [ "overloaded", "rate_limit", "rate limit", "too many requests", "529", "503", "502", "504", "temporarily unavailable", "service unavailable", "connection reset", "connection refused", "timeout" ]) if is_retryable and attempt < max_retries - 1: wait_time = 2 ** attempt logger.warning(f"Transient API error (attempt {attempt + 1}/{max_retries}): {e}") yield emit_status("retrieval", f"⏳ API overbelast, wachten {wait_time}s...") await asyncio.sleep(wait_time) continue else: logger.exception(f"Pipeline execution failed after {attempt + 1} attempts") user_msg, details = extract_user_friendly_error(e) yield emit_error(user_msg, details) return if result is None: if last_error: user_msg, details = extract_user_friendly_error(last_error) yield emit_error(user_msg, details) return yield emit_error("Pipeline uitvoering mislukt zonder resultaat") return # Stage 5: Generate response (only for non-streaming fallback) yield emit_status("generation", "💡 Antwoord genereren...") elapsed_ms = (time.time() - start_time) * 1000 # Extract query_type first - needed for SPARQL visibility decision query_type = getattr(result, "query_type", None) visualization = None if request.include_visualization and hasattr(result, "visualization"): viz = result.visualization if viz: # Now showing SPARQL for all query types including person queries # Person queries use HeritagePersonSPARQLGenerator (schema:Person predicates) # Institution queries use HeritageSPARQLGenerator (crm:E39_Actor predicates) sparql_to_show = getattr(result, "sparql", None) # viz can be either an object (with .viz_type attr) or a dict (with "type" key) # Handle both cases for compatibility with streaming and non-streaming modes if isinstance(viz, dict): viz_type = viz.get("type", "table") else: viz_type = getattr(viz, "viz_type", "table") visualization = { "type": viz_type, "sparql_query": sparql_to_show, } logger.info(f"[DEBUG] Built visualization: type={viz_type}, sparql_len={len(sparql_to_show) if sparql_to_show else 0}") retrieved_results = getattr(result, "retrieved_results", None) # Extract LLM response metadata from DSPy history (GLM 4.7 reasoning_content support) llm_response_metadata = extract_llm_response_metadata( lm=lm, provider=llm_provider_used, latency_ms=int(elapsed_ms), ) # Rule 46: Build epistemic provenance for streaming endpoint stream_sources_used = getattr(result, "sources_used", []) stream_template_used = getattr(result, "template_used", False) stream_template_id = getattr(result, "template_id", None) # Infer data tier from sources stream_tier = DataTier.TIER_3_CROWD_SOURCED.value if stream_template_used: stream_tier = DataTier.TIER_1_AUTHORITATIVE.value elif any(s.lower() in ["sparql", "typedb"] for s in stream_sources_used): stream_tier = DataTier.TIER_1_AUTHORITATIVE.value stream_provenance = EpistemicProvenance( dataSource=EpistemicDataSource.RAG_PIPELINE, dataTier=stream_tier, derivationChain=build_derivation_chain( sources_used=stream_sources_used, template_used=stream_template_used, template_id=stream_template_id, llm_provider=llm_provider_used, ), sourcesQueried=stream_sources_used, totalRetrieved=len(retrieved_results) if retrieved_results else 0, totalAfterFusion=len(retrieved_results) if retrieved_results else 0, templateUsed=stream_template_used, templateId=stream_template_id, llmProvider=llm_provider_used, llmModel=llm_model_used, ) response = DSPyQueryResponse( question=request.question, resolved_question=getattr(result, "resolved_question", None), answer=getattr(result, "answer", "Geen antwoord gevonden."), sources_used=stream_sources_used, visualization=visualization, retrieved_results=retrieved_results, query_type=query_type, query_time_ms=round(elapsed_ms, 2), conversation_turn=len(request.context), embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model), timing_ms=getattr(result, "timing_ms", None), cost_usd=getattr(result, "cost_usd", None), timing_breakdown=getattr(result, "timing_breakdown", None), llm_provider_used=llm_provider_used, llm_model_used=llm_model_used, cache_hit=False, # LLM response provenance (GLM 4.7 Thinking Mode chain-of-thought) llm_response=llm_response_metadata, # Session management fields for multi-turn conversations session_id=session_id, template_used=stream_template_used, template_id=stream_template_id, # Rule 46: Epistemic provenance for transparency epistemic_provenance=stream_provenance.model_dump(), ) # Update session with this turn (before caching) if session_mgr and session_id and conversation_state is not None: try: await session_mgr.add_turn_to_session( session_id=session_id, question=request.question, answer=response.answer, resolved_question=response.resolved_question, template_id=getattr(result, "template_id", None), slots=getattr(result, "slots", {}), ) logger.debug(f"Updated session {session_id} with new turn") except Exception as e: logger.warning(f"Failed to update session {session_id}: {e}") # Record Prometheus metrics for monitoring if METRICS_AVAILABLE and record_query: try: record_query( endpoint="dspy_query_stream", template_used=getattr(result, "template_used", False), template_id=getattr(result, "template_id", None), cache_hit=False, status="success", duration_seconds=elapsed_ms / 1000, intent=query_type, ) except Exception as e: logger.warning(f"Failed to record streaming metrics: {e}") # Cache the response if retriever: await retriever.cache.set_dspy( question=request.question, language=request.language, llm_provider=llm_provider_used, embedding_model=request.embedding_model, response=response.model_dump(), context=request.context if request.context else None, ) yield emit_status("complete", "✅ Klaar!") yield json.dumps({"type": "complete", "data": response.model_dump()}) + "\n" except ImportError as e: logger.warning(f"DSPy pipeline not available: {e}") yield emit_error("DSPy pipeline is niet beschikbaar.") except Exception as e: logger.exception("DSPy streaming query failed") user_msg, details = extract_user_friendly_error(e) yield emit_error(user_msg, details) @app.post("/api/rag/dspy/query/stream") async def dspy_query_stream(request: DSPyQueryRequest) -> StreamingResponse: """Streaming version of DSPy RAG query endpoint. Returns NDJSON stream with status updates at each pipeline stage, allowing the frontend to show progress during long-running queries. Status stages: - cache: Checking for cached response - config: Configuring LLM provider - routing: Analyzing query intent - retrieval: Searching databases (Qdrant, SPARQL, etc.) - generation: Generating answer with LLM - complete: Final response ready """ return StreamingResponse( stream_dspy_query_response(request), media_type="application/x-ndjson", ) async def stream_query_response( request: QueryRequest, ) -> AsyncIterator[str]: """Stream query response for long-running queries.""" if not retriever: yield json.dumps({"error": "Retriever not initialized"}) return start_time = asyncio.get_event_loop().time() # Route query intent, sources = retriever.router.get_sources(request.question, request.sources) # Extract geographic filters from question (province, city, institution type) geo_filters = extract_geographic_filters(request.question) yield json.dumps({ "type": "status", "message": f"Routing query to {len(sources)} sources...", "intent": intent.value, "geo_filters": {k: v for k, v in geo_filters.items() if v}, }) + "\n" # Retrieve from sources and stream progress results = [] for source in sources: yield json.dumps({ "type": "status", "message": f"Querying {source.value}...", }) + "\n" source_results = await retriever.retrieve( request.question, [source], request.k, embedding_model=request.embedding_model, region_codes=geo_filters["region_codes"], cities=geo_filters["cities"], institution_types=geo_filters["institution_types"], ) results.extend(source_results) yield json.dumps({ "type": "partial", "source": source.value, "count": len(source_results[0].items) if source_results else 0, }) + "\n" # Merge and finalize with provenance merged, stream_provenance = retriever.merge_results(results) elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000 yield json.dumps({ "type": "complete", "results": merged, "query_time_ms": round(elapsed_ms, 2), "result_count": len(merged), "epistemic_provenance": stream_provenance.model_dump() if stream_provenance else None, }) + "\n" @app.post("/api/rag/query/stream") async def query_rag_stream(request: QueryRequest) -> StreamingResponse: """Streaming version of RAG query endpoint.""" return StreamingResponse( stream_query_response(request), media_type="application/x-ndjson", ) # ============================================================================= # SEMANTIC CACHE ENDPOINTS (Qdrant-backed) # ============================================================================= # # High-performance semantic cache using Qdrant's HNSW vector index. # Replaces slow client-side cosine similarity with server-side ANN search. # # Performance target: # - Cache lookup: <20ms (vs 500-2000ms with client-side scan) # - Cache store: <50ms # # Architecture: # Frontend → /api/cache/lookup → Qdrant ANN search → cached response # /api/cache/store → embed + upsert to Qdrant # ============================================================================= # Lazy-loaded Qdrant client for cache _cache_qdrant_client: Any = None _cache_embedding_model: Any = None CACHE_COLLECTION_NAME = "query_cache" CACHE_EMBEDDING_DIM = 384 # all-MiniLM-L6-v2 def get_cache_qdrant_client() -> Any: """Get or create Qdrant client for cache collection. Always uses localhost:6333 since cache is co-located with the RAG backend. This avoids reverse proxy overhead and ensures direct local connection. """ global _cache_qdrant_client if _cache_qdrant_client is not None: return _cache_qdrant_client try: from qdrant_client import QdrantClient # Cache always uses localhost - co-located with RAG backend # Uses settings.qdrant_host/port which default to localhost:6333 _cache_qdrant_client = QdrantClient( host=settings.qdrant_host, port=settings.qdrant_port, timeout=30, ) logger.info(f"Qdrant cache client: {settings.qdrant_host}:{settings.qdrant_port}") return _cache_qdrant_client except ImportError: logger.error("qdrant-client not installed") return None except Exception as e: logger.error(f"Failed to create Qdrant cache client: {e}") return None def get_cache_embedding_model() -> Any: """Get or create embedding model for cache (MiniLM-L6-v2, 384-dim).""" global _cache_embedding_model if _cache_embedding_model is not None: return _cache_embedding_model try: from sentence_transformers import SentenceTransformer _cache_embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") logger.info("Loaded cache embedding model: all-MiniLM-L6-v2") return _cache_embedding_model except ImportError: logger.error("sentence-transformers not installed") return None except Exception as e: logger.error(f"Failed to load cache embedding model: {e}") return None def ensure_cache_collection_exists() -> bool: """Ensure the query_cache collection exists in Qdrant.""" client = get_cache_qdrant_client() if client is None: return False try: from qdrant_client.models import Distance, VectorParams # Check if collection exists collections = client.get_collections().collections if any(c.name == CACHE_COLLECTION_NAME for c in collections): return True # Create collection with HNSW index client.create_collection( collection_name=CACHE_COLLECTION_NAME, vectors_config=VectorParams( size=CACHE_EMBEDDING_DIM, distance=Distance.COSINE, ), ) logger.info(f"Created Qdrant collection: {CACHE_COLLECTION_NAME}") return True except Exception as e: logger.error(f"Failed to ensure cache collection: {e}") return False # Request/Response Models for Cache API class CacheLookupRequest(BaseModel): """Cache lookup request.""" query: str = Field(..., description="Query text to look up") embedding: list[float] | None = Field(default=None, description="Pre-computed embedding (optional)") similarity_threshold: float = Field(default=0.92, description="Minimum similarity for match") language: str = Field(default="nl", description="Language filter") class CacheLookupResponse(BaseModel): """Cache lookup response.""" found: bool entry: dict[str, Any] | None = None similarity: float = 0.0 method: str = "none" lookup_time_ms: float = 0.0 class CacheStoreRequest(BaseModel): """Cache store request.""" query: str = Field(..., description="Query text") embedding: list[float] | None = Field(default=None, description="Pre-computed embedding (optional)") response: dict[str, Any] = Field(..., description="Response to cache") language: str = Field(default="nl", description="Language") model: str = Field(default="unknown", description="LLM model used") ttl_seconds: int = Field(default=86400, description="Time-to-live in seconds") class CacheStoreResponse(BaseModel): """Cache store response.""" success: bool id: str | None = None message: str = "" class CacheStatsResponse(BaseModel): """Cache statistics response.""" total_entries: int = 0 collection_name: str = CACHE_COLLECTION_NAME embedding_dim: int = CACHE_EMBEDDING_DIM backend: str = "qdrant" status: str = "ok" @app.post("/api/cache/lookup", response_model=CacheLookupResponse) async def cache_lookup(request: CacheLookupRequest) -> CacheLookupResponse: """Look up a query in the semantic cache using Qdrant ANN search. This endpoint performs sub-millisecond vector similarity search using Qdrant's HNSW index, replacing slow client-side cosine similarity scans. """ import time start_time = time.perf_counter() # Ensure collection exists if not ensure_cache_collection_exists(): return CacheLookupResponse( found=False, similarity=0.0, method="error", lookup_time_ms=(time.perf_counter() - start_time) * 1000, ) client = get_cache_qdrant_client() if client is None: return CacheLookupResponse( found=False, similarity=0.0, method="error", lookup_time_ms=(time.perf_counter() - start_time) * 1000, ) # Get or generate embedding embedding = request.embedding if embedding is None: model = get_cache_embedding_model() if model is None: return CacheLookupResponse( found=False, similarity=0.0, method="error", lookup_time_ms=(time.perf_counter() - start_time) * 1000, ) embedding = model.encode(request.query).tolist() try: from qdrant_client.models import Filter, FieldCondition, MatchValue # Build filter for language search_filter = Filter( must=[ FieldCondition( key="language", match=MatchValue(value=request.language), ) ] ) # Perform ANN search using query_points (qdrant-client >= 1.7) results = client.query_points( collection_name=CACHE_COLLECTION_NAME, query=embedding, query_filter=search_filter, limit=1, score_threshold=request.similarity_threshold, ).points elapsed_ms = (time.perf_counter() - start_time) * 1000 if not results: return CacheLookupResponse( found=False, similarity=0.0, method="semantic", lookup_time_ms=elapsed_ms, ) # Extract best match best = results[0] payload = best.payload or {} return CacheLookupResponse( found=True, entry={ "id": str(best.id), "query": payload.get("query", ""), "query_normalized": payload.get("query_normalized", ""), "response": payload.get("response", {}), "timestamp": payload.get("timestamp", 0), "hit_count": payload.get("hit_count", 0), "last_accessed": payload.get("last_accessed", 0), "language": payload.get("language", "nl"), "model": payload.get("model", "unknown"), }, similarity=best.score, method="semantic", lookup_time_ms=elapsed_ms, ) except Exception as e: logger.error(f"Cache lookup error: {e}") return CacheLookupResponse( found=False, similarity=0.0, method="error", lookup_time_ms=(time.perf_counter() - start_time) * 1000, ) @app.post("/api/cache/store", response_model=CacheStoreResponse) async def cache_store(request: CacheStoreRequest) -> CacheStoreResponse: """Store a query/response pair in the semantic cache. Generates embedding if not provided and upserts to Qdrant. """ import time import uuid # Ensure collection exists if not ensure_cache_collection_exists(): return CacheStoreResponse( success=False, message="Failed to ensure cache collection exists", ) client = get_cache_qdrant_client() if client is None: return CacheStoreResponse( success=False, message="Qdrant client not available", ) # Get or generate embedding embedding = request.embedding if embedding is None: model = get_cache_embedding_model() if model is None: return CacheStoreResponse( success=False, message="Embedding model not available", ) embedding = model.encode(request.query).tolist() try: from qdrant_client.models import PointStruct # Generate unique ID point_id = str(uuid.uuid4()) timestamp = int(time.time() * 1000) # Normalize query for exact matching query_normalized = request.query.lower().strip() # Create point point = PointStruct( id=point_id, vector=embedding, payload={ "query": request.query, "query_normalized": query_normalized, "response": request.response, "language": request.language, "model": request.model, "timestamp": timestamp, "hit_count": 0, "last_accessed": timestamp, "ttl_seconds": request.ttl_seconds, }, ) # Upsert to Qdrant client.upsert( collection_name=CACHE_COLLECTION_NAME, points=[point], ) logger.debug(f"Cached query: {request.query[:50]}...") return CacheStoreResponse( success=True, id=point_id, message="Stored successfully", ) except Exception as e: logger.error(f"Cache store error: {e}") return CacheStoreResponse( success=False, message=str(e), ) @app.get("/api/cache/stats", response_model=CacheStatsResponse) async def cache_stats() -> CacheStatsResponse: """Get cache statistics.""" client = get_cache_qdrant_client() if client is None: return CacheStatsResponse( status="error", total_entries=0, ) try: # Check if collection exists collections = client.get_collections().collections if not any(c.name == CACHE_COLLECTION_NAME for c in collections): return CacheStatsResponse( status="no_collection", total_entries=0, ) # Get collection info info = client.get_collection(CACHE_COLLECTION_NAME) return CacheStatsResponse( total_entries=info.points_count, collection_name=CACHE_COLLECTION_NAME, embedding_dim=CACHE_EMBEDDING_DIM, backend="qdrant", status="ok", ) except Exception as e: logger.error(f"Cache stats error: {e}") return CacheStatsResponse( status=f"error: {e}", total_entries=0, ) @app.delete("/api/cache/clear") async def cache_clear() -> dict[str, Any]: """Clear all cache entries (both Qdrant semantic cache and Valkey/Redis cache).""" results = { "qdrant": {"success": False, "deleted": 0, "message": ""}, "valkey": {"success": False, "deleted": 0, "message": ""}, } # 1. Clear Qdrant semantic cache client = get_cache_qdrant_client() if client is None: results["qdrant"]["message"] = "Qdrant client not available" else: try: collections = client.get_collections().collections if not any(c.name == CACHE_COLLECTION_NAME for c in collections): results["qdrant"] = {"success": True, "deleted": 0, "message": "Collection does not exist"} else: info = client.get_collection(CACHE_COLLECTION_NAME) count = info.points_count client.delete_collection(CACHE_COLLECTION_NAME) ensure_cache_collection_exists() results["qdrant"] = {"success": True, "deleted": count, "message": f"Cleared {count} entries"} except Exception as e: logger.error(f"Qdrant cache clear error: {e}") results["qdrant"]["message"] = str(e) # 2. Clear Valkey/Redis cache using redis-cli FLUSHALL import subprocess try: # Try redis-cli first, then valkey-cli for cli in ["redis-cli", "valkey-cli"]: try: result = subprocess.run( [cli, "FLUSHALL"], capture_output=True, text=True, timeout=5, ) if result.returncode == 0 and "OK" in result.stdout: results["valkey"] = {"success": True, "deleted": -1, "message": f"Flushed via {cli}"} break except FileNotFoundError: continue if not results["valkey"]["success"]: results["valkey"]["message"] = "Neither redis-cli nor valkey-cli available" except subprocess.TimeoutExpired: results["valkey"]["message"] = "Cache flush timed out" except Exception as e: logger.error(f"Valkey cache clear error: {e}") results["valkey"]["message"] = str(e) overall_success = results["qdrant"]["success"] or results["valkey"]["success"] total_deleted = max(0, results["qdrant"]["deleted"]) + max(0, results["valkey"]["deleted"]) return { "success": overall_success, "message": f"Qdrant: {results['qdrant']['message']}, Valkey: {results['valkey']['message']}", "deleted": total_deleted, "details": results, } # Main entry point if __name__ == "__main__": import uvicorn uvicorn.run( "main:app", host="0.0.0.0", port=8003, reload=settings.debug, log_level="info", )