527 lines
17 KiB
Python
527 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test DSPy Heritage RAG with live SPARQL endpoint.
|
|
|
|
Requires SSH tunnel to be active:
|
|
ssh -f -N -L 7878:localhost:7878 -L 6333:localhost:6333 root@91.98.224.44
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
from datetime import datetime
|
|
|
|
import httpx
|
|
|
|
# Add project root to path
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
|
|
|
# Configure DSPy
|
|
import dspy
|
|
|
|
# Use GPT-4o-mini for fast testing
|
|
lm = dspy.LM(
|
|
model="openai/gpt-4o-mini",
|
|
temperature=0.3,
|
|
max_tokens=2000,
|
|
)
|
|
dspy.configure(lm=lm)
|
|
|
|
from backend.rag.dspy_heritage_rag import (
|
|
HeritageRAGPipeline,
|
|
HeritageQueryRouter,
|
|
HeritageSPARQLGenerator,
|
|
HeritageEntityExtractor,
|
|
MultiHopHeritageRetriever,
|
|
SCHEMA_LOADER_AVAILABLE,
|
|
get_schema_aware_sparql_signature,
|
|
get_schema_aware_entity_signature,
|
|
get_schema_aware_answer_signature,
|
|
validate_custodian_type,
|
|
)
|
|
|
|
|
|
def test_sparql_endpoint():
|
|
"""Test direct SPARQL access."""
|
|
print("\n" + "="*60)
|
|
print("Testing SPARQL Endpoint (localhost:7878)")
|
|
print("="*60)
|
|
|
|
# Count custodians
|
|
query = """
|
|
PREFIX hc: <https://nde.nl/ontology/hc/class/>
|
|
SELECT (COUNT(*) as ?count) WHERE { ?s a hc:Custodian }
|
|
"""
|
|
|
|
response = httpx.post(
|
|
"http://localhost:7878/query",
|
|
content=query,
|
|
headers={
|
|
"Content-Type": "application/sparql-query",
|
|
"Accept": "application/sparql-results+json",
|
|
},
|
|
timeout=30.0,
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
count = data["results"]["bindings"][0]["count"]["value"]
|
|
print(f"✓ Connected! Found {count} heritage custodians")
|
|
return True
|
|
else:
|
|
print(f"✗ Failed: {response.status_code}")
|
|
return False
|
|
|
|
|
|
def test_query_router():
|
|
"""Test query intent classification."""
|
|
print("\n" + "="*60)
|
|
print("Testing Query Router")
|
|
print("="*60)
|
|
|
|
# HeritageQueryRouter is a dspy.Module, instantiate it directly
|
|
router = HeritageQueryRouter()
|
|
|
|
test_questions = [
|
|
("Hoeveel musea zijn er in Amsterdam?", "statistical"),
|
|
("Waar is het Rijksmuseum?", "geographic"),
|
|
("Welke archieven hebben WO2 documenten?", "exploration"),
|
|
("Wat is de ISIL code van de KB?", "entity_lookup"),
|
|
("Vergelijk het Mauritshuis met het Rijksmuseum", "comparative"),
|
|
]
|
|
|
|
for question, expected in test_questions:
|
|
result = router(question=question)
|
|
status = "✓" if result.intent == expected else "✗"
|
|
print(f"{status} '{question[:40]}...' → {result.intent} (expected: {expected})")
|
|
|
|
|
|
def test_sparql_generation():
|
|
"""Test SPARQL query generation."""
|
|
print("\n" + "="*60)
|
|
print("Testing SPARQL Generation")
|
|
print("="*60)
|
|
|
|
generator = dspy.ChainOfThought(HeritageSPARQLGenerator)
|
|
|
|
test_cases = [
|
|
{
|
|
"question": "Hoeveel musea zijn er in Nederland?",
|
|
"intent": "statistical",
|
|
"entities": ["museum", "Nederland"],
|
|
},
|
|
{
|
|
"question": "Geef me alle archieven in Amsterdam",
|
|
"intent": "geographic",
|
|
"entities": ["archief", "Amsterdam"],
|
|
},
|
|
{
|
|
"question": "Welke bibliotheken hebben een website?",
|
|
"intent": "exploration",
|
|
"entities": ["bibliotheek", "website"],
|
|
},
|
|
]
|
|
|
|
for tc in test_cases:
|
|
print(f"\nQuestion: {tc['question']}")
|
|
result = generator(
|
|
question=tc["question"],
|
|
intent=tc["intent"],
|
|
entities=tc["entities"],
|
|
context="",
|
|
)
|
|
print(f"SPARQL:\n{result.sparql}")
|
|
print(f"Explanation: {result.explanation}")
|
|
|
|
# Try to execute the query
|
|
try:
|
|
response = httpx.post(
|
|
"http://localhost:7878/query",
|
|
content=result.sparql,
|
|
headers={
|
|
"Content-Type": "application/sparql-query",
|
|
"Accept": "application/sparql-results+json",
|
|
},
|
|
timeout=30.0,
|
|
)
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
count = len(data.get("results", {}).get("bindings", []))
|
|
print(f"✓ Query executed successfully, {count} results")
|
|
else:
|
|
print(f"✗ Query failed: {response.status_code} - {response.text[:200]}")
|
|
except Exception as e:
|
|
print(f"✗ Query error: {e}")
|
|
|
|
|
|
def test_full_pipeline():
|
|
"""Test the full RAG pipeline."""
|
|
print("\n" + "="*60)
|
|
print("Testing Full RAG Pipeline")
|
|
print("="*60)
|
|
|
|
# Load saved pipeline
|
|
pipeline = HeritageRAGPipeline()
|
|
model_path = "backend/rag/optimized_models/heritage_rag_latest.json"
|
|
|
|
if os.path.exists(model_path):
|
|
print(f"Loading saved model from {model_path}")
|
|
pipeline.load(model_path)
|
|
else:
|
|
print("No saved model found, using default pipeline")
|
|
|
|
test_questions = [
|
|
("Hoeveel musea zijn er in Amsterdam?", "nl"),
|
|
("What archives are in The Hague?", "en"),
|
|
("Welke bibliotheken hebben sociale media?", "nl"),
|
|
]
|
|
|
|
for question, language in test_questions:
|
|
print(f"\n{'─'*50}")
|
|
print(f"Q: {question}")
|
|
|
|
try:
|
|
result = pipeline(question=question, language=language)
|
|
print(f"Intent: {result.intent}")
|
|
print(f"Answer: {result.answer[:200]}..." if len(result.answer) > 200 else f"Answer: {result.answer}")
|
|
if result.sparql:
|
|
print(f"SPARQL generated: {len(result.sparql)} chars")
|
|
if result.visualization:
|
|
print(f"Visualization: {result.visualization.get('type', 'none')}")
|
|
except Exception as e:
|
|
print(f"✗ Pipeline error: {e}")
|
|
|
|
|
|
def run_sample_queries():
|
|
"""Run some interesting sample queries against the live data."""
|
|
print("\n" + "="*60)
|
|
print("Sample Queries Against Live Data")
|
|
print("="*60)
|
|
|
|
queries = [
|
|
("Museums by country", """
|
|
PREFIX hc: <https://nde.nl/ontology/hc/class/>
|
|
PREFIX hcp: <https://nde.nl/ontology/hc/>
|
|
PREFIX schema: <http://schema.org/>
|
|
SELECT ?country (COUNT(?s) as ?count) WHERE {
|
|
?s a hc:Custodian ;
|
|
hcp:custodian_type "MUSEUM" ;
|
|
schema:addressCountry ?country .
|
|
} GROUP BY ?country ORDER BY DESC(?count) LIMIT 10
|
|
"""),
|
|
("Dutch archives with websites", """
|
|
PREFIX hc: <https://nde.nl/ontology/hc/class/>
|
|
PREFIX hcp: <https://nde.nl/ontology/hc/>
|
|
PREFIX schema: <http://schema.org/>
|
|
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
|
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
|
|
SELECT ?name ?homepage WHERE {
|
|
?s a hc:Custodian ;
|
|
hcp:custodian_type "ARCHIVE" ;
|
|
schema:addressCountry "NL" ;
|
|
skos:prefLabel ?name ;
|
|
foaf:homepage ?homepage .
|
|
} LIMIT 10
|
|
"""),
|
|
("Heritage institutions with social media", """
|
|
PREFIX hc: <https://nde.nl/ontology/hc/class/>
|
|
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
|
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
|
|
SELECT ?name (COUNT(?account) as ?social_count) WHERE {
|
|
?s a hc:Custodian ;
|
|
skos:prefLabel ?name ;
|
|
foaf:account ?account .
|
|
} GROUP BY ?s ?name ORDER BY DESC(?social_count) LIMIT 10
|
|
"""),
|
|
]
|
|
|
|
for name, query in queries:
|
|
print(f"\n{name}:")
|
|
try:
|
|
response = httpx.post(
|
|
"http://localhost:7878/query",
|
|
content=query.strip(),
|
|
headers={
|
|
"Content-Type": "application/sparql-query",
|
|
"Accept": "application/sparql-results+json",
|
|
},
|
|
timeout=30.0,
|
|
)
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
bindings = data.get("results", {}).get("bindings", [])
|
|
for b in bindings[:5]:
|
|
vals = [f"{k}={v['value'][:40]}" for k, v in b.items()]
|
|
print(f" {', '.join(vals)}")
|
|
if len(bindings) > 5:
|
|
print(f" ... and {len(bindings)-5} more")
|
|
else:
|
|
print(f" Error: {response.status_code}")
|
|
except Exception as e:
|
|
print(f" Error: {e}")
|
|
|
|
|
|
def test_schema_aware_signatures():
|
|
"""Test schema-aware signature functionality."""
|
|
print("\n" + "="*60)
|
|
print("Testing Schema-Aware Signatures")
|
|
print("="*60)
|
|
|
|
print(f"Schema loader available: {SCHEMA_LOADER_AVAILABLE}")
|
|
|
|
if not SCHEMA_LOADER_AVAILABLE:
|
|
print("⚠️ Schema loader not available, skipping schema-aware tests")
|
|
return
|
|
|
|
# Test signature retrieval
|
|
print("\n1. Testing signature factories:")
|
|
try:
|
|
sparql_sig = get_schema_aware_sparql_signature()
|
|
print(f" ✓ SPARQL signature: {sparql_sig.__name__}")
|
|
print(f" Docstring length: {len(sparql_sig.__doc__)} chars")
|
|
except Exception as e:
|
|
print(f" ✗ SPARQL signature failed: {e}")
|
|
|
|
try:
|
|
entity_sig = get_schema_aware_entity_signature()
|
|
print(f" ✓ Entity signature: {entity_sig.__name__}")
|
|
print(f" Docstring length: {len(entity_sig.__doc__)} chars")
|
|
except Exception as e:
|
|
print(f" ✗ Entity signature failed: {e}")
|
|
|
|
try:
|
|
answer_sig = get_schema_aware_answer_signature()
|
|
print(f" ✓ Answer signature: {answer_sig.__name__}")
|
|
print(f" Docstring length: {len(answer_sig.__doc__)} chars")
|
|
except Exception as e:
|
|
print(f" ✗ Answer signature failed: {e}")
|
|
|
|
# Test custodian type validation
|
|
print("\n2. Testing custodian type validation:")
|
|
valid_types = ["MUSEUM", "LIBRARY", "ARCHIVE", "GALLERY"]
|
|
invalid_types = ["museum", "INVALID_TYPE", "", "123"]
|
|
|
|
for t in valid_types:
|
|
result = validate_custodian_type(t)
|
|
status = "✓" if result else "✗"
|
|
print(f" {status} validate_custodian_type('{t}'): {result}")
|
|
|
|
for t in invalid_types:
|
|
result = validate_custodian_type(t)
|
|
status = "✓" if not result else "✗" # These should be False
|
|
print(f" {status} validate_custodian_type('{t}'): {result} (expected: False)")
|
|
|
|
# Test schema-aware SPARQL generation
|
|
print("\n3. Testing schema-aware SPARQL generation:")
|
|
try:
|
|
schema_sparql_gen = dspy.ChainOfThought(get_schema_aware_sparql_signature())
|
|
result = schema_sparql_gen(
|
|
question="Hoeveel musea zijn er in Amsterdam?",
|
|
intent="statistical",
|
|
entities=["museum", "Amsterdam"],
|
|
context="",
|
|
)
|
|
print(f" ✓ Schema-aware SPARQL generated:")
|
|
print(f" Query length: {len(result.sparql)} chars")
|
|
print(f" Explanation: {result.explanation[:100]}...")
|
|
|
|
# Try to execute the query
|
|
response = httpx.post(
|
|
"http://localhost:7878/query",
|
|
content=result.sparql,
|
|
headers={
|
|
"Content-Type": "application/sparql-query",
|
|
"Accept": "application/sparql-results+json",
|
|
},
|
|
timeout=30.0,
|
|
)
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
count = len(data.get("results", {}).get("bindings", []))
|
|
print(f" ✓ Query executed: {count} results")
|
|
else:
|
|
print(f" ✗ Query failed: {response.status_code}")
|
|
except Exception as e:
|
|
print(f" ✗ Schema-aware SPARQL generation failed: {e}")
|
|
|
|
# Test MultiHopHeritageRetriever with schema-aware signatures
|
|
print("\n4. Testing MultiHopHeritageRetriever (schema-aware):")
|
|
try:
|
|
retriever = MultiHopHeritageRetriever(max_hops=2, use_schema_aware=True)
|
|
print(f" ✓ Created retriever with use_schema_aware={retriever.use_schema_aware}")
|
|
except Exception as e:
|
|
print(f" ✗ Failed to create schema-aware retriever: {e}")
|
|
|
|
print("\nSchema-aware signature tests complete!")
|
|
|
|
|
|
def test_multi_turn_conversation():
|
|
"""Test multi-turn conversation with dspy.History."""
|
|
print("\n" + "="*60)
|
|
print("Testing Multi-Turn Conversation")
|
|
print("="*60)
|
|
|
|
from dspy import History
|
|
|
|
pipeline = HeritageRAGPipeline()
|
|
|
|
# Simulate a multi-turn conversation
|
|
conversation = []
|
|
|
|
# Turn 1: Initial query about museums in Amsterdam
|
|
question1 = "Hoeveel musea zijn er in Amsterdam?"
|
|
print(f"\nTurn 1: {question1}")
|
|
|
|
try:
|
|
history1 = History(messages=[]) # Empty history for first turn
|
|
result1 = pipeline(
|
|
question=question1,
|
|
language="nl",
|
|
history=history1,
|
|
include_viz=False,
|
|
)
|
|
print(f" Intent: {result1.intent}")
|
|
print(f" Answer: {result1.answer[:150]}..." if len(result1.answer) > 150 else f" Answer: {result1.answer}")
|
|
|
|
# Add to conversation history
|
|
conversation.append({
|
|
"question": question1,
|
|
"answer": result1.answer
|
|
})
|
|
|
|
except Exception as e:
|
|
print(f" ✗ Turn 1 failed: {e}")
|
|
return
|
|
|
|
# Turn 2: Follow-up question (should use context from turn 1)
|
|
question2 = "Welke van deze beheren ook archieven?"
|
|
print(f"\nTurn 2: {question2}")
|
|
print(" (This is a follow-up that refers to 'these' from previous turn)")
|
|
|
|
try:
|
|
history2 = History(messages=conversation)
|
|
result2 = pipeline(
|
|
question=question2,
|
|
language="nl",
|
|
history=history2,
|
|
include_viz=False,
|
|
)
|
|
|
|
# Check if the resolved_question was captured
|
|
resolved = getattr(result2, 'resolved_question', None)
|
|
if resolved and resolved != question2:
|
|
print(f" ✓ Query resolved: {resolved[:100]}...")
|
|
|
|
print(f" Intent: {result2.intent}")
|
|
print(f" Answer: {result2.answer[:150]}..." if len(result2.answer) > 150 else f" Answer: {result2.answer}")
|
|
|
|
# Add to conversation
|
|
conversation.append({
|
|
"question": question2,
|
|
"answer": result2.answer
|
|
})
|
|
|
|
except Exception as e:
|
|
print(f" ✗ Turn 2 failed: {e}")
|
|
return
|
|
|
|
# Turn 3: Another follow-up
|
|
question3 = "Geef me de websites van de eerste drie"
|
|
print(f"\nTurn 3: {question3}")
|
|
print(" (This refers to 'the first three' from previous results)")
|
|
|
|
try:
|
|
history3 = History(messages=conversation)
|
|
result3 = pipeline(
|
|
question=question3,
|
|
language="nl",
|
|
history=history3,
|
|
include_viz=False,
|
|
)
|
|
print(f" Intent: {result3.intent}")
|
|
print(f" Answer: {result3.answer[:150]}..." if len(result3.answer) > 150 else f" Answer: {result3.answer}")
|
|
|
|
except Exception as e:
|
|
print(f" ✗ Turn 3 failed: {e}")
|
|
|
|
print("\n✓ Multi-turn conversation test complete!")
|
|
print(f" Total turns: {len(conversation) + 1}")
|
|
print(f" History messages: {len(conversation)}")
|
|
|
|
|
|
def test_query_router_with_history():
|
|
"""Test that HeritageQueryRouter properly resolves follow-up questions."""
|
|
print("\n" + "="*60)
|
|
print("Testing Query Router with History")
|
|
print("="*60)
|
|
|
|
from dspy import History
|
|
|
|
router = HeritageQueryRouter()
|
|
|
|
# Test 1: Initial question (no history)
|
|
q1 = "Toon alle musea in Den Haag"
|
|
print(f"\n1. Initial query: {q1}")
|
|
|
|
result1 = router(question=q1, language="nl")
|
|
print(f" Intent: {result1.intent}")
|
|
print(f" Entities: {result1.entities}")
|
|
resolved1 = getattr(result1, 'resolved_question', q1)
|
|
print(f" Resolved: {resolved1}")
|
|
|
|
# Test 2: Follow-up with history
|
|
q2 = "Welke hebben een bibliotheek?"
|
|
history = History(messages=[
|
|
{"question": q1, "answer": "Ik heb 15 musea gevonden in Den Haag..."}
|
|
])
|
|
|
|
print(f"\n2. Follow-up: {q2}")
|
|
print(" (With history about Den Haag museums)")
|
|
|
|
result2 = router(question=q2, language="nl", history=history)
|
|
print(f" Intent: {result2.intent}")
|
|
print(f" Entities: {result2.entities}")
|
|
resolved2 = getattr(result2, 'resolved_question', q2)
|
|
print(f" Resolved: {resolved2}")
|
|
|
|
# Check if "Den Haag" or "musea" appears in resolved question
|
|
if "Den Haag" in resolved2 or "musea" in resolved2.lower():
|
|
print(" ✓ Context resolution working - Den Haag/musea referenced")
|
|
else:
|
|
print(" ⚠️ Context may not have been fully resolved")
|
|
|
|
print("\n✓ Query router history test complete!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("="*60)
|
|
print("DSPy Heritage RAG - Live Testing")
|
|
print(f"Started: {datetime.now().isoformat()}")
|
|
print("="*60)
|
|
|
|
# Test SPARQL first
|
|
if not test_sparql_endpoint():
|
|
print("\n⚠️ SPARQL endpoint not available!")
|
|
print("Make sure SSH tunnel is active:")
|
|
print(" ssh -f -N -L 7878:localhost:7878 root@91.98.224.44")
|
|
sys.exit(1)
|
|
|
|
# Run sample queries to show live data
|
|
run_sample_queries()
|
|
|
|
# Test DSPy components
|
|
test_query_router()
|
|
test_sparql_generation()
|
|
|
|
# Test schema-aware signatures
|
|
test_schema_aware_signatures()
|
|
|
|
# Test multi-turn conversation support
|
|
test_query_router_with_history()
|
|
test_multi_turn_conversation()
|
|
|
|
test_full_pipeline()
|
|
|
|
print("\n" + "="*60)
|
|
print("Testing Complete!")
|
|
print("="*60)
|