- JP: 2,846 processed (24% of 12,096) - CZ: 2,550 processed (30% of 8,432) - CH, NL, BE, AT, BR: 100% complete - Total: 10,913 of 31,772 files (34%) - Using crawl4ai favicon extraction
3539 lines
140 KiB
Python
3539 lines
140 KiB
Python
from __future__ import annotations
|
|
|
|
"""
|
|
Unified RAG Backend for Heritage Custodian Data
|
|
|
|
Multi-source retrieval-augmented generation system that combines:
|
|
- Qdrant vector search (semantic similarity)
|
|
- Oxigraph SPARQL (knowledge graph queries)
|
|
- TypeDB (relationship traversal)
|
|
- PostGIS (geospatial queries)
|
|
- Valkey (semantic caching)
|
|
|
|
Architecture:
|
|
User Query → Query Analysis
|
|
↓
|
|
┌─────┴─────┐
|
|
│ Router │
|
|
└─────┬─────┘
|
|
┌─────┬─────┼─────┬─────┐
|
|
↓ ↓ ↓ ↓ ↓
|
|
Qdrant SPARQL TypeDB PostGIS Cache
|
|
│ │ │ │ │
|
|
└─────┴─────┴─────┴─────┘
|
|
↓
|
|
┌─────┴─────┐
|
|
│ Merger │
|
|
└─────┬─────┘
|
|
↓
|
|
DSPy Generator
|
|
↓
|
|
Visualization Selector
|
|
↓
|
|
Response (JSON/Streaming)
|
|
|
|
Features:
|
|
- Intelligent query routing to appropriate data sources
|
|
- Score fusion for multi-source results
|
|
- Semantic caching via Valkey API
|
|
- Streaming responses for long-running queries
|
|
- DSPy assertions for output validation
|
|
|
|
Endpoints:
|
|
- POST /api/rag/query - Main RAG query endpoint
|
|
- POST /api/rag/sparql - Generate SPARQL with RAG context
|
|
- POST /api/rag/typedb/search - Direct TypeDB search
|
|
- GET /api/rag/health - Health check for all services
|
|
- GET /api/rag/stats - Retriever statistics
|
|
"""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from enum import Enum
|
|
from typing import Any, AsyncIterator, TYPE_CHECKING
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, HTTPException, Query
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Type hints for optional imports (only used during type checking)
|
|
if TYPE_CHECKING:
|
|
from glam_extractor.api.hybrid_retriever import HybridRetriever
|
|
from glam_extractor.api.typedb_retriever import TypeDBRetriever
|
|
from glam_extractor.api.visualization import VisualizationSelector
|
|
|
|
# Import retrievers (with graceful fallbacks)
|
|
RETRIEVERS_AVAILABLE = False
|
|
create_hybrid_retriever: Any = None
|
|
HeritageCustodianRetriever: Any = None
|
|
create_typedb_retriever: Any = None
|
|
select_visualization: Any = None
|
|
VisualizationSelector: Any = None # type: ignore[no-redef]
|
|
generate_sparql: Any = None
|
|
configure_dspy: Any = None
|
|
get_province_code: Any = None # Province name to ISO 3166-2 code converter
|
|
|
|
try:
|
|
import sys
|
|
sys.path.insert(0, str(os.path.join(os.path.dirname(__file__), "..", "..", "src")))
|
|
from glam_extractor.api.hybrid_retriever import (
|
|
HybridRetriever as _HybridRetriever,
|
|
create_hybrid_retriever as _create_hybrid_retriever,
|
|
get_province_code as _get_province_code,
|
|
PERSON_JSONLD_CONTEXT,
|
|
)
|
|
from glam_extractor.api.qdrant_retriever import HeritageCustodianRetriever as _HeritageCustodianRetriever
|
|
from glam_extractor.api.typedb_retriever import TypeDBRetriever as _TypeDBRetriever, create_typedb_retriever as _create_typedb_retriever
|
|
from glam_extractor.api.visualization import select_visualization as _select_visualization, VisualizationSelector as _VisualizationSelector
|
|
# Assign to module-level variables
|
|
create_hybrid_retriever = _create_hybrid_retriever
|
|
HeritageCustodianRetriever = _HeritageCustodianRetriever
|
|
create_typedb_retriever = _create_typedb_retriever
|
|
select_visualization = _select_visualization
|
|
VisualizationSelector = _VisualizationSelector
|
|
get_province_code = _get_province_code
|
|
RETRIEVERS_AVAILABLE = True
|
|
except ImportError as e:
|
|
logger.warning(f"Core retrievers not available: {e}")
|
|
# Provide a fallback get_province_code that returns None
|
|
def get_province_code(province_name: str | None) -> str | None:
|
|
"""Fallback when hybrid_retriever is not available."""
|
|
return None
|
|
|
|
# DSPy is optional - don't block retrievers if it's missing
|
|
try:
|
|
from glam_extractor.api.dspy_sparql import generate_sparql as _generate_sparql, configure_dspy as _configure_dspy
|
|
generate_sparql = _generate_sparql
|
|
configure_dspy = _configure_dspy
|
|
except ImportError as e:
|
|
logger.warning(f"DSPy SPARQL not available: {e}")
|
|
|
|
# Atomic query decomposition for geographic/type filtering
|
|
decompose_query: Any = None
|
|
DECOMPOSER_AVAILABLE = False
|
|
try:
|
|
from atomic_decomposer import decompose_query as _decompose_query
|
|
decompose_query = _decompose_query
|
|
DECOMPOSER_AVAILABLE = True
|
|
logger.info("Query decomposer loaded successfully")
|
|
except ImportError as e:
|
|
logger.info(f"Query decomposer not available: {e}")
|
|
|
|
# Cost tracker is optional - gracefully degrades if unavailable
|
|
COST_TRACKER_AVAILABLE = False
|
|
get_tracker = None
|
|
reset_tracker = None
|
|
try:
|
|
from cost_tracker import get_tracker as _get_tracker, reset_tracker as _reset_tracker
|
|
get_tracker = _get_tracker
|
|
reset_tracker = _reset_tracker
|
|
COST_TRACKER_AVAILABLE = True
|
|
logger.info("Cost tracker module loaded successfully")
|
|
except ImportError as e:
|
|
logger.info(f"Cost tracker not available (optional): {e}")
|
|
|
|
|
|
# Province detection for geographic filtering
|
|
DUTCH_PROVINCES = {
|
|
"noord-holland", "noordholland", "north holland", "north-holland",
|
|
"zuid-holland", "zuidholland", "south holland", "south-holland",
|
|
"utrecht", "gelderland", "noord-brabant", "noordbrabant", "brabant",
|
|
"north brabant", "limburg", "overijssel", "friesland", "fryslân",
|
|
"fryslan", "groningen", "drenthe", "flevoland", "zeeland",
|
|
}
|
|
|
|
|
|
def infer_location_level(location: str) -> str:
|
|
"""Infer whether location is city, province, or region.
|
|
|
|
Returns:
|
|
'province' if location is a Dutch province
|
|
'region' if location is a sub-provincial region
|
|
'city' otherwise
|
|
"""
|
|
location_lower = location.lower().strip()
|
|
|
|
if location_lower in DUTCH_PROVINCES:
|
|
return "province"
|
|
|
|
# Sub-provincial regions
|
|
regions = {"randstad", "veluwe", "achterhoek", "twente", "de betuwe", "betuwe"}
|
|
if location_lower in regions:
|
|
return "region"
|
|
|
|
return "city"
|
|
|
|
|
|
def extract_geographic_filters(question: str) -> dict[str, list[str] | None]:
|
|
"""Extract geographic filters from a question using query decomposition.
|
|
|
|
Returns:
|
|
dict with keys: region_codes, cities, institution_types
|
|
"""
|
|
filters: dict[str, list[str] | None] = {
|
|
"region_codes": None,
|
|
"cities": None,
|
|
"institution_types": None,
|
|
}
|
|
|
|
if not DECOMPOSER_AVAILABLE or not decompose_query:
|
|
return filters
|
|
|
|
try:
|
|
decomposed = decompose_query(question)
|
|
|
|
# Extract location and determine if it's a province or city
|
|
if decomposed.location:
|
|
location = decomposed.location
|
|
level = infer_location_level(location)
|
|
|
|
if level == "province":
|
|
# Convert province name to ISO 3166-2 code for Qdrant filtering
|
|
# e.g., "Noord-Holland" → "NH"
|
|
province_code = get_province_code(location)
|
|
if province_code:
|
|
filters["region_codes"] = [province_code]
|
|
logger.info(f"Province filter: {location} → {province_code}")
|
|
elif level == "city":
|
|
filters["cities"] = [location]
|
|
logger.info(f"City filter: {location}")
|
|
|
|
# Extract institution type
|
|
if decomposed.institution_type:
|
|
# Map common types to enum values
|
|
type_mapping = {
|
|
"archive": "ARCHIVE",
|
|
"archief": "ARCHIVE",
|
|
"archieven": "ARCHIVE",
|
|
"museum": "MUSEUM",
|
|
"musea": "MUSEUM",
|
|
"museums": "MUSEUM",
|
|
"library": "LIBRARY",
|
|
"bibliotheek": "LIBRARY",
|
|
"bibliotheken": "LIBRARY",
|
|
"gallery": "GALLERY",
|
|
"galerie": "GALLERY",
|
|
}
|
|
inst_type = decomposed.institution_type.lower()
|
|
mapped_type = type_mapping.get(inst_type, inst_type.upper())
|
|
filters["institution_types"] = [mapped_type]
|
|
logger.info(f"Institution type filter: {mapped_type}")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to extract geographic filters: {e}")
|
|
|
|
return filters
|
|
|
|
|
|
# Configuration
|
|
class Settings:
|
|
"""Application settings from environment variables."""
|
|
|
|
# API Configuration
|
|
api_title: str = "Heritage RAG API"
|
|
api_version: str = "1.0.0"
|
|
debug: bool = os.getenv("DEBUG", "false").lower() == "true"
|
|
|
|
# Valkey Cache
|
|
valkey_api_url: str = os.getenv("VALKEY_API_URL", "https://bronhouder.nl/api/cache")
|
|
cache_ttl: int = int(os.getenv("CACHE_TTL", "900")) # 15 minutes
|
|
|
|
# Qdrant Vector DB
|
|
# Production: Use URL-based client via bronhouder.nl/qdrant reverse proxy
|
|
qdrant_host: str = os.getenv("QDRANT_HOST", "localhost")
|
|
qdrant_port: int = int(os.getenv("QDRANT_PORT", "6333"))
|
|
qdrant_use_production: bool = os.getenv("QDRANT_USE_PRODUCTION", "true").lower() == "true"
|
|
qdrant_production_url: str = os.getenv("QDRANT_PRODUCTION_URL", "https://bronhouder.nl/qdrant")
|
|
|
|
# Multi-Embedding Support
|
|
# Enable to use named vectors with multiple embedding models (OpenAI 1536, MiniLM 384, BGE 768)
|
|
use_multi_embedding: bool = os.getenv("USE_MULTI_EMBEDDING", "true").lower() == "true"
|
|
preferred_embedding_model: str | None = os.getenv("PREFERRED_EMBEDDING_MODEL", None) # e.g., "minilm_384" or "openai_1536"
|
|
|
|
# Oxigraph SPARQL
|
|
# Production: Use bronhouder.nl/sparql reverse proxy
|
|
sparql_endpoint: str = os.getenv("SPARQL_ENDPOINT", "https://bronhouder.nl/sparql")
|
|
|
|
# TypeDB
|
|
# Note: TypeDB not exposed via reverse proxy - always use localhost
|
|
typedb_host: str = os.getenv("TYPEDB_HOST", "localhost")
|
|
typedb_port: int = int(os.getenv("TYPEDB_PORT", "1729"))
|
|
typedb_database: str = os.getenv("TYPEDB_DATABASE", "heritage_custodians")
|
|
typedb_use_production: bool = os.getenv("TYPEDB_USE_PRODUCTION", "false").lower() == "true" # Default off
|
|
|
|
# PostGIS/Geo API
|
|
# Production: Use bronhouder.nl/api/geo reverse proxy
|
|
postgis_url: str = os.getenv("POSTGIS_URL", "https://bronhouder.nl/api/geo")
|
|
|
|
# DuckLake Analytics
|
|
ducklake_url: str = os.getenv("DUCKLAKE_URL", "http://localhost:8001")
|
|
|
|
# LLM Configuration
|
|
anthropic_api_key: str = os.getenv("ANTHROPIC_API_KEY", "")
|
|
openai_api_key: str = os.getenv("OPENAI_API_KEY", "")
|
|
huggingface_api_key: str = os.getenv("HUGGINGFACE_API_KEY", "")
|
|
groq_api_key: str = os.getenv("GROQ_API_KEY", "")
|
|
zai_api_token: str = os.getenv("ZAI_API_TOKEN", "")
|
|
default_model: str = os.getenv("DEFAULT_MODEL", "claude-opus-4-5-20251101")
|
|
# LLM Provider: "anthropic", "openai", "huggingface", "zai" (FREE), or "groq" (FREE)
|
|
llm_provider: str = os.getenv("LLM_PROVIDER", "anthropic")
|
|
# LLM Model: Specific model to use. Defaults depend on provider.
|
|
# For Z.AI: "glm-4.5-flash" (fast, recommended) or "glm-4.6" (reasoning, slow)
|
|
llm_model: str = os.getenv("LLM_MODEL", "glm-4.5-flash")
|
|
# Fast LM Provider for routing/extraction: "openai" (fast ~1-2s) or "zai" (FREE but slow ~13s)
|
|
# Default to openai for speed. Set to "zai" to save costs (free but adds ~12s latency)
|
|
fast_lm_provider: str = os.getenv("FAST_LM_PROVIDER", "openai")
|
|
|
|
# Retrieval weights
|
|
vector_weight: float = float(os.getenv("VECTOR_WEIGHT", "0.5"))
|
|
graph_weight: float = float(os.getenv("GRAPH_WEIGHT", "0.3"))
|
|
typedb_weight: float = float(os.getenv("TYPEDB_WEIGHT", "0.2"))
|
|
|
|
|
|
settings = Settings()
|
|
|
|
|
|
# Enums and Models
|
|
class QueryIntent(str, Enum):
|
|
"""Detected query intent for routing."""
|
|
GEOGRAPHIC = "geographic" # Location-based queries
|
|
STATISTICAL = "statistical" # Counts, aggregations
|
|
RELATIONAL = "relational" # Relationships between entities
|
|
TEMPORAL = "temporal" # Historical, timeline queries
|
|
SEARCH = "search" # General text search
|
|
DETAIL = "detail" # Specific entity lookup
|
|
|
|
|
|
class DataSource(str, Enum):
|
|
"""Available data sources."""
|
|
QDRANT = "qdrant"
|
|
SPARQL = "sparql"
|
|
TYPEDB = "typedb"
|
|
POSTGIS = "postgis"
|
|
CACHE = "cache"
|
|
DUCKLAKE = "ducklake"
|
|
|
|
|
|
@dataclass
|
|
class RetrievalResult:
|
|
"""Result from a single retriever."""
|
|
source: DataSource
|
|
items: list[dict[str, Any]]
|
|
score: float = 0.0
|
|
query_time_ms: float = 0.0
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
class QueryRequest(BaseModel):
|
|
"""RAG query request."""
|
|
question: str = Field(..., description="Natural language question")
|
|
language: str = Field(default="nl", description="Language code (nl or en)")
|
|
context: list[dict[str, Any]] = Field(default=[], description="Conversation history")
|
|
sources: list[DataSource] | None = Field(
|
|
default=None,
|
|
description="Data sources to query. If None, auto-routes based on query intent.",
|
|
)
|
|
k: int = Field(default=10, description="Number of results per source")
|
|
include_visualization: bool = Field(default=True, description="Include visualization config")
|
|
embedding_model: str | None = Field(
|
|
default=None,
|
|
description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available."
|
|
)
|
|
stream: bool = Field(default=False, description="Stream response")
|
|
|
|
|
|
class QueryResponse(BaseModel):
|
|
"""RAG query response."""
|
|
question: str
|
|
sparql: str | None = None
|
|
results: list[dict[str, Any]]
|
|
visualization: dict[str, Any] | None = None
|
|
sources_used: list[DataSource]
|
|
cache_hit: bool = False
|
|
query_time_ms: float
|
|
result_count: int
|
|
|
|
|
|
class SPARQLRequest(BaseModel):
|
|
"""SPARQL generation request."""
|
|
question: str
|
|
language: str = "nl"
|
|
context: list[dict[str, Any]] = []
|
|
use_rag: bool = True
|
|
|
|
|
|
class SPARQLResponse(BaseModel):
|
|
"""SPARQL generation response."""
|
|
sparql: str
|
|
explanation: str
|
|
rag_used: bool
|
|
retrieved_passages: list[str] = []
|
|
|
|
|
|
class TypeDBSearchRequest(BaseModel):
|
|
"""TypeDB search request."""
|
|
query: str = Field(..., description="Search query (name, type, or location)")
|
|
search_type: str = Field(
|
|
default="semantic",
|
|
description="Search type: semantic, name, type, or location"
|
|
)
|
|
k: int = Field(default=10, ge=1, le=100, description="Number of results")
|
|
|
|
|
|
class TypeDBSearchResponse(BaseModel):
|
|
"""TypeDB search response."""
|
|
query: str
|
|
search_type: str
|
|
results: list[dict[str, Any]]
|
|
result_count: int
|
|
query_time_ms: float
|
|
|
|
|
|
class PersonSearchRequest(BaseModel):
|
|
"""Person/staff search request."""
|
|
query: str = Field(..., description="Search query for person/staff (e.g., 'Wie werkt er in het Nationaal Archief?')")
|
|
k: int = Field(default=10, ge=1, le=100, description="Number of results to return")
|
|
filter_custodian: str | None = Field(default=None, description="Filter by custodian slug (e.g., 'nationaal-archief')")
|
|
only_heritage_relevant: bool = Field(default=False, description="Only return heritage-relevant staff")
|
|
embedding_model: str | None = Field(
|
|
default=None,
|
|
description="Embedding model to use (e.g., 'minilm_384', 'openai_1536'). If None, auto-selects best available."
|
|
)
|
|
|
|
|
|
class PersonSearchResponse(BaseModel):
|
|
"""Person/staff search response with JSON-LD linked data."""
|
|
|
|
context: dict[str, Any] | None = Field(
|
|
default=None,
|
|
alias="@context",
|
|
description="JSON-LD context for linked data semantic interoperability"
|
|
)
|
|
query: str
|
|
results: list[dict[str, Any]]
|
|
result_count: int
|
|
query_time_ms: float
|
|
collection_stats: dict[str, Any] | None = None
|
|
embedding_model_used: str | None = None
|
|
|
|
model_config = {"populate_by_name": True}
|
|
|
|
|
|
class DSPyQueryRequest(BaseModel):
|
|
"""DSPy RAG query request with conversation support."""
|
|
question: str = Field(..., description="Natural language question")
|
|
language: str = Field(default="nl", description="Language code (nl or en)")
|
|
context: list[dict[str, Any]] = Field(
|
|
default=[],
|
|
description="Conversation history as list of {question, answer} dicts"
|
|
)
|
|
include_visualization: bool = Field(default=True, description="Include visualization config")
|
|
embedding_model: str | None = Field(
|
|
default=None,
|
|
description="Embedding model to use for vector search (e.g., 'minilm_384', 'openai_1536', 'bge_768'). If None, auto-selects best available."
|
|
)
|
|
llm_provider: str | None = Field(
|
|
default=None,
|
|
description="LLM provider to use for this request: 'zai', 'anthropic', 'huggingface', or 'openai'. If None, uses server default (LLM_PROVIDER env)."
|
|
)
|
|
llm_model: str | None = Field(
|
|
default=None,
|
|
description="Specific LLM model to use (e.g., 'glm-4.6', 'claude-sonnet-4-5-20250929', 'gpt-4o'). If None, uses provider default."
|
|
)
|
|
|
|
|
|
class LLMResponseMetadata(BaseModel):
|
|
"""LLM response provenance metadata (aligned with LinkML LLMResponse schema).
|
|
|
|
Captures GLM 4.7 Interleaved Thinking chain-of-thought reasoning and
|
|
full API response metadata for audit trails and debugging.
|
|
|
|
See: schemas/20251121/linkml/modules/classes/LLMResponse.yaml
|
|
"""
|
|
# Core response content
|
|
content: str | None = None # The final LLM response text
|
|
reasoning_content: str | None = None # GLM 4.7 Interleaved Thinking chain-of-thought
|
|
|
|
# Model identification
|
|
model: str | None = None # Model identifier (e.g., 'glm-4.7', 'claude-3-opus')
|
|
provider: str | None = None # Provider enum: zai, anthropic, openai, huggingface, groq
|
|
|
|
# Request tracking
|
|
request_id: str | None = None # Provider-assigned request ID
|
|
created: str | None = None # ISO 8601 timestamp of response generation
|
|
|
|
# Token usage (for cost estimation and monitoring)
|
|
prompt_tokens: int | None = None # Tokens in input prompt
|
|
completion_tokens: int | None = None # Tokens in response (content + reasoning)
|
|
total_tokens: int | None = None # Total tokens used
|
|
cached_tokens: int | None = None # Tokens served from provider cache
|
|
|
|
# Response metadata
|
|
finish_reason: str | None = None # stop, length, tool_calls, content_filter
|
|
latency_ms: int | None = None # Response latency in milliseconds
|
|
|
|
# GLM 4.7 Thinking Mode configuration
|
|
thinking_mode: str | None = None # enabled, disabled, interleaved, preserved
|
|
clear_thinking: bool | None = None # False = Preserved Thinking enabled
|
|
|
|
|
|
class DSPyQueryResponse(BaseModel):
|
|
"""DSPy RAG query response."""
|
|
question: str
|
|
resolved_question: str | None = None
|
|
answer: str
|
|
sources_used: list[str] = []
|
|
visualization: dict[str, Any] | None = None
|
|
retrieved_results: list[dict[str, Any]] | None = None # Raw retrieved data for frontend visualization
|
|
query_type: str | None = None # "person" or "institution" - helps frontend choose visualization
|
|
query_time_ms: float = 0.0
|
|
conversation_turn: int = 0
|
|
embedding_model_used: str | None = None # Which embedding model was used for the search
|
|
llm_provider_used: str | None = None # Which LLM provider handled this request (zai, anthropic, huggingface, openai)
|
|
llm_model_used: str | None = None # Which specific LLM model was used (e.g., 'glm-4.6', 'claude-sonnet-4-5-20250929')
|
|
|
|
# Cost tracking fields (from cost_tracker module)
|
|
timing_ms: float | None = None # Total pipeline timing from cost tracker
|
|
cost_usd: float | None = None # Estimated LLM cost in USD
|
|
timing_breakdown: dict[str, float] | None = None # Per-stage timing breakdown
|
|
|
|
# Cache tracking
|
|
cache_hit: bool = False # Whether response was served from cache
|
|
|
|
# LLM response provenance (GLM 4.7 Thinking Mode support)
|
|
llm_response: LLMResponseMetadata | None = None # Full LLM response metadata including reasoning_content
|
|
|
|
|
|
def extract_llm_response_metadata(
|
|
lm: Any,
|
|
provider: str | None = None,
|
|
latency_ms: int | None = None,
|
|
) -> LLMResponseMetadata | None:
|
|
"""Extract LLM response metadata from DSPy LM history.
|
|
|
|
DSPy stores the raw API response in lm.history[-1]["response"], which includes:
|
|
- choices[0].message.content (final response text)
|
|
- choices[0].message.reasoning_content (GLM 4.7 Interleaved Thinking)
|
|
- usage.prompt_tokens, completion_tokens, total_tokens
|
|
- model, created, id, finish_reason
|
|
|
|
This enables capturing GLM 4.7's chain-of-thought reasoning for provenance.
|
|
|
|
Args:
|
|
lm: DSPy LM instance with history attribute
|
|
provider: LLM provider name (zai, anthropic, openai, etc.)
|
|
latency_ms: Response latency in milliseconds
|
|
|
|
Returns:
|
|
LLMResponseMetadata or None if history is empty
|
|
"""
|
|
try:
|
|
# Check if LM has history
|
|
if not hasattr(lm, "history") or not lm.history:
|
|
logger.debug("No LM history available for metadata extraction")
|
|
return None
|
|
|
|
# Get the last history entry (most recent LLM call)
|
|
last_entry = lm.history[-1]
|
|
response = last_entry.get("response")
|
|
|
|
if response is None:
|
|
logger.debug("No response in LM history entry")
|
|
return None
|
|
|
|
# Extract content and reasoning_content from the response
|
|
content = None
|
|
reasoning_content = None
|
|
finish_reason = None
|
|
|
|
if hasattr(response, "choices") and response.choices:
|
|
choice = response.choices[0]
|
|
if hasattr(choice, "message"):
|
|
message = choice.message
|
|
content = getattr(message, "content", None)
|
|
# GLM 4.7 Interleaved Thinking - check for reasoning_content
|
|
reasoning_content = getattr(message, "reasoning_content", None)
|
|
elif isinstance(choice, dict):
|
|
content = choice.get("text") or choice.get("message", {}).get("content")
|
|
reasoning_content = choice.get("message", {}).get("reasoning_content")
|
|
|
|
# Extract finish_reason
|
|
finish_reason = getattr(choice, "finish_reason", None)
|
|
if finish_reason is None and isinstance(choice, dict):
|
|
finish_reason = choice.get("finish_reason")
|
|
|
|
# Extract usage statistics - handle both dict and object types
|
|
# (DSPy/OpenAI SDK may return CompletionUsage objects instead of dicts)
|
|
usage = last_entry.get("usage")
|
|
prompt_tokens = None
|
|
completion_tokens = None
|
|
total_tokens = None
|
|
cached_tokens = None
|
|
|
|
if usage is not None:
|
|
if hasattr(usage, "prompt_tokens"):
|
|
# It's an object (e.g., CompletionUsage from OpenAI SDK)
|
|
prompt_tokens = getattr(usage, "prompt_tokens", None)
|
|
completion_tokens = getattr(usage, "completion_tokens", None)
|
|
total_tokens = getattr(usage, "total_tokens", None)
|
|
prompt_details = getattr(usage, "prompt_tokens_details", None)
|
|
if prompt_details is not None:
|
|
cached_tokens = getattr(prompt_details, "cached_tokens", None)
|
|
elif isinstance(usage, dict):
|
|
# It's a plain dict
|
|
prompt_tokens = usage.get("prompt_tokens")
|
|
completion_tokens = usage.get("completion_tokens")
|
|
total_tokens = usage.get("total_tokens")
|
|
prompt_details = usage.get("prompt_tokens_details")
|
|
if isinstance(prompt_details, dict):
|
|
cached_tokens = prompt_details.get("cached_tokens")
|
|
|
|
# Extract model info
|
|
model = last_entry.get("response_model") or last_entry.get("model")
|
|
request_id = getattr(response, "id", None)
|
|
created = getattr(response, "created", None)
|
|
|
|
# Convert unix timestamp to ISO 8601 if needed
|
|
created_str = None
|
|
if created:
|
|
if isinstance(created, (int, float)):
|
|
import datetime
|
|
created_str = datetime.datetime.fromtimestamp(created, tz=datetime.timezone.utc).isoformat()
|
|
else:
|
|
created_str = str(created)
|
|
|
|
# Determine thinking mode (GLM 4.7 specific)
|
|
thinking_mode = None
|
|
if reasoning_content:
|
|
# If we got reasoning_content, the model used interleaved thinking
|
|
thinking_mode = "interleaved"
|
|
|
|
metadata = LLMResponseMetadata(
|
|
content=content,
|
|
reasoning_content=reasoning_content,
|
|
model=model,
|
|
provider=provider,
|
|
request_id=request_id,
|
|
created=created_str,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens,
|
|
cached_tokens=cached_tokens,
|
|
finish_reason=finish_reason,
|
|
latency_ms=latency_ms,
|
|
thinking_mode=thinking_mode,
|
|
)
|
|
|
|
if reasoning_content:
|
|
logger.info(
|
|
f"Captured GLM 4.7 reasoning_content ({len(reasoning_content)} chars) "
|
|
f"from {provider}/{model}"
|
|
)
|
|
|
|
return metadata
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to extract LLM response metadata: {e}")
|
|
return None
|
|
|
|
|
|
# Cache Client
|
|
class ValkeyClient:
|
|
"""Client for Valkey semantic cache API."""
|
|
|
|
def __init__(self, base_url: str = settings.valkey_api_url):
|
|
self.base_url = base_url.rstrip("/")
|
|
self._client: httpx.AsyncClient | None = None
|
|
|
|
@property
|
|
async def client(self) -> httpx.AsyncClient:
|
|
"""Get or create async HTTP client."""
|
|
if self._client is None or self._client.is_closed:
|
|
self._client = httpx.AsyncClient(timeout=30.0)
|
|
return self._client
|
|
|
|
def _cache_key(self, question: str, sources: list[DataSource] | None) -> str:
|
|
"""Generate cache key from question and sources."""
|
|
if sources:
|
|
sources_str = ",".join(sorted(s.value for s in sources))
|
|
else:
|
|
sources_str = "auto"
|
|
key_str = f"{question.lower().strip()}:{sources_str}"
|
|
return hashlib.sha256(key_str.encode()).hexdigest()[:32]
|
|
|
|
async def get(self, question: str, sources: list[DataSource] | None) -> dict[str, Any] | None:
|
|
"""Get cached response using semantic cache lookup."""
|
|
try:
|
|
client = await self.client
|
|
response = await client.post(
|
|
f"{self.base_url}/cache/lookup",
|
|
json={
|
|
"query": question,
|
|
"similarity_threshold": 0.92,
|
|
},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
if data.get("found") and data.get("entry"):
|
|
logger.info(f"Cache hit for question: {question[:50]}... (similarity: {data.get('similarity', 0):.3f})")
|
|
return data["entry"].get("response") # type: ignore[no-any-return]
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Cache get failed: {e}")
|
|
return None
|
|
|
|
async def set(
|
|
self,
|
|
question: str,
|
|
sources: list[DataSource] | None,
|
|
response: dict[str, Any],
|
|
ttl: int = settings.cache_ttl,
|
|
) -> bool:
|
|
"""Cache response using semantic cache store."""
|
|
try:
|
|
client = await self.client
|
|
|
|
# Build CachedResponse schema
|
|
cached_response = {
|
|
"answer": response.get("answer", ""),
|
|
"sparql_query": response.get("sparql_query"),
|
|
"typeql_query": response.get("typeql_query"),
|
|
"visualization_type": response.get("visualization_type"),
|
|
"visualization_data": response.get("visualization_data"),
|
|
"sources": response.get("sources", []),
|
|
"confidence": response.get("confidence", 0.0),
|
|
"context": response.get("context"),
|
|
}
|
|
|
|
await client.post(
|
|
f"{self.base_url}/cache/store",
|
|
json={
|
|
"query": question,
|
|
"response": cached_response,
|
|
"language": response.get("language", "nl"),
|
|
"model": response.get("llm_model", "unknown"),
|
|
},
|
|
)
|
|
logger.debug(f"Cached response for: {question[:50]}...")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Cache set failed: {e}")
|
|
return False
|
|
|
|
def _dspy_cache_key(
|
|
self,
|
|
question: str,
|
|
language: str,
|
|
llm_provider: str | None,
|
|
embedding_model: str | None,
|
|
context_hash: str | None = None,
|
|
) -> str:
|
|
"""Generate cache key for DSPy query responses.
|
|
|
|
Cache key components:
|
|
- Question text (normalized)
|
|
- Language code
|
|
- LLM provider (different providers give different answers)
|
|
- Embedding model (affects retrieval results)
|
|
- Context hash (for multi-turn conversations)
|
|
"""
|
|
components = [
|
|
question.lower().strip(),
|
|
language,
|
|
llm_provider or "default",
|
|
embedding_model or "auto",
|
|
context_hash or "no_context",
|
|
]
|
|
key_str = ":".join(components)
|
|
return f"dspy:{hashlib.sha256(key_str.encode()).hexdigest()[:32]}"
|
|
|
|
async def get_dspy(
|
|
self,
|
|
question: str,
|
|
language: str,
|
|
llm_provider: str | None,
|
|
embedding_model: str | None,
|
|
context: list[dict[str, Any]] | None = None,
|
|
) -> dict[str, Any] | None:
|
|
"""Get cached DSPy response using semantic cache lookup.
|
|
|
|
Cache hits are filtered by LLM provider to ensure responses from different
|
|
providers (e.g., anthropic vs huggingface) are cached separately.
|
|
"""
|
|
try:
|
|
client = await self.client
|
|
response = await client.post(
|
|
f"{self.base_url}/cache/lookup",
|
|
json={
|
|
"query": question,
|
|
"language": language,
|
|
"similarity_threshold": 0.92,
|
|
},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
if data.get("found") and data.get("entry"):
|
|
cached_response = data["entry"].get("response")
|
|
|
|
# Verify the cached response matches the requested LLM provider
|
|
# The model field in cache contains the provider (e.g., "anthropic", "huggingface")
|
|
cached_model = data["entry"].get("model")
|
|
requested_provider = llm_provider or settings.llm_provider
|
|
|
|
if cached_model and cached_model != requested_provider:
|
|
logger.info(
|
|
f"DSPy cache miss (provider mismatch): cached={cached_model}, requested={requested_provider}"
|
|
)
|
|
return None
|
|
|
|
similarity = data.get("similarity", 0)
|
|
method = data.get("method", "unknown")
|
|
logger.info(f"DSPy cache hit for question: {question[:50]}... (similarity: {similarity:.3f}, method: {method}, provider: {cached_model})")
|
|
return cached_response # type: ignore[no-any-return]
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.warning(f"DSPy cache get failed: {e}")
|
|
return None
|
|
|
|
async def set_dspy(
|
|
self,
|
|
question: str,
|
|
language: str,
|
|
llm_provider: str | None,
|
|
embedding_model: str | None,
|
|
response: dict[str, Any],
|
|
context: list[dict[str, Any]] | None = None,
|
|
ttl: int = settings.cache_ttl,
|
|
) -> bool:
|
|
"""Cache DSPy response using semantic cache store.
|
|
|
|
Maps DSPyQueryResponse fields to CachedResponse schema:
|
|
- sources_used -> sources
|
|
- visualization -> visualization_type + visualization_data
|
|
- Additional context from query_type, resolved_question, etc.
|
|
"""
|
|
try:
|
|
client = await self.client
|
|
|
|
# Extract visualization components if present
|
|
visualization = response.get("visualization")
|
|
viz_type = None
|
|
viz_data = None
|
|
if visualization:
|
|
viz_type = visualization.get("type")
|
|
viz_data = visualization.get("data")
|
|
|
|
# Build CachedResponse schema matching the Valkey API
|
|
# Maps DSPyQueryResponse fields to CachedResponse expected fields
|
|
#
|
|
# IMPORTANT: Include llm_response metadata (GLM 4.7 reasoning_content) in cache
|
|
# so that cached responses also return the chain-of-thought reasoning.
|
|
llm_response_data = None
|
|
if response.get("llm_response"):
|
|
llm_resp = response["llm_response"]
|
|
# Handle both dict and LLMResponseMetadata object
|
|
if hasattr(llm_resp, "model_dump"):
|
|
llm_response_data = llm_resp.model_dump()
|
|
elif isinstance(llm_resp, dict):
|
|
llm_response_data = llm_resp
|
|
|
|
cached_response = {
|
|
"answer": response.get("answer", ""),
|
|
"sparql_query": None, # DSPy doesn't generate SPARQL
|
|
"typeql_query": None, # DSPy doesn't generate TypeQL
|
|
"visualization_type": viz_type,
|
|
"visualization_data": viz_data,
|
|
"sources": response.get("sources_used", []), # DSPy uses sources_used
|
|
"confidence": 0.95, # DSPy responses are generally high confidence
|
|
"context": {
|
|
"query_type": response.get("query_type"),
|
|
"resolved_question": response.get("resolved_question"),
|
|
"retrieved_results": response.get("retrieved_results"),
|
|
"embedding_model": response.get("embedding_model_used"),
|
|
"llm_model": response.get("llm_model_used"),
|
|
"original_context": context,
|
|
"llm_response": llm_response_data, # GLM 4.7 reasoning_content
|
|
},
|
|
}
|
|
|
|
result = await client.post(
|
|
f"{self.base_url}/cache/store",
|
|
json={
|
|
"query": question,
|
|
"response": cached_response,
|
|
"language": language,
|
|
"model": llm_provider or "unknown",
|
|
},
|
|
)
|
|
|
|
# Check if store was successful
|
|
if result.status_code == 200:
|
|
logger.info(f"✓ Cached DSPy response for: {question[:50]}...")
|
|
return True
|
|
else:
|
|
logger.warning(f"Cache store returned {result.status_code}: {result.text[:200]}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.warning(f"DSPy cache set failed: {e}")
|
|
return False
|
|
|
|
async def close(self) -> None:
|
|
"""Close HTTP client."""
|
|
if self._client:
|
|
await self._client.aclose()
|
|
self._client = None
|
|
|
|
|
|
# Query Router
|
|
class QueryRouter:
|
|
"""Routes queries to appropriate data sources based on intent."""
|
|
|
|
def __init__(self) -> None:
|
|
self.intent_keywords = {
|
|
QueryIntent.GEOGRAPHIC: [
|
|
"map", "kaart", "where", "waar", "location", "locatie",
|
|
"city", "stad", "country", "land", "region", "gebied",
|
|
"coordinates", "coördinaten", "near", "nearby", "in de buurt",
|
|
],
|
|
QueryIntent.STATISTICAL: [
|
|
"how many", "hoeveel", "count", "aantal", "total", "totaal",
|
|
"average", "gemiddeld", "distribution", "verdeling",
|
|
"percentage", "statistics", "statistiek", "most", "meest",
|
|
],
|
|
QueryIntent.RELATIONAL: [
|
|
"related", "gerelateerd", "connected", "verbonden",
|
|
"relationship", "relatie", "network", "netwerk",
|
|
"parent", "child", "merged", "fusie", "member of",
|
|
],
|
|
QueryIntent.TEMPORAL: [
|
|
"history", "geschiedenis", "timeline", "tijdlijn",
|
|
"when", "wanneer", "founded", "opgericht", "closed", "gesloten",
|
|
"over time", "evolution", "change", "verandering",
|
|
],
|
|
QueryIntent.DETAIL: [
|
|
"details", "information", "informatie", "about", "over",
|
|
"specific", "specifiek", "what is", "wat is",
|
|
],
|
|
}
|
|
|
|
self.source_routing = {
|
|
QueryIntent.GEOGRAPHIC: [DataSource.POSTGIS, DataSource.QDRANT, DataSource.SPARQL],
|
|
QueryIntent.STATISTICAL: [DataSource.DUCKLAKE, DataSource.SPARQL, DataSource.QDRANT],
|
|
QueryIntent.RELATIONAL: [DataSource.TYPEDB, DataSource.SPARQL],
|
|
QueryIntent.TEMPORAL: [DataSource.TYPEDB, DataSource.SPARQL],
|
|
QueryIntent.SEARCH: [DataSource.QDRANT, DataSource.SPARQL],
|
|
QueryIntent.DETAIL: [DataSource.SPARQL, DataSource.QDRANT],
|
|
}
|
|
|
|
def detect_intent(self, question: str) -> QueryIntent:
|
|
"""Detect query intent from question text."""
|
|
import re
|
|
question_lower = question.lower()
|
|
|
|
intent_scores = {intent: 0 for intent in QueryIntent}
|
|
|
|
for intent, keywords in self.intent_keywords.items():
|
|
for keyword in keywords:
|
|
# Use word boundary matching to avoid partial matches
|
|
# e.g., "land" should not match "netherlands"
|
|
pattern = r'\b' + re.escape(keyword) + r'\b'
|
|
if re.search(pattern, question_lower):
|
|
intent_scores[intent] += 1
|
|
|
|
max_intent = max(intent_scores, key=intent_scores.get) # type: ignore
|
|
|
|
if intent_scores[max_intent] == 0:
|
|
return QueryIntent.SEARCH
|
|
|
|
return max_intent
|
|
|
|
def get_sources(
|
|
self,
|
|
question: str,
|
|
requested_sources: list[DataSource] | None = None,
|
|
) -> tuple[QueryIntent, list[DataSource]]:
|
|
"""Get optimal sources for a query.
|
|
|
|
Args:
|
|
question: User's question
|
|
requested_sources: Explicitly requested sources (overrides routing)
|
|
|
|
Returns:
|
|
Tuple of (detected_intent, list_of_sources)
|
|
"""
|
|
intent = self.detect_intent(question)
|
|
|
|
if requested_sources:
|
|
return intent, requested_sources
|
|
|
|
return intent, self.source_routing.get(intent, [DataSource.QDRANT])
|
|
|
|
|
|
# Multi-Source Retriever
|
|
class MultiSourceRetriever:
|
|
"""Orchestrates retrieval across multiple data sources."""
|
|
|
|
def __init__(self) -> None:
|
|
self.cache = ValkeyClient()
|
|
self.router = QueryRouter()
|
|
|
|
# Initialize retrievers lazily
|
|
self._qdrant: HybridRetriever | None = None
|
|
self._typedb: TypeDBRetriever | None = None
|
|
self._sparql_client: httpx.AsyncClient | None = None
|
|
self._postgis_client: httpx.AsyncClient | None = None
|
|
self._ducklake_client: httpx.AsyncClient | None = None
|
|
|
|
@property
|
|
def qdrant(self) -> HybridRetriever | None:
|
|
"""Lazy-load Qdrant hybrid retriever with multi-embedding support."""
|
|
if self._qdrant is None and RETRIEVERS_AVAILABLE:
|
|
try:
|
|
self._qdrant = create_hybrid_retriever(
|
|
use_production=settings.qdrant_use_production,
|
|
use_multi_embedding=settings.use_multi_embedding,
|
|
preferred_embedding_model=settings.preferred_embedding_model,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize Qdrant: {e}")
|
|
return self._qdrant
|
|
|
|
@property
|
|
def typedb(self) -> TypeDBRetriever | None:
|
|
"""Lazy-load TypeDB retriever."""
|
|
if self._typedb is None and RETRIEVERS_AVAILABLE:
|
|
try:
|
|
self._typedb = create_typedb_retriever(
|
|
use_production=settings.typedb_use_production # Use TypeDB-specific setting
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize TypeDB: {e}")
|
|
return self._typedb
|
|
|
|
async def _get_sparql_client(self) -> httpx.AsyncClient:
|
|
"""Get SPARQL HTTP client."""
|
|
if self._sparql_client is None or self._sparql_client.is_closed:
|
|
self._sparql_client = httpx.AsyncClient(timeout=30.0)
|
|
return self._sparql_client
|
|
|
|
async def _get_postgis_client(self) -> httpx.AsyncClient:
|
|
"""Get PostGIS HTTP client."""
|
|
if self._postgis_client is None or self._postgis_client.is_closed:
|
|
self._postgis_client = httpx.AsyncClient(timeout=30.0)
|
|
return self._postgis_client
|
|
|
|
async def _get_ducklake_client(self) -> httpx.AsyncClient:
|
|
"""Get DuckLake HTTP client."""
|
|
if self._ducklake_client is None or self._ducklake_client.is_closed:
|
|
self._ducklake_client = httpx.AsyncClient(timeout=60.0) # Longer timeout for SQL
|
|
return self._ducklake_client
|
|
|
|
async def retrieve_from_qdrant(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
embedding_model: str | None = None,
|
|
region_codes: list[str] | None = None,
|
|
cities: list[str] | None = None,
|
|
institution_types: list[str] | None = None,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from Qdrant vector + SPARQL hybrid search.
|
|
|
|
Args:
|
|
query: Search query
|
|
k: Number of results to return
|
|
embedding_model: Optional embedding model to use (e.g., 'minilm_384', 'openai_1536')
|
|
region_codes: Filter by province/region codes (e.g., ['NH', 'ZH'])
|
|
cities: Filter by city names (e.g., ['Amsterdam', 'Rotterdam'])
|
|
institution_types: Filter by institution types (e.g., ['ARCHIVE', 'MUSEUM'])
|
|
"""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
items = []
|
|
if self.qdrant:
|
|
try:
|
|
results = self.qdrant.search(
|
|
query,
|
|
k=k,
|
|
using=embedding_model,
|
|
region_codes=region_codes,
|
|
cities=cities,
|
|
institution_types=institution_types,
|
|
)
|
|
items = [r.to_dict() for r in results]
|
|
except Exception as e:
|
|
logger.error(f"Qdrant retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.QDRANT,
|
|
items=items,
|
|
score=max((r.get("scores", {}).get("combined", 0) for r in items), default=0),
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve_from_sparql(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from SPARQL endpoint."""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
# Use DSPy to generate SPARQL
|
|
items = []
|
|
try:
|
|
if RETRIEVERS_AVAILABLE:
|
|
sparql_result = generate_sparql(query, language="nl", use_rag=False)
|
|
sparql_query = sparql_result.get("sparql", "")
|
|
|
|
if sparql_query:
|
|
client = await self._get_sparql_client()
|
|
response = await client.post(
|
|
settings.sparql_endpoint,
|
|
data={"query": sparql_query},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
bindings = data.get("results", {}).get("bindings", [])
|
|
items = [
|
|
{k: v.get("value") for k, v in b.items()}
|
|
for b in bindings[:k]
|
|
]
|
|
except Exception as e:
|
|
logger.error(f"SPARQL retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.SPARQL,
|
|
items=items,
|
|
score=1.0 if items else 0.0,
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve_from_typedb(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from TypeDB knowledge graph."""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
items = []
|
|
if self.typedb:
|
|
try:
|
|
results = self.typedb.semantic_search(query, k=k)
|
|
items = [r.to_dict() for r in results]
|
|
except Exception as e:
|
|
logger.error(f"TypeDB retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.TYPEDB,
|
|
items=items,
|
|
score=max((r.get("relevance_score", 0) for r in items), default=0),
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve_from_postgis(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from PostGIS geospatial database."""
|
|
start = asyncio.get_event_loop().time()
|
|
|
|
# Extract location from query for geospatial search
|
|
# This is a simplified implementation
|
|
items = []
|
|
try:
|
|
client = await self._get_postgis_client()
|
|
|
|
# Try to detect city name for bbox search
|
|
query_lower = query.lower()
|
|
|
|
# Simple city detection
|
|
cities = {
|
|
"amsterdam": {"lat": 52.3676, "lon": 4.9041},
|
|
"rotterdam": {"lat": 51.9244, "lon": 4.4777},
|
|
"den haag": {"lat": 52.0705, "lon": 4.3007},
|
|
"utrecht": {"lat": 52.0907, "lon": 5.1214},
|
|
}
|
|
|
|
for city, coords in cities.items():
|
|
if city in query_lower:
|
|
# Query PostGIS for nearby institutions
|
|
response = await client.get(
|
|
f"{settings.postgis_url}/api/institutions/nearby",
|
|
params={
|
|
"lat": coords["lat"],
|
|
"lon": coords["lon"],
|
|
"radius_km": 10,
|
|
"limit": k,
|
|
},
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
items = response.json()
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"PostGIS retrieval failed: {e}")
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.POSTGIS,
|
|
items=items,
|
|
score=1.0 if items else 0.0,
|
|
query_time_ms=elapsed,
|
|
)
|
|
|
|
async def retrieve_from_ducklake(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
) -> RetrievalResult:
|
|
"""Retrieve from DuckLake SQL analytics database.
|
|
|
|
Uses DSPy HeritageSQLGenerator to convert natural language to SQL,
|
|
then executes against the custodians table.
|
|
|
|
Args:
|
|
query: Natural language question or SQL query
|
|
k: Maximum number of results to return
|
|
|
|
Returns:
|
|
RetrievalResult with query results
|
|
"""
|
|
start = asyncio.get_event_loop().time()
|
|
items = []
|
|
metadata = {}
|
|
|
|
try:
|
|
# Import the SQL generator from dspy module
|
|
import dspy
|
|
from dspy_heritage_rag import HeritageSQLGenerator
|
|
|
|
# Initialize DSPy predictor for SQL generation
|
|
sql_generator = dspy.Predict(HeritageSQLGenerator)
|
|
|
|
# Generate SQL from natural language
|
|
sql_result = sql_generator(
|
|
question=query,
|
|
intent="statistical",
|
|
entities="",
|
|
context="",
|
|
)
|
|
|
|
sql_query = sql_result.sql
|
|
metadata["generated_sql"] = sql_query
|
|
metadata["sql_explanation"] = sql_result.explanation
|
|
|
|
# Execute SQL against DuckLake
|
|
client = await self._get_ducklake_client()
|
|
response = await client.post(
|
|
f"{settings.ducklake_url}/query",
|
|
json={"sql": sql_query},
|
|
timeout=60.0,
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
columns = data.get("columns", [])
|
|
rows = data.get("rows", [])
|
|
|
|
# Convert to list of dicts
|
|
items = [dict(zip(columns, row)) for row in rows[:k]]
|
|
metadata["row_count"] = data.get("row_count", len(items))
|
|
metadata["execution_time_ms"] = data.get("execution_time_ms", 0)
|
|
else:
|
|
logger.error(f"DuckLake query failed: {response.status_code} - {response.text}")
|
|
metadata["error"] = f"HTTP {response.status_code}"
|
|
|
|
except Exception as e:
|
|
logger.error(f"DuckLake retrieval failed: {e}")
|
|
metadata["error"] = str(e)
|
|
|
|
elapsed = (asyncio.get_event_loop().time() - start) * 1000
|
|
|
|
return RetrievalResult(
|
|
source=DataSource.DUCKLAKE,
|
|
items=items,
|
|
score=1.0 if items else 0.0,
|
|
query_time_ms=elapsed,
|
|
metadata=metadata,
|
|
)
|
|
|
|
async def retrieve(
|
|
self,
|
|
question: str,
|
|
sources: list[DataSource],
|
|
k: int = 10,
|
|
embedding_model: str | None = None,
|
|
region_codes: list[str] | None = None,
|
|
cities: list[str] | None = None,
|
|
institution_types: list[str] | None = None,
|
|
) -> list[RetrievalResult]:
|
|
"""Retrieve from multiple sources concurrently.
|
|
|
|
Args:
|
|
question: User's question
|
|
sources: Data sources to query
|
|
k: Number of results per source
|
|
embedding_model: Optional embedding model for Qdrant (e.g., 'minilm_384', 'openai_1536')
|
|
region_codes: Filter by province/region codes (e.g., ['NH', 'ZH']) - Qdrant only
|
|
cities: Filter by city names (e.g., ['Amsterdam']) - Qdrant only
|
|
institution_types: Filter by institution types (e.g., ['ARCHIVE']) - Qdrant only
|
|
|
|
Returns:
|
|
List of RetrievalResult from each source
|
|
"""
|
|
tasks = []
|
|
|
|
for source in sources:
|
|
if source == DataSource.QDRANT:
|
|
tasks.append(self.retrieve_from_qdrant(
|
|
question,
|
|
k,
|
|
embedding_model,
|
|
region_codes=region_codes,
|
|
cities=cities,
|
|
institution_types=institution_types,
|
|
))
|
|
elif source == DataSource.SPARQL:
|
|
tasks.append(self.retrieve_from_sparql(question, k))
|
|
elif source == DataSource.TYPEDB:
|
|
tasks.append(self.retrieve_from_typedb(question, k))
|
|
elif source == DataSource.POSTGIS:
|
|
tasks.append(self.retrieve_from_postgis(question, k))
|
|
elif source == DataSource.DUCKLAKE:
|
|
tasks.append(self.retrieve_from_ducklake(question, k))
|
|
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Filter out exceptions
|
|
valid_results = []
|
|
for r in results:
|
|
if isinstance(r, RetrievalResult):
|
|
valid_results.append(r)
|
|
elif isinstance(r, Exception):
|
|
logger.error(f"Retrieval task failed: {r}")
|
|
|
|
return valid_results
|
|
|
|
def merge_results(
|
|
self,
|
|
results: list[RetrievalResult],
|
|
max_results: int = 20,
|
|
) -> list[dict[str, Any]]:
|
|
"""Merge and deduplicate results from multiple sources.
|
|
|
|
Uses reciprocal rank fusion for score combination.
|
|
"""
|
|
# Track items by GHCID for deduplication
|
|
merged: dict[str, dict[str, Any]] = {}
|
|
|
|
for result in results:
|
|
for rank, item in enumerate(result.items):
|
|
ghcid = item.get("ghcid", item.get("id", f"unknown_{rank}"))
|
|
|
|
if ghcid not in merged:
|
|
merged[ghcid] = item.copy()
|
|
merged[ghcid]["_sources"] = []
|
|
merged[ghcid]["_rrf_score"] = 0.0
|
|
|
|
# Reciprocal Rank Fusion
|
|
rrf_score = 1.0 / (60 + rank) # k=60 is standard
|
|
|
|
# Weight by source
|
|
source_weights = {
|
|
DataSource.QDRANT: settings.vector_weight,
|
|
DataSource.SPARQL: settings.graph_weight,
|
|
DataSource.TYPEDB: settings.typedb_weight,
|
|
DataSource.POSTGIS: 0.3,
|
|
}
|
|
|
|
weight = source_weights.get(result.source, 0.5)
|
|
merged[ghcid]["_rrf_score"] += rrf_score * weight
|
|
merged[ghcid]["_sources"].append(result.source.value)
|
|
|
|
# Sort by RRF score
|
|
sorted_items = sorted(
|
|
merged.values(),
|
|
key=lambda x: x.get("_rrf_score", 0),
|
|
reverse=True,
|
|
)
|
|
|
|
return sorted_items[:max_results]
|
|
|
|
async def close(self) -> None:
|
|
"""Clean up resources."""
|
|
await self.cache.close()
|
|
|
|
if self._sparql_client:
|
|
await self._sparql_client.aclose()
|
|
if self._postgis_client:
|
|
await self._postgis_client.aclose()
|
|
if self._qdrant:
|
|
self._qdrant.close()
|
|
if self._typedb:
|
|
self._typedb.close()
|
|
|
|
def search_persons(
|
|
self,
|
|
query: str,
|
|
k: int = 10,
|
|
filter_custodian: str | None = None,
|
|
only_heritage_relevant: bool = False,
|
|
using: str | None = None,
|
|
) -> list[Any]:
|
|
"""Search for persons/staff in the heritage_persons collection.
|
|
|
|
Delegates to HybridRetriever.search_persons() if available.
|
|
|
|
Args:
|
|
query: Search query
|
|
k: Number of results
|
|
filter_custodian: Optional custodian slug to filter by
|
|
only_heritage_relevant: Only return heritage-relevant staff
|
|
using: Optional embedding model to use (e.g., 'minilm_384', 'openai_1536')
|
|
|
|
Returns:
|
|
List of RetrievedPerson objects
|
|
"""
|
|
if self.qdrant:
|
|
try:
|
|
return self.qdrant.search_persons( # type: ignore[no-any-return]
|
|
query=query,
|
|
k=k,
|
|
filter_custodian=filter_custodian,
|
|
only_heritage_relevant=only_heritage_relevant,
|
|
using=using,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Person search failed: {e}")
|
|
return []
|
|
|
|
def get_stats(self) -> dict[str, Any]:
|
|
"""Get statistics from all retrievers.
|
|
|
|
Returns combined stats from Qdrant (including persons collection) and TypeDB.
|
|
"""
|
|
stats = {}
|
|
|
|
if self.qdrant:
|
|
try:
|
|
qdrant_stats = self.qdrant.get_stats()
|
|
stats.update(qdrant_stats)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get Qdrant stats: {e}")
|
|
|
|
if self.typedb:
|
|
try:
|
|
typedb_stats = self.typedb.get_stats()
|
|
stats["typedb"] = typedb_stats
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get TypeDB stats: {e}")
|
|
|
|
return stats
|
|
|
|
|
|
# Global instances
|
|
retriever: MultiSourceRetriever | None = None
|
|
viz_selector: VisualizationSelector | None = None
|
|
dspy_pipeline: Any = None # HeritageRAGPipeline instance (loaded with optimized model)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
|
"""Application lifespan manager."""
|
|
global retriever, viz_selector, dspy_pipeline
|
|
|
|
# Startup
|
|
logger.info("Starting Heritage RAG API...")
|
|
retriever = MultiSourceRetriever()
|
|
|
|
if RETRIEVERS_AVAILABLE:
|
|
# Check for any available LLM API key (Anthropic preferred, OpenAI fallback)
|
|
has_llm_key = bool(settings.anthropic_api_key or settings.openai_api_key)
|
|
|
|
# VisualizationSelector requires DSPy - make it optional
|
|
try:
|
|
viz_selector = VisualizationSelector(use_dspy=has_llm_key)
|
|
except RuntimeError as e:
|
|
logger.warning(f"VisualizationSelector not available: {e}")
|
|
viz_selector = None
|
|
|
|
# Configure DSPy based on LLM_PROVIDER setting
|
|
# Respect user's provider preference, with fallback chain
|
|
import dspy
|
|
llm_provider = settings.llm_provider.lower()
|
|
logger.info(f"LLM_PROVIDER configured as: {llm_provider}")
|
|
dspy_configured = False
|
|
|
|
# Try Z.AI GLM if configured as provider (FREE!)
|
|
if llm_provider == "zai" and settings.zai_api_token:
|
|
try:
|
|
# Z.AI uses OpenAI-compatible API format
|
|
# Use LLM_MODEL from settings (default: glm-4.5-flash for speed)
|
|
zai_model = settings.llm_model if settings.llm_model.startswith("glm-") else "glm-4.5-flash"
|
|
lm = dspy.LM(
|
|
f"openai/{zai_model}",
|
|
api_key=settings.zai_api_token,
|
|
api_base="https://api.z.ai/api/coding/paas/v4",
|
|
)
|
|
dspy.configure(lm=lm)
|
|
logger.info(f"Configured DSPy with Z.AI {zai_model} (FREE)")
|
|
dspy_configured = True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to configure DSPy with Z.AI: {e}")
|
|
|
|
# Try HuggingFace if configured as provider
|
|
if not dspy_configured and llm_provider == "huggingface" and settings.huggingface_api_key:
|
|
try:
|
|
lm = dspy.LM("huggingface/utter-project/EuroLLM-9B-Instruct", api_key=settings.huggingface_api_key)
|
|
dspy.configure(lm=lm)
|
|
logger.info("Configured DSPy with HuggingFace EuroLLM-9B-Instruct")
|
|
dspy_configured = True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to configure DSPy with HuggingFace: {e}")
|
|
|
|
# Try Anthropic if not yet configured (either as primary or fallback)
|
|
if not dspy_configured and (llm_provider == "anthropic" or (llm_provider == "huggingface" and settings.anthropic_api_key)):
|
|
if settings.anthropic_api_key and configure_dspy:
|
|
try:
|
|
configure_dspy(
|
|
provider="anthropic",
|
|
model=settings.default_model,
|
|
api_key=settings.anthropic_api_key,
|
|
)
|
|
dspy_configured = True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to configure DSPy with Anthropic: {e}")
|
|
|
|
# Try OpenAI as final fallback
|
|
if not dspy_configured and settings.openai_api_key and configure_dspy:
|
|
try:
|
|
configure_dspy(
|
|
provider="openai",
|
|
model="gpt-4o-mini",
|
|
api_key=settings.openai_api_key,
|
|
)
|
|
dspy_configured = True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to configure DSPy with OpenAI: {e}")
|
|
|
|
if not dspy_configured:
|
|
logger.warning("No LLM provider configured - DSPy queries will fail")
|
|
|
|
# Initialize optimized HeritageRAGPipeline (if DSPy is configured)
|
|
if dspy_configured:
|
|
try:
|
|
from dspy_heritage_rag import HeritageRAGPipeline
|
|
from pathlib import Path
|
|
|
|
# Create pipeline with Qdrant retriever
|
|
qdrant_retriever = retriever.qdrant if retriever else None
|
|
dspy_pipeline = HeritageRAGPipeline(retriever=qdrant_retriever)
|
|
|
|
# Load optimized model (BootstrapFewShot: 14.3% quality improvement)
|
|
optimized_model_path = Path(__file__).parent / "optimized_models" / "heritage_rag_bootstrap_latest.json"
|
|
if optimized_model_path.exists():
|
|
dspy_pipeline.load(str(optimized_model_path))
|
|
logger.info(f"Loaded optimized DSPy pipeline from {optimized_model_path}")
|
|
else:
|
|
logger.warning(f"Optimized model not found at {optimized_model_path}, using unoptimized pipeline")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize DSPy pipeline: {e}")
|
|
dspy_pipeline = None
|
|
|
|
# === HOT LOADING: Warmup embedding model to avoid cold-start latency ===
|
|
# The sentence-transformers model takes 3-15 seconds to load on first use.
|
|
# By loading it eagerly at startup, we eliminate this delay for users.
|
|
if retriever.qdrant:
|
|
logger.info("Warming up embedding model (this takes 3-15 seconds on first startup)...")
|
|
try:
|
|
# Trigger model load with a dummy embedding request
|
|
_ = retriever.qdrant._get_embedding("archief warmup query")
|
|
logger.info("✅ Embedding model warmed up - ready for fast queries!")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to warm up embedding model: {e}")
|
|
|
|
logger.info("Heritage RAG API started")
|
|
|
|
yield
|
|
|
|
# Shutdown
|
|
logger.info("Shutting down Heritage RAG API...")
|
|
if retriever:
|
|
await retriever.close()
|
|
logger.info("Heritage RAG API stopped")
|
|
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(
|
|
title=settings.api_title,
|
|
version=settings.api_version,
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# API Endpoints
|
|
@app.get("/api/rag/health")
|
|
async def health_check() -> dict[str, Any]:
|
|
"""Health check for all services."""
|
|
health: dict[str, Any] = {
|
|
"status": "ok",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"services": {},
|
|
}
|
|
|
|
# Check Qdrant
|
|
if retriever and retriever.qdrant:
|
|
try:
|
|
stats = retriever.qdrant.get_stats()
|
|
health["services"]["qdrant"] = {
|
|
"status": "ok",
|
|
"vectors": stats.get("qdrant", {}).get("vectors_count", 0),
|
|
}
|
|
except Exception as e:
|
|
health["services"]["qdrant"] = {"status": "error", "error": str(e)}
|
|
|
|
# Check SPARQL
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
response = await client.get(f"{settings.sparql_endpoint.replace('/query', '')}")
|
|
health["services"]["sparql"] = {
|
|
"status": "ok" if response.status_code < 500 else "error"
|
|
}
|
|
except Exception as e:
|
|
health["services"]["sparql"] = {"status": "error", "error": str(e)}
|
|
|
|
# Check TypeDB
|
|
if retriever and retriever.typedb:
|
|
try:
|
|
stats = retriever.typedb.get_stats()
|
|
health["services"]["typedb"] = {
|
|
"status": "ok",
|
|
"entities": stats.get("entities", {}),
|
|
}
|
|
except Exception as e:
|
|
health["services"]["typedb"] = {"status": "error", "error": str(e)}
|
|
|
|
# Overall status
|
|
services = health["services"]
|
|
errors = sum(1 for s in services.values() if isinstance(s, dict) and s.get("status") == "error")
|
|
health["status"] = "ok" if errors == 0 else "degraded" if errors < 3 else "error"
|
|
|
|
return health
|
|
|
|
|
|
@app.get("/api/rag/stats")
|
|
async def get_stats() -> dict[str, Any]:
|
|
"""Get retriever statistics."""
|
|
stats: dict[str, Any] = {
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"retrievers": {},
|
|
}
|
|
|
|
if retriever:
|
|
if retriever.qdrant:
|
|
stats["retrievers"]["qdrant"] = retriever.qdrant.get_stats()
|
|
if retriever.typedb:
|
|
stats["retrievers"]["typedb"] = retriever.typedb.get_stats()
|
|
|
|
return stats
|
|
|
|
|
|
@app.get("/api/rag/stats/costs")
|
|
async def get_cost_stats() -> dict[str, Any]:
|
|
"""Get cost tracking session statistics.
|
|
|
|
Returns cumulative statistics for the current session including:
|
|
- Total LLM calls and token usage
|
|
- Total retrieval operations and latencies
|
|
- Estimated costs by model
|
|
- Pipeline timing statistics
|
|
|
|
Returns:
|
|
Dict with cost tracker statistics or unavailable message
|
|
"""
|
|
if not COST_TRACKER_AVAILABLE or not get_tracker:
|
|
return {
|
|
"available": False,
|
|
"message": "Cost tracker module not available",
|
|
}
|
|
|
|
tracker = get_tracker()
|
|
return {
|
|
"available": True,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"session": tracker.get_session_summary(),
|
|
}
|
|
|
|
|
|
@app.post("/api/rag/stats/costs/reset")
|
|
async def reset_cost_stats() -> dict[str, Any]:
|
|
"""Reset cost tracking statistics.
|
|
|
|
Clears all accumulated statistics and starts a fresh session.
|
|
Useful for per-conversation or per-session cost tracking.
|
|
|
|
Returns:
|
|
Confirmation message
|
|
"""
|
|
if not COST_TRACKER_AVAILABLE or not reset_tracker:
|
|
return {
|
|
"available": False,
|
|
"message": "Cost tracker module not available",
|
|
}
|
|
|
|
reset_tracker()
|
|
return {
|
|
"available": True,
|
|
"message": "Cost tracking statistics reset",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
|
|
|
|
@app.get("/api/rag/embedding/models")
|
|
async def get_embedding_models() -> dict[str, Any]:
|
|
"""List available embedding models for the Qdrant collections.
|
|
|
|
Returns information about which embedding models are available in each
|
|
collection's named vectors, helping clients choose the right model for
|
|
their use case.
|
|
|
|
Returns:
|
|
Dict with available models per collection, current settings, and recommendations
|
|
"""
|
|
result: dict[str, Any] = {
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"multi_embedding_enabled": settings.use_multi_embedding,
|
|
"preferred_model": settings.preferred_embedding_model,
|
|
"collections": {},
|
|
"models": {
|
|
"openai_1536": {
|
|
"description": "OpenAI text-embedding-3-small (1536 dimensions)",
|
|
"quality": "high",
|
|
"cost": "paid API",
|
|
"recommended_for": "production, high-quality semantic search",
|
|
},
|
|
"minilm_384": {
|
|
"description": "sentence-transformers/all-MiniLM-L6-v2 (384 dimensions)",
|
|
"quality": "good",
|
|
"cost": "free (local)",
|
|
"recommended_for": "development, cost-sensitive deployments",
|
|
},
|
|
"bge_768": {
|
|
"description": "BAAI/bge-small-en-v1.5 (768 dimensions)",
|
|
"quality": "very good",
|
|
"cost": "free (local)",
|
|
"recommended_for": "balanced quality/cost, multilingual support",
|
|
},
|
|
},
|
|
}
|
|
|
|
if retriever and retriever.qdrant:
|
|
qdrant = retriever.qdrant
|
|
|
|
# Check if multi-embedding is enabled and get available models
|
|
if hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding:
|
|
if hasattr(qdrant, 'multi_retriever') and qdrant.multi_retriever:
|
|
multi = qdrant.multi_retriever
|
|
|
|
# Get available models for institutions collection
|
|
try:
|
|
inst_models = multi.get_available_models("heritage_custodians")
|
|
selected = multi.select_model("heritage_custodians")
|
|
result["collections"]["heritage_custodians"] = {
|
|
"available_models": [m.value for m in inst_models],
|
|
"uses_named_vectors": multi.uses_named_vectors("heritage_custodians"),
|
|
"recommended": selected.value if selected else None,
|
|
}
|
|
except Exception as e:
|
|
result["collections"]["heritage_custodians"] = {"error": str(e)}
|
|
|
|
# Get available models for persons collection
|
|
try:
|
|
person_models = multi.get_available_models("heritage_persons")
|
|
selected = multi.select_model("heritage_persons")
|
|
result["collections"]["heritage_persons"] = {
|
|
"available_models": [m.value for m in person_models],
|
|
"uses_named_vectors": multi.uses_named_vectors("heritage_persons"),
|
|
"recommended": selected.value if selected else None,
|
|
}
|
|
except Exception as e:
|
|
result["collections"]["heritage_persons"] = {"error": str(e)}
|
|
else:
|
|
# Single embedding mode - detect dimension
|
|
stats = qdrant.get_stats()
|
|
result["single_embedding_mode"] = True
|
|
result["note"] = "Collections use single embedding vectors. Enable USE_MULTI_EMBEDDING=true to use named vectors."
|
|
|
|
return result
|
|
|
|
|
|
class EmbeddingCompareRequest(BaseModel):
|
|
"""Request for comparing embedding models."""
|
|
query: str = Field(..., description="Query to search with")
|
|
collection: str = Field(default="heritage_persons", description="Collection to search")
|
|
k: int = Field(default=5, ge=1, le=20, description="Number of results per model")
|
|
|
|
|
|
@app.post("/api/rag/embedding/compare")
|
|
async def compare_embedding_models(request: EmbeddingCompareRequest) -> dict[str, Any]:
|
|
"""Compare search results across different embedding models.
|
|
|
|
Performs the same search query using each available embedding model,
|
|
allowing A/B testing of embedding quality.
|
|
|
|
This endpoint is useful for:
|
|
- Evaluating which embedding model works best for your queries
|
|
- Understanding differences in semantic similarity between models
|
|
- Making informed decisions about which model to use in production
|
|
|
|
Returns:
|
|
Dict with results from each embedding model, including scores and overlap analysis
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
if not retriever or not retriever.qdrant:
|
|
raise HTTPException(status_code=503, detail="Qdrant retriever not available")
|
|
|
|
qdrant = retriever.qdrant
|
|
|
|
# Check if multi-embedding is available
|
|
if not (hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Multi-embedding mode not enabled. Set USE_MULTI_EMBEDDING=true to use this endpoint."
|
|
)
|
|
|
|
if not (hasattr(qdrant, 'multi_retriever') and qdrant.multi_retriever):
|
|
raise HTTPException(status_code=503, detail="Multi-embedding retriever not initialized")
|
|
|
|
multi = qdrant.multi_retriever
|
|
|
|
try:
|
|
# Use the compare_models method from MultiEmbeddingRetriever
|
|
comparison = multi.compare_models(
|
|
query=request.query,
|
|
collection=request.collection,
|
|
k=request.k,
|
|
)
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
return {
|
|
"query": request.query,
|
|
"collection": request.collection,
|
|
"k": request.k,
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"comparison": comparison,
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Embedding comparison failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/api/rag/query", response_model=QueryResponse)
|
|
async def query_rag(request: QueryRequest) -> QueryResponse:
|
|
"""Main RAG query endpoint.
|
|
|
|
Orchestrates retrieval from multiple sources, merges results,
|
|
and optionally generates visualization configuration.
|
|
"""
|
|
if not retriever:
|
|
raise HTTPException(status_code=503, detail="Retriever not initialized")
|
|
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
# Check cache first
|
|
cached = await retriever.cache.get(request.question, request.sources)
|
|
if cached:
|
|
cached["cache_hit"] = True
|
|
return QueryResponse(**cached)
|
|
|
|
# Route query to appropriate sources
|
|
intent, sources = retriever.router.get_sources(request.question, request.sources)
|
|
logger.info(f"Query intent: {intent}, sources: {sources}")
|
|
|
|
# Extract geographic filters from question (province, city, institution type)
|
|
geo_filters = extract_geographic_filters(request.question)
|
|
if any(geo_filters.values()):
|
|
logger.info(f"Geographic filters extracted: {geo_filters}")
|
|
|
|
# Retrieve from all sources
|
|
results = await retriever.retrieve(
|
|
request.question,
|
|
sources,
|
|
request.k,
|
|
embedding_model=request.embedding_model,
|
|
region_codes=geo_filters["region_codes"],
|
|
cities=geo_filters["cities"],
|
|
institution_types=geo_filters["institution_types"],
|
|
)
|
|
|
|
# Merge results
|
|
merged_items = retriever.merge_results(results, max_results=request.k * 2)
|
|
|
|
# Generate visualization config if requested
|
|
visualization = None
|
|
if request.include_visualization and viz_selector and merged_items:
|
|
# Extract schema from first result
|
|
schema_fields = list(merged_items[0].keys()) if merged_items else []
|
|
schema_str = ", ".join(f for f in schema_fields if not f.startswith("_"))
|
|
|
|
visualization = viz_selector.select(
|
|
request.question,
|
|
schema_str,
|
|
len(merged_items),
|
|
)
|
|
|
|
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
|
|
|
|
response_data = {
|
|
"question": request.question,
|
|
"sparql": None, # Could be populated from SPARQL result
|
|
"results": merged_items,
|
|
"visualization": visualization,
|
|
"sources_used": [s for s in sources],
|
|
"cache_hit": False,
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"result_count": len(merged_items),
|
|
}
|
|
|
|
# Cache the response
|
|
await retriever.cache.set(request.question, request.sources, response_data)
|
|
|
|
return QueryResponse(**response_data) # type: ignore[arg-type]
|
|
|
|
|
|
@app.post("/api/rag/sparql", response_model=SPARQLResponse)
|
|
async def generate_sparql_endpoint(request: SPARQLRequest) -> SPARQLResponse:
|
|
"""Generate SPARQL query from natural language.
|
|
|
|
Uses DSPy with optional RAG enhancement for context.
|
|
"""
|
|
if not RETRIEVERS_AVAILABLE:
|
|
raise HTTPException(status_code=503, detail="SPARQL generator not available")
|
|
|
|
try:
|
|
result = generate_sparql(
|
|
request.question,
|
|
language=request.language,
|
|
context=request.context,
|
|
use_rag=request.use_rag,
|
|
)
|
|
|
|
return SPARQLResponse(
|
|
sparql=result["sparql"],
|
|
explanation=result.get("explanation", ""),
|
|
rag_used=result.get("rag_used", False),
|
|
retrieved_passages=result.get("retrieved_passages", []),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception("SPARQL generation failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/api/rag/visualize")
|
|
async def get_visualization_config(
|
|
question: str = Query(..., description="User's question"),
|
|
schema: str = Query(..., description="Comma-separated field names"),
|
|
result_count: int = Query(default=0, description="Number of results"),
|
|
) -> dict[str, Any]:
|
|
"""Get visualization configuration for a query."""
|
|
if not viz_selector:
|
|
raise HTTPException(status_code=503, detail="Visualization selector not available")
|
|
|
|
config = viz_selector.select(question, schema, result_count)
|
|
return config # type: ignore[no-any-return]
|
|
|
|
|
|
@app.post("/api/rag/typedb/search", response_model=TypeDBSearchResponse)
|
|
async def typedb_search(request: TypeDBSearchRequest) -> TypeDBSearchResponse:
|
|
"""Direct TypeDB search endpoint.
|
|
|
|
Search heritage custodians in TypeDB using various strategies:
|
|
- semantic: Natural language search (combines type + location patterns)
|
|
- name: Search by institution name
|
|
- type: Search by institution type (museum, archive, library, gallery)
|
|
- location: Search by city/location name
|
|
|
|
Examples:
|
|
- {"query": "museums in Amsterdam", "search_type": "semantic"}
|
|
- {"query": "Rijksmuseum", "search_type": "name"}
|
|
- {"query": "archive", "search_type": "type"}
|
|
- {"query": "Rotterdam", "search_type": "location"}
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
# Check if TypeDB retriever is available
|
|
if not retriever or not retriever.typedb:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="TypeDB retriever not available. Ensure TypeDB is running."
|
|
)
|
|
|
|
try:
|
|
typedb_retriever = retriever.typedb
|
|
|
|
# Route to appropriate search method
|
|
if request.search_type == "name":
|
|
results = typedb_retriever.search_by_name(request.query, k=request.k)
|
|
elif request.search_type == "type":
|
|
results = typedb_retriever.search_by_type(request.query, k=request.k)
|
|
elif request.search_type == "location":
|
|
results = typedb_retriever.search_by_location(city=request.query, k=request.k)
|
|
else: # semantic (default)
|
|
results = typedb_retriever.semantic_search(request.query, k=request.k)
|
|
|
|
# Convert results to dicts
|
|
result_dicts = []
|
|
seen_names = set() # Deduplicate by name
|
|
|
|
for r in results:
|
|
# Handle both dict and object results
|
|
if hasattr(r, 'to_dict'):
|
|
item = r.to_dict()
|
|
elif isinstance(r, dict):
|
|
item = r
|
|
else:
|
|
item = {"name": str(r)}
|
|
|
|
# Deduplicate by name
|
|
name = item.get("name") or item.get("observed_name", "")
|
|
if name and name not in seen_names:
|
|
seen_names.add(name)
|
|
result_dicts.append(item)
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
return TypeDBSearchResponse(
|
|
query=request.query,
|
|
search_type=request.search_type,
|
|
results=result_dicts,
|
|
result_count=len(result_dicts),
|
|
query_time_ms=round(elapsed_ms, 2),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"TypeDB search failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/api/rag/persons/search", response_model=PersonSearchResponse)
|
|
async def person_search(request: PersonSearchRequest) -> PersonSearchResponse:
|
|
"""Search for persons/staff in heritage institutions.
|
|
|
|
Search the heritage_persons Qdrant collection for staff members
|
|
at heritage custodian institutions.
|
|
|
|
Examples:
|
|
- {"query": "Wie werkt er in het Nationaal Archief?"}
|
|
- {"query": "archivist at Rijksmuseum", "k": 20}
|
|
- {"query": "conservator", "filter_custodian": "rijksmuseum"}
|
|
- {"query": "digital preservation", "only_heritage_relevant": true}
|
|
|
|
The search uses semantic vector similarity to find relevant staff members
|
|
based on their name, role, headline, and custodian affiliation.
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
# Check if retriever is available
|
|
if not retriever:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Hybrid retriever not available. Ensure Qdrant is running."
|
|
)
|
|
|
|
try:
|
|
# Use the hybrid retriever's person search
|
|
results = retriever.search_persons(
|
|
query=request.query,
|
|
k=request.k,
|
|
filter_custodian=request.filter_custodian,
|
|
only_heritage_relevant=request.only_heritage_relevant,
|
|
using=request.embedding_model, # Pass embedding model
|
|
)
|
|
|
|
# Determine which embedding model was actually used
|
|
embedding_model_used = None
|
|
qdrant = retriever.qdrant
|
|
if qdrant and hasattr(qdrant, 'use_multi_embedding') and qdrant.use_multi_embedding:
|
|
if request.embedding_model:
|
|
embedding_model_used = request.embedding_model
|
|
elif hasattr(qdrant, '_selected_multi_model') and qdrant._selected_multi_model:
|
|
embedding_model_used = qdrant._selected_multi_model.value
|
|
|
|
# Convert results to dicts using to_dict() method if available
|
|
result_dicts = []
|
|
for r in results:
|
|
if hasattr(r, 'to_dict'):
|
|
item = r.to_dict()
|
|
elif hasattr(r, '__dict__'):
|
|
item = {
|
|
"name": getattr(r, 'name', 'Unknown'),
|
|
"headline": getattr(r, 'headline', None),
|
|
"custodian_name": getattr(r, 'custodian_name', None),
|
|
"custodian_slug": getattr(r, 'custodian_slug', None),
|
|
"linkedin_url": getattr(r, 'linkedin_url', None),
|
|
"heritage_relevant": getattr(r, 'heritage_relevant', None),
|
|
"heritage_type": getattr(r, 'heritage_type', None),
|
|
"location": getattr(r, 'location', None),
|
|
"score": getattr(r, 'combined_score', getattr(r, 'vector_score', None)),
|
|
}
|
|
elif isinstance(r, dict):
|
|
item = r
|
|
else:
|
|
item = {"name": str(r)}
|
|
|
|
result_dicts.append(item)
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
# Get collection stats
|
|
stats = None
|
|
try:
|
|
stats = retriever.get_stats()
|
|
# Only include person collection stats if available
|
|
if stats and 'persons' in stats:
|
|
stats = {'persons': stats['persons']}
|
|
except Exception:
|
|
pass
|
|
|
|
return PersonSearchResponse(
|
|
context=PERSON_JSONLD_CONTEXT, # JSON-LD context for linked data
|
|
query=request.query,
|
|
results=result_dicts,
|
|
result_count=len(result_dicts),
|
|
query_time_ms=round(elapsed_ms, 2),
|
|
collection_stats=stats,
|
|
embedding_model_used=embedding_model_used,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Person search failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/api/rag/dspy/query", response_model=DSPyQueryResponse)
|
|
async def dspy_query(request: DSPyQueryRequest) -> DSPyQueryResponse:
|
|
"""DSPy RAG query endpoint with multi-turn conversation support.
|
|
|
|
Uses the HeritageRAGPipeline for conversation-aware question answering.
|
|
Follow-up questions like "Welke daarvan behoren archieven?" will be
|
|
resolved using previous conversation context.
|
|
|
|
Args:
|
|
request: Query request with question, language, and conversation context
|
|
|
|
Returns:
|
|
DSPyQueryResponse with answer, resolved question, and optional visualization
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
# Resolve the provider BEFORE cache lookup to ensure consistent cache keys
|
|
# This is critical: cache GET and SET must use the same provider value
|
|
resolved_provider = (request.llm_provider or settings.llm_provider).lower()
|
|
|
|
# Check cache first (before expensive LLM configuration)
|
|
if retriever:
|
|
cached = await retriever.cache.get_dspy(
|
|
question=request.question,
|
|
language=request.language,
|
|
llm_provider=resolved_provider, # Use resolved provider, not request.llm_provider
|
|
embedding_model=request.embedding_model,
|
|
context=request.context if request.context else None,
|
|
)
|
|
if cached:
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
logger.info(f"DSPy cache hit - returning cached response in {elapsed_ms:.2f}ms")
|
|
|
|
# Transform CachedResponse format back to DSPyQueryResponse format
|
|
# CachedResponse has: sources, visualization_type, visualization_data, context
|
|
# DSPyQueryResponse needs: sources_used, visualization, query_type, etc.
|
|
cached_context = cached.get("context") or {}
|
|
visualization = None
|
|
if cached.get("visualization_type") or cached.get("visualization_data"):
|
|
visualization = {
|
|
"type": cached.get("visualization_type"),
|
|
"data": cached.get("visualization_data"),
|
|
}
|
|
|
|
# Restore llm_response metadata (GLM 4.7 reasoning_content) from cache
|
|
llm_response_cached = cached_context.get("llm_response")
|
|
llm_response_obj = None
|
|
if llm_response_cached:
|
|
try:
|
|
llm_response_obj = LLMResponseMetadata(**llm_response_cached)
|
|
except Exception:
|
|
# Fall back to dict if LLMResponseMetadata fails
|
|
llm_response_obj = llm_response_cached # type: ignore[assignment]
|
|
|
|
response_data = {
|
|
"question": request.question,
|
|
"answer": cached.get("answer", ""),
|
|
"sources_used": cached.get("sources", []),
|
|
"visualization": visualization,
|
|
"resolved_question": cached_context.get("resolved_question"),
|
|
"retrieved_results": cached_context.get("retrieved_results"),
|
|
"query_type": cached_context.get("query_type"),
|
|
"embedding_model_used": cached_context.get("embedding_model"),
|
|
"llm_model_used": cached_context.get("llm_model"),
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"cache_hit": True,
|
|
"llm_response": llm_response_obj, # GLM 4.7 reasoning_content from cache
|
|
}
|
|
return DSPyQueryResponse(**response_data)
|
|
|
|
try:
|
|
# Import DSPy pipeline and History
|
|
import dspy
|
|
from dspy import History
|
|
from dspy_heritage_rag import HeritageRAGPipeline
|
|
|
|
# Configure DSPy LM per-request based on request.llm_provider (or server default)
|
|
# This allows frontend to switch LLM providers dynamically
|
|
#
|
|
# IMPORTANT: We use dspy.settings.context() instead of dspy.configure() because
|
|
# configure() can only be called from the same async task that initially configured DSPy.
|
|
# context() provides thread-local overrides that work correctly in async request handlers.
|
|
requested_provider = resolved_provider # Already resolved above
|
|
llm_provider_used: str | None = None
|
|
llm_model_used: str | None = None
|
|
lm = None
|
|
|
|
logger.info(f"LLM provider requested: {requested_provider} (request.llm_provider={request.llm_provider}, server default={settings.llm_provider})")
|
|
|
|
# Provider configuration priority: requested provider first, then fallback chain
|
|
providers_to_try = [requested_provider]
|
|
# Add fallback chain (but not duplicates)
|
|
for fallback in ["zai", "groq", "anthropic", "openai"]:
|
|
if fallback not in providers_to_try:
|
|
providers_to_try.append(fallback)
|
|
|
|
for provider in providers_to_try:
|
|
if lm is not None:
|
|
break
|
|
|
|
# Default models per provider (used if request.llm_model is not specified)
|
|
# Use LLM_MODEL from settings when it matches the provider prefix
|
|
default_models = {
|
|
"zai": settings.llm_model if settings.llm_model.startswith("glm-") else "glm-4.5-flash",
|
|
"groq": "llama-3.1-8b-instant",
|
|
"anthropic": settings.llm_model if settings.llm_model.startswith("claude-") else "claude-sonnet-4-20250514",
|
|
"openai": "gpt-4o-mini",
|
|
# Llama 3.1 8B: Good balance of speed/quality, available on HF serverless inference
|
|
# Alternatives: Qwen/QwQ-32B (better reasoning), mistralai/Mistral-7B-Instruct-v0.2
|
|
"huggingface": settings.llm_model if "/" in settings.llm_model else "meta-llama/Llama-3.1-8B-Instruct",
|
|
}
|
|
# HuggingFace models use org/model format (e.g., meta-llama/Llama-3.1-8B-Instruct)
|
|
# Groq models use simple names (e.g., llama-3.1-8b-instant)
|
|
model_prefixes = {
|
|
"glm-": "zai",
|
|
"llama-3.1-": "groq",
|
|
"llama-3.3-": "groq",
|
|
"claude-": "anthropic",
|
|
"gpt-": "openai",
|
|
# HuggingFace organization prefixes
|
|
"mistralai/": "huggingface",
|
|
"google/": "huggingface",
|
|
"Qwen/": "huggingface",
|
|
"deepseek-ai/": "huggingface",
|
|
"meta-llama/": "huggingface",
|
|
"utter-project/": "huggingface",
|
|
"microsoft/": "huggingface",
|
|
"tiiuae/": "huggingface",
|
|
}
|
|
|
|
# Determine which model to use: requested model (if valid for this provider) or default
|
|
requested_model = request.llm_model
|
|
model_to_use = default_models.get(provider, "")
|
|
|
|
# Check if requested model matches this provider
|
|
if requested_model:
|
|
for prefix, model_provider in model_prefixes.items():
|
|
if requested_model.startswith(prefix) and model_provider == provider:
|
|
model_to_use = requested_model
|
|
break
|
|
|
|
if provider == "zai" and settings.zai_api_token:
|
|
try:
|
|
lm = dspy.LM(
|
|
f"openai/{model_to_use}",
|
|
api_key=settings.zai_api_token,
|
|
api_base="https://api.z.ai/api/coding/paas/v4",
|
|
)
|
|
llm_provider_used = "zai"
|
|
llm_model_used = model_to_use
|
|
logger.info(f"Using Z.AI {model_to_use} (FREE) for this request")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create Z.AI LM: {e}")
|
|
|
|
elif provider == "groq" and settings.groq_api_key:
|
|
try:
|
|
lm = dspy.LM(f"groq/{model_to_use}", api_key=settings.groq_api_key)
|
|
llm_provider_used = "groq"
|
|
llm_model_used = model_to_use
|
|
logger.info(f"Using Groq {model_to_use} (FREE) for this request")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create Groq LM: {e}")
|
|
|
|
elif provider == "huggingface" and settings.huggingface_api_key:
|
|
try:
|
|
lm = dspy.LM(f"huggingface/{model_to_use}", api_key=settings.huggingface_api_key)
|
|
llm_provider_used = "huggingface"
|
|
llm_model_used = model_to_use
|
|
logger.info(f"Using HuggingFace {model_to_use} for this request")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create HuggingFace LM: {e}")
|
|
|
|
elif provider == "anthropic" and settings.anthropic_api_key:
|
|
try:
|
|
lm = dspy.LM(f"anthropic/{model_to_use}", api_key=settings.anthropic_api_key)
|
|
llm_provider_used = "anthropic"
|
|
llm_model_used = model_to_use
|
|
logger.info(f"Using Anthropic {model_to_use} for this request")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create Anthropic LM: {e}")
|
|
|
|
elif provider == "openai" and settings.openai_api_key:
|
|
try:
|
|
lm = dspy.LM(f"openai/{model_to_use}", api_key=settings.openai_api_key)
|
|
llm_provider_used = "openai"
|
|
llm_model_used = model_to_use
|
|
logger.info(f"Using OpenAI {model_to_use} for this request")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create OpenAI LM: {e}")
|
|
|
|
# No LM could be configured
|
|
if lm is None:
|
|
raise ValueError(
|
|
f"No LLM could be configured. Requested provider: {requested_provider}. "
|
|
"Ensure the appropriate API key is set: ZAI_API_TOKEN, GROQ_API_KEY, ANTHROPIC_API_KEY, HUGGINGFACE_API_KEY, or OPENAI_API_KEY."
|
|
)
|
|
|
|
logger.info(f"LLM provider for this request: {llm_provider_used}")
|
|
|
|
# =================================================================
|
|
# PERFORMANCE OPTIMIZATION: Create fast LM for routing/extraction
|
|
# Use a fast, cheap model (glm-4.5-flash FREE, gpt-4o-mini $0.15/1M)
|
|
# for routing, entity extraction, and SPARQL generation.
|
|
# The quality_lm (lm) is used only for final answer generation.
|
|
# This can reduce total latency by 2-3x (from ~20s to ~7s).
|
|
# =================================================================
|
|
fast_lm = None
|
|
|
|
# Try to create fast_lm based on FAST_LM_PROVIDER setting
|
|
# Options: "openai" (fast ~1-2s, $0.15/1M) or "zai" (FREE but slow ~13s)
|
|
# Default: openai for speed. Override with FAST_LM_PROVIDER=zai to save costs.
|
|
|
|
if settings.fast_lm_provider == "openai" and settings.openai_api_key:
|
|
try:
|
|
fast_lm = dspy.LM("openai/gpt-4o-mini", api_key=settings.openai_api_key)
|
|
logger.info("Using OpenAI GPT-4o-mini as fast_lm for routing/extraction (~1-2s)")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create fast OpenAI LM: {e}")
|
|
|
|
if fast_lm is None and settings.fast_lm_provider == "zai" and settings.zai_api_token:
|
|
try:
|
|
fast_lm = dspy.LM(
|
|
"openai/glm-4.5-flash",
|
|
api_key=settings.zai_api_token,
|
|
api_base="https://api.z.ai/api/coding/paas/v4",
|
|
)
|
|
logger.info("Using Z.AI GLM-4.5-flash (FREE) as fast_lm for routing/extraction (~13s)")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create fast Z.AI LM: {e}")
|
|
|
|
# Fallback: try the other provider if preferred one failed
|
|
if fast_lm is None and settings.openai_api_key:
|
|
try:
|
|
fast_lm = dspy.LM("openai/gpt-4o-mini", api_key=settings.openai_api_key)
|
|
logger.info("Fallback: Using OpenAI GPT-4o-mini as fast_lm")
|
|
except Exception as e:
|
|
logger.warning(f"Fallback failed - no fast_lm available: {e}")
|
|
|
|
if fast_lm is None:
|
|
logger.info("No fast_lm available - all stages will use quality_lm (slower but works)")
|
|
|
|
# Convert context to DSPy History format
|
|
# Context comes as [{question: "...", answer: "..."}, ...]
|
|
# History expects messages in the same format: [{question: "...", answer: "..."}, ...]
|
|
# (NOT role/content format - that was a bug!)
|
|
history_messages = []
|
|
for turn in request.context:
|
|
# Only include turns that have both question AND answer
|
|
if turn.get("question") and turn.get("answer"):
|
|
history_messages.append({
|
|
"question": turn["question"],
|
|
"answer": turn["answer"]
|
|
})
|
|
|
|
history = History(messages=history_messages) if history_messages else None
|
|
|
|
# Use global optimized pipeline (loaded with BootstrapFewShot weights: +14.3% quality)
|
|
# Falls back to creating a new pipeline if global not available
|
|
if dspy_pipeline is not None:
|
|
pipeline = dspy_pipeline
|
|
logger.debug("Using global optimized DSPy pipeline")
|
|
else:
|
|
# Fallback: create pipeline without optimized weights
|
|
qdrant_retriever = retriever.qdrant if retriever else None
|
|
pipeline = HeritageRAGPipeline(
|
|
retriever=qdrant_retriever,
|
|
fast_lm=fast_lm,
|
|
quality_lm=lm,
|
|
)
|
|
logger.debug("Using fallback (unoptimized) DSPy pipeline")
|
|
|
|
# Execute query with conversation history
|
|
# Retry logic for transient API errors (e.g., Anthropic "Overloaded" errors)
|
|
#
|
|
# IMPORTANT: We use dspy.settings.context(lm=lm) to set the LLM for this request.
|
|
# This provides thread-local overrides that work correctly in async request handlers,
|
|
# unlike dspy.configure() which can only be called from the main async task.
|
|
max_retries = 3
|
|
last_error: Exception | None = None
|
|
result = None
|
|
|
|
with dspy.settings.context(lm=lm):
|
|
for attempt in range(max_retries):
|
|
try:
|
|
# Use pipeline() instead of pipeline.forward() per DSPy 3.0 best practices
|
|
result = pipeline(
|
|
embedding_model=request.embedding_model,
|
|
question=request.question,
|
|
language=request.language,
|
|
history=history,
|
|
include_viz=request.include_visualization,
|
|
)
|
|
break # Success, exit retry loop
|
|
except Exception as e:
|
|
last_error = e
|
|
error_str = str(e).lower()
|
|
# Check for retryable errors (API overload, rate limits, temporary failures)
|
|
is_retryable = any(keyword in error_str for keyword in [
|
|
"overloaded", "rate_limit", "rate limit", "too many requests",
|
|
"529", "503", "502", "504", # HTTP status codes
|
|
"temporarily unavailable", "service unavailable",
|
|
"connection reset", "connection refused", "timeout"
|
|
])
|
|
|
|
if is_retryable and attempt < max_retries - 1:
|
|
wait_time = 2 ** attempt # Exponential backoff: 1s, 2s, 4s
|
|
logger.warning(
|
|
f"Transient API error (attempt {attempt + 1}/{max_retries}): {e}. "
|
|
f"Retrying in {wait_time}s..."
|
|
)
|
|
time.sleep(wait_time)
|
|
continue
|
|
else:
|
|
# Non-retryable error or max retries reached
|
|
raise
|
|
|
|
# If we get here without a result (all retries exhausted), raise the last error
|
|
if result is None:
|
|
if last_error:
|
|
raise last_error
|
|
raise HTTPException(status_code=500, detail="Pipeline execution failed with no result")
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
# Extract visualization if present
|
|
visualization = None
|
|
if request.include_visualization and hasattr(result, "visualization"):
|
|
viz = result.visualization
|
|
if viz:
|
|
visualization = {
|
|
"type": getattr(viz, "viz_type", "table"),
|
|
"sparql_query": getattr(result, "sparql", None),
|
|
}
|
|
|
|
# Extract retrieved results for frontend visualization (tables, graphs)
|
|
retrieved_results = getattr(result, "retrieved_results", None)
|
|
query_type = getattr(result, "query_type", None)
|
|
|
|
# Extract LLM response metadata from DSPy history (GLM 4.7 reasoning_content support)
|
|
llm_response_metadata = extract_llm_response_metadata(
|
|
lm=lm,
|
|
provider=llm_provider_used,
|
|
latency_ms=int(elapsed_ms),
|
|
)
|
|
|
|
# Build response object
|
|
response = DSPyQueryResponse(
|
|
question=request.question,
|
|
resolved_question=getattr(result, "resolved_question", None),
|
|
answer=getattr(result, "answer", "Geen antwoord gevonden."),
|
|
sources_used=getattr(result, "sources_used", []),
|
|
visualization=visualization,
|
|
retrieved_results=retrieved_results, # Raw data for frontend visualization
|
|
query_type=query_type, # "person" or "institution"
|
|
query_time_ms=round(elapsed_ms, 2),
|
|
conversation_turn=len(request.context),
|
|
embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model),
|
|
# Cost tracking fields
|
|
timing_ms=getattr(result, "timing_ms", None),
|
|
cost_usd=getattr(result, "cost_usd", None),
|
|
timing_breakdown=getattr(result, "timing_breakdown", None),
|
|
# LLM provider tracking
|
|
llm_provider_used=llm_provider_used,
|
|
llm_model_used=llm_model_used,
|
|
cache_hit=False,
|
|
# LLM response provenance (GLM 4.7 Thinking Mode chain-of-thought)
|
|
llm_response=llm_response_metadata,
|
|
)
|
|
|
|
# Cache the successful response for future requests
|
|
if retriever:
|
|
await retriever.cache.set_dspy(
|
|
question=request.question,
|
|
language=request.language,
|
|
llm_provider=llm_provider_used, # Use actual provider, not requested
|
|
embedding_model=request.embedding_model,
|
|
response=response.model_dump(),
|
|
context=request.context if request.context else None,
|
|
)
|
|
|
|
return response
|
|
|
|
except ImportError as e:
|
|
logger.warning(f"DSPy pipeline not available: {e}")
|
|
# Fallback to simple response
|
|
return DSPyQueryResponse(
|
|
question=request.question,
|
|
answer="DSPy pipeline is niet beschikbaar. Probeer de standaard /api/rag/query endpoint.",
|
|
query_time_ms=0,
|
|
conversation_turn=len(request.context),
|
|
embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception("DSPy query failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
async def stream_dspy_query_response(
|
|
request: DSPyQueryRequest,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream DSPy query response with progress updates for long-running queries.
|
|
|
|
Yields NDJSON lines with status updates at each pipeline stage:
|
|
- {"type": "status", "stage": "cache", "message": "🔍 Cache controleren..."}
|
|
- {"type": "status", "stage": "config", "message": "⚙️ LLM configureren..."}
|
|
- {"type": "status", "stage": "routing", "message": "🧭 Vraag analyseren..."}
|
|
- {"type": "status", "stage": "retrieval", "message": "📊 Database doorzoeken..."}
|
|
- {"type": "status", "stage": "generation", "message": "💡 Antwoord genereren..."}
|
|
- {"type": "complete", "data": {...DSPyQueryResponse...}}
|
|
"""
|
|
import time
|
|
start_time = time.time()
|
|
|
|
def emit_status(stage: str, message: str) -> str:
|
|
"""Helper to emit status JSON line."""
|
|
return json.dumps({
|
|
"type": "status",
|
|
"stage": stage,
|
|
"message": message,
|
|
"elapsed_ms": round((time.time() - start_time) * 1000, 2),
|
|
}) + "\n"
|
|
|
|
def emit_error(error: str, details: str | None = None) -> str:
|
|
"""Helper to emit error JSON line."""
|
|
return json.dumps({
|
|
"type": "error",
|
|
"error": error,
|
|
"details": details,
|
|
"elapsed_ms": round((time.time() - start_time) * 1000, 2),
|
|
}) + "\n"
|
|
|
|
def extract_user_friendly_error(exception: Exception) -> tuple[str, str | None]:
|
|
"""Extract a user-friendly error message from various exception types.
|
|
|
|
Returns:
|
|
tuple: (user_message, technical_details)
|
|
"""
|
|
error_str = str(exception)
|
|
error_lower = error_str.lower()
|
|
|
|
# HuggingFace / LiteLLM specific errors
|
|
if "huggingface" in error_lower or "hf" in error_lower:
|
|
if "model_not_supported" in error_lower or "not a chat model" in error_lower:
|
|
# Extract model name if present
|
|
import re
|
|
model_match = re.search(r"model['\"]?\s*[:=]\s*['\"]?([^'\"}\s,]+)", error_str)
|
|
model_name = model_match.group(1) if model_match else "geselecteerde model"
|
|
return (
|
|
f"Het model '{model_name}' wordt niet ondersteund door HuggingFace. Kies een ander model.",
|
|
error_str
|
|
)
|
|
if "rate limit" in error_lower or "too many requests" in error_lower:
|
|
return (
|
|
"HuggingFace API limiet bereikt. Probeer het over een minuut opnieuw.",
|
|
error_str
|
|
)
|
|
if "unauthorized" in error_lower or "invalid api key" in error_lower:
|
|
return (
|
|
"HuggingFace API sleutel ongeldig. Neem contact op met de beheerder.",
|
|
error_str
|
|
)
|
|
if "model is loading" in error_lower or "loading" in error_lower and "model" in error_lower:
|
|
return (
|
|
"Het HuggingFace model wordt geladen. Probeer het over 30 seconden opnieuw.",
|
|
error_str
|
|
)
|
|
|
|
# Anthropic errors
|
|
if "anthropic" in error_lower:
|
|
if "rate limit" in error_lower or "overloaded" in error_lower:
|
|
return (
|
|
"Anthropic API is overbelast. Probeer het over een minuut opnieuw.",
|
|
error_str
|
|
)
|
|
if "invalid api key" in error_lower or "unauthorized" in error_lower:
|
|
return (
|
|
"Anthropic API sleutel ongeldig. Neem contact op met de beheerder.",
|
|
error_str
|
|
)
|
|
|
|
# OpenAI errors
|
|
if "openai" in error_lower:
|
|
if "rate limit" in error_lower:
|
|
return (
|
|
"OpenAI API limiet bereikt. Probeer het over een minuut opnieuw.",
|
|
error_str
|
|
)
|
|
if "invalid api key" in error_lower:
|
|
return (
|
|
"OpenAI API sleutel ongeldig. Neem contact op met de beheerder.",
|
|
error_str
|
|
)
|
|
|
|
# Z.AI errors
|
|
if "z.ai" in error_lower or "zai" in error_lower:
|
|
if "rate limit" in error_lower or "quota" in error_lower:
|
|
return (
|
|
"Z.AI API limiet bereikt. Probeer het over een minuut opnieuw.",
|
|
error_str
|
|
)
|
|
|
|
# Generic network/connection errors
|
|
if "connection" in error_lower or "timeout" in error_lower:
|
|
return (
|
|
"Verbindingsfout met de AI service. Controleer uw internetverbinding en probeer het opnieuw.",
|
|
error_str
|
|
)
|
|
|
|
if "503" in error_str or "service unavailable" in error_lower:
|
|
return (
|
|
"De AI service is tijdelijk niet beschikbaar. Probeer het over een minuut opnieuw.",
|
|
error_str
|
|
)
|
|
|
|
# Qdrant/retrieval errors
|
|
if "qdrant" in error_lower:
|
|
return (
|
|
"Fout bij het doorzoeken van de database. Probeer het later opnieuw.",
|
|
error_str
|
|
)
|
|
|
|
# Default: return the raw error but in a nicer format
|
|
return (
|
|
f"Er is een fout opgetreden: {error_str[:200]}{'...' if len(error_str) > 200 else ''}",
|
|
error_str if len(error_str) > 200 else None
|
|
)
|
|
|
|
# Resolve the provider BEFORE cache lookup to ensure consistent cache keys
|
|
# This is critical: cache GET and SET must use the same provider value
|
|
resolved_provider = (request.llm_provider or settings.llm_provider).lower()
|
|
|
|
# Stage 1: Check cache
|
|
yield emit_status("cache", "🔍 Cache controleren...")
|
|
|
|
if retriever:
|
|
cached = await retriever.cache.get_dspy(
|
|
question=request.question,
|
|
language=request.language,
|
|
llm_provider=resolved_provider, # Use resolved provider, not request.llm_provider
|
|
embedding_model=request.embedding_model,
|
|
context=request.context if request.context else None,
|
|
)
|
|
if cached:
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
logger.info(f"DSPy cache hit - returning cached response in {elapsed_ms:.2f}ms")
|
|
|
|
# Transform CachedResponse format back to DSPyQueryResponse format
|
|
cached_context = cached.get("context") or {}
|
|
visualization = None
|
|
if cached.get("visualization_type") or cached.get("visualization_data"):
|
|
visualization = {
|
|
"type": cached.get("visualization_type"),
|
|
"data": cached.get("visualization_data"),
|
|
}
|
|
|
|
response_data = {
|
|
"question": request.question,
|
|
"answer": cached.get("answer", ""),
|
|
"sources_used": cached.get("sources", []),
|
|
"visualization": visualization,
|
|
"resolved_question": cached_context.get("resolved_question"),
|
|
"retrieved_results": cached_context.get("retrieved_results"),
|
|
"query_type": cached_context.get("query_type"),
|
|
"embedding_model_used": cached_context.get("embedding_model"),
|
|
"llm_model_used": cached_context.get("llm_model"),
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"cache_hit": True,
|
|
}
|
|
yield emit_status("cache", "✅ Antwoord gevonden in cache!")
|
|
yield json.dumps({"type": "complete", "data": response_data}) + "\n"
|
|
return
|
|
|
|
try:
|
|
# Stage 2: Configure LLM
|
|
yield emit_status("config", "⚙️ LLM configureren...")
|
|
|
|
import dspy
|
|
from dspy import History
|
|
from dspy_heritage_rag import HeritageRAGPipeline
|
|
|
|
requested_provider = resolved_provider # Already resolved above
|
|
llm_provider_used: str | None = None
|
|
llm_model_used: str | None = None
|
|
lm = None
|
|
|
|
providers_to_try = [requested_provider]
|
|
for fallback in ["zai", "groq", "anthropic", "openai"]:
|
|
if fallback not in providers_to_try:
|
|
providers_to_try.append(fallback)
|
|
|
|
for provider in providers_to_try:
|
|
if lm is not None:
|
|
break
|
|
|
|
# Default models per provider (used if request.llm_model is not specified)
|
|
# Use LLM_MODEL from settings when it matches the provider prefix
|
|
default_models = {
|
|
"zai": settings.llm_model if settings.llm_model.startswith("glm-") else "glm-4.5-flash",
|
|
"groq": "llama-3.1-8b-instant",
|
|
"anthropic": settings.llm_model if settings.llm_model.startswith("claude-") else "claude-sonnet-4-20250514",
|
|
"openai": "gpt-4o-mini",
|
|
# Llama 3.1 8B: Good balance of speed/quality, available on HF serverless inference
|
|
# Alternatives: Qwen/QwQ-32B (better reasoning), mistralai/Mistral-7B-Instruct-v0.2
|
|
"huggingface": settings.llm_model if "/" in settings.llm_model else "meta-llama/Llama-3.1-8B-Instruct",
|
|
}
|
|
# HuggingFace models use org/model format (e.g., meta-llama/Llama-3.1-8B-Instruct)
|
|
# Groq models use simple names (e.g., llama-3.1-8b-instant)
|
|
model_prefixes = {
|
|
"glm-": "zai",
|
|
"llama-3.1-": "groq",
|
|
"llama-3.3-": "groq",
|
|
"claude-": "anthropic",
|
|
"gpt-": "openai",
|
|
# HuggingFace organization prefixes
|
|
"mistralai/": "huggingface",
|
|
"google/": "huggingface",
|
|
"Qwen/": "huggingface",
|
|
"deepseek-ai/": "huggingface",
|
|
"meta-llama/": "huggingface",
|
|
"utter-project/": "huggingface",
|
|
"microsoft/": "huggingface",
|
|
"tiiuae/": "huggingface",
|
|
}
|
|
|
|
# Determine which model to use: requested model (if valid for this provider) or default
|
|
requested_model = request.llm_model
|
|
model_to_use = default_models.get(provider, "")
|
|
|
|
# Check if requested model matches this provider
|
|
if requested_model:
|
|
for prefix, model_provider in model_prefixes.items():
|
|
if requested_model.startswith(prefix) and model_provider == provider:
|
|
model_to_use = requested_model
|
|
break
|
|
|
|
if provider == "zai" and settings.zai_api_token:
|
|
try:
|
|
lm = dspy.LM(
|
|
f"openai/{model_to_use}",
|
|
api_key=settings.zai_api_token,
|
|
api_base="https://api.z.ai/api/coding/paas/v4",
|
|
)
|
|
llm_provider_used = "zai"
|
|
llm_model_used = model_to_use
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create Z.AI LM: {e}")
|
|
|
|
elif provider == "groq" and settings.groq_api_key:
|
|
try:
|
|
lm = dspy.LM(f"groq/{model_to_use}", api_key=settings.groq_api_key)
|
|
llm_provider_used = "groq"
|
|
llm_model_used = model_to_use
|
|
logger.info(f"Using Groq {model_to_use} (FREE) for streaming request")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create Groq LM: {e}")
|
|
|
|
elif provider == "huggingface" and settings.huggingface_api_key:
|
|
try:
|
|
lm = dspy.LM(f"huggingface/{model_to_use}", api_key=settings.huggingface_api_key)
|
|
llm_provider_used = "huggingface"
|
|
llm_model_used = model_to_use
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create HuggingFace LM: {e}")
|
|
|
|
elif provider == "anthropic" and settings.anthropic_api_key:
|
|
try:
|
|
lm = dspy.LM(f"anthropic/{model_to_use}", api_key=settings.anthropic_api_key)
|
|
llm_provider_used = "anthropic"
|
|
llm_model_used = model_to_use
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create Anthropic LM: {e}")
|
|
|
|
elif provider == "openai" and settings.openai_api_key:
|
|
try:
|
|
lm = dspy.LM(f"openai/{model_to_use}", api_key=settings.openai_api_key)
|
|
llm_provider_used = "openai"
|
|
llm_model_used = model_to_use
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create OpenAI LM: {e}")
|
|
|
|
if lm is None:
|
|
yield emit_error(f"Geen LLM beschikbaar. Controleer API keys.")
|
|
return
|
|
|
|
yield emit_status("config", f"✅ LLM geconfigureerd ({llm_provider_used})")
|
|
|
|
# Stage 3: Prepare conversation history
|
|
yield emit_status("routing", "🧭 Vraag analyseren...")
|
|
|
|
history_messages = []
|
|
for turn in request.context:
|
|
if turn.get("question") and turn.get("answer"):
|
|
history_messages.append({
|
|
"question": turn["question"],
|
|
"answer": turn["answer"]
|
|
})
|
|
|
|
history = History(messages=history_messages) if history_messages else None
|
|
|
|
# Use global optimized pipeline (loaded with BootstrapFewShot weights: +14.3% quality)
|
|
if dspy_pipeline is not None:
|
|
pipeline = dspy_pipeline
|
|
logger.debug("Using global optimized DSPy pipeline (streaming)")
|
|
else:
|
|
# Fallback: create pipeline without optimized weights
|
|
qdrant_retriever = retriever.qdrant if retriever else None
|
|
pipeline = HeritageRAGPipeline(retriever=qdrant_retriever)
|
|
logger.debug("Using fallback (unoptimized) DSPy pipeline (streaming)")
|
|
|
|
# Stage 4: Execute pipeline with STREAMING answer generation
|
|
yield emit_status("retrieval", "📊 Database doorzoeken...")
|
|
|
|
result = None
|
|
|
|
# Check if pipeline supports streaming
|
|
if hasattr(pipeline, 'forward_streaming'):
|
|
# Use streaming mode - tokens arrive as they're generated
|
|
try:
|
|
with dspy.settings.context(lm=lm):
|
|
async for event in pipeline.forward_streaming(
|
|
embedding_model=request.embedding_model,
|
|
question=request.question,
|
|
language=request.language,
|
|
history=history,
|
|
include_viz=request.include_visualization,
|
|
):
|
|
event_type = event.get("type")
|
|
|
|
if event_type == "cache_hit":
|
|
# Cache hit - return immediately
|
|
result = event["prediction"]
|
|
yield emit_status("complete", "✅ Klaar! (cache)")
|
|
break
|
|
|
|
elif event_type == "retrieval_complete":
|
|
# Retrieval done, now generating answer
|
|
yield emit_status("generation", "💡 Antwoord genereren...")
|
|
|
|
elif event_type == "token":
|
|
# Stream token to frontend
|
|
yield json.dumps({"type": "token", "content": event["content"]}) + "\n"
|
|
|
|
elif event_type == "status":
|
|
# Status message from pipeline
|
|
yield emit_status("generation", event.get("message", "..."))
|
|
|
|
elif event_type == "answer_complete":
|
|
# Final prediction ready
|
|
result = event["prediction"]
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Streaming pipeline execution failed: {e}")
|
|
user_msg, details = extract_user_friendly_error(e)
|
|
yield emit_error(user_msg, details)
|
|
return
|
|
else:
|
|
# Fallback: Non-streaming mode (original behavior)
|
|
max_retries = 3
|
|
last_error: Exception | None = None
|
|
|
|
with dspy.settings.context(lm=lm):
|
|
for attempt in range(max_retries):
|
|
try:
|
|
if attempt > 0:
|
|
yield emit_status("retrieval", f"🔄 Opnieuw proberen ({attempt + 1}/{max_retries})...")
|
|
|
|
result = pipeline(
|
|
embedding_model=request.embedding_model,
|
|
question=request.question,
|
|
language=request.language,
|
|
history=history,
|
|
include_viz=request.include_visualization,
|
|
)
|
|
break
|
|
except Exception as e:
|
|
last_error = e
|
|
error_str = str(e).lower()
|
|
is_retryable = any(keyword in error_str for keyword in [
|
|
"overloaded", "rate_limit", "rate limit", "too many requests",
|
|
"529", "503", "502", "504",
|
|
"temporarily unavailable", "service unavailable",
|
|
"connection reset", "connection refused", "timeout"
|
|
])
|
|
|
|
if is_retryable and attempt < max_retries - 1:
|
|
wait_time = 2 ** attempt
|
|
logger.warning(f"Transient API error (attempt {attempt + 1}/{max_retries}): {e}")
|
|
yield emit_status("retrieval", f"⏳ API overbelast, wachten {wait_time}s...")
|
|
await asyncio.sleep(wait_time)
|
|
continue
|
|
else:
|
|
logger.exception(f"Pipeline execution failed after {attempt + 1} attempts")
|
|
user_msg, details = extract_user_friendly_error(e)
|
|
yield emit_error(user_msg, details)
|
|
return
|
|
|
|
if result is None:
|
|
if last_error:
|
|
user_msg, details = extract_user_friendly_error(last_error)
|
|
yield emit_error(user_msg, details)
|
|
return
|
|
yield emit_error("Pipeline uitvoering mislukt zonder resultaat")
|
|
return
|
|
|
|
# Stage 5: Generate response (only for non-streaming fallback)
|
|
yield emit_status("generation", "💡 Antwoord genereren...")
|
|
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
|
|
visualization = None
|
|
if request.include_visualization and hasattr(result, "visualization"):
|
|
viz = result.visualization
|
|
if viz:
|
|
visualization = {
|
|
"type": getattr(viz, "viz_type", "table"),
|
|
"sparql_query": getattr(result, "sparql", None),
|
|
}
|
|
|
|
retrieved_results = getattr(result, "retrieved_results", None)
|
|
query_type = getattr(result, "query_type", None)
|
|
|
|
# Extract LLM response metadata from DSPy history (GLM 4.7 reasoning_content support)
|
|
llm_response_metadata = extract_llm_response_metadata(
|
|
lm=lm,
|
|
provider=llm_provider_used,
|
|
latency_ms=int(elapsed_ms),
|
|
)
|
|
|
|
response = DSPyQueryResponse(
|
|
question=request.question,
|
|
resolved_question=getattr(result, "resolved_question", None),
|
|
answer=getattr(result, "answer", "Geen antwoord gevonden."),
|
|
sources_used=getattr(result, "sources_used", []),
|
|
visualization=visualization,
|
|
retrieved_results=retrieved_results,
|
|
query_type=query_type,
|
|
query_time_ms=round(elapsed_ms, 2),
|
|
conversation_turn=len(request.context),
|
|
embedding_model_used=getattr(result, "embedding_model_used", request.embedding_model),
|
|
timing_ms=getattr(result, "timing_ms", None),
|
|
cost_usd=getattr(result, "cost_usd", None),
|
|
timing_breakdown=getattr(result, "timing_breakdown", None),
|
|
llm_provider_used=llm_provider_used,
|
|
llm_model_used=llm_model_used,
|
|
cache_hit=False,
|
|
# LLM response provenance (GLM 4.7 Thinking Mode chain-of-thought)
|
|
llm_response=llm_response_metadata,
|
|
)
|
|
|
|
# Cache the response
|
|
if retriever:
|
|
await retriever.cache.set_dspy(
|
|
question=request.question,
|
|
language=request.language,
|
|
llm_provider=llm_provider_used,
|
|
embedding_model=request.embedding_model,
|
|
response=response.model_dump(),
|
|
context=request.context if request.context else None,
|
|
)
|
|
|
|
yield emit_status("complete", "✅ Klaar!")
|
|
yield json.dumps({"type": "complete", "data": response.model_dump()}) + "\n"
|
|
|
|
except ImportError as e:
|
|
logger.warning(f"DSPy pipeline not available: {e}")
|
|
yield emit_error("DSPy pipeline is niet beschikbaar.")
|
|
|
|
except Exception as e:
|
|
logger.exception("DSPy streaming query failed")
|
|
user_msg, details = extract_user_friendly_error(e)
|
|
yield emit_error(user_msg, details)
|
|
|
|
|
|
@app.post("/api/rag/dspy/query/stream")
|
|
async def dspy_query_stream(request: DSPyQueryRequest) -> StreamingResponse:
|
|
"""Streaming version of DSPy RAG query endpoint.
|
|
|
|
Returns NDJSON stream with status updates at each pipeline stage,
|
|
allowing the frontend to show progress during long-running queries.
|
|
|
|
Status stages:
|
|
- cache: Checking for cached response
|
|
- config: Configuring LLM provider
|
|
- routing: Analyzing query intent
|
|
- retrieval: Searching databases (Qdrant, SPARQL, etc.)
|
|
- generation: Generating answer with LLM
|
|
- complete: Final response ready
|
|
"""
|
|
return StreamingResponse(
|
|
stream_dspy_query_response(request),
|
|
media_type="application/x-ndjson",
|
|
)
|
|
|
|
|
|
async def stream_query_response(
|
|
request: QueryRequest,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream query response for long-running queries."""
|
|
if not retriever:
|
|
yield json.dumps({"error": "Retriever not initialized"})
|
|
return
|
|
|
|
start_time = asyncio.get_event_loop().time()
|
|
|
|
# Route query
|
|
intent, sources = retriever.router.get_sources(request.question, request.sources)
|
|
|
|
# Extract geographic filters from question (province, city, institution type)
|
|
geo_filters = extract_geographic_filters(request.question)
|
|
|
|
yield json.dumps({
|
|
"type": "status",
|
|
"message": f"Routing query to {len(sources)} sources...",
|
|
"intent": intent.value,
|
|
"geo_filters": {k: v for k, v in geo_filters.items() if v},
|
|
}) + "\n"
|
|
|
|
# Retrieve from sources and stream progress
|
|
results = []
|
|
for source in sources:
|
|
yield json.dumps({
|
|
"type": "status",
|
|
"message": f"Querying {source.value}...",
|
|
}) + "\n"
|
|
|
|
source_results = await retriever.retrieve(
|
|
request.question,
|
|
[source],
|
|
request.k,
|
|
embedding_model=request.embedding_model,
|
|
region_codes=geo_filters["region_codes"],
|
|
cities=geo_filters["cities"],
|
|
institution_types=geo_filters["institution_types"],
|
|
)
|
|
results.extend(source_results)
|
|
|
|
yield json.dumps({
|
|
"type": "partial",
|
|
"source": source.value,
|
|
"count": len(source_results[0].items) if source_results else 0,
|
|
}) + "\n"
|
|
|
|
# Merge and finalize
|
|
merged = retriever.merge_results(results)
|
|
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
|
|
|
|
yield json.dumps({
|
|
"type": "complete",
|
|
"results": merged,
|
|
"query_time_ms": round(elapsed_ms, 2),
|
|
"result_count": len(merged),
|
|
}) + "\n"
|
|
|
|
|
|
@app.post("/api/rag/query/stream")
|
|
async def query_rag_stream(request: QueryRequest) -> StreamingResponse:
|
|
"""Streaming version of RAG query endpoint."""
|
|
return StreamingResponse(
|
|
stream_query_response(request),
|
|
media_type="application/x-ndjson",
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# SEMANTIC CACHE ENDPOINTS (Qdrant-backed)
|
|
# =============================================================================
|
|
#
|
|
# High-performance semantic cache using Qdrant's HNSW vector index.
|
|
# Replaces slow client-side cosine similarity with server-side ANN search.
|
|
#
|
|
# Performance target:
|
|
# - Cache lookup: <20ms (vs 500-2000ms with client-side scan)
|
|
# - Cache store: <50ms
|
|
#
|
|
# Architecture:
|
|
# Frontend → /api/cache/lookup → Qdrant ANN search → cached response
|
|
# /api/cache/store → embed + upsert to Qdrant
|
|
# =============================================================================
|
|
|
|
# Lazy-loaded Qdrant client for cache
|
|
_cache_qdrant_client: Any = None
|
|
_cache_embedding_model: Any = None
|
|
|
|
CACHE_COLLECTION_NAME = "query_cache"
|
|
CACHE_EMBEDDING_DIM = 384 # all-MiniLM-L6-v2
|
|
|
|
|
|
def get_cache_qdrant_client() -> Any:
|
|
"""Get or create Qdrant client for cache collection.
|
|
|
|
Always uses localhost:6333 since cache is co-located with the RAG backend.
|
|
This avoids reverse proxy overhead and ensures direct local connection.
|
|
"""
|
|
global _cache_qdrant_client
|
|
|
|
if _cache_qdrant_client is not None:
|
|
return _cache_qdrant_client
|
|
|
|
try:
|
|
from qdrant_client import QdrantClient
|
|
|
|
# Cache always uses localhost - co-located with RAG backend
|
|
# Uses settings.qdrant_host/port which default to localhost:6333
|
|
_cache_qdrant_client = QdrantClient(
|
|
host=settings.qdrant_host,
|
|
port=settings.qdrant_port,
|
|
timeout=30,
|
|
)
|
|
logger.info(f"Qdrant cache client: {settings.qdrant_host}:{settings.qdrant_port}")
|
|
|
|
return _cache_qdrant_client
|
|
except ImportError:
|
|
logger.error("qdrant-client not installed")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to create Qdrant cache client: {e}")
|
|
return None
|
|
|
|
|
|
def get_cache_embedding_model() -> Any:
|
|
"""Get or create embedding model for cache (MiniLM-L6-v2, 384-dim)."""
|
|
global _cache_embedding_model
|
|
|
|
if _cache_embedding_model is not None:
|
|
return _cache_embedding_model
|
|
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
_cache_embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
|
logger.info("Loaded cache embedding model: all-MiniLM-L6-v2")
|
|
return _cache_embedding_model
|
|
except ImportError:
|
|
logger.error("sentence-transformers not installed")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Failed to load cache embedding model: {e}")
|
|
return None
|
|
|
|
|
|
def ensure_cache_collection_exists() -> bool:
|
|
"""Ensure the query_cache collection exists in Qdrant."""
|
|
client = get_cache_qdrant_client()
|
|
if client is None:
|
|
return False
|
|
|
|
try:
|
|
from qdrant_client.models import Distance, VectorParams
|
|
|
|
# Check if collection exists
|
|
collections = client.get_collections().collections
|
|
if any(c.name == CACHE_COLLECTION_NAME for c in collections):
|
|
return True
|
|
|
|
# Create collection with HNSW index
|
|
client.create_collection(
|
|
collection_name=CACHE_COLLECTION_NAME,
|
|
vectors_config=VectorParams(
|
|
size=CACHE_EMBEDDING_DIM,
|
|
distance=Distance.COSINE,
|
|
),
|
|
)
|
|
logger.info(f"Created Qdrant collection: {CACHE_COLLECTION_NAME}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to ensure cache collection: {e}")
|
|
return False
|
|
|
|
|
|
# Request/Response Models for Cache API
|
|
|
|
class CacheLookupRequest(BaseModel):
|
|
"""Cache lookup request."""
|
|
query: str = Field(..., description="Query text to look up")
|
|
embedding: list[float] | None = Field(default=None, description="Pre-computed embedding (optional)")
|
|
similarity_threshold: float = Field(default=0.92, description="Minimum similarity for match")
|
|
language: str = Field(default="nl", description="Language filter")
|
|
|
|
|
|
class CacheLookupResponse(BaseModel):
|
|
"""Cache lookup response."""
|
|
found: bool
|
|
entry: dict[str, Any] | None = None
|
|
similarity: float = 0.0
|
|
method: str = "none"
|
|
lookup_time_ms: float = 0.0
|
|
|
|
|
|
class CacheStoreRequest(BaseModel):
|
|
"""Cache store request."""
|
|
query: str = Field(..., description="Query text")
|
|
embedding: list[float] | None = Field(default=None, description="Pre-computed embedding (optional)")
|
|
response: dict[str, Any] = Field(..., description="Response to cache")
|
|
language: str = Field(default="nl", description="Language")
|
|
model: str = Field(default="unknown", description="LLM model used")
|
|
ttl_seconds: int = Field(default=86400, description="Time-to-live in seconds")
|
|
|
|
|
|
class CacheStoreResponse(BaseModel):
|
|
"""Cache store response."""
|
|
success: bool
|
|
id: str | None = None
|
|
message: str = ""
|
|
|
|
|
|
class CacheStatsResponse(BaseModel):
|
|
"""Cache statistics response."""
|
|
total_entries: int = 0
|
|
collection_name: str = CACHE_COLLECTION_NAME
|
|
embedding_dim: int = CACHE_EMBEDDING_DIM
|
|
backend: str = "qdrant"
|
|
status: str = "ok"
|
|
|
|
|
|
@app.post("/api/cache/lookup", response_model=CacheLookupResponse)
|
|
async def cache_lookup(request: CacheLookupRequest) -> CacheLookupResponse:
|
|
"""Look up a query in the semantic cache using Qdrant ANN search.
|
|
|
|
This endpoint performs sub-millisecond vector similarity search using
|
|
Qdrant's HNSW index, replacing slow client-side cosine similarity scans.
|
|
"""
|
|
import time
|
|
start_time = time.perf_counter()
|
|
|
|
# Ensure collection exists
|
|
if not ensure_cache_collection_exists():
|
|
return CacheLookupResponse(
|
|
found=False,
|
|
similarity=0.0,
|
|
method="error",
|
|
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
|
|
)
|
|
|
|
client = get_cache_qdrant_client()
|
|
if client is None:
|
|
return CacheLookupResponse(
|
|
found=False,
|
|
similarity=0.0,
|
|
method="error",
|
|
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
|
|
)
|
|
|
|
# Get or generate embedding
|
|
embedding = request.embedding
|
|
if embedding is None:
|
|
model = get_cache_embedding_model()
|
|
if model is None:
|
|
return CacheLookupResponse(
|
|
found=False,
|
|
similarity=0.0,
|
|
method="error",
|
|
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
|
|
)
|
|
embedding = model.encode(request.query).tolist()
|
|
|
|
try:
|
|
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
|
|
|
# Build filter for language
|
|
search_filter = Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="language",
|
|
match=MatchValue(value=request.language),
|
|
)
|
|
]
|
|
)
|
|
|
|
# Perform ANN search using query_points (qdrant-client >= 1.7)
|
|
results = client.query_points(
|
|
collection_name=CACHE_COLLECTION_NAME,
|
|
query=embedding,
|
|
query_filter=search_filter,
|
|
limit=1,
|
|
score_threshold=request.similarity_threshold,
|
|
).points
|
|
|
|
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
|
|
|
if not results:
|
|
return CacheLookupResponse(
|
|
found=False,
|
|
similarity=0.0,
|
|
method="semantic",
|
|
lookup_time_ms=elapsed_ms,
|
|
)
|
|
|
|
# Extract best match
|
|
best = results[0]
|
|
payload = best.payload or {}
|
|
|
|
return CacheLookupResponse(
|
|
found=True,
|
|
entry={
|
|
"id": str(best.id),
|
|
"query": payload.get("query", ""),
|
|
"query_normalized": payload.get("query_normalized", ""),
|
|
"response": payload.get("response", {}),
|
|
"timestamp": payload.get("timestamp", 0),
|
|
"hit_count": payload.get("hit_count", 0),
|
|
"last_accessed": payload.get("last_accessed", 0),
|
|
"language": payload.get("language", "nl"),
|
|
"model": payload.get("model", "unknown"),
|
|
},
|
|
similarity=best.score,
|
|
method="semantic",
|
|
lookup_time_ms=elapsed_ms,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Cache lookup error: {e}")
|
|
return CacheLookupResponse(
|
|
found=False,
|
|
similarity=0.0,
|
|
method="error",
|
|
lookup_time_ms=(time.perf_counter() - start_time) * 1000,
|
|
)
|
|
|
|
|
|
@app.post("/api/cache/store", response_model=CacheStoreResponse)
|
|
async def cache_store(request: CacheStoreRequest) -> CacheStoreResponse:
|
|
"""Store a query/response pair in the semantic cache.
|
|
|
|
Generates embedding if not provided and upserts to Qdrant.
|
|
"""
|
|
import time
|
|
import uuid
|
|
|
|
# Ensure collection exists
|
|
if not ensure_cache_collection_exists():
|
|
return CacheStoreResponse(
|
|
success=False,
|
|
message="Failed to ensure cache collection exists",
|
|
)
|
|
|
|
client = get_cache_qdrant_client()
|
|
if client is None:
|
|
return CacheStoreResponse(
|
|
success=False,
|
|
message="Qdrant client not available",
|
|
)
|
|
|
|
# Get or generate embedding
|
|
embedding = request.embedding
|
|
if embedding is None:
|
|
model = get_cache_embedding_model()
|
|
if model is None:
|
|
return CacheStoreResponse(
|
|
success=False,
|
|
message="Embedding model not available",
|
|
)
|
|
embedding = model.encode(request.query).tolist()
|
|
|
|
try:
|
|
from qdrant_client.models import PointStruct
|
|
|
|
# Generate unique ID
|
|
point_id = str(uuid.uuid4())
|
|
timestamp = int(time.time() * 1000)
|
|
|
|
# Normalize query for exact matching
|
|
query_normalized = request.query.lower().strip()
|
|
|
|
# Create point
|
|
point = PointStruct(
|
|
id=point_id,
|
|
vector=embedding,
|
|
payload={
|
|
"query": request.query,
|
|
"query_normalized": query_normalized,
|
|
"response": request.response,
|
|
"language": request.language,
|
|
"model": request.model,
|
|
"timestamp": timestamp,
|
|
"hit_count": 0,
|
|
"last_accessed": timestamp,
|
|
"ttl_seconds": request.ttl_seconds,
|
|
},
|
|
)
|
|
|
|
# Upsert to Qdrant
|
|
client.upsert(
|
|
collection_name=CACHE_COLLECTION_NAME,
|
|
points=[point],
|
|
)
|
|
|
|
logger.debug(f"Cached query: {request.query[:50]}...")
|
|
|
|
return CacheStoreResponse(
|
|
success=True,
|
|
id=point_id,
|
|
message="Stored successfully",
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Cache store error: {e}")
|
|
return CacheStoreResponse(
|
|
success=False,
|
|
message=str(e),
|
|
)
|
|
|
|
|
|
@app.get("/api/cache/stats", response_model=CacheStatsResponse)
|
|
async def cache_stats() -> CacheStatsResponse:
|
|
"""Get cache statistics."""
|
|
client = get_cache_qdrant_client()
|
|
if client is None:
|
|
return CacheStatsResponse(
|
|
status="error",
|
|
total_entries=0,
|
|
)
|
|
|
|
try:
|
|
# Check if collection exists
|
|
collections = client.get_collections().collections
|
|
if not any(c.name == CACHE_COLLECTION_NAME for c in collections):
|
|
return CacheStatsResponse(
|
|
status="no_collection",
|
|
total_entries=0,
|
|
)
|
|
|
|
# Get collection info
|
|
info = client.get_collection(CACHE_COLLECTION_NAME)
|
|
|
|
return CacheStatsResponse(
|
|
total_entries=info.points_count,
|
|
collection_name=CACHE_COLLECTION_NAME,
|
|
embedding_dim=CACHE_EMBEDDING_DIM,
|
|
backend="qdrant",
|
|
status="ok",
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Cache stats error: {e}")
|
|
return CacheStatsResponse(
|
|
status=f"error: {e}",
|
|
total_entries=0,
|
|
)
|
|
|
|
|
|
@app.delete("/api/cache/clear")
|
|
async def cache_clear() -> dict[str, Any]:
|
|
"""Clear all cache entries (admin only)."""
|
|
client = get_cache_qdrant_client()
|
|
if client is None:
|
|
return {"success": False, "message": "Qdrant client not available"}
|
|
|
|
try:
|
|
# Check if collection exists
|
|
collections = client.get_collections().collections
|
|
if not any(c.name == CACHE_COLLECTION_NAME for c in collections):
|
|
return {"success": True, "message": "Collection does not exist", "deleted": 0}
|
|
|
|
# Get count before deletion
|
|
info = client.get_collection(CACHE_COLLECTION_NAME)
|
|
count = info.points_count
|
|
|
|
# Delete and recreate collection
|
|
client.delete_collection(CACHE_COLLECTION_NAME)
|
|
ensure_cache_collection_exists()
|
|
|
|
return {"success": True, "message": f"Cleared {count} entries", "deleted": count}
|
|
except Exception as e:
|
|
logger.error(f"Cache clear error: {e}")
|
|
return {"success": False, "message": str(e)}
|
|
|
|
|
|
# Main entry point
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"main:app",
|
|
host="0.0.0.0",
|
|
port=8003,
|
|
reload=settings.debug,
|
|
log_level="info",
|
|
)
|