263 lines
7.8 KiB
Python
263 lines
7.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
End-to-end test of the LLM Annotator with a real NDE entry.
|
|
|
|
This script tests the complete annotation pipeline:
|
|
1. Load an archived HTML page from an NDE entry
|
|
2. Run LLM-based entity extraction
|
|
3. Validate the results
|
|
|
|
Usage:
|
|
python scripts/test_llm_annotator.py
|
|
|
|
Environment Variables:
|
|
ANTHROPIC_API_KEY - Required for Claude (preferred for testing)
|
|
ZAI_API_TOKEN - Required for Z.AI GLM-4
|
|
OPENAI_API_KEY - Required for OpenAI GPT-4
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(project_root / "src"))
|
|
|
|
# Load environment variables from .env file
|
|
from dotenv import load_dotenv
|
|
load_dotenv(project_root / ".env")
|
|
|
|
from glam_extractor.annotators import (
|
|
create_llm_annotator,
|
|
heritage_custodian_schema,
|
|
LLMAnnotator,
|
|
LLMAnnotatorConfig,
|
|
LLMProvider,
|
|
)
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Test data paths
|
|
NDE_ENTRY_PATH = project_root / "data/nde/enriched/entries/0000_Q22246632.yaml"
|
|
HTML_PAGE_PATH = project_root / "data/nde/enriched/entries/web/0000/kampwesterbork.nl/pages/collectie.html"
|
|
|
|
|
|
def check_api_keys() -> tuple[str | None, list[str]]:
|
|
"""Check which API keys are available and return preferred provider."""
|
|
available = []
|
|
|
|
if os.environ.get("ANTHROPIC_API_KEY"):
|
|
available.append("anthropic")
|
|
if os.environ.get("ZAI_API_TOKEN"):
|
|
available.append("zai")
|
|
if os.environ.get("OPENAI_API_KEY"):
|
|
available.append("openai")
|
|
|
|
if not available:
|
|
return None, []
|
|
|
|
# Prefer Anthropic for testing (most reliable)
|
|
if "anthropic" in available:
|
|
return "anthropic", available
|
|
return available[0], available
|
|
|
|
|
|
async def test_basic_annotation():
|
|
"""Test basic LLM annotation with a heritage institution HTML page."""
|
|
print("\n" + "="*60)
|
|
print("TEST: Basic LLM Annotation")
|
|
print("="*60)
|
|
|
|
# Check API keys
|
|
provider, available = check_api_keys()
|
|
|
|
if not provider:
|
|
print("\n⚠️ SKIPPED: No API keys found!")
|
|
print("To run this test, set one of:")
|
|
print(" - ANTHROPIC_API_KEY")
|
|
print(" - ZAI_API_TOKEN")
|
|
print(" - OPENAI_API_KEY")
|
|
return None # Skipped
|
|
|
|
print(f"\nUsing provider: {provider}")
|
|
print(f"Available providers: {available}")
|
|
|
|
# Create annotator with longer retry delays for rate limits
|
|
print("\nCreating LLM annotator...")
|
|
from glam_extractor.annotators import RetryConfig, LLMAnnotatorConfig, LLMProvider, LLMAnnotator
|
|
|
|
# Use longer delays for Z.AI rate limits (60s reset)
|
|
retry_config = RetryConfig(
|
|
max_retries=5,
|
|
base_delay=10.0, # Start with 10s delay
|
|
max_delay=120.0, # Cap at 2 minutes
|
|
exponential_base=2.0,
|
|
jitter=True,
|
|
)
|
|
|
|
provider_enum = LLMProvider(provider)
|
|
config = LLMAnnotatorConfig(
|
|
provider=provider_enum,
|
|
retry=retry_config,
|
|
fallback_providers=[], # Disable fallback for focused testing
|
|
)
|
|
annotator = LLMAnnotator(config)
|
|
|
|
# Load HTML page
|
|
if not HTML_PAGE_PATH.exists():
|
|
print(f"ERROR: Test file not found: {HTML_PAGE_PATH}")
|
|
return False
|
|
|
|
print(f"\nLoading HTML file: {HTML_PAGE_PATH.name}")
|
|
with open(HTML_PAGE_PATH, 'r', encoding='utf-8') as f:
|
|
html_content = f.read()
|
|
|
|
print(f"HTML size: {len(html_content):,} bytes")
|
|
|
|
# Run annotation
|
|
print("\nRunning LLM annotation (this may take 30-60 seconds)...")
|
|
try:
|
|
session = await annotator.annotate(
|
|
document=HTML_PAGE_PATH,
|
|
source_url="https://kampwesterbork.nl/collectie",
|
|
)
|
|
except Exception as e:
|
|
print(f"ERROR: Annotation failed: {e}")
|
|
return False
|
|
|
|
# Print results
|
|
print("\n" + "-"*40)
|
|
print("ANNOTATION RESULTS")
|
|
print("-"*40)
|
|
|
|
print(f"\nSession ID: {session.session_id}")
|
|
print(f"Source URL: {session.source_url}")
|
|
print(f"Completed at: {session.completed_at}")
|
|
print(f"Errors: {len(session.errors)}")
|
|
|
|
if session.errors:
|
|
print("\nErrors encountered:")
|
|
for error in session.errors:
|
|
print(f" - {error}")
|
|
|
|
print(f"\n📊 Entity Claims: {len(session.entity_claims)}")
|
|
for claim in session.entity_claims[:10]: # First 10
|
|
text = (claim.text_content or "")[:50]
|
|
conf = claim.recognition_confidence
|
|
print(f" [{claim.hypernym.value}] {text}... (conf: {conf:.2f})")
|
|
if len(session.entity_claims) > 10:
|
|
print(f" ... and {len(session.entity_claims) - 10} more")
|
|
|
|
print(f"\n📄 Layout Claims: {len(session.layout_claims)}")
|
|
for claim in session.layout_claims[:5]: # First 5
|
|
text = (claim.text_content or "")[:50]
|
|
print(f" [{claim.region.value}] {text}...")
|
|
if len(session.layout_claims) > 5:
|
|
print(f" ... and {len(session.layout_claims) - 5} more")
|
|
|
|
print(f"\n📋 Aggregate Claims: {len(session.aggregate_claims)}")
|
|
for claim in session.aggregate_claims[:10]: # First 10
|
|
value = (claim.claim_value or "")[:60] if claim.claim_value else ""
|
|
conf = claim.provenance.confidence if claim.provenance else 0.5
|
|
print(f" [{claim.claim_type}]: {value}... (conf: {conf:.2f})")
|
|
if len(session.aggregate_claims) > 10:
|
|
print(f" ... and {len(session.aggregate_claims) - 10} more")
|
|
|
|
# Validate we got some results
|
|
success = (
|
|
len(session.entity_claims) > 0 or
|
|
len(session.aggregate_claims) > 0
|
|
)
|
|
|
|
print("\n" + "="*60)
|
|
if success:
|
|
print("✅ TEST PASSED: LLM annotation completed with results")
|
|
else:
|
|
print("❌ TEST FAILED: No entities or claims extracted")
|
|
print("="*60)
|
|
|
|
return success
|
|
|
|
|
|
async def test_schema_builder():
|
|
"""Test the schema builder generates valid prompts."""
|
|
print("\n" + "="*60)
|
|
print("TEST: Schema Builder")
|
|
print("="*60)
|
|
|
|
# Create heritage custodian schema
|
|
schema = heritage_custodian_schema()
|
|
|
|
print(f"\nSchema: {schema.name}")
|
|
print(f"Entity types: {schema.entity_types}")
|
|
print(f"Fields: {len(schema.fields)}")
|
|
print(f"Relations: {schema.relation_types}")
|
|
|
|
# Generate prompt
|
|
prompt = schema.to_llm_prompt()
|
|
print(f"\nGenerated prompt length: {len(prompt)} chars")
|
|
print("\nPrompt preview (first 500 chars):")
|
|
print("-"*40)
|
|
print(prompt[:500])
|
|
print("-"*40)
|
|
|
|
# Generate JSON schema
|
|
json_schema = schema.to_json_schema()
|
|
print(f"\nJSON Schema properties: {list(json_schema['properties'].keys())}")
|
|
print(f"Required fields: {json_schema['required']}")
|
|
|
|
print("\n" + "="*60)
|
|
print("✅ TEST PASSED: Schema builder works correctly")
|
|
print("="*60)
|
|
|
|
return True
|
|
|
|
|
|
async def main():
|
|
"""Run all tests."""
|
|
print("\n" + "#"*60)
|
|
print("# LLM ANNOTATOR END-TO-END TESTS")
|
|
print("#"*60)
|
|
|
|
results = []
|
|
|
|
# Test schema builder (no API call needed)
|
|
results.append(("Schema Builder", await test_schema_builder()))
|
|
|
|
# Test basic annotation (requires API key)
|
|
results.append(("LLM Annotation", await test_basic_annotation()))
|
|
|
|
# Summary
|
|
print("\n" + "#"*60)
|
|
print("# TEST SUMMARY")
|
|
print("#"*60)
|
|
|
|
all_passed = True
|
|
for name, passed in results:
|
|
if passed is None:
|
|
status = "⚠️ SKIP"
|
|
elif passed:
|
|
status = "✅ PASS"
|
|
else:
|
|
status = "❌ FAIL"
|
|
all_passed = False
|
|
print(f" {status}: {name}")
|
|
|
|
print("#"*60)
|
|
|
|
return 0 if all_passed else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit_code = asyncio.run(main())
|
|
sys.exit(exit_code)
|