338 lines
12 KiB
Python
338 lines
12 KiB
Python
"""
|
|
Tests for NLP-based institution extraction from conversation text.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime
|
|
|
|
from glam_extractor.extractors.nlp_extractor import (
|
|
InstitutionExtractor,
|
|
ExtractionPatterns,
|
|
ExtractedEntity,
|
|
Result
|
|
)
|
|
from glam_extractor.models import InstitutionType, DataSource, DataTier
|
|
from glam_extractor.parsers.conversation import Conversation, ChatMessage
|
|
|
|
|
|
class TestExtractionPatterns:
|
|
"""Test pattern matching regex patterns"""
|
|
|
|
def test_isil_pattern(self):
|
|
"""Test ISIL code extraction"""
|
|
patterns = ExtractionPatterns()
|
|
|
|
text = "The ISIL code NL-AsdRM identifies the Rijksmuseum."
|
|
matches = patterns.ISIL_PATTERN.findall(text)
|
|
|
|
assert len(matches) == 1
|
|
assert matches[0] == "NL-AsdRM"
|
|
|
|
def test_wikidata_pattern(self):
|
|
"""Test Wikidata ID extraction"""
|
|
patterns = ExtractionPatterns()
|
|
|
|
text = "The museum has Wikidata ID Q924335."
|
|
matches = patterns.WIKIDATA_PATTERN.findall(text)
|
|
|
|
assert len(matches) == 1
|
|
assert matches[0] == "Q924335"
|
|
|
|
def test_viaf_pattern(self):
|
|
"""Test VIAF ID extraction"""
|
|
patterns = ExtractionPatterns()
|
|
|
|
text = "See https://viaf.org/viaf/123456789 for more information."
|
|
matches = patterns.VIAF_URL_PATTERN.findall(text)
|
|
|
|
assert len(matches) == 1
|
|
assert matches[0] == "123456789"
|
|
|
|
def test_city_pattern(self):
|
|
"""Test city extraction from 'in [City]' pattern"""
|
|
patterns = ExtractionPatterns()
|
|
|
|
text = "The Rijksmuseum in Amsterdam is a major museum."
|
|
matches = patterns.CITY_PATTERN.findall(text)
|
|
|
|
assert len(matches) >= 1
|
|
assert "Amsterdam" in matches
|
|
|
|
|
|
class TestInstitutionExtractor:
|
|
"""Test institution extraction from text"""
|
|
|
|
@pytest.fixture
|
|
def extractor(self):
|
|
"""Create an InstitutionExtractor instance"""
|
|
return InstitutionExtractor()
|
|
|
|
def test_extract_from_text_simple(self, extractor):
|
|
"""Test extraction from simple text with ISIL code"""
|
|
text = "The Rijksmuseum in Amsterdam (ISIL: NL-AsdRM) is a major art museum."
|
|
|
|
result = extractor.extract_from_text(text)
|
|
|
|
assert result.success is True
|
|
assert isinstance(result.value, list)
|
|
assert len(result.value) > 0
|
|
|
|
# Check first extracted institution
|
|
institution = result.value[0]
|
|
assert "Rijksmuseum" in institution.name or "Museum" in institution.name
|
|
# Compare as strings since InstitutionTypeEnum != PermissibleValue
|
|
assert str(institution.institution_type) == 'MUSEUM'
|
|
assert str(institution.provenance.data_source) == 'CONVERSATION_NLP'
|
|
assert str(institution.provenance.data_tier) == 'TIER_4_INFERRED'
|
|
assert institution.provenance.confidence_score is not None
|
|
assert 0.0 <= institution.provenance.confidence_score <= 1.0
|
|
|
|
def test_extract_isil_identifier(self, extractor):
|
|
"""Test that ISIL identifier is captured"""
|
|
text = "The Amsterdam Museum has ISIL code NL-AsdAM."
|
|
|
|
result = extractor.extract_from_text(text)
|
|
|
|
assert result.success is True
|
|
assert len(result.value) > 0
|
|
|
|
institution = result.value[0]
|
|
assert len(institution.identifiers) > 0
|
|
|
|
# Find ISIL identifier
|
|
isil_ids = [i for i in institution.identifiers if i.identifier_scheme == "ISIL"]
|
|
assert len(isil_ids) > 0
|
|
assert isil_ids[0].identifier_value == "NL-AsdAM"
|
|
|
|
def test_extract_location(self, extractor):
|
|
"""Test location extraction"""
|
|
text = "The Museum of Modern Art in New York is a famous gallery."
|
|
|
|
result = extractor.extract_from_text(text)
|
|
|
|
assert result.success is True
|
|
assert len(result.value) > 0
|
|
|
|
institution = result.value[0]
|
|
# Should detect city from "in New York" pattern
|
|
if institution.locations:
|
|
assert institution.locations[0].city is not None
|
|
|
|
def test_extract_wikidata_identifier(self, extractor):
|
|
"""Test Wikidata ID extraction"""
|
|
text = "The British Museum (Q6373) is located in London."
|
|
|
|
result = extractor.extract_from_text(text)
|
|
|
|
assert result.success is True
|
|
assert len(result.value) > 0
|
|
|
|
institution = result.value[0]
|
|
wikidata_ids = [i for i in institution.identifiers if i.identifier_scheme == "Wikidata"]
|
|
|
|
if wikidata_ids: # May or may not be extracted depending on patterns
|
|
assert wikidata_ids[0].identifier_value == "Q6373"
|
|
assert "wikidata.org" in str(wikidata_ids[0].identifier_url)
|
|
|
|
def test_classify_library(self, extractor):
|
|
"""Test library classification"""
|
|
text = "The National Library of Brazil in Rio de Janeiro holds over 9 million items."
|
|
|
|
result = extractor.extract_from_text(text)
|
|
|
|
assert result.success is True
|
|
assert len(result.value) > 0
|
|
|
|
institution = result.value[0]
|
|
# Compare as strings since InstitutionTypeEnum != PermissibleValue
|
|
assert str(institution.institution_type) == 'LIBRARY'
|
|
|
|
def test_classify_archive(self, extractor):
|
|
"""Test archive classification"""
|
|
text = "The National Archives in Washington DC preserves government records."
|
|
|
|
result = extractor.extract_from_text(text)
|
|
|
|
assert result.success is True
|
|
assert len(result.value) > 0
|
|
|
|
institution = result.value[0]
|
|
# Should detect "archive" or "archives"
|
|
assert str(institution.institution_type) in ['ARCHIVE', 'UNKNOWN']
|
|
|
|
def test_empty_text(self, extractor):
|
|
"""Test extraction from empty text"""
|
|
result = extractor.extract_from_text("")
|
|
|
|
assert result.success is True
|
|
assert result.value == []
|
|
|
|
def test_no_institutions(self, extractor):
|
|
"""Test text with no institutions"""
|
|
text = "This is just some random text about nothing in particular."
|
|
|
|
result = extractor.extract_from_text(text)
|
|
|
|
assert result.success is True
|
|
# May return empty list or very low confidence results
|
|
assert isinstance(result.value, list)
|
|
|
|
def test_extract_from_conversation(self, extractor):
|
|
"""Test extraction from a Conversation object"""
|
|
# Create a simple conversation
|
|
message = ChatMessage(
|
|
uuid="msg-123",
|
|
text="The Rijksmuseum in Amsterdam (ISIL: NL-AsdRM) is a major art museum.",
|
|
sender="assistant",
|
|
content=[],
|
|
created_at=None,
|
|
updated_at=None
|
|
)
|
|
|
|
conversation = Conversation(
|
|
uuid="conv-123",
|
|
name="Dutch Museums",
|
|
summary=None,
|
|
created_at=None,
|
|
updated_at=None,
|
|
chat_messages=[message]
|
|
)
|
|
|
|
result = extractor.extract_from_conversation(conversation)
|
|
|
|
assert result.success is True
|
|
assert len(result.value) > 0
|
|
|
|
institution = result.value[0]
|
|
assert institution.provenance.conversation_id == "conv-123"
|
|
|
|
def test_confidence_scoring(self, extractor):
|
|
"""Test that confidence scores are calculated properly"""
|
|
# High confidence: has type, location, and identifier
|
|
text1 = "The Rijksmuseum in Amsterdam (ISIL: NL-AsdRM) is a major art museum."
|
|
|
|
result1 = extractor.extract_from_text(text1)
|
|
|
|
if result1.success and result1.value:
|
|
institution1 = result1.value[0]
|
|
# Should have relatively high confidence (type + location + identifier)
|
|
assert institution1.provenance.confidence_score >= 0.6
|
|
|
|
def test_multilingual_keywords(self, extractor):
|
|
"""Test detection of multilingual institution keywords"""
|
|
# Spanish museum
|
|
text = "El Museo Nacional de Brasil está en Brasilia."
|
|
|
|
result = extractor.extract_from_text(text)
|
|
|
|
# Should detect "museo" keyword
|
|
assert result.success is True
|
|
# May or may not extract name properly due to Spanish grammar
|
|
if result.value:
|
|
institution = result.value[0]
|
|
# Compare as strings since InstitutionTypeEnum != PermissibleValue
|
|
assert str(institution.institution_type) == 'MUSEUM'
|
|
|
|
def test_provenance_metadata(self, extractor):
|
|
"""Test that provenance metadata is complete"""
|
|
text = "The Amsterdam Museum has ISIL code NL-AsdAM."
|
|
|
|
result = extractor.extract_from_text(
|
|
text,
|
|
conversation_id="test-conv-123",
|
|
conversation_name="Test Conversation"
|
|
)
|
|
|
|
assert result.success is True
|
|
assert len(result.value) > 0
|
|
|
|
institution = result.value[0]
|
|
prov = institution.provenance
|
|
|
|
# Compare as strings since enum types differ
|
|
assert str(prov.data_source) == 'CONVERSATION_NLP'
|
|
assert str(prov.data_tier) == 'TIER_4_INFERRED'
|
|
assert prov.extraction_date is not None
|
|
# LinkML may serialize datetime as string, check for either
|
|
assert isinstance(prov.extraction_date, (datetime, str))
|
|
assert prov.extraction_method is not None
|
|
assert prov.confidence_score is not None
|
|
assert prov.conversation_id == "test-conv-123"
|
|
assert prov.verified_by is None # Not yet verified
|
|
|
|
def test_deduplication(self, extractor):
|
|
"""Test that duplicate entities are removed"""
|
|
# Mention same institution twice
|
|
text = (
|
|
"The Rijksmuseum is in Amsterdam. "
|
|
"The Rijksmuseum has a large collection."
|
|
)
|
|
|
|
result = extractor.extract_from_text(text)
|
|
|
|
assert result.success is True
|
|
# Should deduplicate and return only one instance
|
|
names = [inst.name for inst in result.value]
|
|
# Check that Rijksmuseum appears only once (case-insensitive)
|
|
rijks_count = sum(1 for name in names if "rijksmuseum" in name.lower())
|
|
assert rijks_count <= 1
|
|
|
|
|
|
class TestExtractedEntity:
|
|
"""Test ExtractedEntity dataclass"""
|
|
|
|
def test_create_entity(self):
|
|
"""Test creating an ExtractedEntity"""
|
|
entity = ExtractedEntity(
|
|
name="Test Museum",
|
|
institution_type=InstitutionType.MUSEUM,
|
|
city="Amsterdam",
|
|
country="NL",
|
|
confidence_score=0.85
|
|
)
|
|
|
|
assert entity.name == "Test Museum"
|
|
assert entity.institution_type == InstitutionType.MUSEUM
|
|
assert entity.city == "Amsterdam"
|
|
assert entity.country == "NL"
|
|
assert entity.confidence_score == 0.85
|
|
assert entity.identifiers == [] # Default from __post_init__
|
|
|
|
def test_entity_with_identifiers(self):
|
|
"""Test entity with identifiers list"""
|
|
from glam_extractor.models import Identifier
|
|
|
|
entity = ExtractedEntity(
|
|
name="Test Museum",
|
|
identifiers=[
|
|
Identifier(
|
|
identifier_scheme="ISIL",
|
|
identifier_value="NL-Test",
|
|
identifier_url=None,
|
|
assigned_date=None
|
|
)
|
|
]
|
|
)
|
|
|
|
assert len(entity.identifiers) == 1
|
|
assert entity.identifiers[0].identifier_scheme == "ISIL"
|
|
|
|
|
|
class TestResult:
|
|
"""Test Result pattern for error handling"""
|
|
|
|
def test_result_ok(self):
|
|
"""Test successful result"""
|
|
result = Result.ok([1, 2, 3])
|
|
|
|
assert result.success is True
|
|
assert result.value == [1, 2, 3]
|
|
assert result.error is None
|
|
|
|
def test_result_err(self):
|
|
"""Test error result"""
|
|
result = Result.err("Something went wrong")
|
|
|
|
assert result.success is False
|
|
assert result.value is None
|
|
assert result.error == "Something went wrong"
|