glam/backend/rag/test_live_rag.py
2025-12-11 22:32:09 +01:00

281 lines
8.7 KiB
Python

#!/usr/bin/env python3
"""
Test DSPy Heritage RAG with live SPARQL endpoint.
Requires SSH tunnel to be active:
ssh -f -N -L 7878:localhost:7878 -L 6333:localhost:6333 root@91.98.224.44
"""
import json
import os
import sys
from datetime import datetime
import httpx
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
# Configure DSPy
import dspy
# Use GPT-4o-mini for fast testing
lm = dspy.LM(
model="openai/gpt-4o-mini",
temperature=0.3,
max_tokens=2000,
)
dspy.configure(lm=lm)
from backend.rag.dspy_heritage_rag import (
HeritageRAGPipeline,
HeritageQueryRouter,
HeritageSPARQLGenerator,
HeritageEntityExtractor,
)
def test_sparql_endpoint():
"""Test direct SPARQL access."""
print("\n" + "="*60)
print("Testing SPARQL Endpoint (localhost:7878)")
print("="*60)
# Count custodians
query = """
PREFIX hc: <https://nde.nl/ontology/hc/class/>
SELECT (COUNT(*) as ?count) WHERE { ?s a hc:Custodian }
"""
response = httpx.post(
"http://localhost:7878/query",
content=query,
headers={
"Content-Type": "application/sparql-query",
"Accept": "application/sparql-results+json",
},
timeout=30.0,
)
if response.status_code == 200:
data = response.json()
count = data["results"]["bindings"][0]["count"]["value"]
print(f"✓ Connected! Found {count} heritage custodians")
return True
else:
print(f"✗ Failed: {response.status_code}")
return False
def test_query_router():
"""Test query intent classification."""
print("\n" + "="*60)
print("Testing Query Router")
print("="*60)
# HeritageQueryRouter is a dspy.Module, instantiate it directly
router = HeritageQueryRouter()
test_questions = [
("Hoeveel musea zijn er in Amsterdam?", "statistical"),
("Waar is het Rijksmuseum?", "geographic"),
("Welke archieven hebben WO2 documenten?", "exploration"),
("Wat is de ISIL code van de KB?", "entity_lookup"),
("Vergelijk het Mauritshuis met het Rijksmuseum", "comparative"),
]
for question, expected in test_questions:
result = router(question=question)
status = "" if result.intent == expected else ""
print(f"{status} '{question[:40]}...'{result.intent} (expected: {expected})")
def test_sparql_generation():
"""Test SPARQL query generation."""
print("\n" + "="*60)
print("Testing SPARQL Generation")
print("="*60)
generator = dspy.ChainOfThought(HeritageSPARQLGenerator)
test_cases = [
{
"question": "Hoeveel musea zijn er in Nederland?",
"intent": "statistical",
"entities": ["museum", "Nederland"],
},
{
"question": "Geef me alle archieven in Amsterdam",
"intent": "geographic",
"entities": ["archief", "Amsterdam"],
},
{
"question": "Welke bibliotheken hebben een website?",
"intent": "exploration",
"entities": ["bibliotheek", "website"],
},
]
for tc in test_cases:
print(f"\nQuestion: {tc['question']}")
result = generator(
question=tc["question"],
intent=tc["intent"],
entities=tc["entities"],
context="",
)
print(f"SPARQL:\n{result.sparql}")
print(f"Explanation: {result.explanation}")
# Try to execute the query
try:
response = httpx.post(
"http://localhost:7878/query",
content=result.sparql,
headers={
"Content-Type": "application/sparql-query",
"Accept": "application/sparql-results+json",
},
timeout=30.0,
)
if response.status_code == 200:
data = response.json()
count = len(data.get("results", {}).get("bindings", []))
print(f"✓ Query executed successfully, {count} results")
else:
print(f"✗ Query failed: {response.status_code} - {response.text[:200]}")
except Exception as e:
print(f"✗ Query error: {e}")
def test_full_pipeline():
"""Test the full RAG pipeline."""
print("\n" + "="*60)
print("Testing Full RAG Pipeline")
print("="*60)
# Load saved pipeline
pipeline = HeritageRAGPipeline()
model_path = "backend/rag/optimized_models/heritage_rag_latest.json"
if os.path.exists(model_path):
print(f"Loading saved model from {model_path}")
pipeline.load(model_path)
else:
print("No saved model found, using default pipeline")
test_questions = [
("Hoeveel musea zijn er in Amsterdam?", "nl"),
("What archives are in The Hague?", "en"),
("Welke bibliotheken hebben sociale media?", "nl"),
]
for question, language in test_questions:
print(f"\n{''*50}")
print(f"Q: {question}")
try:
result = pipeline(question=question, language=language)
print(f"Intent: {result.intent}")
print(f"Answer: {result.answer[:200]}..." if len(result.answer) > 200 else f"Answer: {result.answer}")
if result.sparql:
print(f"SPARQL generated: {len(result.sparql)} chars")
if result.visualization:
print(f"Visualization: {result.visualization.get('type', 'none')}")
except Exception as e:
print(f"✗ Pipeline error: {e}")
def run_sample_queries():
"""Run some interesting sample queries against the live data."""
print("\n" + "="*60)
print("Sample Queries Against Live Data")
print("="*60)
queries = [
("Museums by country", """
PREFIX hc: <https://nde.nl/ontology/hc/class/>
PREFIX hcp: <https://nde.nl/ontology/hc/>
PREFIX schema: <http://schema.org/>
SELECT ?country (COUNT(?s) as ?count) WHERE {
?s a hc:Custodian ;
hcp:custodian_type "MUSEUM" ;
schema:addressCountry ?country .
} GROUP BY ?country ORDER BY DESC(?count) LIMIT 10
"""),
("Dutch archives with websites", """
PREFIX hc: <https://nde.nl/ontology/hc/class/>
PREFIX hcp: <https://nde.nl/ontology/hc/>
PREFIX schema: <http://schema.org/>
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
SELECT ?name ?homepage WHERE {
?s a hc:Custodian ;
hcp:custodian_type "ARCHIVE" ;
schema:addressCountry "NL" ;
skos:prefLabel ?name ;
foaf:homepage ?homepage .
} LIMIT 10
"""),
("Heritage institutions with social media", """
PREFIX hc: <https://nde.nl/ontology/hc/class/>
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
SELECT ?name (COUNT(?account) as ?social_count) WHERE {
?s a hc:Custodian ;
skos:prefLabel ?name ;
foaf:account ?account .
} GROUP BY ?s ?name ORDER BY DESC(?social_count) LIMIT 10
"""),
]
for name, query in queries:
print(f"\n{name}:")
try:
response = httpx.post(
"http://localhost:7878/query",
content=query.strip(),
headers={
"Content-Type": "application/sparql-query",
"Accept": "application/sparql-results+json",
},
timeout=30.0,
)
if response.status_code == 200:
data = response.json()
bindings = data.get("results", {}).get("bindings", [])
for b in bindings[:5]:
vals = [f"{k}={v['value'][:40]}" for k, v in b.items()]
print(f" {', '.join(vals)}")
if len(bindings) > 5:
print(f" ... and {len(bindings)-5} more")
else:
print(f" Error: {response.status_code}")
except Exception as e:
print(f" Error: {e}")
if __name__ == "__main__":
print("="*60)
print("DSPy Heritage RAG - Live Testing")
print(f"Started: {datetime.now().isoformat()}")
print("="*60)
# Test SPARQL first
if not test_sparql_endpoint():
print("\n⚠️ SPARQL endpoint not available!")
print("Make sure SSH tunnel is active:")
print(" ssh -f -N -L 7878:localhost:7878 root@91.98.224.44")
sys.exit(1)
# Run sample queries to show live data
run_sample_queries()
# Test DSPy components
test_query_router()
test_sparql_generation()
test_full_pipeline()
print("\n" + "="*60)
print("Testing Complete!")
print("="*60)