311 lines
11 KiB
Python
311 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test Schema-Driven LLM Annotation.
|
|
|
|
Tests the GLiNER2-style schema builder and LLM annotator integration.
|
|
"""
|
|
|
|
import asyncio
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import pytest # noqa: E402
|
|
|
|
# Add src directory to path for imports
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
|
|
from dotenv import load_dotenv # noqa: E402
|
|
load_dotenv()
|
|
|
|
from glam_extractor.annotators.llm_annotator import (
|
|
LLMAnnotator,
|
|
LLMAnnotatorConfig,
|
|
LLMProvider,
|
|
)
|
|
from glam_extractor.annotators.schema_builder import (
|
|
GLAMSchema,
|
|
FieldSpec,
|
|
heritage_custodian_schema,
|
|
)
|
|
|
|
# Sample HTML for testing
|
|
SAMPLE_HTML = """
|
|
<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>Historische Kring Elden - Home</title>
|
|
<meta name="description" content="De Historische Kring Elden is een vereniging die zich bezighoudt met de geschiedenis van Elden en omgeving.">
|
|
</head>
|
|
<body>
|
|
<header>
|
|
<h1>Historische Kring Elden</h1>
|
|
<nav>
|
|
<ul>
|
|
<li><a href="/">Home</a></li>
|
|
<li><a href="/over-ons">Over ons</a></li>
|
|
<li><a href="/archief">Archief</a></li>
|
|
</ul>
|
|
</nav>
|
|
</header>
|
|
<main>
|
|
<section class="intro">
|
|
<h2>Welkom bij de Historische Kring Elden</h2>
|
|
<p>De Historische Kring Elden is opgericht in 1985 en houdt zich bezig met het verzamelen en bewaren van historisch materiaal over Elden en omgeving.</p>
|
|
<p>Wij hebben een collectie van meer dan 5.000 foto's, documenten en objecten.</p>
|
|
</section>
|
|
<section class="contact">
|
|
<h2>Contact</h2>
|
|
<p>Email: <a href="mailto:info@historischekringelden.nl">info@historischekringelden.nl</a></p>
|
|
<p>Adres: Dorpsstraat 12, 6832 AA Elden</p>
|
|
<p>Telefoon: 026-1234567</p>
|
|
</section>
|
|
<section class="social">
|
|
<h2>Volg ons</h2>
|
|
<ul>
|
|
<li><a href="https://facebook.com/hkelden">Facebook</a></li>
|
|
<li><a href="https://instagram.com/hkelden">Instagram</a></li>
|
|
</ul>
|
|
</section>
|
|
</main>
|
|
<footer>
|
|
<p>© 2024 Historische Kring Elden | KvK: 12345678</p>
|
|
</footer>
|
|
</body>
|
|
</html>
|
|
"""
|
|
|
|
|
|
def test_gliner2_field_syntax():
|
|
"""Test GLiNER2-style field specification parsing."""
|
|
print("\n" + "="*60)
|
|
print("TEST 1: GLiNER2-style Field Syntax Parsing")
|
|
print("="*60 + "\n")
|
|
|
|
test_cases = [
|
|
("name::str::Institution name", {"name": "name", "dtype": "str", "description": "Institution name"}),
|
|
("type::[MUSEUM|ARCHIVE|LIBRARY]::str::Type", {"name": "type", "dtype": "str", "choices": ["MUSEUM", "ARCHIVE", "LIBRARY"]}),
|
|
("features::[indoor|outdoor]::list::Available features", {"name": "features", "dtype": "list", "choices": ["indoor", "outdoor"]}),
|
|
("price::str::Monthly cost", {"name": "price", "dtype": "str", "description": "Monthly cost"}),
|
|
("simple_field", {"name": "simple_field", "dtype": "str"}),
|
|
]
|
|
|
|
for spec_string, expected in test_cases:
|
|
print(f" Input: '{spec_string}'")
|
|
parsed = FieldSpec.from_gliner2_syntax(spec_string)
|
|
print(f" → name: {parsed.name}")
|
|
print(f" → dtype: {parsed.dtype}")
|
|
print(f" → choices: {parsed.choices}")
|
|
print(f" → description: {parsed.description}")
|
|
|
|
# Verify
|
|
assert parsed.name == expected["name"], f"Name mismatch: {parsed.name} != {expected['name']}"
|
|
assert parsed.dtype == expected["dtype"], f"Dtype mismatch: {parsed.dtype} != {expected['dtype']}"
|
|
if "choices" in expected:
|
|
assert parsed.choices == expected["choices"], f"Choices mismatch"
|
|
|
|
print(" ✓ PASSED\n")
|
|
|
|
print("All field syntax tests passed!\n")
|
|
|
|
|
|
def test_schema_builder():
|
|
"""Test schema builder with fluent API."""
|
|
print("\n" + "="*60)
|
|
print("TEST 2: Schema Builder Fluent API")
|
|
print("="*60 + "\n")
|
|
|
|
# Build schema using GLiNER2-style syntax
|
|
schema = (
|
|
GLAMSchema("test_institution")
|
|
.entities("GRP", "TOP", "AGT")
|
|
.classification("type", choices=["MUSEUM", "ARCHIVE", "COLLECTING_SOCIETY"])
|
|
.structure()
|
|
.field("name", dtype="str", required=True, description="Institution name")
|
|
.field("email", dtype="str", pattern=r"^[^@]+@[^@]+\.[^@]+$")
|
|
.field("city", dtype="str")
|
|
.build()
|
|
)
|
|
|
|
print(f" Schema name: {schema.name}")
|
|
print(f" Entity types: {schema.entity_types}")
|
|
print(f" Classifications: {list(schema.classifications.keys())}")
|
|
print(f" Fields: {[f.name for f in schema.fields]}")
|
|
|
|
# Generate LLM prompt
|
|
prompt = schema.to_llm_prompt()
|
|
print(f"\n Prompt length: {len(prompt)} chars")
|
|
print(f" Prompt preview:\n{prompt[:500]}...\n")
|
|
|
|
# Generate JSON Schema
|
|
json_schema = schema.to_json_schema()
|
|
print(f" JSON Schema properties: {list(json_schema['properties'].keys())}")
|
|
print(f" Required fields: {json_schema['required']}")
|
|
|
|
print("\n ✓ Schema builder test passed!\n")
|
|
|
|
|
|
def test_json_schema_generation():
|
|
"""Test JSON Schema generation from FieldSpec."""
|
|
print("\n" + "="*60)
|
|
print("TEST 3: JSON Schema Generation from FieldSpec")
|
|
print("="*60 + "\n")
|
|
|
|
# Test various field types
|
|
test_fields = [
|
|
FieldSpec.from_gliner2_syntax("name::str::Institution name"),
|
|
FieldSpec.from_gliner2_syntax("type::[MUSEUM|ARCHIVE]::str::Type"),
|
|
FieldSpec.from_gliner2_syntax("tags::[indoor|outdoor|accessible]::list::Features"),
|
|
FieldSpec(name="count", dtype="int", description="Item count"),
|
|
FieldSpec(name="rating", dtype="float", description="Average rating"),
|
|
FieldSpec(name="is_open", dtype="bool", description="Currently open"),
|
|
FieldSpec(name="founded", dtype="date", description="Founding date"),
|
|
]
|
|
|
|
for field in test_fields:
|
|
prop = field.to_json_schema_property()
|
|
print(f" {field.name} ({field.dtype}):")
|
|
print(f" JSON Schema: {prop}")
|
|
|
|
print("\n ✓ JSON Schema generation test passed!\n")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_schema_driven_annotation():
|
|
"""Test schema-driven LLM annotation."""
|
|
print("\n" + "="*60)
|
|
print("TEST 4: Schema-Driven LLM Annotation")
|
|
print("="*60 + "\n")
|
|
|
|
# Create custom schema
|
|
schema = (
|
|
GLAMSchema("historical_society")
|
|
.entities("GRP", "TOP", "TMP", "APP", "QTY")
|
|
.classification("institution_type", choices=["MUSEUM", "ARCHIVE", "COLLECTING_SOCIETY", "HISTORICAL_SOCIETY"])
|
|
.structure()
|
|
.field("full_name", dtype="str", required=True, description="Official institution name")
|
|
.field("description", dtype="str", description="Brief description")
|
|
.field("email", dtype="str", description="Contact email")
|
|
.field("phone", dtype="str", description="Contact phone")
|
|
.field("address", dtype="str", description="Physical address")
|
|
.field("city", dtype="str", description="City")
|
|
.field("founding_year", dtype="str", description="Year founded")
|
|
.field("collection_size", dtype="str", description="Size of collection")
|
|
.field("kvk_number", dtype="str", description="Chamber of Commerce number")
|
|
.field("social_facebook", dtype="str", description="Facebook URL")
|
|
.field("social_instagram", dtype="str", description="Instagram URL")
|
|
.build()
|
|
)
|
|
|
|
print(f" Schema: {schema.name}")
|
|
print(f" Fields: {[f.name for f in schema.fields]}")
|
|
|
|
# Create annotator
|
|
try:
|
|
annotator = LLMAnnotator(LLMAnnotatorConfig(
|
|
provider=LLMProvider.ZAI,
|
|
model="glm-4.6",
|
|
))
|
|
except ValueError as e:
|
|
print(f"\n ⚠️ Skipping LLM test: {e}")
|
|
print(" Set ZAI_API_TOKEN in .env to run this test\n")
|
|
return
|
|
|
|
print("\n Calling LLM with schema-driven prompt...")
|
|
|
|
# Run annotation
|
|
session, structured_data = await annotator.annotate_with_schema(
|
|
document=SAMPLE_HTML,
|
|
schema=schema,
|
|
source_url="https://historischekringelden.nl",
|
|
)
|
|
|
|
print(f"\n Session ID: {session.session_id}")
|
|
print(f" Entities: {len(session.entity_claims)}")
|
|
print(f" Layout regions: {len(session.layout_claims)}")
|
|
print(f" Claims: {len(session.aggregate_claims)}")
|
|
print(f" Errors: {session.errors}")
|
|
|
|
print(f"\n Structured Data:")
|
|
for key, value in structured_data.items():
|
|
if not key.startswith('_'):
|
|
print(f" {key}: {value}")
|
|
|
|
# Show some entities
|
|
if session.entity_claims:
|
|
print(f"\n Sample Entities:")
|
|
for claim in session.entity_claims[:5]:
|
|
hypernym = claim.hypernym.value if hasattr(claim.hypernym, 'value') else str(claim.hypernym)
|
|
print(f" [{hypernym}] {claim.text_content} (conf: {claim.recognition_confidence:.2f})")
|
|
|
|
print("\n ✓ Schema-driven annotation test passed!\n")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_quick_extraction():
|
|
"""Test the quick extraction method with GLiNER2 syntax."""
|
|
print("\n" + "="*60)
|
|
print("TEST 5: Quick Structured Extraction")
|
|
print("="*60 + "\n")
|
|
|
|
# Create annotator
|
|
try:
|
|
annotator = LLMAnnotator(LLMAnnotatorConfig(
|
|
provider=LLMProvider.ZAI,
|
|
model="glm-4.6",
|
|
))
|
|
except ValueError as e:
|
|
print(f"\n ⚠️ Skipping LLM test: {e}")
|
|
print(" Set ZAI_API_TOKEN in .env to run this test\n")
|
|
return
|
|
|
|
# Define fields using GLiNER2 syntax
|
|
fields = [
|
|
"name::str::Institution name",
|
|
"email::str::Contact email",
|
|
"phone::str::Phone number",
|
|
"city::str::City",
|
|
"type::[MUSEUM|ARCHIVE|HISTORICAL_SOCIETY]::str::Institution type",
|
|
]
|
|
|
|
print(f" Fields to extract:")
|
|
for f in fields:
|
|
print(f" • {f}")
|
|
|
|
print("\n Running quick extraction...")
|
|
|
|
result = await annotator.extract_structured(
|
|
document=SAMPLE_HTML,
|
|
fields=fields,
|
|
source_url="https://historischekringelden.nl",
|
|
)
|
|
|
|
print(f"\n Results:")
|
|
for key, value in result.items():
|
|
if not key.startswith('_'):
|
|
print(f" {key}: {value}")
|
|
|
|
print("\n ✓ Quick extraction test passed!\n")
|
|
|
|
|
|
async def main():
|
|
"""Run all tests."""
|
|
print("\n" + "="*60)
|
|
print("GLAM Schema-Driven Annotation Tests")
|
|
print("="*60)
|
|
|
|
# Run sync tests
|
|
test_gliner2_field_syntax()
|
|
test_schema_builder()
|
|
test_json_schema_generation()
|
|
|
|
# Run async tests (require API key)
|
|
await test_schema_driven_annotation()
|
|
await test_quick_extraction()
|
|
|
|
print("\n" + "="*60)
|
|
print("ALL TESTS COMPLETED")
|
|
print("="*60 + "\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|