glam/backend/rag/main.py
2025-12-11 22:32:09 +01:00

942 lines
32 KiB
Python

"""
Unified RAG Backend for Heritage Custodian Data
Multi-source retrieval-augmented generation system that combines:
- Qdrant vector search (semantic similarity)
- Oxigraph SPARQL (knowledge graph queries)
- TypeDB (relationship traversal)
- PostGIS (geospatial queries)
- Valkey (semantic caching)
Architecture:
User Query → Query Analysis
┌─────┴─────┐
│ Router │
└─────┬─────┘
┌─────┬─────┼─────┬─────┐
↓ ↓ ↓ ↓ ↓
Qdrant SPARQL TypeDB PostGIS Cache
│ │ │ │ │
└─────┴─────┴─────┴─────┘
┌─────┴─────┐
│ Merger │
└─────┬─────┘
DSPy Generator
Visualization Selector
Response (JSON/Streaming)
Features:
- Intelligent query routing to appropriate data sources
- Score fusion for multi-source results
- Semantic caching via Valkey API
- Streaming responses for long-running queries
- DSPy assertions for output validation
Endpoints:
- POST /api/rag/query - Main RAG query endpoint
- POST /api/rag/sparql - Generate SPARQL with RAG context
- GET /api/rag/health - Health check for all services
- GET /api/rag/stats - Retriever statistics
"""
import asyncio
import hashlib
import json
import logging
import os
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Any, AsyncIterator
import httpx
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Import retrievers (with graceful fallbacks)
try:
import sys
sys.path.insert(0, str(os.path.join(os.path.dirname(__file__), "..", "..", "src")))
from glam_extractor.api.hybrid_retriever import HybridRetriever, create_hybrid_retriever
from glam_extractor.api.qdrant_retriever import HeritageCustodianRetriever
from glam_extractor.api.typedb_retriever import TypeDBRetriever, create_typedb_retriever
from glam_extractor.api.visualization import select_visualization, VisualizationSelector
from glam_extractor.api.dspy_sparql import generate_sparql, configure_dspy
RETRIEVERS_AVAILABLE = True
except ImportError as e:
logger.warning(f"Some retrievers not available: {e}")
RETRIEVERS_AVAILABLE = False
# Configuration
class Settings:
"""Application settings from environment variables."""
# API Configuration
api_title: str = "Heritage RAG API"
api_version: str = "1.0.0"
debug: bool = os.getenv("DEBUG", "false").lower() == "true"
# Valkey Cache
valkey_api_url: str = os.getenv("VALKEY_API_URL", "https://bronhouder.nl/api/cache")
cache_ttl: int = int(os.getenv("CACHE_TTL", "900")) # 15 minutes
# Qdrant Vector DB
qdrant_host: str = os.getenv("QDRANT_HOST", "localhost")
qdrant_port: int = int(os.getenv("QDRANT_PORT", "6333"))
qdrant_use_production: bool = os.getenv("QDRANT_USE_PRODUCTION", "false").lower() == "true"
# Oxigraph SPARQL
sparql_endpoint: str = os.getenv("SPARQL_ENDPOINT", "http://localhost:7878/query")
# TypeDB
typedb_host: str = os.getenv("TYPEDB_HOST", "localhost")
typedb_port: int = int(os.getenv("TYPEDB_PORT", "1729"))
typedb_database: str = os.getenv("TYPEDB_DATABASE", "heritage_custodians")
# PostGIS
postgis_url: str = os.getenv("POSTGIS_URL", "http://localhost:8001")
# LLM Configuration
anthropic_api_key: str = os.getenv("ANTHROPIC_API_KEY", "")
openai_api_key: str = os.getenv("OPENAI_API_KEY", "")
default_model: str = os.getenv("DEFAULT_MODEL", "claude-opus-4-5-20251101")
# Retrieval weights
vector_weight: float = float(os.getenv("VECTOR_WEIGHT", "0.5"))
graph_weight: float = float(os.getenv("GRAPH_WEIGHT", "0.3"))
typedb_weight: float = float(os.getenv("TYPEDB_WEIGHT", "0.2"))
settings = Settings()
# Enums and Models
class QueryIntent(str, Enum):
"""Detected query intent for routing."""
GEOGRAPHIC = "geographic" # Location-based queries
STATISTICAL = "statistical" # Counts, aggregations
RELATIONAL = "relational" # Relationships between entities
TEMPORAL = "temporal" # Historical, timeline queries
SEARCH = "search" # General text search
DETAIL = "detail" # Specific entity lookup
class DataSource(str, Enum):
"""Available data sources."""
QDRANT = "qdrant"
SPARQL = "sparql"
TYPEDB = "typedb"
POSTGIS = "postgis"
CACHE = "cache"
@dataclass
class RetrievalResult:
"""Result from a single retriever."""
source: DataSource
items: list[dict[str, Any]]
score: float = 0.0
query_time_ms: float = 0.0
metadata: dict[str, Any] = field(default_factory=dict)
class QueryRequest(BaseModel):
"""RAG query request."""
question: str = Field(..., description="Natural language question")
language: str = Field(default="nl", description="Language code (nl or en)")
context: list[dict[str, Any]] = Field(default=[], description="Conversation history")
sources: list[DataSource] = Field(
default=[DataSource.QDRANT, DataSource.SPARQL],
description="Data sources to query",
)
k: int = Field(default=10, description="Number of results per source")
include_visualization: bool = Field(default=True, description="Include visualization config")
stream: bool = Field(default=False, description="Stream response")
class QueryResponse(BaseModel):
"""RAG query response."""
question: str
sparql: str | None = None
results: list[dict[str, Any]]
visualization: dict[str, Any] | None = None
sources_used: list[DataSource]
cache_hit: bool = False
query_time_ms: float
result_count: int
class SPARQLRequest(BaseModel):
"""SPARQL generation request."""
question: str
language: str = "nl"
context: list[dict[str, Any]] = []
use_rag: bool = True
class SPARQLResponse(BaseModel):
"""SPARQL generation response."""
sparql: str
explanation: str
rag_used: bool
retrieved_passages: list[str] = []
# Cache Client
class ValkeyClient:
"""Client for Valkey semantic cache API."""
def __init__(self, base_url: str = settings.valkey_api_url):
self.base_url = base_url.rstrip("/")
self._client: httpx.AsyncClient | None = None
@property
async def client(self) -> httpx.AsyncClient:
"""Get or create async HTTP client."""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(timeout=30.0)
return self._client
def _cache_key(self, question: str, sources: list[DataSource]) -> str:
"""Generate cache key from question and sources."""
sources_str = ",".join(sorted(s.value for s in sources))
key_str = f"{question.lower().strip()}:{sources_str}"
return hashlib.sha256(key_str.encode()).hexdigest()[:32]
async def get(self, question: str, sources: list[DataSource]) -> dict[str, Any] | None:
"""Get cached response."""
try:
key = self._cache_key(question, sources)
client = await self.client
response = await client.get(f"{self.base_url}/get/{key}")
if response.status_code == 200:
data = response.json()
if data.get("value"):
logger.info(f"Cache hit for question: {question[:50]}...")
return json.loads(data["value"])
return None
except Exception as e:
logger.warning(f"Cache get failed: {e}")
return None
async def set(
self,
question: str,
sources: list[DataSource],
response: dict[str, Any],
ttl: int = settings.cache_ttl,
) -> bool:
"""Cache response."""
try:
key = self._cache_key(question, sources)
client = await self.client
await client.post(
f"{self.base_url}/set",
json={
"key": key,
"value": json.dumps(response),
"ttl": ttl,
},
)
logger.debug(f"Cached response for: {question[:50]}...")
return True
except Exception as e:
logger.warning(f"Cache set failed: {e}")
return False
async def close(self):
"""Close HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
# Query Router
class QueryRouter:
"""Routes queries to appropriate data sources based on intent."""
def __init__(self):
self.intent_keywords = {
QueryIntent.GEOGRAPHIC: [
"map", "kaart", "where", "waar", "location", "locatie",
"city", "stad", "country", "land", "region", "gebied",
"coordinates", "coördinaten", "near", "nearby", "in de buurt",
],
QueryIntent.STATISTICAL: [
"how many", "hoeveel", "count", "aantal", "total", "totaal",
"average", "gemiddeld", "distribution", "verdeling",
"percentage", "statistics", "statistiek", "most", "meest",
],
QueryIntent.RELATIONAL: [
"related", "gerelateerd", "connected", "verbonden",
"relationship", "relatie", "network", "netwerk",
"parent", "child", "merged", "fusie", "member of",
],
QueryIntent.TEMPORAL: [
"history", "geschiedenis", "timeline", "tijdlijn",
"when", "wanneer", "founded", "opgericht", "closed", "gesloten",
"over time", "evolution", "change", "verandering",
],
QueryIntent.DETAIL: [
"details", "information", "informatie", "about", "over",
"specific", "specifiek", "what is", "wat is",
],
}
self.source_routing = {
QueryIntent.GEOGRAPHIC: [DataSource.POSTGIS, DataSource.QDRANT, DataSource.SPARQL],
QueryIntent.STATISTICAL: [DataSource.SPARQL, DataSource.QDRANT],
QueryIntent.RELATIONAL: [DataSource.TYPEDB, DataSource.SPARQL],
QueryIntent.TEMPORAL: [DataSource.TYPEDB, DataSource.SPARQL],
QueryIntent.SEARCH: [DataSource.QDRANT, DataSource.SPARQL],
QueryIntent.DETAIL: [DataSource.SPARQL, DataSource.QDRANT],
}
def detect_intent(self, question: str) -> QueryIntent:
"""Detect query intent from question text."""
question_lower = question.lower()
intent_scores = {intent: 0 for intent in QueryIntent}
for intent, keywords in self.intent_keywords.items():
for keyword in keywords:
if keyword in question_lower:
intent_scores[intent] += 1
max_intent = max(intent_scores, key=intent_scores.get) # type: ignore
if intent_scores[max_intent] == 0:
return QueryIntent.SEARCH
return max_intent
def get_sources(
self,
question: str,
requested_sources: list[DataSource] | None = None,
) -> tuple[QueryIntent, list[DataSource]]:
"""Get optimal sources for a query.
Args:
question: User's question
requested_sources: Explicitly requested sources (overrides routing)
Returns:
Tuple of (detected_intent, list_of_sources)
"""
intent = self.detect_intent(question)
if requested_sources:
return intent, requested_sources
return intent, self.source_routing.get(intent, [DataSource.QDRANT])
# Multi-Source Retriever
class MultiSourceRetriever:
"""Orchestrates retrieval across multiple data sources."""
def __init__(self):
self.cache = ValkeyClient()
self.router = QueryRouter()
# Initialize retrievers lazily
self._qdrant: HybridRetriever | None = None
self._typedb: TypeDBRetriever | None = None
self._sparql_client: httpx.AsyncClient | None = None
self._postgis_client: httpx.AsyncClient | None = None
@property
def qdrant(self) -> HybridRetriever | None:
"""Lazy-load Qdrant hybrid retriever."""
if self._qdrant is None and RETRIEVERS_AVAILABLE:
try:
self._qdrant = create_hybrid_retriever(
use_production=settings.qdrant_use_production
)
except Exception as e:
logger.warning(f"Failed to initialize Qdrant: {e}")
return self._qdrant
@property
def typedb(self) -> TypeDBRetriever | None:
"""Lazy-load TypeDB retriever."""
if self._typedb is None and RETRIEVERS_AVAILABLE:
try:
self._typedb = create_typedb_retriever(
use_production=settings.qdrant_use_production
)
except Exception as e:
logger.warning(f"Failed to initialize TypeDB: {e}")
return self._typedb
async def _get_sparql_client(self) -> httpx.AsyncClient:
"""Get SPARQL HTTP client."""
if self._sparql_client is None or self._sparql_client.is_closed:
self._sparql_client = httpx.AsyncClient(timeout=30.0)
return self._sparql_client
async def _get_postgis_client(self) -> httpx.AsyncClient:
"""Get PostGIS HTTP client."""
if self._postgis_client is None or self._postgis_client.is_closed:
self._postgis_client = httpx.AsyncClient(timeout=30.0)
return self._postgis_client
async def retrieve_from_qdrant(
self,
query: str,
k: int = 10,
) -> RetrievalResult:
"""Retrieve from Qdrant vector + SPARQL hybrid search."""
start = asyncio.get_event_loop().time()
items = []
if self.qdrant:
try:
results = self.qdrant.search(query, k=k)
items = [r.to_dict() for r in results]
except Exception as e:
logger.error(f"Qdrant retrieval failed: {e}")
elapsed = (asyncio.get_event_loop().time() - start) * 1000
return RetrievalResult(
source=DataSource.QDRANT,
items=items,
score=max((r.get("scores", {}).get("combined", 0) for r in items), default=0),
query_time_ms=elapsed,
)
async def retrieve_from_sparql(
self,
query: str,
k: int = 10,
) -> RetrievalResult:
"""Retrieve from SPARQL endpoint."""
start = asyncio.get_event_loop().time()
# Use DSPy to generate SPARQL
items = []
try:
if RETRIEVERS_AVAILABLE:
sparql_result = generate_sparql(query, language="nl", use_rag=False)
sparql_query = sparql_result.get("sparql", "")
if sparql_query:
client = await self._get_sparql_client()
response = await client.post(
settings.sparql_endpoint,
data={"query": sparql_query},
headers={"Accept": "application/sparql-results+json"},
)
if response.status_code == 200:
data = response.json()
bindings = data.get("results", {}).get("bindings", [])
items = [
{k: v.get("value") for k, v in b.items()}
for b in bindings[:k]
]
except Exception as e:
logger.error(f"SPARQL retrieval failed: {e}")
elapsed = (asyncio.get_event_loop().time() - start) * 1000
return RetrievalResult(
source=DataSource.SPARQL,
items=items,
score=1.0 if items else 0.0,
query_time_ms=elapsed,
)
async def retrieve_from_typedb(
self,
query: str,
k: int = 10,
) -> RetrievalResult:
"""Retrieve from TypeDB knowledge graph."""
start = asyncio.get_event_loop().time()
items = []
if self.typedb:
try:
results = self.typedb.semantic_search(query, k=k)
items = [r.to_dict() for r in results]
except Exception as e:
logger.error(f"TypeDB retrieval failed: {e}")
elapsed = (asyncio.get_event_loop().time() - start) * 1000
return RetrievalResult(
source=DataSource.TYPEDB,
items=items,
score=max((r.get("relevance_score", 0) for r in items), default=0),
query_time_ms=elapsed,
)
async def retrieve_from_postgis(
self,
query: str,
k: int = 10,
) -> RetrievalResult:
"""Retrieve from PostGIS geospatial database."""
start = asyncio.get_event_loop().time()
# Extract location from query for geospatial search
# This is a simplified implementation
items = []
try:
client = await self._get_postgis_client()
# Try to detect city name for bbox search
query_lower = query.lower()
# Simple city detection
cities = {
"amsterdam": {"lat": 52.3676, "lon": 4.9041},
"rotterdam": {"lat": 51.9244, "lon": 4.4777},
"den haag": {"lat": 52.0705, "lon": 4.3007},
"utrecht": {"lat": 52.0907, "lon": 5.1214},
}
for city, coords in cities.items():
if city in query_lower:
# Query PostGIS for nearby institutions
response = await client.get(
f"{settings.postgis_url}/api/institutions/nearby",
params={
"lat": coords["lat"],
"lon": coords["lon"],
"radius_km": 10,
"limit": k,
},
)
if response.status_code == 200:
items = response.json()
break
except Exception as e:
logger.error(f"PostGIS retrieval failed: {e}")
elapsed = (asyncio.get_event_loop().time() - start) * 1000
return RetrievalResult(
source=DataSource.POSTGIS,
items=items,
score=1.0 if items else 0.0,
query_time_ms=elapsed,
)
async def retrieve(
self,
question: str,
sources: list[DataSource],
k: int = 10,
) -> list[RetrievalResult]:
"""Retrieve from multiple sources concurrently.
Args:
question: User's question
sources: Data sources to query
k: Number of results per source
Returns:
List of RetrievalResult from each source
"""
tasks = []
for source in sources:
if source == DataSource.QDRANT:
tasks.append(self.retrieve_from_qdrant(question, k))
elif source == DataSource.SPARQL:
tasks.append(self.retrieve_from_sparql(question, k))
elif source == DataSource.TYPEDB:
tasks.append(self.retrieve_from_typedb(question, k))
elif source == DataSource.POSTGIS:
tasks.append(self.retrieve_from_postgis(question, k))
results = await asyncio.gather(*tasks, return_exceptions=True)
# Filter out exceptions
valid_results = []
for r in results:
if isinstance(r, RetrievalResult):
valid_results.append(r)
elif isinstance(r, Exception):
logger.error(f"Retrieval task failed: {r}")
return valid_results
def merge_results(
self,
results: list[RetrievalResult],
max_results: int = 20,
) -> list[dict[str, Any]]:
"""Merge and deduplicate results from multiple sources.
Uses reciprocal rank fusion for score combination.
"""
# Track items by GHCID for deduplication
merged: dict[str, dict[str, Any]] = {}
for result in results:
for rank, item in enumerate(result.items):
ghcid = item.get("ghcid", item.get("id", f"unknown_{rank}"))
if ghcid not in merged:
merged[ghcid] = item.copy()
merged[ghcid]["_sources"] = []
merged[ghcid]["_rrf_score"] = 0.0
# Reciprocal Rank Fusion
rrf_score = 1.0 / (60 + rank) # k=60 is standard
# Weight by source
source_weights = {
DataSource.QDRANT: settings.vector_weight,
DataSource.SPARQL: settings.graph_weight,
DataSource.TYPEDB: settings.typedb_weight,
DataSource.POSTGIS: 0.3,
}
weight = source_weights.get(result.source, 0.5)
merged[ghcid]["_rrf_score"] += rrf_score * weight
merged[ghcid]["_sources"].append(result.source.value)
# Sort by RRF score
sorted_items = sorted(
merged.values(),
key=lambda x: x.get("_rrf_score", 0),
reverse=True,
)
return sorted_items[:max_results]
async def close(self):
"""Clean up resources."""
await self.cache.close()
if self._sparql_client:
await self._sparql_client.aclose()
if self._postgis_client:
await self._postgis_client.aclose()
if self._qdrant:
self._qdrant.close()
if self._typedb:
self._typedb.close()
# Global instances
retriever: MultiSourceRetriever | None = None
viz_selector: VisualizationSelector | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager."""
global retriever, viz_selector
# Startup
logger.info("Starting Heritage RAG API...")
retriever = MultiSourceRetriever()
if RETRIEVERS_AVAILABLE:
viz_selector = VisualizationSelector(use_dspy=bool(settings.anthropic_api_key))
# Configure DSPy if API key available
if settings.anthropic_api_key:
try:
configure_dspy(
provider="anthropic",
model=settings.default_model,
api_key=settings.anthropic_api_key,
)
except Exception as e:
logger.warning(f"Failed to configure DSPy: {e}")
logger.info("Heritage RAG API started")
yield
# Shutdown
logger.info("Shutting down Heritage RAG API...")
if retriever:
await retriever.close()
logger.info("Heritage RAG API stopped")
# Create FastAPI app
app = FastAPI(
title=settings.api_title,
version=settings.api_version,
lifespan=lifespan,
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API Endpoints
@app.get("/api/rag/health")
async def health_check():
"""Health check for all services."""
health = {
"status": "ok",
"timestamp": datetime.now(timezone.utc).isoformat(),
"services": {},
}
# Check Qdrant
if retriever and retriever.qdrant:
try:
stats = retriever.qdrant.get_stats()
health["services"]["qdrant"] = {
"status": "ok",
"vectors": stats.get("qdrant", {}).get("vectors_count", 0),
}
except Exception as e:
health["services"]["qdrant"] = {"status": "error", "error": str(e)}
# Check SPARQL
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{settings.sparql_endpoint.replace('/query', '')}")
health["services"]["sparql"] = {
"status": "ok" if response.status_code < 500 else "error"
}
except Exception as e:
health["services"]["sparql"] = {"status": "error", "error": str(e)}
# Check TypeDB
if retriever and retriever.typedb:
try:
stats = retriever.typedb.get_stats()
health["services"]["typedb"] = {
"status": "ok",
"entities": stats.get("entities", {}),
}
except Exception as e:
health["services"]["typedb"] = {"status": "error", "error": str(e)}
# Overall status
errors = sum(1 for s in health["services"].values() if s.get("status") == "error")
health["status"] = "ok" if errors == 0 else "degraded" if errors < 3 else "error"
return health
@app.get("/api/rag/stats")
async def get_stats():
"""Get retriever statistics."""
stats = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"retrievers": {},
}
if retriever:
if retriever.qdrant:
stats["retrievers"]["qdrant"] = retriever.qdrant.get_stats()
if retriever.typedb:
stats["retrievers"]["typedb"] = retriever.typedb.get_stats()
return stats
@app.post("/api/rag/query", response_model=QueryResponse)
async def query_rag(request: QueryRequest):
"""Main RAG query endpoint.
Orchestrates retrieval from multiple sources, merges results,
and optionally generates visualization configuration.
"""
if not retriever:
raise HTTPException(status_code=503, detail="Retriever not initialized")
start_time = asyncio.get_event_loop().time()
# Check cache first
cached = await retriever.cache.get(request.question, request.sources)
if cached:
cached["cache_hit"] = True
return QueryResponse(**cached)
# Route query to appropriate sources
intent, sources = retriever.router.get_sources(request.question, request.sources)
logger.info(f"Query intent: {intent}, sources: {sources}")
# Retrieve from all sources
results = await retriever.retrieve(request.question, sources, request.k)
# Merge results
merged_items = retriever.merge_results(results, max_results=request.k * 2)
# Generate visualization config if requested
visualization = None
if request.include_visualization and viz_selector and merged_items:
# Extract schema from first result
schema_fields = list(merged_items[0].keys()) if merged_items else []
schema_str = ", ".join(f for f in schema_fields if not f.startswith("_"))
visualization = viz_selector.select(
request.question,
schema_str,
len(merged_items),
)
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
response_data = {
"question": request.question,
"sparql": None, # Could be populated from SPARQL result
"results": merged_items,
"visualization": visualization,
"sources_used": [s for s in sources],
"cache_hit": False,
"query_time_ms": round(elapsed_ms, 2),
"result_count": len(merged_items),
}
# Cache the response
await retriever.cache.set(request.question, request.sources, response_data)
return QueryResponse(**response_data)
@app.post("/api/rag/sparql", response_model=SPARQLResponse)
async def generate_sparql_endpoint(request: SPARQLRequest):
"""Generate SPARQL query from natural language.
Uses DSPy with optional RAG enhancement for context.
"""
if not RETRIEVERS_AVAILABLE:
raise HTTPException(status_code=503, detail="SPARQL generator not available")
try:
result = generate_sparql(
request.question,
language=request.language,
context=request.context,
use_rag=request.use_rag,
)
return SPARQLResponse(
sparql=result["sparql"],
explanation=result.get("explanation", ""),
rag_used=result.get("rag_used", False),
retrieved_passages=result.get("retrieved_passages", []),
)
except Exception as e:
logger.exception("SPARQL generation failed")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/rag/visualize")
async def get_visualization_config(
question: str = Query(..., description="User's question"),
schema: str = Query(..., description="Comma-separated field names"),
result_count: int = Query(default=0, description="Number of results"),
):
"""Get visualization configuration for a query."""
if not viz_selector:
raise HTTPException(status_code=503, detail="Visualization selector not available")
config = viz_selector.select(question, schema, result_count)
return config
async def stream_query_response(
request: QueryRequest,
) -> AsyncIterator[str]:
"""Stream query response for long-running queries."""
if not retriever:
yield json.dumps({"error": "Retriever not initialized"})
return
start_time = asyncio.get_event_loop().time()
# Route query
intent, sources = retriever.router.get_sources(request.question, request.sources)
yield json.dumps({
"type": "status",
"message": f"Routing query to {len(sources)} sources...",
"intent": intent.value,
}) + "\n"
# Retrieve from sources and stream progress
results = []
for source in sources:
yield json.dumps({
"type": "status",
"message": f"Querying {source.value}...",
}) + "\n"
source_results = await retriever.retrieve(request.question, [source], request.k)
results.extend(source_results)
yield json.dumps({
"type": "partial",
"source": source.value,
"count": len(source_results[0].items) if source_results else 0,
}) + "\n"
# Merge and finalize
merged = retriever.merge_results(results)
elapsed_ms = (asyncio.get_event_loop().time() - start_time) * 1000
yield json.dumps({
"type": "complete",
"results": merged,
"query_time_ms": round(elapsed_ms, 2),
"result_count": len(merged),
}) + "\n"
@app.post("/api/rag/query/stream")
async def query_rag_stream(request: QueryRequest):
"""Streaming version of RAG query endpoint."""
return StreamingResponse(
stream_query_response(request),
media_type="application/x-ndjson",
)
# Main entry point
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8002,
reload=settings.debug,
log_level="info",
)