glam/scripts/test_hybrid_retriever.py
2025-12-09 09:16:19 +01:00

298 lines
10 KiB
Python

#!/usr/bin/env python3
"""
Test script for the Hybrid Retriever (Vector + Knowledge Graph)
Tests the combination of Qdrant vector search with Oxigraph SPARQL expansion.
Usage:
# Local development (requires SSH tunnel for Oxigraph)
python scripts/test_hybrid_retriever.py
# Production endpoints
python scripts/test_hybrid_retriever.py --production
# Custom query
python scripts/test_hybrid_retriever.py --query "Dutch colonial museums"
"""
import argparse
import json
import logging
import sys
from pathlib import Path
# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s"
)
logger = logging.getLogger(__name__)
def test_sparql_connectivity(endpoint: str) -> bool:
"""Test SPARQL endpoint connectivity."""
import httpx
query = """
PREFIX hcc: <https://nde.nl/ontology/hc/class/>
SELECT (COUNT(DISTINCT ?s) as ?count) WHERE { ?s a hcc:Custodian }
"""
try:
response = httpx.post(
endpoint,
data={"query": query},
headers={"Accept": "application/sparql-results+json"},
timeout=10.0
)
response.raise_for_status()
data = response.json()
count = data.get("results", {}).get("bindings", [{}])[0].get("count", {}).get("value", "0")
logger.info(f"✓ SPARQL endpoint connected: {count} custodians")
return True
except Exception as e:
logger.error(f"✗ SPARQL endpoint failed: {e}")
return False
def test_qdrant_connectivity(host: str, port: int, https: bool = False, prefix: str | None = None) -> bool:
"""Test Qdrant connectivity."""
try:
from qdrant_client import QdrantClient
if https:
client = QdrantClient(host=host, port=port, https=True, prefix=prefix, prefer_grpc=False, timeout=10)
else:
client = QdrantClient(host=host, port=port)
collections = client.get_collections()
collection_names = [c.name for c in collections.collections]
logger.info(f"✓ Qdrant connected: {len(collection_names)} collections - {collection_names}")
# Check heritage_custodians collection
if "heritage_custodians" in collection_names:
info = client.get_collection("heritage_custodians")
logger.info(f" → heritage_custodians: {info.vectors_count} vectors")
return True
except Exception as e:
logger.error(f"✗ Qdrant connection failed: {e}")
return False
def test_hybrid_search(retriever, query: str) -> None:
"""Test hybrid search functionality."""
logger.info(f"\n{'='*60}")
logger.info(f"Hybrid Search: '{query}'")
logger.info(f"{'='*60}")
try:
results = retriever.search(query, k=5)
logger.info(f"\nFound {len(results)} results:\n")
for i, result in enumerate(results, 1):
logger.info(f"{i}. {result.name}")
logger.info(f" GHCID: {result.ghcid}")
logger.info(f" Scores: vector={result.vector_score:.3f}, graph={result.graph_score:.3f}, combined={result.combined_score:.3f}")
if result.institution_type:
logger.info(f" Type: {result.institution_type}")
if result.city:
logger.info(f" City: {result.city}")
if result.expansion_reason:
logger.info(f" Graph expansion: {result.expansion_reason}")
logger.info("")
# Test DSPy-compatible interface
logger.info("\nDSPy-compatible output (passage texts):")
passages = retriever(query, k=3)
for i, passage in enumerate(passages, 1):
logger.info(f" {i}. {passage[:100]}...")
except Exception as e:
logger.error(f"Hybrid search failed: {e}")
raise
def test_graph_expansion_queries(sparql_endpoint: str) -> None:
"""Test individual SPARQL expansion queries."""
import httpx
logger.info("\n" + "="*60)
logger.info("Testing Graph Expansion Queries")
logger.info("="*60)
# Test 1: Find institutions by city code
query1 = """
PREFIX hc: <https://nde.nl/ontology/hc/>
PREFIX hcc: <https://nde.nl/ontology/hc/class/>
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
SELECT ?name ?ghcid WHERE {
?s a hcc:Custodian ;
skos:prefLabel ?name ;
hc:ghcid ?ghcid .
FILTER(CONTAINS(?ghcid, "-AMS-"))
}
LIMIT 5
"""
try:
response = httpx.post(
sparql_endpoint,
data={"query": query1},
headers={"Accept": "application/sparql-results+json"},
timeout=10.0
)
data = response.json()
results = data.get("results", {}).get("bindings", [])
logger.info(f"\n1. Institutions in Amsterdam (city code AMS): {len(results)} found")
for r in results[:3]:
logger.info(f" - {r.get('name', {}).get('value', '')} ({r.get('ghcid', {}).get('value', '')})")
except Exception as e:
logger.error(f" Query failed: {e}")
# Test 2: Find museums (type code M)
query2 = """
PREFIX hc: <https://nde.nl/ontology/hc/>
PREFIX hcc: <https://nde.nl/ontology/hc/class/>
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
SELECT ?name ?ghcid WHERE {
?s a hcc:Custodian ;
skos:prefLabel ?name ;
hc:ghcid ?ghcid .
FILTER(STRSTARTS(?ghcid, "NL-"))
FILTER(CONTAINS(?ghcid, "-M-"))
}
LIMIT 5
"""
try:
response = httpx.post(
sparql_endpoint,
data={"query": query2},
headers={"Accept": "application/sparql-results+json"},
timeout=10.0
)
data = response.json()
results = data.get("results", {}).get("bindings", [])
logger.info(f"\n2. Dutch Museums (NL-*-*-M-*): {len(results)} found")
for r in results[:3]:
logger.info(f" - {r.get('name', {}).get('value', '')} ({r.get('ghcid', {}).get('value', '')})")
except Exception as e:
logger.error(f" Query failed: {e}")
# Test 3: Institutions by Wikidata country
query3 = """
PREFIX hc: <https://nde.nl/ontology/hc/>
PREFIX hcc: <https://nde.nl/ontology/hc/class/>
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
PREFIX wdt: <http://www.wikidata.org/prop/direct/>
PREFIX wd: <http://www.wikidata.org/entity/>
SELECT ?name ?ghcid WHERE {
?s a hcc:Custodian ;
skos:prefLabel ?name ;
hc:ghcid ?ghcid ;
wdt:P17 wd:Q55 . # Netherlands
}
LIMIT 5
"""
try:
response = httpx.post(
sparql_endpoint,
data={"query": query3},
headers={"Accept": "application/sparql-results+json"},
timeout=10.0
)
data = response.json()
results = data.get("results", {}).get("bindings", [])
logger.info(f"\n3. Institutions with Wikidata P17=Q55 (Netherlands): {len(results)} found")
for r in results[:3]:
logger.info(f" - {r.get('name', {}).get('value', '')} ({r.get('ghcid', {}).get('value', '')})")
except Exception as e:
logger.error(f" Query failed: {e}")
def main():
parser = argparse.ArgumentParser(description="Test Hybrid Retriever")
parser.add_argument("--production", action="store_true", help="Use production endpoints")
parser.add_argument("--query", type=str, default="Dutch art museums", help="Search query to test")
parser.add_argument("--skip-search", action="store_true", help="Skip search test (requires OpenAI key)")
args = parser.parse_args()
# Determine endpoints
if args.production:
sparql_endpoint = "https://bronhouder.nl/query"
qdrant_host = "bronhouder.nl"
qdrant_port = 443
qdrant_https = True
qdrant_prefix = "qdrant"
logger.info("Using PRODUCTION endpoints")
else:
sparql_endpoint = "http://localhost:7878/query"
qdrant_host = "localhost"
qdrant_port = 6333
qdrant_https = False
qdrant_prefix = None
logger.info("Using LOCAL endpoints (ensure SSH tunnel is active)")
# Test 1: SPARQL connectivity
logger.info("\n" + "="*60)
logger.info("Testing Connectivity")
logger.info("="*60)
sparql_ok = test_sparql_connectivity(sparql_endpoint)
qdrant_ok = test_qdrant_connectivity(qdrant_host, qdrant_port, qdrant_https, qdrant_prefix)
if not sparql_ok or not qdrant_ok:
logger.warning("\nSome connectivity tests failed. Continuing with available services...")
# Test 2: Graph expansion queries (SPARQL only)
if sparql_ok:
test_graph_expansion_queries(sparql_endpoint)
# Test 3: Hybrid search (requires both services + OpenAI)
if not args.skip_search and sparql_ok and qdrant_ok:
import os
if not os.getenv("OPENAI_API_KEY"):
logger.warning("\nOPENAI_API_KEY not set, skipping hybrid search test")
logger.info("Set OPENAI_API_KEY to test vector search")
else:
from glam_extractor.api.hybrid_retriever import create_hybrid_retriever
retriever = create_hybrid_retriever(use_production=args.production)
# Print stats
logger.info("\n" + "="*60)
logger.info("Retriever Statistics")
logger.info("="*60)
stats = retriever.get_stats()
logger.info(json.dumps(stats, indent=2))
# Test search
test_hybrid_search(retriever, args.query)
# Additional test queries
test_queries = [
"libraries with medieval manuscripts",
"museums in Prague",
"Japanese national archives",
]
for q in test_queries:
test_hybrid_search(retriever, q)
retriever.close()
logger.info("\n" + "="*60)
logger.info("Test Complete")
logger.info("="*60)
if __name__ == "__main__":
main()