1661 lines
61 KiB
Python
1661 lines
61 KiB
Python
from __future__ import annotations
|
|
|
|
"""
|
|
Unified RAG Backend for Heritage Custodian Data
|
|
|
|
Multi-source retrieval-augmented generation system that combines:
|
|
- Qdrant vector search (semantic similarity)
|
|
- Oxigraph SPARQL (knowledge graph queries)
|
|
- TypeDB (relationship traversal)
|
|
- PostGIS (geospatial queries)
|
|
- Valkey (semantic caching)
|
|
|
|
Architecture:
|
|
User Query → Query Analysis
|
|
↓
|
|
┌─────┴─────┐
|
|
│ Router │
|
|
└─────┬─────┘
|
|
┌─────┬─────┼─────┬─────┐
|
|
↓ ↓ ↓ ↓ ↓
|
|
Qdrant SPARQL TypeDB PostGIS Cache
|
|
│ │ │ │ │
|
|
└─────┴─────┴─────┴─────┘
|
|
↓
|
|
┌─────┴─────┐
|
|
│ Merger │
|
|
└─────┬─────┘
|
|
↓
|
|
DSPy Generator
|
|
↓
|
|
Visualization Selector
|
|
↓
|
|
Response (JSON/Streaming)
|
|
|
|
Features:
|
|
- Intelligent query routing to appropriate data sources
|
|
- Score fusion for multi-source results
|
|
- Semantic caching via Valkey API
|
|
- Streaming responses for long-running queries
|
|
- DSPy assertions for output validation
|
|
|
|
Endpoints:
|
|
- POST /api/rag/query - Main RAG query endpoint
|
|
- POST /api/rag/sparql - Generate SPARQL with RAG context
|
|
- POST /api/rag/typedb/search - Direct TypeDB search
|
|
- GET /api/rag/health - Health check for all services
|
|
- GET /api/rag/stats - Retriever statistics
|
|
"""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from enum import Enum
|
|
from typing import Any, AsyncIterator, TYPE_CHECKING
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, HTTPException, Query
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Type hints for optional imports (only used during type checking)
|
|
if TYPE_CHECKING:
|
|
from glam_extractor.api.hybrid_retriever import HybridRetriever
|
|
from glam_extractor.api.typedb_retriever import TypeDBRetriever
|
|
from glam_extractor.api.visualization import VisualizationSelector
|
|
|
|
# Import retrievers (with graceful fallbacks)
|
|
RETRIEVERS_AVAILABLE = False
|
|
create_hybrid_retriever: Any = None
|
|
HeritageCustodianRetriever: Any = None
|
|
create_typedb_retriever: Any = None
|
|
select_visualization: Any = None
|
|
VisualizationSelector: Any = None # type: ignore[no-redef]
|
|
generate_sparql: Any = None
|
|
configure_dspy: Any = None
|
|
|
|
try:
|
|
import sys
|
|
sys.path.insert(0, str(os.path.join(os.path.dirname(__file__), "..", "..", "src")))
|
|
from glam_extractor.api.hybrid_retriever import HybridRetriever as _HybridRetriever, create_hybrid_retriever as _create_hybrid_retriever
|
|
from glam_extractor.api.qdrant_retriever import HeritageCustodianRetriever as _HeritageCustodianRetriever
|
|
from glam_extractor.api.typedb_retriever import TypeDBRetriever as _TypeDBRetriever, create_typedb_retriever as _create_typedb_retriever
|
|
from glam_extractor.api.visualization import select_visualization as _select_visualization, VisualizationSelector as _VisualizationSelector
|
|
# Assign to module-level variables
|
|
create_hybrid_retriever = _create_hybrid_retriever
|
|
HeritageCustodianRetriever = _HeritageCustodianRetriever
|
|
create_typedb_retriever = _create_typedb_retriever
|
|
select_visualization = _select_visualization
|
|
VisualizationSelector = _VisualizationSelector
|
|
RETRIEVERS_AVAILABLE = True
|
|
except ImportError as e:
|
|
logger.warning(f"Core retrievers not available: {e}")
|
|
|
|
# DSPy is optional - don't block retrievers if it's missing
|
|
try:
|
|
from glam_extractor.api.dspy_sparql import generate_sparql as _generate_sparql, configure_dspy as _configure_dspy
|
|
generate_sparql = _generate_sparql
|
|
configure_dspy = _configure_dspy
|
|
except ImportError as e:
|
|
logger.warning(f"DSPy SPARQL not available: {e}")
|
|
|
|
# Cost tracker is optional - gracefully degrades if unavailable
|
|
COST_TRACKER_AVAILABLE = False
|
|
get_tracker = None
|
|
reset_tracker = None
|
|
try:
|
|
from cost_tracker import get_tracker as _get_tracker, reset_tracker as _reset_tracker
|
|
get_tracker = _get_tracker
|
|
reset_tracker = _reset_tracker
|
|
COST_TRACKER_AVAILABLE = True
|
|
logger.info("Cost tracker module loaded successfully")
|
|
except ImportError as e:
|
|
logger.info(f"Cost tracker not available (optional): {e}")
|
|
|
|
|
|
# Configuration
|
|
class Settings:
|
|
"""Application settings from environment variables."""
|
|
|
|
# API Configuration
|
|
api_title: str = "Heritage RAG API"
|
|
api_version: str = "1.0.0"
|
|
debug: bool = os.getenv("DEBUG", "false").lower() == "true"
|
|
|
|
# Valkey Cache
|
|
valkey_api_url: str = os.getenv("VALKEY_API_URL", "https://bronhouder.nl/api/cache")
|
|
cache_ttl: int = int(os.getenv("CACHE_TTL", "900")) # 15 minutes
|
|
|
|
# Qdrant Vector DB
|
|
# Production: Use URL-based client via bronhouder.nl/qdrant reverse proxy
|
|
qdrant_host: str = os.getenv("QDRANT_HOST", "localhost")
|
|
qdrant_port: int = int(os.getenv("QDRANT_PORT", "6333"))
|
|
qdrant_use_production: bool = os.getenv("QDRANT_USE_PRODUCTION", "true").lower() == "true"
|
|
qdrant_production_url: str = os.getenv("QDRANT_PRODUCTION_URL", "https://bronhouder.nl/qdrant")
|
|
|
|
# Multi-Embedding Support
|
|
# Enable to use named vectors with multiple embedding models (OpenAI 1536, MiniLM 384, BGE 768)
|
|
use_multi_embedding: bool = os.getenv("USE_MULTI_EMBEDDING", "true").lower() == "true"
|
|
preferred_embedding_model: str | None = os.getenv("PREFERRED_EMBEDDING_MODEL", None) # e.g., "minilm_384" or "openai_1536"
|
|
|
|
# Oxigraph SPARQL
|
|
# Production: Use bronhouder.nl/sparql reverse proxy
|
|
sparql_endpoint: str = os.getenv("SPARQL_ENDPOINT", "https://bronhouder.nl/sparql")
|
|
|
|
# TypeDB
|
|
# Note: TypeDB not exposed via reverse proxy - always use localhost
|
|
typedb_host: str = os.getenv("TYPEDB_HOST", "localhost")
|
|
typedb_port: int = int(os.getenv("TYPEDB_PORT", "1729"))
|
|
typedb_database: str = os.getenv("TYPEDB_DATABASE", "heritage_custodians")
|
|
typedb_use_production: bool = os.getenv("TYPEDB_USE_PRODUCTION", "false").lower() == "true" # Default off
|
|
|
|
# PostGIS/Geo API
|
|
# Production: Use bronhouder.nl/api/geo reverse proxy
|
|
postgis_url: str = os.getenv("POSTGIS_URL", "https://bronhouder.nl/api/geo")
|
|
|
|
# LLM Configuration
|
|
anthropic_api_key: str = os.getenv("ANTHROPIC_API_KEY", "")
|
|
openai_api_key: str = os.getenv("OPENAI_API_KEY", "")
|
|
default_model: str = os.getenv("DEFAULT_MODEL", "claude-opus-4-5-20251101")
|
|
|
|
# Retrieval weights
|
|
vector_weight: float = float(os.getenv("VECTOR_WEIGHT", "0.5"))
|
|
graph_weight: float = float(os.getenv("GRAPH_WEIGHT", "0.3"))
|
|
typedb_weight: float = float(os.getenv("TYPEDB_WEIGHT", "0.2"))
|
|
|
|
|
|
settings = Settings()
|
|
|
|
|
|
# Enums and Models
|
|
class QueryIntent(str, Enum):
|
|
"""Detected query intent for routing."""
|
|
GEOGRAPHIC = "geographic" # Location-based queries
|
|
STATISTICAL = "statistical" # Counts, aggregations
|
|
RELATIONAL = "relational" # Relationships between entities
|
|
TEMPORAL = "temporal" # Historical, timeline queries
|
|
SEARCH = "search" # General text search
|
|
DETAIL = "detail" # Specific entity lookup
|
|
|
|
|
|
class DataSource(str, Enum):
|
|
"""Available data sources."""
|
|
QDRANT = "qdrant"
|
|
SPARQL = "sparql"
|
|
TYPEDB = "typedb"
|
|
POSTGIS = "postgis"
|
|
CACHE = "cache"
|
|
|
|
|
|
@dataclass
|
|
class RetrievalResult:
|
|
"""Result from a single retriever."""
|
|
source: DataSource
|
|
items: list[dict[str, Any]]
|
|
score: float = 0.0
|
|
query_time_ms: float = 0.0
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
class QueryRequest(BaseModel):
|
|
"""RAG query request."""
|
|
question: str = Field(..., description="Natural language question")
|
|
language: str = Field(default="nl", description="Language code (nl or en)")
|
|
context: list[dict[str, Any]] = Field(default=[], description="Conversation history")
|
|
sources: list[DataSource] = Field(
|
|
default=[DataSource.QDRANT, DataSource.SPARQL],
|
|
description="Data sources to query",
|
|
)
|
|
k: int = Field(default=10, description="Number of results per source")
|
|
include_visualization: bool = Field(default=True, description="Include visualization config")
|
|
embedding_model: str | None = Field(
|
|
default=None,
|
|
description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available."
|
|
)
|
|
stream: bool = Field(default=False, description="Stream response")
|
|
embedding_model: str | None = Field(
|
|
default=None,
|
|
description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available."
|
|
)
|
|
|
|
|
|
class QueryResponse(BaseModel):
|
|
"""RAG query response."""
|
|
question: str
|
|
sparql: str | None = None
|
|
results: list[dict[str, Any]]
|
|
visualization: dict[str, Any] | None = None
|
|
sources_used: list[DataSource]
|
|
cache_hit: bool = False
|
|
query_time_ms: float
|
|
result_count: int
|
|
|
|
|
|
class SPARQLRequest(BaseModel):
|
|
"""SPARQL generation request."""
|
|
question: str
|
|
language: str = "nl"
|
|
context: list[dict[str, Any]] = []
|
|
use_rag: bool = True
|
|
|
|
|
|
class SPARQLResponse(BaseModel):
|
|
"""SPARQL generation response."""
|
|
sparql: str
|
|
explanation: str
|
|
rag_used: bool
|
|
retrieved_passages: list[str] = []
|
|
|
|
|
|
class TypeDBSearchRequest(BaseModel):
|
|
"""TypeDB search request."""
|
|
query: str = Field(..., description="Search query (name, type, or location)")
|
|
search_type: str = Field(
|
|
default="semantic",
|
|
description="Search type: semantic, name, type, or location"
|
|
)
|
|
k: int = Field(default=10, ge=1, le=100, description="Number of results")
|
|
|
|
|
|
class TypeDBSearchResponse(BaseModel):
|
|
"""TypeDB search response."""
|
|
query: str
|
|
search_type: str
|
|
results: list[dict[str, Any]]
|
|
result_count: int
|
|
query_time_ms: float
|
|
|
|
|
|
class PersonSearchRequest(BaseModel):
|
|
"""Person/staff search request."""
|
|
query: str = Field(..., description="Search query for person/staff (e.g., 'Wie werkt er in het Nationaal Archief?')")
|
|
k: int = Field(default=10, ge=1, le=100, description="Number of results to return")
|
|
filter_custodian: str | None = Field(default=None, description="Filter by custodian slug (e.g., 'nationaal-archief')")
|
|
only_heritage_relevant: bool = Field(default=False, description="Only return heritage-relevant staff")
|
|
embedding_model: str | None = Field(
|
|
default=None,
|
|
description="Embedding model to use (e.g., 'minilm_384', 'openai_1536'). If None, auto-selects best available."
|
|
)
|
|
|
|
|
|
class PersonSearchResponse(BaseModel):
|
|
"""Person/staff search response."""
|
|
query: str
|
|
results: list[dict[str, Any]]
|
|
result_count: int
|
|
query_time_ms: float
|
|
collection_stats: dict[str, Any] | None = None
|
|
embedding_model_used: str | None = None
|
|
|
|
|
|
class DSPyQueryRequest(BaseModel):
|
|
"""DSPy RAG query request with conversation support."""
|
|
question: str = Field(..., description="Natural language question")
|
|
language: str = Field(default="nl", description="Language code (nl or en)")
|
|
context: list[dict[str, Any]] = Field(
|
|
default=[],
|
|
description="Conversation history as list of {question, answer} dicts"
|
|
)
|
|
include_visualization: bool = Field(default=True, description="Include visualization config")
|
|
embedding_model: str | None = Field(
|
|
default=None,
|
|
description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available."
|
|
)
|
|
|
|
|
|
class DSPyQueryResponse(BaseModel):
|
|
"""DSPy RAG query response."""
|
|
question: str
|
|
resolved_question: str | None = None
|
|
answer: str
|
|
sources_used: list[str] = []
|
|
visualization: dict[str, Any] | None = None
|
|
retrieved_results: list[dict[str, Any]] | None = None # Raw retrieved data for frontend visualization
|
|
query_type: str | None = None # "person" or "institution" - helps frontend choose visualization
|
|
query_time_ms: float = 0.0
|
|
conversation_turn: int = 0
|
|
embedding_model_used: str | None = None # Which embedding model was used for the search
|
|
|
|
# Cost tracking fields (from cost_tracker module)
|
|
timing_ms: float | None = None # Total pipeline timing from cost tracker
|
|
cost_usd: float | None = None # Estimated LLM cost in USD
|
|
timing_breakdown: dict[str, float] | None = None # Per-stage timing breakdown
|
|
|
|
|
|
# Cache Client
|
|
class ValkeyClient:
|
|
"""Client for Valkey semantic cache API."""
|
|
|
|
def __init__(self, base_url: str = settings.valkey_api_url):
|
|
self.base_url = base_url.rstrip("/")
|
|
self._client: httpx.AsyncClient | None = None
|
|
|
|
@property
|
|
async def client(self) -> httpx.AsyncClient:
|
|
"""Get or create async HTTP client."""
|
|
if self._client is None or self._client.is_closed:
|
|
self._client = httpx.AsyncClient(timeout=30.0)
|
|
return self._client
|
|
|
|
def _cache_key(self, question: str, sources: list[DataSource]) -> str:
|
|
"""Generate cache key from question and sources."""
|
|
sources_str = ",".join(sorted(s.value for s in sources))
|
|
key_str = f"{question.lower().strip()}:{sources_str}"
|
|
return hashlib.sha256(key_str.encode()).hexdigest()[:32]
|
|
|
|
async def get(self, question: str, sources: list[DataSource]) -> dict[str, Any] | None:
|
|
"""Get cached response."""
|
|
try:
|
|
key = self._cache_key(question, sources)
|
|
client = await self.client
|
|
response = await client.get(f"{self.base_url}/get/{key}")
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
if data.get("value"):
|
|
logger.info(f"Cache hit for question: {question[:50]}...")
|
|
return json.loads(data["value"]) # type: ignore[no-any-return]
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Cache get failed: {e}")
|
|
return None
|
|
|
|
async def set(
|
|
self,
|
|
question: str,
|
|
sources: list[DataSource],
|
|
response: dict[str, Any],
|
|
ttl: int = settings.cache_ttl,
|
|
) -> bool:
|
|
"""Cache response."""
|
|
try:
|
|
key = self._cache_key(question, sources)
|
|
client = await self.client
|
|
|
|
await client.post(
|
|
f"{self.base_url}/set",
|
|
json={
|
|
"key": key,
|
|
"value": json.dumps(response),
|
|
"ttl": ttl,
|
|
},
|
|
)
|
|
logger.debug(f"Cached response for: {question[:50]}...")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Cache set failed: {e}")
|
|
return False
|
|
|
|
async def close(self) -> None:
|
|
"""Close HTTP client."""
|
|
if self._client:
|
|
await self._client.aclose()
|
|
self._client = None
|
|
|
|
|
|
# Query Router
|
|
class QueryRouter:
|
|
"""Routes queries to appropriate data sources based on intent."""
|
|
|
|
def __init__(self) -> None:
|
|
self.intent_keywords = {
|
|
QueryIntent.GEOGRAPHIC: [
|
|
"map", "kaart", "where", "waar", "location", "locatie",
|
|
"city", "stad", "country", "land", "region", "gebied",
|
|
"coordinates", "coördinaten", "near", "nearby", "in de buurt",
|
|
],
|
|
QueryIntent.STATISTICAL: [
|
|
"how many", "hoeveel", "count", "aantal", "total", "totaal",
|
|
"average", "gemiddeld", "distribution", "verdeling",
|
|
"percentage", "statistics", "statistiek", "most", "meest",
|
|
],
|
|
QueryIntent.RELATIONAL: [
|
|
"related", "gerelateerd", "connected", "verbonden",
|
|
"relationship", "relatie", "network", "netwerk",
|
|
"parent", "child", "merged", "fusie", "member of",
|
|
],
|
|
QueryIntent.TEMPORAL: [
|
|
"history", "geschiedenis", "timeline", "tijdlijn",
|
|
"when", "wanneer", "founded", "opgericht", "closed", "gesloten",
|
|
"over time", "evolution", "change", "verandering",
|
|
],
|
|
QueryIntent.DETAIL: [
|
|
"details", "information", "informatie", "about", "over",
|
|
"specific", "specifiek", "what is", "wat is",
|
|
],
|
|
}
|
|
|
|
self.source_routing = {
|
|
QueryIntent.GEOGRAPHIC: [DataSource.POSTGIS, DataSource.QDRANT, DataSource.SPARQL],
|
|
QueryIntent.STATISTICAL: [DataSource.SPARQL, DataSource.QDRANT],
|
|
QueryIntent.RELATIONAL: [DataSource.TYPEDB, DataSource.SPARQL],
|
|
QueryIntent.TEMPORAL: [DataSource.TYPEDB, DataSource.SPARQL],
|
|
QueryIntent.SEARCH: [DataSource.QDRANT, DataSource.SPARQL],
|
|
QueryIntent.DETAIL: [DataSource.SPARQL, DataSource.QDRANT],
|
|
}
|
|
|
|
def detect_intent(self, question: str) -> QueryIntent:
|
|
"""Detect query intent from question text."""
|
|
question_lower = question.lower()
|
|
|
|
intent_scores = {intent: 0 for intent in QueryIntent}
|
|
|
|
for intent, keywords in self.intent_keywords.items():
|
|
for keyword in keywords:
|
|
if keyword in question_lower:
|
|
intent_scores[intent] += 1
|
|
|
|
max_intent = max(intent_scores, key=intent_scores.get) # type: ignore
|
|
|
|
if intent_scores[max_intent] == 0:
|
|
return QueryIntent.SEARCH
|
|
|
|
return max_intent
|
|
|
|
def get_sources(
|
|
self,
|
|
question: str,
|
|
requested_sources: list[DataSource] | None = None,
|
|
) -> tuple[QueryIntent, list[DataSource]]:
|
|
"""Get optimal sources for a query.
|
|
|
|
Args:
|
|
question: User's question
|
|
requested_sources: Explicitly requested sources (overrides routing)
|
|
|
|
Returns:
|
|
Tuple of (detected_intent, list_of_sources)
|
|
"""
|
|
intent = self.detect_intent(question)
|
|
|
|
if requested_sources:
|
|
return intent, requested_sources
|
|
|
|
return intent, self.source_routing.get(intent, [DataSource.QDRANT])
|
|
|
|
|
|
# Multi-Source Retriever
|
|
class MultiSourceRetriever:
|
|
"""Orchestrates retrieval across multiple data sources."""
|
|
|
|
def __init__(self) -> None:
|
|
self.cache = ValkeyClient()
|
|
self.router = QueryRouter()
|
|
|
|
# Initialize retrievers lazily
|
|
self._qdrant: HybridRetriever | None = None
|
|
self._typedb: TypeDBRetriever | None = None
|
|
self._sparql_client: httpx.AsyncClient | None = None
|
|
self._postgis_client: httpx.AsyncClient | None = None
|
|
|
|
@property
|
|
def qdrant(self) -> HybridRetriever | None:
|
|
"""Lazy-load Qdrant hybrid retriever with multi-embedding support."""
|
|
if self._qdrant is None and RETRIEVERS_AVAILABLE:
|
|
try:
|
|
self._qdrant = create_hybrid_retriever(
|
|
use_production=settings.qdrant_use_production,
|
|
use_multi_embedding=settings.use_multi_embedding,
|
|
preferred_embedding_model=settings.preferred_embedding_model,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize Qdrant: {e}")
|
|
return self._qdrant
|
|
|
|
@property
|
|
def typedb(self) -> TypeDBRetriever | None:
|
|
"""Lazy-load TypeDB retriever."""
|
|
if self._typedb is None and RETRIEVERS_AVAILABLE:
|
|
try:
|
|
self._typedb = create_typedb_retriever(
|
|
use_production=settings.typedb_use_production # Use TypeDB-specific setting
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize TypeDB: {e}")
|
|
return self._typedb
|
|
|
|
async def _get_sparql_client(self) -> httpx.AsyncClient:
|
|
"""Get SPARQL HTTP client."""
|
|
if self._sparql_client is None or self._sparql_client.is_closed:
|
|
self._sparql_client = httpx.AsyncClient(timeout=30.0)
|
|
return self._sparql_client
|
|
|
|
async def _get_postgis_client(self) -> httpx.AsyncClient:
|
|
"""Get PostGIS HTTP client."""
|
|
if self._postgis_client is None or self._postgis_client.is_closed:
|
|
self._postgis_client = httpx.AsyncClient(timeout=30.0)
|
|
return self._postgis_client
|
|
|
|
async def retrieve_from_qdrant(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
embedding_model: str | None = None,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from Qdrant vector + SPARQL hybrid search.
|
|
|
|
Args:
|
|
query: Search query
|
|
k: Number of results to return
|
|
embedding_model: Optional embedding model to use (e.g., 'minilm_384', 'openai_1536')
|
|
"""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
items = []
|
|
if self.qdrant:
|
|
try:
|
|
results = self.qdrant.search(query, k=k, using=embedding_model)
|
|
items = [r.to_dict() for r in results]
|
|
except Exception as e:
|
|
logger.error(f"Qdrant retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.QDRANT,
|
|
items=items,
|
|
score=max((r.get("scores", {}).get("combined", 0) for r in items), default=0),
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve_from_sparql(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from SPARQL endpoint."""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
# Use DSPy to generate SPARQL
|
|
items = []
|
|
try:
|
|
if RETRIEVERS_AVAILABLE:
|
|
sparql_result = generate_sparql(query, language="nl", use_rag=False)
|
|
sparql_query = sparql_result.get("sparql", "")
|
|
|
|
if sparql_query:
|
|
client = await self._get_sparql_client()
|
|
response = await client.post(
|
|
settings.sparql_endpoint,
|
|
data={"query": sparql_query},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
bindings = data.get("results", {}).get("bindings", [])
|
|
items = [
|
|
{k: v.get("value") for k, v in b.items()}
|
|
for b in bindings[:k]
|
|
]
|
|
except Exception as e:
|
|
logger.error(f"SPARQL retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.SPARQL,
|
|
items=items,
|
|
score=1.0 if items else 0.0,
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve_from_typedb(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from TypeDB knowledge graph."""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
items = []
|
|
if self.typedb:
|
|
try:
|
|
results = self.typedb.semantic_search(query, k=k)
|
|
items = [r.to_dict() for r in results]
|
|
except Exception as e:
|
|
logger.error(f"TypeDB retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.TYPEDB,
|
|
items=items,
|
|
score=max((r.get("relevance_score", 0) for r in items), default=0),
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve_from_postgis(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from PostGIS geospatial database."""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
# Extract location from query for geospatial search
|
|
# This is a simplified implementation
|
|
items = []
|
|
try:
|
|
client = await self._get_postgis_client()
|
|
|
|
# Try to detect city name for bbox search
|
|
query_lower = query.lower()
|
|
|
|
# Simple city detection
|
|
cities = {
|
|
"amsterdam": {"lat": 52.3676, "lon": 4.9041},
|
|
"rotterdam": {"lat": 51.9244, "lon": 4.4777},
|
|
"den haag": {"lat": 52.0705, "lon": 4.3007},
|
|
"utrecht": {"lat": 52.0907, "lon": 5.1214},
|
|
}
|
|
|
|
for city, coords in cities.items():
|
|
if city in query_lower:
|
|
# Query PostGIS for nearby institutions
|
|
response = await client.get(
|
|
f"{settings.postgis_url}/api/institutions/nearby",
|
|
params={
|
|
"lat": coords["lat"],
|
|
"lon": coords["lon"],
|
|
"radius_km": 10,
|
|
"limit": k,
|
|
},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
items = response.json()
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"PostGIS retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.POSTGIS,
|
|
items=items,
|
|
score=1.0 if items else 0.0,
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve(
|
|
self,
|
|
question: str,
|
|
sources: list[DataSource],
|
|
k: int = 10,
|
|
embedding_model: str | None = None,
|
|
) -> list[RetrievalResult]:
|
|
"""Retrieve from multiple sources concurrently.
|
|
|
|
Args:
|
|
question: User's question
|
|
sources: Data sources to query
|
|
k: Number of results per source
|
|
embedding_model: Optional embedding model for Qdrant (e.g., 'minilm_384', 'openai_1536')
|
|
|
|
Returns:
|
|
List of RetrievalResult from each source
|
|
"""
|
|
tasks = []
|
|
|
|
for source in sources:
|
|
if source == DataSource.QDRANT:
|
|
tasks.append(self.retrieve_from_qdrant(question, k, embedding_model))
|
|
elif source == DataSource.SPARQL:
|
|
tasks.append(self.retrieve_from_sparql(question, k))
|
|
elif source == DataSource.TYPEDB:
|
|
tasks.append(self.retrieve_from_typedb(question, k))
|
|
elif source == DataSource.POSTGIS:
|
|
tasks.append(self.retrieve_from_postgis(question, k))
|
|
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Filter out exceptions
|
|
valid_results = []
|
|
for r in results:
|
|
if isinstance(r, RetrievalResult):
|
|
valid_results.append(r)
|
|
elif isinstance(r, Exception):
|
|
logger.error(f"Retrieval task failed: {r}")
|
|
|
|
return valid_results
|
|
|
|
def merge_results(
|
|
self,
|
|
results: list[RetrievalResult],
|
|
max_results: int = 20,
|
|
) -> list[dict[str, Any]]:
|
|
"""Merge and deduplicate results from multiple sources.
|
|
|
|
Uses reciprocal rank fusion for score combination.
|
|
"""
|
|
# Track items by GHCID for deduplication
|
|
merged: dict[str, dict[str, Any]] = {}
|
|
|
|
for result in results:
|
|
for rank, item in enumerate(result.items):
|
|
ghcid = item.get("ghcid", item.get("id", f"unknown_{rank}"))
|
|
|
|
if ghcid not in merged:
|
|
merged[ghcid] = item.copy()
|
|
merged[ghcid]["_sources"] = []
|
|
merged[ghcid]["_rrf_score"] = 0.0
|
|
|
|
# Reciprocal Rank Fusion
|
|
rrf_score = 1.0 / (60 + rank) # k=60 is standard
|
|
|
|
# Weight by source
|
|
source_weights = {
|
|
DataSource.QDRANT: settings.vector_weight,
|
|
DataSource.SPARQL: settings.graph_weight,
|
|
DataSource.TYPEDB: settings.typedb_weight,
|
|
DataSource.POSTGIS: 0.3,
|
|
}
|
|
|
|
weight = source_weights.get(result.source, 0.5)
|
|
merged[ghcid]["_rrf_score"] += rrf_score * weight
|
|
merged[ghcid]["_sources"].append(result.source.value)
|
|
|
|
# Sort by RRF score
|
|
sorted_items = sorted(
|
|
merged.values(),
|
|
key=lambda x: x.get("_rrf_score", 0),
|
|
reverse=True,
|
|
)
|
|
|
|
return sorted_items[:max_results]
|
|
|
|
async def close(self) -> None:
|
|
"""Clean up resources."""
|
|
await self.cache.close()
|
|
|
|
if self._sparql_client:
|
|
await self._sparql_client.aclose()
|
|
if self._postgis_client:
|
|
await self._postgis_client.aclose()
|
|
if self._qdrant:
|
|
self._qdrant.close()
|
|
if self._typedb:
|
|
self._typedb.close()
|
|
|
|
def search_persons(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
filter_custodian: str | None = None,
|
|
only_heritage_relevant: bool = False,
|
|
using: str | None = None,
|
|
) -> list[Any]:
|
|
"""Search for persons/staff in the heritage_persons collection.
|
|
|
|
Delegates to HybridRetriever.search_persons() if available.
|
|
|
|
Args:
|
|
query: Search query
|
|
k: Number of results
|
|
filter_custodian: Optional custodian slug to filter by
|
|
only_heritage_relevant: Only return heritage-relevant staff
|
|
using: Optional embedding model to use (e.g., 'minilm_384', 'openai_1536')
|
|
|
|
Returns:
|
|
List of RetrievedPerson objects
|
|
"""
|
|
if self.qdrant:
|
|
try:
|
|
return self.qdrant.search_persons( # type: ignore[no-any-return]
|
|
query=query,
|
|
k=k,
|
|
filter_custodian=filter_custodian,
|
|
only_heritage_relevant=only_heritage_relevant,
|
|
using=using,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Person search failed: {e}")
|
|
return []
|
|
|
|
def get_stats(self) -> dict[str, Any]:
|
|
"""Get statistics from all retrievers.
|
|
|
|
Returns combined stats from Qdrant (including persons collection) and TypeDB.
|
|
"""
|
|
stats = {}
|
|
|
|
if self.qdrant:
|
|
try:
|
|
qdrant_stats = self.qdrant.get_stats()
|
|
stats.update(qdrant_stats)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get Qdrant stats: {e}")
|
|
|
|
if self.typedb:
|
|
try:
|
|
typedb_stats = self.typedb.get_stats()
|
|
stats["typedb"] = typedb_stats
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get TypeDB stats: {e}")
|
|
|
|
return stats
|
|
|
|
|
|
# Global instances
|
|
retriever: MultiSourceRetriever | None = None
|
|
viz_selector: VisualizationSelector | None = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
|
"""Application lifespan manager."""
|
|
global retriever, viz_selector
|
|
|
|
# Startup
|
|
logger.info("Starting Heritage RAG API...")
|
|
retriever = MultiSourceRetriever()
|
|
|
|
if RETRIEVERS_AVAILABLE:
|
|
# Check for any available LLM API key (Anthropic preferred, OpenAI fallback)
|
|
has_llm_key = bool(settings.anthropic_api_key or settings.openai_api_key)
|
|
|
|
# VisualizationSelector requires DSPy - make it optional
|
|
try:
|
|
viz_selector = VisualizationSelector(use_dspy=has_llm_key)
|
|
except RuntimeError as e:
|
|
logger.warning(f"VisualizationSelector not available: {e}")
|
|
viz_selector = None
|
|
|
|
# Configure DSPy if API key available
|
|
if configure_dspy and settings.anthropic_api_key:
|
|
try:
|
|
configure_dspy(
|
|
provider="anthropic",
|
|
model=settings.default_model,
|
|
api_key=settings.anthropic_api_key,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to configure DSPy with Anthropic: {e}")
|
|
elif configure_dspy and settings.openai_api_key:
|
|
try:
|
|
configure_dspy(
|
|
provider="openai",
|
|
model="gpt-4o-mini",
|
|
api_key=settings.openai_api_key,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to configure DSPy with OpenAI: {e}")
|
|
|
|
logger.info("Heritage RAG API started")
|
|
|
|
yield
|
|
|
|
# Shutdown
|
|
logger.info("Shutting down Heritage RAG API...")
|
|
if retriever:
|
|
await retriever.close()
|
|
logger.info("Heritage RAG API stopped")
|
|
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(
|
|
title=settings.api_title,
|
|
version=settings.api_version,
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# API Endpoints
|
|
@app.get("/api/rag/health")
|
|
async def health_check() -> dict[str, Any]:
|
|
"""Health check for all services."""
|
|
health: dict[str, Any] = {
|
|
"status": "ok",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"services": {},
|
|
}
|
|
|
|
# Check Qdrant
|
|
if retriever and retriever.qdrant:
|
|
try:
|
|
stats = retriever.qdrant.get_stats()
|
|
health["services"]["qdrant"] = {
|
|
"status": "ok",
|
|
"vectors": stats.get("qdrant", {}).get("vectors_count", 0),
|
|
}
|
|
except Exception as e:
|
|
health["services"]["qdrant"] = {"status": "error", "error": str(e)}
|
|
|
|
# Check SPARQL
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
response = await client.get(f"{settings.sparql_endpoint.replace('/query', '')}")
|
|
health["services"]["sparql"] = {
|
|
"status": "ok" if response.status_code < 500 else "error"
|
|
}
|
|
except Exception as e:
|
|
health["services"]["sparql"] = {"status": "error", "error": str(e)}
|
|
|
|
# Check TypeDB
|
|
if retriever and retriever.typedb:
|
|
try:
|
|
stats = retriever.typedb.get_stats()
|
|
health["services"]["typedb"] = {
|
|
"status": "ok",
|
|
"entities": stats.get("entities", {}),
|
|
}
|
|
except Exception as e:
|
|
health["services"]["typedb"] = {"status": "error", "error": str(e)}
|
|
|
|
# Overall status
|
|
services = health["services"]
|
|
errors = sum(1 for s in services.values() if isinstance(s, dict) and s.get("status") == "error")
|
|
health["status"] = "ok" if errors == 0 else "degraded" if errors < 3 else "error"
|
|
|
|
return health
|
|
|
|
|
|
@app.get("/api/rag/stats")
|
|
async def get_stats() -> dict[str, Any]:
|
|
"""Get retriever statistics."""
|
|
stats: dict[str, Any] = {
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"retrievers": {},
|
|
}
|
|
|
|
if retriever:
|
|
if retriever.qdrant:
|
|
stats["retrievers"]["qdrant"] = retriever.qdrant.get_stats()
|
|
if retriever.typedb:
|
|
stats["retrievers"]["typedb"] = retriever.typedb.get_stats()
|
|
|
|
return stats
|
|
|
|
|
|
@app.get("/api/rag/stats/costs")
|
|
async def get_cost_stats() -> dict[str, Any]:
|
|
"""Get cost tracking session statistics.
|
|
|
|
Returns cumulative statistics for the current session including:
|
|
- Total LLM calls and token usage
|
|
- Total retrieval operations and latencies
|
|
- Estimated costs by model
|
|
- Pipeline timing statistics
|
|
|
|
Returns:
|
|
Dict with cost tracker statistics or unavailable message
|
|
"""
|
|
if not COST_TRACKER_AVAILABLE or not get_tracker:
|
|
return {
|
|
"available": False,
|
|
"message": "Cost tracker module not available",
|
|
}
|
|
|
|
tracker = get_tracker()
|
|
return {
|
|
"available": True,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"session": tracker.get_session_summary(),
|
|
}
|
|
|
|
|
|
@app.post("/api/rag/stats/costs/reset")
|
|
async def reset_cost_stats() -> dict[str, Any]:
|
|
"""Reset cost tracking statistics.
|
|
|
|
Clears all accumulated statistics and starts a fresh session.
|
|
Useful for per-conversation or per-session cost tracking.
|
|
|
|
Returns:
|
|
Confirmation message
|
|
"""
|
|
if not COST_TRACKER_AVAILABLE or not reset_tracker:
|
|
return {
|
|
"available": False,
|
|
"message": "Cost tracker module not available",
|
|
}
|
|
|
|
reset_tracker()
|
|
return {
|
|
"available": True,
|
|
"message": "Cost tracking statistics reset",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
|
|
|
|
@app.get("/api/rag/embedding/models")
|
|
async def get_embedding_models() -> dict[str, Any]:
|
|
"""List available embedding models for the Qdrant collections.
|
|
|
|
Returns information about which embedding models are available in each
|
|
collection's named vectors, helping clients choose the right model for
|
|
their use case.
|
|
|
|
Returns:
|
|
Dict with available models per collection, current settings, and recommendations
|
|
"""
|
|
result: dict[str, Any] = {
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"multi_embedding_enabled": settings.use_multi_embedding,
|
|
"preferred_model": settings.preferred_embedding_model,
|
|
"collections": {},
|
|
"models": {
|
|
"openai_1536": {
|
|
"description": "OpenAI text-embedding-3-small (1536 dimensions)",
|
|
"quality": "high",
|
|
"cost": "paid API",
|
|
"recommended_for": "production, high-quality semantic search",
|
|
},
|
|
"minilm_384": {
|
|
"description": "sentence-transformers/all-MiniLM-L6-v2 (384 dimensions)",
|
|
"quality": "good",
|
|
"cost": "free (local)",
|
|
"recommended_for": "development, cost-sensitive deployments",
|
|
},
|
|
"bge_768": {
|
|
"description": "BAAI/bge-small-en-v1.5 (768 dimensions)",
|
|
"quality": "very good",
|
|
"cost": "free (local)",
|
|
"recommended_for": "balanced quality/cost, multilingual support",
|
|
},
|
|
},
|
|
}
|
|
|
|
if retriever and retriever.qdrant:
|
|
qdrant = retriever.qdrant
|
|
|
|
# Check if multi-embedding is enabled and get available models
|
|
if hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding:
|
|
if hasattr(qdrant, 'multi_retriever') and qdrant.multi_retriever:
|
|
multi = qdrant.multi_retriever
|
|
|
|
# Get available models for institutions collection
|
|
try:
|
|
inst_models = multi.get_available_models("heritage_custodians")
|
|
selected = multi.select_model("heritage_custodians")
|
|
result["collections"]["heritage_custodians"] = {
|
|
"available_models": [m.value for m in inst_models],
|
|
"uses_named_vectors": multi.uses_named_vectors("heritage_custodians"),
|
|
"recommended": selected.value if selected else None,
|
|
}
|
|
except Exception as e:
|
|
result["collections"]["heritage_custodians"] = {"error": str(e)}
|
|
|
|
# Get available models for persons collection
|
|
try:
|
|
person_models = multi.get_available_models("heritage_persons")
|
|
selected = multi.select_model("heritage_persons")
|
|
result["collections"]["heritage_persons"] = {
|
|
"available_models": [m.value for m in person_models],
|
|
"uses_named_vectors": multi.uses_named_vectors("heritage_persons"),
|
|
"recommended": selected.value if selected else None,
|
|
}
|
|
except Exception as e:
|
|
result["collections"]["heritage_persons"] = {"error": str(e)}
|
|
else:
|
|
# Single embedding mode - detect dimension
|
|
stats = qdrant.get_stats()
|
|
result["single_embedding_mode"] = True
|
|
result["note"] = "Collections use single embedding vectors. Enable USE_MULTI_EMBEDDING=true to use named vectors."
|
|
|
|
return result
|
|
|
|
|
|
class EmbeddingCompareRequest(BaseModel):
|
|
"""Request for comparing embedding models."""
|
|
query: str = Field(..., description="Query to search with")
|
|
collection: str = Field(default="heritage_persons", description="Collection to search")
|
|
k: int = Field(default=5, ge=1, le=20, description="Number of results per model")
|
|
|
|
|
|
@app.post("/api/rag/embedding/compare")
|
|
async def compare_embedding_models(request: EmbeddingCompareRequest) -> dict[str, Any]:
|
|
"""Compare search results across different embedding models.
|
|
|
|
Performs the same search query using each available embedding model,
|
|
allowing A/B testing of embedding quality.
|
|
|
|
This endpoint is useful for:
|
|
- Evaluating which embedding model works best for your queries
|
|
- Understanding differences in semantic similarity between models
|
|
- Making informed decisions about which model to use in production
|
|
|
|
Returns:
|
|
Dict with results from each embedding model, including scores and overlap analysis
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
if not retriever or not retriever.qdrant:
|
|
raise HTTPException(status_code=503, detail="Qdrant retriever not available")
|
|
|
|
qdrant = retriever.qdrant
|
|
|
|
# Check if multi-embedding is available
|
|
if not (hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Multi-embedding mode not enabled. Set USE_MULTI_EMBEDDING=true to use this endpoint."
|
|
)
|
|
|
|
if not (hasattr(qdrant, 'multi_retriever') and qdrant.multi_retriever):
|
|
raise HTTPException(status_code=503, detail="Multi-embedding retriever not initialized")
|
|
|
|
multi = qdrant.multi_retriever
|
|
|
|
try:
|
|
# Use the compare_models method from MultiEmbeddingRetriever
|
|
comparison = multi.compare_models(
|
|
query=request.query,
|
|
collection=request.collection,
|
|
k=request.k,
|
|
)
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
return {
|
|
"query": request.query,
|
|
"collection": request.collection,
|
|
"k": request.k,
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"comparison": comparison,
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Embedding comparison failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/api/rag/query", response_model=QueryResponse)
|
|
async def query_rag(request: QueryRequest) -> QueryResponse:
|
|
"""Main RAG query endpoint.
|
|
|
|
Orchestrates retrieval from multiple sources, merges results,
|
|
and optionally generates visualization configuration.
|
|
"""
|
|
if not retriever:
|
|
raise HTTPException(status_code=503, detail="Retriever not initialized")
|
|
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
# Check cache first
|
|
cached = await retriever.cache.get(request.question, request.sources)
|
|
if cached:
|
|
cached["cache_hit"] = True
|
|
return QueryResponse(**cached)
|
|
|
|
# Route query to appropriate sources
|
|
intent, sources = retriever.router.get_sources(request.question, request.sources)
|
|
logger.info(f"Query intent: {intent}, sources: {sources}")
|
|
|
|
# Retrieve from all sources
|
|
results = await retriever.retrieve(
|
|
request.question,
|
|
sources,
|
|
request.k,
|
|
embedding_model=request.embedding_model,
|
|
)
|
|
|
|
# Merge results
|
|
merged_items = retriever.merge_results(results, max_results=request.k * 2)
|
|
|
|
# Generate visualization config if requested
|
|
visualization = None
|
|
if request.include_visualization and viz_selector and merged_items:
|
|
# Extract schema from first result
|
|
schema_fields = list(merged_items[0].keys()) if merged_items else []
|
|
schema_str = ", ".join(f for f in schema_fields if not f.startswith("_"))
|
|
|
|
visualization = viz_selector.select(
|
|
request.question,
|
|
schema_str,
|
|
len(merged_items),
|
|
)
|
|
|
|
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
|
|
|
|
response_data = {
|
|
"question": request.question,
|
|
"sparql": None, # Could be populated from SPARQL result
|
|
"results": merged_items,
|
|
"visualization": visualization,
|
|
"sources_used": [s for s in sources],
|
|
"cache_hit": False,
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"result_count": len(merged_items),
|
|
}
|
|
|
|
# Cache the response
|
|
await retriever.cache.set(request.question, request.sources, response_data)
|
|
|
|
return QueryResponse(**response_data) # type: ignore[arg-type]
|
|
|
|
|
|
@app.post("/api/rag/sparql", response_model=SPARQLResponse)
|
|
async def generate_sparql_endpoint(request: SPARQLRequest) -> SPARQLResponse:
|
|
"""Generate SPARQL query from natural language.
|
|
|
|
Uses DSPy with optional RAG enhancement for context.
|
|
"""
|
|
if not RETRIEVERS_AVAILABLE:
|
|
raise HTTPException(status_code=503, detail="SPARQL generator not available")
|
|
|
|
try:
|
|
result = generate_sparql(
|
|
request.question,
|
|
language=request.language,
|
|
context=request.context,
|
|
use_rag=request.use_rag,
|
|
)
|
|
|
|
return SPARQLResponse(
|
|
sparql=result["sparql"],
|
|
explanation=result.get("explanation", ""),
|
|
rag_used=result.get("rag_used", False),
|
|
retrieved_passages=result.get("retrieved_passages", []),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception("SPARQL generation failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/api/rag/visualize")
|
|
async def get_visualization_config(
|
|
question: str = Query(..., description="User's question"),
|
|
schema: str = Query(..., description="Comma-separated field names"),
|
|
result_count: int = Query(default=0, description="Number of results"),
|
|
) -> dict[str, Any]:
|
|
"""Get visualization configuration for a query."""
|
|
if not viz_selector:
|
|
raise HTTPException(status_code=503, detail="Visualization selector not available")
|
|
|
|
config = viz_selector.select(question, schema, result_count)
|
|
return config # type: ignore[no-any-return]
|
|
|
|
|
|
@app.post("/api/rag/typedb/search", response_model=TypeDBSearchResponse)
|
|
async def typedb_search(request: TypeDBSearchRequest) -> TypeDBSearchResponse:
|
|
"""Direct TypeDB search endpoint.
|
|
|
|
Search heritage custodians in TypeDB using various strategies:
|
|
- semantic: Natural language search (combines type + location patterns)
|
|
- name: Search by institution name
|
|
- type: Search by institution type (museum, archive, library, gallery)
|
|
- location: Search by city/location name
|
|
|
|
Examples:
|
|
- {"query": "museums in Amsterdam", "search_type": "semantic"}
|
|
- {"query": "Rijksmuseum", "search_type": "name"}
|
|
- {"query": "archive", "search_type": "type"}
|
|
- {"query": "Rotterdam", "search_type": "location"}
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
# Check if TypeDB retriever is available
|
|
if not retriever or not retriever.typedb:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="TypeDB retriever not available. Ensure TypeDB is running."
|
|
)
|
|
|
|
try:
|
|
typedb_retriever = retriever.typedb
|
|
|
|
# Route to appropriate search method
|
|
if request.search_type == "name":
|
|
results = typedb_retriever.search_by_name(request.query, k=request.k)
|
|
elif request.search_type == "type":
|
|
results = typedb_retriever.search_by_type(request.query, k=request.k)
|
|
elif request.search_type == "location":
|
|
results = typedb_retriever.search_by_location(city=request.query, k=request.k)
|
|
else: # semantic (default)
|
|
results = typedb_retriever.semantic_search(request.query, k=request.k)
|
|
|
|
# Convert results to dicts
|
|
result_dicts = []
|
|
seen_names = set() # Deduplicate by name
|
|
|
|
for r in results:
|
|
# Handle both dict and object results
|
|
if hasattr(r, 'to_dict'):
|
|
item = r.to_dict()
|
|
elif isinstance(r, dict):
|
|
item = r
|
|
else:
|
|
item = {"name": str(r)}
|
|
|
|
# Deduplicate by name
|
|
name = item.get("name") or item.get("observed_name", "")
|
|
if name and name not in seen_names:
|
|
seen_names.add(name)
|
|
result_dicts.append(item)
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
return TypeDBSearchResponse(
|
|
query=request.query,
|
|
search_type=request.search_type,
|
|
results=result_dicts,
|
|
result_count=len(result_dicts),
|
|
query_time_ms=round(elapsed_ms, 2),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"TypeDB search failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/api/rag/persons/search", response_model=PersonSearchResponse)
|
|
async def person_search(request: PersonSearchRequest) -> PersonSearchResponse:
|
|
"""Search for persons/staff in heritage institutions.
|
|
|
|
Search the heritage_persons Qdrant collection for staff members
|
|
at heritage custodian institutions.
|
|
|
|
Examples:
|
|
- {"query": "Wie werkt er in het Nationaal Archief?"}
|
|
- {"query": "archivist at Rijksmuseum", "k": 20}
|
|
- {"query": "conservator", "filter_custodian": "rijksmuseum"}
|
|
- {"query": "digital preservation", "only_heritage_relevant": true}
|
|
|
|
The search uses semantic vector similarity to find relevant staff members
|
|
based on their name, role, headline, and custodian affiliation.
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
# Check if retriever is available
|
|
if not retriever:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Hybrid retriever not available. Ensure Qdrant is running."
|
|
)
|
|
|
|
try:
|
|
# Use the hybrid retriever's person search
|
|
results = retriever.search_persons(
|
|
query=request.query,
|
|
k=request.k,
|
|
filter_custodian=request.filter_custodian,
|
|
only_heritage_relevant=request.only_heritage_relevant,
|
|
using=request.embedding_model, # Pass embedding model
|
|
)
|
|
|
|
# Determine which embedding model was actually used
|
|
embedding_model_used = None
|
|
qdrant = retriever.qdrant
|
|
if qdrant and hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding:
|
|
if request.embedding_model:
|
|
embedding_model_used = request.embedding_model
|
|
elif hasattr(qdrant, '_selected_multi_model') and qdrant._selected_multi_model:
|
|
embedding_model_used = qdrant._selected_multi_model.value
|
|
|
|
# Convert results to dicts using to_dict() method if available
|
|
result_dicts = []
|
|
for r in results:
|
|
if hasattr(r, 'to_dict'):
|
|
item = r.to_dict()
|
|
elif hasattr(r, '__dict__'):
|
|
item = {
|
|
"name": getattr(r, 'name', 'Unknown'),
|
|
"headline": getattr(r, 'headline', None),
|
|
"custodian_name": getattr(r, 'custodian_name', None),
|
|
"custodian_slug": getattr(r, 'custodian_slug', None),
|
|
"linkedin_url": getattr(r, 'linkedin_url', None),
|
|
"heritage_relevant": getattr(r, 'heritage_relevant', None),
|
|
"heritage_type": getattr(r, 'heritage_type', None),
|
|
"location": getattr(r, 'location', None),
|
|
"score": getattr(r, 'combined_score', getattr(r, 'vector_score', None)),
|
|
}
|
|
elif isinstance(r, dict):
|
|
item = r
|
|
else:
|
|
item = {"name": str(r)}
|
|
|
|
result_dicts.append(item)
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
# Get collection stats
|
|
stats = None
|
|
try:
|
|
stats = retriever.get_stats()
|
|
# Only include person collection stats if available
|
|
if stats and 'persons' in stats:
|
|
stats = {'persons': stats['persons']}
|
|
except Exception:
|
|
pass
|
|
|
|
return PersonSearchResponse(
|
|
query=request.query,
|
|
results=result_dicts,
|
|
result_count=len(result_dicts),
|
|
query_time_ms=round(elapsed_ms, 2),
|
|
collection_stats=stats,
|
|
embedding_model_used=embedding_model_used,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Person search failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/api/rag/dspy/query", response_model=DSPyQueryResponse)
|
|
async def dspy_query(request: DSPyQueryRequest) -> DSPyQueryResponse:
|
|
"""DSPy RAG query endpoint with multi-turn conversation support.
|
|
|
|
Uses the HeritageRAGPipeline for conversation-aware question answering.
|
|
Follow-up questions like "Welke daarvan behoren archieven?" will be
|
|
resolved using previous conversation context.
|
|
|
|
Args:
|
|
request: Query request with question, language, and conversation context
|
|
|
|
Returns:
|
|
DSPyQueryResponse with answer, resolved question, and optional visualization
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Import DSPy pipeline and History
|
|
import dspy
|
|
from dspy import History
|
|
from dspy_heritage_rag import HeritageRAGPipeline
|
|
|
|
# Ensure DSPy has an LM configured
|
|
# Check if LM is already configured by testing if we can get the settings
|
|
try:
|
|
current_lm = dspy.settings.lm
|
|
if current_lm is None:
|
|
raise ValueError("No LM configured")
|
|
except (AttributeError, ValueError):
|
|
# No LM configured yet - try to configure one
|
|
api_key = settings.anthropic_api_key or os.getenv("ANTHROPIC_API_KEY", "")
|
|
if api_key:
|
|
lm = dspy.LM("anthropic/claude-sonnet-4-20250514", api_key=api_key)
|
|
dspy.configure(lm=lm)
|
|
logger.info("Configured DSPy with Anthropic Claude")
|
|
else:
|
|
# Try OpenAI as fallback
|
|
openai_key = os.getenv("OPENAI_API_KEY", "")
|
|
if openai_key:
|
|
lm = dspy.LM("openai/gpt-4o-mini", api_key=openai_key)
|
|
dspy.configure(lm=lm)
|
|
logger.info("Configured DSPy with OpenAI GPT-4o-mini")
|
|
else:
|
|
raise ValueError(
|
|
"No LLM API key found. Set ANTHROPIC_API_KEY or OPENAI_API_KEY environment variable."
|
|
)
|
|
|
|
# Convert context to DSPy History format
|
|
# Context comes as [{question: "...", answer: "..."}, ...]
|
|
# History expects messages with role and content
|
|
history_messages = []
|
|
for turn in request.context:
|
|
if turn.get("question"):
|
|
history_messages.append({"role": "user", "content": turn["question"]})
|
|
if turn.get("answer"):
|
|
history_messages.append({"role": "assistant", "content": turn["answer"]})
|
|
|
|
history = History(messages=history_messages) if history_messages else None
|
|
|
|
# Initialize pipeline with retriever for actual data retrieval
|
|
# Pass the qdrant retriever (HybridRetriever) for person/institution searches
|
|
qdrant_retriever = retriever.qdrant if retriever else None
|
|
pipeline = HeritageRAGPipeline(retriever=qdrant_retriever)
|
|
|
|
# Execute query with conversation history
|
|
result = pipeline.forward(
|
|
embedding_model=request.embedding_model,
|
|
question=request.question,
|
|
language=request.language,
|
|
history=history,
|
|
include_viz=request.include_visualization,
|
|
)
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
# Extract visualization if present
|
|
visualization = None
|
|
if request.include_visualization and hasattr(result, "visualization"):
|
|
viz = result.visualization
|
|
if viz:
|
|
visualization = {
|
|
"type": getattr(viz, "viz_type", "table"),
|
|
"sparql_query": getattr(result, "sparql", None),
|
|
}
|
|
|
|
# Extract retrieved results for frontend visualization (tables, graphs)
|
|
retrieved_results = getattr(result, "retrieved_results", None)
|
|
query_type = getattr(result, "query_type", None)
|
|
|
|
return DSPyQueryResponse(
|
|
question=request.question,
|
|
resolved_question=getattr(result, "resolved_question", None),
|
|
answer=getattr(result, "answer", "Geen antwoord gevonden."),
|
|
sources_used=getattr(result, "sources_used", []),
|
|
visualization=visualization,
|
|
retrieved_results=retrieved_results, # Raw data for frontend visualization
|
|
query_type=query_type, # "person" or "institution"
|
|
query_time_ms=round(elapsed_ms, 2),
|
|
conversation_turn=len(request.context),
|
|
embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model),
|
|
# Cost tracking fields
|
|
timing_ms=getattr(result, "timing_ms", None),
|
|
cost_usd=getattr(result, "cost_usd", None),
|
|
timing_breakdown=getattr(result, "timing_breakdown", None),
|
|
)
|
|
|
|
except ImportError as e:
|
|
logger.warning(f"DSPy pipeline not available: {e}")
|
|
# Fallback to simple response
|
|
return DSPyQueryResponse(
|
|
question=request.question,
|
|
answer="DSPy pipeline is niet beschikbaar. Probeer de standaard /api/rag/query endpoint.",
|
|
query_time_ms=0,
|
|
conversation_turn=len(request.context),
|
|
embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception("DSPy query failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
async def stream_query_response(
|
|
request: QueryRequest,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream query response for long-running queries."""
|
|
if not retriever:
|
|
yield json.dumps({"error": "Retriever not initialized"})
|
|
return
|
|
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
# Route query
|
|
intent, sources = retriever.router.get_sources(request.question, request.sources)
|
|
|
|
yield json.dumps({
|
|
"type": "status",
|
|
"message": f"Routing query to {len(sources)} sources...",
|
|
"intent": intent.value,
|
|
}) + "\n"
|
|
|
|
# Retrieve from sources and stream progress
|
|
results = []
|
|
for source in sources:
|
|
yield json.dumps({
|
|
"type": "status",
|
|
"message": f"Querying {source.value}...",
|
|
}) + "\n"
|
|
|
|
source_results = await retriever.retrieve(
|
|
request.question,
|
|
[source],
|
|
request.k,
|
|
embedding_model=request.embedding_model,
|
|
)
|
|
results.extend(source_results)
|
|
|
|
yield json.dumps({
|
|
"type": "partial",
|
|
"source": source.value,
|
|
"count": len(source_results[0].items) if source_results else 0,
|
|
}) + "\n"
|
|
|
|
# Merge and finalize
|
|
merged = retriever.merge_results(results)
|
|
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
|
|
|
|
yield json.dumps({
|
|
"type": "complete",
|
|
"results": merged,
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"result_count": len(merged),
|
|
}) + "\n"
|
|
|
|
|
|
@app.post("/api/rag/query/stream")
|
|
async def query_rag_stream(request: QueryRequest) -> StreamingResponse:
|
|
"""Streaming version of RAG query endpoint."""
|
|
return StreamingResponse(
|
|
stream_query_response(request),
|
|
media_type="application/x-ndjson",
|
|
)
|
|
|
|
|
|
# Main entry point
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"main:app",
|
|
host="0.0.0.0",
|
|
port=8003,
|
|
reload=settings.debug,
|
|
log_level="info",
|
|
)
|