glam/backend/rag/main.py
2025-12-21 00:01:54 +01:00

3144 lines
121 KiB
Python

from __future__ import annotations
"""
Unified RAG Backend for Heritage Custodian Data
Multi-source retrieval-augmented generation system that combines:
- Qdrant vector search (semantic similarity)
- Oxigraph SPARQL (knowledge graph queries)
- TypeDB (relationship traversal)
- PostGIS (geospatial queries)
- Valkey (semantic caching)
Architecture:
User Query → Query Analysis
┌─────┴─────┐
│ Router │
└─────┬─────┘
┌─────┬─────┼─────┬─────┐
↓ ↓ ↓ ↓ ↓
Qdrant SPARQL TypeDB PostGIS Cache
│ │ │ │ │
└─────┴─────┴─────┴─────┘
┌─────┴─────┐
│ Merger │
└─────┬─────┘
DSPy Generator
Visualization Selector
Response (JSON/Streaming)
Features:
- Intelligent query routing to appropriate data sources
- Score fusion for multi-source results
- Semantic caching via Valkey API
- Streaming responses for long-running queries
- DSPy assertions for output validation
Endpoints:
- POST /api/rag/query - Main RAG query endpoint
- POST /api/rag/sparql - Generate SPARQL with RAG context
- POST /api/rag/typedb/search - Direct TypeDB search
- GET /api/rag/health - Health check for all services
- GET /api/rag/stats - Retriever statistics
"""
import asyncio
import hashlib
import json
import logging
import os
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Any, AsyncIterator, TYPE_CHECKING
import httpx
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Type hints for optional imports (only used during type checking)
if TYPE_CHECKING:
from glam_extractor.api.hybrid_retriever import HybridRetriever
from glam_extractor.api.typedb_retriever import TypeDBRetriever
from glam_extractor.api.visualization import VisualizationSelector
# Import retrievers (with graceful fallbacks)
RETRIEVERS_AVAILABLE = False
create_hybrid_retriever: Any = None
HeritageCustodianRetriever: Any = None
create_typedb_retriever: Any = None
select_visualization: Any = None
VisualizationSelector: Any = None # type: ignore[no-redef]
generate_sparql: Any = None
configure_dspy: Any = None
get_province_code: Any = None # Province name to ISO 3166-2 code converter
try:
import sys
sys.path.insert(0, str(os.path.join(os.path.dirname(__file__), "..", "..", "src")))
from glam_extractor.api.hybrid_retriever import (
HybridRetriever as _HybridRetriever,
create_hybrid_retriever as _create_hybrid_retriever,
get_province_code as _get_province_code,
)
from glam_extractor.api.qdrant_retriever import HeritageCustodianRetriever as _HeritageCustodianRetriever
from glam_extractor.api.typedb_retriever import TypeDBRetriever as _TypeDBRetriever, create_typedb_retriever as _create_typedb_retriever
from glam_extractor.api.visualization import select_visualization as _select_visualization, VisualizationSelector as _VisualizationSelector
# Assign to module-level variables
create_hybrid_retriever = _create_hybrid_retriever
HeritageCustodianRetriever = _HeritageCustodianRetriever
create_typedb_retriever = _create_typedb_retriever
select_visualization = _select_visualization
VisualizationSelector = _VisualizationSelector
get_province_code = _get_province_code
RETRIEVERS_AVAILABLE = True
except ImportError as e:
logger.warning(f"Core retrievers not available: {e}")
# Provide a fallback get_province_code that returns None
def get_province_code(province_name: str | None) -> str | None:
"""Fallback when hybrid_retriever is not available."""
return None
# DSPy is optional - don't block retrievers if it's missing
try:
from glam_extractor.api.dspy_sparql import generate_sparql as _generate_sparql, configure_dspy as _configure_dspy
generate_sparql = _generate_sparql
configure_dspy = _configure_dspy
except ImportError as e:
logger.warning(f"DSPy SPARQL not available: {e}")
# Atomic query decomposition for geographic/type filtering
decompose_query: Any = None
DECOMPOSER_AVAILABLE = False
try:
from atomic_decomposer import decompose_query as _decompose_query
decompose_query = _decompose_query
DECOMPOSER_AVAILABLE = True
logger.info("Query decomposer loaded successfully")
except ImportError as e:
logger.info(f"Query decomposer not available: {e}")
# Cost tracker is optional - gracefully degrades if unavailable
COST_TRACKER_AVAILABLE = False
get_tracker = None
reset_tracker = None
try:
from cost_tracker import get_tracker as _get_tracker, reset_tracker as _reset_tracker
get_tracker = _get_tracker
reset_tracker = _reset_tracker
COST_TRACKER_AVAILABLE = True
logger.info("Cost tracker module loaded successfully")
except ImportError as e:
logger.info(f"Cost tracker not available (optional): {e}")
# Province detection for geographic filtering
DUTCH_PROVINCES = {
"noord-holland", "noordholland", "north holland", "north-holland",
"zuid-holland", "zuidholland", "south holland", "south-holland",
"utrecht", "gelderland", "noord-brabant", "noordbrabant", "brabant",
"north brabant", "limburg", "overijssel", "friesland", "fryslân",
"fryslan", "groningen", "drenthe", "flevoland", "zeeland",
}
def infer_location_level(location: str) -> str:
"""Infer whether location is city, province, or region.
Returns:
'province' if location is a Dutch province
'region' if location is a sub-provincial region
'city' otherwise
"""
location_lower = location.lower().strip()
if location_lower in DUTCH_PROVINCES:
return "province"
# Sub-provincial regions
regions = {"randstad", "veluwe", "achterhoek", "twente", "de betuwe", "betuwe"}
if location_lower in regions:
return "region"
return "city"
def extract_geographic_filters(question: str) -> dict[str, list[str] | None]:
"""Extract geographic filters from a question using query decomposition.
Returns:
dict with keys: region_codes, cities, institution_types
"""
filters: dict[str, list[str] | None] = {
"region_codes": None,
"cities": None,
"institution_types": None,
}
if not DECOMPOSER_AVAILABLE or not decompose_query:
return filters
try:
decomposed = decompose_query(question)
# Extract location and determine if it's a province or city
if decomposed.location:
location = decomposed.location
level = infer_location_level(location)
if level == "province":
# Convert province name to ISO 3166-2 code for Qdrant filtering
# e.g., "Noord-Holland" → "NH"
province_code = get_province_code(location)
if province_code:
filters["region_codes"] = [province_code]
logger.info(f"Province filter: {location}{province_code}")
elif level == "city":
filters["cities"] = [location]
logger.info(f"City filter: {location}")
# Extract institution type
if decomposed.institution_type:
# Map common types to enum values
type_mapping = {
"archive": "ARCHIVE",
"archief": "ARCHIVE",
"archieven": "ARCHIVE",
"museum": "MUSEUM",
"musea": "MUSEUM",
"museums": "MUSEUM",
"library": "LIBRARY",
"bibliotheek": "LIBRARY",
"bibliotheken": "LIBRARY",
"gallery": "GALLERY",
"galerie": "GALLERY",
}
inst_type = decomposed.institution_type.lower()
mapped_type = type_mapping.get(inst_type, inst_type.upper())
filters["institution_types"] = [mapped_type]
logger.info(f"Institution type filter: {mapped_type}")
except Exception as e:
logger.warning(f"Failed to extract geographic filters: {e}")
return filters
# Configuration
class Settings:
"""Application settings from environment variables."""
# API Configuration
api_title: str = "Heritage RAG API"
api_version: str = "1.0.0"
debug: bool = os.getenv("DEBUG", "false").lower() == "true"
# Valkey Cache
valkey_api_url: str = os.getenv("VALKEY_API_URL", "https://bronhouder.nl/api/cache")
cache_ttl: int = int(os.getenv("CACHE_TTL", "900")) # 15 minutes
# Qdrant Vector DB
# Production: Use URL-based client via bronhouder.nl/qdrant reverse proxy
qdrant_host: str = os.getenv("QDRANT_HOST", "localhost")
qdrant_port: int = int(os.getenv("QDRANT_PORT", "6333"))
qdrant_use_production: bool = os.getenv("QDRANT_USE_PRODUCTION", "true").lower() == "true"
qdrant_production_url: str = os.getenv("QDRANT_PRODUCTION_URL", "https://bronhouder.nl/qdrant")
# Multi-Embedding Support
# Enable to use named vectors with multiple embedding models (OpenAI 1536, MiniLM 384, BGE 768)
use_multi_embedding: bool = os.getenv("USE_MULTI_EMBEDDING", "true").lower() == "true"
preferred_embedding_model: str | None = os.getenv("PREFERRED_EMBEDDING_MODEL", None) # e.g., "minilm_384" or "openai_1536"
# Oxigraph SPARQL
# Production: Use bronhouder.nl/sparql reverse proxy
sparql_endpoint: str = os.getenv("SPARQL_ENDPOINT", "https://bronhouder.nl/sparql")
# TypeDB
# Note: TypeDB not exposed via reverse proxy - always use localhost
typedb_host: str = os.getenv("TYPEDB_HOST", "localhost")
typedb_port: int = int(os.getenv("TYPEDB_PORT", "1729"))
typedb_database: str = os.getenv("TYPEDB_DATABASE", "heritage_custodians")
typedb_use_production: bool = os.getenv("TYPEDB_USE_PRODUCTION", "false").lower() == "true" # Default off
# PostGIS/Geo API
# Production: Use bronhouder.nl/api/geo reverse proxy
postgis_url: str = os.getenv("POSTGIS_URL", "https://bronhouder.nl/api/geo")
# DuckLake Analytics
ducklake_url: str = os.getenv("DUCKLAKE_URL", "http://localhost:8001")
# LLM Configuration
anthropic_api_key: str = os.getenv("ANTHROPIC_API_KEY", "")
openai_api_key: str = os.getenv("OPENAI_API_KEY", "")
huggingface_api_key: str = os.getenv("HUGGINGFACE_API_KEY", "")
groq_api_key: str = os.getenv("GROQ_API_KEY", "")
zai_api_token: str = os.getenv("ZAI_API_TOKEN", "")
default_model: str = os.getenv("DEFAULT_MODEL", "claude-opus-4-5-20251101")
# LLM Provider: "anthropic", "openai", "huggingface", "zai" (FREE), or "groq" (FREE)
llm_provider: str = os.getenv("LLM_PROVIDER", "anthropic")
# Fast LM Provider for routing/extraction: "openai" (fast ~1-2s) or "zai" (FREE but slow ~13s)
# Default to openai for speed. Set to "zai" to save costs (free but adds ~12s latency)
fast_lm_provider: str = os.getenv("FAST_LM_PROVIDER", "openai")
# Retrieval weights
vector_weight: float = float(os.getenv("VECTOR_WEIGHT", "0.5"))
graph_weight: float = float(os.getenv("GRAPH_WEIGHT", "0.3"))
typedb_weight: float = float(os.getenv("TYPEDB_WEIGHT", "0.2"))
settings = Settings()
# Enums and Models
class QueryIntent(str, Enum):
"""Detected query intent for routing."""
GEOGRAPHIC = "geographic" # Location-based queries
STATISTICAL = "statistical" # Counts, aggregations
RELATIONAL = "relational" # Relationships between entities
TEMPORAL = "temporal" # Historical, timeline queries
SEARCH = "search" # General text search
DETAIL = "detail" # Specific entity lookup
class DataSource(str, Enum):
"""Available data sources."""
QDRANT = "qdrant"
SPARQL = "sparql"
TYPEDB = "typedb"
POSTGIS = "postgis"
CACHE = "cache"
DUCKLAKE = "ducklake"
@dataclass
class RetrievalResult:
"""Result from a single retriever."""
source: DataSource
items: list[dict[str, Any]]
score: float = 0.0
query_time_ms: float = 0.0
metadata: dict[str, Any] = field(default_factory=dict)
class QueryRequest(BaseModel):
"""RAG query request."""
question: str = Field(..., description="Natural language question")
language: str = Field(default="nl", description="Language code (nl or en)")
context: list[dict[str, Any]] = Field(default=[], description="Conversation history")
sources: list[DataSource] | None = Field(
default=None,
description="Data sources to query. If None, auto-routes based on query intent.",
)
k: int = Field(default=10, description="Number of results per source")
include_visualization: bool = Field(default=True, description="Include visualization config")
embedding_model: str | None = Field(
default=None,
description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available."
)
stream: bool = Field(default=False, description="Stream response")
class QueryResponse(BaseModel):
"""RAG query response."""
question: str
sparql: str | None = None
results: list[dict[str, Any]]
visualization: dict[str, Any] | None = None
sources_used: list[DataSource]
cache_hit: bool = False
query_time_ms: float
result_count: int
class SPARQLRequest(BaseModel):
"""SPARQL generation request."""
question: str
language: str = "nl"
context: list[dict[str, Any]] = []
use_rag: bool = True
class SPARQLResponse(BaseModel):
"""SPARQL generation response."""
sparql: str
explanation: str
rag_used: bool
retrieved_passages: list[str] = []
class TypeDBSearchRequest(BaseModel):
"""TypeDB search request."""
query: str = Field(..., description="Search query (name, type, or location)")
search_type: str = Field(
default="semantic",
description="Search type: semantic, name, type, or location"
)
k: int = Field(default=10, ge=1, le=100, description="Number of results")
class TypeDBSearchResponse(BaseModel):
"""TypeDB search response."""
query: str
search_type: str
results: list[dict[str, Any]]
result_count: int
query_time_ms: float
class PersonSearchRequest(BaseModel):
"""Person/staff search request."""
query: str = Field(..., description="Search query for person/staff (e.g., 'Wie werkt er in het Nationaal Archief?')")
k: int = Field(default=10, ge=1, le=100, description="Number of results to return")
filter_custodian: str | None = Field(default=None, description="Filter by custodian slug (e.g., 'nationaal-archief')")
only_heritage_relevant: bool = Field(default=False, description="Only return heritage-relevant staff")
embedding_model: str | None = Field(
default=None,
description="Embedding model to use (e.g., 'minilm_384', 'openai_1536'). If None, auto-selects best available."
)
class PersonSearchResponse(BaseModel):
"""Person/staff search response."""
query: str
results: list[dict[str, Any]]
result_count: int
query_time_ms: float
collection_stats: dict[str, Any] | None = None
embedding_model_used: str | None = None
class DSPyQueryRequest(BaseModel):
"""DSPy RAG query request with conversation support."""
question: str = Field(..., description="Natural language question")
language: str = Field(default="nl", description="Language code (nl or en)")
context: list[dict[str, Any]] = Field(
default=[],
description="Conversation history as list of {question, answer} dicts"
)
include_visualization: bool = Field(default=True, description="Include visualization config")
embedding_model: str | None = Field(
default=None,
description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available."
)
llm_provider: str | None = Field(
default=None,
description="LLM provider to use for this request: 'zai', 'anthropic', 'huggingface', or 'openai'. If None, uses server default (LLM_PROVIDER env)."
)
llm_model: str | None = Field(
default=None,
description="Specific LLM model to use (e.g., 'glm-4.6', 'claude-sonnet-4-5-20250929', 'gpt-4o'). If None, uses provider default."
)
class DSPyQueryResponse(BaseModel):
"""DSPy RAG query response."""
question: str
resolved_question: str | None = None
answer: str
sources_used: list[str] = []
visualization: dict[str, Any] | None = None
retrieved_results: list[dict[str, Any]] | None = None # Raw retrieved data for frontend visualization
query_type: str | None = None # "person" or "institution" - helps frontend choose visualization
query_time_ms: float = 0.0
conversation_turn: int = 0
embedding_model_used: str | None = None # Which embedding model was used for the search
llm_provider_used: str | None = None # Which LLM provider handled this request (zai, anthropic, huggingface, openai)
llm_model_used: str | None = None # Which specific LLM model was used (e.g., 'glm-4.6', 'claude-sonnet-4-5-20250929')
# Cost tracking fields (from cost_tracker module)
timing_ms: float | None = None # Total pipeline timing from cost tracker
cost_usd: float | None = None # Estimated LLM cost in USD
timing_breakdown: dict[str, float] | None = None # Per-stage timing breakdown
# Cache tracking
cache_hit: bool = False # Whether response was served from cache
# Cache Client
class ValkeyClient:
"""Client for Valkey semantic cache API."""
def __init__(self, base_url: str = settings.valkey_api_url):
self.base_url = base_url.rstrip("/")
self._client: httpx.AsyncClient | None = None
@property
async def client(self) -> httpx.AsyncClient:
"""Get or create async HTTP client."""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(timeout=30.0)
return self._client
def _cache_key(self, question: str, sources: list[DataSource] | None) -> str:
"""Generate cache key from question and sources."""
if sources:
sources_str = ",".join(sorted(s.value for s in sources))
else:
sources_str = "auto"
key_str = f"{question.lower().strip()}:{sources_str}"
return hashlib.sha256(key_str.encode()).hexdigest()[:32]
async def get(self, question: str, sources: list[DataSource] | None) -> dict[str, Any] | None:
"""Get cached response."""
try:
key = self._cache_key(question, sources)
client = await self.client
response = await client.get(f"{self.base_url}/get/{key}")
if response.status_code == 200:
data = response.json()
if data.get("value"):
logger.info(f"Cache hit for question: {question[:50]}...")
return json.loads(data["value"]) # type: ignore[no-any-return]
return None
except Exception as e:
logger.warning(f"Cache get failed: {e}")
return None
async def set(
self,
question: str,
sources: list[DataSource] | None,
response: dict[str, Any],
ttl: int = settings.cache_ttl,
) -> bool:
"""Cache response."""
try:
key = self._cache_key(question, sources)
client = await self.client
await client.post(
f"{self.base_url}/set",
json={
"key": key,
"value": json.dumps(response),
"ttl": ttl,
},
)
logger.debug(f"Cached response for: {question[:50]}...")
return True
except Exception as e:
logger.warning(f"Cache set failed: {e}")
return False
def _dspy_cache_key(
self,
question: str,
language: str,
llm_provider: str | None,
embedding_model: str | None,
context_hash: str | None = None,
) -> str:
"""Generate cache key for DSPy query responses.
Cache key components:
- Question text (normalized)
- Language code
- LLM provider (different providers give different answers)
- Embedding model (affects retrieval results)
- Context hash (for multi-turn conversations)
"""
components = [
question.lower().strip(),
language,
llm_provider or "default",
embedding_model or "auto",
context_hash or "no_context",
]
key_str = ":".join(components)
return f"dspy:{hashlib.sha256(key_str.encode()).hexdigest()[:32]}"
async def get_dspy(
self,
question: str,
language: str,
llm_provider: str | None,
embedding_model: str | None,
context: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Get cached DSPy response."""
try:
# Generate context hash if there's conversation history
context_hash = None
if context:
context_str = json.dumps(context, sort_keys=True)
context_hash = hashlib.sha256(context_str.encode()).hexdigest()[:16]
key = self._dspy_cache_key(question, language, llm_provider, embedding_model, context_hash)
client = await self.client
response = await client.get(f"{self.base_url}/get/{key}")
if response.status_code == 200:
data = response.json()
if data.get("value"):
logger.info(f"DSPy cache hit for question: {question[:50]}...")
return json.loads(data["value"]) # type: ignore[no-any-return]
return None
except Exception as e:
logger.warning(f"DSPy cache get failed: {e}")
return None
async def set_dspy(
self,
question: str,
language: str,
llm_provider: str | None,
embedding_model: str | None,
response: dict[str, Any],
context: list[dict[str, Any]] | None = None,
ttl: int = settings.cache_ttl,
) -> bool:
"""Cache DSPy response."""
try:
# Generate context hash if there's conversation history
context_hash = None
if context:
context_str = json.dumps(context, sort_keys=True)
context_hash = hashlib.sha256(context_str.encode()).hexdigest()[:16]
key = self._dspy_cache_key(question, language, llm_provider, embedding_model, context_hash)
client = await self.client
await client.post(
f"{self.base_url}/set",
json={
"key": key,
"value": json.dumps(response),
"ttl": ttl,
},
)
logger.debug(f"Cached DSPy response for: {question[:50]}...")
return True
except Exception as e:
logger.warning(f"DSPy cache set failed: {e}")
return False
async def close(self) -> None:
"""Close HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
# Query Router
class QueryRouter:
"""Routes queries to appropriate data sources based on intent."""
def __init__(self) -> None:
self.intent_keywords = {
QueryIntent.GEOGRAPHIC: [
"map", "kaart", "where", "waar", "location", "locatie",
"city", "stad", "country", "land", "region", "gebied",
"coordinates", "coördinaten", "near", "nearby", "in de buurt",
],
QueryIntent.STATISTICAL: [
"how many", "hoeveel", "count", "aantal", "total", "totaal",
"average", "gemiddeld", "distribution", "verdeling",
"percentage", "statistics", "statistiek", "most", "meest",
],
QueryIntent.RELATIONAL: [
"related", "gerelateerd", "connected", "verbonden",
"relationship", "relatie", "network", "netwerk",
"parent", "child", "merged", "fusie", "member of",
],
QueryIntent.TEMPORAL: [
"history", "geschiedenis", "timeline", "tijdlijn",
"when", "wanneer", "founded", "opgericht", "closed", "gesloten",
"over time", "evolution", "change", "verandering",
],
QueryIntent.DETAIL: [
"details", "information", "informatie", "about", "over",
"specific", "specifiek", "what is", "wat is",
],
}
self.source_routing = {
QueryIntent.GEOGRAPHIC: [DataSource.POSTGIS, DataSource.QDRANT, DataSource.SPARQL],
QueryIntent.STATISTICAL: [DataSource.DUCKLAKE, DataSource.SPARQL, DataSource.QDRANT],
QueryIntent.RELATIONAL: [DataSource.TYPEDB, DataSource.SPARQL],
QueryIntent.TEMPORAL: [DataSource.TYPEDB, DataSource.SPARQL],
QueryIntent.SEARCH: [DataSource.QDRANT, DataSource.SPARQL],
QueryIntent.DETAIL: [DataSource.SPARQL, DataSource.QDRANT],
}
def detect_intent(self, question: str) -> QueryIntent:
"""Detect query intent from question text."""
import re
question_lower = question.lower()
intent_scores = {intent: 0 for intent in QueryIntent}
for intent, keywords in self.intent_keywords.items():
for keyword in keywords:
# Use word boundary matching to avoid partial matches
# e.g., "land" should not match "netherlands"
pattern = r'\b' + re.escape(keyword) + r'\b'
if re.search(pattern, question_lower):
intent_scores[intent] += 1
max_intent = max(intent_scores, key=intent_scores.get) # type: ignore
if intent_scores[max_intent] == 0:
return QueryIntent.SEARCH
return max_intent
def get_sources(
self,
question: str,
requested_sources: list[DataSource] | None = None,
) -> tuple[QueryIntent, list[DataSource]]:
"""Get optimal sources for a query.
Args:
question: User's question
requested_sources: Explicitly requested sources (overrides routing)
Returns:
Tuple of (detected_intent, list_of_sources)
"""
intent = self.detect_intent(question)
if requested_sources:
return intent, requested_sources
return intent, self.source_routing.get(intent, [DataSource.QDRANT])
# Multi-Source Retriever
class MultiSourceRetriever:
"""Orchestrates retrieval across multiple data sources."""
def __init__(self) -> None:
self.cache = ValkeyClient()
self.router = QueryRouter()
# Initialize retrievers lazily
self._qdrant: HybridRetriever | None = None
self._typedb: TypeDBRetriever | None = None
self._sparql_client: httpx.AsyncClient | None = None
self._postgis_client: httpx.AsyncClient | None = None
self._ducklake_client: httpx.AsyncClient | None = None
@property
def qdrant(self) -> HybridRetriever | None:
"""Lazy-load Qdrant hybrid retriever with multi-embedding support."""
if self._qdrant is None and RETRIEVERS_AVAILABLE:
try:
self._qdrant = create_hybrid_retriever(
use_production=settings.qdrant_use_production,
use_multi_embedding=settings.use_multi_embedding,
preferred_embedding_model=settings.preferred_embedding_model,
)
except Exception as e:
logger.warning(f"Failed to initialize Qdrant: {e}")
return self._qdrant
@property
def typedb(self) -> TypeDBRetriever | None:
"""Lazy-load TypeDB retriever."""
if self._typedb is None and RETRIEVERS_AVAILABLE:
try:
self._typedb = create_typedb_retriever(
use_production=settings.typedb_use_production # Use TypeDB-specific setting
)
except Exception as e:
logger.warning(f"Failed to initialize TypeDB: {e}")
return self._typedb
async def _get_sparql_client(self) -> httpx.AsyncClient:
"""Get SPARQL HTTP client."""
if self._sparql_client is None or self._sparql_client.is_closed:
self._sparql_client = httpx.AsyncClient(timeout=30.0)
return self._sparql_client
async def _get_postgis_client(self) -> httpx.AsyncClient:
"""Get PostGIS HTTP client."""
if self._postgis_client is None or self._postgis_client.is_closed:
self._postgis_client = httpx.AsyncClient(timeout=30.0)
return self._postgis_client
async def _get_ducklake_client(self) -> httpx.AsyncClient:
"""Get DuckLake HTTP client."""
if self._ducklake_client is None or self._ducklake_client.is_closed:
self._ducklake_client = httpx.AsyncClient(timeout=60.0) # Longer timeout for SQL
return self._ducklake_client
async def retrieve_from_qdrant(
self,
query: str,
k: int = 10,
embedding_model: str | None = None,
region_codes: list[str] | None = None,
cities: list[str] | None = None,
institution_types: list[str] | None = None,
) -> RetrievalResult:
"""Retrieve from Qdrant vector + SPARQL hybrid search.
Args:
query: Search query
k: Number of results to return
embedding_model: Optional embedding model to use (e.g., 'minilm_384', 'openai_1536')
region_codes: Filter by province/region codes (e.g., ['NH', 'ZH'])
cities: Filter by city names (e.g., ['Amsterdam', 'Rotterdam'])
institution_types: Filter by institution types (e.g., ['ARCHIVE', 'MUSEUM'])
"""
start = asyncio.get_event_loop().time()
items = []
if self.qdrant:
try:
results = self.qdrant.search(
query,
k=k,
using=embedding_model,
region_codes=region_codes,
cities=cities,
institution_types=institution_types,
)
items = [r.to_dict() for r in results]
except Exception as e:
logger.error(f"Qdrant retrieval failed: {e}")
elapsed = (asyncio.get_event_loop().time() - start) * 1000
return RetrievalResult(
source=DataSource.QDRANT,
items=items,
score=max((r.get("scores", {}).get("combined", 0) for r in items), default=0),
query_time_ms=elapsed,
)
async def retrieve_from_sparql(
self,
query: str,
k: int = 10,
) -> RetrievalResult:
"""Retrieve from SPARQL endpoint."""
start = asyncio.get_event_loop().time()
# Use DSPy to generate SPARQL
items = []
try:
if RETRIEVERS_AVAILABLE:
sparql_result = generate_sparql(query, language="nl", use_rag=False)
sparql_query = sparql_result.get("sparql", "")
if sparql_query:
client = await self._get_sparql_client()
response = await client.post(
settings.sparql_endpoint,
data={"query": sparql_query},
headers={"Accept": "application/sparql-results+json"},
)
if response.status_code == 200:
data = response.json()
bindings = data.get("results", {}).get("bindings", [])
items = [
{k: v.get("value") for k, v in b.items()}
for b in bindings[:k]
]
except Exception as e:
logger.error(f"SPARQL retrieval failed: {e}")
elapsed = (asyncio.get_event_loop().time() - start) * 1000
return RetrievalResult(
source=DataSource.SPARQL,
items=items,
score=1.0 if items else 0.0,
query_time_ms=elapsed,
)
async def retrieve_from_typedb(
self,
query: str,
k: int = 10,
) -> RetrievalResult:
"""Retrieve from TypeDB knowledge graph."""
start = asyncio.get_event_loop().time()
items = []
if self.typedb:
try:
results = self.typedb.semantic_search(query, k=k)
items = [r.to_dict() for r in results]
except Exception as e:
logger.error(f"TypeDB retrieval failed: {e}")
elapsed = (asyncio.get_event_loop().time() - start) * 1000
return RetrievalResult(
source=DataSource.TYPEDB,
items=items,
score=max((r.get("relevance_score", 0) for r in items), default=0),
query_time_ms=elapsed,
)
async def retrieve_from_postgis(
self,
query: str,
k: int = 10,
) -> RetrievalResult:
"""Retrieve from PostGIS geospatial database."""
start = asyncio.get_event_loop().time()
# Extract location from query for geospatial search
# This is a simplified implementation
items = []
try:
client = await self._get_postgis_client()
# Try to detect city name for bbox search
query_lower = query.lower()
# Simple city detection
cities = {
"amsterdam": {"lat": 52.3676, "lon": 4.9041},
"rotterdam": {"lat": 51.9244, "lon": 4.4777},
"den haag": {"lat": 52.0705, "lon": 4.3007},
"utrecht": {"lat": 52.0907, "lon": 5.1214},
}
for city, coords in cities.items():
if city in query_lower:
# Query PostGIS for nearby institutions
response = await client.get(
f"{settings.postgis_url}/api/institutions/nearby",
params={
"lat": coords["lat"],
"lon": coords["lon"],
"radius_km": 10,
"limit": k,
},
)
if response.status_code == 200:
items = response.json()
break
except Exception as e:
logger.error(f"PostGIS retrieval failed: {e}")
elapsed = (asyncio.get_event_loop().time() - start) * 1000
return RetrievalResult(
source=DataSource.POSTGIS,
items=items,
score=1.0 if items else 0.0,
query_time_ms=elapsed,
)
async def retrieve_from_ducklake(
self,
query: str,
k: int = 10,
) -> RetrievalResult:
"""Retrieve from DuckLake SQL analytics database.
Uses DSPy HeritageSQLGenerator to convert natural language to SQL,
then executes against the custodians table.
Args:
query: Natural language question or SQL query
k: Maximum number of results to return
Returns:
RetrievalResult with query results
"""
start = asyncio.get_event_loop().time()
items = []
metadata = {}
try:
# Import the SQL generator from dspy module
import dspy
from dspy_heritage_rag import HeritageSQLGenerator
# Initialize DSPy predictor for SQL generation
sql_generator = dspy.Predict(HeritageSQLGenerator)
# Generate SQL from natural language
sql_result = sql_generator(
question=query,
intent="statistical",
entities="",
context="",
)
sql_query = sql_result.sql
metadata["generated_sql"] = sql_query
metadata["sql_explanation"] = sql_result.explanation
# Execute SQL against DuckLake
client = await self._get_ducklake_client()
response = await client.post(
f"{settings.ducklake_url}/query",
json={"sql": sql_query},
timeout=60.0,
)
if response.status_code == 200:
data = response.json()
columns = data.get("columns", [])
rows = data.get("rows", [])
# Convert to list of dicts
items = [dict(zip(columns, row)) for row in rows[:k]]
metadata["row_count"] = data.get("row_count", len(items))
metadata["execution_time_ms"] = data.get("execution_time_ms", 0)
else:
logger.error(f"DuckLake query failed: {response.status_code} - {response.text}")
metadata["error"] = f"HTTP {response.status_code}"
except Exception as e:
logger.error(f"DuckLake retrieval failed: {e}")
metadata["error"] = str(e)
elapsed = (asyncio.get_event_loop().time() - start) * 1000
return RetrievalResult(
source=DataSource.DUCKLAKE,
items=items,
score=1.0 if items else 0.0,
query_time_ms=elapsed,
metadata=metadata,
)
async def retrieve(
self,
question: str,
sources: list[DataSource],
k: int = 10,
embedding_model: str | None = None,
region_codes: list[str] | None = None,
cities: list[str] | None = None,
institution_types: list[str] | None = None,
) -> list[RetrievalResult]:
"""Retrieve from multiple sources concurrently.
Args:
question: User's question
sources: Data sources to query
k: Number of results per source
embedding_model: Optional embedding model for Qdrant (e.g., 'minilm_384', 'openai_1536')
region_codes: Filter by province/region codes (e.g., ['NH', 'ZH']) - Qdrant only
cities: Filter by city names (e.g., ['Amsterdam']) - Qdrant only
institution_types: Filter by institution types (e.g., ['ARCHIVE']) - Qdrant only
Returns:
List of RetrievalResult from each source
"""
tasks = []
for source in sources:
if source == DataSource.QDRANT:
tasks.append(self.retrieve_from_qdrant(
question,
k,
embedding_model,
region_codes=region_codes,
cities=cities,
institution_types=institution_types,
))
elif source == DataSource.SPARQL:
tasks.append(self.retrieve_from_sparql(question, k))
elif source == DataSource.TYPEDB:
tasks.append(self.retrieve_from_typedb(question, k))
elif source == DataSource.POSTGIS:
tasks.append(self.retrieve_from_postgis(question, k))
elif source == DataSource.DUCKLAKE:
tasks.append(self.retrieve_from_ducklake(question, k))
results = await asyncio.gather(*tasks, return_exceptions=True)
# Filter out exceptions
valid_results = []
for r in results:
if isinstance(r, RetrievalResult):
valid_results.append(r)
elif isinstance(r, Exception):
logger.error(f"Retrieval task failed: {r}")
return valid_results
def merge_results(
self,
results: list[RetrievalResult],
max_results: int = 20,
) -> list[dict[str, Any]]:
"""Merge and deduplicate results from multiple sources.
Uses reciprocal rank fusion for score combination.
"""
# Track items by GHCID for deduplication
merged: dict[str, dict[str, Any]] = {}
for result in results:
for rank, item in enumerate(result.items):
ghcid = item.get("ghcid", item.get("id", f"unknown_{rank}"))
if ghcid not in merged:
merged[ghcid] = item.copy()
merged[ghcid]["_sources"] = []
merged[ghcid]["_rrf_score"] = 0.0
# Reciprocal Rank Fusion
rrf_score = 1.0 / (60 + rank) # k=60 is standard
# Weight by source
source_weights = {
DataSource.QDRANT: settings.vector_weight,
DataSource.SPARQL: settings.graph_weight,
DataSource.TYPEDB: settings.typedb_weight,
DataSource.POSTGIS: 0.3,
}
weight = source_weights.get(result.source, 0.5)
merged[ghcid]["_rrf_score"] += rrf_score * weight
merged[ghcid]["_sources"].append(result.source.value)
# Sort by RRF score
sorted_items = sorted(
merged.values(),
key=lambda x: x.get("_rrf_score", 0),
reverse=True,
)
return sorted_items[:max_results]
async def close(self) -> None:
"""Clean up resources."""
await self.cache.close()
if self._sparql_client:
await self._sparql_client.aclose()
if self._postgis_client:
await self._postgis_client.aclose()
if self._qdrant:
self._qdrant.close()
if self._typedb:
self._typedb.close()
def search_persons(
self,
query: str,
k: int = 10,
filter_custodian: str | None = None,
only_heritage_relevant: bool = False,
using: str | None = None,
) -> list[Any]:
"""Search for persons/staff in the heritage_persons collection.
Delegates to HybridRetriever.search_persons() if available.
Args:
query: Search query
k: Number of results
filter_custodian: Optional custodian slug to filter by
only_heritage_relevant: Only return heritage-relevant staff
using: Optional embedding model to use (e.g., 'minilm_384', 'openai_1536')
Returns:
List of RetrievedPerson objects
"""
if self.qdrant:
try:
return self.qdrant.search_persons( # type: ignore[no-any-return]
query=query,
k=k,
filter_custodian=filter_custodian,
only_heritage_relevant=only_heritage_relevant,
using=using,
)
except Exception as e:
logger.error(f"Person search failed: {e}")
return []
def get_stats(self) -> dict[str, Any]:
"""Get statistics from all retrievers.
Returns combined stats from Qdrant (including persons collection) and TypeDB.
"""
stats = {}
if self.qdrant:
try:
qdrant_stats = self.qdrant.get_stats()
stats.update(qdrant_stats)
except Exception as e:
logger.warning(f"Failed to get Qdrant stats: {e}")
if self.typedb:
try:
typedb_stats = self.typedb.get_stats()
stats["typedb"] = typedb_stats
except Exception as e:
logger.warning(f"Failed to get TypeDB stats: {e}")
return stats
# Global instances
retriever: MultiSourceRetriever | None = None
viz_selector: VisualizationSelector | None = None
dspy_pipeline: Any = None # HeritageRAGPipeline instance (loaded with optimized model)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Application lifespan manager."""
global retriever, viz_selector, dspy_pipeline
# Startup
logger.info("Starting Heritage RAG API...")
retriever = MultiSourceRetriever()
if RETRIEVERS_AVAILABLE:
# Check for any available LLM API key (Anthropic preferred, OpenAI fallback)
has_llm_key = bool(settings.anthropic_api_key or settings.openai_api_key)
# VisualizationSelector requires DSPy - make it optional
try:
viz_selector = VisualizationSelector(use_dspy=has_llm_key)
except RuntimeError as e:
logger.warning(f"VisualizationSelector not available: {e}")
viz_selector = None
# Configure DSPy based on LLM_PROVIDER setting
# Respect user's provider preference, with fallback chain
import dspy
llm_provider = settings.llm_provider.lower()
logger.info(f"LLM_PROVIDER configured as: {llm_provider}")
dspy_configured = False
# Try Z.AI GLM-4.6 if configured as provider (FREE!)
if llm_provider == "zai" and settings.zai_api_token:
try:
# Z.AI uses OpenAI-compatible API format
lm = dspy.LM(
"openai/glm-4.6",
api_key=settings.zai_api_token,
api_base="https://api.z.ai/api/coding/paas/v4",
)
dspy.configure(lm=lm)
logger.info("Configured DSPy with Z.AI GLM-4.6 (FREE)")
dspy_configured = True
except Exception as e:
logger.warning(f"Failed to configure DSPy with Z.AI: {e}")
# Try HuggingFace if configured as provider
if not dspy_configured and llm_provider == "huggingface" and settings.huggingface_api_key:
try:
lm = dspy.LM("huggingface/utter-project/EuroLLM-9B-Instruct", api_key=settings.huggingface_api_key)
dspy.configure(lm=lm)
logger.info("Configured DSPy with HuggingFace EuroLLM-9B-Instruct")
dspy_configured = True
except Exception as e:
logger.warning(f"Failed to configure DSPy with HuggingFace: {e}")
# Try Anthropic if not yet configured (either as primary or fallback)
if not dspy_configured and (llm_provider == "anthropic" or (llm_provider == "huggingface" and settings.anthropic_api_key)):
if settings.anthropic_api_key and configure_dspy:
try:
configure_dspy(
provider="anthropic",
model=settings.default_model,
api_key=settings.anthropic_api_key,
)
dspy_configured = True
except Exception as e:
logger.warning(f"Failed to configure DSPy with Anthropic: {e}")
# Try OpenAI as final fallback
if not dspy_configured and settings.openai_api_key and configure_dspy:
try:
configure_dspy(
provider="openai",
model="gpt-4o-mini",
api_key=settings.openai_api_key,
)
dspy_configured = True
except Exception as e:
logger.warning(f"Failed to configure DSPy with OpenAI: {e}")
if not dspy_configured:
logger.warning("No LLM provider configured - DSPy queries will fail")
# Initialize optimized HeritageRAGPipeline (if DSPy is configured)
if dspy_configured:
try:
from dspy_heritage_rag import HeritageRAGPipeline
from pathlib import Path
# Create pipeline with Qdrant retriever
qdrant_retriever = retriever.qdrant if retriever else None
dspy_pipeline = HeritageRAGPipeline(retriever=qdrant_retriever)
# Load optimized model (BootstrapFewShot: 14.3% quality improvement)
optimized_model_path = Path(__file__).parent / "optimized_models" / "heritage_rag_bootstrap_latest.json"
if optimized_model_path.exists():
dspy_pipeline.load(str(optimized_model_path))
logger.info(f"Loaded optimized DSPy pipeline from {optimized_model_path}")
else:
logger.warning(f"Optimized model not found at {optimized_model_path}, using unoptimized pipeline")
except Exception as e:
logger.warning(f"Failed to initialize DSPy pipeline: {e}")
dspy_pipeline = None
# === HOT LOADING: Warmup embedding model to avoid cold-start latency ===
# The sentence-transformers model takes 3-15 seconds to load on first use.
# By loading it eagerly at startup, we eliminate this delay for users.
if retriever.qdrant:
logger.info("Warming up embedding model (this takes 3-15 seconds on first startup)...")
try:
# Trigger model load with a dummy embedding request
_ = retriever.qdrant._get_embedding("archief warmup query")
logger.info("✅ Embedding model warmed up - ready for fast queries!")
except Exception as e:
logger.warning(f"Failed to warm up embedding model: {e}")
logger.info("Heritage RAG API started")
yield
# Shutdown
logger.info("Shutting down Heritage RAG API...")
if retriever:
await retriever.close()
logger.info("Heritage RAG API stopped")
# Create FastAPI app
app = FastAPI(
title=settings.api_title,
version=settings.api_version,
lifespan=lifespan,
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API Endpoints
@app.get("/api/rag/health")
async def health_check() -> dict[str, Any]:
"""Health check for all services."""
health: dict[str, Any] = {
"status": "ok",
"timestamp": datetime.now(timezone.utc).isoformat(),
"services": {},
}
# Check Qdrant
if retriever and retriever.qdrant:
try:
stats = retriever.qdrant.get_stats()
health["services"]["qdrant"] = {
"status": "ok",
"vectors": stats.get("qdrant", {}).get("vectors_count", 0),
}
except Exception as e:
health["services"]["qdrant"] = {"status": "error", "error": str(e)}
# Check SPARQL
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{settings.sparql_endpoint.replace('/query', '')}")
health["services"]["sparql"] = {
"status": "ok" if response.status_code < 500 else "error"
}
except Exception as e:
health["services"]["sparql"] = {"status": "error", "error": str(e)}
# Check TypeDB
if retriever and retriever.typedb:
try:
stats = retriever.typedb.get_stats()
health["services"]["typedb"] = {
"status": "ok",
"entities": stats.get("entities", {}),
}
except Exception as e:
health["services"]["typedb"] = {"status": "error", "error": str(e)}
# Overall status
services = health["services"]
errors = sum(1 for s in services.values() if isinstance(s, dict) and s.get("status") == "error")
health["status"] = "ok" if errors == 0 else "degraded" if errors < 3 else "error"
return health
@app.get("/api/rag/stats")
async def get_stats() -> dict[str, Any]:
"""Get retriever statistics."""
stats: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"retrievers": {},
}
if retriever:
if retriever.qdrant:
stats["retrievers"]["qdrant"] = retriever.qdrant.get_stats()
if retriever.typedb:
stats["retrievers"]["typedb"] = retriever.typedb.get_stats()
return stats
@app.get("/api/rag/stats/costs")
async def get_cost_stats() -> dict[str, Any]:
"""Get cost tracking session statistics.
Returns cumulative statistics for the current session including:
- Total LLM calls and token usage
- Total retrieval operations and latencies
- Estimated costs by model
- Pipeline timing statistics
Returns:
Dict with cost tracker statistics or unavailable message
"""
if not COST_TRACKER_AVAILABLE or not get_tracker:
return {
"available": False,
"message": "Cost tracker module not available",
}
tracker = get_tracker()
return {
"available": True,
"timestamp": datetime.now(timezone.utc).isoformat(),
"session": tracker.get_session_summary(),
}
@app.post("/api/rag/stats/costs/reset")
async def reset_cost_stats() -> dict[str, Any]:
"""Reset cost tracking statistics.
Clears all accumulated statistics and starts a fresh session.
Useful for per-conversation or per-session cost tracking.
Returns:
Confirmation message
"""
if not COST_TRACKER_AVAILABLE or not reset_tracker:
return {
"available": False,
"message": "Cost tracker module not available",
}
reset_tracker()
return {
"available": True,
"message": "Cost tracking statistics reset",
"timestamp": datetime.now(timezone.utc).isoformat(),
}
@app.get("/api/rag/embedding/models")
async def get_embedding_models() -> dict[str, Any]:
"""List available embedding models for the Qdrant collections.
Returns information about which embedding models are available in each
collection's named vectors, helping clients choose the right model for
their use case.
Returns:
Dict with available models per collection, current settings, and recommendations
"""
result: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"multi_embedding_enabled": settings.use_multi_embedding,
"preferred_model": settings.preferred_embedding_model,
"collections": {},
"models": {
"openai_1536": {
"description": "OpenAI text-embedding-3-small (1536 dimensions)",
"quality": "high",
"cost": "paid API",
"recommended_for": "production, high-quality semantic search",
},
"minilm_384": {
"description": "sentence-transformers/all-MiniLM-L6-v2 (384 dimensions)",
"quality": "good",
"cost": "free (local)",
"recommended_for": "development, cost-sensitive deployments",
},
"bge_768": {
"description": "BAAI/bge-small-en-v1.5 (768 dimensions)",
"quality": "very good",
"cost": "free (local)",
"recommended_for": "balanced quality/cost, multilingual support",
},
},
}
if retriever and retriever.qdrant:
qdrant = retriever.qdrant
# Check if multi-embedding is enabled and get available models
if hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding:
if hasattr(qdrant, 'multi_retriever') and qdrant.multi_retriever:
multi = qdrant.multi_retriever
# Get available models for institutions collection
try:
inst_models = multi.get_available_models("heritage_custodians")
selected = multi.select_model("heritage_custodians")
result["collections"]["heritage_custodians"] = {
"available_models": [m.value for m in inst_models],
"uses_named_vectors": multi.uses_named_vectors("heritage_custodians"),
"recommended": selected.value if selected else None,
}
except Exception as e:
result["collections"]["heritage_custodians"] = {"error": str(e)}
# Get available models for persons collection
try:
person_models = multi.get_available_models("heritage_persons")
selected = multi.select_model("heritage_persons")
result["collections"]["heritage_persons"] = {
"available_models": [m.value for m in person_models],
"uses_named_vectors": multi.uses_named_vectors("heritage_persons"),
"recommended": selected.value if selected else None,
}
except Exception as e:
result["collections"]["heritage_persons"] = {"error": str(e)}
else:
# Single embedding mode - detect dimension
stats = qdrant.get_stats()
result["single_embedding_mode"] = True
result["note"] = "Collections use single embedding vectors. Enable USE_MULTI_EMBEDDING=true to use named vectors."
return result
class EmbeddingCompareRequest(BaseModel):
"""Request for comparing embedding models."""
query: str = Field(..., description="Query to search with")
collection: str = Field(default="heritage_persons", description="Collection to search")
k: int = Field(default=5, ge=1, le=20, description="Number of results per model")
@app.post("/api/rag/embedding/compare")
async def compare_embedding_models(request: EmbeddingCompareRequest) -> dict[str, Any]:
"""Compare search results across different embedding models.
Performs the same search query using each available embedding model,
allowing A/B testing of embedding quality.
This endpoint is useful for:
- Evaluating which embedding model works best for your queries
- Understanding differences in semantic similarity between models
- Making informed decisions about which model to use in production
Returns:
Dict with results from each embedding model, including scores and overlap analysis
"""
import time
start_time = time.time()
if not retriever or not retriever.qdrant:
raise HTTPException(status_code=503, detail="Qdrant retriever not available")
qdrant = retriever.qdrant
# Check if multi-embedding is available
if not (hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding):
raise HTTPException(
status_code=400,
detail="Multi-embedding mode not enabled. Set USE_MULTI_EMBEDDING=true to use this endpoint."
)
if not (hasattr(qdrant, 'multi_retriever') and qdrant.multi_retriever):
raise HTTPException(status_code=503, detail="Multi-embedding retriever not initialized")
multi = qdrant.multi_retriever
try:
# Use the compare_models method from MultiEmbeddingRetriever
comparison = multi.compare_models(
query=request.query,
collection=request.collection,
k=request.k,
)
elapsed_ms = (time.time() - start_time) * 1000
return {
"query": request.query,
"collection": request.collection,
"k": request.k,
"query_time_ms": round(elapsed_ms, 2),
"comparison": comparison,
}
except Exception as e:
logger.exception(f"Embedding comparison failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/rag/query", response_model=QueryResponse)
async def query_rag(request: QueryRequest) -> QueryResponse:
"""Main RAG query endpoint.
Orchestrates retrieval from multiple sources, merges results,
and optionally generates visualization configuration.
"""
if not retriever:
raise HTTPException(status_code=503, detail="Retriever not initialized")
start_time = asyncio.get_event_loop().time()
# Check cache first
cached = await retriever.cache.get(request.question, request.sources)
if cached:
cached["cache_hit"] = True
return QueryResponse(**cached)
# Route query to appropriate sources
intent, sources = retriever.router.get_sources(request.question, request.sources)
logger.info(f"Query intent: {intent}, sources: {sources}")
# Extract geographic filters from question (province, city, institution type)
geo_filters = extract_geographic_filters(request.question)
if any(geo_filters.values()):
logger.info(f"Geographic filters extracted: {geo_filters}")
# Retrieve from all sources
results = await retriever.retrieve(
request.question,
sources,
request.k,
embedding_model=request.embedding_model,
region_codes=geo_filters["region_codes"],
cities=geo_filters["cities"],
institution_types=geo_filters["institution_types"],
)
# Merge results
merged_items = retriever.merge_results(results, max_results=request.k * 2)
# Generate visualization config if requested
visualization = None
if request.include_visualization and viz_selector and merged_items:
# Extract schema from first result
schema_fields = list(merged_items[0].keys()) if merged_items else []
schema_str = ", ".join(f for f in schema_fields if not f.startswith("_"))
visualization = viz_selector.select(
request.question,
schema_str,
len(merged_items),
)
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
response_data = {
"question": request.question,
"sparql": None, # Could be populated from SPARQL result
"results": merged_items,
"visualization": visualization,
"sources_used": [s for s in sources],
"cache_hit": False,
"query_time_ms": round(elapsed_ms, 2),
"result_count": len(merged_items),
}
# Cache the response
await retriever.cache.set(request.question, request.sources, response_data)
return QueryResponse(**response_data) # type: ignore[arg-type]
@app.post("/api/rag/sparql", response_model=SPARQLResponse)
async def generate_sparql_endpoint(request: SPARQLRequest) -> SPARQLResponse:
"""Generate SPARQL query from natural language.
Uses DSPy with optional RAG enhancement for context.
"""
if not RETRIEVERS_AVAILABLE:
raise HTTPException(status_code=503, detail="SPARQL generator not available")
try:
result = generate_sparql(
request.question,
language=request.language,
context=request.context,
use_rag=request.use_rag,
)
return SPARQLResponse(
sparql=result["sparql"],
explanation=result.get("explanation", ""),
rag_used=result.get("rag_used", False),
retrieved_passages=result.get("retrieved_passages", []),
)
except Exception as e:
logger.exception("SPARQL generation failed")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/rag/visualize")
async def get_visualization_config(
question: str = Query(..., description="User's question"),
schema: str = Query(..., description="Comma-separated field names"),
result_count: int = Query(default=0, description="Number of results"),
) -> dict[str, Any]:
"""Get visualization configuration for a query."""
if not viz_selector:
raise HTTPException(status_code=503, detail="Visualization selector not available")
config = viz_selector.select(question, schema, result_count)
return config # type: ignore[no-any-return]
@app.post("/api/rag/typedb/search", response_model=TypeDBSearchResponse)
async def typedb_search(request: TypeDBSearchRequest) -> TypeDBSearchResponse:
"""Direct TypeDB search endpoint.
Search heritage custodians in TypeDB using various strategies:
- semantic: Natural language search (combines type + location patterns)
- name: Search by institution name
- type: Search by institution type (museum, archive, library, gallery)
- location: Search by city/location name
Examples:
- {"query": "museums in Amsterdam", "search_type": "semantic"}
- {"query": "Rijksmuseum", "search_type": "name"}
- {"query": "archive", "search_type": "type"}
- {"query": "Rotterdam", "search_type": "location"}
"""
import time
start_time = time.time()
# Check if TypeDB retriever is available
if not retriever or not retriever.typedb:
raise HTTPException(
status_code=503,
detail="TypeDB retriever not available. Ensure TypeDB is running."
)
try:
typedb_retriever = retriever.typedb
# Route to appropriate search method
if request.search_type == "name":
results = typedb_retriever.search_by_name(request.query, k=request.k)
elif request.search_type == "type":
results = typedb_retriever.search_by_type(request.query, k=request.k)
elif request.search_type == "location":
results = typedb_retriever.search_by_location(city=request.query, k=request.k)
else: # semantic (default)
results = typedb_retriever.semantic_search(request.query, k=request.k)
# Convert results to dicts
result_dicts = []
seen_names = set() # Deduplicate by name
for r in results:
# Handle both dict and object results
if hasattr(r, 'to_dict'):
item = r.to_dict()
elif isinstance(r, dict):
item = r
else:
item = {"name": str(r)}
# Deduplicate by name
name = item.get("name") or item.get("observed_name", "")
if name and name not in seen_names:
seen_names.add(name)
result_dicts.append(item)
elapsed_ms = (time.time() - start_time) * 1000
return TypeDBSearchResponse(
query=request.query,
search_type=request.search_type,
results=result_dicts,
result_count=len(result_dicts),
query_time_ms=round(elapsed_ms, 2),
)
except Exception as e:
logger.exception(f"TypeDB search failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/rag/persons/search", response_model=PersonSearchResponse)
async def person_search(request: PersonSearchRequest) -> PersonSearchResponse:
"""Search for persons/staff in heritage institutions.
Search the heritage_persons Qdrant collection for staff members
at heritage custodian institutions.
Examples:
- {"query": "Wie werkt er in het Nationaal Archief?"}
- {"query": "archivist at Rijksmuseum", "k": 20}
- {"query": "conservator", "filter_custodian": "rijksmuseum"}
- {"query": "digital preservation", "only_heritage_relevant": true}
The search uses semantic vector similarity to find relevant staff members
based on their name, role, headline, and custodian affiliation.
"""
import time
start_time = time.time()
# Check if retriever is available
if not retriever:
raise HTTPException(
status_code=503,
detail="Hybrid retriever not available. Ensure Qdrant is running."
)
try:
# Use the hybrid retriever's person search
results = retriever.search_persons(
query=request.query,
k=request.k,
filter_custodian=request.filter_custodian,
only_heritage_relevant=request.only_heritage_relevant,
using=request.embedding_model, # Pass embedding model
)
# Determine which embedding model was actually used
embedding_model_used = None
qdrant = retriever.qdrant
if qdrant and hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding:
if request.embedding_model:
embedding_model_used = request.embedding_model
elif hasattr(qdrant, '_selected_multi_model') and qdrant._selected_multi_model:
embedding_model_used = qdrant._selected_multi_model.value
# Convert results to dicts using to_dict() method if available
result_dicts = []
for r in results:
if hasattr(r, 'to_dict'):
item = r.to_dict()
elif hasattr(r, '__dict__'):
item = {
"name": getattr(r, 'name', 'Unknown'),
"headline": getattr(r, 'headline', None),
"custodian_name": getattr(r, 'custodian_name', None),
"custodian_slug": getattr(r, 'custodian_slug', None),
"linkedin_url": getattr(r, 'linkedin_url', None),
"heritage_relevant": getattr(r, 'heritage_relevant', None),
"heritage_type": getattr(r, 'heritage_type', None),
"location": getattr(r, 'location', None),
"score": getattr(r, 'combined_score', getattr(r, 'vector_score', None)),
}
elif isinstance(r, dict):
item = r
else:
item = {"name": str(r)}
result_dicts.append(item)
elapsed_ms = (time.time() - start_time) * 1000
# Get collection stats
stats = None
try:
stats = retriever.get_stats()
# Only include person collection stats if available
if stats and 'persons' in stats:
stats = {'persons': stats['persons']}
except Exception:
pass
return PersonSearchResponse(
query=request.query,
results=result_dicts,
result_count=len(result_dicts),
query_time_ms=round(elapsed_ms, 2),
collection_stats=stats,
embedding_model_used=embedding_model_used,
)
except Exception as e:
logger.exception(f"Person search failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/rag/dspy/query", response_model=DSPyQueryResponse)
async def dspy_query(request: DSPyQueryRequest) -> DSPyQueryResponse:
"""DSPy RAG query endpoint with multi-turn conversation support.
Uses the HeritageRAGPipeline for conversation-aware question answering.
Follow-up questions like "Welke daarvan behoren archieven?" will be
resolved using previous conversation context.
Args:
request: Query request with question, language, and conversation context
Returns:
DSPyQueryResponse with answer, resolved question, and optional visualization
"""
import time
start_time = time.time()
# Check cache first (before expensive LLM configuration)
if retriever:
cached = await retriever.cache.get_dspy(
question=request.question,
language=request.language,
llm_provider=request.llm_provider,
embedding_model=request.embedding_model,
context=request.context if request.context else None,
)
if cached:
elapsed_ms = (time.time() - start_time) * 1000
logger.info(f"DSPy cache hit - returning cached response in {elapsed_ms:.2f}ms")
# Add cache_hit flag and update timing
cached["query_time_ms"] = round(elapsed_ms, 2)
cached["cache_hit"] = True
return DSPyQueryResponse(**cached)
try:
# Import DSPy pipeline and History
import dspy
from dspy import History
from dspy_heritage_rag import HeritageRAGPipeline
# Configure DSPy LM per-request based on request.llm_provider (or server default)
# This allows frontend to switch LLM providers dynamically
#
# IMPORTANT: We use dspy.settings.context() instead of dspy.configure() because
# configure() can only be called from the same async task that initially configured DSPy.
# context() provides thread-local overrides that work correctly in async request handlers.
requested_provider = (request.llm_provider or settings.llm_provider).lower()
llm_provider_used: str | None = None
llm_model_used: str | None = None
lm = None
logger.info(f"LLM provider requested: {requested_provider} (request.llm_provider={request.llm_provider}, server default={settings.llm_provider})")
# Provider configuration priority: requested provider first, then fallback chain
providers_to_try = [requested_provider]
# Add fallback chain (but not duplicates)
for fallback in ["zai", "groq", "anthropic", "openai"]:
if fallback not in providers_to_try:
providers_to_try.append(fallback)
for provider in providers_to_try:
if lm is not None:
break
# Default models per provider (used if request.llm_model is not specified)
default_models = {
"zai": "glm-4.6",
"groq": "llama-3.1-8b-instant",
"anthropic": "claude-sonnet-4-20250514",
"openai": "gpt-4o-mini",
"huggingface": "meta-llama/Llama-3.1-8B-Instruct",
}
# HuggingFace models use org/model format (e.g., meta-llama/Llama-3.1-8B-Instruct)
# Groq models use simple names (e.g., llama-3.1-8b-instant)
model_prefixes = {
"glm-": "zai",
"llama-3.1-": "groq",
"llama-3.3-": "groq",
"claude-": "anthropic",
"gpt-": "openai",
# HuggingFace organization prefixes
"mistralai/": "huggingface",
"google/": "huggingface",
"Qwen/": "huggingface",
"deepseek-ai/": "huggingface",
"meta-llama/": "huggingface",
"utter-project/": "huggingface",
"microsoft/": "huggingface",
"tiiuae/": "huggingface",
}
# Determine which model to use: requested model (if valid for this provider) or default
requested_model = request.llm_model
model_to_use = default_models.get(provider, "")
# Check if requested model matches this provider
if requested_model:
for prefix, model_provider in model_prefixes.items():
if requested_model.startswith(prefix) and model_provider == provider:
model_to_use = requested_model
break
if provider == "zai" and settings.zai_api_token:
try:
lm = dspy.LM(
f"openai/{model_to_use}",
api_key=settings.zai_api_token,
api_base="https://api.z.ai/api/coding/paas/v4",
)
llm_provider_used = "zai"
llm_model_used = model_to_use
logger.info(f"Using Z.AI {model_to_use} (FREE) for this request")
except Exception as e:
logger.warning(f"Failed to create Z.AI LM: {e}")
elif provider == "groq" and settings.groq_api_key:
try:
lm = dspy.LM(f"groq/{model_to_use}", api_key=settings.groq_api_key)
llm_provider_used = "groq"
llm_model_used = model_to_use
logger.info(f"Using Groq {model_to_use} (FREE) for this request")
except Exception as e:
logger.warning(f"Failed to create Groq LM: {e}")
elif provider == "huggingface" and settings.huggingface_api_key:
try:
lm = dspy.LM(f"huggingface/{model_to_use}", api_key=settings.huggingface_api_key)
llm_provider_used = "huggingface"
llm_model_used = model_to_use
logger.info(f"Using HuggingFace {model_to_use} for this request")
except Exception as e:
logger.warning(f"Failed to create HuggingFace LM: {e}")
elif provider == "anthropic" and settings.anthropic_api_key:
try:
lm = dspy.LM(f"anthropic/{model_to_use}", api_key=settings.anthropic_api_key)
llm_provider_used = "anthropic"
llm_model_used = model_to_use
logger.info(f"Using Anthropic {model_to_use} for this request")
except Exception as e:
logger.warning(f"Failed to create Anthropic LM: {e}")
elif provider == "openai" and settings.openai_api_key:
try:
lm = dspy.LM(f"openai/{model_to_use}", api_key=settings.openai_api_key)
llm_provider_used = "openai"
llm_model_used = model_to_use
logger.info(f"Using OpenAI {model_to_use} for this request")
except Exception as e:
logger.warning(f"Failed to create OpenAI LM: {e}")
# No LM could be configured
if lm is None:
raise ValueError(
f"No LLM could be configured. Requested provider: {requested_provider}. "
"Ensure the appropriate API key is set: ZAI_API_TOKEN, GROQ_API_KEY, ANTHROPIC_API_KEY, HUGGINGFACE_API_KEY, or OPENAI_API_KEY."
)
logger.info(f"LLM provider for this request: {llm_provider_used}")
# =================================================================
# PERFORMANCE OPTIMIZATION: Create fast LM for routing/extraction
# Use a fast, cheap model (glm-4.5-flash FREE, gpt-4o-mini $0.15/1M)
# for routing, entity extraction, and SPARQL generation.
# The quality_lm (lm) is used only for final answer generation.
# This can reduce total latency by 2-3x (from ~20s to ~7s).
# =================================================================
fast_lm = None
# Try to create fast_lm based on FAST_LM_PROVIDER setting
# Options: "openai" (fast ~1-2s, $0.15/1M) or "zai" (FREE but slow ~13s)
# Default: openai for speed. Override with FAST_LM_PROVIDER=zai to save costs.
if settings.fast_lm_provider == "openai" and settings.openai_api_key:
try:
fast_lm = dspy.LM("openai/gpt-4o-mini", api_key=settings.openai_api_key)
logger.info("Using OpenAI GPT-4o-mini as fast_lm for routing/extraction (~1-2s)")
except Exception as e:
logger.warning(f"Failed to create fast OpenAI LM: {e}")
if fast_lm is None and settings.fast_lm_provider == "zai" and settings.zai_api_token:
try:
fast_lm = dspy.LM(
"openai/glm-4.5-flash",
api_key=settings.zai_api_token,
api_base="https://api.z.ai/api/coding/paas/v4",
)
logger.info("Using Z.AI GLM-4.5-flash (FREE) as fast_lm for routing/extraction (~13s)")
except Exception as e:
logger.warning(f"Failed to create fast Z.AI LM: {e}")
# Fallback: try the other provider if preferred one failed
if fast_lm is None and settings.openai_api_key:
try:
fast_lm = dspy.LM("openai/gpt-4o-mini", api_key=settings.openai_api_key)
logger.info("Fallback: Using OpenAI GPT-4o-mini as fast_lm")
except Exception as e:
logger.warning(f"Fallback failed - no fast_lm available: {e}")
if fast_lm is None:
logger.info("No fast_lm available - all stages will use quality_lm (slower but works)")
# Convert context to DSPy History format
# Context comes as [{question: "...", answer: "..."}, ...]
# History expects messages in the same format: [{question: "...", answer: "..."}, ...]
# (NOT role/content format - that was a bug!)
history_messages = []
for turn in request.context:
# Only include turns that have both question AND answer
if turn.get("question") and turn.get("answer"):
history_messages.append({
"question": turn["question"],
"answer": turn["answer"]
})
history = History(messages=history_messages) if history_messages else None
# Use global optimized pipeline (loaded with BootstrapFewShot weights: +14.3% quality)
# Falls back to creating a new pipeline if global not available
if dspy_pipeline is not None:
pipeline = dspy_pipeline
logger.debug("Using global optimized DSPy pipeline")
else:
# Fallback: create pipeline without optimized weights
qdrant_retriever = retriever.qdrant if retriever else None
pipeline = HeritageRAGPipeline(
retriever=qdrant_retriever,
fast_lm=fast_lm,
quality_lm=lm,
)
logger.debug("Using fallback (unoptimized) DSPy pipeline")
# Execute query with conversation history
# Retry logic for transient API errors (e.g., Anthropic "Overloaded" errors)
#
# IMPORTANT: We use dspy.settings.context(lm=lm) to set the LLM for this request.
# This provides thread-local overrides that work correctly in async request handlers,
# unlike dspy.configure() which can only be called from the main async task.
max_retries = 3
last_error: Exception | None = None
result = None
with dspy.settings.context(lm=lm):
for attempt in range(max_retries):
try:
# Use pipeline() instead of pipeline.forward() per DSPy 3.0 best practices
result = pipeline(
embedding_model=request.embedding_model,
question=request.question,
language=request.language,
history=history,
include_viz=request.include_visualization,
)
break # Success, exit retry loop
except Exception as e:
last_error = e
error_str = str(e).lower()
# Check for retryable errors (API overload, rate limits, temporary failures)
is_retryable = any(keyword in error_str for keyword in [
"overloaded", "rate_limit", "rate limit", "too many requests",
"529", "503", "502", "504", # HTTP status codes
"temporarily unavailable", "service unavailable",
"connection reset", "connection refused", "timeout"
])
if is_retryable and attempt < max_retries - 1:
wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
logger.warning(
f"Transient API error (attempt {attempt + 1}/{max_retries}): {e}. "
f"Retrying in {wait_time}s..."
)
time.sleep(wait_time)
continue
else:
# Non-retryable error or max retries reached
raise
# If we get here without a result (all retries exhausted), raise the last error
if result is None:
if last_error:
raise last_error
raise HTTPException(status_code=500, detail="Pipeline execution failed with no result")
elapsed_ms = (time.time() - start_time) * 1000
# Extract visualization if present
visualization = None
if request.include_visualization and hasattr(result, "visualization"):
viz = result.visualization
if viz:
visualization = {
"type": getattr(viz, "viz_type", "table"),
"sparql_query": getattr(result, "sparql", None),
}
# Extract retrieved results for frontend visualization (tables, graphs)
retrieved_results = getattr(result, "retrieved_results", None)
query_type = getattr(result, "query_type", None)
# Build response object
response = DSPyQueryResponse(
question=request.question,
resolved_question=getattr(result, "resolved_question", None),
answer=getattr(result, "answer", "Geen antwoord gevonden."),
sources_used=getattr(result, "sources_used", []),
visualization=visualization,
retrieved_results=retrieved_results, # Raw data for frontend visualization
query_type=query_type, # "person" or "institution"
query_time_ms=round(elapsed_ms, 2),
conversation_turn=len(request.context),
embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model),
# Cost tracking fields
timing_ms=getattr(result, "timing_ms", None),
cost_usd=getattr(result, "cost_usd", None),
timing_breakdown=getattr(result, "timing_breakdown", None),
# LLM provider tracking
llm_provider_used=llm_provider_used,
llm_model_used=llm_model_used,
cache_hit=False,
)
# Cache the successful response for future requests
if retriever:
await retriever.cache.set_dspy(
question=request.question,
language=request.language,
llm_provider=llm_provider_used, # Use actual provider, not requested
embedding_model=request.embedding_model,
response=response.model_dump(),
context=request.context if request.context else None,
)
return response
except ImportError as e:
logger.warning(f"DSPy pipeline not available: {e}")
# Fallback to simple response
return DSPyQueryResponse(
question=request.question,
answer="DSPy pipeline is niet beschikbaar. Probeer de standaard /api/rag/query endpoint.",
query_time_ms=0,
conversation_turn=len(request.context),
embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model),
)
except Exception as e:
logger.exception("DSPy query failed")
raise HTTPException(status_code=500, detail=str(e))
async def stream_dspy_query_response(
request: DSPyQueryRequest,
) -> AsyncIterator[str]:
"""Stream DSPy query response with progress updates for long-running queries.
Yields NDJSON lines with status updates at each pipeline stage:
- {"type": "status", "stage": "cache", "message": "🔍 Cache controleren..."}
- {"type": "status", "stage": "config", "message": "⚙️ LLM configureren..."}
- {"type": "status", "stage": "routing", "message": "🧭 Vraag analyseren..."}
- {"type": "status", "stage": "retrieval", "message": "📊 Database doorzoeken..."}
- {"type": "status", "stage": "generation", "message": "💡 Antwoord genereren..."}
- {"type": "complete", "data": {...DSPyQueryResponse...}}
"""
import time
start_time = time.time()
def emit_status(stage: str, message: str) -> str:
"""Helper to emit status JSON line."""
return json.dumps({
"type": "status",
"stage": stage,
"message": message,
"elapsed_ms": round((time.time() - start_time) * 1000, 2),
}) + "\n"
def emit_error(error: str, details: str | None = None) -> str:
"""Helper to emit error JSON line."""
return json.dumps({
"type": "error",
"error": error,
"details": details,
"elapsed_ms": round((time.time() - start_time) * 1000, 2),
}) + "\n"
def extract_user_friendly_error(exception: Exception) -> tuple[str, str | None]:
"""Extract a user-friendly error message from various exception types.
Returns:
tuple: (user_message, technical_details)
"""
error_str = str(exception)
error_lower = error_str.lower()
# HuggingFace / LiteLLM specific errors
if "huggingface" in error_lower or "hf" in error_lower:
if "model_not_supported" in error_lower or "not a chat model" in error_lower:
# Extract model name if present
import re
model_match = re.search(r"model['\"]?\s*[:=]\s*['\"]?([^'\"}\s,]+)", error_str)
model_name = model_match.group(1) if model_match else "geselecteerde model"
return (
f"Het model '{model_name}' wordt niet ondersteund door HuggingFace. Kies een ander model.",
error_str
)
if "rate limit" in error_lower or "too many requests" in error_lower:
return (
"HuggingFace API limiet bereikt. Probeer het over een minuut opnieuw.",
error_str
)
if "unauthorized" in error_lower or "invalid api key" in error_lower:
return (
"HuggingFace API sleutel ongeldig. Neem contact op met de beheerder.",
error_str
)
if "model is loading" in error_lower or "loading" in error_lower and "model" in error_lower:
return (
"Het HuggingFace model wordt geladen. Probeer het over 30 seconden opnieuw.",
error_str
)
# Anthropic errors
if "anthropic" in error_lower:
if "rate limit" in error_lower or "overloaded" in error_lower:
return (
"Anthropic API is overbelast. Probeer het over een minuut opnieuw.",
error_str
)
if "invalid api key" in error_lower or "unauthorized" in error_lower:
return (
"Anthropic API sleutel ongeldig. Neem contact op met de beheerder.",
error_str
)
# OpenAI errors
if "openai" in error_lower:
if "rate limit" in error_lower:
return (
"OpenAI API limiet bereikt. Probeer het over een minuut opnieuw.",
error_str
)
if "invalid api key" in error_lower:
return (
"OpenAI API sleutel ongeldig. Neem contact op met de beheerder.",
error_str
)
# Z.AI errors
if "z.ai" in error_lower or "zai" in error_lower:
if "rate limit" in error_lower or "quota" in error_lower:
return (
"Z.AI API limiet bereikt. Probeer het over een minuut opnieuw.",
error_str
)
# Generic network/connection errors
if "connection" in error_lower or "timeout" in error_lower:
return (
"Verbindingsfout met de AI service. Controleer uw internetverbinding en probeer het opnieuw.",
error_str
)
if "503" in error_str or "service unavailable" in error_lower:
return (
"De AI service is tijdelijk niet beschikbaar. Probeer het over een minuut opnieuw.",
error_str
)
# Qdrant/retrieval errors
if "qdrant" in error_lower:
return (
"Fout bij het doorzoeken van de database. Probeer het later opnieuw.",
error_str
)
# Default: return the raw error but in a nicer format
return (
f"Er is een fout opgetreden: {error_str[:200]}{'...' if len(error_str) > 200 else ''}",
error_str if len(error_str) > 200 else None
)
# Stage 1: Check cache
yield emit_status("cache", "🔍 Cache controleren...")
if retriever:
cached = await retriever.cache.get_dspy(
question=request.question,
language=request.language,
llm_provider=request.llm_provider,
embedding_model=request.embedding_model,
context=request.context if request.context else None,
)
if cached:
elapsed_ms = (time.time() - start_time) * 1000
logger.info(f"DSPy cache hit - returning cached response in {elapsed_ms:.2f}ms")
cached["query_time_ms"] = round(elapsed_ms, 2)
cached["cache_hit"] = True
yield emit_status("cache", "✅ Antwoord gevonden in cache!")
yield json.dumps({"type": "complete", "data": cached}) + "\n"
return
try:
# Stage 2: Configure LLM
yield emit_status("config", "⚙️ LLM configureren...")
import dspy
from dspy import History
from dspy_heritage_rag import HeritageRAGPipeline
requested_provider = (request.llm_provider or settings.llm_provider).lower()
llm_provider_used: str | None = None
llm_model_used: str | None = None
lm = None
providers_to_try = [requested_provider]
for fallback in ["zai", "groq", "anthropic", "openai"]:
if fallback not in providers_to_try:
providers_to_try.append(fallback)
for provider in providers_to_try:
if lm is not None:
break
# Default models per provider (used if request.llm_model is not specified)
default_models = {
"zai": "glm-4.6",
"groq": "llama-3.1-8b-instant",
"anthropic": "claude-sonnet-4-20250514",
"openai": "gpt-4o-mini",
"huggingface": "meta-llama/Llama-3.1-8B-Instruct",
}
# HuggingFace models use org/model format (e.g., meta-llama/Llama-3.1-8B-Instruct)
# Groq models use simple names (e.g., llama-3.1-8b-instant)
model_prefixes = {
"glm-": "zai",
"llama-3.1-": "groq",
"llama-3.3-": "groq",
"claude-": "anthropic",
"gpt-": "openai",
# HuggingFace organization prefixes
"mistralai/": "huggingface",
"google/": "huggingface",
"Qwen/": "huggingface",
"deepseek-ai/": "huggingface",
"meta-llama/": "huggingface",
"utter-project/": "huggingface",
"microsoft/": "huggingface",
"tiiuae/": "huggingface",
}
# Determine which model to use: requested model (if valid for this provider) or default
requested_model = request.llm_model
model_to_use = default_models.get(provider, "")
# Check if requested model matches this provider
if requested_model:
for prefix, model_provider in model_prefixes.items():
if requested_model.startswith(prefix) and model_provider == provider:
model_to_use = requested_model
break
if provider == "zai" and settings.zai_api_token:
try:
lm = dspy.LM(
f"openai/{model_to_use}",
api_key=settings.zai_api_token,
api_base="https://api.z.ai/api/coding/paas/v4",
)
llm_provider_used = "zai"
llm_model_used = model_to_use
except Exception as e:
logger.warning(f"Failed to create Z.AI LM: {e}")
elif provider == "groq" and settings.groq_api_key:
try:
lm = dspy.LM(f"groq/{model_to_use}", api_key=settings.groq_api_key)
llm_provider_used = "groq"
llm_model_used = model_to_use
logger.info(f"Using Groq {model_to_use} (FREE) for streaming request")
except Exception as e:
logger.warning(f"Failed to create Groq LM: {e}")
elif provider == "huggingface" and settings.huggingface_api_key:
try:
lm = dspy.LM(f"huggingface/{model_to_use}", api_key=settings.huggingface_api_key)
llm_provider_used = "huggingface"
llm_model_used = model_to_use
except Exception as e:
logger.warning(f"Failed to create HuggingFace LM: {e}")
elif provider == "anthropic" and settings.anthropic_api_key:
try:
lm = dspy.LM(f"anthropic/{model_to_use}", api_key=settings.anthropic_api_key)
llm_provider_used = "anthropic"
llm_model_used = model_to_use
except Exception as e:
logger.warning(f"Failed to create Anthropic LM: {e}")
elif provider == "openai" and settings.openai_api_key:
try:
lm = dspy.LM(f"openai/{model_to_use}", api_key=settings.openai_api_key)
llm_provider_used = "openai"
llm_model_used = model_to_use
except Exception as e:
logger.warning(f"Failed to create OpenAI LM: {e}")
if lm is None:
yield emit_error(f"Geen LLM beschikbaar. Controleer API keys.")
return
yield emit_status("config", f"✅ LLM geconfigureerd ({llm_provider_used})")
# Stage 3: Prepare conversation history
yield emit_status("routing", "🧭 Vraag analyseren...")
history_messages = []
for turn in request.context:
if turn.get("question") and turn.get("answer"):
history_messages.append({
"question": turn["question"],
"answer": turn["answer"]
})
history = History(messages=history_messages) if history_messages else None
# Use global optimized pipeline (loaded with BootstrapFewShot weights: +14.3% quality)
if dspy_pipeline is not None:
pipeline = dspy_pipeline
logger.debug("Using global optimized DSPy pipeline (streaming)")
else:
# Fallback: create pipeline without optimized weights
qdrant_retriever = retriever.qdrant if retriever else None
pipeline = HeritageRAGPipeline(retriever=qdrant_retriever)
logger.debug("Using fallback (unoptimized) DSPy pipeline (streaming)")
# Stage 4: Execute pipeline with status updates
yield emit_status("retrieval", "📊 Database doorzoeken...")
max_retries = 3
last_error: Exception | None = None
result = None
with dspy.settings.context(lm=lm):
for attempt in range(max_retries):
try:
# Emit progress for retries
if attempt > 0:
yield emit_status("retrieval", f"🔄 Opnieuw proberen ({attempt + 1}/{max_retries})...")
# Use pipeline() instead of pipeline.forward() per DSPy 3.0 best practices
result = pipeline(
embedding_model=request.embedding_model,
question=request.question,
language=request.language,
history=history,
include_viz=request.include_visualization,
)
break
except Exception as e:
last_error = e
error_str = str(e).lower()
is_retryable = any(keyword in error_str for keyword in [
"overloaded", "rate_limit", "rate limit", "too many requests",
"529", "503", "502", "504",
"temporarily unavailable", "service unavailable",
"connection reset", "connection refused", "timeout"
])
if is_retryable and attempt < max_retries - 1:
wait_time = 2 ** attempt
logger.warning(f"Transient API error (attempt {attempt + 1}/{max_retries}): {e}")
yield emit_status("retrieval", f"⏳ API overbelast, wachten {wait_time}s...")
await asyncio.sleep(wait_time)
continue
else:
# Don't re-raise - yield error directly to ensure it reaches frontend
# Re-raising in async generators can fail to propagate to outer except blocks
logger.exception(f"Pipeline execution failed after {attempt + 1} attempts")
user_msg, details = extract_user_friendly_error(e)
yield emit_error(user_msg, details)
return
if result is None:
if last_error:
user_msg, details = extract_user_friendly_error(last_error)
yield emit_error(user_msg, details)
return
yield emit_error("Pipeline uitvoering mislukt zonder resultaat")
return
# Stage 5: Generate response
yield emit_status("generation", "💡 Antwoord genereren...")
elapsed_ms = (time.time() - start_time) * 1000
visualization = None
if request.include_visualization and hasattr(result, "visualization"):
viz = result.visualization
if viz:
visualization = {
"type": getattr(viz, "viz_type", "table"),
"sparql_query": getattr(result, "sparql", None),
}
retrieved_results = getattr(result, "retrieved_results", None)
query_type = getattr(result, "query_type", None)
response = DSPyQueryResponse(
question=request.question,
resolved_question=getattr(result, "resolved_question", None),
answer=getattr(result, "answer", "Geen antwoord gevonden."),
sources_used=getattr(result, "sources_used", []),
visualization=visualization,
retrieved_results=retrieved_results,
query_type=query_type,
query_time_ms=round(elapsed_ms, 2),
conversation_turn=len(request.context),
embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model),
timing_ms=getattr(result, "timing_ms", None),
cost_usd=getattr(result, "cost_usd", None),
timing_breakdown=getattr(result, "timing_breakdown", None),
llm_provider_used=llm_provider_used,
llm_model_used=llm_model_used,
cache_hit=False,
)
# Cache the response
if retriever:
await retriever.cache.set_dspy(
question=request.question,
language=request.language,
llm_provider=llm_provider_used,
embedding_model=request.embedding_model,
response=response.model_dump(),
context=request.context if request.context else None,
)
yield emit_status("complete", "✅ Klaar!")
yield json.dumps({"type": "complete", "data": response.model_dump()}) + "\n"
except ImportError as e:
logger.warning(f"DSPy pipeline not available: {e}")
yield emit_error("DSPy pipeline is niet beschikbaar.")
except Exception as e:
logger.exception("DSPy streaming query failed")
user_msg, details = extract_user_friendly_error(e)
yield emit_error(user_msg, details)
@app.post("/api/rag/dspy/query/stream")
async def dspy_query_stream(request: DSPyQueryRequest) -> StreamingResponse:
"""Streaming version of DSPy RAG query endpoint.
Returns NDJSON stream with status updates at each pipeline stage,
allowing the frontend to show progress during long-running queries.
Status stages:
- cache: Checking for cached response
- config: Configuring LLM provider
- routing: Analyzing query intent
- retrieval: Searching databases (Qdrant, SPARQL, etc.)
- generation: Generating answer with LLM
- complete: Final response ready
"""
return StreamingResponse(
stream_dspy_query_response(request),
media_type="application/x-ndjson",
)
async def stream_query_response(
request: QueryRequest,
) -> AsyncIterator[str]:
"""Stream query response for long-running queries."""
if not retriever:
yield json.dumps({"error": "Retriever not initialized"})
return
start_time = asyncio.get_event_loop().time()
# Route query
intent, sources = retriever.router.get_sources(request.question, request.sources)
# Extract geographic filters from question (province, city, institution type)
geo_filters = extract_geographic_filters(request.question)
yield json.dumps({
"type": "status",
"message": f"Routing query to {len(sources)} sources...",
"intent": intent.value,
"geo_filters": {k: v for k, v in geo_filters.items() if v},
}) + "\n"
# Retrieve from sources and stream progress
results = []
for source in sources:
yield json.dumps({
"type": "status",
"message": f"Querying {source.value}...",
}) + "\n"
source_results = await retriever.retrieve(
request.question,
[source],
request.k,
embedding_model=request.embedding_model,
region_codes=geo_filters["region_codes"],
cities=geo_filters["cities"],
institution_types=geo_filters["institution_types"],
)
results.extend(source_results)
yield json.dumps({
"type": "partial",
"source": source.value,
"count": len(source_results[0].items) if source_results else 0,
}) + "\n"
# Merge and finalize
merged = retriever.merge_results(results)
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
yield json.dumps({
"type": "complete",
"results": merged,
"query_time_ms": round(elapsed_ms, 2),
"result_count": len(merged),
}) + "\n"
@app.post("/api/rag/query/stream")
async def query_rag_stream(request: QueryRequest) -> StreamingResponse:
"""Streaming version of RAG query endpoint."""
return StreamingResponse(
stream_query_response(request),
media_type="application/x-ndjson",
)
# =============================================================================
# SEMANTIC CACHE ENDPOINTS (Qdrant-backed)
# =============================================================================
#
# High-performance semantic cache using Qdrant's HNSW vector index.
# Replaces slow client-side cosine similarity with server-side ANN search.
#
# Performance target:
# - Cache lookup: <20ms (vs 500-2000ms with client-side scan)
# - Cache store: <50ms
#
# Architecture:
# Frontend → /api/cache/lookup → Qdrant ANN search → cached response
# /api/cache/store → embed + upsert to Qdrant
# =============================================================================
# Lazy-loaded Qdrant client for cache
_cache_qdrant_client: Any = None
_cache_embedding_model: Any = None
CACHE_COLLECTION_NAME = "query_cache"
CACHE_EMBEDDING_DIM = 384 # all-MiniLM-L6-v2
def get_cache_qdrant_client() -> Any:
"""Get or create Qdrant client for cache collection.
Always uses localhost:6333 since cache is co-located with the RAG backend.
This avoids reverse proxy overhead and ensures direct local connection.
"""
global _cache_qdrant_client
if _cache_qdrant_client is not None:
return _cache_qdrant_client
try:
from qdrant_client import QdrantClient
# Cache always uses localhost - co-located with RAG backend
# Uses settings.qdrant_host/port which default to localhost:6333
_cache_qdrant_client = QdrantClient(
host=settings.qdrant_host,
port=settings.qdrant_port,
timeout=30,
)
logger.info(f"Qdrant cache client: {settings.qdrant_host}:{settings.qdrant_port}")
return _cache_qdrant_client
except ImportError:
logger.error("qdrant-client not installed")
return None
except Exception as e:
logger.error(f"Failed to create Qdrant cache client: {e}")
return None
def get_cache_embedding_model() -> Any:
"""Get or create embedding model for cache (MiniLM-L6-v2, 384-dim)."""
global _cache_embedding_model
if _cache_embedding_model is not None:
return _cache_embedding_model
try:
from sentence_transformers import SentenceTransformer
_cache_embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
logger.info("Loaded cache embedding model: all-MiniLM-L6-v2")
return _cache_embedding_model
except ImportError:
logger.error("sentence-transformers not installed")
return None
except Exception as e:
logger.error(f"Failed to load cache embedding model: {e}")
return None
def ensure_cache_collection_exists() -> bool:
"""Ensure the query_cache collection exists in Qdrant."""
client = get_cache_qdrant_client()
if client is None:
return False
try:
from qdrant_client.models import Distance, VectorParams
# Check if collection exists
collections = client.get_collections().collections
if any(c.name == CACHE_COLLECTION_NAME for c in collections):
return True
# Create collection with HNSW index
client.create_collection(
collection_name=CACHE_COLLECTION_NAME,
vectors_config=VectorParams(
size=CACHE_EMBEDDING_DIM,
distance=Distance.COSINE,
),
)
logger.info(f"Created Qdrant collection: {CACHE_COLLECTION_NAME}")
return True
except Exception as e:
logger.error(f"Failed to ensure cache collection: {e}")
return False
# Request/Response Models for Cache API
class CacheLookupRequest(BaseModel):
"""Cache lookup request."""
query: str = Field(..., description="Query text to look up")
embedding: list[float] | None = Field(default=None, description="Pre-computed embedding (optional)")
similarity_threshold: float = Field(default=0.92, description="Minimum similarity for match")
language: str = Field(default="nl", description="Language filter")
class CacheLookupResponse(BaseModel):
"""Cache lookup response."""
found: bool
entry: dict[str, Any] | None = None
similarity: float = 0.0
method: str = "none"
lookup_time_ms: float = 0.0
class CacheStoreRequest(BaseModel):
"""Cache store request."""
query: str = Field(..., description="Query text")
embedding: list[float] | None = Field(default=None, description="Pre-computed embedding (optional)")
response: dict[str, Any] = Field(..., description="Response to cache")
language: str = Field(default="nl", description="Language")
model: str = Field(default="unknown", description="LLM model used")
ttl_seconds: int = Field(default=86400, description="Time-to-live in seconds")
class CacheStoreResponse(BaseModel):
"""Cache store response."""
success: bool
id: str | None = None
message: str = ""
class CacheStatsResponse(BaseModel):
"""Cache statistics response."""
total_entries: int = 0
collection_name: str = CACHE_COLLECTION_NAME
embedding_dim: int = CACHE_EMBEDDING_DIM
backend: str = "qdrant"
status: str = "ok"
@app.post("/api/cache/lookup", response_model=CacheLookupResponse)
async def cache_lookup(request: CacheLookupRequest) -> CacheLookupResponse:
"""Look up a query in the semantic cache using Qdrant ANN search.
This endpoint performs sub-millisecond vector similarity search using
Qdrant's HNSW index, replacing slow client-side cosine similarity scans.
"""
import time
start_time = time.perf_counter()
# Ensure collection exists
if not ensure_cache_collection_exists():
return CacheLookupResponse(
found=False,
similarity=0.0,
method="error",
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
)
client = get_cache_qdrant_client()
if client is None:
return CacheLookupResponse(
found=False,
similarity=0.0,
method="error",
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
)
# Get or generate embedding
embedding = request.embedding
if embedding is None:
model = get_cache_embedding_model()
if model is None:
return CacheLookupResponse(
found=False,
similarity=0.0,
method="error",
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
)
embedding = model.encode(request.query).tolist()
try:
from qdrant_client.models import Filter, FieldCondition, MatchValue
# Build filter for language
search_filter = Filter(
must=[
FieldCondition(
key="language",
match=MatchValue(value=request.language),
)
]
)
# Perform ANN search using query_points (qdrant-client >= 1.7)
results = client.query_points(
collection_name=CACHE_COLLECTION_NAME,
query=embedding,
query_filter=search_filter,
limit=1,
score_threshold=request.similarity_threshold,
).points
elapsed_ms = (time.perf_counter() - start_time) * 1000
if not results:
return CacheLookupResponse(
found=False,
similarity=0.0,
method="semantic",
lookup_time_ms=elapsed_ms,
)
# Extract best match
best = results[0]
payload = best.payload or {}
return CacheLookupResponse(
found=True,
entry={
"id": str(best.id),
"query": payload.get("query", ""),
"query_normalized": payload.get("query_normalized", ""),
"response": payload.get("response", {}),
"timestamp": payload.get("timestamp", 0),
"hit_count": payload.get("hit_count", 0),
"last_accessed": payload.get("last_accessed", 0),
"language": payload.get("language", "nl"),
"model": payload.get("model", "unknown"),
},
similarity=best.score,
method="semantic",
lookup_time_ms=elapsed_ms,
)
except Exception as e:
logger.error(f"Cache lookup error: {e}")
return CacheLookupResponse(
found=False,
similarity=0.0,
method="error",
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
)
@app.post("/api/cache/store", response_model=CacheStoreResponse)
async def cache_store(request: CacheStoreRequest) -> CacheStoreResponse:
"""Store a query/response pair in the semantic cache.
Generates embedding if not provided and upserts to Qdrant.
"""
import time
import uuid
# Ensure collection exists
if not ensure_cache_collection_exists():
return CacheStoreResponse(
success=False,
message="Failed to ensure cache collection exists",
)
client = get_cache_qdrant_client()
if client is None:
return CacheStoreResponse(
success=False,
message="Qdrant client not available",
)
# Get or generate embedding
embedding = request.embedding
if embedding is None:
model = get_cache_embedding_model()
if model is None:
return CacheStoreResponse(
success=False,
message="Embedding model not available",
)
embedding = model.encode(request.query).tolist()
try:
from qdrant_client.models import PointStruct
# Generate unique ID
point_id = str(uuid.uuid4())
timestamp = int(time.time() * 1000)
# Normalize query for exact matching
query_normalized = request.query.lower().strip()
# Create point
point = PointStruct(
id=point_id,
vector=embedding,
payload={
"query": request.query,
"query_normalized": query_normalized,
"response": request.response,
"language": request.language,
"model": request.model,
"timestamp": timestamp,
"hit_count": 0,
"last_accessed": timestamp,
"ttl_seconds": request.ttl_seconds,
},
)
# Upsert to Qdrant
client.upsert(
collection_name=CACHE_COLLECTION_NAME,
points=[point],
)
logger.debug(f"Cached query: {request.query[:50]}...")
return CacheStoreResponse(
success=True,
id=point_id,
message="Stored successfully",
)
except Exception as e:
logger.error(f"Cache store error: {e}")
return CacheStoreResponse(
success=False,
message=str(e),
)
@app.get("/api/cache/stats", response_model=CacheStatsResponse)
async def cache_stats() -> CacheStatsResponse:
"""Get cache statistics."""
client = get_cache_qdrant_client()
if client is None:
return CacheStatsResponse(
status="error",
total_entries=0,
)
try:
# Check if collection exists
collections = client.get_collections().collections
if not any(c.name == CACHE_COLLECTION_NAME for c in collections):
return CacheStatsResponse(
status="no_collection",
total_entries=0,
)
# Get collection info
info = client.get_collection(CACHE_COLLECTION_NAME)
return CacheStatsResponse(
total_entries=info.points_count,
collection_name=CACHE_COLLECTION_NAME,
embedding_dim=CACHE_EMBEDDING_DIM,
backend="qdrant",
status="ok",
)
except Exception as e:
logger.error(f"Cache stats error: {e}")
return CacheStatsResponse(
status=f"error: {e}",
total_entries=0,
)
@app.delete("/api/cache/clear")
async def cache_clear() -> dict[str, Any]:
"""Clear all cache entries (admin only)."""
client = get_cache_qdrant_client()
if client is None:
return {"success": False, "message": "Qdrant client not available"}
try:
# Check if collection exists
collections = client.get_collections().collections
if not any(c.name == CACHE_COLLECTION_NAME for c in collections):
return {"success": True, "message": "Collection does not exist", "deleted": 0}
# Get count before deletion
info = client.get_collection(CACHE_COLLECTION_NAME)
count = info.points_count
# Delete and recreate collection
client.delete_collection(CACHE_COLLECTION_NAME)
ensure_cache_collection_exists()
return {"success": True, "message": f"Cleared {count} entries", "deleted": count}
except Exception as e:
logger.error(f"Cache clear error: {e}")
return {"success": False, "message": str(e)}
# Main entry point
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8003,
reload=settings.debug,
log_level="info",
)