605 lines
21 KiB
Python
605 lines
21 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for DSPy Heritage RAG endpoints.
|
|
|
|
Tests both the DSPy module directly and the API endpoints.
|
|
|
|
Usage:
|
|
# Test DSPy module directly (no server required)
|
|
python test_dspy_rag.py --module
|
|
|
|
# Test API endpoints (requires server running on localhost:8000)
|
|
python test_dspy_rag.py --api
|
|
|
|
# Test both
|
|
python test_dspy_rag.py --all
|
|
|
|
# Test streaming endpoint
|
|
python test_dspy_rag.py --stream
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import sys
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
|
|
# =============================================================================
|
|
# TEST QUERIES - Real heritage institution questions
|
|
# =============================================================================
|
|
|
|
TEST_QUERIES = [
|
|
# Dutch queries
|
|
{
|
|
"question": "Hoeveel musea zijn er in Amsterdam?",
|
|
"language": "nl",
|
|
"expected_intent": "statistical",
|
|
"expected_entities": ["amsterdam", "musea"],
|
|
"description": "Count museums in Amsterdam",
|
|
},
|
|
{
|
|
"question": "Waar is het Rijksmuseum gevestigd?",
|
|
"language": "nl",
|
|
"expected_intent": "entity_lookup",
|
|
"expected_entities": ["rijksmuseum"],
|
|
"description": "Location of Rijksmuseum",
|
|
},
|
|
{
|
|
"question": "Welke archieven zijn gefuseerd in Noord-Holland sinds 2000?",
|
|
"language": "nl",
|
|
"expected_intent": "temporal",
|
|
"expected_entities": ["archieven", "noord-holland", "2000"],
|
|
"description": "Archive mergers in Noord-Holland",
|
|
},
|
|
{
|
|
"question": "Toon erfgoedinstellingen in de buurt van Rotterdam Centraal",
|
|
"language": "nl",
|
|
"expected_intent": "geographic",
|
|
"expected_entities": ["rotterdam centraal"],
|
|
"description": "Heritage institutions near Rotterdam",
|
|
},
|
|
# English queries
|
|
{
|
|
"question": "How many libraries are there in the Netherlands?",
|
|
"language": "en",
|
|
"expected_intent": "statistical",
|
|
"expected_entities": ["libraries", "netherlands"],
|
|
"description": "Count libraries in NL",
|
|
},
|
|
{
|
|
"question": "Show me archives related to World War II",
|
|
"language": "en",
|
|
"expected_intent": "exploration",
|
|
"expected_entities": ["archives", "world war ii"],
|
|
"description": "WWII archives exploration",
|
|
},
|
|
{
|
|
"question": "Compare the collections of Rijksmuseum and Van Gogh Museum",
|
|
"language": "en",
|
|
"expected_intent": "comparative",
|
|
"expected_entities": ["rijksmuseum", "van gogh museum"],
|
|
"description": "Compare two major museums",
|
|
},
|
|
{
|
|
"question": "What institutions are part of the Erfgoed Leiden network?",
|
|
"language": "en",
|
|
"expected_intent": "relational",
|
|
"expected_entities": ["erfgoed leiden"],
|
|
"description": "Network membership query",
|
|
},
|
|
{
|
|
"question": "When was the Nationaal Archief founded?",
|
|
"language": "en",
|
|
"expected_intent": "temporal",
|
|
"expected_entities": ["nationaal archief"],
|
|
"description": "Founding date query",
|
|
},
|
|
{
|
|
"question": "Find galleries in the Randstad region",
|
|
"language": "en",
|
|
"expected_intent": "geographic",
|
|
"expected_entities": ["galleries", "randstad"],
|
|
"expected_entity_type": "institution",
|
|
"description": "Geographic gallery search",
|
|
},
|
|
# Person queries - Dutch
|
|
{
|
|
"question": "Wie werkt bij het Nationaal Archief?",
|
|
"language": "nl",
|
|
"expected_intent": "entity_lookup",
|
|
"expected_entities": ["nationaal archief"],
|
|
"expected_entity_type": "person",
|
|
"description": "Staff at Nationaal Archief (Dutch)",
|
|
},
|
|
{
|
|
"question": "Welke curatoren werken bij het Rijksmuseum?",
|
|
"language": "nl",
|
|
"expected_intent": "entity_lookup",
|
|
"expected_entities": ["rijksmuseum", "curator"],
|
|
"expected_entity_type": "person",
|
|
"description": "Curators at Rijksmuseum (Dutch)",
|
|
},
|
|
# Person queries - English
|
|
{
|
|
"question": "Who works at the Eye Filmmuseum?",
|
|
"language": "en",
|
|
"expected_intent": "entity_lookup",
|
|
"expected_entities": ["eye filmmuseum"],
|
|
"expected_entity_type": "person",
|
|
"description": "Staff at Eye Filmmuseum (English)",
|
|
},
|
|
{
|
|
"question": "Show me the director of the Amsterdam Museum",
|
|
"language": "en",
|
|
"expected_intent": "entity_lookup",
|
|
"expected_entities": ["amsterdam museum", "director"],
|
|
"expected_entity_type": "person",
|
|
"description": "Director of Amsterdam Museum",
|
|
},
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class TestResult:
|
|
"""Result of a single test."""
|
|
query: str
|
|
description: str
|
|
success: bool
|
|
duration_ms: float
|
|
response: dict[str, Any] | None = None
|
|
error: str | None = None
|
|
|
|
|
|
# =============================================================================
|
|
# MODULE TESTS - Direct DSPy module testing
|
|
# =============================================================================
|
|
|
|
async def test_dspy_module() -> list[TestResult]:
|
|
"""Test DSPy Heritage RAG module directly."""
|
|
results = []
|
|
|
|
print("\n" + "=" * 60)
|
|
print("TESTING DSPy HERITAGE RAG MODULE DIRECTLY")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
# Import and configure DSPy
|
|
import dspy
|
|
from dspy_heritage_rag import (
|
|
HeritageRAGPipeline,
|
|
HeritageQueryRouter,
|
|
create_heritage_rag_pipeline,
|
|
create_heritage_training_data,
|
|
)
|
|
|
|
# Configure DSPy with Claude
|
|
import os
|
|
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
|
if not api_key:
|
|
print("WARNING: ANTHROPIC_API_KEY not set, skipping module tests")
|
|
return results
|
|
|
|
lm = dspy.LM(model="anthropic/claude-sonnet-4-20250514", api_key=api_key)
|
|
dspy.configure(lm=lm)
|
|
|
|
print("\nDSPy configured with Claude Sonnet 4")
|
|
|
|
# Create pipeline
|
|
pipeline = create_heritage_rag_pipeline(use_tools=False)
|
|
router = HeritageQueryRouter()
|
|
|
|
print(f"Pipeline created, testing {len(TEST_QUERIES)} queries...\n")
|
|
|
|
for i, test in enumerate(TEST_QUERIES, 1):
|
|
print(f"[{i}/{len(TEST_QUERIES)}] {test['description']}...")
|
|
|
|
start = time.time()
|
|
try:
|
|
# Test router first
|
|
routing = router(
|
|
question=test["question"],
|
|
language=test["language"],
|
|
)
|
|
|
|
# Test full pipeline
|
|
result = pipeline(
|
|
question=test["question"],
|
|
language=test["language"],
|
|
include_viz=True,
|
|
use_agent=False,
|
|
)
|
|
|
|
duration_ms = (time.time() - start) * 1000
|
|
|
|
# Validate response
|
|
success = all([
|
|
hasattr(result, "answer") and result.answer,
|
|
hasattr(result, "intent"),
|
|
hasattr(result, "confidence"),
|
|
])
|
|
|
|
# Check intent match
|
|
intent_match = routing.intent == test["expected_intent"]
|
|
|
|
# Check entity_type match (if expected_entity_type provided)
|
|
entity_type = getattr(routing, 'entity_type', 'institution')
|
|
expected_entity_type = test.get("expected_entity_type", "institution")
|
|
entity_type_match = entity_type == expected_entity_type
|
|
|
|
results.append(TestResult(
|
|
query=test["question"],
|
|
description=test["description"],
|
|
success=success,
|
|
duration_ms=duration_ms,
|
|
response={
|
|
"intent": routing.intent,
|
|
"intent_match": intent_match,
|
|
"entity_type": entity_type,
|
|
"entity_type_match": entity_type_match,
|
|
"entities": routing.entities,
|
|
"answer_preview": result.answer[:100] if result.answer else None,
|
|
"confidence": result.confidence,
|
|
"sources": result.sources_used,
|
|
},
|
|
))
|
|
|
|
status = "PASS" if success else "FAIL"
|
|
intent_status = "OK" if intent_match else f"MISMATCH (got {routing.intent})"
|
|
entity_type_status = "OK" if entity_type_match else f"MISMATCH (got {entity_type})"
|
|
print(f" [{status}] {duration_ms:.0f}ms - Intent: {intent_status}, Entity Type: {entity_type_status}")
|
|
|
|
except Exception as e:
|
|
duration_ms = (time.time() - start) * 1000
|
|
results.append(TestResult(
|
|
query=test["question"],
|
|
description=test["description"],
|
|
success=False,
|
|
duration_ms=duration_ms,
|
|
error=str(e),
|
|
))
|
|
print(f" [ERROR] {e}")
|
|
|
|
except ImportError as e:
|
|
print(f"ERROR: Could not import DSPy module: {e}")
|
|
print("Make sure you're running from the backend/rag directory")
|
|
|
|
return results
|
|
|
|
|
|
async def test_training_data() -> None:
|
|
"""Test that training data is valid."""
|
|
print("\n" + "=" * 60)
|
|
print("TESTING GEPA TRAINING DATA")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
from dspy_heritage_rag import create_heritage_training_data
|
|
|
|
trainset, valset = create_heritage_training_data()
|
|
|
|
print(f"\nTraining examples: {len(trainset)}")
|
|
print(f"Validation examples: {len(valset)}")
|
|
|
|
print("\nSample training example:")
|
|
ex = trainset[0]
|
|
print(f" Question: {ex.question}")
|
|
print(f" Language: {ex.language}")
|
|
print(f" Expected intent: {ex.expected_intent}")
|
|
print(f" Expected entities: {ex.expected_entities}")
|
|
print(f" Expected sources: {ex.expected_sources}")
|
|
|
|
# Validate all examples have required fields
|
|
all_valid = True
|
|
for ex in trainset + valset:
|
|
if not all([
|
|
hasattr(ex, "question"),
|
|
hasattr(ex, "language"),
|
|
hasattr(ex, "expected_intent"),
|
|
hasattr(ex, "expected_entities"),
|
|
hasattr(ex, "expected_sources"),
|
|
]):
|
|
all_valid = False
|
|
print(f" INVALID: {ex}")
|
|
|
|
status = "PASS" if all_valid else "FAIL"
|
|
print(f"\n[{status}] All training examples valid")
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: {e}")
|
|
|
|
|
|
# =============================================================================
|
|
# API TESTS - Test endpoints via HTTP
|
|
# =============================================================================
|
|
|
|
async def test_api_health(client: httpx.AsyncClient) -> bool:
|
|
"""Test DSPy RAG health endpoint."""
|
|
print("\n--- Testing /api/dspy/rag/health ---")
|
|
|
|
try:
|
|
response = await client.get("/api/dspy/rag/health")
|
|
data = response.json()
|
|
|
|
print(f"Status: {response.status_code}")
|
|
print(f"DSPy available: {data.get('components', {}).get('dspy_available')}")
|
|
print(f"Pipeline initialized: {data.get('components', {}).get('pipeline_initialized')}")
|
|
print(f"GEPA optimized: {data.get('components', {}).get('gepa_optimized')}")
|
|
|
|
return response.status_code == 200 and data.get("status") == "ok"
|
|
except Exception as e:
|
|
print(f"ERROR: {e}")
|
|
return False
|
|
|
|
|
|
async def test_api_training_data(client: httpx.AsyncClient) -> bool:
|
|
"""Test training data endpoint."""
|
|
print("\n--- Testing /api/dspy/rag/training-data ---")
|
|
|
|
try:
|
|
response = await client.get("/api/dspy/rag/training-data")
|
|
data = response.json()
|
|
|
|
print(f"Status: {response.status_code}")
|
|
print(f"Training examples: {data.get('total_training')}")
|
|
print(f"Validation examples: {data.get('total_validation')}")
|
|
|
|
return response.status_code == 200 and data.get("total_training", 0) > 0
|
|
except Exception as e:
|
|
print(f"ERROR: {e}")
|
|
return False
|
|
|
|
|
|
async def test_api_query(client: httpx.AsyncClient, test: dict) -> TestResult:
|
|
"""Test DSPy RAG query endpoint."""
|
|
start = time.time()
|
|
|
|
try:
|
|
response = await client.post(
|
|
"/api/dspy/rag/query",
|
|
json={
|
|
"question": test["question"],
|
|
"language": test["language"],
|
|
"include_visualization": True,
|
|
"use_agent": False,
|
|
},
|
|
timeout=60.0,
|
|
)
|
|
|
|
duration_ms = (time.time() - start) * 1000
|
|
|
|
if response.status_code != 200:
|
|
return TestResult(
|
|
query=test["question"],
|
|
description=test["description"],
|
|
success=False,
|
|
duration_ms=duration_ms,
|
|
error=f"HTTP {response.status_code}: {response.text}",
|
|
)
|
|
|
|
data = response.json()
|
|
|
|
success = all([
|
|
data.get("answer"),
|
|
data.get("intent"),
|
|
data.get("confidence") is not None,
|
|
])
|
|
|
|
return TestResult(
|
|
query=test["question"],
|
|
description=test["description"],
|
|
success=success,
|
|
duration_ms=duration_ms,
|
|
response={
|
|
"intent": data.get("intent"),
|
|
"intent_match": data.get("intent") == test["expected_intent"],
|
|
"answer_preview": data.get("answer", "")[:100],
|
|
"confidence": data.get("confidence"),
|
|
"sources": data.get("sources_used"),
|
|
"visualization": data.get("visualization", {}).get("type"),
|
|
},
|
|
)
|
|
|
|
except Exception as e:
|
|
duration_ms = (time.time() - start) * 1000
|
|
return TestResult(
|
|
query=test["question"],
|
|
description=test["description"],
|
|
success=False,
|
|
duration_ms=duration_ms,
|
|
error=str(e),
|
|
)
|
|
|
|
|
|
async def test_api_stream(client: httpx.AsyncClient, test: dict) -> TestResult:
|
|
"""Test DSPy RAG streaming endpoint."""
|
|
start = time.time()
|
|
chunks = []
|
|
|
|
try:
|
|
async with client.stream(
|
|
"POST",
|
|
"/api/dspy/rag/query/stream",
|
|
json={
|
|
"question": test["question"],
|
|
"language": test["language"],
|
|
"include_visualization": True,
|
|
},
|
|
timeout=60.0,
|
|
) as response:
|
|
async for line in response.aiter_lines():
|
|
if line.strip():
|
|
chunks.append(json.loads(line))
|
|
|
|
duration_ms = (time.time() - start) * 1000
|
|
|
|
# Find complete message
|
|
complete = next((c for c in chunks if c.get("type") == "complete"), None)
|
|
routing = next((c for c in chunks if c.get("type") == "routing"), None)
|
|
|
|
success = complete is not None and len(chunks) > 2
|
|
|
|
return TestResult(
|
|
query=test["question"],
|
|
description=test["description"],
|
|
success=success,
|
|
duration_ms=duration_ms,
|
|
response={
|
|
"chunk_count": len(chunks),
|
|
"chunk_types": [c.get("type") for c in chunks],
|
|
"intent": routing.get("intent") if routing else None,
|
|
"answer_preview": complete.get("answer", "")[:100] if complete else None,
|
|
},
|
|
)
|
|
|
|
except Exception as e:
|
|
duration_ms = (time.time() - start) * 1000
|
|
return TestResult(
|
|
query=test["question"],
|
|
description=test["description"],
|
|
success=False,
|
|
duration_ms=duration_ms,
|
|
error=str(e),
|
|
)
|
|
|
|
|
|
async def test_api_endpoints(base_url: str = "http://localhost:8000") -> list[TestResult]:
|
|
"""Test all API endpoints."""
|
|
results = []
|
|
|
|
print("\n" + "=" * 60)
|
|
print(f"TESTING DSPy RAG API ENDPOINTS ({base_url})")
|
|
print("=" * 60)
|
|
|
|
async with httpx.AsyncClient(base_url=base_url) as client:
|
|
# Test health endpoint
|
|
health_ok = await test_api_health(client)
|
|
if not health_ok:
|
|
print("\nWARNING: DSPy RAG not available, some tests may fail")
|
|
|
|
# Test training data endpoint
|
|
await test_api_training_data(client)
|
|
|
|
# Test query endpoint with all test queries
|
|
print(f"\n--- Testing /api/dspy/rag/query ({len(TEST_QUERIES)} queries) ---")
|
|
|
|
for i, test in enumerate(TEST_QUERIES, 1):
|
|
print(f"[{i}/{len(TEST_QUERIES)}] {test['description']}...")
|
|
result = await test_api_query(client, test)
|
|
results.append(result)
|
|
|
|
if result.success:
|
|
intent_match = result.response.get("intent_match", False) if result.response else False
|
|
intent_status = "OK" if intent_match else "MISMATCH"
|
|
print(f" [PASS] {result.duration_ms:.0f}ms - Intent: {intent_status}")
|
|
else:
|
|
print(f" [FAIL] {result.error or 'Unknown error'}")
|
|
|
|
return results
|
|
|
|
|
|
async def test_streaming_endpoint(base_url: str = "http://localhost:8000") -> list[TestResult]:
|
|
"""Test streaming endpoint specifically."""
|
|
results = []
|
|
|
|
print("\n" + "=" * 60)
|
|
print(f"TESTING DSPy RAG STREAMING ENDPOINT ({base_url})")
|
|
print("=" * 60)
|
|
|
|
# Test a subset of queries for streaming
|
|
stream_tests = TEST_QUERIES[:3]
|
|
|
|
async with httpx.AsyncClient(base_url=base_url) as client:
|
|
for i, test in enumerate(stream_tests, 1):
|
|
print(f"\n[{i}/{len(stream_tests)}] {test['description']}...")
|
|
result = await test_api_stream(client, test)
|
|
results.append(result)
|
|
|
|
if result.success:
|
|
chunk_count = result.response.get("chunk_count", 0) if result.response else 0
|
|
print(f" [PASS] {result.duration_ms:.0f}ms - Received {chunk_count} chunks")
|
|
if result.response:
|
|
print(f" Chunk types: {result.response.get('chunk_types')}")
|
|
else:
|
|
print(f" [FAIL] {result.error or 'Unknown error'}")
|
|
|
|
return results
|
|
|
|
|
|
# =============================================================================
|
|
# MAIN
|
|
# =============================================================================
|
|
|
|
def print_summary(results: list[TestResult], title: str) -> None:
|
|
"""Print test summary."""
|
|
if not results:
|
|
return
|
|
|
|
passed = sum(1 for r in results if r.success)
|
|
failed = len(results) - passed
|
|
avg_duration = sum(r.duration_ms for r in results) / len(results)
|
|
|
|
print("\n" + "=" * 60)
|
|
print(f"SUMMARY: {title}")
|
|
print("=" * 60)
|
|
print(f"Total tests: {len(results)}")
|
|
print(f"Passed: {passed}")
|
|
print(f"Failed: {failed}")
|
|
print(f"Success rate: {passed/len(results)*100:.1f}%")
|
|
print(f"Average duration: {avg_duration:.0f}ms")
|
|
|
|
if failed > 0:
|
|
print("\nFailed tests:")
|
|
for r in results:
|
|
if not r.success:
|
|
print(f" - {r.description}: {r.error or 'Unknown error'}")
|
|
|
|
|
|
async def main():
|
|
parser = argparse.ArgumentParser(description="Test DSPy Heritage RAG")
|
|
parser.add_argument("--module", action="store_true", help="Test DSPy module directly")
|
|
parser.add_argument("--api", action="store_true", help="Test API endpoints")
|
|
parser.add_argument("--stream", action="store_true", help="Test streaming endpoint")
|
|
parser.add_argument("--all", action="store_true", help="Run all tests")
|
|
parser.add_argument("--url", default="http://localhost:8000", help="API base URL")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Default to --all if no specific test selected
|
|
if not any([args.module, args.api, args.stream, args.all]):
|
|
args.all = True
|
|
|
|
all_results = []
|
|
|
|
if args.module or args.all:
|
|
results = await test_dspy_module()
|
|
all_results.extend(results)
|
|
print_summary(results, "DSPy Module Tests")
|
|
|
|
await test_training_data()
|
|
|
|
if args.api or args.all:
|
|
results = await test_api_endpoints(args.url)
|
|
all_results.extend(results)
|
|
print_summary(results, "API Endpoint Tests")
|
|
|
|
if args.stream or args.all:
|
|
results = await test_streaming_endpoint(args.url)
|
|
all_results.extend(results)
|
|
print_summary(results, "Streaming Tests")
|
|
|
|
# Overall summary
|
|
if all_results:
|
|
print_summary(all_results, "ALL TESTS")
|
|
|
|
# Exit code based on success
|
|
passed = sum(1 for r in all_results if r.success)
|
|
sys.exit(0 if passed == len(all_results) else 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|