glam/tests/annotators/test_hybrid_annotator.py
2025-12-05 16:25:39 +01:00

445 lines
15 KiB
Python

"""
Tests for the Hybrid GLiNER2 + LLM Annotator.
This module tests the hybrid annotation pipeline that combines fast
encoder-based NER (GLiNER2) with LLM reasoning for comprehensive
entity and relationship extraction.
Tests are organized by component:
1. AnnotationCandidate and RelationshipCandidate data classes
2. GLiNER2Annotator wrapper
3. HybridAnnotator pipeline stages
4. Integration tests (requires LLM API)
"""
import pytest
from datetime import datetime
from unittest.mock import Mock, patch, AsyncMock
# Import from annotators module
from glam_extractor.annotators.hybrid_annotator import (
AnnotationCandidate,
RelationshipCandidate,
CandidateSource,
CandidateStatus,
GLiNER2Annotator,
HybridAnnotator,
HybridConfig,
GLINER2_TO_GLAM_MAPPING,
)
from glam_extractor.annotators.base import (
EntityClaim,
Provenance,
EntityHypernym,
)
# =============================================================================
# FIXTURES
# =============================================================================
@pytest.fixture
def sample_text():
"""Sample text for testing entity extraction."""
return """
The Rijksmuseum in Amsterdam, founded in 1800, houses over 8,000 objects
of art and history. Vincent van Gogh's "The Night Watch" by Rembrandt
is the most famous painting in the collection. The museum is located at
Museumstraat 1, 1071 XX Amsterdam, Netherlands.
"""
@pytest.fixture
def sample_candidates():
"""Pre-defined annotation candidates for testing."""
return [
AnnotationCandidate(
text="Rijksmuseum",
start_offset=8,
end_offset=19,
hypernym="GRP",
hyponym="GRP.CUS",
overall_confidence=0.92,
source=CandidateSource.GLINER2,
),
AnnotationCandidate(
text="Amsterdam",
start_offset=23,
end_offset=32,
hypernym="GEO",
hyponym="GEO.PPL",
overall_confidence=0.88,
source=CandidateSource.GLINER2,
),
AnnotationCandidate(
text="1800",
start_offset=45,
end_offset=49,
hypernym="TMP",
hyponym="TMP.DAB",
overall_confidence=0.75,
source=CandidateSource.GLINER2,
),
AnnotationCandidate(
text="Vincent van Gogh",
start_offset=82,
end_offset=98,
hypernym="AGT",
hyponym="AGT.PER",
overall_confidence=0.95,
source=CandidateSource.GLINER2,
),
]
@pytest.fixture
def hybrid_config():
"""Default hybrid config for testing."""
return HybridConfig(
enable_fast_pass=True,
enable_refinement=False, # Disable LLM for unit tests
enable_validation=True,
enable_relationships=False,
)
# =============================================================================
# ANNOTATION CANDIDATE TESTS
# =============================================================================
class TestAnnotationCandidate:
"""Tests for AnnotationCandidate data class."""
def test_create_minimal(self):
"""Test creating candidate with minimal required fields."""
candidate = AnnotationCandidate(
text="Museum",
start_offset=0,
end_offset=6,
hypernym="GRP",
)
assert candidate.text == "Museum"
assert candidate.start_offset == 0
assert candidate.end_offset == 6
assert candidate.hypernym == "GRP"
assert candidate.hyponym is None
assert candidate.overall_confidence == 0.0
assert candidate.source == CandidateSource.GLINER2
assert candidate.status == CandidateStatus.DETECTED
def test_create_with_all_fields(self):
"""Test creating candidate with all fields."""
candidate = AnnotationCandidate(
text="Rijksmuseum Amsterdam",
start_offset=10,
end_offset=31,
hypernym="GRP",
hyponym="GRP.CUS",
overall_confidence=0.95,
source=CandidateSource.LLM,
status=CandidateStatus.VALIDATED,
wikidata_id="Q190804",
)
assert candidate.hyponym == "GRP.CUS"
assert candidate.overall_confidence == 0.95
assert candidate.source == CandidateSource.LLM
assert candidate.status == CandidateStatus.VALIDATED
assert candidate.wikidata_id == "Q190804"
def test_span_property(self):
"""Test span tuple via start/end offsets."""
candidate = AnnotationCandidate(
text="Test",
start_offset=5,
end_offset=9,
hypernym="THG",
)
assert candidate.start_offset == 5
assert candidate.end_offset == 9
assert (candidate.start_offset, candidate.end_offset) == (5, 9)
def test_to_entity_claim_basic(self):
"""Test conversion to EntityClaim."""
candidate = AnnotationCandidate(
text="Amsterdam",
start_offset=10,
end_offset=19,
hypernym="GEO",
hyponym="GEO.PPL",
overall_confidence=0.9,
source=CandidateSource.HYBRID,
)
claim = candidate.to_entity_claim()
assert isinstance(claim, EntityClaim)
assert claim.text_content == "Amsterdam"
assert claim.hyponym == "GEO.PPL"
assert claim.start_offset == 10
assert claim.end_offset == 19
assert claim.recognition_confidence == 0.9
def test_from_entity_claim(self):
"""Test creation from EntityClaim."""
claim = EntityClaim(
text_content="Van Gogh Museum",
hypernym=EntityHypernym.GRP,
hyponym="GRP.CUS",
start_offset=0,
end_offset=15,
recognition_confidence=0.88,
provenance=Provenance(
namespace="glam-ner",
path="/test/document.html",
timestamp=datetime.now().isoformat(),
agent="glm-4",
),
)
candidate = AnnotationCandidate.from_entity_claim(claim)
assert candidate.text == "Van Gogh Museum"
assert candidate.hypernym == "GRP"
assert candidate.hyponym == "GRP.CUS"
assert candidate.start_offset == 0
assert candidate.end_offset == 15
assert candidate.overall_confidence == 0.88
# =============================================================================
# RELATIONSHIP CANDIDATE TESTS
# =============================================================================
class TestRelationshipCandidate:
"""Tests for RelationshipCandidate data class."""
def test_create_basic(self, sample_candidates):
"""Test creating a relationship candidate."""
subject = sample_candidates[0] # Rijksmuseum
object_ = sample_candidates[1] # Amsterdam
rel = RelationshipCandidate(
subject_id=subject.candidate_id,
subject_text=subject.text,
subject_type=subject.hyponym or subject.hypernym,
object_id=object_.candidate_id,
object_text=object_.text,
object_type=object_.hyponym or object_.hypernym,
relationship_type="REL.SPA.LOC",
relationship_label="located in",
confidence=0.85,
)
assert rel.subject_text == "Rijksmuseum"
assert rel.relationship_type == "REL.SPA.LOC"
assert rel.object_text == "Amsterdam"
assert rel.confidence == 0.85
assert rel.is_valid is True # Default
# =============================================================================
# GLINER2 ANNOTATOR TESTS
# =============================================================================
class TestGLiNER2Annotator:
"""Tests for GLiNER2Annotator wrapper class."""
def test_init_default(self):
"""Test default initialization (GLiNER2 optional)."""
annotator = GLiNER2Annotator()
# Should initialize without error even if gliner not installed
assert annotator.threshold == 0.5
assert annotator.entity_labels is not None
def test_is_available_property(self):
"""Test is_available property."""
annotator = GLiNER2Annotator()
# Returns True if gliner is installed, False otherwise
# We don't require gliner for tests
assert isinstance(annotator.is_available, bool)
def test_default_entity_labels(self):
"""Test default entity labels exist."""
annotator = GLiNER2Annotator()
# Should have entity labels list
assert annotator.entity_labels is not None
assert len(annotator.entity_labels) > 0
# =============================================================================
# GLINER2 TO GLAM MAPPING TESTS
# =============================================================================
class TestGLiNER2ToGLAMMapping:
"""Tests for the GLINER2_TO_GLAM_MAPPING dictionary."""
def test_mapping_exists(self):
"""Test that mapping dictionary exists and has entries."""
assert GLINER2_TO_GLAM_MAPPING is not None
assert len(GLINER2_TO_GLAM_MAPPING) > 0
def test_mapping_values_format(self):
"""Test that mapping values are valid GLAM-NER format."""
for gliner_label, hyponym in GLINER2_TO_GLAM_MAPPING.items():
# Hyponym should be either a 3-letter hypernym OR HYPERNYM.SUBCATEGORY format
if "." in hyponym:
parts = hyponym.split(".")
assert len(parts) >= 2
assert parts[0].isupper() # Hypernym is uppercase
else:
# Top-level hypernym only (e.g., "GRP", "TOP")
assert len(hyponym) == 3
assert hyponym.isupper()
def test_person_mapping(self):
"""Test person mapping exists."""
assert "person" in GLINER2_TO_GLAM_MAPPING
assert GLINER2_TO_GLAM_MAPPING["person"] == "AGT.PER"
def test_organization_mapping(self):
"""Test organization mapping exists."""
assert "organization" in GLINER2_TO_GLAM_MAPPING
# =============================================================================
# HYBRID CONFIG TESTS
# =============================================================================
class TestHybridConfig:
"""Tests for HybridConfig dataclass."""
def test_default_values(self):
"""Test default configuration values."""
config = HybridConfig()
assert config.gliner_model == "urchade/gliner_multi-v2.1"
assert config.gliner_threshold == 0.5
assert config.llm_model == "glm-4"
assert config.enable_fast_pass is True
assert config.enable_refinement is True
assert config.enable_validation is True
assert config.enable_relationships is True
assert config.merge_threshold == 0.3
assert config.prefer_llm_on_conflict is True
def test_custom_values(self):
"""Test custom configuration."""
config = HybridConfig(
gliner_model="urchade/gliner_small",
gliner_threshold=0.7,
llm_model="glm-4.5",
enable_fast_pass=False,
enable_refinement=True,
enable_validation=False,
merge_threshold=0.5,
)
assert config.gliner_model == "urchade/gliner_small"
assert config.gliner_threshold == 0.7
assert config.enable_fast_pass is False
assert config.merge_threshold == 0.5
# =============================================================================
# HYBRID ANNOTATOR TESTS
# =============================================================================
class TestHybridAnnotator:
"""Tests for HybridAnnotator class."""
def test_init_default(self):
"""Test default initialization."""
annotator = HybridAnnotator()
assert annotator.config is not None
assert annotator.config.enable_fast_pass is True
def test_init_with_config(self, hybrid_config):
"""Test initialization with custom config."""
annotator = HybridAnnotator(config=hybrid_config)
assert annotator.config.enable_refinement is False
def test_gliner_available_property(self):
"""Test gliner_available property."""
annotator = HybridAnnotator()
# Should return bool regardless of installation status
assert isinstance(annotator.gliner_available, bool)
def test_fast_pass_disabled(self, sample_text):
"""Test fast_pass when disabled in config."""
config = HybridConfig(enable_fast_pass=False)
annotator = HybridAnnotator(config=config)
result = annotator.fast_pass(sample_text)
assert result == []
# =============================================================================
# EDGE CASES
# =============================================================================
class TestEdgeCases:
"""Test edge cases and error handling."""
def test_empty_text(self):
"""Test handling of empty text."""
annotator = HybridAnnotator()
result = annotator.fast_pass("")
assert result == []
def test_whitespace_only_text(self):
"""Test handling of whitespace-only text."""
annotator = HybridAnnotator()
result = annotator.fast_pass(" \n\t ")
assert result == []
def test_very_long_text(self):
"""Test handling of very long text."""
config = HybridConfig(enable_fast_pass=True, enable_refinement=False)
annotator = HybridAnnotator(config=config)
long_text = "The Rijksmuseum in Amsterdam. " * 1000
# Should not raise, may return empty if GLiNER2 not available
result = annotator.fast_pass(long_text)
assert isinstance(result, list)
def test_unicode_text(self):
"""Test handling of Unicode text."""
annotator = HybridAnnotator()
unicode_text = "Le Musée du Louvre à Paris contient la Joconde de Léonard de Vinci."
result = annotator.fast_pass(unicode_text)
assert isinstance(result, list)
def test_candidate_with_zero_length(self):
"""Test handling candidate with zero length span."""
candidate = AnnotationCandidate(
text="",
start_offset=5,
end_offset=5, # Zero length
hypernym="THG",
)
assert candidate.start_offset == 5
assert candidate.end_offset == 5
if __name__ == "__main__":
pytest.main([__file__, "-v"])