298 lines
10 KiB
Python
298 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for the Hybrid Retriever (Vector + Knowledge Graph)
|
|
|
|
Tests the combination of Qdrant vector search with Oxigraph SPARQL expansion.
|
|
|
|
Usage:
|
|
# Local development (requires SSH tunnel for Oxigraph)
|
|
python scripts/test_hybrid_retriever.py
|
|
|
|
# Production endpoints
|
|
python scripts/test_hybrid_retriever.py --production
|
|
|
|
# Custom query
|
|
python scripts/test_hybrid_retriever.py --query "Dutch colonial museums"
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add src to path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def test_sparql_connectivity(endpoint: str) -> bool:
|
|
"""Test SPARQL endpoint connectivity."""
|
|
import httpx
|
|
|
|
query = """
|
|
PREFIX hcc: <https://nde.nl/ontology/hc/class/>
|
|
SELECT (COUNT(DISTINCT ?s) as ?count) WHERE { ?s a hcc:Custodian }
|
|
"""
|
|
|
|
try:
|
|
response = httpx.post(
|
|
endpoint,
|
|
data={"query": query},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
timeout=10.0
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
count = data.get("results", {}).get("bindings", [{}])[0].get("count", {}).get("value", "0")
|
|
logger.info(f"✓ SPARQL endpoint connected: {count} custodians")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"✗ SPARQL endpoint failed: {e}")
|
|
return False
|
|
|
|
|
|
def test_qdrant_connectivity(host: str, port: int, https: bool = False, prefix: str | None = None) -> bool:
|
|
"""Test Qdrant connectivity."""
|
|
try:
|
|
from qdrant_client import QdrantClient
|
|
|
|
if https:
|
|
client = QdrantClient(host=host, port=port, https=True, prefix=prefix, prefer_grpc=False, timeout=10)
|
|
else:
|
|
client = QdrantClient(host=host, port=port)
|
|
|
|
collections = client.get_collections()
|
|
collection_names = [c.name for c in collections.collections]
|
|
logger.info(f"✓ Qdrant connected: {len(collection_names)} collections - {collection_names}")
|
|
|
|
# Check heritage_custodians collection
|
|
if "heritage_custodians" in collection_names:
|
|
info = client.get_collection("heritage_custodians")
|
|
logger.info(f" → heritage_custodians: {info.vectors_count} vectors")
|
|
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"✗ Qdrant connection failed: {e}")
|
|
return False
|
|
|
|
|
|
def test_hybrid_search(retriever, query: str) -> None:
|
|
"""Test hybrid search functionality."""
|
|
logger.info(f"\n{'='*60}")
|
|
logger.info(f"Hybrid Search: '{query}'")
|
|
logger.info(f"{'='*60}")
|
|
|
|
try:
|
|
results = retriever.search(query, k=5)
|
|
|
|
logger.info(f"\nFound {len(results)} results:\n")
|
|
|
|
for i, result in enumerate(results, 1):
|
|
logger.info(f"{i}. {result.name}")
|
|
logger.info(f" GHCID: {result.ghcid}")
|
|
logger.info(f" Scores: vector={result.vector_score:.3f}, graph={result.graph_score:.3f}, combined={result.combined_score:.3f}")
|
|
if result.institution_type:
|
|
logger.info(f" Type: {result.institution_type}")
|
|
if result.city:
|
|
logger.info(f" City: {result.city}")
|
|
if result.expansion_reason:
|
|
logger.info(f" Graph expansion: {result.expansion_reason}")
|
|
logger.info("")
|
|
|
|
# Test DSPy-compatible interface
|
|
logger.info("\nDSPy-compatible output (passage texts):")
|
|
passages = retriever(query, k=3)
|
|
for i, passage in enumerate(passages, 1):
|
|
logger.info(f" {i}. {passage[:100]}...")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Hybrid search failed: {e}")
|
|
raise
|
|
|
|
|
|
def test_graph_expansion_queries(sparql_endpoint: str) -> None:
|
|
"""Test individual SPARQL expansion queries."""
|
|
import httpx
|
|
|
|
logger.info("\n" + "="*60)
|
|
logger.info("Testing Graph Expansion Queries")
|
|
logger.info("="*60)
|
|
|
|
# Test 1: Find institutions by city code
|
|
query1 = """
|
|
PREFIX hc: <https://nde.nl/ontology/hc/>
|
|
PREFIX hcc: <https://nde.nl/ontology/hc/class/>
|
|
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
|
|
|
|
SELECT ?name ?ghcid WHERE {
|
|
?s a hcc:Custodian ;
|
|
skos:prefLabel ?name ;
|
|
hc:ghcid ?ghcid .
|
|
FILTER(CONTAINS(?ghcid, "-AMS-"))
|
|
}
|
|
LIMIT 5
|
|
"""
|
|
|
|
try:
|
|
response = httpx.post(
|
|
sparql_endpoint,
|
|
data={"query": query1},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
timeout=10.0
|
|
)
|
|
data = response.json()
|
|
results = data.get("results", {}).get("bindings", [])
|
|
logger.info(f"\n1. Institutions in Amsterdam (city code AMS): {len(results)} found")
|
|
for r in results[:3]:
|
|
logger.info(f" - {r.get('name', {}).get('value', '')} ({r.get('ghcid', {}).get('value', '')})")
|
|
except Exception as e:
|
|
logger.error(f" Query failed: {e}")
|
|
|
|
# Test 2: Find museums (type code M)
|
|
query2 = """
|
|
PREFIX hc: <https://nde.nl/ontology/hc/>
|
|
PREFIX hcc: <https://nde.nl/ontology/hc/class/>
|
|
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
|
|
|
|
SELECT ?name ?ghcid WHERE {
|
|
?s a hcc:Custodian ;
|
|
skos:prefLabel ?name ;
|
|
hc:ghcid ?ghcid .
|
|
FILTER(STRSTARTS(?ghcid, "NL-"))
|
|
FILTER(CONTAINS(?ghcid, "-M-"))
|
|
}
|
|
LIMIT 5
|
|
"""
|
|
|
|
try:
|
|
response = httpx.post(
|
|
sparql_endpoint,
|
|
data={"query": query2},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
timeout=10.0
|
|
)
|
|
data = response.json()
|
|
results = data.get("results", {}).get("bindings", [])
|
|
logger.info(f"\n2. Dutch Museums (NL-*-*-M-*): {len(results)} found")
|
|
for r in results[:3]:
|
|
logger.info(f" - {r.get('name', {}).get('value', '')} ({r.get('ghcid', {}).get('value', '')})")
|
|
except Exception as e:
|
|
logger.error(f" Query failed: {e}")
|
|
|
|
# Test 3: Institutions by Wikidata country
|
|
query3 = """
|
|
PREFIX hc: <https://nde.nl/ontology/hc/>
|
|
PREFIX hcc: <https://nde.nl/ontology/hc/class/>
|
|
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
|
|
PREFIX wdt: <http://www.wikidata.org/prop/direct/>
|
|
PREFIX wd: <http://www.wikidata.org/entity/>
|
|
|
|
SELECT ?name ?ghcid WHERE {
|
|
?s a hcc:Custodian ;
|
|
skos:prefLabel ?name ;
|
|
hc:ghcid ?ghcid ;
|
|
wdt:P17 wd:Q55 . # Netherlands
|
|
}
|
|
LIMIT 5
|
|
"""
|
|
|
|
try:
|
|
response = httpx.post(
|
|
sparql_endpoint,
|
|
data={"query": query3},
|
|
headers={"Accept": "application/sparql-results+json"},
|
|
timeout=10.0
|
|
)
|
|
data = response.json()
|
|
results = data.get("results", {}).get("bindings", [])
|
|
logger.info(f"\n3. Institutions with Wikidata P17=Q55 (Netherlands): {len(results)} found")
|
|
for r in results[:3]:
|
|
logger.info(f" - {r.get('name', {}).get('value', '')} ({r.get('ghcid', {}).get('value', '')})")
|
|
except Exception as e:
|
|
logger.error(f" Query failed: {e}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Test Hybrid Retriever")
|
|
parser.add_argument("--production", action="store_true", help="Use production endpoints")
|
|
parser.add_argument("--query", type=str, default="Dutch art museums", help="Search query to test")
|
|
parser.add_argument("--skip-search", action="store_true", help="Skip search test (requires OpenAI key)")
|
|
args = parser.parse_args()
|
|
|
|
# Determine endpoints
|
|
if args.production:
|
|
sparql_endpoint = "https://bronhouder.nl/query"
|
|
qdrant_host = "bronhouder.nl"
|
|
qdrant_port = 443
|
|
qdrant_https = True
|
|
qdrant_prefix = "qdrant"
|
|
logger.info("Using PRODUCTION endpoints")
|
|
else:
|
|
sparql_endpoint = "http://localhost:7878/query"
|
|
qdrant_host = "localhost"
|
|
qdrant_port = 6333
|
|
qdrant_https = False
|
|
qdrant_prefix = None
|
|
logger.info("Using LOCAL endpoints (ensure SSH tunnel is active)")
|
|
|
|
# Test 1: SPARQL connectivity
|
|
logger.info("\n" + "="*60)
|
|
logger.info("Testing Connectivity")
|
|
logger.info("="*60)
|
|
|
|
sparql_ok = test_sparql_connectivity(sparql_endpoint)
|
|
qdrant_ok = test_qdrant_connectivity(qdrant_host, qdrant_port, qdrant_https, qdrant_prefix)
|
|
|
|
if not sparql_ok or not qdrant_ok:
|
|
logger.warning("\nSome connectivity tests failed. Continuing with available services...")
|
|
|
|
# Test 2: Graph expansion queries (SPARQL only)
|
|
if sparql_ok:
|
|
test_graph_expansion_queries(sparql_endpoint)
|
|
|
|
# Test 3: Hybrid search (requires both services + OpenAI)
|
|
if not args.skip_search and sparql_ok and qdrant_ok:
|
|
import os
|
|
if not os.getenv("OPENAI_API_KEY"):
|
|
logger.warning("\nOPENAI_API_KEY not set, skipping hybrid search test")
|
|
logger.info("Set OPENAI_API_KEY to test vector search")
|
|
else:
|
|
from glam_extractor.api.hybrid_retriever import create_hybrid_retriever
|
|
|
|
retriever = create_hybrid_retriever(use_production=args.production)
|
|
|
|
# Print stats
|
|
logger.info("\n" + "="*60)
|
|
logger.info("Retriever Statistics")
|
|
logger.info("="*60)
|
|
stats = retriever.get_stats()
|
|
logger.info(json.dumps(stats, indent=2))
|
|
|
|
# Test search
|
|
test_hybrid_search(retriever, args.query)
|
|
|
|
# Additional test queries
|
|
test_queries = [
|
|
"libraries with medieval manuscripts",
|
|
"museums in Prague",
|
|
"Japanese national archives",
|
|
]
|
|
|
|
for q in test_queries:
|
|
test_hybrid_search(retriever, q)
|
|
|
|
retriever.close()
|
|
|
|
logger.info("\n" + "="*60)
|
|
logger.info("Test Complete")
|
|
logger.info("="*60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|