glam/backend/rag/test_dspy_rag.py
2025-12-14 17:09:55 +01:00

605 lines
21 KiB
Python

#!/usr/bin/env python3
"""
Test script for DSPy Heritage RAG endpoints.
Tests both the DSPy module directly and the API endpoints.
Usage:
# Test DSPy module directly (no server required)
python test_dspy_rag.py --module
# Test API endpoints (requires server running on localhost:8000)
python test_dspy_rag.py --api
# Test both
python test_dspy_rag.py --all
# Test streaming endpoint
python test_dspy_rag.py --stream
"""
import argparse
import asyncio
import json
import sys
import time
from dataclasses import dataclass
from typing import Any
import httpx
# =============================================================================
# TEST QUERIES - Real heritage institution questions
# =============================================================================
TEST_QUERIES = [
# Dutch queries
{
"question": "Hoeveel musea zijn er in Amsterdam?",
"language": "nl",
"expected_intent": "statistical",
"expected_entities": ["amsterdam", "musea"],
"description": "Count museums in Amsterdam",
},
{
"question": "Waar is het Rijksmuseum gevestigd?",
"language": "nl",
"expected_intent": "entity_lookup",
"expected_entities": ["rijksmuseum"],
"description": "Location of Rijksmuseum",
},
{
"question": "Welke archieven zijn gefuseerd in Noord-Holland sinds 2000?",
"language": "nl",
"expected_intent": "temporal",
"expected_entities": ["archieven", "noord-holland", "2000"],
"description": "Archive mergers in Noord-Holland",
},
{
"question": "Toon erfgoedinstellingen in de buurt van Rotterdam Centraal",
"language": "nl",
"expected_intent": "geographic",
"expected_entities": ["rotterdam centraal"],
"description": "Heritage institutions near Rotterdam",
},
# English queries
{
"question": "How many libraries are there in the Netherlands?",
"language": "en",
"expected_intent": "statistical",
"expected_entities": ["libraries", "netherlands"],
"description": "Count libraries in NL",
},
{
"question": "Show me archives related to World War II",
"language": "en",
"expected_intent": "exploration",
"expected_entities": ["archives", "world war ii"],
"description": "WWII archives exploration",
},
{
"question": "Compare the collections of Rijksmuseum and Van Gogh Museum",
"language": "en",
"expected_intent": "comparative",
"expected_entities": ["rijksmuseum", "van gogh museum"],
"description": "Compare two major museums",
},
{
"question": "What institutions are part of the Erfgoed Leiden network?",
"language": "en",
"expected_intent": "relational",
"expected_entities": ["erfgoed leiden"],
"description": "Network membership query",
},
{
"question": "When was the Nationaal Archief founded?",
"language": "en",
"expected_intent": "temporal",
"expected_entities": ["nationaal archief"],
"description": "Founding date query",
},
{
"question": "Find galleries in the Randstad region",
"language": "en",
"expected_intent": "geographic",
"expected_entities": ["galleries", "randstad"],
"expected_entity_type": "institution",
"description": "Geographic gallery search",
},
# Person queries - Dutch
{
"question": "Wie werkt bij het Nationaal Archief?",
"language": "nl",
"expected_intent": "entity_lookup",
"expected_entities": ["nationaal archief"],
"expected_entity_type": "person",
"description": "Staff at Nationaal Archief (Dutch)",
},
{
"question": "Welke curatoren werken bij het Rijksmuseum?",
"language": "nl",
"expected_intent": "entity_lookup",
"expected_entities": ["rijksmuseum", "curator"],
"expected_entity_type": "person",
"description": "Curators at Rijksmuseum (Dutch)",
},
# Person queries - English
{
"question": "Who works at the Eye Filmmuseum?",
"language": "en",
"expected_intent": "entity_lookup",
"expected_entities": ["eye filmmuseum"],
"expected_entity_type": "person",
"description": "Staff at Eye Filmmuseum (English)",
},
{
"question": "Show me the director of the Amsterdam Museum",
"language": "en",
"expected_intent": "entity_lookup",
"expected_entities": ["amsterdam museum", "director"],
"expected_entity_type": "person",
"description": "Director of Amsterdam Museum",
},
]
@dataclass
class TestResult:
"""Result of a single test."""
query: str
description: str
success: bool
duration_ms: float
response: dict[str, Any] | None = None
error: str | None = None
# =============================================================================
# MODULE TESTS - Direct DSPy module testing
# =============================================================================
async def test_dspy_module() -> list[TestResult]:
"""Test DSPy Heritage RAG module directly."""
results = []
print("\n" + "=" * 60)
print("TESTING DSPy HERITAGE RAG MODULE DIRECTLY")
print("=" * 60)
try:
# Import and configure DSPy
import dspy
from dspy_heritage_rag import (
HeritageRAGPipeline,
HeritageQueryRouter,
create_heritage_rag_pipeline,
create_heritage_training_data,
)
# Configure DSPy with Claude
import os
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
print("WARNING: ANTHROPIC_API_KEY not set, skipping module tests")
return results
lm = dspy.LM(model="anthropic/claude-sonnet-4-20250514", api_key=api_key)
dspy.configure(lm=lm)
print("\nDSPy configured with Claude Sonnet 4")
# Create pipeline
pipeline = create_heritage_rag_pipeline(use_tools=False)
router = HeritageQueryRouter()
print(f"Pipeline created, testing {len(TEST_QUERIES)} queries...\n")
for i, test in enumerate(TEST_QUERIES, 1):
print(f"[{i}/{len(TEST_QUERIES)}] {test['description']}...")
start = time.time()
try:
# Test router first
routing = router(
question=test["question"],
language=test["language"],
)
# Test full pipeline
result = pipeline(
question=test["question"],
language=test["language"],
include_viz=True,
use_agent=False,
)
duration_ms = (time.time() - start) * 1000
# Validate response
success = all([
hasattr(result, "answer") and result.answer,
hasattr(result, "intent"),
hasattr(result, "confidence"),
])
# Check intent match
intent_match = routing.intent == test["expected_intent"]
# Check entity_type match (if expected_entity_type provided)
entity_type = getattr(routing, 'entity_type', 'institution')
expected_entity_type = test.get("expected_entity_type", "institution")
entity_type_match = entity_type == expected_entity_type
results.append(TestResult(
query=test["question"],
description=test["description"],
success=success,
duration_ms=duration_ms,
response={
"intent": routing.intent,
"intent_match": intent_match,
"entity_type": entity_type,
"entity_type_match": entity_type_match,
"entities": routing.entities,
"answer_preview": result.answer[:100] if result.answer else None,
"confidence": result.confidence,
"sources": result.sources_used,
},
))
status = "PASS" if success else "FAIL"
intent_status = "OK" if intent_match else f"MISMATCH (got {routing.intent})"
entity_type_status = "OK" if entity_type_match else f"MISMATCH (got {entity_type})"
print(f" [{status}] {duration_ms:.0f}ms - Intent: {intent_status}, Entity Type: {entity_type_status}")
except Exception as e:
duration_ms = (time.time() - start) * 1000
results.append(TestResult(
query=test["question"],
description=test["description"],
success=False,
duration_ms=duration_ms,
error=str(e),
))
print(f" [ERROR] {e}")
except ImportError as e:
print(f"ERROR: Could not import DSPy module: {e}")
print("Make sure you're running from the backend/rag directory")
return results
async def test_training_data() -> None:
"""Test that training data is valid."""
print("\n" + "=" * 60)
print("TESTING GEPA TRAINING DATA")
print("=" * 60)
try:
from dspy_heritage_rag import create_heritage_training_data
trainset, valset = create_heritage_training_data()
print(f"\nTraining examples: {len(trainset)}")
print(f"Validation examples: {len(valset)}")
print("\nSample training example:")
ex = trainset[0]
print(f" Question: {ex.question}")
print(f" Language: {ex.language}")
print(f" Expected intent: {ex.expected_intent}")
print(f" Expected entities: {ex.expected_entities}")
print(f" Expected sources: {ex.expected_sources}")
# Validate all examples have required fields
all_valid = True
for ex in trainset + valset:
if not all([
hasattr(ex, "question"),
hasattr(ex, "language"),
hasattr(ex, "expected_intent"),
hasattr(ex, "expected_entities"),
hasattr(ex, "expected_sources"),
]):
all_valid = False
print(f" INVALID: {ex}")
status = "PASS" if all_valid else "FAIL"
print(f"\n[{status}] All training examples valid")
except Exception as e:
print(f"ERROR: {e}")
# =============================================================================
# API TESTS - Test endpoints via HTTP
# =============================================================================
async def test_api_health(client: httpx.AsyncClient) -> bool:
"""Test DSPy RAG health endpoint."""
print("\n--- Testing /api/dspy/rag/health ---")
try:
response = await client.get("/api/dspy/rag/health")
data = response.json()
print(f"Status: {response.status_code}")
print(f"DSPy available: {data.get('components', {}).get('dspy_available')}")
print(f"Pipeline initialized: {data.get('components', {}).get('pipeline_initialized')}")
print(f"GEPA optimized: {data.get('components', {}).get('gepa_optimized')}")
return response.status_code == 200 and data.get("status") == "ok"
except Exception as e:
print(f"ERROR: {e}")
return False
async def test_api_training_data(client: httpx.AsyncClient) -> bool:
"""Test training data endpoint."""
print("\n--- Testing /api/dspy/rag/training-data ---")
try:
response = await client.get("/api/dspy/rag/training-data")
data = response.json()
print(f"Status: {response.status_code}")
print(f"Training examples: {data.get('total_training')}")
print(f"Validation examples: {data.get('total_validation')}")
return response.status_code == 200 and data.get("total_training", 0) > 0
except Exception as e:
print(f"ERROR: {e}")
return False
async def test_api_query(client: httpx.AsyncClient, test: dict) -> TestResult:
"""Test DSPy RAG query endpoint."""
start = time.time()
try:
response = await client.post(
"/api/dspy/rag/query",
json={
"question": test["question"],
"language": test["language"],
"include_visualization": True,
"use_agent": False,
},
timeout=60.0,
)
duration_ms = (time.time() - start) * 1000
if response.status_code != 200:
return TestResult(
query=test["question"],
description=test["description"],
success=False,
duration_ms=duration_ms,
error=f"HTTP {response.status_code}: {response.text}",
)
data = response.json()
success = all([
data.get("answer"),
data.get("intent"),
data.get("confidence") is not None,
])
return TestResult(
query=test["question"],
description=test["description"],
success=success,
duration_ms=duration_ms,
response={
"intent": data.get("intent"),
"intent_match": data.get("intent") == test["expected_intent"],
"answer_preview": data.get("answer", "")[:100],
"confidence": data.get("confidence"),
"sources": data.get("sources_used"),
"visualization": data.get("visualization", {}).get("type"),
},
)
except Exception as e:
duration_ms = (time.time() - start) * 1000
return TestResult(
query=test["question"],
description=test["description"],
success=False,
duration_ms=duration_ms,
error=str(e),
)
async def test_api_stream(client: httpx.AsyncClient, test: dict) -> TestResult:
"""Test DSPy RAG streaming endpoint."""
start = time.time()
chunks = []
try:
async with client.stream(
"POST",
"/api/dspy/rag/query/stream",
json={
"question": test["question"],
"language": test["language"],
"include_visualization": True,
},
timeout=60.0,
) as response:
async for line in response.aiter_lines():
if line.strip():
chunks.append(json.loads(line))
duration_ms = (time.time() - start) * 1000
# Find complete message
complete = next((c for c in chunks if c.get("type") == "complete"), None)
routing = next((c for c in chunks if c.get("type") == "routing"), None)
success = complete is not None and len(chunks) > 2
return TestResult(
query=test["question"],
description=test["description"],
success=success,
duration_ms=duration_ms,
response={
"chunk_count": len(chunks),
"chunk_types": [c.get("type") for c in chunks],
"intent": routing.get("intent") if routing else None,
"answer_preview": complete.get("answer", "")[:100] if complete else None,
},
)
except Exception as e:
duration_ms = (time.time() - start) * 1000
return TestResult(
query=test["question"],
description=test["description"],
success=False,
duration_ms=duration_ms,
error=str(e),
)
async def test_api_endpoints(base_url: str = "http://localhost:8000") -> list[TestResult]:
"""Test all API endpoints."""
results = []
print("\n" + "=" * 60)
print(f"TESTING DSPy RAG API ENDPOINTS ({base_url})")
print("=" * 60)
async with httpx.AsyncClient(base_url=base_url) as client:
# Test health endpoint
health_ok = await test_api_health(client)
if not health_ok:
print("\nWARNING: DSPy RAG not available, some tests may fail")
# Test training data endpoint
await test_api_training_data(client)
# Test query endpoint with all test queries
print(f"\n--- Testing /api/dspy/rag/query ({len(TEST_QUERIES)} queries) ---")
for i, test in enumerate(TEST_QUERIES, 1):
print(f"[{i}/{len(TEST_QUERIES)}] {test['description']}...")
result = await test_api_query(client, test)
results.append(result)
if result.success:
intent_match = result.response.get("intent_match", False) if result.response else False
intent_status = "OK" if intent_match else "MISMATCH"
print(f" [PASS] {result.duration_ms:.0f}ms - Intent: {intent_status}")
else:
print(f" [FAIL] {result.error or 'Unknown error'}")
return results
async def test_streaming_endpoint(base_url: str = "http://localhost:8000") -> list[TestResult]:
"""Test streaming endpoint specifically."""
results = []
print("\n" + "=" * 60)
print(f"TESTING DSPy RAG STREAMING ENDPOINT ({base_url})")
print("=" * 60)
# Test a subset of queries for streaming
stream_tests = TEST_QUERIES[:3]
async with httpx.AsyncClient(base_url=base_url) as client:
for i, test in enumerate(stream_tests, 1):
print(f"\n[{i}/{len(stream_tests)}] {test['description']}...")
result = await test_api_stream(client, test)
results.append(result)
if result.success:
chunk_count = result.response.get("chunk_count", 0) if result.response else 0
print(f" [PASS] {result.duration_ms:.0f}ms - Received {chunk_count} chunks")
if result.response:
print(f" Chunk types: {result.response.get('chunk_types')}")
else:
print(f" [FAIL] {result.error or 'Unknown error'}")
return results
# =============================================================================
# MAIN
# =============================================================================
def print_summary(results: list[TestResult], title: str) -> None:
"""Print test summary."""
if not results:
return
passed = sum(1 for r in results if r.success)
failed = len(results) - passed
avg_duration = sum(r.duration_ms for r in results) / len(results)
print("\n" + "=" * 60)
print(f"SUMMARY: {title}")
print("=" * 60)
print(f"Total tests: {len(results)}")
print(f"Passed: {passed}")
print(f"Failed: {failed}")
print(f"Success rate: {passed/len(results)*100:.1f}%")
print(f"Average duration: {avg_duration:.0f}ms")
if failed > 0:
print("\nFailed tests:")
for r in results:
if not r.success:
print(f" - {r.description}: {r.error or 'Unknown error'}")
async def main():
parser = argparse.ArgumentParser(description="Test DSPy Heritage RAG")
parser.add_argument("--module", action="store_true", help="Test DSPy module directly")
parser.add_argument("--api", action="store_true", help="Test API endpoints")
parser.add_argument("--stream", action="store_true", help="Test streaming endpoint")
parser.add_argument("--all", action="store_true", help="Run all tests")
parser.add_argument("--url", default="http://localhost:8000", help="API base URL")
args = parser.parse_args()
# Default to --all if no specific test selected
if not any([args.module, args.api, args.stream, args.all]):
args.all = True
all_results = []
if args.module or args.all:
results = await test_dspy_module()
all_results.extend(results)
print_summary(results, "DSPy Module Tests")
await test_training_data()
if args.api or args.all:
results = await test_api_endpoints(args.url)
all_results.extend(results)
print_summary(results, "API Endpoint Tests")
if args.stream or args.all:
results = await test_streaming_endpoint(args.url)
all_results.extend(results)
print_summary(results, "Streaming Tests")
# Overall summary
if all_results:
print_summary(all_results, "ALL TESTS")
# Exit code based on success
passed = sum(1 for r in all_results if r.success)
sys.exit(0 if passed == len(all_results) else 1)
if __name__ == "__main__":
asyncio.run(main())