Track full lineage of RAG responses: WHERE data comes from, WHEN it was retrieved, HOW it was processed (SPARQL/vector/LLM). Backend changes: - Add provenance.py with EpistemicProvenance, DataTier, SourceAttribution - Integrate provenance into MultiSourceRetriever.merge_results() - Return epistemic_provenance in DSPyQueryResponse Frontend changes: - Pass EpistemicProvenance through useMultiDatabaseRAG hook - Display provenance in ConversationPage (for cache transparency) Schema fixes: - Fix truncated example in has_observation.yaml slot definition References: - Pavlyshyn's Context Graphs and Data Traces paper - LinkML ProvenanceBlock schema pattern
5095 lines
219 KiB
Python
5095 lines
219 KiB
Python
from __future__ import annotations
|
|
|
|
"""
|
|
Unified RAG Backend for Heritage Custodian Data
|
|
|
|
Multi-source retrieval-augmented generation system that combines:
|
|
- Qdrant vector search (semantic similarity)
|
|
- Oxigraph SPARQL (knowledge graph queries)
|
|
- TypeDB (relationship traversal)
|
|
- PostGIS (geospatial queries)
|
|
- Valkey (semantic caching)
|
|
|
|
Architecture:
|
|
User Query → Query Analysis
|
|
↓
|
|
┌─────┴─────┐
|
|
│ Router │
|
|
└─────┬─────┘
|
|
┌─────┬─────┼─────┬─────┐
|
|
↓ ↓ ↓ ↓ ↓
|
|
Qdrant SPARQL TypeDB PostGIS Cache
|
|
│ │ │ │ │
|
|
└─────┴─────┴─────┴─────┘
|
|
↓
|
|
┌─────┴─────┐
|
|
│ Merger │
|
|
└─────┬─────┘
|
|
↓
|
|
DSPy Generator
|
|
↓
|
|
Visualization Selector
|
|
↓
|
|
Response (JSON/Streaming)
|
|
|
|
Features:
|
|
- Intelligent query routing to appropriate data sources
|
|
- Score fusion for multi-source results
|
|
- Semantic caching via Valkey API
|
|
- Streaming responses for long-running queries
|
|
- DSPy assertions for output validation
|
|
|
|
Endpoints:
|
|
- POST /api/rag/query - Main RAG query endpoint
|
|
- POST /api/rag/sparql - Generate SPARQL with RAG context
|
|
- POST /api/rag/typedb/search - Direct TypeDB search
|
|
- GET /api/rag/health - Health check for all services
|
|
- GET /api/rag/stats - Retriever statistics
|
|
"""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from enum import Enum
|
|
from typing import Any, AsyncIterator, TYPE_CHECKING
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, HTTPException, Query
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
# Rule 46: Epistemic Provenance Tracking
|
|
from .provenance import (
|
|
EpistemicProvenance,
|
|
EpistemicDataSource,
|
|
DataTier,
|
|
RetrievalSource,
|
|
SourceAttribution,
|
|
infer_data_tier,
|
|
build_derivation_chain,
|
|
aggregate_data_tier,
|
|
)
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Type hints for optional imports (only used during type checking)
|
|
if TYPE_CHECKING:
|
|
from glam_extractor.api.hybrid_retriever import HybridRetriever
|
|
from glam_extractor.api.typedb_retriever import TypeDBRetriever
|
|
from glam_extractor.api.visualization import VisualizationSelector
|
|
|
|
# Import retrievers (with graceful fallbacks)
|
|
RETRIEVERS_AVAILABLE = False
|
|
create_hybrid_retriever: Any = None
|
|
HeritageCustodianRetriever: Any = None
|
|
create_typedb_retriever: Any = None
|
|
select_visualization: Any = None
|
|
VisualizationSelector: Any = None # type: ignore[no-redef]
|
|
generate_sparql: Any = None
|
|
configure_dspy: Any = None
|
|
get_province_code: Any = None # Province name to ISO 3166-2 code converter
|
|
|
|
try:
|
|
import sys
|
|
sys.path.insert(0, str(os.path.join(os.path.dirname(__file__), "..", "..", "src")))
|
|
from glam_extractor.api.hybrid_retriever import (
|
|
HybridRetriever as _HybridRetriever,
|
|
create_hybrid_retriever as _create_hybrid_retriever,
|
|
get_province_code as _get_province_code,
|
|
PERSON_JSONLD_CONTEXT,
|
|
)
|
|
from glam_extractor.api.qdrant_retriever import HeritageCustodianRetriever as _HeritageCustodianRetriever
|
|
from glam_extractor.api.typedb_retriever import TypeDBRetriever as _TypeDBRetriever, create_typedb_retriever as _create_typedb_retriever
|
|
from glam_extractor.api.visualization import select_visualization as _select_visualization, VisualizationSelector as _VisualizationSelector
|
|
# Assign to module-level variables
|
|
create_hybrid_retriever = _create_hybrid_retriever
|
|
HeritageCustodianRetriever = _HeritageCustodianRetriever
|
|
create_typedb_retriever = _create_typedb_retriever
|
|
select_visualization = _select_visualization
|
|
VisualizationSelector = _VisualizationSelector
|
|
get_province_code = _get_province_code
|
|
RETRIEVERS_AVAILABLE = True
|
|
except ImportError as e:
|
|
logger.warning(f"Core retrievers not available: {e}")
|
|
# Provide a fallback get_province_code that returns None
|
|
def get_province_code(province_name: str | None) -> str | None:
|
|
"""Fallback when hybrid_retriever is not available."""
|
|
return None
|
|
|
|
# DSPy is optional - don't block retrievers if it's missing
|
|
try:
|
|
from glam_extractor.api.dspy_sparql import generate_sparql as _generate_sparql, configure_dspy as _configure_dspy
|
|
generate_sparql = _generate_sparql
|
|
configure_dspy = _configure_dspy
|
|
except ImportError as e:
|
|
logger.warning(f"DSPy SPARQL not available: {e}")
|
|
|
|
# Atomic query decomposition for geographic/type filtering and sub-task caching
|
|
decompose_query: Any = None
|
|
AtomicCacheManager: Any = None
|
|
DECOMPOSER_AVAILABLE = False
|
|
ATOMIC_CACHE_AVAILABLE = False
|
|
try:
|
|
from atomic_decomposer import (
|
|
decompose_query as _decompose_query,
|
|
AtomicCacheManager as _AtomicCacheManager,
|
|
)
|
|
decompose_query = _decompose_query
|
|
AtomicCacheManager = _AtomicCacheManager
|
|
DECOMPOSER_AVAILABLE = True
|
|
ATOMIC_CACHE_AVAILABLE = True
|
|
logger.info("Query decomposer and AtomicCacheManager loaded successfully")
|
|
except ImportError as e:
|
|
logger.info(f"Query decomposer not available: {e}")
|
|
|
|
# Cost tracker is optional - gracefully degrades if unavailable
|
|
COST_TRACKER_AVAILABLE = False
|
|
get_tracker = None
|
|
reset_tracker = None
|
|
try:
|
|
from cost_tracker import get_tracker as _get_tracker, reset_tracker as _reset_tracker
|
|
get_tracker = _get_tracker
|
|
reset_tracker = _reset_tracker
|
|
COST_TRACKER_AVAILABLE = True
|
|
logger.info("Cost tracker module loaded successfully")
|
|
except ImportError as e:
|
|
logger.info(f"Cost tracker not available (optional): {e}")
|
|
|
|
# Session manager for multi-turn conversation state
|
|
SESSION_MANAGER_AVAILABLE = False
|
|
get_session_manager = None
|
|
shutdown_session_manager = None
|
|
ConversationState = None
|
|
try:
|
|
from session_manager import (
|
|
get_session_manager as _get_session_manager,
|
|
shutdown_session_manager as _shutdown_session_manager,
|
|
ConversationState as _ConversationState,
|
|
)
|
|
get_session_manager = _get_session_manager
|
|
shutdown_session_manager = _shutdown_session_manager
|
|
ConversationState = _ConversationState
|
|
SESSION_MANAGER_AVAILABLE = True
|
|
logger.info("Session manager module loaded successfully")
|
|
except ImportError as e:
|
|
logger.info(f"Session manager not available (optional): {e}")
|
|
|
|
# Template-based SPARQL pipeline (deterministic, validated queries)
|
|
# This provides 65% precision vs 10% for LLM-only SPARQL generation
|
|
TEMPLATE_SPARQL_AVAILABLE = False
|
|
TemplateSPARQLPipeline: Any = None
|
|
get_template_pipeline: Any = None
|
|
_template_pipeline_instance: Any = None # Singleton for reuse
|
|
|
|
try:
|
|
from template_sparql import (
|
|
TemplateSPARQLPipeline as _TemplateSPARQLPipeline,
|
|
get_template_pipeline as _get_template_pipeline,
|
|
)
|
|
TemplateSPARQLPipeline = _TemplateSPARQLPipeline
|
|
get_template_pipeline = _get_template_pipeline
|
|
TEMPLATE_SPARQL_AVAILABLE = True
|
|
logger.info("Template SPARQL pipeline loaded successfully")
|
|
except ImportError as e:
|
|
logger.info(f"Template SPARQL pipeline not available (optional): {e}")
|
|
|
|
# Prometheus metrics for monitoring template hit rate, latency, etc.
|
|
METRICS_AVAILABLE = False
|
|
record_query = None
|
|
record_template_matching = None
|
|
record_template_tier = None
|
|
record_atomic_cache = None
|
|
record_atomic_subtask_cached = None
|
|
record_connection_pool = None
|
|
record_embedding_warmup = None
|
|
record_template_embedding_warmup = None
|
|
set_warmup_status = None
|
|
set_active_sessions = None
|
|
create_metrics_endpoint = None
|
|
get_template_hit_rate = None
|
|
get_all_performance_stats = None
|
|
try:
|
|
from metrics import (
|
|
record_query as _record_query,
|
|
record_template_matching as _record_template_matching,
|
|
record_template_tier as _record_template_tier,
|
|
record_atomic_cache as _record_atomic_cache,
|
|
record_atomic_subtask_cached as _record_atomic_subtask_cached,
|
|
record_connection_pool as _record_connection_pool,
|
|
record_embedding_warmup as _record_embedding_warmup,
|
|
record_template_embedding_warmup as _record_template_embedding_warmup,
|
|
set_warmup_status as _set_warmup_status,
|
|
set_active_sessions as _set_active_sessions,
|
|
create_metrics_endpoint as _create_metrics_endpoint,
|
|
get_template_hit_rate as _get_template_hit_rate,
|
|
get_all_performance_stats as _get_all_performance_stats,
|
|
PROMETHEUS_AVAILABLE,
|
|
)
|
|
record_query = _record_query
|
|
record_template_matching = _record_template_matching
|
|
record_template_tier = _record_template_tier
|
|
record_atomic_cache = _record_atomic_cache
|
|
record_atomic_subtask_cached = _record_atomic_subtask_cached
|
|
record_connection_pool = _record_connection_pool
|
|
record_embedding_warmup = _record_embedding_warmup
|
|
record_template_embedding_warmup = _record_template_embedding_warmup
|
|
set_warmup_status = _set_warmup_status
|
|
set_active_sessions = _set_active_sessions
|
|
create_metrics_endpoint = _create_metrics_endpoint
|
|
get_template_hit_rate = _get_template_hit_rate
|
|
get_all_performance_stats = _get_all_performance_stats
|
|
METRICS_AVAILABLE = PROMETHEUS_AVAILABLE
|
|
logger.info(f"Metrics module loaded (prometheus={PROMETHEUS_AVAILABLE})")
|
|
except ImportError as e:
|
|
logger.info(f"Metrics module not available (optional): {e}")
|
|
|
|
|
|
# Province detection for geographic filtering
|
|
DUTCH_PROVINCES = {
|
|
"noord-holland", "noordholland", "north holland", "north-holland",
|
|
"zuid-holland", "zuidholland", "south holland", "south-holland",
|
|
"utrecht", "gelderland", "noord-brabant", "noordbrabant", "brabant",
|
|
"north brabant", "limburg", "overijssel", "friesland", "fryslân",
|
|
"fryslan", "groningen", "drenthe", "flevoland", "zeeland",
|
|
}
|
|
|
|
|
|
def infer_location_level(location: str) -> str:
|
|
"""Infer whether location is city, province, or region.
|
|
|
|
Returns:
|
|
'province' if location is a Dutch province
|
|
'region' if location is a sub-provincial region
|
|
'city' otherwise
|
|
"""
|
|
location_lower = location.lower().strip()
|
|
|
|
if location_lower in DUTCH_PROVINCES:
|
|
return "province"
|
|
|
|
# Sub-provincial regions
|
|
regions = {"randstad", "veluwe", "achterhoek", "twente", "de betuwe", "betuwe"}
|
|
if location_lower in regions:
|
|
return "region"
|
|
|
|
return "city"
|
|
|
|
|
|
def extract_geographic_filters(question: str) -> dict[str, list[str] | None]:
|
|
"""Extract geographic filters from a question using query decomposition.
|
|
|
|
Returns:
|
|
dict with keys: region_codes, cities, institution_types
|
|
"""
|
|
filters: dict[str, list[str] | None] = {
|
|
"region_codes": None,
|
|
"cities": None,
|
|
"institution_types": None,
|
|
}
|
|
|
|
if not DECOMPOSER_AVAILABLE or not decompose_query:
|
|
return filters
|
|
|
|
# Check for explicit city markers BEFORE decomposition
|
|
# This overrides province disambiguation when user explicitly says "de stad"
|
|
question_lower = question.lower()
|
|
explicit_city_markers = [
|
|
"de stad ", "in de stad", "stad van", "gemeente ",
|
|
"the city of", "city of ", "in the city"
|
|
]
|
|
force_city = any(marker in question_lower for marker in explicit_city_markers)
|
|
|
|
try:
|
|
decomposed = decompose_query(question)
|
|
|
|
# Extract location and determine if it's a province or city
|
|
if decomposed.location:
|
|
location = decomposed.location
|
|
|
|
# If user explicitly said "de stad", treat as city even if it's a province name
|
|
if force_city:
|
|
filters["cities"] = [location]
|
|
logger.info(f"City filter (explicit): {location}")
|
|
else:
|
|
level = infer_location_level(location)
|
|
|
|
if level == "province":
|
|
# Convert province name to ISO 3166-2 code for Qdrant filtering
|
|
# e.g., "Noord-Holland" → "NH"
|
|
province_code = get_province_code(location)
|
|
if province_code:
|
|
filters["region_codes"] = [province_code]
|
|
logger.info(f"Province filter: {location} → {province_code}")
|
|
elif level == "city":
|
|
filters["cities"] = [location]
|
|
logger.info(f"City filter: {location}")
|
|
|
|
# Extract institution type
|
|
if decomposed.institution_type:
|
|
# Map common types to enum values
|
|
type_mapping = {
|
|
"archive": "ARCHIVE",
|
|
"archief": "ARCHIVE",
|
|
"archieven": "ARCHIVE",
|
|
"museum": "MUSEUM",
|
|
"musea": "MUSEUM",
|
|
"museums": "MUSEUM",
|
|
"library": "LIBRARY",
|
|
"bibliotheek": "LIBRARY",
|
|
"bibliotheken": "LIBRARY",
|
|
"gallery": "GALLERY",
|
|
"galerie": "GALLERY",
|
|
}
|
|
inst_type = decomposed.institution_type.lower()
|
|
mapped_type = type_mapping.get(inst_type, inst_type.upper())
|
|
filters["institution_types"] = [mapped_type]
|
|
logger.info(f"Institution type filter: {mapped_type}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to extract geographic filters: {e}")
|
|
|
|
return filters
|
|
|
|
|
|
# Configuration
|
|
class Settings:
|
|
"""Application settings from environment variables."""
|
|
|
|
# API Configuration
|
|
api_title: str = "Heritage RAG API"
|
|
api_version: str = "1.0.0"
|
|
debug: bool = os.getenv("DEBUG", "false").lower() == "true"
|
|
|
|
# Valkey Cache
|
|
valkey_api_url: str = os.getenv("VALKEY_API_URL", "https://bronhouder.nl/api/cache")
|
|
cache_ttl: int = int(os.getenv("CACHE_TTL", "900")) # 15 minutes
|
|
|
|
# Qdrant Vector DB
|
|
# Production: Use URL-based client via bronhouder.nl/qdrant reverse proxy
|
|
qdrant_host: str = os.getenv("QDRANT_HOST", "localhost")
|
|
qdrant_port: int = int(os.getenv("QDRANT_PORT", "6333"))
|
|
qdrant_use_production: bool = os.getenv("QDRANT_USE_PRODUCTION", "true").lower() == "true"
|
|
qdrant_production_url: str = os.getenv("QDRANT_PRODUCTION_URL", "https://bronhouder.nl/qdrant")
|
|
|
|
# Multi-Embedding Support
|
|
# Enable to use named vectors with multiple embedding models (OpenAI 1536, MiniLM 384, BGE 768)
|
|
use_multi_embedding: bool = os.getenv("USE_MULTI_EMBEDDING", "true").lower() == "true"
|
|
preferred_embedding_model: str | None = os.getenv("PREFERRED_EMBEDDING_MODEL", None) # e.g., "minilm_384" or "openai_1536"
|
|
|
|
# Oxigraph SPARQL
|
|
# Production: Use bronhouder.nl/sparql reverse proxy
|
|
sparql_endpoint: str = os.getenv("SPARQL_ENDPOINT", "https://bronhouder.nl/sparql")
|
|
|
|
# TypeDB
|
|
# Note: TypeDB not exposed via reverse proxy - always use localhost
|
|
typedb_host: str = os.getenv("TYPEDB_HOST", "localhost")
|
|
typedb_port: int = int(os.getenv("TYPEDB_PORT", "1729"))
|
|
typedb_database: str = os.getenv("TYPEDB_DATABASE", "heritage_custodians")
|
|
typedb_use_production: bool = os.getenv("TYPEDB_USE_PRODUCTION", "false").lower() == "true" # Default off
|
|
|
|
# PostGIS/Geo API
|
|
# Production: Use bronhouder.nl/api/geo reverse proxy
|
|
postgis_url: str = os.getenv("POSTGIS_URL", "https://bronhouder.nl/api/geo")
|
|
|
|
# NOTE: DuckLake removed from RAG - it's for offline analytics only, not real-time retrieval
|
|
# RAG uses only Qdrant (vectors) and Oxigraph (SPARQL) for retrieval
|
|
|
|
# LLM Configuration
|
|
anthropic_api_key: str = os.getenv("ANTHROPIC_API_KEY", "")
|
|
openai_api_key: str = os.getenv("OPENAI_API_KEY", "")
|
|
huggingface_api_key: str = os.getenv("HUGGINGFACE_API_KEY", "")
|
|
groq_api_key: str = os.getenv("GROQ_API_KEY", "")
|
|
zai_api_token: str = os.getenv("ZAI_API_TOKEN", "")
|
|
default_model: str = os.getenv("DEFAULT_MODEL", "claude-opus-4-5-20251101")
|
|
# LLM Provider: "anthropic", "openai", "huggingface", "zai" (FREE), or "groq" (FREE)
|
|
llm_provider: str = os.getenv("LLM_PROVIDER", "anthropic")
|
|
# LLM Model: Specific model to use. Defaults depend on provider.
|
|
# For Z.AI: "glm-4.5-flash" (fast, recommended) or "glm-4.6" (reasoning, slow)
|
|
llm_model: str = os.getenv("LLM_MODEL", "glm-4.5-flash")
|
|
# Fast LM Provider for routing/extraction: "openai" (fast ~1-2s) or "zai" (FREE but slow ~13s)
|
|
# Default to openai for speed. Set to "zai" to save costs (free but adds ~12s latency)
|
|
fast_lm_provider: str = os.getenv("FAST_LM_PROVIDER", "openai")
|
|
|
|
# Retrieval weights
|
|
vector_weight: float = float(os.getenv("VECTOR_WEIGHT", "0.5"))
|
|
graph_weight: float = float(os.getenv("GRAPH_WEIGHT", "0.3"))
|
|
typedb_weight: float = float(os.getenv("TYPEDB_WEIGHT", "0.2"))
|
|
|
|
|
|
settings = Settings()
|
|
|
|
|
|
# Enums and Models
|
|
class QueryIntent(str, Enum):
|
|
"""Detected query intent for routing."""
|
|
GEOGRAPHIC = "geographic" # Location-based queries
|
|
STATISTICAL = "statistical" # Counts, aggregations
|
|
RELATIONAL = "relational" # Relationships between entities
|
|
TEMPORAL = "temporal" # Historical, timeline queries
|
|
SEARCH = "search" # General text search
|
|
DETAIL = "detail" # Specific entity lookup
|
|
|
|
|
|
class DataSource(str, Enum):
|
|
"""Available data sources for RAG retrieval.
|
|
|
|
NOTE: DuckLake removed - it's for offline analytics only, not real-time RAG retrieval.
|
|
RAG uses Qdrant (vectors) and Oxigraph (SPARQL) as primary backends.
|
|
"""
|
|
QDRANT = "qdrant"
|
|
SPARQL = "sparql"
|
|
TYPEDB = "typedb"
|
|
POSTGIS = "postgis"
|
|
CACHE = "cache"
|
|
# DUCKLAKE removed - use DuckLake separately for offline analytics/dashboards
|
|
|
|
|
|
@dataclass
|
|
class RetrievalResult:
|
|
"""Result from a single retriever."""
|
|
source: DataSource
|
|
items: list[dict[str, Any]]
|
|
score: float = 0.0
|
|
query_time_ms: float = 0.0
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
class QueryRequest(BaseModel):
|
|
"""RAG query request."""
|
|
question: str = Field(..., description="Natural language question")
|
|
language: str = Field(default="nl", description="Language code (nl or en)")
|
|
context: list[dict[str, Any]] = Field(default=[], description="Conversation history")
|
|
sources: list[DataSource] | None = Field(
|
|
default=None,
|
|
description="Data sources to query. If None, auto-routes based on query intent.",
|
|
)
|
|
k: int = Field(default=10, description="Number of results per source")
|
|
include_visualization: bool = Field(default=True, description="Include visualization config")
|
|
embedding_model: str | None = Field(
|
|
default=None,
|
|
description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available."
|
|
)
|
|
stream: bool = Field(default=False, description="Stream response")
|
|
|
|
|
|
class QueryResponse(BaseModel):
|
|
"""RAG query response."""
|
|
question: str
|
|
sparql: str | None = None
|
|
results: list[dict[str, Any]]
|
|
visualization: dict[str, Any] | None = None
|
|
sources_used: list[DataSource]
|
|
cache_hit: bool = False
|
|
query_time_ms: float
|
|
result_count: int
|
|
|
|
|
|
class SPARQLRequest(BaseModel):
|
|
"""SPARQL generation request."""
|
|
question: str
|
|
language: str = "nl"
|
|
context: list[dict[str, Any]] = []
|
|
use_rag: bool = True
|
|
|
|
|
|
class SPARQLResponse(BaseModel):
|
|
"""SPARQL generation response."""
|
|
sparql: str
|
|
explanation: str
|
|
rag_used: bool
|
|
retrieved_passages: list[str] = []
|
|
|
|
|
|
class SPARQLExecuteRequest(BaseModel):
|
|
"""Execute a SPARQL query directly against the knowledge graph."""
|
|
sparql_query: str = Field(..., description="SPARQL query to execute")
|
|
timeout: float = Field(default=30.0, ge=1.0, le=120.0, description="Query timeout in seconds")
|
|
|
|
|
|
class SPARQLExecuteResponse(BaseModel):
|
|
"""Response from direct SPARQL execution."""
|
|
results: list[dict[str, Any]] = Field(default=[], description="Query results as list of dicts")
|
|
result_count: int = Field(default=0, description="Number of results")
|
|
query_time_ms: float = Field(default=0.0, description="Query execution time in milliseconds")
|
|
error: str | None = Field(default=None, description="Error message if query failed")
|
|
|
|
|
|
class SPARQLRerunRequest(BaseModel):
|
|
"""Re-run RAG pipeline with modified SPARQL results injected into context."""
|
|
sparql_query: str = Field(..., description="Modified SPARQL query to execute")
|
|
original_question: str = Field(..., description="Original user question")
|
|
conversation_history: list[dict[str, Any]] = Field(
|
|
default=[],
|
|
description="Previous conversation turns"
|
|
)
|
|
language: str = Field(default="nl", description="Language code (nl or en)")
|
|
llm_provider: str | None = Field(default=None, description="LLM provider to use")
|
|
llm_model: str | None = Field(default=None, description="Specific LLM model")
|
|
|
|
|
|
class SPARQLRerunResponse(BaseModel):
|
|
"""Response from re-running RAG with modified SPARQL context."""
|
|
results: list[dict[str, Any]] = Field(default=[], description="SPARQL query results")
|
|
answer: str = Field(default="", description="Re-generated answer based on modified SPARQL results")
|
|
sparql_result_count: int = Field(default=0, description="Number of SPARQL results")
|
|
query_time_ms: float = Field(default=0.0, description="Total processing time")
|
|
|
|
|
|
class TypeDBSearchRequest(BaseModel):
|
|
"""TypeDB search request."""
|
|
query: str = Field(..., description="Search query (name, type, or location)")
|
|
search_type: str = Field(
|
|
default="semantic",
|
|
description="Search type: semantic, name, type, or location"
|
|
)
|
|
k: int = Field(default=10, ge=1, le=100, description="Number of results")
|
|
|
|
|
|
class TypeDBSearchResponse(BaseModel):
|
|
"""TypeDB search response."""
|
|
query: str
|
|
search_type: str
|
|
results: list[dict[str, Any]]
|
|
result_count: int
|
|
query_time_ms: float
|
|
|
|
|
|
class PersonSearchRequest(BaseModel):
|
|
"""Person/staff search request."""
|
|
query: str = Field(..., description="Search query for person/staff (e.g., 'Wie werkt er in het Nationaal Archief?')")
|
|
k: int = Field(default=10, ge=1, le=100, description="Number of results to return")
|
|
filter_custodian: str | None = Field(default=None, description="Filter by custodian slug (e.g., 'nationaal-archief')")
|
|
only_heritage_relevant: bool = Field(default=False, description="Only return heritage-relevant staff")
|
|
embedding_model: str | None = Field(
|
|
default=None,
|
|
description="Embedding model to use (e.g., 'minilm_384', 'openai_1536'). If None, auto-selects best available."
|
|
)
|
|
|
|
|
|
class PersonSearchResponse(BaseModel):
|
|
"""Person/staff search response with JSON-LD linked data."""
|
|
|
|
context: dict[str, Any] | None = Field(
|
|
default=None,
|
|
alias="@context",
|
|
description="JSON-LD context for linked data semantic interoperability"
|
|
)
|
|
query: str
|
|
results: list[dict[str, Any]]
|
|
result_count: int
|
|
query_time_ms: float
|
|
collection_stats: dict[str, Any] | None = None
|
|
embedding_model_used: str | None = None
|
|
|
|
model_config = {"populate_by_name": True}
|
|
|
|
|
|
class DSPyQueryRequest(BaseModel):
|
|
"""DSPy RAG query request with conversation support."""
|
|
question: str = Field(..., description="Natural language question")
|
|
language: str = Field(default="nl", description="Language code (nl or en)")
|
|
context: list[dict[str, Any]] = Field(
|
|
default=[],
|
|
description="Conversation history as list of {question, answer} dicts"
|
|
)
|
|
session_id: str | None = Field(
|
|
default=None,
|
|
description="Session ID for multi-turn conversations. If provided, session state is used to resolve follow-up questions like 'En in Enschede?'. If None, a new session is created and returned."
|
|
)
|
|
include_visualization: bool = Field(default=True, description="Include visualization config")
|
|
embedding_model: str | None = Field(
|
|
default=None,
|
|
description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available."
|
|
)
|
|
llm_provider: str | None = Field(
|
|
default=None,
|
|
description="LLM provider to use for this request: 'zai', 'anthropic', 'huggingface', or 'openai'. If None, uses server default (LLM_PROVIDER env)."
|
|
)
|
|
llm_model: str | None = Field(
|
|
default=None,
|
|
description="Specific LLM model to use (e.g., 'glm-4.6', 'claude-sonnet-4-5-20250929', 'gpt-4o'). If None, uses provider default."
|
|
)
|
|
skip_cache: bool = Field(
|
|
default=False,
|
|
description="Bypass cache lookup and force fresh LLM query. Useful for debugging."
|
|
)
|
|
|
|
|
|
class LLMResponseMetadata(BaseModel):
|
|
"""LLM response provenance metadata (aligned with LinkML LLMResponse schema).
|
|
|
|
Captures GLM 4.7 Interleaved Thinking chain-of-thought reasoning and
|
|
full API response metadata for audit trails and debugging.
|
|
|
|
See: schemas/20251121/linkml/modules/classes/LLMResponse.yaml
|
|
"""
|
|
# Core response content
|
|
content: str | None = None # The final LLM response text
|
|
reasoning_content: str | None = None # GLM 4.7 Interleaved Thinking chain-of-thought
|
|
|
|
# Model identification
|
|
model: str | None = None # Model identifier (e.g., 'glm-4.7', 'claude-3-opus')
|
|
provider: str | None = None # Provider enum: zai, anthropic, openai, huggingface, groq
|
|
|
|
# Request tracking
|
|
request_id: str | None = None # Provider-assigned request ID
|
|
created: str | None = None # ISO 8601 timestamp of response generation
|
|
|
|
# Token usage (for cost estimation and monitoring)
|
|
prompt_tokens: int | None = None # Tokens in input prompt
|
|
completion_tokens: int | None = None # Tokens in response (content + reasoning)
|
|
total_tokens: int | None = None # Total tokens used
|
|
cached_tokens: int | None = None # Tokens served from provider cache
|
|
|
|
# Response metadata
|
|
finish_reason: str | None = None # stop, length, tool_calls, content_filter
|
|
latency_ms: int | None = None # Response latency in milliseconds
|
|
|
|
# GLM 4.7 Thinking Mode configuration
|
|
thinking_mode: str | None = None # enabled, disabled, interleaved, preserved
|
|
clear_thinking: bool | None = None # False = Preserved Thinking enabled
|
|
|
|
|
|
class DSPyQueryResponse(BaseModel):
|
|
"""DSPy RAG query response."""
|
|
question: str
|
|
resolved_question: str | None = None
|
|
answer: str
|
|
sources_used: list[str] = []
|
|
visualization: dict[str, Any] | None = None
|
|
retrieved_results: list[dict[str, Any]] | None = None # Raw retrieved data for frontend visualization
|
|
query_type: str | None = None # "person" or "institution" - helps frontend choose visualization
|
|
query_time_ms: float = 0.0
|
|
conversation_turn: int = 0
|
|
embedding_model_used: str | None = None # Which embedding model was used for the search
|
|
llm_provider_used: str | None = None # Which LLM provider handled this request (zai, anthropic, huggingface, openai)
|
|
llm_model_used: str | None = None # Which specific LLM model was used (e.g., 'glm-4.6', 'claude-sonnet-4-5-20250929')
|
|
|
|
# Cost tracking fields (from cost_tracker module)
|
|
timing_ms: float | None = None # Total pipeline timing from cost tracker
|
|
cost_usd: float | None = None # Estimated LLM cost in USD
|
|
timing_breakdown: dict[str, float] | None = None # Per-stage timing breakdown
|
|
|
|
# Cache tracking
|
|
cache_hit: bool = False # Whether response was served from cache
|
|
|
|
# LLM response provenance (GLM 4.7 Thinking Mode support)
|
|
llm_response: LLMResponseMetadata | None = None # Full LLM response metadata including reasoning_content
|
|
|
|
# Session management for multi-turn conversations
|
|
session_id: str | None = None # Session ID for continuing conversation (returned even if not provided in request)
|
|
|
|
# Template SPARQL tracking (for monitoring template hit rate vs LLM fallback)
|
|
template_used: bool = False # Whether template-based SPARQL was used (vs LLM generation)
|
|
template_id: str | None = None # Which template was used (e.g., "institution_by_city", "person_by_name")
|
|
|
|
# Factual query mode - skip LLM generation for count/list queries
|
|
factual_result: bool = False # True if this is a direct SPARQL result (no LLM prose generation)
|
|
sparql_query: str | None = None # The SPARQL query that was executed (for transparency)
|
|
|
|
# Rule 46: Epistemic Provenance Tracking
|
|
epistemic_provenance: dict[str, Any] | None = None # Full provenance chain for transparency
|
|
|
|
|
|
def extract_llm_response_metadata(
|
|
lm: Any,
|
|
provider: str | None = None,
|
|
latency_ms: int | None = None,
|
|
) -> LLMResponseMetadata | None:
|
|
"""Extract LLM response metadata from DSPy LM history.
|
|
|
|
DSPy stores the raw API response in lm.history[-1]["response"], which includes:
|
|
- choices[0].message.content (final response text)
|
|
- choices[0].message.reasoning_content (GLM 4.7 Interleaved Thinking)
|
|
- usage.prompt_tokens, completion_tokens, total_tokens
|
|
- model, created, id, finish_reason
|
|
|
|
This enables capturing GLM 4.7's chain-of-thought reasoning for provenance.
|
|
|
|
Args:
|
|
lm: DSPy LM instance with history attribute
|
|
provider: LLM provider name (zai, anthropic, openai, etc.)
|
|
latency_ms: Response latency in milliseconds
|
|
|
|
Returns:
|
|
LLMResponseMetadata or None if history is empty
|
|
"""
|
|
try:
|
|
# Check if LM has history
|
|
if not hasattr(lm, "history") or not lm.history:
|
|
logger.debug("No LM history available for metadata extraction")
|
|
return None
|
|
|
|
# Get the last history entry (most recent LLM call)
|
|
last_entry = lm.history[-1]
|
|
response = last_entry.get("response")
|
|
|
|
if response is None:
|
|
logger.debug("No response in LM history entry")
|
|
return None
|
|
|
|
# Extract content and reasoning_content from the response
|
|
content = None
|
|
reasoning_content = None
|
|
finish_reason = None
|
|
|
|
if hasattr(response, "choices") and response.choices:
|
|
choice = response.choices[0]
|
|
if hasattr(choice, "message"):
|
|
message = choice.message
|
|
content = getattr(message, "content", None)
|
|
# GLM 4.7 Interleaved Thinking - check for reasoning_content
|
|
reasoning_content = getattr(message, "reasoning_content", None)
|
|
elif isinstance(choice, dict):
|
|
content = choice.get("text") or choice.get("message", {}).get("content")
|
|
reasoning_content = choice.get("message", {}).get("reasoning_content")
|
|
|
|
# Extract finish_reason
|
|
finish_reason = getattr(choice, "finish_reason", None)
|
|
if finish_reason is None and isinstance(choice, dict):
|
|
finish_reason = choice.get("finish_reason")
|
|
|
|
# Extract usage statistics - handle both dict and object types
|
|
# (DSPy/OpenAI SDK may return CompletionUsage objects instead of dicts)
|
|
usage = last_entry.get("usage")
|
|
prompt_tokens = None
|
|
completion_tokens = None
|
|
total_tokens = None
|
|
cached_tokens = None
|
|
|
|
if usage is not None:
|
|
if hasattr(usage, "prompt_tokens"):
|
|
# It's an object (e.g., CompletionUsage from OpenAI SDK)
|
|
prompt_tokens = getattr(usage, "prompt_tokens", None)
|
|
completion_tokens = getattr(usage, "completion_tokens", None)
|
|
total_tokens = getattr(usage, "total_tokens", None)
|
|
prompt_details = getattr(usage, "prompt_tokens_details", None)
|
|
if prompt_details is not None:
|
|
cached_tokens = getattr(prompt_details, "cached_tokens", None)
|
|
elif isinstance(usage, dict):
|
|
# It's a plain dict
|
|
prompt_tokens = usage.get("prompt_tokens")
|
|
completion_tokens = usage.get("completion_tokens")
|
|
total_tokens = usage.get("total_tokens")
|
|
prompt_details = usage.get("prompt_tokens_details")
|
|
if isinstance(prompt_details, dict):
|
|
cached_tokens = prompt_details.get("cached_tokens")
|
|
|
|
# Extract model info
|
|
model = last_entry.get("response_model") or last_entry.get("model")
|
|
request_id = getattr(response, "id", None)
|
|
created = getattr(response, "created", None)
|
|
|
|
# Convert unix timestamp to ISO 8601 if needed
|
|
created_str = None
|
|
if created:
|
|
if isinstance(created, (int, float)):
|
|
import datetime
|
|
created_str = datetime.datetime.fromtimestamp(created, tz=datetime.timezone.utc).isoformat()
|
|
else:
|
|
created_str = str(created)
|
|
|
|
# Determine thinking mode (GLM 4.7 specific)
|
|
thinking_mode = None
|
|
if reasoning_content:
|
|
# If we got reasoning_content, the model used interleaved thinking
|
|
thinking_mode = "interleaved"
|
|
|
|
metadata = LLMResponseMetadata(
|
|
content=content,
|
|
reasoning_content=reasoning_content,
|
|
model=model,
|
|
provider=provider,
|
|
request_id=request_id,
|
|
created=created_str,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens,
|
|
cached_tokens=cached_tokens,
|
|
finish_reason=finish_reason,
|
|
latency_ms=latency_ms,
|
|
thinking_mode=thinking_mode,
|
|
)
|
|
|
|
if reasoning_content:
|
|
logger.info(
|
|
f"Captured GLM 4.7 reasoning_content ({len(reasoning_content)} chars) "
|
|
f"from {provider}/{model}"
|
|
)
|
|
|
|
return metadata
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to extract LLM response metadata: {e}")
|
|
return None
|
|
|
|
|
|
# Cache Client
|
|
class ValkeyClient:
|
|
"""Client for Valkey semantic cache API."""
|
|
|
|
def __init__(self, base_url: str = settings.valkey_api_url):
|
|
self.base_url = base_url.rstrip("/")
|
|
self._client: httpx.AsyncClient | None = None
|
|
|
|
@property
|
|
async def client(self) -> httpx.AsyncClient:
|
|
"""Get or create async HTTP client."""
|
|
if self._client is None or self._client.is_closed:
|
|
self._client = httpx.AsyncClient(timeout=30.0)
|
|
return self._client
|
|
|
|
def _cache_key(self, question: str, sources: list[DataSource] | None) -> str:
|
|
"""Generate cache key from question and sources."""
|
|
if sources:
|
|
sources_str = ",".join(sorted(s.value for s in sources))
|
|
else:
|
|
sources_str = "auto"
|
|
key_str = f"{question.lower().strip()}:{sources_str}"
|
|
return hashlib.sha256(key_str.encode()).hexdigest()[:32]
|
|
|
|
async def get(self, question: str, sources: list[DataSource] | None) -> dict[str, Any] | None:
|
|
"""Get cached response using semantic cache lookup."""
|
|
try:
|
|
client = await self.client
|
|
response = await client.post(
|
|
f"{self.base_url}/cache/lookup",
|
|
json={
|
|
"query": question,
|
|
# Higher threshold (0.97) to avoid false cache hits on semantically
|
|
# similar but geographically different queries
|
|
"similarity_threshold": 0.97,
|
|
},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
if data.get("found") and data.get("entry"):
|
|
logger.info(f"Cache hit for question: {question[:50]}... (similarity: {data.get('similarity', 0):.3f})")
|
|
return data["entry"].get("response") # type: ignore[no-any-return]
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Cache get failed: {e}")
|
|
return None
|
|
|
|
async def set(
|
|
self,
|
|
question: str,
|
|
sources: list[DataSource] | None,
|
|
response: dict[str, Any],
|
|
ttl: int = settings.cache_ttl,
|
|
) -> bool:
|
|
"""Cache response using semantic cache store."""
|
|
try:
|
|
client = await self.client
|
|
|
|
# Build CachedResponse schema
|
|
cached_response = {
|
|
"answer": response.get("answer", ""),
|
|
"sparql_query": response.get("sparql_query"),
|
|
"typeql_query": response.get("typeql_query"),
|
|
"visualization_type": response.get("visualization_type"),
|
|
"visualization_data": response.get("visualization_data"),
|
|
"sources": response.get("sources", []),
|
|
"confidence": response.get("confidence", 0.0),
|
|
"context": response.get("context"),
|
|
}
|
|
|
|
await client.post(
|
|
f"{self.base_url}/cache/store",
|
|
json={
|
|
"query": question,
|
|
"response": cached_response,
|
|
"language": response.get("language", "nl"),
|
|
"model": response.get("llm_model", "unknown"),
|
|
},
|
|
)
|
|
logger.debug(f"Cached response for: {question[:50]}...")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Cache set failed: {e}")
|
|
return False
|
|
|
|
def _dspy_cache_key(
|
|
self,
|
|
question: str,
|
|
language: str,
|
|
llm_provider: str | None,
|
|
embedding_model: str | None,
|
|
context_hash: str | None = None,
|
|
) -> str:
|
|
"""Generate cache key for DSPy query responses.
|
|
|
|
Cache key components:
|
|
- Question text (normalized)
|
|
- Language code
|
|
- LLM provider (different providers give different answers)
|
|
- Embedding model (affects retrieval results)
|
|
- Context hash (for multi-turn conversations)
|
|
"""
|
|
components = [
|
|
question.lower().strip(),
|
|
language,
|
|
llm_provider or "default",
|
|
embedding_model or "auto",
|
|
context_hash or "no_context",
|
|
]
|
|
key_str = ":".join(components)
|
|
return f"dspy:{hashlib.sha256(key_str.encode()).hexdigest()[:32]}"
|
|
|
|
async def get_dspy(
|
|
self,
|
|
question: str,
|
|
language: str,
|
|
llm_provider: str | None,
|
|
embedding_model: str | None,
|
|
context: list[dict[str, Any]] | None = None,
|
|
) -> dict[str, Any] | None:
|
|
"""Get cached DSPy response using semantic cache lookup.
|
|
|
|
Cache hits are filtered by LLM provider to ensure responses from different
|
|
providers (e.g., anthropic vs huggingface) are cached separately.
|
|
"""
|
|
try:
|
|
client = await self.client
|
|
response = await client.post(
|
|
f"{self.base_url}/cache/lookup",
|
|
json={
|
|
"query": question,
|
|
"language": language,
|
|
# Higher threshold (0.97) to avoid false cache hits on semantically
|
|
# similar but geographically different queries like
|
|
# "archieven in Groningen" vs "archieven in de stad Groningen"
|
|
"similarity_threshold": 0.97,
|
|
},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
if data.get("found") and data.get("entry"):
|
|
cached_response = data["entry"].get("response")
|
|
|
|
# Verify the cached response matches the requested LLM provider
|
|
# The model field in cache contains the provider (e.g., "anthropic", "huggingface")
|
|
cached_model = data["entry"].get("model")
|
|
requested_provider = llm_provider or settings.llm_provider
|
|
|
|
if cached_model and cached_model != requested_provider:
|
|
logger.info(
|
|
f"DSPy cache miss (provider mismatch): cached={cached_model}, requested={requested_provider}"
|
|
)
|
|
return None
|
|
|
|
similarity = data.get("similarity", 0)
|
|
method = data.get("method", "unknown")
|
|
logger.info(f"DSPy cache hit for question: {question[:50]}... (similarity: {similarity:.3f}, method: {method}, provider: {cached_model})")
|
|
return cached_response # type: ignore[no-any-return]
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.warning(f"DSPy cache get failed: {e}")
|
|
return None
|
|
|
|
async def set_dspy(
|
|
self,
|
|
question: str,
|
|
language: str,
|
|
llm_provider: str | None,
|
|
embedding_model: str | None,
|
|
response: dict[str, Any],
|
|
context: list[dict[str, Any]] | None = None,
|
|
ttl: int = settings.cache_ttl,
|
|
) -> bool:
|
|
"""Cache DSPy response using semantic cache store.
|
|
|
|
Maps DSPyQueryResponse fields to CachedResponse schema:
|
|
- sources_used -> sources
|
|
- visualization -> visualization_type + visualization_data
|
|
- Additional context from query_type, resolved_question, etc.
|
|
"""
|
|
try:
|
|
client = await self.client
|
|
|
|
# Extract visualization components if present
|
|
visualization = response.get("visualization")
|
|
viz_type = None
|
|
viz_data = None
|
|
if visualization:
|
|
viz_type = visualization.get("type")
|
|
viz_data = visualization.get("data")
|
|
|
|
# Build CachedResponse schema matching the Valkey API
|
|
# Maps DSPyQueryResponse fields to CachedResponse expected fields
|
|
#
|
|
# IMPORTANT: Include llm_response metadata (GLM 4.7 reasoning_content) in cache
|
|
# so that cached responses also return the chain-of-thought reasoning.
|
|
llm_response_data = None
|
|
if response.get("llm_response"):
|
|
llm_resp = response["llm_response"]
|
|
# Handle both dict and LLMResponseMetadata object
|
|
if hasattr(llm_resp, "model_dump"):
|
|
llm_response_data = llm_resp.model_dump()
|
|
elif isinstance(llm_resp, dict):
|
|
llm_response_data = llm_resp
|
|
|
|
cached_response = {
|
|
"answer": response.get("answer", ""),
|
|
"sparql_query": None, # DSPy doesn't generate SPARQL
|
|
"typeql_query": None, # DSPy doesn't generate TypeQL
|
|
"visualization_type": viz_type,
|
|
"visualization_data": viz_data,
|
|
"sources": response.get("sources_used", []), # DSPy uses sources_used
|
|
"confidence": 0.95, # DSPy responses are generally high confidence
|
|
"context": {
|
|
"query_type": response.get("query_type"),
|
|
"resolved_question": response.get("resolved_question"),
|
|
"retrieved_results": response.get("retrieved_results"),
|
|
"embedding_model": response.get("embedding_model_used"),
|
|
"llm_model": response.get("llm_model_used"),
|
|
"original_context": context,
|
|
"llm_response": llm_response_data, # GLM 4.7 reasoning_content
|
|
},
|
|
}
|
|
|
|
result = await client.post(
|
|
f"{self.base_url}/cache/store",
|
|
json={
|
|
"query": question,
|
|
"response": cached_response,
|
|
"language": language,
|
|
"model": llm_provider or "unknown",
|
|
},
|
|
)
|
|
|
|
# Check if store was successful
|
|
if result.status_code == 200:
|
|
logger.info(f"✓ Cached DSPy response for: {question[:50]}...")
|
|
return True
|
|
else:
|
|
logger.warning(f"Cache store returned {result.status_code}: {result.text[:200]}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.warning(f"DSPy cache set failed: {e}")
|
|
return False
|
|
|
|
async def close(self) -> None:
|
|
"""Close HTTP client."""
|
|
if self._client:
|
|
await self._client.aclose()
|
|
self._client = None
|
|
|
|
|
|
# Query Router
|
|
class QueryRouter:
|
|
"""Routes queries to appropriate data sources based on intent."""
|
|
|
|
def __init__(self) -> None:
|
|
self.intent_keywords = {
|
|
QueryIntent.GEOGRAPHIC: [
|
|
"map", "kaart", "where", "waar", "location", "locatie",
|
|
"city", "stad", "country", "land", "region", "gebied",
|
|
"coordinates", "coördinaten", "near", "nearby", "in de buurt",
|
|
],
|
|
QueryIntent.STATISTICAL: [
|
|
"how many", "hoeveel", "count", "aantal", "total", "totaal",
|
|
"average", "gemiddeld", "distribution", "verdeling",
|
|
"percentage", "statistics", "statistiek", "most", "meest",
|
|
],
|
|
QueryIntent.RELATIONAL: [
|
|
"related", "gerelateerd", "connected", "verbonden",
|
|
"relationship", "relatie", "network", "netwerk",
|
|
"parent", "child", "merged", "fusie", "member of",
|
|
],
|
|
QueryIntent.TEMPORAL: [
|
|
"history", "geschiedenis", "timeline", "tijdlijn",
|
|
"when", "wanneer", "founded", "opgericht", "closed", "gesloten",
|
|
"over time", "evolution", "change", "verandering",
|
|
],
|
|
QueryIntent.DETAIL: [
|
|
"details", "information", "informatie", "about", "over",
|
|
"specific", "specifiek", "what is", "wat is",
|
|
],
|
|
}
|
|
|
|
# NOTE: DuckLake removed from RAG - it's for offline analytics only
|
|
# Statistical queries now use SPARQL aggregations (COUNT, SUM, AVG, GROUP BY)
|
|
self.source_routing = {
|
|
QueryIntent.GEOGRAPHIC: [DataSource.POSTGIS, DataSource.QDRANT, DataSource.SPARQL],
|
|
QueryIntent.STATISTICAL: [DataSource.SPARQL, DataSource.QDRANT], # SPARQL aggregations
|
|
QueryIntent.RELATIONAL: [DataSource.TYPEDB, DataSource.SPARQL],
|
|
QueryIntent.TEMPORAL: [DataSource.TYPEDB, DataSource.SPARQL],
|
|
QueryIntent.SEARCH: [DataSource.QDRANT, DataSource.SPARQL],
|
|
QueryIntent.DETAIL: [DataSource.SPARQL, DataSource.QDRANT],
|
|
}
|
|
|
|
def detect_intent(self, question: str) -> QueryIntent:
|
|
"""Detect query intent from question text."""
|
|
import re
|
|
question_lower = question.lower()
|
|
|
|
intent_scores = {intent: 0 for intent in QueryIntent}
|
|
|
|
for intent, keywords in self.intent_keywords.items():
|
|
for keyword in keywords:
|
|
# Use word boundary matching to avoid partial matches
|
|
# e.g., "land" should not match "netherlands"
|
|
pattern = r'\b' + re.escape(keyword) + r'\b'
|
|
if re.search(pattern, question_lower):
|
|
intent_scores[intent] += 1
|
|
|
|
max_intent = max(intent_scores, key=intent_scores.get) # type: ignore
|
|
|
|
if intent_scores[max_intent] == 0:
|
|
return QueryIntent.SEARCH
|
|
|
|
return max_intent
|
|
|
|
def get_sources(
|
|
self,
|
|
question: str,
|
|
requested_sources: list[DataSource] | None = None,
|
|
) -> tuple[QueryIntent, list[DataSource]]:
|
|
"""Get optimal sources for a query.
|
|
|
|
Args:
|
|
question: User's question
|
|
requested_sources: Explicitly requested sources (overrides routing)
|
|
|
|
Returns:
|
|
Tuple of (detected_intent, list_of_sources)
|
|
"""
|
|
intent = self.detect_intent(question)
|
|
|
|
if requested_sources:
|
|
return intent, requested_sources
|
|
|
|
return intent, self.source_routing.get(intent, [DataSource.QDRANT])
|
|
|
|
|
|
# Multi-Source Retriever
|
|
class MultiSourceRetriever:
|
|
"""Orchestrates retrieval across multiple data sources."""
|
|
|
|
def __init__(self) -> None:
|
|
self.cache = ValkeyClient()
|
|
self.router = QueryRouter()
|
|
|
|
# Initialize retrievers lazily
|
|
self._qdrant: HybridRetriever | None = None
|
|
self._typedb: TypeDBRetriever | None = None
|
|
self._sparql_client: httpx.AsyncClient | None = None
|
|
self._postgis_client: httpx.AsyncClient | None = None
|
|
# NOTE: DuckLake client removed - DuckLake is for offline analytics only
|
|
|
|
@property
|
|
def qdrant(self) -> HybridRetriever | None:
|
|
"""Lazy-load Qdrant hybrid retriever with multi-embedding support."""
|
|
if self._qdrant is None and RETRIEVERS_AVAILABLE:
|
|
try:
|
|
self._qdrant = create_hybrid_retriever(
|
|
use_production=settings.qdrant_use_production,
|
|
use_multi_embedding=settings.use_multi_embedding,
|
|
preferred_embedding_model=settings.preferred_embedding_model,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize Qdrant: {e}")
|
|
return self._qdrant
|
|
|
|
@property
|
|
def typedb(self) -> TypeDBRetriever | None:
|
|
"""Lazy-load TypeDB retriever."""
|
|
if self._typedb is None and RETRIEVERS_AVAILABLE:
|
|
try:
|
|
self._typedb = create_typedb_retriever(
|
|
use_production=settings.typedb_use_production # Use TypeDB-specific setting
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize TypeDB: {e}")
|
|
return self._typedb
|
|
|
|
async def _get_sparql_client(self) -> httpx.AsyncClient:
|
|
"""Get SPARQL HTTP client with connection pooling.
|
|
|
|
Connection pooling improves performance by reusing TCP connections
|
|
instead of creating new ones for each request.
|
|
"""
|
|
if self._sparql_client is None or self._sparql_client.is_closed:
|
|
self._sparql_client = httpx.AsyncClient(
|
|
timeout=30.0,
|
|
limits=httpx.Limits(
|
|
max_connections=20, # Max total connections
|
|
max_keepalive_connections=10, # Keep-alive connections in pool
|
|
keepalive_expiry=30.0, # Seconds to keep idle connections
|
|
),
|
|
)
|
|
# Record connection pool metrics
|
|
if record_connection_pool:
|
|
record_connection_pool(client="sparql", pool_size=20, available=20)
|
|
return self._sparql_client
|
|
|
|
async def _get_postgis_client(self) -> httpx.AsyncClient:
|
|
"""Get PostGIS HTTP client with connection pooling."""
|
|
if self._postgis_client is None or self._postgis_client.is_closed:
|
|
self._postgis_client = httpx.AsyncClient(
|
|
timeout=30.0,
|
|
limits=httpx.Limits(
|
|
max_connections=10,
|
|
max_keepalive_connections=5,
|
|
keepalive_expiry=30.0,
|
|
),
|
|
)
|
|
# Record connection pool metrics
|
|
if record_connection_pool:
|
|
record_connection_pool(client="postgis", pool_size=10, available=10)
|
|
return self._postgis_client
|
|
|
|
# NOTE: _get_ducklake_client removed - DuckLake is for offline analytics only, not RAG retrieval
|
|
|
|
async def retrieve_from_qdrant(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
embedding_model: str | None = None,
|
|
region_codes: list[str] | None = None,
|
|
cities: list[str] | None = None,
|
|
institution_types: list[str] | None = None,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from Qdrant vector + SPARQL hybrid search.
|
|
|
|
Args:
|
|
query: Search query
|
|
k: Number of results to return
|
|
embedding_model: Optional embedding model to use (e.g., 'minilm_384', 'openai_1536')
|
|
region_codes: Filter by province/region codes (e.g., ['NH', 'ZH'])
|
|
cities: Filter by city names (e.g., ['Amsterdam', 'Rotterdam'])
|
|
institution_types: Filter by institution types (e.g., ['ARCHIVE', 'MUSEUM'])
|
|
"""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
items = []
|
|
if self.qdrant:
|
|
try:
|
|
results = self.qdrant.search(
|
|
query,
|
|
k=k,
|
|
using=embedding_model,
|
|
region_codes=region_codes,
|
|
cities=cities,
|
|
institution_types=institution_types,
|
|
)
|
|
items = [r.to_dict() for r in results]
|
|
except Exception as e:
|
|
logger.error(f"Qdrant retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.QDRANT,
|
|
items=items,
|
|
score=max((r.get("scores", {}).get("combined", 0) for r in items), default=0),
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve_from_sparql(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from SPARQL endpoint.
|
|
|
|
Uses TEMPLATE-FIRST approach:
|
|
1. Try template-based SPARQL generation (deterministic, validated)
|
|
2. Fall back to LLM-based generation only if no template matches
|
|
|
|
Template approach provides 65% precision vs 10% for LLM-only.
|
|
"""
|
|
global _template_pipeline_instance
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
items = []
|
|
sparql_query = ""
|
|
template_used = False
|
|
|
|
try:
|
|
# ===================================================================
|
|
# STEP 1: Try TEMPLATE-BASED SPARQL generation (preferred)
|
|
# ===================================================================
|
|
if TEMPLATE_SPARQL_AVAILABLE and get_template_pipeline:
|
|
try:
|
|
# Get or create singleton pipeline instance
|
|
if _template_pipeline_instance is None:
|
|
_template_pipeline_instance = get_template_pipeline()
|
|
logger.info("[SPARQL] Template pipeline initialized for MultiSourceRetriever")
|
|
|
|
# Run template matching in thread pool (DSPy is synchronous)
|
|
template_result = await asyncio.to_thread(
|
|
_template_pipeline_instance,
|
|
question=query,
|
|
conversation_state=None, # No conversation state in simple retriever
|
|
language="nl"
|
|
)
|
|
|
|
if template_result.matched and template_result.sparql:
|
|
sparql_query = template_result.sparql
|
|
template_used = True
|
|
logger.info(f"[SPARQL] Template match: '{template_result.template_id}' "
|
|
f"(confidence={template_result.confidence:.2f})")
|
|
else:
|
|
logger.info(f"[SPARQL] No template match: {template_result.reasoning}")
|
|
except Exception as e:
|
|
logger.warning(f"[SPARQL] Template pipeline failed: {e}")
|
|
|
|
# ===================================================================
|
|
# STEP 2: Fall back to LLM-BASED SPARQL generation
|
|
# ===================================================================
|
|
if not template_used and RETRIEVERS_AVAILABLE and generate_sparql:
|
|
logger.info("[SPARQL] Falling back to LLM-based SPARQL generation")
|
|
sparql_result = generate_sparql(query, language="nl", use_rag=False)
|
|
sparql_query = sparql_result.get("sparql", "")
|
|
|
|
# ===================================================================
|
|
# STEP 3: Execute the SPARQL query
|
|
# ===================================================================
|
|
if sparql_query:
|
|
logger.debug(f"[SPARQL] Executing query:\n{sparql_query[:500]}...")
|
|
client = await self._get_sparql_client()
|
|
response = await client.post(
|
|
settings.sparql_endpoint,
|
|
data={"query": sparql_query},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
bindings = data.get("results", {}).get("bindings", [])
|
|
items = [
|
|
{key: val.get("value") for key, val in b.items()}
|
|
for b in bindings[:k]
|
|
]
|
|
logger.info(f"[SPARQL] Query returned {len(items)} results "
|
|
f"(template={template_used})")
|
|
else:
|
|
logger.warning(f"[SPARQL] Query failed with status {response.status_code}: "
|
|
f"{response.text[:200]}")
|
|
except Exception as e:
|
|
logger.error(f"SPARQL retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.SPARQL,
|
|
items=items,
|
|
score=1.0 if items else 0.0,
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve_from_typedb(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from TypeDB knowledge graph."""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
items = []
|
|
if self.typedb:
|
|
try:
|
|
results = self.typedb.semantic_search(query, k=k)
|
|
items = [r.to_dict() for r in results]
|
|
except Exception as e:
|
|
logger.error(f"TypeDB retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.TYPEDB,
|
|
items=items,
|
|
score=max((r.get("relevance_score", 0) for r in items), default=0),
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve_from_postgis(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from PostGIS geospatial database."""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
# Extract location from query for geospatial search
|
|
# This is a simplified implementation
|
|
items = []
|
|
try:
|
|
client = await self._get_postgis_client()
|
|
|
|
# Try to detect city name for bbox search
|
|
query_lower = query.lower()
|
|
|
|
# Simple city detection
|
|
cities = {
|
|
"amsterdam": {"lat": 52.3676, "lon": 4.9041},
|
|
"rotterdam": {"lat": 51.9244, "lon": 4.4777},
|
|
"den haag": {"lat": 52.0705, "lon": 4.3007},
|
|
"utrecht": {"lat": 52.0907, "lon": 5.1214},
|
|
}
|
|
|
|
for city, coords in cities.items():
|
|
if city in query_lower:
|
|
# Query PostGIS for nearby institutions
|
|
response = await client.get(
|
|
f"{settings.postgis_url}/api/institutions/nearby",
|
|
params={
|
|
"lat": coords["lat"],
|
|
"lon": coords["lon"],
|
|
"radius_km": 10,
|
|
"limit": k,
|
|
},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
items = response.json()
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"PostGIS retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.POSTGIS,
|
|
items=items,
|
|
score=1.0 if items else 0.0,
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
# NOTE: retrieve_from_ducklake removed - DuckLake is for offline analytics only, not RAG retrieval
|
|
# Statistical queries now use SPARQL aggregations (COUNT, SUM, AVG, GROUP BY) on Oxigraph
|
|
|
|
async def retrieve(
|
|
self,
|
|
question: str,
|
|
sources: list[DataSource],
|
|
k: int = 10,
|
|
embedding_model: str | None = None,
|
|
region_codes: list[str] | None = None,
|
|
cities: list[str] | None = None,
|
|
institution_types: list[str] | None = None,
|
|
) -> list[RetrievalResult]:
|
|
"""Retrieve from multiple sources concurrently.
|
|
|
|
Args:
|
|
question: User's question
|
|
sources: Data sources to query
|
|
k: Number of results per source
|
|
embedding_model: Optional embedding model for Qdrant (e.g., 'minilm_384', 'openai_1536')
|
|
region_codes: Filter by province/region codes (e.g., ['NH', 'ZH']) - Qdrant only
|
|
cities: Filter by city names (e.g., ['Amsterdam']) - Qdrant only
|
|
institution_types: Filter by institution types (e.g., ['ARCHIVE']) - Qdrant only
|
|
|
|
Returns:
|
|
List of RetrievalResult from each source
|
|
"""
|
|
tasks = []
|
|
|
|
for source in sources:
|
|
if source == DataSource.QDRANT:
|
|
tasks.append(self.retrieve_from_qdrant(
|
|
question,
|
|
k,
|
|
embedding_model,
|
|
region_codes=region_codes,
|
|
cities=cities,
|
|
institution_types=institution_types,
|
|
))
|
|
elif source == DataSource.SPARQL:
|
|
tasks.append(self.retrieve_from_sparql(question, k))
|
|
elif source == DataSource.TYPEDB:
|
|
tasks.append(self.retrieve_from_typedb(question, k))
|
|
elif source == DataSource.POSTGIS:
|
|
tasks.append(self.retrieve_from_postgis(question, k))
|
|
# NOTE: DuckLake case removed - DuckLake is for offline analytics only
|
|
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Filter out exceptions
|
|
valid_results = []
|
|
for r in results:
|
|
if isinstance(r, RetrievalResult):
|
|
valid_results.append(r)
|
|
elif isinstance(r, Exception):
|
|
logger.error(f"Retrieval task failed: {r}")
|
|
|
|
return valid_results
|
|
|
|
def merge_results(
|
|
self,
|
|
results: list[RetrievalResult],
|
|
max_results: int = 20,
|
|
template_used: bool = False,
|
|
template_id: str | None = None,
|
|
) -> tuple[list[dict[str, Any]], EpistemicProvenance]:
|
|
"""Merge and deduplicate results from multiple sources.
|
|
|
|
Uses reciprocal rank fusion for score combination.
|
|
Returns merged items AND epistemic provenance tracking.
|
|
|
|
Rule 46: Epistemic Provenance Tracking
|
|
"""
|
|
from datetime import datetime, timezone
|
|
|
|
# Track items by GHCID for deduplication
|
|
merged: dict[str, dict[str, Any]] = {}
|
|
|
|
# Initialize provenance tracking
|
|
tier_counts: dict[DataTier, int] = {}
|
|
sources_queried = [r.source.value for r in results]
|
|
total_retrieved = sum(len(r.items) for r in results)
|
|
|
|
for result in results:
|
|
# Map DataSource to RetrievalSource
|
|
source_map = {
|
|
DataSource.QDRANT: RetrievalSource.QDRANT,
|
|
DataSource.SPARQL: RetrievalSource.SPARQL,
|
|
DataSource.TYPEDB: RetrievalSource.TYPEDB,
|
|
DataSource.POSTGIS: RetrievalSource.POSTGIS,
|
|
DataSource.CACHE: RetrievalSource.CACHE,
|
|
}
|
|
retrieval_source = source_map.get(result.source, RetrievalSource.LLM_SYNTHESIS)
|
|
|
|
for rank, item in enumerate(result.items):
|
|
ghcid = item.get("ghcid", item.get("id", f"unknown_{rank}"))
|
|
|
|
if ghcid not in merged:
|
|
merged[ghcid] = item.copy()
|
|
merged[ghcid]["_sources"] = []
|
|
merged[ghcid]["_rrf_score"] = 0.0
|
|
merged[ghcid]["_data_tier"] = None
|
|
|
|
# Infer data tier for this item
|
|
item_tier = infer_data_tier(item, retrieval_source)
|
|
tier_counts[item_tier] = tier_counts.get(item_tier, 0) + 1
|
|
|
|
# Track best (lowest) tier for each item
|
|
if merged[ghcid]["_data_tier"] is None:
|
|
merged[ghcid]["_data_tier"] = item_tier.value
|
|
else:
|
|
merged[ghcid]["_data_tier"] = min(merged[ghcid]["_data_tier"], item_tier.value)
|
|
|
|
# Reciprocal Rank Fusion
|
|
rrf_score = 1.0 / (60 + rank) # k=60 is standard
|
|
|
|
# Weight by source
|
|
source_weights = {
|
|
DataSource.QDRANT: settings.vector_weight,
|
|
DataSource.SPARQL: settings.graph_weight,
|
|
DataSource.TYPEDB: settings.typedb_weight,
|
|
DataSource.POSTGIS: 0.3,
|
|
}
|
|
|
|
weight = source_weights.get(result.source, 0.5)
|
|
merged[ghcid]["_rrf_score"] += rrf_score * weight
|
|
merged[ghcid]["_sources"].append(result.source.value)
|
|
|
|
# Sort by RRF score
|
|
sorted_items = sorted(
|
|
merged.values(),
|
|
key=lambda x: x.get("_rrf_score", 0),
|
|
reverse=True,
|
|
)
|
|
|
|
final_items = sorted_items[:max_results]
|
|
|
|
# Build epistemic provenance
|
|
provenance = EpistemicProvenance(
|
|
dataSource=EpistemicDataSource.RAG_PIPELINE,
|
|
dataTier=aggregate_data_tier(tier_counts),
|
|
sourceTimestamp=datetime.now(timezone.utc).isoformat(),
|
|
derivationChain=build_derivation_chain(
|
|
sources_used=sources_queried,
|
|
template_used=template_used,
|
|
template_id=template_id,
|
|
),
|
|
revalidationPolicy="weekly",
|
|
sourcesQueried=sources_queried,
|
|
totalRetrieved=total_retrieved,
|
|
totalAfterFusion=len(final_items),
|
|
dataTierBreakdown={
|
|
f"tier_{tier.value}": count
|
|
for tier, count in tier_counts.items()
|
|
},
|
|
templateUsed=template_used,
|
|
templateId=template_id,
|
|
)
|
|
|
|
return final_items, provenance
|
|
|
|
async def close(self) -> None:
|
|
"""Clean up resources."""
|
|
await self.cache.close()
|
|
|
|
if self._sparql_client:
|
|
await self._sparql_client.aclose()
|
|
if self._postgis_client:
|
|
await self._postgis_client.aclose()
|
|
if self._qdrant:
|
|
self._qdrant.close()
|
|
if self._typedb:
|
|
self._typedb.close()
|
|
|
|
def search_persons(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
filter_custodian: str | None = None,
|
|
only_heritage_relevant: bool = False,
|
|
using: str | None = None,
|
|
) -> list[Any]:
|
|
"""Search for persons/staff in the heritage_persons collection.
|
|
|
|
Delegates to HybridRetriever.search_persons() if available.
|
|
|
|
Args:
|
|
query: Search query
|
|
k: Number of results
|
|
filter_custodian: Optional custodian slug to filter by
|
|
only_heritage_relevant: Only return heritage-relevant staff
|
|
using: Optional embedding model to use (e.g., 'minilm_384', 'openai_1536')
|
|
|
|
Returns:
|
|
List of RetrievedPerson objects
|
|
"""
|
|
if self.qdrant:
|
|
try:
|
|
return self.qdrant.search_persons( # type: ignore[no-any-return]
|
|
query=query,
|
|
k=k,
|
|
filter_custodian=filter_custodian,
|
|
only_heritage_relevant=only_heritage_relevant,
|
|
using=using,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Person search failed: {e}")
|
|
return []
|
|
|
|
def get_stats(self) -> dict[str, Any]:
|
|
"""Get statistics from all retrievers.
|
|
|
|
Returns combined stats from Qdrant (including persons collection) and TypeDB.
|
|
"""
|
|
stats = {}
|
|
|
|
if self.qdrant:
|
|
try:
|
|
qdrant_stats = self.qdrant.get_stats()
|
|
stats.update(qdrant_stats)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get Qdrant stats: {e}")
|
|
|
|
if self.typedb:
|
|
try:
|
|
typedb_stats = self.typedb.get_stats()
|
|
stats["typedb"] = typedb_stats
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get TypeDB stats: {e}")
|
|
|
|
return stats
|
|
|
|
|
|
# Global instances
|
|
retriever: MultiSourceRetriever | None = None
|
|
viz_selector: VisualizationSelector | None = None
|
|
dspy_pipeline: Any = None # HeritageRAGPipeline instance (loaded with optimized model)
|
|
atomic_cache_manager: Any = None # AtomicCacheManager for sub-task caching (40-70% hit rate vs 5-15% for full queries)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
|
"""Application lifespan manager."""
|
|
global retriever, viz_selector, dspy_pipeline, atomic_cache_manager
|
|
|
|
# Startup
|
|
logger.info("Starting Heritage RAG API...")
|
|
retriever = MultiSourceRetriever()
|
|
|
|
if RETRIEVERS_AVAILABLE:
|
|
# Check for any available LLM API key (Anthropic preferred, OpenAI fallback)
|
|
has_llm_key = bool(settings.anthropic_api_key or settings.openai_api_key)
|
|
|
|
# VisualizationSelector requires DSPy - make it optional
|
|
try:
|
|
viz_selector = VisualizationSelector(use_dspy=has_llm_key)
|
|
except RuntimeError as e:
|
|
logger.warning(f"VisualizationSelector not available: {e}")
|
|
viz_selector = None
|
|
|
|
# Configure DSPy based on LLM_PROVIDER setting
|
|
# Respect user's provider preference, with fallback chain
|
|
import dspy
|
|
llm_provider = settings.llm_provider.lower()
|
|
logger.info(f"LLM_PROVIDER configured as: {llm_provider}")
|
|
dspy_configured = False
|
|
|
|
# Try Z.AI GLM if configured as provider (FREE!)
|
|
if llm_provider == "zai" and settings.zai_api_token:
|
|
try:
|
|
# Z.AI uses OpenAI-compatible API format
|
|
# Use LLM_MODEL from settings (default: glm-4.5-flash for speed)
|
|
zai_model = settings.llm_model if settings.llm_model.startswith("glm-") else "glm-4.5-flash"
|
|
lm = dspy.LM(
|
|
f"openai/{zai_model}",
|
|
api_key=settings.zai_api_token,
|
|
api_base="https://api.z.ai/api/coding/paas/v4",
|
|
)
|
|
dspy.configure(lm=lm)
|
|
logger.info(f"Configured DSPy with Z.AI {zai_model} (FREE)")
|
|
dspy_configured = True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to configure DSPy with Z.AI: {e}")
|
|
|
|
# Try HuggingFace if configured as provider
|
|
if not dspy_configured and llm_provider == "huggingface" and settings.huggingface_api_key:
|
|
try:
|
|
lm = dspy.LM("huggingface/utter-project/EuroLLM-9B-Instruct", api_key=settings.huggingface_api_key)
|
|
dspy.configure(lm=lm)
|
|
logger.info("Configured DSPy with HuggingFace EuroLLM-9B-Instruct")
|
|
dspy_configured = True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to configure DSPy with HuggingFace: {e}")
|
|
|
|
# Try Anthropic if not yet configured (either as primary or fallback)
|
|
if not dspy_configured and (llm_provider == "anthropic" or (llm_provider == "huggingface" and settings.anthropic_api_key)):
|
|
if settings.anthropic_api_key and configure_dspy:
|
|
try:
|
|
configure_dspy(
|
|
provider="anthropic",
|
|
model=settings.default_model,
|
|
api_key=settings.anthropic_api_key,
|
|
)
|
|
dspy_configured = True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to configure DSPy with Anthropic: {e}")
|
|
|
|
# Try OpenAI as final fallback
|
|
if not dspy_configured and settings.openai_api_key and configure_dspy:
|
|
try:
|
|
configure_dspy(
|
|
provider="openai",
|
|
model="gpt-4o-mini",
|
|
api_key=settings.openai_api_key,
|
|
)
|
|
dspy_configured = True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to configure DSPy with OpenAI: {e}")
|
|
|
|
if not dspy_configured:
|
|
logger.warning("No LLM provider configured - DSPy queries will fail")
|
|
|
|
# Initialize optimized HeritageRAGPipeline (if DSPy is configured)
|
|
if dspy_configured:
|
|
try:
|
|
from dspy_heritage_rag import HeritageRAGPipeline
|
|
from pathlib import Path
|
|
|
|
# Create pipeline with Qdrant retriever
|
|
qdrant_retriever = retriever.qdrant if retriever else None
|
|
dspy_pipeline = HeritageRAGPipeline(retriever=qdrant_retriever)
|
|
|
|
# Load optimized model (BootstrapFewShot: 14.3% quality improvement)
|
|
# Note: load() may fail if new modules were added that aren't in the saved state
|
|
optimized_model_path = Path(__file__).parent / "optimized_models" / "heritage_rag_bootstrap_latest.json"
|
|
if optimized_model_path.exists():
|
|
try:
|
|
dspy_pipeline.load(str(optimized_model_path))
|
|
logger.info(f"Loaded optimized DSPy pipeline from {optimized_model_path}")
|
|
except Exception as load_err:
|
|
# Pipeline still works, just without optimized demos for new modules
|
|
logger.warning(f"Could not load optimized model (new modules may need re-optimization): {load_err}")
|
|
logger.info("Pipeline initialized without optimized demos - will work but may be less accurate")
|
|
else:
|
|
logger.warning(f"Optimized model not found at {optimized_model_path}, using unoptimized pipeline")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize DSPy pipeline: {e}")
|
|
dspy_pipeline = None
|
|
|
|
# === HOT LOADING: Warmup embedding model to avoid cold-start latency ===
|
|
# The sentence-transformers model takes 3-15 seconds to load on first use.
|
|
# By loading it eagerly at startup, we eliminate this delay for users.
|
|
if retriever.qdrant:
|
|
logger.info("Warming up embedding model (this takes 3-15 seconds on first startup)...")
|
|
warmup_start = time.perf_counter()
|
|
try:
|
|
# Trigger model load with a dummy embedding request
|
|
_ = retriever.qdrant._get_embedding("archief warmup query")
|
|
warmup_duration = time.perf_counter() - warmup_start
|
|
logger.info(f"✅ Embedding model warmed up in {warmup_duration:.2f}s - ready for fast queries!")
|
|
|
|
# Record warmup metrics
|
|
if record_embedding_warmup:
|
|
record_embedding_warmup(
|
|
model="sentence-transformers/all-MiniLM-L6-v2",
|
|
duration_seconds=warmup_duration,
|
|
success=True,
|
|
)
|
|
except Exception as e:
|
|
warmup_duration = time.perf_counter() - warmup_start
|
|
logger.warning(f"Failed to warm up embedding model: {e}")
|
|
if record_embedding_warmup:
|
|
record_embedding_warmup(
|
|
model="sentence-transformers/all-MiniLM-L6-v2",
|
|
duration_seconds=warmup_duration,
|
|
success=False,
|
|
)
|
|
|
|
# === TEMPLATE EMBEDDING WARMUP: Pre-compute embeddings for template patterns ===
|
|
# The TemplateEmbeddingMatcher computes embeddings on first query (~2-5 seconds).
|
|
# By pre-computing at startup, we eliminate this delay for users.
|
|
template_warmup_start = time.perf_counter()
|
|
template_count = 0
|
|
try:
|
|
from template_sparql import get_template_embedding_matcher, TemplateClassifier
|
|
|
|
logger.info("Pre-computing template pattern embeddings...")
|
|
classifier = TemplateClassifier()
|
|
templates = classifier._load_templates()
|
|
|
|
if templates:
|
|
template_count = len(templates)
|
|
matcher = get_template_embedding_matcher()
|
|
if matcher._ensure_embeddings_computed(templates):
|
|
template_warmup_duration = time.perf_counter() - template_warmup_start
|
|
logger.info(f"✅ Template embeddings pre-computed ({template_count} templates) in {template_warmup_duration:.2f}s")
|
|
|
|
# Record template warmup metrics
|
|
if record_template_embedding_warmup:
|
|
record_template_embedding_warmup(
|
|
duration_seconds=template_warmup_duration,
|
|
template_count=template_count,
|
|
success=True,
|
|
)
|
|
else:
|
|
logger.warning("Template embedding computation skipped (model not available)")
|
|
if set_warmup_status:
|
|
set_warmup_status("template_embeddings", False)
|
|
else:
|
|
logger.warning("No templates found for embedding warmup")
|
|
if set_warmup_status:
|
|
set_warmup_status("template_embeddings", False)
|
|
except Exception as e:
|
|
template_warmup_duration = time.perf_counter() - template_warmup_start
|
|
logger.warning(f"Failed to pre-compute template embeddings: {e}")
|
|
if record_template_embedding_warmup:
|
|
record_template_embedding_warmup(
|
|
duration_seconds=template_warmup_duration,
|
|
template_count=template_count,
|
|
success=False,
|
|
)
|
|
|
|
# === ATOMIC CACHE MANAGER: Sub-task caching for higher hit rates ===
|
|
# Research shows 40-70% cache hit rates with atomic decomposition vs 5-15% for full queries.
|
|
# Initialize AtomicCacheManager with retriever's semantic cache for persistence.
|
|
if ATOMIC_CACHE_AVAILABLE and AtomicCacheManager:
|
|
try:
|
|
semantic_cache = retriever.cache if retriever else None
|
|
atomic_cache_manager = AtomicCacheManager(semantic_cache=semantic_cache)
|
|
logger.info("✅ AtomicCacheManager initialized for sub-task caching")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize AtomicCacheManager: {e}")
|
|
|
|
# === ONTOLOGY CACHE WARMUP: Pre-load KG values to avoid cold-start latency ===
|
|
# The OntologyLoader queries the Knowledge Graph for valid slot values (cities, regions, types).
|
|
# These queries can take 1-3 seconds each on first access.
|
|
# By pre-loading at startup, we eliminate this delay for users.
|
|
ontology_warmup_start = time.perf_counter()
|
|
try:
|
|
from template_sparql import get_ontology_loader
|
|
|
|
logger.info("Warming up ontology cache (pre-loading KG values)...")
|
|
ontology = get_ontology_loader()
|
|
ontology.load() # Triggers KG queries for institution_types, subregions, cities, etc.
|
|
|
|
ontology_warmup_duration = time.perf_counter() - ontology_warmup_start
|
|
cache_stats = ontology.get_kg_cache_stats()
|
|
logger.info(
|
|
f"✅ Ontology cache warmed up in {ontology_warmup_duration:.2f}s "
|
|
f"({cache_stats['cache_size']} KG queries cached, TTL={cache_stats['ttl_seconds']}s)"
|
|
)
|
|
except Exception as e:
|
|
ontology_warmup_duration = time.perf_counter() - ontology_warmup_start
|
|
logger.warning(f"Failed to warm up ontology cache: {e}")
|
|
|
|
logger.info("Heritage RAG API started")
|
|
|
|
yield
|
|
|
|
# Shutdown
|
|
logger.info("Shutting down Heritage RAG API...")
|
|
if retriever:
|
|
await retriever.close()
|
|
logger.info("Heritage RAG API stopped")
|
|
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(
|
|
title=settings.api_title,
|
|
version=settings.api_version,
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Prometheus metrics endpoint
|
|
if METRICS_AVAILABLE and create_metrics_endpoint:
|
|
app.include_router(create_metrics_endpoint(), prefix="/api/rag")
|
|
|
|
|
|
# API Endpoints
|
|
@app.get("/api/rag/health")
|
|
async def health_check() -> dict[str, Any]:
|
|
"""Health check for all services."""
|
|
health: dict[str, Any] = {
|
|
"status": "ok",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"services": {},
|
|
}
|
|
|
|
# Check Qdrant
|
|
if retriever and retriever.qdrant:
|
|
try:
|
|
stats = retriever.qdrant.get_stats()
|
|
health["services"]["qdrant"] = {
|
|
"status": "ok",
|
|
"vectors": stats.get("qdrant", {}).get("vectors_count", 0),
|
|
}
|
|
except Exception as e:
|
|
health["services"]["qdrant"] = {"status": "error", "error": str(e)}
|
|
|
|
# Check SPARQL
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
response = await client.get(f"{settings.sparql_endpoint.replace('/query', '')}")
|
|
health["services"]["sparql"] = {
|
|
"status": "ok" if response.status_code < 500 else "error"
|
|
}
|
|
except Exception as e:
|
|
health["services"]["sparql"] = {"status": "error", "error": str(e)}
|
|
|
|
# Check TypeDB
|
|
if retriever and retriever.typedb:
|
|
try:
|
|
stats = retriever.typedb.get_stats()
|
|
health["services"]["typedb"] = {
|
|
"status": "ok",
|
|
"entities": stats.get("entities", {}),
|
|
}
|
|
except Exception as e:
|
|
health["services"]["typedb"] = {"status": "error", "error": str(e)}
|
|
|
|
# Overall status
|
|
services = health["services"]
|
|
errors = sum(1 for s in services.values() if isinstance(s, dict) and s.get("status") == "error")
|
|
health["status"] = "ok" if errors == 0 else "degraded" if errors < 3 else "error"
|
|
|
|
return health
|
|
|
|
|
|
@app.get("/api/rag/stats")
|
|
async def get_stats() -> dict[str, Any]:
|
|
"""Get retriever statistics."""
|
|
stats: dict[str, Any] = {
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"retrievers": {},
|
|
}
|
|
|
|
if retriever:
|
|
if retriever.qdrant:
|
|
stats["retrievers"]["qdrant"] = retriever.qdrant.get_stats()
|
|
if retriever.typedb:
|
|
stats["retrievers"]["typedb"] = retriever.typedb.get_stats()
|
|
|
|
return stats
|
|
|
|
|
|
@app.get("/api/rag/stats/costs")
|
|
async def get_cost_stats() -> dict[str, Any]:
|
|
"""Get cost tracking session statistics.
|
|
|
|
Returns cumulative statistics for the current session including:
|
|
- Total LLM calls and token usage
|
|
- Total retrieval operations and latencies
|
|
- Estimated costs by model
|
|
- Pipeline timing statistics
|
|
|
|
Returns:
|
|
Dict with cost tracker statistics or unavailable message
|
|
"""
|
|
if not COST_TRACKER_AVAILABLE or not get_tracker:
|
|
return {
|
|
"available": False,
|
|
"message": "Cost tracker module not available",
|
|
}
|
|
|
|
tracker = get_tracker()
|
|
return {
|
|
"available": True,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"session": tracker.get_session_summary(),
|
|
}
|
|
|
|
|
|
@app.post("/api/rag/stats/costs/reset")
|
|
async def reset_cost_stats() -> dict[str, Any]:
|
|
"""Reset cost tracking statistics.
|
|
|
|
Clears all accumulated statistics and starts a fresh session.
|
|
Useful for per-conversation or per-session cost tracking.
|
|
|
|
Returns:
|
|
Confirmation message
|
|
"""
|
|
if not COST_TRACKER_AVAILABLE or not reset_tracker:
|
|
return {
|
|
"available": False,
|
|
"message": "Cost tracker module not available",
|
|
}
|
|
|
|
reset_tracker()
|
|
return {
|
|
"available": True,
|
|
"message": "Cost tracking statistics reset",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
|
|
|
|
@app.get("/api/rag/stats/templates")
|
|
async def get_template_stats() -> dict[str, Any]:
|
|
"""Get template SPARQL usage statistics.
|
|
|
|
Returns metrics about template-based SPARQL query generation,
|
|
including hit rate and breakdown by template ID.
|
|
|
|
This is useful for:
|
|
- Monitoring template coverage (what % of queries use templates)
|
|
- Identifying which templates are most used
|
|
- Tuning template slot extraction parameters
|
|
|
|
Returns:
|
|
Dict with template hit rate, breakdown by template_id, and timestamp
|
|
"""
|
|
if not METRICS_AVAILABLE:
|
|
return {
|
|
"available": False,
|
|
"message": "Metrics module not available",
|
|
}
|
|
|
|
# Import the metrics functions
|
|
try:
|
|
from metrics import get_template_hit_rate, get_template_breakdown
|
|
except ImportError:
|
|
return {
|
|
"available": False,
|
|
"message": "Metrics module import failed",
|
|
}
|
|
|
|
return {
|
|
"available": True,
|
|
"hit_rate": get_template_hit_rate(),
|
|
"breakdown": get_template_breakdown(),
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
|
|
|
|
@app.get("/api/rag/embedding/models")
|
|
async def get_embedding_models() -> dict[str, Any]:
|
|
"""List available embedding models for the Qdrant collections.
|
|
|
|
Returns information about which embedding models are available in each
|
|
collection's named vectors, helping clients choose the right model for
|
|
their use case.
|
|
|
|
Returns:
|
|
Dict with available models per collection, current settings, and recommendations
|
|
"""
|
|
result: dict[str, Any] = {
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"multi_embedding_enabled": settings.use_multi_embedding,
|
|
"preferred_model": settings.preferred_embedding_model,
|
|
"collections": {},
|
|
"models": {
|
|
"openai_1536": {
|
|
"description": "OpenAI text-embedding-3-small (1536 dimensions)",
|
|
"quality": "high",
|
|
"cost": "paid API",
|
|
"recommended_for": "production, high-quality semantic search",
|
|
},
|
|
"minilm_384": {
|
|
"description": "sentence-transformers/all-MiniLM-L6-v2 (384 dimensions)",
|
|
"quality": "good",
|
|
"cost": "free (local)",
|
|
"recommended_for": "development, cost-sensitive deployments",
|
|
},
|
|
"bge_768": {
|
|
"description": "BAAI/bge-small-en-v1.5 (768 dimensions)",
|
|
"quality": "very good",
|
|
"cost": "free (local)",
|
|
"recommended_for": "balanced quality/cost, multilingual support",
|
|
},
|
|
},
|
|
}
|
|
|
|
if retriever and retriever.qdrant:
|
|
qdrant = retriever.qdrant
|
|
|
|
# Check if multi-embedding is enabled and get available models
|
|
if hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding:
|
|
if hasattr(qdrant, 'multi_retriever') and qdrant.multi_retriever:
|
|
multi = qdrant.multi_retriever
|
|
|
|
# Get available models for institutions collection
|
|
try:
|
|
inst_models = multi.get_available_models("heritage_custodians")
|
|
selected = multi.select_model("heritage_custodians")
|
|
result["collections"]["heritage_custodians"] = {
|
|
"available_models": [m.value for m in inst_models],
|
|
"uses_named_vectors": multi.uses_named_vectors("heritage_custodians"),
|
|
"recommended": selected.value if selected else None,
|
|
}
|
|
except Exception as e:
|
|
result["collections"]["heritage_custodians"] = {"error": str(e)}
|
|
|
|
# Get available models for persons collection
|
|
try:
|
|
person_models = multi.get_available_models("heritage_persons")
|
|
selected = multi.select_model("heritage_persons")
|
|
result["collections"]["heritage_persons"] = {
|
|
"available_models": [m.value for m in person_models],
|
|
"uses_named_vectors": multi.uses_named_vectors("heritage_persons"),
|
|
"recommended": selected.value if selected else None,
|
|
}
|
|
except Exception as e:
|
|
result["collections"]["heritage_persons"] = {"error": str(e)}
|
|
else:
|
|
# Single embedding mode - detect dimension
|
|
stats = qdrant.get_stats()
|
|
result["single_embedding_mode"] = True
|
|
result["note"] = "Collections use single embedding vectors. Enable USE_MULTI_EMBEDDING=true to use named vectors."
|
|
|
|
return result
|
|
|
|
|
|
class EmbeddingCompareRequest(BaseModel):
|
|
"""Request for comparing embedding models."""
|
|
query: str = Field(..., description="Query to search with")
|
|
collection: str = Field(default="heritage_persons", description="Collection to search")
|
|
k: int = Field(default=5, ge=1, le=20, description="Number of results per model")
|
|
|
|
|
|
@app.post("/api/rag/embedding/compare")
|
|
async def compare_embedding_models(request: EmbeddingCompareRequest) -> dict[str, Any]:
|
|
"""Compare search results across different embedding models.
|
|
|
|
Performs the same search query using each available embedding model,
|
|
allowing A/B testing of embedding quality.
|
|
|
|
This endpoint is useful for:
|
|
- Evaluating which embedding model works best for your queries
|
|
- Understanding differences in semantic similarity between models
|
|
- Making informed decisions about which model to use in production
|
|
|
|
Returns:
|
|
Dict with results from each embedding model, including scores and overlap analysis
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
if not retriever or not retriever.qdrant:
|
|
raise HTTPException(status_code=503, detail="Qdrant retriever not available")
|
|
|
|
qdrant = retriever.qdrant
|
|
|
|
# Check if multi-embedding is available
|
|
if not (hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Multi-embedding mode not enabled. Set USE_MULTI_EMBEDDING=true to use this endpoint."
|
|
)
|
|
|
|
if not (hasattr(qdrant, 'multi_retriever') and qdrant.multi_retriever):
|
|
raise HTTPException(status_code=503, detail="Multi-embedding retriever not initialized")
|
|
|
|
multi = qdrant.multi_retriever
|
|
|
|
try:
|
|
# Use the compare_models method from MultiEmbeddingRetriever
|
|
comparison = multi.compare_models(
|
|
query=request.query,
|
|
collection=request.collection,
|
|
k=request.k,
|
|
)
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
return {
|
|
"query": request.query,
|
|
"collection": request.collection,
|
|
"k": request.k,
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"comparison": comparison,
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Embedding comparison failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/api/rag/query", response_model=QueryResponse)
|
|
async def query_rag(request: QueryRequest) -> QueryResponse:
|
|
"""Main RAG query endpoint.
|
|
|
|
Orchestrates retrieval from multiple sources, merges results,
|
|
and optionally generates visualization configuration.
|
|
"""
|
|
if not retriever:
|
|
raise HTTPException(status_code=503, detail="Retriever not initialized")
|
|
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
# Check cache first
|
|
cached = await retriever.cache.get(request.question, request.sources)
|
|
if cached:
|
|
# Transform cached data to QueryResponse schema
|
|
# Cache stores: answer, sparql_query, sources, confidence, context
|
|
# QueryResponse needs: question, sparql, results, sources_used, query_time_ms, result_count
|
|
try:
|
|
# Get sources from cached data (may be strings or DataSource enums)
|
|
cached_sources = cached.get("sources", [])
|
|
sources_used = []
|
|
for s in cached_sources:
|
|
if isinstance(s, str):
|
|
try:
|
|
sources_used.append(DataSource(s))
|
|
except ValueError:
|
|
# Skip invalid source values
|
|
pass
|
|
elif isinstance(s, DataSource):
|
|
sources_used.append(s)
|
|
|
|
# Get results from context if available
|
|
results = cached.get("results", [])
|
|
if not results and cached.get("context"):
|
|
results = cached["context"].get("retrieved_results", []) or []
|
|
|
|
return QueryResponse(
|
|
question=request.question,
|
|
sparql=cached.get("sparql_query") or cached.get("sparql"),
|
|
results=results,
|
|
visualization=cached.get("visualization"),
|
|
sources_used=sources_used or [DataSource.QDRANT], # Default if none
|
|
cache_hit=True,
|
|
query_time_ms=cached.get("query_time_ms", 0.0),
|
|
result_count=cached.get("result_count", len(results)),
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to transform cached response: {e}, skipping cache")
|
|
# Fall through to normal query processing
|
|
|
|
# Route query to appropriate sources
|
|
intent, sources = retriever.router.get_sources(request.question, request.sources)
|
|
logger.info(f"Query intent: {intent}, sources: {sources}")
|
|
|
|
# Extract geographic filters from question (province, city, institution type)
|
|
geo_filters = extract_geographic_filters(request.question)
|
|
if any(geo_filters.values()):
|
|
logger.info(f"Geographic filters extracted: {geo_filters}")
|
|
|
|
# Retrieve from all sources
|
|
results = await retriever.retrieve(
|
|
request.question,
|
|
sources,
|
|
request.k,
|
|
embedding_model=request.embedding_model,
|
|
region_codes=geo_filters["region_codes"],
|
|
cities=geo_filters["cities"],
|
|
institution_types=geo_filters["institution_types"],
|
|
)
|
|
|
|
# Merge results with provenance tracking
|
|
merged_items, retrieval_provenance = retriever.merge_results(results, max_results=request.k * 2)
|
|
|
|
# Generate visualization config if requested
|
|
visualization = None
|
|
if request.include_visualization and viz_selector and merged_items:
|
|
# Extract schema from first result
|
|
schema_fields = list(merged_items[0].keys()) if merged_items else []
|
|
schema_str = ", ".join(f for f in schema_fields if not f.startswith("_"))
|
|
|
|
visualization = viz_selector.select(
|
|
request.question,
|
|
schema_str,
|
|
len(merged_items),
|
|
)
|
|
|
|
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
|
|
|
|
response_data = {
|
|
"question": request.question,
|
|
"sparql": None, # Could be populated from SPARQL result
|
|
"results": merged_items,
|
|
"visualization": visualization,
|
|
"sources_used": [s for s in sources],
|
|
"cache_hit": False,
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"result_count": len(merged_items),
|
|
}
|
|
|
|
# Cache the response
|
|
await retriever.cache.set(request.question, request.sources, response_data)
|
|
|
|
return QueryResponse(**response_data) # type: ignore[arg-type]
|
|
|
|
|
|
@app.post("/api/rag/sparql", response_model=SPARQLResponse)
|
|
async def generate_sparql_endpoint(request: SPARQLRequest) -> SPARQLResponse:
|
|
"""Generate SPARQL query from natural language.
|
|
|
|
Uses TEMPLATE-FIRST approach:
|
|
1. Try template-based SPARQL generation (deterministic, validated)
|
|
2. Fall back to LLM-based generation only if no template matches
|
|
|
|
Template approach provides 65% precision vs 10% for LLM-only (Formica et al. 2023).
|
|
"""
|
|
global _template_pipeline_instance
|
|
|
|
template_used = False
|
|
sparql_query = ""
|
|
explanation = ""
|
|
|
|
try:
|
|
# ===================================================================
|
|
# STEP 1: Try TEMPLATE-BASED SPARQL generation (preferred)
|
|
# ===================================================================
|
|
if TEMPLATE_SPARQL_AVAILABLE and get_template_pipeline:
|
|
try:
|
|
# Get or create singleton pipeline instance
|
|
if _template_pipeline_instance is None:
|
|
_template_pipeline_instance = get_template_pipeline()
|
|
logger.info("[SPARQL] Template pipeline initialized for /api/rag/sparql endpoint")
|
|
|
|
# Run template matching in thread pool (DSPy is synchronous)
|
|
template_result = await asyncio.to_thread(
|
|
_template_pipeline_instance,
|
|
question=request.question,
|
|
conversation_state=None,
|
|
language=request.language
|
|
)
|
|
|
|
if template_result.matched and template_result.sparql:
|
|
sparql_query = template_result.sparql
|
|
template_used = True
|
|
explanation = (
|
|
f"Template '{template_result.template_id}' matched with "
|
|
f"confidence {template_result.confidence:.2f}. "
|
|
f"Slots: {template_result.slots}. "
|
|
f"{template_result.reasoning}"
|
|
)
|
|
logger.info(f"[SPARQL] Template match: '{template_result.template_id}' "
|
|
f"(confidence={template_result.confidence:.2f})")
|
|
else:
|
|
logger.info(f"[SPARQL] No template match: {template_result.reasoning}")
|
|
except Exception as e:
|
|
logger.warning(f"[SPARQL] Template pipeline failed: {e}")
|
|
|
|
# ===================================================================
|
|
# STEP 2: Fall back to LLM-BASED SPARQL generation
|
|
# ===================================================================
|
|
if not template_used:
|
|
if not RETRIEVERS_AVAILABLE:
|
|
raise HTTPException(status_code=503, detail="SPARQL generator not available")
|
|
|
|
logger.info("[SPARQL] Falling back to LLM-based SPARQL generation")
|
|
result = generate_sparql(
|
|
request.question,
|
|
language=request.language,
|
|
context=request.context,
|
|
use_rag=request.use_rag,
|
|
)
|
|
sparql_query = result["sparql"]
|
|
explanation = result.get("explanation", "")
|
|
|
|
return SPARQLResponse(
|
|
sparql=sparql_query,
|
|
explanation=explanation,
|
|
rag_used=not template_used, # RAG only used if LLM fallback
|
|
retrieved_passages=[],
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("SPARQL generation failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/api/rag/sparql/execute", response_model=SPARQLExecuteResponse)
|
|
async def execute_sparql_query(request: SPARQLExecuteRequest) -> SPARQLExecuteResponse:
|
|
"""Execute a SPARQL query directly against the knowledge graph.
|
|
|
|
This endpoint allows users to run modified SPARQL queries and see the results
|
|
without regenerating an answer. Useful for exploration and debugging.
|
|
"""
|
|
import time
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Use pooled SPARQL client from retriever for better connection reuse
|
|
if retriever:
|
|
client = await retriever._get_sparql_client()
|
|
response = await client.post(
|
|
settings.sparql_endpoint,
|
|
data={"query": request.sparql_query},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
timeout=request.timeout, # Override timeout per-request if specified
|
|
)
|
|
else:
|
|
# Fallback to creating a new client if retriever not initialized
|
|
async with httpx.AsyncClient(timeout=request.timeout) as client:
|
|
response = await client.post(
|
|
settings.sparql_endpoint,
|
|
data={"query": request.sparql_query},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
return SPARQLExecuteResponse(
|
|
results=[],
|
|
result_count=0,
|
|
query_time_ms=(time.time() - start_time) * 1000,
|
|
error=f"SPARQL endpoint returned {response.status_code}: {response.text[:500]}",
|
|
)
|
|
|
|
data = response.json()
|
|
bindings = data.get("results", {}).get("bindings", [])
|
|
|
|
# Convert bindings to simple dicts
|
|
results = [
|
|
{k: v.get("value") for k, v in binding.items()}
|
|
for binding in bindings
|
|
]
|
|
|
|
return SPARQLExecuteResponse(
|
|
results=results,
|
|
result_count=len(results),
|
|
query_time_ms=(time.time() - start_time) * 1000,
|
|
)
|
|
|
|
except httpx.TimeoutException:
|
|
return SPARQLExecuteResponse(
|
|
results=[],
|
|
result_count=0,
|
|
query_time_ms=(time.time() - start_time) * 1000,
|
|
error=f"Query timed out after {request.timeout}s",
|
|
)
|
|
except Exception as e:
|
|
logger.exception("SPARQL execution failed")
|
|
return SPARQLExecuteResponse(
|
|
results=[],
|
|
result_count=0,
|
|
query_time_ms=(time.time() - start_time) * 1000,
|
|
error=str(e),
|
|
)
|
|
|
|
|
|
@app.post("/api/rag/sparql/rerun", response_model=SPARQLRerunResponse)
|
|
async def rerun_rag_with_sparql(request: SPARQLRerunRequest) -> SPARQLRerunResponse:
|
|
"""Re-run the RAG pipeline with modified SPARQL results injected into context.
|
|
|
|
This endpoint allows users to:
|
|
1. Execute a modified SPARQL query
|
|
2. Inject those results into the DSPy RAG context
|
|
3. Generate a new answer based on the modified knowledge graph results
|
|
|
|
This affects the entire conversation through DSPy by providing new factual
|
|
context that the LLM uses to generate its response.
|
|
"""
|
|
import time
|
|
import dspy
|
|
|
|
start_time = time.time()
|
|
|
|
# Step 1: Execute the modified SPARQL query using pooled client
|
|
sparql_results: list[dict[str, Any]] = []
|
|
try:
|
|
# Use pooled SPARQL client from retriever for better connection reuse
|
|
if retriever:
|
|
client = await retriever._get_sparql_client()
|
|
response = await client.post(
|
|
settings.sparql_endpoint,
|
|
data={"query": request.sparql_query},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
)
|
|
else:
|
|
# Fallback to creating a new client if retriever not initialized
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.post(
|
|
settings.sparql_endpoint,
|
|
data={"query": request.sparql_query},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
bindings = data.get("results", {}).get("bindings", [])
|
|
sparql_results = [
|
|
{k: v.get("value") for k, v in binding.items()}
|
|
for binding in bindings[:50] # Limit to 50 results for context
|
|
]
|
|
logger.info(f"SPARQL rerun: got {len(sparql_results)} results")
|
|
else:
|
|
logger.warning(f"SPARQL rerun: endpoint returned {response.status_code}")
|
|
except Exception as e:
|
|
logger.exception(f"SPARQL rerun: execution failed: {e}")
|
|
|
|
# Step 2: Format SPARQL results as context for DSPy
|
|
sparql_context = ""
|
|
if sparql_results:
|
|
sparql_context = "\n[KENNISGRAAF RESULTATEN (aangepaste SPARQL query)]:\n"
|
|
for i, result in enumerate(sparql_results[:20], 1):
|
|
entry = " | ".join(f"{k}: {v}" for k, v in result.items() if v)
|
|
sparql_context += f" {i}. {entry}\n"
|
|
|
|
# Step 3: Run DSPy answer generation with injected SPARQL context
|
|
answer = ""
|
|
try:
|
|
from dspy_heritage_rag import HeritageRAGPipeline
|
|
|
|
# Get LLM configuration
|
|
lm = None
|
|
provider = request.llm_provider or "zai"
|
|
model = request.llm_model
|
|
|
|
if provider == "zai" and settings.zai_api_token:
|
|
model = model or "glm-4.5-flash"
|
|
lm = dspy.LM(
|
|
f"openai/{model}",
|
|
api_key=settings.zai_api_token,
|
|
api_base="https://api.z.ai/api/coding/paas/v4",
|
|
)
|
|
elif provider == "groq" and settings.groq_api_key:
|
|
model = model or "llama-3.1-8b-instant"
|
|
lm = dspy.LM(f"groq/{model}", api_key=settings.groq_api_key)
|
|
elif provider == "openai" and settings.openai_api_key:
|
|
model = model or "gpt-4o-mini"
|
|
lm = dspy.LM(f"openai/{model}", api_key=settings.openai_api_key)
|
|
elif provider == "anthropic" and settings.anthropic_api_key:
|
|
model = model or "claude-sonnet-4-20250514"
|
|
lm = dspy.LM(f"anthropic/{model}", api_key=settings.anthropic_api_key)
|
|
|
|
if lm:
|
|
with dspy.settings.context(lm=lm):
|
|
# Create a simple answer generator that uses the SPARQL context
|
|
generate_answer = dspy.ChainOfThought(
|
|
"question, sparql_context, language -> answer"
|
|
)
|
|
|
|
result = generate_answer(
|
|
question=request.original_question,
|
|
sparql_context=sparql_context,
|
|
language=request.language,
|
|
)
|
|
answer = result.answer
|
|
else:
|
|
answer = f"LLM niet beschikbaar. SPARQL resultaten: {len(sparql_results)} gevonden."
|
|
|
|
except Exception as e:
|
|
logger.exception(f"SPARQL rerun: answer generation failed: {e}")
|
|
answer = f"Fout bij het genereren van antwoord: {str(e)}"
|
|
|
|
return SPARQLRerunResponse(
|
|
results=sparql_results,
|
|
answer=answer,
|
|
sparql_result_count=len(sparql_results),
|
|
query_time_ms=(time.time() - start_time) * 1000,
|
|
)
|
|
|
|
|
|
@app.post("/api/rag/visualize")
|
|
async def get_visualization_config(
|
|
question: str = Query(..., description="User's question"),
|
|
schema: str = Query(..., description="Comma-separated field names"),
|
|
result_count: int = Query(default=0, description="Number of results"),
|
|
) -> dict[str, Any]:
|
|
"""Get visualization configuration for a query."""
|
|
if not viz_selector:
|
|
raise HTTPException(status_code=503, detail="Visualization selector not available")
|
|
|
|
config = viz_selector.select(question, schema, result_count)
|
|
return config # type: ignore[no-any-return]
|
|
|
|
|
|
@app.post("/api/rag/typedb/search", response_model=TypeDBSearchResponse)
|
|
async def typedb_search(request: TypeDBSearchRequest) -> TypeDBSearchResponse:
|
|
"""Direct TypeDB search endpoint.
|
|
|
|
Search heritage custodians in TypeDB using various strategies:
|
|
- semantic: Natural language search (combines type + location patterns)
|
|
- name: Search by institution name
|
|
- type: Search by institution type (museum, archive, library, gallery)
|
|
- location: Search by city/location name
|
|
|
|
Examples:
|
|
- {"query": "museums in Amsterdam", "search_type": "semantic"}
|
|
- {"query": "Rijksmuseum", "search_type": "name"}
|
|
- {"query": "archive", "search_type": "type"}
|
|
- {"query": "Rotterdam", "search_type": "location"}
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
# Check if TypeDB retriever is available
|
|
if not retriever or not retriever.typedb:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="TypeDB retriever not available. Ensure TypeDB is running."
|
|
)
|
|
|
|
try:
|
|
typedb_retriever = retriever.typedb
|
|
|
|
# Route to appropriate search method
|
|
if request.search_type == "name":
|
|
results = typedb_retriever.search_by_name(request.query, k=request.k)
|
|
elif request.search_type == "type":
|
|
results = typedb_retriever.search_by_type(request.query, k=request.k)
|
|
elif request.search_type == "location":
|
|
results = typedb_retriever.search_by_location(city=request.query, k=request.k)
|
|
else: # semantic (default)
|
|
results = typedb_retriever.semantic_search(request.query, k=request.k)
|
|
|
|
# Convert results to dicts
|
|
result_dicts = []
|
|
seen_names = set() # Deduplicate by name
|
|
|
|
for r in results:
|
|
# Handle both dict and object results
|
|
if hasattr(r, 'to_dict'):
|
|
item = r.to_dict()
|
|
elif isinstance(r, dict):
|
|
item = r
|
|
else:
|
|
item = {"name": str(r)}
|
|
|
|
# Deduplicate by name
|
|
name = item.get("name") or item.get("observed_name", "")
|
|
if name and name not in seen_names:
|
|
seen_names.add(name)
|
|
result_dicts.append(item)
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
return TypeDBSearchResponse(
|
|
query=request.query,
|
|
search_type=request.search_type,
|
|
results=result_dicts,
|
|
result_count=len(result_dicts),
|
|
query_time_ms=round(elapsed_ms, 2),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"TypeDB search failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/api/rag/persons/search", response_model=PersonSearchResponse)
|
|
async def person_search(request: PersonSearchRequest) -> PersonSearchResponse:
|
|
"""Search for persons/staff in heritage institutions.
|
|
|
|
Search the heritage_persons Qdrant collection for staff members
|
|
at heritage custodian institutions.
|
|
|
|
Examples:
|
|
- {"query": "Wie werkt er in het Nationaal Archief?"}
|
|
- {"query": "archivist at Rijksmuseum", "k": 20}
|
|
- {"query": "conservator", "filter_custodian": "rijksmuseum"}
|
|
- {"query": "digital preservation", "only_heritage_relevant": true}
|
|
|
|
The search uses semantic vector similarity to find relevant staff members
|
|
based on their name, role, headline, and custodian affiliation.
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
# Check if retriever is available
|
|
if not retriever:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Hybrid retriever not available. Ensure Qdrant is running."
|
|
)
|
|
|
|
try:
|
|
# Use the hybrid retriever's person search
|
|
results = retriever.search_persons(
|
|
query=request.query,
|
|
k=request.k,
|
|
filter_custodian=request.filter_custodian,
|
|
only_heritage_relevant=request.only_heritage_relevant,
|
|
using=request.embedding_model, # Pass embedding model
|
|
)
|
|
|
|
# Determine which embedding model was actually used
|
|
embedding_model_used = None
|
|
qdrant = retriever.qdrant
|
|
if qdrant and hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding:
|
|
if request.embedding_model:
|
|
embedding_model_used = request.embedding_model
|
|
elif hasattr(qdrant, '_selected_multi_model') and qdrant._selected_multi_model:
|
|
embedding_model_used = qdrant._selected_multi_model.value
|
|
|
|
# Convert results to dicts using to_dict() method if available
|
|
result_dicts = []
|
|
for r in results:
|
|
if hasattr(r, 'to_dict'):
|
|
item = r.to_dict()
|
|
elif hasattr(r, '__dict__'):
|
|
item = {
|
|
"name": getattr(r, 'name', 'Unknown'),
|
|
"headline": getattr(r, 'headline', None),
|
|
"custodian_name": getattr(r, 'custodian_name', None),
|
|
"custodian_slug": getattr(r, 'custodian_slug', None),
|
|
"linkedin_url": getattr(r, 'linkedin_url', None),
|
|
"heritage_relevant": getattr(r, 'heritage_relevant', None),
|
|
"heritage_type": getattr(r, 'heritage_type', None),
|
|
"location": getattr(r, 'location', None),
|
|
"score": getattr(r, 'combined_score', getattr(r, 'vector_score', None)),
|
|
}
|
|
elif isinstance(r, dict):
|
|
item = r
|
|
else:
|
|
item = {"name": str(r)}
|
|
|
|
result_dicts.append(item)
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
# Get collection stats
|
|
stats = None
|
|
try:
|
|
stats = retriever.get_stats()
|
|
# Only include person collection stats if available
|
|
if stats and 'persons' in stats:
|
|
stats = {'persons': stats['persons']}
|
|
except Exception:
|
|
pass
|
|
|
|
return PersonSearchResponse(
|
|
context=PERSON_JSONLD_CONTEXT, # JSON-LD context for linked data
|
|
query=request.query,
|
|
results=result_dicts,
|
|
result_count=len(result_dicts),
|
|
query_time_ms=round(elapsed_ms, 2),
|
|
collection_stats=stats,
|
|
embedding_model_used=embedding_model_used,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Person search failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
def _extract_subtask_result(task: Any, pipeline_result: Any, response: Any) -> Any:
|
|
"""Extract the relevant result portion for an atomic sub-task.
|
|
|
|
Maps sub-task types to corresponding data from the pipeline result.
|
|
This enables caching individual components for reuse in similar queries.
|
|
|
|
Args:
|
|
task: AtomicSubTask with task_type and parameters
|
|
pipeline_result: Raw result from HeritageRAGPipeline
|
|
response: DSPyQueryResponse object
|
|
|
|
Returns:
|
|
Cacheable result for this sub-task, or None if not extractable
|
|
"""
|
|
from atomic_decomposer import SubTaskType
|
|
|
|
task_type = task.task_type
|
|
|
|
# Intent classification - cache the detected intent
|
|
if task_type == SubTaskType.INTENT_CLASSIFICATION:
|
|
return {
|
|
"intent": task.parameters.get("intent"),
|
|
"query_type": getattr(pipeline_result, "query_type", None),
|
|
}
|
|
|
|
# Type filter - cache institution type filtering results
|
|
if task_type == SubTaskType.TYPE_FILTER:
|
|
inst_type = task.parameters.get("institution_type")
|
|
retrieved = getattr(pipeline_result, "retrieved_results", None)
|
|
if retrieved and isinstance(retrieved, list):
|
|
# Filter results to just this institution type
|
|
filtered = [r for r in retrieved if r.get("institution_type") == inst_type]
|
|
return {
|
|
"institution_type": inst_type,
|
|
"count": len(filtered),
|
|
"sample_ids": [r.get("id") for r in filtered[:10]], # Cache IDs not full records
|
|
}
|
|
|
|
# Location filter - cache geographic filtering results
|
|
if task_type == SubTaskType.LOCATION_FILTER:
|
|
location = task.parameters.get("location")
|
|
retrieved = getattr(pipeline_result, "retrieved_results", None)
|
|
if retrieved and isinstance(retrieved, list):
|
|
# Count results in this location
|
|
location_lower = location.lower() if location else ""
|
|
in_location = [
|
|
r for r in retrieved
|
|
if location_lower in str(r.get("city", "")).lower() or
|
|
location_lower in str(r.get("region", "")).lower()
|
|
]
|
|
return {
|
|
"location": location,
|
|
"level": task.parameters.get("level"),
|
|
"count": len(in_location),
|
|
}
|
|
|
|
# Aggregation - cache aggregate statistics
|
|
if task_type == SubTaskType.AGGREGATION:
|
|
agg_type = task.parameters.get("aggregation")
|
|
retrieved = getattr(pipeline_result, "retrieved_results", None)
|
|
if agg_type == "count" and retrieved:
|
|
return {
|
|
"aggregation": "count",
|
|
"value": len(retrieved) if isinstance(retrieved, list) else 0,
|
|
}
|
|
|
|
# Identifier filter - cache identifier lookup results
|
|
if task_type == SubTaskType.IDENTIFIER_FILTER:
|
|
id_type = task.parameters.get("identifier_type")
|
|
retrieved = getattr(pipeline_result, "retrieved_results", None)
|
|
if retrieved and isinstance(retrieved, list):
|
|
# Count entities with this identifier type
|
|
has_id = [
|
|
r for r in retrieved
|
|
if r.get(f"has_{id_type}") or r.get(id_type)
|
|
]
|
|
return {
|
|
"identifier_type": id_type,
|
|
"has_identifier_count": len(has_id),
|
|
}
|
|
|
|
# Default: don't cache if we can't extract meaningful sub-result
|
|
return None
|
|
|
|
|
|
@app.post("/api/rag/dspy/query", response_model=DSPyQueryResponse)
|
|
async def dspy_query(request: DSPyQueryRequest) -> DSPyQueryResponse:
|
|
"""DSPy RAG query endpoint with multi-turn conversation support.
|
|
|
|
Uses the HeritageRAGPipeline for conversation-aware question answering.
|
|
Follow-up questions like "Welke daarvan behoren archieven?" will be
|
|
resolved using previous conversation context.
|
|
|
|
Args:
|
|
request: Query request with question, language, and conversation context
|
|
|
|
Returns:
|
|
DSPyQueryResponse with answer, resolved question, and optional visualization
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
# Session management for multi-turn conversations
|
|
# Get or create session state that enables follow-up question resolution
|
|
session_id = request.session_id
|
|
conversation_state = None
|
|
session_mgr = None
|
|
|
|
if SESSION_MANAGER_AVAILABLE and get_session_manager:
|
|
try:
|
|
session_mgr = await get_session_manager()
|
|
session_id, conversation_state = await session_mgr.get_or_create(request.session_id)
|
|
logger.debug(f"Session {session_id}: {len(conversation_state.turns)} previous turns")
|
|
except Exception as e:
|
|
logger.warning(f"Session manager error (continuing without session): {e}")
|
|
# Generate a new session_id even if session manager failed
|
|
import uuid
|
|
session_id = str(uuid.uuid4())
|
|
else:
|
|
# No session manager - generate session_id for tracking purposes
|
|
import uuid
|
|
session_id = request.session_id or str(uuid.uuid4())
|
|
|
|
# Resolve the provider BEFORE cache lookup to ensure consistent cache keys
|
|
# This is critical: cache GET and SET must use the same provider value
|
|
resolved_provider = (request.llm_provider or settings.llm_provider).lower()
|
|
|
|
# Check cache first (before expensive LLM configuration) unless skip_cache is True
|
|
if retriever and not request.skip_cache:
|
|
cached = await retriever.cache.get_dspy(
|
|
question=request.question,
|
|
language=request.language,
|
|
llm_provider=resolved_provider, # Use resolved provider, not request.llm_provider
|
|
embedding_model=request.embedding_model,
|
|
context=request.context if request.context else None,
|
|
)
|
|
if cached:
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
logger.info(f"DSPy cache hit - returning cached response in {elapsed_ms:.2f}ms")
|
|
|
|
# Transform CachedResponse format back to DSPyQueryResponse format
|
|
# CachedResponse has: sources, visualization_type, visualization_data, context
|
|
# DSPyQueryResponse needs: sources_used, visualization, query_type, etc.
|
|
cached_context = cached.get("context") or {}
|
|
visualization = None
|
|
if cached.get("visualization_type") or cached.get("visualization_data"):
|
|
visualization = {
|
|
"type": cached.get("visualization_type"),
|
|
"data": cached.get("visualization_data"),
|
|
}
|
|
|
|
# Restore llm_response metadata (GLM 4.7 reasoning_content) from cache
|
|
llm_response_cached = cached_context.get("llm_response")
|
|
llm_response_obj = None
|
|
if llm_response_cached:
|
|
try:
|
|
llm_response_obj = LLMResponseMetadata(**llm_response_cached)
|
|
except Exception:
|
|
# Fall back to dict if LLMResponseMetadata fails
|
|
llm_response_obj = llm_response_cached # type: ignore[assignment]
|
|
|
|
# Rule 46: Build provenance for cache hit responses
|
|
cached_sources = cached.get("sources", [])
|
|
cached_template_used = cached_context.get("template_used", False)
|
|
cached_template_id = cached_context.get("template_id")
|
|
cached_llm_provider = cached_context.get("llm_provider")
|
|
cached_llm_model = cached_context.get("llm_model")
|
|
|
|
# Infer data tier - prioritize cached provenance if present
|
|
cached_provenance = cached_context.get("epistemic_provenance")
|
|
if cached_provenance:
|
|
# Use the cached provenance, but mark it as coming from cache
|
|
cache_provenance = cached_provenance.copy()
|
|
if "CACHE" not in cache_provenance.get("derivationChain", []):
|
|
cache_provenance.setdefault("derivationChain", []).insert(0, "CACHE:hit")
|
|
else:
|
|
# Build fresh provenance for older cache entries
|
|
cache_tier = DataTier.TIER_3_CROWD_SOURCED.value
|
|
if cached_template_used:
|
|
cache_tier = DataTier.TIER_1_AUTHORITATIVE.value
|
|
elif any(s.lower() in ["sparql", "typedb"] for s in cached_sources):
|
|
cache_tier = DataTier.TIER_1_AUTHORITATIVE.value
|
|
|
|
cache_provenance = EpistemicProvenance(
|
|
dataSource=EpistemicDataSource.CACHE_AGGREGATION,
|
|
dataTier=cache_tier,
|
|
derivationChain=["CACHE:hit"] + build_derivation_chain(
|
|
sources_used=cached_sources,
|
|
template_used=cached_template_used,
|
|
template_id=cached_template_id,
|
|
llm_provider=cached_llm_provider,
|
|
),
|
|
sourcesQueried=cached_sources,
|
|
templateUsed=cached_template_used,
|
|
templateId=cached_template_id,
|
|
llmProvider=cached_llm_provider,
|
|
llmModel=cached_llm_model,
|
|
).model_dump()
|
|
|
|
response_data = {
|
|
"question": request.question,
|
|
"answer": cached.get("answer", ""),
|
|
"sources_used": cached_sources,
|
|
"visualization": visualization,
|
|
"resolved_question": cached_context.get("resolved_question"),
|
|
"retrieved_results": cached_context.get("retrieved_results"),
|
|
"query_type": cached_context.get("query_type"),
|
|
"embedding_model_used": cached_context.get("embedding_model"),
|
|
"llm_model_used": cached_llm_model,
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"cache_hit": True,
|
|
"llm_response": llm_response_obj, # GLM 4.7 reasoning_content from cache
|
|
# Session management - return session_id for follow-up queries
|
|
"session_id": session_id,
|
|
# Template tracking from cache
|
|
"template_used": cached_template_used,
|
|
"template_id": cached_template_id,
|
|
# Rule 46: Epistemic provenance for transparency
|
|
"epistemic_provenance": cache_provenance,
|
|
}
|
|
|
|
# Record cache hit metrics
|
|
if METRICS_AVAILABLE and record_query:
|
|
try:
|
|
record_query(
|
|
endpoint="dspy_query",
|
|
template_used=cached_context.get("template_used", False),
|
|
template_id=cached_context.get("template_id"),
|
|
cache_hit=True,
|
|
status="success",
|
|
duration_seconds=elapsed_ms / 1000,
|
|
intent=cached_context.get("query_type"),
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to record cache hit metrics: {e}")
|
|
|
|
return DSPyQueryResponse(**response_data)
|
|
|
|
# === ATOMIC SUB-TASK CACHING ===
|
|
# Full query cache miss - try atomic decomposition for partial cache hits
|
|
# Research shows 40-70% cache hit rates with atomic decomposition
|
|
decomposed_query = None
|
|
cached_subtasks: dict[str, Any] = {}
|
|
|
|
if ATOMIC_CACHE_AVAILABLE and atomic_cache_manager:
|
|
try:
|
|
decomposed_query, cached_subtasks = await atomic_cache_manager.process_query(
|
|
query=request.question,
|
|
language=request.language,
|
|
)
|
|
|
|
# Record atomic cache metrics
|
|
subtask_hits = decomposed_query.partial_cache_hits if decomposed_query else 0
|
|
subtask_misses = len(decomposed_query.sub_tasks) - subtask_hits if decomposed_query else 0
|
|
|
|
if decomposed_query.fully_cached:
|
|
# All sub-tasks are cached - can potentially skip LLM
|
|
logger.info(f"Atomic cache: fully cached ({len(decomposed_query.sub_tasks)} sub-tasks)")
|
|
if record_atomic_cache:
|
|
record_atomic_cache(
|
|
query_hit=True,
|
|
subtask_hits=subtask_hits,
|
|
subtask_misses=0,
|
|
fully_assembled=True,
|
|
)
|
|
elif decomposed_query.partial_cache_hits > 0:
|
|
# Partial cache hit - some sub-tasks cached
|
|
logger.info(
|
|
f"Atomic cache: partial hit ({decomposed_query.partial_cache_hits}/"
|
|
f"{len(decomposed_query.sub_tasks)} sub-tasks cached)"
|
|
)
|
|
if record_atomic_cache:
|
|
record_atomic_cache(
|
|
query_hit=False,
|
|
subtask_hits=subtask_hits,
|
|
subtask_misses=subtask_misses,
|
|
fully_assembled=False,
|
|
)
|
|
else:
|
|
logger.debug(f"Atomic cache: miss (0/{len(decomposed_query.sub_tasks)} sub-tasks)")
|
|
if record_atomic_cache:
|
|
record_atomic_cache(
|
|
query_hit=False,
|
|
subtask_hits=0,
|
|
subtask_misses=subtask_misses,
|
|
fully_assembled=False,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Atomic decomposition failed: {e}")
|
|
|
|
# ==========================================================================
|
|
# FACTUAL QUERY FAST PATH: Skip LLM for count/list queries
|
|
# ==========================================================================
|
|
# For factual queries (counts, lists, comparisons), the SPARQL results ARE
|
|
# the answer. No need for expensive LLM prose generation - just return the
|
|
# table directly. This can reduce latency from ~15s to ~2s.
|
|
# ==========================================================================
|
|
try:
|
|
from template_sparql import get_template_pipeline
|
|
|
|
template_pipeline = get_template_pipeline()
|
|
|
|
# Try template matching (this handles follow-up resolution internally)
|
|
# Note: conversation_state already contains history from request.context
|
|
# Run in thread pool to avoid blocking the event loop (DSPy is synchronous)
|
|
template_result = await asyncio.to_thread(
|
|
template_pipeline,
|
|
question=request.question,
|
|
language=request.language,
|
|
conversation_state=conversation_state,
|
|
)
|
|
|
|
# Check if this is a factual query that can skip LLM (template-driven, not hardcoded)
|
|
# Fast path rule: If "prose" is NOT in response_modes, LLM generation is skipped
|
|
if template_result.matched and not template_result.requires_llm():
|
|
# Log database routing decision
|
|
databases_used = template_result.databases if hasattr(template_result, 'databases') else ["oxigraph", "qdrant"]
|
|
qdrant_skipped = "qdrant" not in databases_used
|
|
logger.info(
|
|
f"[FAST-PATH] Template '{template_result.template_id}' uses response_modes={template_result.response_modes}, "
|
|
f"databases={databases_used} - skipping LLM generation{', Qdrant skipped' if qdrant_skipped else ''} "
|
|
f"(confidence={template_result.confidence:.2f})"
|
|
)
|
|
|
|
# Execute SPARQL directly
|
|
sparql_query = template_result.sparql
|
|
sparql_results: list[dict[str, Any]] = []
|
|
sparql_error: str | None = None
|
|
|
|
try:
|
|
if retriever:
|
|
client = await retriever._get_sparql_client()
|
|
response = await client.post(
|
|
settings.sparql_endpoint,
|
|
data={"query": sparql_query},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
timeout=30.0,
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
bindings = data.get("results", {}).get("bindings", [])
|
|
raw_results = [
|
|
{k: v.get("value") for k, v in binding.items()}
|
|
for binding in bindings
|
|
]
|
|
|
|
# Check if this is a COUNT query (raw_results has 'count' key)
|
|
# COUNT queries return [{"count": "10"}] - don't transform these
|
|
is_count_query = raw_results and "count" in raw_results[0]
|
|
|
|
if is_count_query:
|
|
# For COUNT queries, preserve raw results with count value
|
|
# Convert count string to int for template rendering
|
|
sparql_results = []
|
|
for row in raw_results:
|
|
count_val = row.get("count", "0")
|
|
try:
|
|
count_int = int(count_val)
|
|
except (ValueError, TypeError):
|
|
count_int = 0
|
|
sparql_results.append({
|
|
"count": count_int,
|
|
"metadata": {
|
|
"institution_type": template_result.slots.get("institution_type"),
|
|
},
|
|
"scores": {"combined": 1.0},
|
|
})
|
|
logger.debug(f"[FACTUAL-QUERY] COUNT query result: {sparql_results[0].get('count') if sparql_results else 0}")
|
|
|
|
# Execute companion query if available to get entity results for map/list
|
|
# This fetches the actual institution records that were counted
|
|
companion_query = getattr(template_result, 'companion_query', None)
|
|
if companion_query:
|
|
try:
|
|
companion_response = await client.post(
|
|
settings.sparql_endpoint,
|
|
data={"query": companion_query},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
timeout=30.0,
|
|
)
|
|
|
|
if companion_response.status_code == 200:
|
|
companion_data = companion_response.json()
|
|
companion_bindings = companion_data.get("results", {}).get("bindings", [])
|
|
companion_raw = [
|
|
{k: v.get("value") for k, v in binding.items()}
|
|
for binding in companion_bindings
|
|
]
|
|
|
|
# Transform companion results to frontend format
|
|
companion_results = []
|
|
for row in companion_raw:
|
|
lat = None
|
|
lon = None
|
|
if row.get("lat"):
|
|
try:
|
|
lat = float(row["lat"])
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if row.get("lon"):
|
|
try:
|
|
lon = float(row["lon"])
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
companion_results.append({
|
|
"name": row.get("name"),
|
|
"institution_uri": row.get("institution"),
|
|
"metadata": {
|
|
"latitude": lat,
|
|
"longitude": lon,
|
|
"city": row.get("city") or template_result.slots.get("city"),
|
|
"institution_type": template_result.slots.get("institution_type"),
|
|
},
|
|
"scores": {"combined": 1.0},
|
|
})
|
|
|
|
# Store companion results - these will be used for map/list display
|
|
# while sparql_results contains the count for the answer text
|
|
if companion_results:
|
|
logger.info(f"[COMPANION-QUERY] Fetched {len(companion_results)} entities for display, {sum(1 for r in companion_results if r['metadata'].get('latitude'))} with coordinates")
|
|
# Replace sparql_results with companion results for display
|
|
# but preserve the count value for answer rendering
|
|
count_value = sparql_results[0].get("count", 0) if sparql_results else 0
|
|
sparql_results = companion_results
|
|
# Add count to first result so it's available for ui_template
|
|
if sparql_results:
|
|
sparql_results[0]["count"] = count_value
|
|
else:
|
|
logger.warning(f"[COMPANION-QUERY] Failed with status {companion_response.status_code}")
|
|
except Exception as ce:
|
|
logger.warning(f"[COMPANION-QUERY] Execution failed: {ce}")
|
|
else:
|
|
# Transform SPARQL results to match frontend expected format
|
|
# Frontend expects: {name, website, metadata: {latitude, longitude, city, ...}}
|
|
# SPARQL returns: {name, website, lat, lon, city, ...}
|
|
sparql_results = []
|
|
for row in raw_results:
|
|
# Parse lat/lon to float if present
|
|
lat = None
|
|
lon = None
|
|
if row.get("lat"):
|
|
try:
|
|
lat = float(row["lat"])
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if row.get("lon"):
|
|
try:
|
|
lon = float(row["lon"])
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
transformed = {
|
|
"name": row.get("name"),
|
|
"website": row.get("website"),
|
|
"metadata": {
|
|
"latitude": lat,
|
|
"longitude": lon,
|
|
"city": row.get("city") or template_result.slots.get("city"),
|
|
"country": row.get("country") or template_result.slots.get("country"),
|
|
"region": row.get("region") or template_result.slots.get("region"),
|
|
"institution_type": row.get("type") or template_result.slots.get("institution_type"),
|
|
},
|
|
"scores": {"combined": 1.0}, # SPARQL results are exact matches
|
|
}
|
|
sparql_results.append(transformed)
|
|
|
|
logger.debug(f"[FACTUAL-QUERY] Transformed {len(sparql_results)} results, {sum(1 for r in sparql_results if r['metadata']['latitude'])} with coordinates")
|
|
else:
|
|
sparql_error = f"SPARQL returned {response.status_code}"
|
|
else:
|
|
sparql_error = "Retriever not available"
|
|
|
|
except Exception as e:
|
|
sparql_error = str(e)
|
|
logger.warning(f"[FACTUAL-QUERY] SPARQL execution failed: {e}")
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
# Generate answer using ui_template if available, otherwise fallback
|
|
if sparql_error:
|
|
answer = f"Er is een fout opgetreden bij het uitvoeren van de query: {sparql_error}"
|
|
elif not sparql_results:
|
|
answer = "Geen resultaten gevonden."
|
|
elif template_result.ui_template:
|
|
# Use template-defined UI template (template-driven answer formatting)
|
|
lang = request.language if request.language in template_result.ui_template else "nl"
|
|
ui_tmpl = template_result.ui_template.get(lang, template_result.ui_template.get("nl", ""))
|
|
|
|
# Build context for Jinja2 template rendering with human-readable labels
|
|
# The slots have resolved codes (M, NL-NH) but ui_template expects labels (musea, Noord-Holland)
|
|
template_context = {
|
|
"result_count": len(sparql_results),
|
|
"count": sparql_results[0].get("count", len(sparql_results)) if sparql_results else 0,
|
|
**template_result.slots # Include resolved slot values (codes)
|
|
}
|
|
|
|
# Add human-readable labels for common slot types
|
|
# Labels loaded from schema/reference files per Rule 41 (no hardcoding)
|
|
try:
|
|
from schema_labels import get_label_resolver
|
|
label_resolver = get_label_resolver()
|
|
INSTITUTION_TYPE_LABELS_NL = label_resolver.get_all_institution_type_labels("nl")
|
|
INSTITUTION_TYPE_LABELS_EN = label_resolver.get_all_institution_type_labels("en")
|
|
SUBREGION_LABELS = label_resolver.get_all_subregion_labels("nl")
|
|
except ImportError:
|
|
# Fallback if schema_labels module not available (shouldn't happen in prod)
|
|
logger.warning("schema_labels module not available, using inline fallback")
|
|
INSTITUTION_TYPE_LABELS_NL = {
|
|
"M": "musea", "L": "bibliotheken", "A": "archieven", "G": "galerijen",
|
|
"O": "overheidsinstellingen", "R": "onderzoekscentra", "C": "bedrijfsarchieven",
|
|
"U": "instellingen", "B": "botanische tuinen en dierentuinen",
|
|
"E": "onderwijsinstellingen", "S": "heemkundige kringen", "F": "monumenten",
|
|
"I": "immaterieel erfgoedgroepen", "X": "gecombineerde instellingen",
|
|
"P": "privéverzamelingen", "H": "religieuze erfgoedsites",
|
|
"D": "digitale platforms", "N": "erfgoedorganisaties", "T": "culinair erfgoed"
|
|
}
|
|
INSTITUTION_TYPE_LABELS_EN = {
|
|
"M": "museums", "L": "libraries", "A": "archives", "G": "galleries",
|
|
"O": "official institutions", "R": "research centers", "C": "corporate archives",
|
|
"U": "institutions", "B": "botanical gardens and zoos",
|
|
"E": "education providers", "S": "heritage societies", "F": "features",
|
|
"I": "intangible heritage groups", "X": "mixed institutions",
|
|
"P": "personal collections", "H": "holy sites",
|
|
"D": "digital platforms", "N": "heritage NGOs", "T": "taste/smell heritage"
|
|
}
|
|
SUBREGION_LABELS = {
|
|
"NL-DR": "Drenthe", "NL-FR": "Friesland", "NL-GE": "Gelderland",
|
|
"NL-GR": "Groningen", "NL-LI": "Limburg", "NL-NB": "Noord-Brabant",
|
|
"NL-NH": "Noord-Holland", "NL-OV": "Overijssel", "NL-UT": "Utrecht",
|
|
"NL-ZE": "Zeeland", "NL-ZH": "Zuid-Holland", "NL-FL": "Flevoland"
|
|
}
|
|
|
|
# Add institution_type_nl and institution_type_en labels
|
|
if "institution_type" in template_result.slots:
|
|
type_code = template_result.slots["institution_type"]
|
|
template_context["institution_type_nl"] = INSTITUTION_TYPE_LABELS_NL.get(type_code, type_code)
|
|
template_context["institution_type_en"] = INSTITUTION_TYPE_LABELS_EN.get(type_code, type_code)
|
|
|
|
# Add human-readable location label
|
|
if "location" in template_result.slots:
|
|
loc_code = template_result.slots["location"]
|
|
# Check if it's a subregion code
|
|
if loc_code in SUBREGION_LABELS:
|
|
template_context["location"] = SUBREGION_LABELS[loc_code]
|
|
# Otherwise keep the original (might already be a city name)
|
|
|
|
# Simple Jinja2-style replacement (avoids importing Jinja2)
|
|
answer = ui_tmpl
|
|
for key, value in template_context.items():
|
|
answer = answer.replace("{{ " + key + " }}", str(value))
|
|
answer = answer.replace("{{" + key + "}}", str(value))
|
|
elif "count" in template_result.response_modes:
|
|
# Count query - format as count
|
|
count_value = sparql_results[0].get("count", len(sparql_results))
|
|
answer = f"Aantal: {count_value}"
|
|
else:
|
|
# List/table query - just indicate result count
|
|
answer = f"Gevonden: {len(sparql_results)} resultaten. Zie de tabel hieronder."
|
|
|
|
# Determine visualization type from response_modes
|
|
viz_types = []
|
|
if "table" in template_result.response_modes:
|
|
viz_types.append("table")
|
|
if "chart" in template_result.response_modes:
|
|
viz_types.append("chart")
|
|
if "map" in template_result.response_modes:
|
|
viz_types.append("map")
|
|
|
|
# Build response with factual_result=True
|
|
factual_response = DSPyQueryResponse(
|
|
question=request.question,
|
|
resolved_question=getattr(template_result, "resolved_question", None),
|
|
answer=answer,
|
|
sources_used=["SPARQL Knowledge Graph"],
|
|
visualization={
|
|
"types": viz_types,
|
|
"primary_type": viz_types[0] if viz_types else "table",
|
|
"sparql_query": sparql_query,
|
|
"response_modes": template_result.response_modes,
|
|
"databases_used": databases_used, # For transparency/debugging
|
|
},
|
|
retrieved_results=sparql_results,
|
|
query_type="factual",
|
|
query_time_ms=round(elapsed_ms, 2),
|
|
conversation_turn=len(request.context),
|
|
cache_hit=False,
|
|
session_id=session_id,
|
|
template_used=True,
|
|
template_id=template_result.template_id,
|
|
factual_result=True,
|
|
sparql_query=sparql_query,
|
|
)
|
|
|
|
# Update session with this turn
|
|
if session_mgr and session_id:
|
|
try:
|
|
await session_mgr.add_turn_to_session(
|
|
session_id=session_id,
|
|
question=request.question,
|
|
answer=answer,
|
|
resolved_question=getattr(template_result, "resolved_question", None),
|
|
template_id=template_result.template_id,
|
|
slots=template_result.slots or {},
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to update session: {e}")
|
|
|
|
# Record metrics
|
|
if METRICS_AVAILABLE and record_query:
|
|
try:
|
|
record_query(
|
|
endpoint="dspy_query",
|
|
template_used=True,
|
|
template_id=template_result.template_id,
|
|
cache_hit=False,
|
|
status="success",
|
|
duration_seconds=elapsed_ms / 1000,
|
|
intent="factual",
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to record metrics: {e}")
|
|
|
|
# Cache the response
|
|
if retriever:
|
|
await retriever.cache.set_dspy(
|
|
question=request.question,
|
|
language=request.language,
|
|
llm_provider="none", # No LLM used
|
|
embedding_model=request.embedding_model,
|
|
response=factual_response.model_dump(),
|
|
context=request.context if request.context else None,
|
|
)
|
|
|
|
logger.info(f"[FACTUAL-QUERY] Returned {len(sparql_results)} results in {elapsed_ms:.2f}ms (LLM skipped)")
|
|
return factual_response
|
|
|
|
except ImportError as e:
|
|
logger.debug(f"Template SPARQL not available for factual query detection: {e}")
|
|
except Exception as e:
|
|
logger.warning(f"Factual query detection failed (continuing with full pipeline): {e}")
|
|
|
|
# ==========================================================================
|
|
# FULL RAG PIPELINE: For non-factual queries or when factual detection fails
|
|
# ==========================================================================
|
|
try:
|
|
# Import DSPy pipeline and History
|
|
import dspy
|
|
from dspy import History
|
|
from dspy_heritage_rag import HeritageRAGPipeline
|
|
|
|
# Configure DSPy LM per-request based on request.llm_provider (or server default)
|
|
# This allows frontend to switch LLM providers dynamically
|
|
#
|
|
# IMPORTANT: We use dspy.settings.context() instead of dspy.configure() because
|
|
# configure() can only be called from the same async task that initially configured DSPy.
|
|
# context() provides thread-local overrides that work correctly in async request handlers.
|
|
requested_provider = resolved_provider # Already resolved above
|
|
llm_provider_used: str | None = None
|
|
llm_model_used: str | None = None
|
|
lm = None
|
|
|
|
logger.info(f"LLM provider requested: {requested_provider} (request.llm_provider={request.llm_provider}, server default={settings.llm_provider})")
|
|
|
|
# Provider configuration priority: requested provider first, then fallback chain
|
|
providers_to_try = [requested_provider]
|
|
# Add fallback chain (but not duplicates)
|
|
for fallback in ["zai", "groq", "anthropic", "openai"]:
|
|
if fallback not in providers_to_try:
|
|
providers_to_try.append(fallback)
|
|
|
|
for provider in providers_to_try:
|
|
if lm is not None:
|
|
break
|
|
|
|
# Default models per provider (used if request.llm_model is not specified)
|
|
# Use LLM_MODEL from settings when it matches the provider prefix
|
|
default_models = {
|
|
"zai": settings.llm_model if settings.llm_model.startswith("glm-") else "glm-4.5-flash",
|
|
"groq": "llama-3.1-8b-instant",
|
|
"anthropic": settings.llm_model if settings.llm_model.startswith("claude-") else "claude-sonnet-4-20250514",
|
|
"openai": "gpt-4o-mini",
|
|
# Llama 3.1 8B: Good balance of speed/quality, available on HF serverless inference
|
|
# Alternatives: Qwen/QwQ-32B (better reasoning), mistralai/Mistral-7B-Instruct-v0.2
|
|
"huggingface": settings.llm_model if "/" in settings.llm_model else "meta-llama/Llama-3.1-8B-Instruct",
|
|
}
|
|
# HuggingFace models use org/model format (e.g., meta-llama/Llama-3.1-8B-Instruct)
|
|
# Groq models use simple names (e.g., llama-3.1-8b-instant)
|
|
model_prefixes = {
|
|
"glm-": "zai",
|
|
"llama-3.1-": "groq",
|
|
"llama-3.3-": "groq",
|
|
"claude-": "anthropic",
|
|
"gpt-": "openai",
|
|
# HuggingFace organization prefixes
|
|
"mistralai/": "huggingface",
|
|
"google/": "huggingface",
|
|
"Qwen/": "huggingface",
|
|
"deepseek-ai/": "huggingface",
|
|
"meta-llama/": "huggingface",
|
|
"utter-project/": "huggingface",
|
|
"microsoft/": "huggingface",
|
|
"tiiuae/": "huggingface",
|
|
}
|
|
|
|
# Determine which model to use: requested model (if valid for this provider) or default
|
|
requested_model = request.llm_model
|
|
model_to_use = default_models.get(provider, "")
|
|
|
|
# Check if requested model matches this provider
|
|
if requested_model:
|
|
for prefix, model_provider in model_prefixes.items():
|
|
if requested_model.startswith(prefix) and model_provider == provider:
|
|
model_to_use = requested_model
|
|
break
|
|
|
|
if provider == "zai" and settings.zai_api_token:
|
|
try:
|
|
lm = dspy.LM(
|
|
f"openai/{model_to_use}",
|
|
api_key=settings.zai_api_token,
|
|
api_base="https://api.z.ai/api/coding/paas/v4",
|
|
)
|
|
llm_provider_used = "zai"
|
|
llm_model_used = model_to_use
|
|
logger.info(f"Using Z.AI {model_to_use} (FREE) for this request")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create Z.AI LM: {e}")
|
|
|
|
elif provider == "groq" and settings.groq_api_key:
|
|
try:
|
|
lm = dspy.LM(f"groq/{model_to_use}", api_key=settings.groq_api_key)
|
|
llm_provider_used = "groq"
|
|
llm_model_used = model_to_use
|
|
logger.info(f"Using Groq {model_to_use} (FREE) for this request")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create Groq LM: {e}")
|
|
|
|
elif provider == "huggingface" and settings.huggingface_api_key:
|
|
try:
|
|
lm = dspy.LM(f"huggingface/{model_to_use}", api_key=settings.huggingface_api_key)
|
|
llm_provider_used = "huggingface"
|
|
llm_model_used = model_to_use
|
|
logger.info(f"Using HuggingFace {model_to_use} for this request")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create HuggingFace LM: {e}")
|
|
|
|
elif provider == "anthropic" and settings.anthropic_api_key:
|
|
try:
|
|
lm = dspy.LM(f"anthropic/{model_to_use}", api_key=settings.anthropic_api_key)
|
|
llm_provider_used = "anthropic"
|
|
llm_model_used = model_to_use
|
|
logger.info(f"Using Anthropic {model_to_use} for this request")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create Anthropic LM: {e}")
|
|
|
|
elif provider == "openai" and settings.openai_api_key:
|
|
try:
|
|
lm = dspy.LM(f"openai/{model_to_use}", api_key=settings.openai_api_key)
|
|
llm_provider_used = "openai"
|
|
llm_model_used = model_to_use
|
|
logger.info(f"Using OpenAI {model_to_use} for this request")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create OpenAI LM: {e}")
|
|
|
|
# No LM could be configured
|
|
if lm is None:
|
|
raise ValueError(
|
|
f"No LLM could be configured. Requested provider: {requested_provider}. "
|
|
"Ensure the appropriate API key is set: ZAI_API_TOKEN, GROQ_API_KEY, ANTHROPIC_API_KEY, HUGGINGFACE_API_KEY, or OPENAI_API_KEY."
|
|
)
|
|
|
|
logger.info(f"LLM provider for this request: {llm_provider_used}")
|
|
|
|
# =================================================================
|
|
# PERFORMANCE OPTIMIZATION: Create fast LM for routing/extraction
|
|
# Use a fast, cheap model (glm-4.5-flash FREE, gpt-4o-mini $0.15/1M)
|
|
# for routing, entity extraction, and SPARQL generation.
|
|
# The quality_lm (lm) is used only for final answer generation.
|
|
# This can reduce total latency by 2-3x (from ~20s to ~7s).
|
|
# =================================================================
|
|
fast_lm = None
|
|
|
|
# Try to create fast_lm based on FAST_LM_PROVIDER setting
|
|
# Options: "openai" (fast ~1-2s, $0.15/1M) or "zai" (FREE but slow ~13s)
|
|
# Default: openai for speed. Override with FAST_LM_PROVIDER=zai to save costs.
|
|
|
|
if settings.fast_lm_provider == "openai" and settings.openai_api_key:
|
|
try:
|
|
fast_lm = dspy.LM("openai/gpt-4o-mini", api_key=settings.openai_api_key)
|
|
logger.info("Using OpenAI GPT-4o-mini as fast_lm for routing/extraction (~1-2s)")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create fast OpenAI LM: {e}")
|
|
|
|
if fast_lm is None and settings.fast_lm_provider == "zai" and settings.zai_api_token:
|
|
try:
|
|
fast_lm = dspy.LM(
|
|
"openai/glm-4.5-flash",
|
|
api_key=settings.zai_api_token,
|
|
api_base="https://api.z.ai/api/coding/paas/v4",
|
|
)
|
|
logger.info("Using Z.AI GLM-4.5-flash (FREE) as fast_lm for routing/extraction (~13s)")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create fast Z.AI LM: {e}")
|
|
|
|
# Fallback: try the other provider if preferred one failed
|
|
if fast_lm is None and settings.openai_api_key:
|
|
try:
|
|
fast_lm = dspy.LM("openai/gpt-4o-mini", api_key=settings.openai_api_key)
|
|
logger.info("Fallback: Using OpenAI GPT-4o-mini as fast_lm")
|
|
except Exception as e:
|
|
logger.warning(f"Fallback failed - no fast_lm available: {e}")
|
|
|
|
if fast_lm is None:
|
|
logger.info("No fast_lm available - all stages will use quality_lm (slower but works)")
|
|
|
|
# Convert context to DSPy History format
|
|
# Context comes as [{question: "...", answer: "..."}, ...]
|
|
# History expects messages in the same format: [{question: "...", answer: "..."}, ...]
|
|
# (NOT role/content format - that was a bug!)
|
|
history_messages = []
|
|
for turn in request.context:
|
|
# Only include turns that have both question AND answer
|
|
if turn.get("question") and turn.get("answer"):
|
|
history_messages.append({
|
|
"question": turn["question"],
|
|
"answer": turn["answer"]
|
|
})
|
|
|
|
history = History(messages=history_messages) if history_messages else None
|
|
|
|
# Use global optimized pipeline (loaded with BootstrapFewShot weights: +14.3% quality)
|
|
# Falls back to creating a new pipeline if global not available
|
|
if dspy_pipeline is not None:
|
|
pipeline = dspy_pipeline
|
|
logger.debug("Using global optimized DSPy pipeline")
|
|
else:
|
|
# Fallback: create pipeline without optimized weights
|
|
qdrant_retriever = retriever.qdrant if retriever else None
|
|
pipeline = HeritageRAGPipeline(
|
|
retriever=qdrant_retriever,
|
|
fast_lm=fast_lm,
|
|
quality_lm=lm,
|
|
)
|
|
logger.debug("Using fallback (unoptimized) DSPy pipeline")
|
|
|
|
# Execute query with conversation history
|
|
# Retry logic for transient API errors (e.g., Anthropic "Overloaded" errors)
|
|
#
|
|
# IMPORTANT: We use dspy.settings.context(lm=lm) to set the LLM for this request.
|
|
# This provides thread-local overrides that work correctly in async request handlers,
|
|
# unlike dspy.configure() which can only be called from the main async task.
|
|
max_retries = 3
|
|
last_error: Exception | None = None
|
|
result = None
|
|
|
|
# Helper function to run pipeline synchronously (for asyncio.to_thread)
|
|
def run_pipeline_sync():
|
|
"""Run DSPy pipeline in sync context with retry logic."""
|
|
nonlocal last_error, result
|
|
with dspy.settings.context(lm=lm):
|
|
for attempt in range(max_retries):
|
|
try:
|
|
# Use pipeline() instead of pipeline.forward() per DSPy 3.0 best practices
|
|
return pipeline(
|
|
embedding_model=request.embedding_model,
|
|
question=request.question,
|
|
language=request.language,
|
|
history=history,
|
|
include_viz=request.include_visualization,
|
|
conversation_state=conversation_state, # Pass session state for template SPARQL
|
|
)
|
|
except Exception as e:
|
|
last_error = e
|
|
error_str = str(e).lower()
|
|
# Check for retryable errors (API overload, rate limits, temporary failures)
|
|
is_retryable = any(keyword in error_str for keyword in [
|
|
"overloaded", "rate_limit", "rate limit", "too many requests",
|
|
"529", "503", "502", "504", # HTTP status codes
|
|
"temporarily unavailable", "service unavailable",
|
|
"connection reset", "connection refused", "timeout"
|
|
])
|
|
|
|
if is_retryable and attempt < max_retries - 1:
|
|
wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
|
|
logger.warning(
|
|
f"Transient API error (attempt {attempt + 1}/{max_retries}): {e}. "
|
|
f"Retrying in {wait_time}s..."
|
|
)
|
|
time.sleep(wait_time) # OK to block in thread pool
|
|
continue
|
|
else:
|
|
# Non-retryable error or max retries reached
|
|
raise
|
|
return None
|
|
|
|
# Run DSPy pipeline in thread pool to avoid blocking the event loop
|
|
result = await asyncio.to_thread(run_pipeline_sync)
|
|
|
|
# If we get here without a result (all retries exhausted), raise the last error
|
|
if result is None:
|
|
if last_error:
|
|
raise last_error
|
|
raise HTTPException(status_code=500, detail="Pipeline execution failed with no result")
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
# Extract retrieved results for frontend visualization (tables, graphs)
|
|
retrieved_results = getattr(result, "retrieved_results", None)
|
|
query_type = getattr(result, "query_type", None)
|
|
|
|
# Extract visualization if present
|
|
visualization = None
|
|
if request.include_visualization and hasattr(result, "visualization"):
|
|
viz = result.visualization
|
|
if viz:
|
|
# Now showing SPARQL for all query types including person queries
|
|
# Person queries use HeritagePersonSPARQLGenerator (schema:Person predicates)
|
|
# Institution queries use HeritageSPARQLGenerator (crm:E39_Actor predicates)
|
|
sparql_to_show = getattr(result, "sparql", None)
|
|
|
|
visualization = {
|
|
"type": getattr(viz, "viz_type", "table"),
|
|
"sparql_query": sparql_to_show,
|
|
}
|
|
|
|
# Extract LLM response metadata from DSPy history (GLM 4.7 reasoning_content support)
|
|
llm_response_metadata = extract_llm_response_metadata(
|
|
lm=lm,
|
|
provider=llm_provider_used,
|
|
latency_ms=int(elapsed_ms),
|
|
)
|
|
|
|
# Extract template SPARQL info from result
|
|
template_used = getattr(result, "template_used", False)
|
|
template_id = getattr(result, "template_id", None)
|
|
|
|
# Rule 46: Build epistemic provenance for transparency
|
|
# This tracks WHERE, WHEN, and HOW the response data originated
|
|
sources_used_list = getattr(result, "sources_used", [])
|
|
|
|
# Infer data tier from sources - SPARQL/TypeDB are authoritative, Qdrant may include scraped data
|
|
inferred_tier = DataTier.TIER_3_CROWD_SOURCED.value # Default
|
|
if template_used:
|
|
# Template-based SPARQL uses curated Oxigraph data
|
|
inferred_tier = DataTier.TIER_1_AUTHORITATIVE.value
|
|
elif any(s.lower() in ["sparql", "typedb"] for s in sources_used_list):
|
|
inferred_tier = DataTier.TIER_1_AUTHORITATIVE.value
|
|
elif any(s.lower() == "qdrant" for s in sources_used_list):
|
|
inferred_tier = DataTier.TIER_3_CROWD_SOURCED.value
|
|
|
|
# Build provenance object
|
|
response_provenance = EpistemicProvenance(
|
|
dataSource=EpistemicDataSource.RAG_PIPELINE,
|
|
dataTier=inferred_tier,
|
|
derivationChain=build_derivation_chain(
|
|
sources_used=sources_used_list,
|
|
template_used=template_used,
|
|
template_id=template_id,
|
|
llm_provider=llm_provider_used,
|
|
),
|
|
sourcesQueried=sources_used_list,
|
|
totalRetrieved=len(retrieved_results) if retrieved_results else 0,
|
|
totalAfterFusion=len(retrieved_results) if retrieved_results else 0,
|
|
templateUsed=template_used,
|
|
templateId=template_id,
|
|
llmProvider=llm_provider_used,
|
|
llmModel=llm_model_used,
|
|
)
|
|
|
|
# Build response object
|
|
response = DSPyQueryResponse(
|
|
question=request.question,
|
|
resolved_question=getattr(result, "resolved_question", None),
|
|
answer=getattr(result, "answer", "Geen antwoord gevonden."),
|
|
sources_used=sources_used_list,
|
|
visualization=visualization,
|
|
retrieved_results=retrieved_results, # Raw data for frontend visualization
|
|
query_type=query_type, # "person" or "institution"
|
|
query_time_ms=round(elapsed_ms, 2),
|
|
conversation_turn=len(request.context),
|
|
embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model),
|
|
# Cost tracking fields
|
|
timing_ms=getattr(result, "timing_ms", None),
|
|
cost_usd=getattr(result, "cost_usd", None),
|
|
timing_breakdown=getattr(result, "timing_breakdown", None),
|
|
# LLM provider tracking
|
|
llm_provider_used=llm_provider_used,
|
|
llm_model_used=llm_model_used,
|
|
cache_hit=False,
|
|
# LLM response provenance (GLM 4.7 Thinking Mode chain-of-thought)
|
|
llm_response=llm_response_metadata,
|
|
# Session management - return session_id for follow-up queries
|
|
session_id=session_id,
|
|
# Template SPARQL tracking
|
|
template_used=template_used,
|
|
template_id=template_id,
|
|
# Rule 46: Epistemic provenance for transparency
|
|
epistemic_provenance=response_provenance.model_dump(),
|
|
)
|
|
|
|
# Update session with this turn for multi-turn conversation support
|
|
if session_mgr and session_id:
|
|
try:
|
|
await session_mgr.add_turn_to_session(
|
|
session_id=session_id,
|
|
question=request.question,
|
|
answer=response.answer,
|
|
resolved_question=response.resolved_question,
|
|
template_id=template_id,
|
|
slots=getattr(result, "slots", {}), # Extracted slots for follow-up inheritance
|
|
)
|
|
logger.debug(f"Session {session_id} updated with new turn")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to update session {session_id}: {e}")
|
|
|
|
# Record Prometheus metrics for monitoring
|
|
if METRICS_AVAILABLE and record_query:
|
|
try:
|
|
record_query(
|
|
endpoint="dspy_query",
|
|
template_used=template_used,
|
|
template_id=template_id,
|
|
cache_hit=False,
|
|
status="success",
|
|
duration_seconds=elapsed_ms / 1000,
|
|
intent=query_type,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to record metrics: {e}")
|
|
|
|
# Cache the successful response for future requests
|
|
if retriever:
|
|
await retriever.cache.set_dspy(
|
|
question=request.question,
|
|
language=request.language,
|
|
llm_provider=llm_provider_used, # Use actual provider, not requested
|
|
embedding_model=request.embedding_model,
|
|
response=response.model_dump(),
|
|
context=request.context if request.context else None,
|
|
)
|
|
|
|
# === CACHE ATOMIC SUB-TASKS FOR FUTURE QUERIES ===
|
|
# Cache individual sub-tasks for higher hit rates on similar queries
|
|
# E.g., "musea in Amsterdam" sub-task can be reused for
|
|
# "Hoeveel musea in Amsterdam hebben een website?"
|
|
if ATOMIC_CACHE_AVAILABLE and atomic_cache_manager and decomposed_query:
|
|
try:
|
|
subtasks_cached = 0
|
|
for task in decomposed_query.sub_tasks:
|
|
if not task.cache_hit:
|
|
# Extract relevant result for this sub-task type
|
|
subtask_result = _extract_subtask_result(task, result, response)
|
|
if subtask_result is not None:
|
|
await atomic_cache_manager.cache_subtask_result(
|
|
task=task,
|
|
result=subtask_result,
|
|
language=request.language,
|
|
ttl=3600, # 1 hour TTL
|
|
)
|
|
subtasks_cached += 1
|
|
|
|
# Record subtasks cached metric
|
|
if subtasks_cached > 0 and record_atomic_subtask_cached:
|
|
record_atomic_subtask_cached(subtasks_cached)
|
|
|
|
# Log atomic cache stats periodically
|
|
stats = atomic_cache_manager.get_stats()
|
|
if stats["queries_decomposed"] % 10 == 0:
|
|
logger.info(
|
|
f"Atomic cache stats: {stats['subtask_hit_rate']}% hit rate, "
|
|
f"{stats['queries_decomposed']} queries, "
|
|
f"{stats['full_query_reassemblies']} fully cached"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cache atomic sub-tasks: {e}")
|
|
|
|
return response
|
|
|
|
except ImportError as e:
|
|
logger.warning(f"DSPy pipeline not available: {e}")
|
|
# Fallback to simple response
|
|
return DSPyQueryResponse(
|
|
question=request.question,
|
|
answer="DSPy pipeline is niet beschikbaar. Probeer de standaard /api/rag/query endpoint.",
|
|
query_time_ms=0,
|
|
conversation_turn=len(request.context),
|
|
embedding_model_used=request.embedding_model,
|
|
session_id=session_id, # Still return session_id even on error
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception("DSPy query failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
async def stream_dspy_query_response(
|
|
request: DSPyQueryRequest,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream DSPy query response with progress updates for long-running queries.
|
|
|
|
Yields NDJSON lines with status updates at each pipeline stage:
|
|
- {"type": "status", "stage": "cache", "message": "🔍 Cache controleren..."}
|
|
- {"type": "status", "stage": "config", "message": "⚙️ LLM configureren..."}
|
|
- {"type": "status", "stage": "routing", "message": "🧭 Vraag analyseren..."}
|
|
- {"type": "status", "stage": "retrieval", "message": "📊 Database doorzoeken..."}
|
|
- {"type": "status", "stage": "generation", "message": "💡 Antwoord genereren..."}
|
|
- {"type": "complete", "data": {...DSPyQueryResponse...}}
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
# Session management for multi-turn conversations
|
|
# Get or create session state that enables follow-up question resolution
|
|
session_id = request.session_id
|
|
conversation_state = None
|
|
session_mgr = None
|
|
|
|
if SESSION_MANAGER_AVAILABLE and get_session_manager:
|
|
try:
|
|
session_mgr = await get_session_manager()
|
|
session_id, conversation_state = await session_mgr.get_or_create(request.session_id)
|
|
logger.debug(f"Stream session {session_id}: {len(conversation_state.turns)} previous turns")
|
|
except Exception as e:
|
|
logger.warning(f"Session manager error (continuing without session): {e}")
|
|
import uuid
|
|
session_id = str(uuid.uuid4())
|
|
else:
|
|
import uuid
|
|
session_id = request.session_id or str(uuid.uuid4())
|
|
|
|
def emit_status(stage: str, message: str) -> str:
|
|
"""Helper to emit status JSON line."""
|
|
return json.dumps({
|
|
"type": "status",
|
|
"stage": stage,
|
|
"message": message,
|
|
"elapsed_ms": round((time.time() - start_time) * 1000, 2),
|
|
}) + "\n"
|
|
|
|
def emit_error(error: str, details: str | None = None) -> str:
|
|
"""Helper to emit error JSON line."""
|
|
return json.dumps({
|
|
"type": "error",
|
|
"error": error,
|
|
"details": details,
|
|
"elapsed_ms": round((time.time() - start_time) * 1000, 2),
|
|
}) + "\n"
|
|
|
|
def extract_user_friendly_error(exception: Exception) -> tuple[str, str | None]:
|
|
"""Extract a user-friendly error message from various exception types.
|
|
|
|
Returns:
|
|
tuple: (user_message, technical_details)
|
|
"""
|
|
error_str = str(exception)
|
|
error_lower = error_str.lower()
|
|
|
|
# HuggingFace / LiteLLM specific errors
|
|
if "huggingface" in error_lower or "hf" in error_lower:
|
|
if "model_not_supported" in error_lower or "not a chat model" in error_lower:
|
|
# Extract model name if present
|
|
import re
|
|
model_match = re.search(r"model['\"]?\s*[:=]\s*['\"]?([^'\"}\s,]+)", error_str)
|
|
model_name = model_match.group(1) if model_match else "geselecteerde model"
|
|
return (
|
|
f"Het model '{model_name}' wordt niet ondersteund door HuggingFace. Kies een ander model.",
|
|
error_str
|
|
)
|
|
if "rate limit" in error_lower or "too many requests" in error_lower:
|
|
return (
|
|
"HuggingFace API limiet bereikt. Probeer het over een minuut opnieuw.",
|
|
error_str
|
|
)
|
|
if "unauthorized" in error_lower or "invalid api key" in error_lower:
|
|
return (
|
|
"HuggingFace API sleutel ongeldig. Neem contact op met de beheerder.",
|
|
error_str
|
|
)
|
|
if "model is loading" in error_lower or "loading" in error_lower and "model" in error_lower:
|
|
return (
|
|
"Het HuggingFace model wordt geladen. Probeer het over 30 seconden opnieuw.",
|
|
error_str
|
|
)
|
|
|
|
# Anthropic errors
|
|
if "anthropic" in error_lower:
|
|
if "rate limit" in error_lower or "overloaded" in error_lower:
|
|
return (
|
|
"Anthropic API is overbelast. Probeer het over een minuut opnieuw.",
|
|
error_str
|
|
)
|
|
if "invalid api key" in error_lower or "unauthorized" in error_lower:
|
|
return (
|
|
"Anthropic API sleutel ongeldig. Neem contact op met de beheerder.",
|
|
error_str
|
|
)
|
|
|
|
# OpenAI errors
|
|
if "openai" in error_lower:
|
|
if "rate limit" in error_lower:
|
|
return (
|
|
"OpenAI API limiet bereikt. Probeer het over een minuut opnieuw.",
|
|
error_str
|
|
)
|
|
if "invalid api key" in error_lower:
|
|
return (
|
|
"OpenAI API sleutel ongeldig. Neem contact op met de beheerder.",
|
|
error_str
|
|
)
|
|
|
|
# Z.AI errors
|
|
if "z.ai" in error_lower or "zai" in error_lower:
|
|
if "rate limit" in error_lower or "quota" in error_lower:
|
|
return (
|
|
"Z.AI API limiet bereikt. Probeer het over een minuut opnieuw.",
|
|
error_str
|
|
)
|
|
|
|
# Generic network/connection errors
|
|
if "connection" in error_lower or "timeout" in error_lower:
|
|
return (
|
|
"Verbindingsfout met de AI service. Controleer uw internetverbinding en probeer het opnieuw.",
|
|
error_str
|
|
)
|
|
|
|
if "503" in error_str or "service unavailable" in error_lower:
|
|
return (
|
|
"De AI service is tijdelijk niet beschikbaar. Probeer het over een minuut opnieuw.",
|
|
error_str
|
|
)
|
|
|
|
# Qdrant/retrieval errors
|
|
if "qdrant" in error_lower:
|
|
return (
|
|
"Fout bij het doorzoeken van de database. Probeer het later opnieuw.",
|
|
error_str
|
|
)
|
|
|
|
# Default: return the raw error but in a nicer format
|
|
return (
|
|
f"Er is een fout opgetreden: {error_str[:200]}{'...' if len(error_str) > 200 else ''}",
|
|
error_str if len(error_str) > 200 else None
|
|
)
|
|
|
|
# Resolve the provider BEFORE cache lookup to ensure consistent cache keys
|
|
# This is critical: cache GET and SET must use the same provider value
|
|
resolved_provider = (request.llm_provider or settings.llm_provider).lower()
|
|
|
|
# Stage 1: Check cache
|
|
yield emit_status("cache", "🔍 Cache controleren...")
|
|
|
|
if retriever:
|
|
cached = await retriever.cache.get_dspy(
|
|
question=request.question,
|
|
language=request.language,
|
|
llm_provider=resolved_provider, # Use resolved provider, not request.llm_provider
|
|
embedding_model=request.embedding_model,
|
|
context=request.context if request.context else None,
|
|
)
|
|
if cached:
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
logger.info(f"DSPy cache hit - returning cached response in {elapsed_ms:.2f}ms")
|
|
|
|
# Transform CachedResponse format back to DSPyQueryResponse format
|
|
cached_context = cached.get("context") or {}
|
|
visualization = None
|
|
if cached.get("visualization_type") or cached.get("visualization_data"):
|
|
visualization = {
|
|
"type": cached.get("visualization_type"),
|
|
"data": cached.get("visualization_data"),
|
|
}
|
|
|
|
# Rule 46: Build provenance for streaming cache hit responses
|
|
stream_cached_sources = cached.get("sources", [])
|
|
stream_cached_template_used = cached_context.get("template_used", False)
|
|
stream_cached_template_id = cached_context.get("template_id")
|
|
stream_cached_llm_provider = cached_context.get("llm_provider")
|
|
stream_cached_llm_model = cached_context.get("llm_model")
|
|
|
|
# Infer data tier - prioritize cached provenance if present
|
|
stream_cached_prov = cached_context.get("epistemic_provenance")
|
|
if stream_cached_prov:
|
|
# Use the cached provenance, but mark it as coming from cache
|
|
stream_cache_provenance = stream_cached_prov.copy()
|
|
if "CACHE" not in stream_cache_provenance.get("derivationChain", []):
|
|
stream_cache_provenance.setdefault("derivationChain", []).insert(0, "CACHE:hit")
|
|
else:
|
|
# Build fresh provenance for older cache entries
|
|
stream_cache_tier = DataTier.TIER_3_CROWD_SOURCED.value
|
|
if stream_cached_template_used:
|
|
stream_cache_tier = DataTier.TIER_1_AUTHORITATIVE.value
|
|
elif any(s.lower() in ["sparql", "typedb"] for s in stream_cached_sources):
|
|
stream_cache_tier = DataTier.TIER_1_AUTHORITATIVE.value
|
|
|
|
stream_cache_provenance = EpistemicProvenance(
|
|
dataSource=EpistemicDataSource.CACHE_AGGREGATION,
|
|
dataTier=stream_cache_tier,
|
|
derivationChain=["CACHE:hit"] + build_derivation_chain(
|
|
sources_used=stream_cached_sources,
|
|
template_used=stream_cached_template_used,
|
|
template_id=stream_cached_template_id,
|
|
llm_provider=stream_cached_llm_provider,
|
|
),
|
|
sourcesQueried=stream_cached_sources,
|
|
templateUsed=stream_cached_template_used,
|
|
templateId=stream_cached_template_id,
|
|
llmProvider=stream_cached_llm_provider,
|
|
llmModel=stream_cached_llm_model,
|
|
).model_dump()
|
|
|
|
response_data = {
|
|
"question": request.question,
|
|
"answer": cached.get("answer", ""),
|
|
"sources_used": stream_cached_sources,
|
|
"visualization": visualization,
|
|
"resolved_question": cached_context.get("resolved_question"),
|
|
"retrieved_results": cached_context.get("retrieved_results"),
|
|
"query_type": cached_context.get("query_type"),
|
|
"embedding_model_used": cached_context.get("embedding_model"),
|
|
"llm_model_used": stream_cached_llm_model,
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"cache_hit": True,
|
|
# Session management
|
|
"session_id": session_id,
|
|
# Template tracking from cache
|
|
"template_used": stream_cached_template_used,
|
|
"template_id": stream_cached_template_id,
|
|
# Rule 46: Epistemic provenance for transparency
|
|
"epistemic_provenance": stream_cache_provenance,
|
|
}
|
|
|
|
# Record cache hit metrics for streaming endpoint
|
|
if METRICS_AVAILABLE and record_query:
|
|
try:
|
|
record_query(
|
|
endpoint="dspy_query_stream",
|
|
template_used=cached_context.get("template_used", False),
|
|
template_id=cached_context.get("template_id"),
|
|
cache_hit=True,
|
|
status="success",
|
|
duration_seconds=elapsed_ms / 1000,
|
|
intent=cached_context.get("query_type"),
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to record streaming cache hit metrics: {e}")
|
|
|
|
yield emit_status("cache", "✅ Antwoord gevonden in cache!")
|
|
yield json.dumps({"type": "complete", "data": response_data}) + "\n"
|
|
return
|
|
|
|
try:
|
|
# Stage 2: Configure LLM
|
|
yield emit_status("config", "⚙️ LLM configureren...")
|
|
|
|
import dspy
|
|
from dspy import History
|
|
from dspy_heritage_rag import HeritageRAGPipeline
|
|
|
|
requested_provider = resolved_provider # Already resolved above
|
|
llm_provider_used: str | None = None
|
|
llm_model_used: str | None = None
|
|
lm = None
|
|
|
|
providers_to_try = [requested_provider]
|
|
for fallback in ["zai", "groq", "anthropic", "openai"]:
|
|
if fallback not in providers_to_try:
|
|
providers_to_try.append(fallback)
|
|
|
|
for provider in providers_to_try:
|
|
if lm is not None:
|
|
break
|
|
|
|
# Default models per provider (used if request.llm_model is not specified)
|
|
# Use LLM_MODEL from settings when it matches the provider prefix
|
|
default_models = {
|
|
"zai": settings.llm_model if settings.llm_model.startswith("glm-") else "glm-4.5-flash",
|
|
"groq": "llama-3.1-8b-instant",
|
|
"anthropic": settings.llm_model if settings.llm_model.startswith("claude-") else "claude-sonnet-4-20250514",
|
|
"openai": "gpt-4o-mini",
|
|
# Llama 3.1 8B: Good balance of speed/quality, available on HF serverless inference
|
|
# Alternatives: Qwen/QwQ-32B (better reasoning), mistralai/Mistral-7B-Instruct-v0.2
|
|
"huggingface": settings.llm_model if "/" in settings.llm_model else "meta-llama/Llama-3.1-8B-Instruct",
|
|
}
|
|
# HuggingFace models use org/model format (e.g., meta-llama/Llama-3.1-8B-Instruct)
|
|
# Groq models use simple names (e.g., llama-3.1-8b-instant)
|
|
model_prefixes = {
|
|
"glm-": "zai",
|
|
"llama-3.1-": "groq",
|
|
"llama-3.3-": "groq",
|
|
"claude-": "anthropic",
|
|
"gpt-": "openai",
|
|
# HuggingFace organization prefixes
|
|
"mistralai/": "huggingface",
|
|
"google/": "huggingface",
|
|
"Qwen/": "huggingface",
|
|
"deepseek-ai/": "huggingface",
|
|
"meta-llama/": "huggingface",
|
|
"utter-project/": "huggingface",
|
|
"microsoft/": "huggingface",
|
|
"tiiuae/": "huggingface",
|
|
}
|
|
|
|
# Determine which model to use: requested model (if valid for this provider) or default
|
|
requested_model = request.llm_model
|
|
model_to_use = default_models.get(provider, "")
|
|
|
|
# Check if requested model matches this provider
|
|
if requested_model:
|
|
for prefix, model_provider in model_prefixes.items():
|
|
if requested_model.startswith(prefix) and model_provider == provider:
|
|
model_to_use = requested_model
|
|
break
|
|
|
|
if provider == "zai" and settings.zai_api_token:
|
|
try:
|
|
lm = dspy.LM(
|
|
f"openai/{model_to_use}",
|
|
api_key=settings.zai_api_token,
|
|
api_base="https://api.z.ai/api/coding/paas/v4",
|
|
)
|
|
llm_provider_used = "zai"
|
|
llm_model_used = model_to_use
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create Z.AI LM: {e}")
|
|
|
|
elif provider == "groq" and settings.groq_api_key:
|
|
try:
|
|
lm = dspy.LM(f"groq/{model_to_use}", api_key=settings.groq_api_key)
|
|
llm_provider_used = "groq"
|
|
llm_model_used = model_to_use
|
|
logger.info(f"Using Groq {model_to_use} (FREE) for streaming request")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create Groq LM: {e}")
|
|
|
|
elif provider == "huggingface" and settings.huggingface_api_key:
|
|
try:
|
|
lm = dspy.LM(f"huggingface/{model_to_use}", api_key=settings.huggingface_api_key)
|
|
llm_provider_used = "huggingface"
|
|
llm_model_used = model_to_use
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create HuggingFace LM: {e}")
|
|
|
|
elif provider == "anthropic" and settings.anthropic_api_key:
|
|
try:
|
|
lm = dspy.LM(f"anthropic/{model_to_use}", api_key=settings.anthropic_api_key)
|
|
llm_provider_used = "anthropic"
|
|
llm_model_used = model_to_use
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create Anthropic LM: {e}")
|
|
|
|
elif provider == "openai" and settings.openai_api_key:
|
|
try:
|
|
lm = dspy.LM(f"openai/{model_to_use}", api_key=settings.openai_api_key)
|
|
llm_provider_used = "openai"
|
|
llm_model_used = model_to_use
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create OpenAI LM: {e}")
|
|
|
|
if lm is None:
|
|
yield emit_error(f"Geen LLM beschikbaar. Controleer API keys.")
|
|
return
|
|
|
|
yield emit_status("config", f"✅ LLM geconfigureerd ({llm_provider_used})")
|
|
|
|
# Stage 3: Prepare conversation history
|
|
yield emit_status("routing", "🧭 Vraag analyseren...")
|
|
|
|
history_messages = []
|
|
for turn in request.context:
|
|
if turn.get("question") and turn.get("answer"):
|
|
history_messages.append({
|
|
"question": turn["question"],
|
|
"answer": turn["answer"]
|
|
})
|
|
|
|
history = History(messages=history_messages) if history_messages else None
|
|
|
|
# Use global optimized pipeline (loaded with BootstrapFewShot weights: +14.3% quality)
|
|
if dspy_pipeline is not None:
|
|
pipeline = dspy_pipeline
|
|
logger.debug("Using global optimized DSPy pipeline (streaming)")
|
|
else:
|
|
# Fallback: create pipeline without optimized weights
|
|
qdrant_retriever = retriever.qdrant if retriever else None
|
|
pipeline = HeritageRAGPipeline(retriever=qdrant_retriever)
|
|
logger.debug("Using fallback (unoptimized) DSPy pipeline (streaming)")
|
|
|
|
# Stage 4: Execute pipeline with STREAMING answer generation
|
|
yield emit_status("retrieval", "📊 Database doorzoeken...")
|
|
|
|
result = None
|
|
|
|
# Check if pipeline supports streaming
|
|
if hasattr(pipeline, 'forward_streaming'):
|
|
# Use streaming mode - tokens arrive as they're generated
|
|
try:
|
|
with dspy.settings.context(lm=lm):
|
|
async for event in pipeline.forward_streaming(
|
|
embedding_model=request.embedding_model,
|
|
question=request.question,
|
|
language=request.language,
|
|
history=history,
|
|
include_viz=request.include_visualization,
|
|
conversation_state=conversation_state, # Pass session state for template SPARQL
|
|
):
|
|
event_type = event.get("type")
|
|
|
|
if event_type == "cache_hit":
|
|
# Cache hit - return immediately
|
|
result = event["prediction"]
|
|
yield emit_status("complete", "✅ Klaar! (cache)")
|
|
break
|
|
|
|
elif event_type == "retrieval_complete":
|
|
# Retrieval done, now generating answer
|
|
yield emit_status("generation", "💡 Antwoord genereren...")
|
|
|
|
elif event_type == "token":
|
|
# Stream token to frontend
|
|
yield json.dumps({"type": "token", "content": event["content"]}) + "\n"
|
|
|
|
elif event_type == "status":
|
|
# Status message from pipeline
|
|
yield emit_status("generation", event.get("message", "..."))
|
|
|
|
elif event_type == "answer_complete":
|
|
# Final prediction ready
|
|
result = event["prediction"]
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Streaming pipeline execution failed: {e}")
|
|
user_msg, details = extract_user_friendly_error(e)
|
|
yield emit_error(user_msg, details)
|
|
return
|
|
else:
|
|
# Fallback: Non-streaming mode (original behavior)
|
|
max_retries = 3
|
|
last_error: Exception | None = None
|
|
|
|
with dspy.settings.context(lm=lm):
|
|
for attempt in range(max_retries):
|
|
try:
|
|
if attempt > 0:
|
|
yield emit_status("retrieval", f"🔄 Opnieuw proberen ({attempt + 1}/{max_retries})...")
|
|
|
|
result = pipeline(
|
|
embedding_model=request.embedding_model,
|
|
question=request.question,
|
|
language=request.language,
|
|
history=history,
|
|
include_viz=request.include_visualization,
|
|
conversation_state=conversation_state,
|
|
)
|
|
break
|
|
except Exception as e:
|
|
last_error = e
|
|
error_str = str(e).lower()
|
|
is_retryable = any(keyword in error_str for keyword in [
|
|
"overloaded", "rate_limit", "rate limit", "too many requests",
|
|
"529", "503", "502", "504",
|
|
"temporarily unavailable", "service unavailable",
|
|
"connection reset", "connection refused", "timeout"
|
|
])
|
|
|
|
if is_retryable and attempt < max_retries - 1:
|
|
wait_time = 2 ** attempt
|
|
logger.warning(f"Transient API error (attempt {attempt + 1}/{max_retries}): {e}")
|
|
yield emit_status("retrieval", f"⏳ API overbelast, wachten {wait_time}s...")
|
|
await asyncio.sleep(wait_time)
|
|
continue
|
|
else:
|
|
logger.exception(f"Pipeline execution failed after {attempt + 1} attempts")
|
|
user_msg, details = extract_user_friendly_error(e)
|
|
yield emit_error(user_msg, details)
|
|
return
|
|
|
|
if result is None:
|
|
if last_error:
|
|
user_msg, details = extract_user_friendly_error(last_error)
|
|
yield emit_error(user_msg, details)
|
|
return
|
|
yield emit_error("Pipeline uitvoering mislukt zonder resultaat")
|
|
return
|
|
|
|
# Stage 5: Generate response (only for non-streaming fallback)
|
|
yield emit_status("generation", "💡 Antwoord genereren...")
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
# Extract query_type first - needed for SPARQL visibility decision
|
|
query_type = getattr(result, "query_type", None)
|
|
|
|
visualization = None
|
|
if request.include_visualization and hasattr(result, "visualization"):
|
|
viz = result.visualization
|
|
if viz:
|
|
# Now showing SPARQL for all query types including person queries
|
|
# Person queries use HeritagePersonSPARQLGenerator (schema:Person predicates)
|
|
# Institution queries use HeritageSPARQLGenerator (crm:E39_Actor predicates)
|
|
sparql_to_show = getattr(result, "sparql", None)
|
|
|
|
# viz can be either an object (with .viz_type attr) or a dict (with "type" key)
|
|
# Handle both cases for compatibility with streaming and non-streaming modes
|
|
if isinstance(viz, dict):
|
|
viz_type = viz.get("type", "table")
|
|
else:
|
|
viz_type = getattr(viz, "viz_type", "table")
|
|
|
|
visualization = {
|
|
"type": viz_type,
|
|
"sparql_query": sparql_to_show,
|
|
}
|
|
logger.info(f"[DEBUG] Built visualization: type={viz_type}, sparql_len={len(sparql_to_show) if sparql_to_show else 0}")
|
|
|
|
retrieved_results = getattr(result, "retrieved_results", None)
|
|
|
|
# Extract LLM response metadata from DSPy history (GLM 4.7 reasoning_content support)
|
|
llm_response_metadata = extract_llm_response_metadata(
|
|
lm=lm,
|
|
provider=llm_provider_used,
|
|
latency_ms=int(elapsed_ms),
|
|
)
|
|
|
|
# Rule 46: Build epistemic provenance for streaming endpoint
|
|
stream_sources_used = getattr(result, "sources_used", [])
|
|
stream_template_used = getattr(result, "template_used", False)
|
|
stream_template_id = getattr(result, "template_id", None)
|
|
|
|
# Infer data tier from sources
|
|
stream_tier = DataTier.TIER_3_CROWD_SOURCED.value
|
|
if stream_template_used:
|
|
stream_tier = DataTier.TIER_1_AUTHORITATIVE.value
|
|
elif any(s.lower() in ["sparql", "typedb"] for s in stream_sources_used):
|
|
stream_tier = DataTier.TIER_1_AUTHORITATIVE.value
|
|
|
|
stream_provenance = EpistemicProvenance(
|
|
dataSource=EpistemicDataSource.RAG_PIPELINE,
|
|
dataTier=stream_tier,
|
|
derivationChain=build_derivation_chain(
|
|
sources_used=stream_sources_used,
|
|
template_used=stream_template_used,
|
|
template_id=stream_template_id,
|
|
llm_provider=llm_provider_used,
|
|
),
|
|
sourcesQueried=stream_sources_used,
|
|
totalRetrieved=len(retrieved_results) if retrieved_results else 0,
|
|
totalAfterFusion=len(retrieved_results) if retrieved_results else 0,
|
|
templateUsed=stream_template_used,
|
|
templateId=stream_template_id,
|
|
llmProvider=llm_provider_used,
|
|
llmModel=llm_model_used,
|
|
)
|
|
|
|
response = DSPyQueryResponse(
|
|
question=request.question,
|
|
resolved_question=getattr(result, "resolved_question", None),
|
|
answer=getattr(result, "answer", "Geen antwoord gevonden."),
|
|
sources_used=stream_sources_used,
|
|
visualization=visualization,
|
|
retrieved_results=retrieved_results,
|
|
query_type=query_type,
|
|
query_time_ms=round(elapsed_ms, 2),
|
|
conversation_turn=len(request.context),
|
|
embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model),
|
|
timing_ms=getattr(result, "timing_ms", None),
|
|
cost_usd=getattr(result, "cost_usd", None),
|
|
timing_breakdown=getattr(result, "timing_breakdown", None),
|
|
llm_provider_used=llm_provider_used,
|
|
llm_model_used=llm_model_used,
|
|
cache_hit=False,
|
|
# LLM response provenance (GLM 4.7 Thinking Mode chain-of-thought)
|
|
llm_response=llm_response_metadata,
|
|
# Session management fields for multi-turn conversations
|
|
session_id=session_id,
|
|
template_used=stream_template_used,
|
|
template_id=stream_template_id,
|
|
# Rule 46: Epistemic provenance for transparency
|
|
epistemic_provenance=stream_provenance.model_dump(),
|
|
)
|
|
|
|
# Update session with this turn (before caching)
|
|
if session_mgr and session_id and conversation_state is not None:
|
|
try:
|
|
await session_mgr.add_turn_to_session(
|
|
session_id=session_id,
|
|
question=request.question,
|
|
answer=response.answer,
|
|
resolved_question=response.resolved_question,
|
|
template_id=getattr(result, "template_id", None),
|
|
slots=getattr(result, "slots", {}),
|
|
)
|
|
logger.debug(f"Updated session {session_id} with new turn")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to update session {session_id}: {e}")
|
|
|
|
# Record Prometheus metrics for monitoring
|
|
if METRICS_AVAILABLE and record_query:
|
|
try:
|
|
record_query(
|
|
endpoint="dspy_query_stream",
|
|
template_used=getattr(result, "template_used", False),
|
|
template_id=getattr(result, "template_id", None),
|
|
cache_hit=False,
|
|
status="success",
|
|
duration_seconds=elapsed_ms / 1000,
|
|
intent=query_type,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to record streaming metrics: {e}")
|
|
|
|
# Cache the response
|
|
if retriever:
|
|
await retriever.cache.set_dspy(
|
|
question=request.question,
|
|
language=request.language,
|
|
llm_provider=llm_provider_used,
|
|
embedding_model=request.embedding_model,
|
|
response=response.model_dump(),
|
|
context=request.context if request.context else None,
|
|
)
|
|
|
|
yield emit_status("complete", "✅ Klaar!")
|
|
yield json.dumps({"type": "complete", "data": response.model_dump()}) + "\n"
|
|
|
|
except ImportError as e:
|
|
logger.warning(f"DSPy pipeline not available: {e}")
|
|
yield emit_error("DSPy pipeline is niet beschikbaar.")
|
|
|
|
except Exception as e:
|
|
logger.exception("DSPy streaming query failed")
|
|
user_msg, details = extract_user_friendly_error(e)
|
|
yield emit_error(user_msg, details)
|
|
|
|
|
|
@app.post("/api/rag/dspy/query/stream")
|
|
async def dspy_query_stream(request: DSPyQueryRequest) -> StreamingResponse:
|
|
"""Streaming version of DSPy RAG query endpoint.
|
|
|
|
Returns NDJSON stream with status updates at each pipeline stage,
|
|
allowing the frontend to show progress during long-running queries.
|
|
|
|
Status stages:
|
|
- cache: Checking for cached response
|
|
- config: Configuring LLM provider
|
|
- routing: Analyzing query intent
|
|
- retrieval: Searching databases (Qdrant, SPARQL, etc.)
|
|
- generation: Generating answer with LLM
|
|
- complete: Final response ready
|
|
"""
|
|
return StreamingResponse(
|
|
stream_dspy_query_response(request),
|
|
media_type="application/x-ndjson",
|
|
)
|
|
|
|
|
|
async def stream_query_response(
|
|
request: QueryRequest,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream query response for long-running queries."""
|
|
if not retriever:
|
|
yield json.dumps({"error": "Retriever not initialized"})
|
|
return
|
|
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
# Route query
|
|
intent, sources = retriever.router.get_sources(request.question, request.sources)
|
|
|
|
# Extract geographic filters from question (province, city, institution type)
|
|
geo_filters = extract_geographic_filters(request.question)
|
|
|
|
yield json.dumps({
|
|
"type": "status",
|
|
"message": f"Routing query to {len(sources)} sources...",
|
|
"intent": intent.value,
|
|
"geo_filters": {k: v for k, v in geo_filters.items() if v},
|
|
}) + "\n"
|
|
|
|
# Retrieve from sources and stream progress
|
|
results = []
|
|
for source in sources:
|
|
yield json.dumps({
|
|
"type": "status",
|
|
"message": f"Querying {source.value}...",
|
|
}) + "\n"
|
|
|
|
source_results = await retriever.retrieve(
|
|
request.question,
|
|
[source],
|
|
request.k,
|
|
embedding_model=request.embedding_model,
|
|
region_codes=geo_filters["region_codes"],
|
|
cities=geo_filters["cities"],
|
|
institution_types=geo_filters["institution_types"],
|
|
)
|
|
results.extend(source_results)
|
|
|
|
yield json.dumps({
|
|
"type": "partial",
|
|
"source": source.value,
|
|
"count": len(source_results[0].items) if source_results else 0,
|
|
}) + "\n"
|
|
|
|
# Merge and finalize with provenance
|
|
merged, stream_provenance = retriever.merge_results(results)
|
|
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
|
|
|
|
yield json.dumps({
|
|
"type": "complete",
|
|
"results": merged,
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"result_count": len(merged),
|
|
"epistemic_provenance": stream_provenance.model_dump() if stream_provenance else None,
|
|
}) + "\n"
|
|
|
|
|
|
@app.post("/api/rag/query/stream")
|
|
async def query_rag_stream(request: QueryRequest) -> StreamingResponse:
|
|
"""Streaming version of RAG query endpoint."""
|
|
return StreamingResponse(
|
|
stream_query_response(request),
|
|
media_type="application/x-ndjson",
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# SEMANTIC CACHE ENDPOINTS (Qdrant-backed)
|
|
# =============================================================================
|
|
#
|
|
# High-performance semantic cache using Qdrant's HNSW vector index.
|
|
# Replaces slow client-side cosine similarity with server-side ANN search.
|
|
#
|
|
# Performance target:
|
|
# - Cache lookup: <20ms (vs 500-2000ms with client-side scan)
|
|
# - Cache store: <50ms
|
|
#
|
|
# Architecture:
|
|
# Frontend → /api/cache/lookup → Qdrant ANN search → cached response
|
|
# /api/cache/store → embed + upsert to Qdrant
|
|
# =============================================================================
|
|
|
|
# Lazy-loaded Qdrant client for cache
|
|
_cache_qdrant_client: Any = None
|
|
_cache_embedding_model: Any = None
|
|
|
|
CACHE_COLLECTION_NAME = "query_cache"
|
|
CACHE_EMBEDDING_DIM = 384 # all-MiniLM-L6-v2
|
|
|
|
|
|
def get_cache_qdrant_client() -> Any:
|
|
"""Get or create Qdrant client for cache collection.
|
|
|
|
Always uses localhost:6333 since cache is co-located with the RAG backend.
|
|
This avoids reverse proxy overhead and ensures direct local connection.
|
|
"""
|
|
global _cache_qdrant_client
|
|
|
|
if _cache_qdrant_client is not None:
|
|
return _cache_qdrant_client
|
|
|
|
try:
|
|
from qdrant_client import QdrantClient
|
|
|
|
# Cache always uses localhost - co-located with RAG backend
|
|
# Uses settings.qdrant_host/port which default to localhost:6333
|
|
_cache_qdrant_client = QdrantClient(
|
|
host=settings.qdrant_host,
|
|
port=settings.qdrant_port,
|
|
timeout=30,
|
|
)
|
|
logger.info(f"Qdrant cache client: {settings.qdrant_host}:{settings.qdrant_port}")
|
|
|
|
return _cache_qdrant_client
|
|
except ImportError:
|
|
logger.error("qdrant-client not installed")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to create Qdrant cache client: {e}")
|
|
return None
|
|
|
|
|
|
def get_cache_embedding_model() -> Any:
|
|
"""Get or create embedding model for cache (MiniLM-L6-v2, 384-dim)."""
|
|
global _cache_embedding_model
|
|
|
|
if _cache_embedding_model is not None:
|
|
return _cache_embedding_model
|
|
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
_cache_embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
|
logger.info("Loaded cache embedding model: all-MiniLM-L6-v2")
|
|
return _cache_embedding_model
|
|
except ImportError:
|
|
logger.error("sentence-transformers not installed")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to load cache embedding model: {e}")
|
|
return None
|
|
|
|
|
|
def ensure_cache_collection_exists() -> bool:
|
|
"""Ensure the query_cache collection exists in Qdrant."""
|
|
client = get_cache_qdrant_client()
|
|
if client is None:
|
|
return False
|
|
|
|
try:
|
|
from qdrant_client.models import Distance, VectorParams
|
|
|
|
# Check if collection exists
|
|
collections = client.get_collections().collections
|
|
if any(c.name == CACHE_COLLECTION_NAME for c in collections):
|
|
return True
|
|
|
|
# Create collection with HNSW index
|
|
client.create_collection(
|
|
collection_name=CACHE_COLLECTION_NAME,
|
|
vectors_config=VectorParams(
|
|
size=CACHE_EMBEDDING_DIM,
|
|
distance=Distance.COSINE,
|
|
),
|
|
)
|
|
logger.info(f"Created Qdrant collection: {CACHE_COLLECTION_NAME}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to ensure cache collection: {e}")
|
|
return False
|
|
|
|
|
|
# Request/Response Models for Cache API
|
|
|
|
class CacheLookupRequest(BaseModel):
|
|
"""Cache lookup request."""
|
|
query: str = Field(..., description="Query text to look up")
|
|
embedding: list[float] | None = Field(default=None, description="Pre-computed embedding (optional)")
|
|
similarity_threshold: float = Field(default=0.92, description="Minimum similarity for match")
|
|
language: str = Field(default="nl", description="Language filter")
|
|
|
|
|
|
class CacheLookupResponse(BaseModel):
|
|
"""Cache lookup response."""
|
|
found: bool
|
|
entry: dict[str, Any] | None = None
|
|
similarity: float = 0.0
|
|
method: str = "none"
|
|
lookup_time_ms: float = 0.0
|
|
|
|
|
|
class CacheStoreRequest(BaseModel):
|
|
"""Cache store request."""
|
|
query: str = Field(..., description="Query text")
|
|
embedding: list[float] | None = Field(default=None, description="Pre-computed embedding (optional)")
|
|
response: dict[str, Any] = Field(..., description="Response to cache")
|
|
language: str = Field(default="nl", description="Language")
|
|
model: str = Field(default="unknown", description="LLM model used")
|
|
ttl_seconds: int = Field(default=86400, description="Time-to-live in seconds")
|
|
|
|
|
|
class CacheStoreResponse(BaseModel):
|
|
"""Cache store response."""
|
|
success: bool
|
|
id: str | None = None
|
|
message: str = ""
|
|
|
|
|
|
class CacheStatsResponse(BaseModel):
|
|
"""Cache statistics response."""
|
|
total_entries: int = 0
|
|
collection_name: str = CACHE_COLLECTION_NAME
|
|
embedding_dim: int = CACHE_EMBEDDING_DIM
|
|
backend: str = "qdrant"
|
|
status: str = "ok"
|
|
|
|
|
|
@app.post("/api/cache/lookup", response_model=CacheLookupResponse)
|
|
async def cache_lookup(request: CacheLookupRequest) -> CacheLookupResponse:
|
|
"""Look up a query in the semantic cache using Qdrant ANN search.
|
|
|
|
This endpoint performs sub-millisecond vector similarity search using
|
|
Qdrant's HNSW index, replacing slow client-side cosine similarity scans.
|
|
"""
|
|
import time
|
|
start_time = time.perf_counter()
|
|
|
|
# Ensure collection exists
|
|
if not ensure_cache_collection_exists():
|
|
return CacheLookupResponse(
|
|
found=False,
|
|
similarity=0.0,
|
|
method="error",
|
|
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
|
|
)
|
|
|
|
client = get_cache_qdrant_client()
|
|
if client is None:
|
|
return CacheLookupResponse(
|
|
found=False,
|
|
similarity=0.0,
|
|
method="error",
|
|
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
|
|
)
|
|
|
|
# Get or generate embedding
|
|
embedding = request.embedding
|
|
if embedding is None:
|
|
model = get_cache_embedding_model()
|
|
if model is None:
|
|
return CacheLookupResponse(
|
|
found=False,
|
|
similarity=0.0,
|
|
method="error",
|
|
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
|
|
)
|
|
embedding = model.encode(request.query).tolist()
|
|
|
|
try:
|
|
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
|
|
|
# Build filter for language
|
|
search_filter = Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="language",
|
|
match=MatchValue(value=request.language),
|
|
)
|
|
]
|
|
)
|
|
|
|
# Perform ANN search using query_points (qdrant-client >= 1.7)
|
|
results = client.query_points(
|
|
collection_name=CACHE_COLLECTION_NAME,
|
|
query=embedding,
|
|
query_filter=search_filter,
|
|
limit=1,
|
|
score_threshold=request.similarity_threshold,
|
|
).points
|
|
|
|
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
|
|
|
if not results:
|
|
return CacheLookupResponse(
|
|
found=False,
|
|
similarity=0.0,
|
|
method="semantic",
|
|
lookup_time_ms=elapsed_ms,
|
|
)
|
|
|
|
# Extract best match
|
|
best = results[0]
|
|
payload = best.payload or {}
|
|
|
|
return CacheLookupResponse(
|
|
found=True,
|
|
entry={
|
|
"id": str(best.id),
|
|
"query": payload.get("query", ""),
|
|
"query_normalized": payload.get("query_normalized", ""),
|
|
"response": payload.get("response", {}),
|
|
"timestamp": payload.get("timestamp", 0),
|
|
"hit_count": payload.get("hit_count", 0),
|
|
"last_accessed": payload.get("last_accessed", 0),
|
|
"language": payload.get("language", "nl"),
|
|
"model": payload.get("model", "unknown"),
|
|
},
|
|
similarity=best.score,
|
|
method="semantic",
|
|
lookup_time_ms=elapsed_ms,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Cache lookup error: {e}")
|
|
return CacheLookupResponse(
|
|
found=False,
|
|
similarity=0.0,
|
|
method="error",
|
|
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
|
|
)
|
|
|
|
|
|
@app.post("/api/cache/store", response_model=CacheStoreResponse)
|
|
async def cache_store(request: CacheStoreRequest) -> CacheStoreResponse:
|
|
"""Store a query/response pair in the semantic cache.
|
|
|
|
Generates embedding if not provided and upserts to Qdrant.
|
|
"""
|
|
import time
|
|
import uuid
|
|
|
|
# Ensure collection exists
|
|
if not ensure_cache_collection_exists():
|
|
return CacheStoreResponse(
|
|
success=False,
|
|
message="Failed to ensure cache collection exists",
|
|
)
|
|
|
|
client = get_cache_qdrant_client()
|
|
if client is None:
|
|
return CacheStoreResponse(
|
|
success=False,
|
|
message="Qdrant client not available",
|
|
)
|
|
|
|
# Get or generate embedding
|
|
embedding = request.embedding
|
|
if embedding is None:
|
|
model = get_cache_embedding_model()
|
|
if model is None:
|
|
return CacheStoreResponse(
|
|
success=False,
|
|
message="Embedding model not available",
|
|
)
|
|
embedding = model.encode(request.query).tolist()
|
|
|
|
try:
|
|
from qdrant_client.models import PointStruct
|
|
|
|
# Generate unique ID
|
|
point_id = str(uuid.uuid4())
|
|
timestamp = int(time.time() * 1000)
|
|
|
|
# Normalize query for exact matching
|
|
query_normalized = request.query.lower().strip()
|
|
|
|
# Create point
|
|
point = PointStruct(
|
|
id=point_id,
|
|
vector=embedding,
|
|
payload={
|
|
"query": request.query,
|
|
"query_normalized": query_normalized,
|
|
"response": request.response,
|
|
"language": request.language,
|
|
"model": request.model,
|
|
"timestamp": timestamp,
|
|
"hit_count": 0,
|
|
"last_accessed": timestamp,
|
|
"ttl_seconds": request.ttl_seconds,
|
|
},
|
|
)
|
|
|
|
# Upsert to Qdrant
|
|
client.upsert(
|
|
collection_name=CACHE_COLLECTION_NAME,
|
|
points=[point],
|
|
)
|
|
|
|
logger.debug(f"Cached query: {request.query[:50]}...")
|
|
|
|
return CacheStoreResponse(
|
|
success=True,
|
|
id=point_id,
|
|
message="Stored successfully",
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Cache store error: {e}")
|
|
return CacheStoreResponse(
|
|
success=False,
|
|
message=str(e),
|
|
)
|
|
|
|
|
|
@app.get("/api/cache/stats", response_model=CacheStatsResponse)
|
|
async def cache_stats() -> CacheStatsResponse:
|
|
"""Get cache statistics."""
|
|
client = get_cache_qdrant_client()
|
|
if client is None:
|
|
return CacheStatsResponse(
|
|
status="error",
|
|
total_entries=0,
|
|
)
|
|
|
|
try:
|
|
# Check if collection exists
|
|
collections = client.get_collections().collections
|
|
if not any(c.name == CACHE_COLLECTION_NAME for c in collections):
|
|
return CacheStatsResponse(
|
|
status="no_collection",
|
|
total_entries=0,
|
|
)
|
|
|
|
# Get collection info
|
|
info = client.get_collection(CACHE_COLLECTION_NAME)
|
|
|
|
return CacheStatsResponse(
|
|
total_entries=info.points_count,
|
|
collection_name=CACHE_COLLECTION_NAME,
|
|
embedding_dim=CACHE_EMBEDDING_DIM,
|
|
backend="qdrant",
|
|
status="ok",
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Cache stats error: {e}")
|
|
return CacheStatsResponse(
|
|
status=f"error: {e}",
|
|
total_entries=0,
|
|
)
|
|
|
|
|
|
@app.delete("/api/cache/clear")
|
|
async def cache_clear() -> dict[str, Any]:
|
|
"""Clear all cache entries (both Qdrant semantic cache and Valkey/Redis cache)."""
|
|
results = {
|
|
"qdrant": {"success": False, "deleted": 0, "message": ""},
|
|
"valkey": {"success": False, "deleted": 0, "message": ""},
|
|
}
|
|
|
|
# 1. Clear Qdrant semantic cache
|
|
client = get_cache_qdrant_client()
|
|
if client is None:
|
|
results["qdrant"]["message"] = "Qdrant client not available"
|
|
else:
|
|
try:
|
|
collections = client.get_collections().collections
|
|
if not any(c.name == CACHE_COLLECTION_NAME for c in collections):
|
|
results["qdrant"] = {"success": True, "deleted": 0, "message": "Collection does not exist"}
|
|
else:
|
|
info = client.get_collection(CACHE_COLLECTION_NAME)
|
|
count = info.points_count
|
|
client.delete_collection(CACHE_COLLECTION_NAME)
|
|
ensure_cache_collection_exists()
|
|
results["qdrant"] = {"success": True, "deleted": count, "message": f"Cleared {count} entries"}
|
|
except Exception as e:
|
|
logger.error(f"Qdrant cache clear error: {e}")
|
|
results["qdrant"]["message"] = str(e)
|
|
|
|
# 2. Clear Valkey/Redis cache using redis-cli FLUSHALL
|
|
import subprocess
|
|
try:
|
|
# Try redis-cli first, then valkey-cli
|
|
for cli in ["redis-cli", "valkey-cli"]:
|
|
try:
|
|
result = subprocess.run(
|
|
[cli, "FLUSHALL"],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=5,
|
|
)
|
|
if result.returncode == 0 and "OK" in result.stdout:
|
|
results["valkey"] = {"success": True, "deleted": -1, "message": f"Flushed via {cli}"}
|
|
break
|
|
except FileNotFoundError:
|
|
continue
|
|
|
|
if not results["valkey"]["success"]:
|
|
results["valkey"]["message"] = "Neither redis-cli nor valkey-cli available"
|
|
except subprocess.TimeoutExpired:
|
|
results["valkey"]["message"] = "Cache flush timed out"
|
|
except Exception as e:
|
|
logger.error(f"Valkey cache clear error: {e}")
|
|
results["valkey"]["message"] = str(e)
|
|
|
|
overall_success = results["qdrant"]["success"] or results["valkey"]["success"]
|
|
total_deleted = max(0, results["qdrant"]["deleted"]) + max(0, results["valkey"]["deleted"])
|
|
|
|
return {
|
|
"success": overall_success,
|
|
"message": f"Qdrant: {results['qdrant']['message']}, Valkey: {results['valkey']['message']}",
|
|
"deleted": total_deleted,
|
|
"details": results,
|
|
}
|
|
|
|
|
|
# Main entry point
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"main:app",
|
|
host="0.0.0.0",
|
|
port=8003,
|
|
reload=settings.debug,
|
|
log_level="info",
|
|
)
|