- Implemented a new script `test_pico_arabic_waqf.py` to test the GLM annotator's ability to extract person observations from Arabic historical documents. - The script includes environment variable handling for API token, structured prompts for the GLM API, and validation of extraction results. - Added comprehensive logging for API responses, extraction results, and validation errors. - Included a sample Arabic waqf text for testing purposes, following the PiCo ontology pattern.
1091 lines
38 KiB
Python
1091 lines
38 KiB
Python
from __future__ import annotations
|
|
|
|
"""
|
|
Unified RAG Backend for Heritage Custodian Data
|
|
|
|
Multi-source retrieval-augmented generation system that combines:
|
|
- Qdrant vector search (semantic similarity)
|
|
- Oxigraph SPARQL (knowledge graph queries)
|
|
- TypeDB (relationship traversal)
|
|
- PostGIS (geospatial queries)
|
|
- Valkey (semantic caching)
|
|
|
|
Architecture:
|
|
User Query → Query Analysis
|
|
↓
|
|
┌─────┴─────┐
|
|
│ Router │
|
|
└─────┬─────┘
|
|
┌─────┬─────┼─────┬─────┐
|
|
↓ ↓ ↓ ↓ ↓
|
|
Qdrant SPARQL TypeDB PostGIS Cache
|
|
│ │ │ │ │
|
|
└─────┴─────┴─────┴─────┘
|
|
↓
|
|
┌─────┴─────┐
|
|
│ Merger │
|
|
└─────┬─────┘
|
|
↓
|
|
DSPy Generator
|
|
↓
|
|
Visualization Selector
|
|
↓
|
|
Response (JSON/Streaming)
|
|
|
|
Features:
|
|
- Intelligent query routing to appropriate data sources
|
|
- Score fusion for multi-source results
|
|
- Semantic caching via Valkey API
|
|
- Streaming responses for long-running queries
|
|
- DSPy assertions for output validation
|
|
|
|
Endpoints:
|
|
- POST /api/rag/query - Main RAG query endpoint
|
|
- POST /api/rag/sparql - Generate SPARQL with RAG context
|
|
- GET /api/rag/health - Health check for all services
|
|
- GET /api/rag/stats - Retriever statistics
|
|
"""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from enum import Enum
|
|
from typing import Any, AsyncIterator
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, HTTPException, Query
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Import retrievers (with graceful fallbacks)
|
|
try:
|
|
import sys
|
|
sys.path.insert(0, str(os.path.join(os.path.dirname(__file__), "..", "..", "src")))
|
|
from glam_extractor.api.hybrid_retriever import HybridRetriever, create_hybrid_retriever
|
|
from glam_extractor.api.qdrant_retriever import HeritageCustodianRetriever
|
|
from glam_extractor.api.typedb_retriever import TypeDBRetriever, create_typedb_retriever
|
|
from glam_extractor.api.visualization import select_visualization, VisualizationSelector
|
|
from glam_extractor.api.dspy_sparql import generate_sparql, configure_dspy
|
|
RETRIEVERS_AVAILABLE = True
|
|
except ImportError as e:
|
|
logger.warning(f"Some retrievers not available: {e}")
|
|
RETRIEVERS_AVAILABLE = False
|
|
|
|
|
|
# Configuration
|
|
class Settings:
|
|
"""Application settings from environment variables."""
|
|
|
|
# API Configuration
|
|
api_title: str = "Heritage RAG API"
|
|
api_version: str = "1.0.0"
|
|
debug: bool = os.getenv("DEBUG", "false").lower() == "true"
|
|
|
|
# Valkey Cache
|
|
valkey_api_url: str = os.getenv("VALKEY_API_URL", "https://bronhouder.nl/api/cache")
|
|
cache_ttl: int = int(os.getenv("CACHE_TTL", "900")) # 15 minutes
|
|
|
|
# Qdrant Vector DB
|
|
# Production: Use URL-based client via bronhouder.nl/qdrant reverse proxy
|
|
qdrant_host: str = os.getenv("QDRANT_HOST", "localhost")
|
|
qdrant_port: int = int(os.getenv("QDRANT_PORT", "6333"))
|
|
qdrant_use_production: bool = os.getenv("QDRANT_USE_PRODUCTION", "true").lower() == "true"
|
|
qdrant_production_url: str = os.getenv("QDRANT_PRODUCTION_URL", "https://bronhouder.nl/qdrant")
|
|
|
|
# Oxigraph SPARQL
|
|
# Production: Use bronhouder.nl/sparql reverse proxy
|
|
sparql_endpoint: str = os.getenv("SPARQL_ENDPOINT", "https://bronhouder.nl/sparql")
|
|
|
|
# TypeDB
|
|
# Note: TypeDB not exposed via reverse proxy - always use localhost
|
|
typedb_host: str = os.getenv("TYPEDB_HOST", "localhost")
|
|
typedb_port: int = int(os.getenv("TYPEDB_PORT", "1729"))
|
|
typedb_database: str = os.getenv("TYPEDB_DATABASE", "heritage_custodians")
|
|
typedb_use_production: bool = os.getenv("TYPEDB_USE_PRODUCTION", "false").lower() == "true" # Default off
|
|
|
|
# PostGIS/Geo API
|
|
# Production: Use bronhouder.nl/api/geo reverse proxy
|
|
postgis_url: str = os.getenv("POSTGIS_URL", "https://bronhouder.nl/api/geo")
|
|
|
|
# LLM Configuration
|
|
anthropic_api_key: str = os.getenv("ANTHROPIC_API_KEY", "")
|
|
openai_api_key: str = os.getenv("OPENAI_API_KEY", "")
|
|
default_model: str = os.getenv("DEFAULT_MODEL", "claude-opus-4-5-20251101")
|
|
|
|
# Retrieval weights
|
|
vector_weight: float = float(os.getenv("VECTOR_WEIGHT", "0.5"))
|
|
graph_weight: float = float(os.getenv("GRAPH_WEIGHT", "0.3"))
|
|
typedb_weight: float = float(os.getenv("TYPEDB_WEIGHT", "0.2"))
|
|
|
|
|
|
settings = Settings()
|
|
|
|
|
|
# Enums and Models
|
|
class QueryIntent(str, Enum):
|
|
"""Detected query intent for routing."""
|
|
GEOGRAPHIC = "geographic" # Location-based queries
|
|
STATISTICAL = "statistical" # Counts, aggregations
|
|
RELATIONAL = "relational" # Relationships between entities
|
|
TEMPORAL = "temporal" # Historical, timeline queries
|
|
SEARCH = "search" # General text search
|
|
DETAIL = "detail" # Specific entity lookup
|
|
|
|
|
|
class DataSource(str, Enum):
|
|
"""Available data sources."""
|
|
QDRANT = "qdrant"
|
|
SPARQL = "sparql"
|
|
TYPEDB = "typedb"
|
|
POSTGIS = "postgis"
|
|
CACHE = "cache"
|
|
|
|
|
|
@dataclass
|
|
class RetrievalResult:
|
|
"""Result from a single retriever."""
|
|
source: DataSource
|
|
items: list[dict[str, Any]]
|
|
score: float = 0.0
|
|
query_time_ms: float = 0.0
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
class QueryRequest(BaseModel):
|
|
"""RAG query request."""
|
|
question: str = Field(..., description="Natural language question")
|
|
language: str = Field(default="nl", description="Language code (nl or en)")
|
|
context: list[dict[str, Any]] = Field(default=[], description="Conversation history")
|
|
sources: list[DataSource] = Field(
|
|
default=[DataSource.QDRANT, DataSource.SPARQL],
|
|
description="Data sources to query",
|
|
)
|
|
k: int = Field(default=10, description="Number of results per source")
|
|
include_visualization: bool = Field(default=True, description="Include visualization config")
|
|
stream: bool = Field(default=False, description="Stream response")
|
|
|
|
|
|
class QueryResponse(BaseModel):
|
|
"""RAG query response."""
|
|
question: str
|
|
sparql: str | None = None
|
|
results: list[dict[str, Any]]
|
|
visualization: dict[str, Any] | None = None
|
|
sources_used: list[DataSource]
|
|
cache_hit: bool = False
|
|
query_time_ms: float
|
|
result_count: int
|
|
|
|
|
|
class SPARQLRequest(BaseModel):
|
|
"""SPARQL generation request."""
|
|
question: str
|
|
language: str = "nl"
|
|
context: list[dict[str, Any]] = []
|
|
use_rag: bool = True
|
|
|
|
|
|
class SPARQLResponse(BaseModel):
|
|
"""SPARQL generation response."""
|
|
sparql: str
|
|
explanation: str
|
|
rag_used: bool
|
|
retrieved_passages: list[str] = []
|
|
|
|
|
|
class DSPyQueryRequest(BaseModel):
|
|
"""DSPy RAG query request with conversation support."""
|
|
question: str = Field(..., description="Natural language question")
|
|
language: str = Field(default="nl", description="Language code (nl or en)")
|
|
context: list[dict[str, Any]] = Field(
|
|
default=[],
|
|
description="Conversation history as list of {question, answer} dicts"
|
|
)
|
|
include_visualization: bool = Field(default=True, description="Include visualization config")
|
|
|
|
|
|
class DSPyQueryResponse(BaseModel):
|
|
"""DSPy RAG query response."""
|
|
question: str
|
|
resolved_question: str | None = None
|
|
answer: str
|
|
sources_used: list[str] = []
|
|
visualization: dict[str, Any] | None = None
|
|
query_time_ms: float = 0.0
|
|
conversation_turn: int = 0
|
|
|
|
|
|
# Cache Client
|
|
class ValkeyClient:
|
|
"""Client for Valkey semantic cache API."""
|
|
|
|
def __init__(self, base_url: str = settings.valkey_api_url):
|
|
self.base_url = base_url.rstrip("/")
|
|
self._client: httpx.AsyncClient | None = None
|
|
|
|
@property
|
|
async def client(self) -> httpx.AsyncClient:
|
|
"""Get or create async HTTP client."""
|
|
if self._client is None or self._client.is_closed:
|
|
self._client = httpx.AsyncClient(timeout=30.0)
|
|
return self._client
|
|
|
|
def _cache_key(self, question: str, sources: list[DataSource]) -> str:
|
|
"""Generate cache key from question and sources."""
|
|
sources_str = ",".join(sorted(s.value for s in sources))
|
|
key_str = f"{question.lower().strip()}:{sources_str}"
|
|
return hashlib.sha256(key_str.encode()).hexdigest()[:32]
|
|
|
|
async def get(self, question: str, sources: list[DataSource]) -> dict[str, Any] | None:
|
|
"""Get cached response."""
|
|
try:
|
|
key = self._cache_key(question, sources)
|
|
client = await self.client
|
|
response = await client.get(f"{self.base_url}/get/{key}")
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
if data.get("value"):
|
|
logger.info(f"Cache hit for question: {question[:50]}...")
|
|
return json.loads(data["value"])
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Cache get failed: {e}")
|
|
return None
|
|
|
|
async def set(
|
|
self,
|
|
question: str,
|
|
sources: list[DataSource],
|
|
response: dict[str, Any],
|
|
ttl: int = settings.cache_ttl,
|
|
) -> bool:
|
|
"""Cache response."""
|
|
try:
|
|
key = self._cache_key(question, sources)
|
|
client = await self.client
|
|
|
|
await client.post(
|
|
f"{self.base_url}/set",
|
|
json={
|
|
"key": key,
|
|
"value": json.dumps(response),
|
|
"ttl": ttl,
|
|
},
|
|
)
|
|
logger.debug(f"Cached response for: {question[:50]}...")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Cache set failed: {e}")
|
|
return False
|
|
|
|
async def close(self):
|
|
"""Close HTTP client."""
|
|
if self._client:
|
|
await self._client.aclose()
|
|
self._client = None
|
|
|
|
|
|
# Query Router
|
|
class QueryRouter:
|
|
"""Routes queries to appropriate data sources based on intent."""
|
|
|
|
def __init__(self):
|
|
self.intent_keywords = {
|
|
QueryIntent.GEOGRAPHIC: [
|
|
"map", "kaart", "where", "waar", "location", "locatie",
|
|
"city", "stad", "country", "land", "region", "gebied",
|
|
"coordinates", "coördinaten", "near", "nearby", "in de buurt",
|
|
],
|
|
QueryIntent.STATISTICAL: [
|
|
"how many", "hoeveel", "count", "aantal", "total", "totaal",
|
|
"average", "gemiddeld", "distribution", "verdeling",
|
|
"percentage", "statistics", "statistiek", "most", "meest",
|
|
],
|
|
QueryIntent.RELATIONAL: [
|
|
"related", "gerelateerd", "connected", "verbonden",
|
|
"relationship", "relatie", "network", "netwerk",
|
|
"parent", "child", "merged", "fusie", "member of",
|
|
],
|
|
QueryIntent.TEMPORAL: [
|
|
"history", "geschiedenis", "timeline", "tijdlijn",
|
|
"when", "wanneer", "founded", "opgericht", "closed", "gesloten",
|
|
"over time", "evolution", "change", "verandering",
|
|
],
|
|
QueryIntent.DETAIL: [
|
|
"details", "information", "informatie", "about", "over",
|
|
"specific", "specifiek", "what is", "wat is",
|
|
],
|
|
}
|
|
|
|
self.source_routing = {
|
|
QueryIntent.GEOGRAPHIC: [DataSource.POSTGIS, DataSource.QDRANT, DataSource.SPARQL],
|
|
QueryIntent.STATISTICAL: [DataSource.SPARQL, DataSource.QDRANT],
|
|
QueryIntent.RELATIONAL: [DataSource.TYPEDB, DataSource.SPARQL],
|
|
QueryIntent.TEMPORAL: [DataSource.TYPEDB, DataSource.SPARQL],
|
|
QueryIntent.SEARCH: [DataSource.QDRANT, DataSource.SPARQL],
|
|
QueryIntent.DETAIL: [DataSource.SPARQL, DataSource.QDRANT],
|
|
}
|
|
|
|
def detect_intent(self, question: str) -> QueryIntent:
|
|
"""Detect query intent from question text."""
|
|
question_lower = question.lower()
|
|
|
|
intent_scores = {intent: 0 for intent in QueryIntent}
|
|
|
|
for intent, keywords in self.intent_keywords.items():
|
|
for keyword in keywords:
|
|
if keyword in question_lower:
|
|
intent_scores[intent] += 1
|
|
|
|
max_intent = max(intent_scores, key=intent_scores.get) # type: ignore
|
|
|
|
if intent_scores[max_intent] == 0:
|
|
return QueryIntent.SEARCH
|
|
|
|
return max_intent
|
|
|
|
def get_sources(
|
|
self,
|
|
question: str,
|
|
requested_sources: list[DataSource] | None = None,
|
|
) -> tuple[QueryIntent, list[DataSource]]:
|
|
"""Get optimal sources for a query.
|
|
|
|
Args:
|
|
question: User's question
|
|
requested_sources: Explicitly requested sources (overrides routing)
|
|
|
|
Returns:
|
|
Tuple of (detected_intent, list_of_sources)
|
|
"""
|
|
intent = self.detect_intent(question)
|
|
|
|
if requested_sources:
|
|
return intent, requested_sources
|
|
|
|
return intent, self.source_routing.get(intent, [DataSource.QDRANT])
|
|
|
|
|
|
# Multi-Source Retriever
|
|
class MultiSourceRetriever:
|
|
"""Orchestrates retrieval across multiple data sources."""
|
|
|
|
def __init__(self):
|
|
self.cache = ValkeyClient()
|
|
self.router = QueryRouter()
|
|
|
|
# Initialize retrievers lazily
|
|
self._qdrant: HybridRetriever | None = None
|
|
self._typedb: TypeDBRetriever | None = None
|
|
self._sparql_client: httpx.AsyncClient | None = None
|
|
self._postgis_client: httpx.AsyncClient | None = None
|
|
|
|
@property
|
|
def qdrant(self) -> HybridRetriever | None:
|
|
"""Lazy-load Qdrant hybrid retriever."""
|
|
if self._qdrant is None and RETRIEVERS_AVAILABLE:
|
|
try:
|
|
self._qdrant = create_hybrid_retriever(
|
|
use_production=settings.qdrant_use_production
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize Qdrant: {e}")
|
|
return self._qdrant
|
|
|
|
@property
|
|
def typedb(self) -> TypeDBRetriever | None:
|
|
"""Lazy-load TypeDB retriever."""
|
|
if self._typedb is None and RETRIEVERS_AVAILABLE:
|
|
try:
|
|
self._typedb = create_typedb_retriever(
|
|
use_production=settings.typedb_use_production # Use TypeDB-specific setting
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize TypeDB: {e}")
|
|
return self._typedb
|
|
|
|
async def _get_sparql_client(self) -> httpx.AsyncClient:
|
|
"""Get SPARQL HTTP client."""
|
|
if self._sparql_client is None or self._sparql_client.is_closed:
|
|
self._sparql_client = httpx.AsyncClient(timeout=30.0)
|
|
return self._sparql_client
|
|
|
|
async def _get_postgis_client(self) -> httpx.AsyncClient:
|
|
"""Get PostGIS HTTP client."""
|
|
if self._postgis_client is None or self._postgis_client.is_closed:
|
|
self._postgis_client = httpx.AsyncClient(timeout=30.0)
|
|
return self._postgis_client
|
|
|
|
async def retrieve_from_qdrant(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from Qdrant vector + SPARQL hybrid search."""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
items = []
|
|
if self.qdrant:
|
|
try:
|
|
results = self.qdrant.search(query, k=k)
|
|
items = [r.to_dict() for r in results]
|
|
except Exception as e:
|
|
logger.error(f"Qdrant retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.QDRANT,
|
|
items=items,
|
|
score=max((r.get("scores", {}).get("combined", 0) for r in items), default=0),
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve_from_sparql(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from SPARQL endpoint."""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
# Use DSPy to generate SPARQL
|
|
items = []
|
|
try:
|
|
if RETRIEVERS_AVAILABLE:
|
|
sparql_result = generate_sparql(query, language="nl", use_rag=False)
|
|
sparql_query = sparql_result.get("sparql", "")
|
|
|
|
if sparql_query:
|
|
client = await self._get_sparql_client()
|
|
response = await client.post(
|
|
settings.sparql_endpoint,
|
|
data={"query": sparql_query},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
bindings = data.get("results", {}).get("bindings", [])
|
|
items = [
|
|
{k: v.get("value") for k, v in b.items()}
|
|
for b in bindings[:k]
|
|
]
|
|
except Exception as e:
|
|
logger.error(f"SPARQL retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.SPARQL,
|
|
items=items,
|
|
score=1.0 if items else 0.0,
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve_from_typedb(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from TypeDB knowledge graph."""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
items = []
|
|
if self.typedb:
|
|
try:
|
|
results = self.typedb.semantic_search(query, k=k)
|
|
items = [r.to_dict() for r in results]
|
|
except Exception as e:
|
|
logger.error(f"TypeDB retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.TYPEDB,
|
|
items=items,
|
|
score=max((r.get("relevance_score", 0) for r in items), default=0),
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve_from_postgis(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from PostGIS geospatial database."""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
# Extract location from query for geospatial search
|
|
# This is a simplified implementation
|
|
items = []
|
|
try:
|
|
client = await self._get_postgis_client()
|
|
|
|
# Try to detect city name for bbox search
|
|
query_lower = query.lower()
|
|
|
|
# Simple city detection
|
|
cities = {
|
|
"amsterdam": {"lat": 52.3676, "lon": 4.9041},
|
|
"rotterdam": {"lat": 51.9244, "lon": 4.4777},
|
|
"den haag": {"lat": 52.0705, "lon": 4.3007},
|
|
"utrecht": {"lat": 52.0907, "lon": 5.1214},
|
|
}
|
|
|
|
for city, coords in cities.items():
|
|
if city in query_lower:
|
|
# Query PostGIS for nearby institutions
|
|
response = await client.get(
|
|
f"{settings.postgis_url}/api/institutions/nearby",
|
|
params={
|
|
"lat": coords["lat"],
|
|
"lon": coords["lon"],
|
|
"radius_km": 10,
|
|
"limit": k,
|
|
},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
items = response.json()
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"PostGIS retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.POSTGIS,
|
|
items=items,
|
|
score=1.0 if items else 0.0,
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve(
|
|
self,
|
|
question: str,
|
|
sources: list[DataSource],
|
|
k: int = 10,
|
|
) -> list[RetrievalResult]:
|
|
"""Retrieve from multiple sources concurrently.
|
|
|
|
Args:
|
|
question: User's question
|
|
sources: Data sources to query
|
|
k: Number of results per source
|
|
|
|
Returns:
|
|
List of RetrievalResult from each source
|
|
"""
|
|
tasks = []
|
|
|
|
for source in sources:
|
|
if source == DataSource.QDRANT:
|
|
tasks.append(self.retrieve_from_qdrant(question, k))
|
|
elif source == DataSource.SPARQL:
|
|
tasks.append(self.retrieve_from_sparql(question, k))
|
|
elif source == DataSource.TYPEDB:
|
|
tasks.append(self.retrieve_from_typedb(question, k))
|
|
elif source == DataSource.POSTGIS:
|
|
tasks.append(self.retrieve_from_postgis(question, k))
|
|
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Filter out exceptions
|
|
valid_results = []
|
|
for r in results:
|
|
if isinstance(r, RetrievalResult):
|
|
valid_results.append(r)
|
|
elif isinstance(r, Exception):
|
|
logger.error(f"Retrieval task failed: {r}")
|
|
|
|
return valid_results
|
|
|
|
def merge_results(
|
|
self,
|
|
results: list[RetrievalResult],
|
|
max_results: int = 20,
|
|
) -> list[dict[str, Any]]:
|
|
"""Merge and deduplicate results from multiple sources.
|
|
|
|
Uses reciprocal rank fusion for score combination.
|
|
"""
|
|
# Track items by GHCID for deduplication
|
|
merged: dict[str, dict[str, Any]] = {}
|
|
|
|
for result in results:
|
|
for rank, item in enumerate(result.items):
|
|
ghcid = item.get("ghcid", item.get("id", f"unknown_{rank}"))
|
|
|
|
if ghcid not in merged:
|
|
merged[ghcid] = item.copy()
|
|
merged[ghcid]["_sources"] = []
|
|
merged[ghcid]["_rrf_score"] = 0.0
|
|
|
|
# Reciprocal Rank Fusion
|
|
rrf_score = 1.0 / (60 + rank) # k=60 is standard
|
|
|
|
# Weight by source
|
|
source_weights = {
|
|
DataSource.QDRANT: settings.vector_weight,
|
|
DataSource.SPARQL: settings.graph_weight,
|
|
DataSource.TYPEDB: settings.typedb_weight,
|
|
DataSource.POSTGIS: 0.3,
|
|
}
|
|
|
|
weight = source_weights.get(result.source, 0.5)
|
|
merged[ghcid]["_rrf_score"] += rrf_score * weight
|
|
merged[ghcid]["_sources"].append(result.source.value)
|
|
|
|
# Sort by RRF score
|
|
sorted_items = sorted(
|
|
merged.values(),
|
|
key=lambda x: x.get("_rrf_score", 0),
|
|
reverse=True,
|
|
)
|
|
|
|
return sorted_items[:max_results]
|
|
|
|
async def close(self):
|
|
"""Clean up resources."""
|
|
await self.cache.close()
|
|
|
|
if self._sparql_client:
|
|
await self._sparql_client.aclose()
|
|
if self._postgis_client:
|
|
await self._postgis_client.aclose()
|
|
if self._qdrant:
|
|
self._qdrant.close()
|
|
if self._typedb:
|
|
self._typedb.close()
|
|
|
|
|
|
# Global instances
|
|
retriever: MultiSourceRetriever | None = None
|
|
viz_selector: VisualizationSelector | None = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Application lifespan manager."""
|
|
global retriever, viz_selector
|
|
|
|
# Startup
|
|
logger.info("Starting Heritage RAG API...")
|
|
retriever = MultiSourceRetriever()
|
|
|
|
if RETRIEVERS_AVAILABLE:
|
|
# Check for any available LLM API key (Anthropic preferred, OpenAI fallback)
|
|
has_llm_key = bool(settings.anthropic_api_key or settings.openai_api_key)
|
|
viz_selector = VisualizationSelector(use_dspy=has_llm_key)
|
|
|
|
# Configure DSPy if API key available
|
|
if settings.anthropic_api_key:
|
|
try:
|
|
configure_dspy(
|
|
provider="anthropic",
|
|
model=settings.default_model,
|
|
api_key=settings.anthropic_api_key,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to configure DSPy with Anthropic: {e}")
|
|
elif settings.openai_api_key:
|
|
try:
|
|
configure_dspy(
|
|
provider="openai",
|
|
model="gpt-4o-mini",
|
|
api_key=settings.openai_api_key,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to configure DSPy with OpenAI: {e}")
|
|
|
|
logger.info("Heritage RAG API started")
|
|
|
|
yield
|
|
|
|
# Shutdown
|
|
logger.info("Shutting down Heritage RAG API...")
|
|
if retriever:
|
|
await retriever.close()
|
|
logger.info("Heritage RAG API stopped")
|
|
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(
|
|
title=settings.api_title,
|
|
version=settings.api_version,
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# API Endpoints
|
|
@app.get("/api/rag/health")
|
|
async def health_check():
|
|
"""Health check for all services."""
|
|
health = {
|
|
"status": "ok",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"services": {},
|
|
}
|
|
|
|
# Check Qdrant
|
|
if retriever and retriever.qdrant:
|
|
try:
|
|
stats = retriever.qdrant.get_stats()
|
|
health["services"]["qdrant"] = {
|
|
"status": "ok",
|
|
"vectors": stats.get("qdrant", {}).get("vectors_count", 0),
|
|
}
|
|
except Exception as e:
|
|
health["services"]["qdrant"] = {"status": "error", "error": str(e)}
|
|
|
|
# Check SPARQL
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
response = await client.get(f"{settings.sparql_endpoint.replace('/query', '')}")
|
|
health["services"]["sparql"] = {
|
|
"status": "ok" if response.status_code < 500 else "error"
|
|
}
|
|
except Exception as e:
|
|
health["services"]["sparql"] = {"status": "error", "error": str(e)}
|
|
|
|
# Check TypeDB
|
|
if retriever and retriever.typedb:
|
|
try:
|
|
stats = retriever.typedb.get_stats()
|
|
health["services"]["typedb"] = {
|
|
"status": "ok",
|
|
"entities": stats.get("entities", {}),
|
|
}
|
|
except Exception as e:
|
|
health["services"]["typedb"] = {"status": "error", "error": str(e)}
|
|
|
|
# Overall status
|
|
errors = sum(1 for s in health["services"].values() if s.get("status") == "error")
|
|
health["status"] = "ok" if errors == 0 else "degraded" if errors < 3 else "error"
|
|
|
|
return health
|
|
|
|
|
|
@app.get("/api/rag/stats")
|
|
async def get_stats():
|
|
"""Get retriever statistics."""
|
|
stats = {
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"retrievers": {},
|
|
}
|
|
|
|
if retriever:
|
|
if retriever.qdrant:
|
|
stats["retrievers"]["qdrant"] = retriever.qdrant.get_stats()
|
|
if retriever.typedb:
|
|
stats["retrievers"]["typedb"] = retriever.typedb.get_stats()
|
|
|
|
return stats
|
|
|
|
|
|
@app.post("/api/rag/query", response_model=QueryResponse)
|
|
async def query_rag(request: QueryRequest):
|
|
"""Main RAG query endpoint.
|
|
|
|
Orchestrates retrieval from multiple sources, merges results,
|
|
and optionally generates visualization configuration.
|
|
"""
|
|
if not retriever:
|
|
raise HTTPException(status_code=503, detail="Retriever not initialized")
|
|
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
# Check cache first
|
|
cached = await retriever.cache.get(request.question, request.sources)
|
|
if cached:
|
|
cached["cache_hit"] = True
|
|
return QueryResponse(**cached)
|
|
|
|
# Route query to appropriate sources
|
|
intent, sources = retriever.router.get_sources(request.question, request.sources)
|
|
logger.info(f"Query intent: {intent}, sources: {sources}")
|
|
|
|
# Retrieve from all sources
|
|
results = await retriever.retrieve(request.question, sources, request.k)
|
|
|
|
# Merge results
|
|
merged_items = retriever.merge_results(results, max_results=request.k * 2)
|
|
|
|
# Generate visualization config if requested
|
|
visualization = None
|
|
if request.include_visualization and viz_selector and merged_items:
|
|
# Extract schema from first result
|
|
schema_fields = list(merged_items[0].keys()) if merged_items else []
|
|
schema_str = ", ".join(f for f in schema_fields if not f.startswith("_"))
|
|
|
|
visualization = viz_selector.select(
|
|
request.question,
|
|
schema_str,
|
|
len(merged_items),
|
|
)
|
|
|
|
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
|
|
|
|
response_data = {
|
|
"question": request.question,
|
|
"sparql": None, # Could be populated from SPARQL result
|
|
"results": merged_items,
|
|
"visualization": visualization,
|
|
"sources_used": [s for s in sources],
|
|
"cache_hit": False,
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"result_count": len(merged_items),
|
|
}
|
|
|
|
# Cache the response
|
|
await retriever.cache.set(request.question, request.sources, response_data)
|
|
|
|
return QueryResponse(**response_data)
|
|
|
|
|
|
@app.post("/api/rag/sparql", response_model=SPARQLResponse)
|
|
async def generate_sparql_endpoint(request: SPARQLRequest):
|
|
"""Generate SPARQL query from natural language.
|
|
|
|
Uses DSPy with optional RAG enhancement for context.
|
|
"""
|
|
if not RETRIEVERS_AVAILABLE:
|
|
raise HTTPException(status_code=503, detail="SPARQL generator not available")
|
|
|
|
try:
|
|
result = generate_sparql(
|
|
request.question,
|
|
language=request.language,
|
|
context=request.context,
|
|
use_rag=request.use_rag,
|
|
)
|
|
|
|
return SPARQLResponse(
|
|
sparql=result["sparql"],
|
|
explanation=result.get("explanation", ""),
|
|
rag_used=result.get("rag_used", False),
|
|
retrieved_passages=result.get("retrieved_passages", []),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception("SPARQL generation failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/api/rag/visualize")
|
|
async def get_visualization_config(
|
|
question: str = Query(..., description="User's question"),
|
|
schema: str = Query(..., description="Comma-separated field names"),
|
|
result_count: int = Query(default=0, description="Number of results"),
|
|
):
|
|
"""Get visualization configuration for a query."""
|
|
if not viz_selector:
|
|
raise HTTPException(status_code=503, detail="Visualization selector not available")
|
|
|
|
config = viz_selector.select(question, schema, result_count)
|
|
return config
|
|
|
|
|
|
@app.post("/api/rag/dspy/query", response_model=DSPyQueryResponse)
|
|
async def dspy_query(request: DSPyQueryRequest):
|
|
"""DSPy RAG query endpoint with multi-turn conversation support.
|
|
|
|
Uses the HeritageRAGPipeline for conversation-aware question answering.
|
|
Follow-up questions like "Welke daarvan behoren archieven?" will be
|
|
resolved using previous conversation context.
|
|
|
|
Args:
|
|
request: Query request with question, language, and conversation context
|
|
|
|
Returns:
|
|
DSPyQueryResponse with answer, resolved question, and optional visualization
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Import DSPy pipeline and History
|
|
import dspy
|
|
from dspy import History
|
|
from dspy_heritage_rag import HeritageRAGPipeline
|
|
|
|
# Ensure DSPy has an LM configured
|
|
# Check if LM is already configured by testing if we can get the settings
|
|
try:
|
|
current_lm = dspy.settings.lm
|
|
if current_lm is None:
|
|
raise ValueError("No LM configured")
|
|
except (AttributeError, ValueError):
|
|
# No LM configured yet - try to configure one
|
|
api_key = settings.anthropic_api_key or os.getenv("ANTHROPIC_API_KEY", "")
|
|
if api_key:
|
|
lm = dspy.LM("anthropic/claude-sonnet-4-20250514", api_key=api_key)
|
|
dspy.configure(lm=lm)
|
|
logger.info("Configured DSPy with Anthropic Claude")
|
|
else:
|
|
# Try OpenAI as fallback
|
|
openai_key = os.getenv("OPENAI_API_KEY", "")
|
|
if openai_key:
|
|
lm = dspy.LM("openai/gpt-4o-mini", api_key=openai_key)
|
|
dspy.configure(lm=lm)
|
|
logger.info("Configured DSPy with OpenAI GPT-4o-mini")
|
|
else:
|
|
raise ValueError(
|
|
"No LLM API key found. Set ANTHROPIC_API_KEY or OPENAI_API_KEY environment variable."
|
|
)
|
|
|
|
# Convert context to DSPy History format
|
|
# Context comes as [{question: "...", answer: "..."}, ...]
|
|
# History expects messages with role and content
|
|
history_messages = []
|
|
for turn in request.context:
|
|
if turn.get("question"):
|
|
history_messages.append({"role": "user", "content": turn["question"]})
|
|
if turn.get("answer"):
|
|
history_messages.append({"role": "assistant", "content": turn["answer"]})
|
|
|
|
history = History(messages=history_messages) if history_messages else None
|
|
|
|
# Initialize pipeline (could be cached globally for performance)
|
|
pipeline = HeritageRAGPipeline()
|
|
|
|
# Execute query with conversation history
|
|
result = pipeline.forward(
|
|
question=request.question,
|
|
language=request.language,
|
|
history=history,
|
|
include_viz=request.include_visualization,
|
|
)
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
# Extract visualization if present
|
|
visualization = None
|
|
if request.include_visualization and hasattr(result, "visualization"):
|
|
viz = result.visualization
|
|
if viz:
|
|
visualization = {
|
|
"type": getattr(viz, "viz_type", "table"),
|
|
"sparql_query": getattr(result, "sparql", None),
|
|
}
|
|
|
|
return DSPyQueryResponse(
|
|
question=request.question,
|
|
resolved_question=getattr(result, "resolved_question", None),
|
|
answer=getattr(result, "answer", "Geen antwoord gevonden."),
|
|
sources_used=getattr(result, "sources_used", []),
|
|
visualization=visualization,
|
|
query_time_ms=round(elapsed_ms, 2),
|
|
conversation_turn=len(request.context),
|
|
)
|
|
|
|
except ImportError as e:
|
|
logger.warning(f"DSPy pipeline not available: {e}")
|
|
# Fallback to simple response
|
|
return DSPyQueryResponse(
|
|
question=request.question,
|
|
answer="DSPy pipeline is niet beschikbaar. Probeer de standaard /api/rag/query endpoint.",
|
|
query_time_ms=0,
|
|
conversation_turn=len(request.context),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception("DSPy query failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
async def stream_query_response(
|
|
request: QueryRequest,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream query response for long-running queries."""
|
|
if not retriever:
|
|
yield json.dumps({"error": "Retriever not initialized"})
|
|
return
|
|
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
# Route query
|
|
intent, sources = retriever.router.get_sources(request.question, request.sources)
|
|
|
|
yield json.dumps({
|
|
"type": "status",
|
|
"message": f"Routing query to {len(sources)} sources...",
|
|
"intent": intent.value,
|
|
}) + "\n"
|
|
|
|
# Retrieve from sources and stream progress
|
|
results = []
|
|
for source in sources:
|
|
yield json.dumps({
|
|
"type": "status",
|
|
"message": f"Querying {source.value}...",
|
|
}) + "\n"
|
|
|
|
source_results = await retriever.retrieve(request.question, [source], request.k)
|
|
results.extend(source_results)
|
|
|
|
yield json.dumps({
|
|
"type": "partial",
|
|
"source": source.value,
|
|
"count": len(source_results[0].items) if source_results else 0,
|
|
}) + "\n"
|
|
|
|
# Merge and finalize
|
|
merged = retriever.merge_results(results)
|
|
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
|
|
|
|
yield json.dumps({
|
|
"type": "complete",
|
|
"results": merged,
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"result_count": len(merged),
|
|
}) + "\n"
|
|
|
|
|
|
@app.post("/api/rag/query/stream")
|
|
async def query_rag_stream(request: QueryRequest):
|
|
"""Streaming version of RAG query endpoint."""
|
|
return StreamingResponse(
|
|
stream_query_response(request),
|
|
media_type="application/x-ndjson",
|
|
)
|
|
|
|
|
|
# Main entry point
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"main:app",
|
|
host="0.0.0.0",
|
|
port=8003,
|
|
reload=settings.debug,
|
|
log_level="info",
|
|
)
|