#!/usr/bin/env python3 """ Test script for DSPy Heritage RAG endpoints. Tests both the DSPy module directly and the API endpoints. Usage: # Test DSPy module directly (no server required) python test_dspy_rag.py --module # Test API endpoints (requires server running on localhost:8000) python test_dspy_rag.py --api # Test both python test_dspy_rag.py --all # Test streaming endpoint python test_dspy_rag.py --stream """ import argparse import asyncio import json import sys import time from dataclasses import dataclass from typing import Any import httpx # ============================================================================= # TEST QUERIES - Real heritage institution questions # ============================================================================= TEST_QUERIES = [ # Dutch queries { "question": "Hoeveel musea zijn er in Amsterdam?", "language": "nl", "expected_intent": "statistical", "expected_entities": ["amsterdam", "musea"], "description": "Count museums in Amsterdam", }, { "question": "Waar is het Rijksmuseum gevestigd?", "language": "nl", "expected_intent": "entity_lookup", "expected_entities": ["rijksmuseum"], "description": "Location of Rijksmuseum", }, { "question": "Welke archieven zijn gefuseerd in Noord-Holland sinds 2000?", "language": "nl", "expected_intent": "temporal", "expected_entities": ["archieven", "noord-holland", "2000"], "description": "Archive mergers in Noord-Holland", }, { "question": "Toon erfgoedinstellingen in de buurt van Rotterdam Centraal", "language": "nl", "expected_intent": "geographic", "expected_entities": ["rotterdam centraal"], "description": "Heritage institutions near Rotterdam", }, # English queries { "question": "How many libraries are there in the Netherlands?", "language": "en", "expected_intent": "statistical", "expected_entities": ["libraries", "netherlands"], "description": "Count libraries in NL", }, { "question": "Show me archives related to World War II", "language": "en", "expected_intent": "exploration", "expected_entities": ["archives", "world war ii"], "description": "WWII archives exploration", }, { "question": "Compare the collections of Rijksmuseum and Van Gogh Museum", "language": "en", "expected_intent": "comparative", "expected_entities": ["rijksmuseum", "van gogh museum"], "description": "Compare two major museums", }, { "question": "What institutions are part of the Erfgoed Leiden network?", "language": "en", "expected_intent": "relational", "expected_entities": ["erfgoed leiden"], "description": "Network membership query", }, { "question": "When was the Nationaal Archief founded?", "language": "en", "expected_intent": "temporal", "expected_entities": ["nationaal archief"], "description": "Founding date query", }, { "question": "Find galleries in the Randstad region", "language": "en", "expected_intent": "geographic", "expected_entities": ["galleries", "randstad"], "expected_entity_type": "institution", "description": "Geographic gallery search", }, # Person queries - Dutch { "question": "Wie werkt bij het Nationaal Archief?", "language": "nl", "expected_intent": "entity_lookup", "expected_entities": ["nationaal archief"], "expected_entity_type": "person", "description": "Staff at Nationaal Archief (Dutch)", }, { "question": "Welke curatoren werken bij het Rijksmuseum?", "language": "nl", "expected_intent": "entity_lookup", "expected_entities": ["rijksmuseum", "curator"], "expected_entity_type": "person", "description": "Curators at Rijksmuseum (Dutch)", }, # Person queries - English { "question": "Who works at the Eye Filmmuseum?", "language": "en", "expected_intent": "entity_lookup", "expected_entities": ["eye filmmuseum"], "expected_entity_type": "person", "description": "Staff at Eye Filmmuseum (English)", }, { "question": "Show me the director of the Amsterdam Museum", "language": "en", "expected_intent": "entity_lookup", "expected_entities": ["amsterdam museum", "director"], "expected_entity_type": "person", "description": "Director of Amsterdam Museum", }, ] @dataclass class TestResult: """Result of a single test.""" query: str description: str success: bool duration_ms: float response: dict[str, Any] | None = None error: str | None = None # ============================================================================= # MODULE TESTS - Direct DSPy module testing # ============================================================================= async def test_dspy_module() -> list[TestResult]: """Test DSPy Heritage RAG module directly.""" results = [] print("\n" + "=" * 60) print("TESTING DSPy HERITAGE RAG MODULE DIRECTLY") print("=" * 60) try: # Import and configure DSPy import dspy from dspy_heritage_rag import ( HeritageRAGPipeline, HeritageQueryRouter, create_heritage_rag_pipeline, create_heritage_training_data, ) # Configure DSPy with Claude import os api_key = os.environ.get("ANTHROPIC_API_KEY") if not api_key: print("WARNING: ANTHROPIC_API_KEY not set, skipping module tests") return results lm = dspy.LM(model="anthropic/claude-sonnet-4-20250514", api_key=api_key) dspy.configure(lm=lm) print("\nDSPy configured with Claude Sonnet 4") # Create pipeline pipeline = create_heritage_rag_pipeline(use_tools=False) router = HeritageQueryRouter() print(f"Pipeline created, testing {len(TEST_QUERIES)} queries...\n") for i, test in enumerate(TEST_QUERIES, 1): print(f"[{i}/{len(TEST_QUERIES)}] {test['description']}...") start = time.time() try: # Test router first routing = router( question=test["question"], language=test["language"], ) # Test full pipeline result = pipeline( question=test["question"], language=test["language"], include_viz=True, use_agent=False, ) duration_ms = (time.time() - start) * 1000 # Validate response success = all([ hasattr(result, "answer") and result.answer, hasattr(result, "intent"), hasattr(result, "confidence"), ]) # Check intent match intent_match = routing.intent == test["expected_intent"] # Check entity_type match (if expected_entity_type provided) entity_type = getattr(routing, 'entity_type', 'institution') expected_entity_type = test.get("expected_entity_type", "institution") entity_type_match = entity_type == expected_entity_type results.append(TestResult( query=test["question"], description=test["description"], success=success, duration_ms=duration_ms, response={ "intent": routing.intent, "intent_match": intent_match, "entity_type": entity_type, "entity_type_match": entity_type_match, "entities": routing.entities, "answer_preview": result.answer[:100] if result.answer else None, "confidence": result.confidence, "sources": result.sources_used, }, )) status = "PASS" if success else "FAIL" intent_status = "OK" if intent_match else f"MISMATCH (got {routing.intent})" entity_type_status = "OK" if entity_type_match else f"MISMATCH (got {entity_type})" print(f" [{status}] {duration_ms:.0f}ms - Intent: {intent_status}, Entity Type: {entity_type_status}") except Exception as e: duration_ms = (time.time() - start) * 1000 results.append(TestResult( query=test["question"], description=test["description"], success=False, duration_ms=duration_ms, error=str(e), )) print(f" [ERROR] {e}") except ImportError as e: print(f"ERROR: Could not import DSPy module: {e}") print("Make sure you're running from the backend/rag directory") return results async def test_training_data() -> None: """Test that training data is valid.""" print("\n" + "=" * 60) print("TESTING GEPA TRAINING DATA") print("=" * 60) try: from dspy_heritage_rag import create_heritage_training_data trainset, valset = create_heritage_training_data() print(f"\nTraining examples: {len(trainset)}") print(f"Validation examples: {len(valset)}") print("\nSample training example:") ex = trainset[0] print(f" Question: {ex.question}") print(f" Language: {ex.language}") print(f" Expected intent: {ex.expected_intent}") print(f" Expected entities: {ex.expected_entities}") print(f" Expected sources: {ex.expected_sources}") # Validate all examples have required fields all_valid = True for ex in trainset + valset: if not all([ hasattr(ex, "question"), hasattr(ex, "language"), hasattr(ex, "expected_intent"), hasattr(ex, "expected_entities"), hasattr(ex, "expected_sources"), ]): all_valid = False print(f" INVALID: {ex}") status = "PASS" if all_valid else "FAIL" print(f"\n[{status}] All training examples valid") except Exception as e: print(f"ERROR: {e}") # ============================================================================= # API TESTS - Test endpoints via HTTP # ============================================================================= async def test_api_health(client: httpx.AsyncClient) -> bool: """Test DSPy RAG health endpoint.""" print("\n--- Testing /api/dspy/rag/health ---") try: response = await client.get("/api/dspy/rag/health") data = response.json() print(f"Status: {response.status_code}") print(f"DSPy available: {data.get('components', {}).get('dspy_available')}") print(f"Pipeline initialized: {data.get('components', {}).get('pipeline_initialized')}") print(f"GEPA optimized: {data.get('components', {}).get('gepa_optimized')}") return response.status_code == 200 and data.get("status") == "ok" except Exception as e: print(f"ERROR: {e}") return False async def test_api_training_data(client: httpx.AsyncClient) -> bool: """Test training data endpoint.""" print("\n--- Testing /api/dspy/rag/training-data ---") try: response = await client.get("/api/dspy/rag/training-data") data = response.json() print(f"Status: {response.status_code}") print(f"Training examples: {data.get('total_training')}") print(f"Validation examples: {data.get('total_validation')}") return response.status_code == 200 and data.get("total_training", 0) > 0 except Exception as e: print(f"ERROR: {e}") return False async def test_api_query(client: httpx.AsyncClient, test: dict) -> TestResult: """Test DSPy RAG query endpoint.""" start = time.time() try: response = await client.post( "/api/dspy/rag/query", json={ "question": test["question"], "language": test["language"], "include_visualization": True, "use_agent": False, }, timeout=60.0, ) duration_ms = (time.time() - start) * 1000 if response.status_code != 200: return TestResult( query=test["question"], description=test["description"], success=False, duration_ms=duration_ms, error=f"HTTP {response.status_code}: {response.text}", ) data = response.json() success = all([ data.get("answer"), data.get("intent"), data.get("confidence") is not None, ]) return TestResult( query=test["question"], description=test["description"], success=success, duration_ms=duration_ms, response={ "intent": data.get("intent"), "intent_match": data.get("intent") == test["expected_intent"], "answer_preview": data.get("answer", "")[:100], "confidence": data.get("confidence"), "sources": data.get("sources_used"), "visualization": data.get("visualization", {}).get("type"), }, ) except Exception as e: duration_ms = (time.time() - start) * 1000 return TestResult( query=test["question"], description=test["description"], success=False, duration_ms=duration_ms, error=str(e), ) async def test_api_stream(client: httpx.AsyncClient, test: dict) -> TestResult: """Test DSPy RAG streaming endpoint.""" start = time.time() chunks = [] try: async with client.stream( "POST", "/api/dspy/rag/query/stream", json={ "question": test["question"], "language": test["language"], "include_visualization": True, }, timeout=60.0, ) as response: async for line in response.aiter_lines(): if line.strip(): chunks.append(json.loads(line)) duration_ms = (time.time() - start) * 1000 # Find complete message complete = next((c for c in chunks if c.get("type") == "complete"), None) routing = next((c for c in chunks if c.get("type") == "routing"), None) success = complete is not None and len(chunks) > 2 return TestResult( query=test["question"], description=test["description"], success=success, duration_ms=duration_ms, response={ "chunk_count": len(chunks), "chunk_types": [c.get("type") for c in chunks], "intent": routing.get("intent") if routing else None, "answer_preview": complete.get("answer", "")[:100] if complete else None, }, ) except Exception as e: duration_ms = (time.time() - start) * 1000 return TestResult( query=test["question"], description=test["description"], success=False, duration_ms=duration_ms, error=str(e), ) async def test_api_endpoints(base_url: str = "http://localhost:8000") -> list[TestResult]: """Test all API endpoints.""" results = [] print("\n" + "=" * 60) print(f"TESTING DSPy RAG API ENDPOINTS ({base_url})") print("=" * 60) async with httpx.AsyncClient(base_url=base_url) as client: # Test health endpoint health_ok = await test_api_health(client) if not health_ok: print("\nWARNING: DSPy RAG not available, some tests may fail") # Test training data endpoint await test_api_training_data(client) # Test query endpoint with all test queries print(f"\n--- Testing /api/dspy/rag/query ({len(TEST_QUERIES)} queries) ---") for i, test in enumerate(TEST_QUERIES, 1): print(f"[{i}/{len(TEST_QUERIES)}] {test['description']}...") result = await test_api_query(client, test) results.append(result) if result.success: intent_match = result.response.get("intent_match", False) if result.response else False intent_status = "OK" if intent_match else "MISMATCH" print(f" [PASS] {result.duration_ms:.0f}ms - Intent: {intent_status}") else: print(f" [FAIL] {result.error or 'Unknown error'}") return results async def test_streaming_endpoint(base_url: str = "http://localhost:8000") -> list[TestResult]: """Test streaming endpoint specifically.""" results = [] print("\n" + "=" * 60) print(f"TESTING DSPy RAG STREAMING ENDPOINT ({base_url})") print("=" * 60) # Test a subset of queries for streaming stream_tests = TEST_QUERIES[:3] async with httpx.AsyncClient(base_url=base_url) as client: for i, test in enumerate(stream_tests, 1): print(f"\n[{i}/{len(stream_tests)}] {test['description']}...") result = await test_api_stream(client, test) results.append(result) if result.success: chunk_count = result.response.get("chunk_count", 0) if result.response else 0 print(f" [PASS] {result.duration_ms:.0f}ms - Received {chunk_count} chunks") if result.response: print(f" Chunk types: {result.response.get('chunk_types')}") else: print(f" [FAIL] {result.error or 'Unknown error'}") return results # ============================================================================= # MAIN # ============================================================================= def print_summary(results: list[TestResult], title: str) -> None: """Print test summary.""" if not results: return passed = sum(1 for r in results if r.success) failed = len(results) - passed avg_duration = sum(r.duration_ms for r in results) / len(results) print("\n" + "=" * 60) print(f"SUMMARY: {title}") print("=" * 60) print(f"Total tests: {len(results)}") print(f"Passed: {passed}") print(f"Failed: {failed}") print(f"Success rate: {passed/len(results)*100:.1f}%") print(f"Average duration: {avg_duration:.0f}ms") if failed > 0: print("\nFailed tests:") for r in results: if not r.success: print(f" - {r.description}: {r.error or 'Unknown error'}") async def main(): parser = argparse.ArgumentParser(description="Test DSPy Heritage RAG") parser.add_argument("--module", action="store_true", help="Test DSPy module directly") parser.add_argument("--api", action="store_true", help="Test API endpoints") parser.add_argument("--stream", action="store_true", help="Test streaming endpoint") parser.add_argument("--all", action="store_true", help="Run all tests") parser.add_argument("--url", default="http://localhost:8000", help="API base URL") args = parser.parse_args() # Default to --all if no specific test selected if not any([args.module, args.api, args.stream, args.all]): args.all = True all_results = [] if args.module or args.all: results = await test_dspy_module() all_results.extend(results) print_summary(results, "DSPy Module Tests") await test_training_data() if args.api or args.all: results = await test_api_endpoints(args.url) all_results.extend(results) print_summary(results, "API Endpoint Tests") if args.stream or args.all: results = await test_streaming_endpoint(args.url) all_results.extend(results) print_summary(results, "Streaming Tests") # Overall summary if all_results: print_summary(all_results, "ALL TESTS") # Exit code based on success passed = sum(1 for r in all_results if r.success) sys.exit(0 if passed == len(all_results) else 1) if __name__ == "__main__": asyncio.run(main())