""" Tests for SOTA Template Matching Components Tests the new components added based on SOTA research: 1. RAGEnhancedMatcher (Tier 2.5) - Context-enriched matching using Q&A examples 2. SPARQLValidator - Validates SPARQL against ontology schema 3. 4-tier fallback behavior - Pattern -> Embedding -> RAG -> LLM Based on: - SPARQL-LLM (arXiv:2512.14277) - COT-SPARQL (SEMANTICS 2024) - KGQuest (arXiv:2511.11258) Author: OpenCode Created: 2025-01-07 """ import pytest from unittest.mock import MagicMock, patch import numpy as np # ============================================================================= # SPARQL VALIDATOR TESTS # ============================================================================= class TestSPARQLValidator: """Tests for the SPARQLValidator class.""" @pytest.fixture def validator(self): """Get a fresh validator instance.""" from backend.rag.template_sparql import SPARQLValidator return SPARQLValidator() def test_valid_query_with_known_predicates(self, validator): """Valid query using known hc: predicates should pass.""" sparql = """ SELECT ?name ?type WHERE { ?inst hc:institutionType "M" . ?inst hc:settlementName "Amsterdam" . } """ result = validator.validate(sparql) assert result.valid is True assert len(result.errors) == 0 def test_valid_query_with_schema_predicates(self, validator): """Valid query using schema.org predicates should pass.""" sparql = """ SELECT ?name WHERE { ?inst schema:name ?name . ?inst schema:foundingDate ?date . } """ result = validator.validate(sparql) assert result.valid is True assert len(result.errors) == 0 def test_invalid_predicate_detected(self, validator): """Unknown hc: predicate should be flagged.""" sparql = """ SELECT ?x WHERE { ?x hc:unknownPredicate "test" . } """ result = validator.validate(sparql) # Should have warning or error for unknown predicate assert len(result.warnings) > 0 or len(result.errors) > 0 def test_typo_suggestion(self, validator): """Typo in predicate should suggest correction.""" sparql = """ SELECT ?x WHERE { ?x hc:institutionTyp "M" . } """ result = validator.validate(sparql) # Should suggest "hc:institutionType" if result.suggestions: assert any("institutionType" in s for s in result.suggestions) def test_mismatched_braces_detected(self, validator): """Mismatched braces should be flagged.""" sparql = """ SELECT ?x WHERE { ?x hc:institutionType "M" . """ # Missing closing brace result = validator.validate(sparql) assert result.valid is False assert any("brace" in e.lower() for e in result.errors) def test_missing_where_clause_detected(self, validator): """SELECT without WHERE should be flagged.""" sparql = """ SELECT ?x { ?x hc:institutionType "M" . } """ result = validator.validate(sparql) # Note: This has braces but no WHERE keyword assert result.valid is False assert any("WHERE" in e for e in result.errors) def test_non_hc_query_passes(self, validator): """Query without hc: predicates should pass (not our responsibility).""" sparql = """ SELECT ?s ?p ?o WHERE { ?s ?p ?o . } LIMIT 10 """ result = validator.validate(sparql) assert result.valid is True def test_budget_predicates_valid(self, validator): """Budget-related predicates should be valid.""" sparql = """ SELECT ?budget WHERE { ?inst hc:innovation_budget ?budget . ?inst hc:digitization_budget ?dbudget . } """ result = validator.validate(sparql) assert result.valid is True assert len(result.errors) == 0 def test_change_event_predicates_valid(self, validator): """Change event predicates should be valid.""" sparql = """ SELECT ?event WHERE { ?event hc:changeType "MERGER" . ?event hc:eventDate ?date . } """ result = validator.validate(sparql) assert result.valid is True # ============================================================================= # RAG ENHANCED MATCHER TESTS # ============================================================================= class TestRAGEnhancedMatcher: """Tests for the RAGEnhancedMatcher class.""" @pytest.fixture def mock_templates(self): """Create mock template definitions for testing.""" from backend.rag.template_sparql import TemplateDefinition, SlotDefinition, SlotType return { "list_institutions_by_type_city": TemplateDefinition( id="list_institutions_by_type_city", description="List institutions by type in a city", intent=["list", "institutions", "city"], question_patterns=["Welke {type} zijn er in {city}?"], slots={ "type": SlotDefinition(type=SlotType.INSTITUTION_TYPE), "city": SlotDefinition(type=SlotType.CITY), }, sparql_template="SELECT ?inst WHERE { ?inst hc:institutionType '{{ type }}' }", examples=[ {"question": "Welke musea zijn er in Amsterdam?", "slots": {"type": "M", "city": "Amsterdam"}}, {"question": "Welke archieven zijn er in Utrecht?", "slots": {"type": "A", "city": "Utrecht"}}, {"question": "Welke bibliotheken zijn er in Rotterdam?", "slots": {"type": "L", "city": "Rotterdam"}}, ], ), "count_institutions_by_type": TemplateDefinition( id="count_institutions_by_type", description="Count institutions by type", intent=["count", "institutions"], question_patterns=["Hoeveel {type} zijn er?"], slots={ "type": SlotDefinition(type=SlotType.INSTITUTION_TYPE), }, sparql_template="SELECT (COUNT(?inst) AS ?count) WHERE { ?inst hc:institutionType '{{ type }}' }", examples=[ {"question": "Hoeveel musea zijn er in Nederland?", "slots": {"type": "M"}}, {"question": "Hoeveel archieven zijn er?", "slots": {"type": "A"}}, ], ), } @pytest.fixture def matcher(self): """Get a fresh RAG matcher instance (reset singleton).""" from backend.rag.template_sparql import RAGEnhancedMatcher # Reset singleton state for clean tests RAGEnhancedMatcher._instance = None RAGEnhancedMatcher._example_embeddings = None RAGEnhancedMatcher._example_template_ids = None RAGEnhancedMatcher._example_texts = None RAGEnhancedMatcher._example_slots = None return RAGEnhancedMatcher() def test_singleton_pattern(self): """RAGEnhancedMatcher should be a singleton.""" from backend.rag.template_sparql import RAGEnhancedMatcher # Reset first RAGEnhancedMatcher._instance = None matcher1 = RAGEnhancedMatcher() matcher2 = RAGEnhancedMatcher() assert matcher1 is matcher2 def test_match_returns_none_without_model(self, matcher, mock_templates): """Should return None if embedding model not available.""" with patch('backend.rag.template_sparql._get_embedding_model', return_value=None): result = matcher.match("Welke musea zijn er in Den Haag?", mock_templates) assert result is None def test_match_with_high_agreement(self, matcher, mock_templates): """Should match when examples agree on template.""" # Create mock embedding model mock_model = MagicMock() # Create embeddings that will give high similarity for "list" template # The question embedding question_emb = np.array([1.0, 0.0, 0.0]) # Example embeddings - 3 from list template, 2 from count template example_embs = np.array([ [0.95, 0.1, 0.0], # list - high similarity [0.90, 0.15, 0.0], # list - high similarity [0.92, 0.12, 0.0], # list - high similarity [0.3, 0.9, 0.1], # count - low similarity [0.25, 0.85, 0.15], # count - low similarity ]) mock_model.encode = MagicMock(side_effect=[ example_embs, # First call for indexing np.array([question_emb]), # Second call for question ]) with patch('backend.rag.template_sparql._get_embedding_model', return_value=mock_model): # Reset cached embeddings matcher._example_embeddings = None result = matcher.match("Welke musea zijn er in Den Haag?", mock_templates, k=5) # Should match list template with high confidence if result is not None: assert result.matched is True assert result.template_id == "list_institutions_by_type_city" assert result.confidence >= 0.70 def test_match_returns_none_with_low_agreement(self, matcher, mock_templates): """Should return None when examples don't agree.""" mock_model = MagicMock() # Create embeddings with low agreement (split between templates) question_emb = np.array([0.5, 0.5, 0.0]) example_embs = np.array([ [0.6, 0.4, 0.0], # list - medium similarity [0.55, 0.45, 0.0], # list - medium similarity [0.52, 0.48, 0.0], # list - medium similarity [0.45, 0.55, 0.0], # count - medium similarity [0.4, 0.6, 0.0], # count - medium similarity ]) mock_model.encode = MagicMock(side_effect=[ example_embs, np.array([question_emb]), ]) with patch('backend.rag.template_sparql._get_embedding_model', return_value=mock_model): matcher._example_embeddings = None # With mixed agreement, should fail min_agreement threshold result = matcher.match( "Geef me informatie over musea", mock_templates, k=5, min_agreement=0.8 # High threshold ) # May return None if agreement is below threshold # The exact behavior depends on similarity calculation class TestRAGEnhancedMatcherFactory: """Tests for the get_rag_enhanced_matcher factory function.""" def test_factory_returns_singleton(self): """Factory should return same instance.""" from backend.rag.template_sparql import get_rag_enhanced_matcher, RAGEnhancedMatcher # Reset RAGEnhancedMatcher._instance = None matcher1 = get_rag_enhanced_matcher() matcher2 = get_rag_enhanced_matcher() assert matcher1 is matcher2 class TestSPARQLValidatorFactory: """Tests for the get_sparql_validator factory function.""" def test_factory_returns_instance(self): """Factory should return validator instance.""" from backend.rag.template_sparql import get_sparql_validator, SPARQLValidator validator = get_sparql_validator() assert isinstance(validator, SPARQLValidator) def test_factory_returns_singleton(self): """Factory should return same instance.""" from backend.rag.template_sparql import get_sparql_validator, _sparql_validator import backend.rag.template_sparql as module # Reset module._sparql_validator = None validator1 = get_sparql_validator() validator2 = get_sparql_validator() assert validator1 is validator2 # ============================================================================= # TIER METRICS TESTS # ============================================================================= class TestTierMetrics: """Tests for tier tracking in metrics.""" def test_rag_tier_in_stats(self): """RAG tier should be tracked in tier stats.""" from backend.rag.metrics import get_template_tier_stats stats = get_template_tier_stats() # Should include rag tier if stats.get("available"): assert "rag" in stats.get("tiers", {}) def test_record_template_tier_accepts_rag(self): """record_template_tier should accept 'rag' tier.""" from backend.rag.metrics import record_template_tier # Should not raise record_template_tier(tier="rag", matched=True, template_id="test_template") record_template_tier(tier="rag", matched=False) # ============================================================================= # VALIDATION RESULT TESTS # ============================================================================= class TestSPARQLValidationResult: """Tests for SPARQLValidationResult model.""" def test_valid_result_creation(self): """Should create valid result.""" from backend.rag.template_sparql import SPARQLValidationResult result = SPARQLValidationResult(valid=True) assert result.valid is True assert result.errors == [] assert result.warnings == [] assert result.suggestions == [] def test_invalid_result_with_errors(self): """Should create result with errors.""" from backend.rag.template_sparql import SPARQLValidationResult result = SPARQLValidationResult( valid=False, errors=["Unknown predicate: hc:foo"], suggestions=["Did you mean: hc:foaf?"] ) assert result.valid is False assert len(result.errors) == 1 assert len(result.suggestions) == 1 # ============================================================================= # INTEGRATION TESTS # ============================================================================= class TestFourTierFallback: """Integration tests for 4-tier matching fallback.""" @pytest.fixture def classifier(self): """Get TemplateClassifier instance.""" from backend.rag.template_sparql import TemplateClassifier return TemplateClassifier() def test_pattern_tier_takes_priority(self, classifier): """Tier 1 (pattern) should be checked first.""" # A question that matches a pattern exactly question = "Welke musea zijn er in Amsterdam?" result = classifier.forward(question, language="nl") # Should match (if templates are loaded) if result.matched: # Pattern matches should have high confidence assert result.confidence >= 0.75 assert "pattern" in result.reasoning.lower() or result.confidence >= 0.90 def test_embedding_tier_fallback(self, classifier): """Tier 2 (embedding) should be used when pattern fails.""" # A paraphrased question that won't match patterns exactly question = "Geef mij een lijst van alle musea in de stad Amsterdam" # This may use embedding tier if pattern doesn't match result = classifier.forward(question, language="nl") # Should still get a result (may be embedding or LLM) # We just verify the fallback mechanism works assert result is not None class TestValidatorPredicateSet: """Tests to verify the validator predicate/class sets.""" def test_valid_predicates_not_empty(self): """Validator should have known predicates.""" from backend.rag.template_sparql import SPARQLValidator validator = SPARQLValidator() assert len(validator.VALID_HC_PREDICATES) > 0 assert len(validator.VALID_EXTERNAL_PREDICATES) > 0 def test_institution_type_in_predicates(self): """institutionType should be a valid predicate.""" from backend.rag.template_sparql import SPARQLValidator validator = SPARQLValidator() assert "hc:institutionType" in validator.VALID_HC_PREDICATES def test_settlement_name_in_predicates(self): """settlementName should be a valid predicate.""" from backend.rag.template_sparql import SPARQLValidator validator = SPARQLValidator() assert "hc:settlementName" in validator.VALID_HC_PREDICATES # ============================================================================= # SCHEMA-AWARE SLOT VALIDATOR TESTS # ============================================================================= class TestSchemaAwareSlotValidator: """Tests for the SchemaAwareSlotValidator class.""" @pytest.fixture def validator(self): """Get a fresh validator instance (reset singleton).""" from backend.rag.template_sparql import SchemaAwareSlotValidator # Reset singleton state SchemaAwareSlotValidator._instance = None SchemaAwareSlotValidator._valid_values = {} SchemaAwareSlotValidator._synonym_maps = {} SchemaAwareSlotValidator._loaded = False return SchemaAwareSlotValidator() def test_singleton_pattern(self): """SchemaAwareSlotValidator should be a singleton.""" from backend.rag.template_sparql import SchemaAwareSlotValidator SchemaAwareSlotValidator._instance = None v1 = SchemaAwareSlotValidator() v2 = SchemaAwareSlotValidator() assert v1 is v2 def test_direct_institution_type_match(self, validator): """Direct institution type should be valid.""" result = validator.validate_slot("institution_type", "museum") assert result.valid is True assert result.corrected_value == "M" assert result.confidence == 1.0 def test_code_institution_type_valid(self, validator): """Single-letter code should be valid.""" result = validator.validate_slot("institution_type", "M") assert result.valid is True def test_dutch_plural_institution_type(self, validator): """Dutch plural forms should resolve correctly.""" result = validator.validate_slot("institution_type", "musea") assert result.corrected_value == "M" def test_typo_auto_correction(self, validator): """Typos should be auto-corrected with fuzzy matching.""" result = validator.validate_slot("institution_type", "msueum") # Should correct to "M" with lower confidence assert result.corrected_value == "M" assert result.confidence < 1.0 assert result.confidence >= 0.7 def test_subregion_name_resolution(self, validator): """Subregion names should resolve to ISO codes.""" result = validator.validate_slot("subregion", "noord-holland") assert result.valid is True assert result.corrected_value == "NL-NH" def test_subregion_code_valid(self, validator): """ISO subregion codes should be valid.""" result = validator.validate_slot("subregion", "NL-NH") assert result.valid is True def test_country_name_resolution(self, validator): """Country names should resolve to codes.""" result = validator.validate_slot("country", "nederland") assert result.valid is True assert result.corrected_value == "NL" def test_validate_multiple_slots(self, validator): """validate_slots should handle multiple slots.""" slots = { "institution_type": "museum", "subregion": "noord-holland", } results = validator.validate_slots(slots) assert len(results) == 2 assert results["institution_type"].corrected_value == "M" assert results["subregion"].corrected_value == "NL-NH" def test_get_corrected_slots(self, validator): """get_corrected_slots should return corrected values.""" slots = { "institution_type": "bibliotheek", "subregion": "zuid-holland", } corrected = validator.get_corrected_slots(slots) assert corrected["institution_type"] == "L" class TestSchemaSlotValidatorFactory: """Tests for the get_schema_slot_validator factory.""" def test_factory_returns_instance(self): """Factory should return validator instance.""" from backend.rag.template_sparql import get_schema_slot_validator, SchemaAwareSlotValidator # Reset SchemaAwareSlotValidator._instance = None validator = get_schema_slot_validator() assert isinstance(validator, SchemaAwareSlotValidator) def test_factory_returns_singleton(self): """Factory should return same instance.""" from backend.rag.template_sparql import get_schema_slot_validator, SchemaAwareSlotValidator import backend.rag.template_sparql as module # Reset SchemaAwareSlotValidator._instance = None module._schema_slot_validator = None v1 = get_schema_slot_validator() v2 = get_schema_slot_validator() assert v1 is v2 class TestSlotValidationResult: """Tests for SlotValidationResult model.""" def test_valid_result_creation(self): """Should create valid result.""" from backend.rag.template_sparql import SlotValidationResult result = SlotValidationResult( valid=True, original_value="museum", slot_name="institution_type", corrected_value="M" ) assert result.valid is True assert result.original_value == "museum" assert result.corrected_value == "M" def test_invalid_result_with_suggestions(self): """Should create invalid result with suggestions.""" from backend.rag.template_sparql import SlotValidationResult result = SlotValidationResult( valid=False, original_value="xyz", slot_name="institution_type", errors=["Invalid value 'xyz'"], suggestions=["Valid values include: M, L, A"] ) assert result.valid is False assert len(result.errors) == 1 assert len(result.suggestions) == 1 if __name__ == "__main__": pytest.main([__file__, "-v"])