glam/backend/rag/test_template_sota.py
2026-01-08 15:56:28 +01:00

588 lines
22 KiB
Python

"""
Tests for SOTA Template Matching Components
Tests the new components added based on SOTA research:
1. RAGEnhancedMatcher (Tier 2.5) - Context-enriched matching using Q&A examples
2. SPARQLValidator - Validates SPARQL against ontology schema
3. 4-tier fallback behavior - Pattern -> Embedding -> RAG -> LLM
Based on:
- SPARQL-LLM (arXiv:2512.14277)
- COT-SPARQL (SEMANTICS 2024)
- KGQuest (arXiv:2511.11258)
Author: OpenCode
Created: 2025-01-07
"""
import pytest
from unittest.mock import MagicMock, patch
import numpy as np
# =============================================================================
# SPARQL VALIDATOR TESTS
# =============================================================================
class TestSPARQLValidator:
"""Tests for the SPARQLValidator class."""
@pytest.fixture
def validator(self):
"""Get a fresh validator instance."""
from backend.rag.template_sparql import SPARQLValidator
return SPARQLValidator()
def test_valid_query_with_known_predicates(self, validator):
"""Valid query using known hc: predicates should pass."""
sparql = """
SELECT ?name ?type WHERE {
?inst hc:institutionType "M" .
?inst hc:settlementName "Amsterdam" .
}
"""
result = validator.validate(sparql)
assert result.valid is True
assert len(result.errors) == 0
def test_valid_query_with_schema_predicates(self, validator):
"""Valid query using schema.org predicates should pass."""
sparql = """
SELECT ?name WHERE {
?inst schema:name ?name .
?inst schema:foundingDate ?date .
}
"""
result = validator.validate(sparql)
assert result.valid is True
assert len(result.errors) == 0
def test_invalid_predicate_detected(self, validator):
"""Unknown hc: predicate should be flagged."""
sparql = """
SELECT ?x WHERE {
?x hc:unknownPredicate "test" .
}
"""
result = validator.validate(sparql)
# Should have warning or error for unknown predicate
assert len(result.warnings) > 0 or len(result.errors) > 0
def test_typo_suggestion(self, validator):
"""Typo in predicate should suggest correction."""
sparql = """
SELECT ?x WHERE {
?x hc:institutionTyp "M" .
}
"""
result = validator.validate(sparql)
# Should suggest "hc:institutionType"
if result.suggestions:
assert any("institutionType" in s for s in result.suggestions)
def test_mismatched_braces_detected(self, validator):
"""Mismatched braces should be flagged."""
sparql = """
SELECT ?x WHERE {
?x hc:institutionType "M" .
""" # Missing closing brace
result = validator.validate(sparql)
assert result.valid is False
assert any("brace" in e.lower() for e in result.errors)
def test_missing_where_clause_detected(self, validator):
"""SELECT without WHERE should be flagged."""
sparql = """
SELECT ?x {
?x hc:institutionType "M" .
}
"""
result = validator.validate(sparql)
# Note: This has braces but no WHERE keyword
assert result.valid is False
assert any("WHERE" in e for e in result.errors)
def test_non_hc_query_passes(self, validator):
"""Query without hc: predicates should pass (not our responsibility)."""
sparql = """
SELECT ?s ?p ?o WHERE {
?s ?p ?o .
} LIMIT 10
"""
result = validator.validate(sparql)
assert result.valid is True
def test_budget_predicates_valid(self, validator):
"""Budget-related predicates should be valid."""
sparql = """
SELECT ?budget WHERE {
?inst hc:innovation_budget ?budget .
?inst hc:digitization_budget ?dbudget .
}
"""
result = validator.validate(sparql)
assert result.valid is True
assert len(result.errors) == 0
def test_change_event_predicates_valid(self, validator):
"""Change event predicates should be valid."""
sparql = """
SELECT ?event WHERE {
?event hc:changeType "MERGER" .
?event hc:eventDate ?date .
}
"""
result = validator.validate(sparql)
assert result.valid is True
# =============================================================================
# RAG ENHANCED MATCHER TESTS
# =============================================================================
class TestRAGEnhancedMatcher:
"""Tests for the RAGEnhancedMatcher class."""
@pytest.fixture
def mock_templates(self):
"""Create mock template definitions for testing."""
from backend.rag.template_sparql import TemplateDefinition, SlotDefinition, SlotType
return {
"list_institutions_by_type_city": TemplateDefinition(
id="list_institutions_by_type_city",
description="List institutions by type in a city",
intent=["list", "institutions", "city"],
question_patterns=["Welke {type} zijn er in {city}?"],
slots={
"type": SlotDefinition(type=SlotType.INSTITUTION_TYPE),
"city": SlotDefinition(type=SlotType.CITY),
},
sparql_template="SELECT ?inst WHERE { ?inst hc:institutionType '{{ type }}' }",
examples=[
{"question": "Welke musea zijn er in Amsterdam?", "slots": {"type": "M", "city": "Amsterdam"}},
{"question": "Welke archieven zijn er in Utrecht?", "slots": {"type": "A", "city": "Utrecht"}},
{"question": "Welke bibliotheken zijn er in Rotterdam?", "slots": {"type": "L", "city": "Rotterdam"}},
],
),
"count_institutions_by_type": TemplateDefinition(
id="count_institutions_by_type",
description="Count institutions by type",
intent=["count", "institutions"],
question_patterns=["Hoeveel {type} zijn er?"],
slots={
"type": SlotDefinition(type=SlotType.INSTITUTION_TYPE),
},
sparql_template="SELECT (COUNT(?inst) AS ?count) WHERE { ?inst hc:institutionType '{{ type }}' }",
examples=[
{"question": "Hoeveel musea zijn er in Nederland?", "slots": {"type": "M"}},
{"question": "Hoeveel archieven zijn er?", "slots": {"type": "A"}},
],
),
}
@pytest.fixture
def matcher(self):
"""Get a fresh RAG matcher instance (reset singleton)."""
from backend.rag.template_sparql import RAGEnhancedMatcher
# Reset singleton state for clean tests
RAGEnhancedMatcher._instance = None
RAGEnhancedMatcher._example_embeddings = None
RAGEnhancedMatcher._example_template_ids = None
RAGEnhancedMatcher._example_texts = None
RAGEnhancedMatcher._example_slots = None
return RAGEnhancedMatcher()
def test_singleton_pattern(self):
"""RAGEnhancedMatcher should be a singleton."""
from backend.rag.template_sparql import RAGEnhancedMatcher
# Reset first
RAGEnhancedMatcher._instance = None
matcher1 = RAGEnhancedMatcher()
matcher2 = RAGEnhancedMatcher()
assert matcher1 is matcher2
def test_match_returns_none_without_model(self, matcher, mock_templates):
"""Should return None if embedding model not available."""
with patch('backend.rag.template_sparql._get_embedding_model', return_value=None):
result = matcher.match("Welke musea zijn er in Den Haag?", mock_templates)
assert result is None
def test_match_with_high_agreement(self, matcher, mock_templates):
"""Should match when examples agree on template."""
# Create mock embedding model
mock_model = MagicMock()
# Create embeddings that will give high similarity for "list" template
# The question embedding
question_emb = np.array([1.0, 0.0, 0.0])
# Example embeddings - 3 from list template, 2 from count template
example_embs = np.array([
[0.95, 0.1, 0.0], # list - high similarity
[0.90, 0.15, 0.0], # list - high similarity
[0.92, 0.12, 0.0], # list - high similarity
[0.3, 0.9, 0.1], # count - low similarity
[0.25, 0.85, 0.15], # count - low similarity
])
mock_model.encode = MagicMock(side_effect=[
example_embs, # First call for indexing
np.array([question_emb]), # Second call for question
])
with patch('backend.rag.template_sparql._get_embedding_model', return_value=mock_model):
# Reset cached embeddings
matcher._example_embeddings = None
result = matcher.match("Welke musea zijn er in Den Haag?", mock_templates, k=5)
# Should match list template with high confidence
if result is not None:
assert result.matched is True
assert result.template_id == "list_institutions_by_type_city"
assert result.confidence >= 0.70
def test_match_returns_none_with_low_agreement(self, matcher, mock_templates):
"""Should return None when examples don't agree."""
mock_model = MagicMock()
# Create embeddings with low agreement (split between templates)
question_emb = np.array([0.5, 0.5, 0.0])
example_embs = np.array([
[0.6, 0.4, 0.0], # list - medium similarity
[0.55, 0.45, 0.0], # list - medium similarity
[0.52, 0.48, 0.0], # list - medium similarity
[0.45, 0.55, 0.0], # count - medium similarity
[0.4, 0.6, 0.0], # count - medium similarity
])
mock_model.encode = MagicMock(side_effect=[
example_embs,
np.array([question_emb]),
])
with patch('backend.rag.template_sparql._get_embedding_model', return_value=mock_model):
matcher._example_embeddings = None
# With mixed agreement, should fail min_agreement threshold
result = matcher.match(
"Geef me informatie over musea",
mock_templates,
k=5,
min_agreement=0.8 # High threshold
)
# May return None if agreement is below threshold
# The exact behavior depends on similarity calculation
class TestRAGEnhancedMatcherFactory:
"""Tests for the get_rag_enhanced_matcher factory function."""
def test_factory_returns_singleton(self):
"""Factory should return same instance."""
from backend.rag.template_sparql import get_rag_enhanced_matcher, RAGEnhancedMatcher
# Reset
RAGEnhancedMatcher._instance = None
matcher1 = get_rag_enhanced_matcher()
matcher2 = get_rag_enhanced_matcher()
assert matcher1 is matcher2
class TestSPARQLValidatorFactory:
"""Tests for the get_sparql_validator factory function."""
def test_factory_returns_instance(self):
"""Factory should return validator instance."""
from backend.rag.template_sparql import get_sparql_validator, SPARQLValidator
validator = get_sparql_validator()
assert isinstance(validator, SPARQLValidator)
def test_factory_returns_singleton(self):
"""Factory should return same instance."""
from backend.rag.template_sparql import get_sparql_validator, _sparql_validator
import backend.rag.template_sparql as module
# Reset
module._sparql_validator = None
validator1 = get_sparql_validator()
validator2 = get_sparql_validator()
assert validator1 is validator2
# =============================================================================
# TIER METRICS TESTS
# =============================================================================
class TestTierMetrics:
"""Tests for tier tracking in metrics."""
def test_rag_tier_in_stats(self):
"""RAG tier should be tracked in tier stats."""
from backend.rag.metrics import get_template_tier_stats
stats = get_template_tier_stats()
# Should include rag tier
if stats.get("available"):
assert "rag" in stats.get("tiers", {})
def test_record_template_tier_accepts_rag(self):
"""record_template_tier should accept 'rag' tier."""
from backend.rag.metrics import record_template_tier
# Should not raise
record_template_tier(tier="rag", matched=True, template_id="test_template")
record_template_tier(tier="rag", matched=False)
# =============================================================================
# VALIDATION RESULT TESTS
# =============================================================================
class TestSPARQLValidationResult:
"""Tests for SPARQLValidationResult model."""
def test_valid_result_creation(self):
"""Should create valid result."""
from backend.rag.template_sparql import SPARQLValidationResult
result = SPARQLValidationResult(valid=True)
assert result.valid is True
assert result.errors == []
assert result.warnings == []
assert result.suggestions == []
def test_invalid_result_with_errors(self):
"""Should create result with errors."""
from backend.rag.template_sparql import SPARQLValidationResult
result = SPARQLValidationResult(
valid=False,
errors=["Unknown predicate: hc:foo"],
suggestions=["Did you mean: hc:foaf?"]
)
assert result.valid is False
assert len(result.errors) == 1
assert len(result.suggestions) == 1
# =============================================================================
# INTEGRATION TESTS
# =============================================================================
class TestFourTierFallback:
"""Integration tests for 4-tier matching fallback."""
@pytest.fixture
def classifier(self):
"""Get TemplateClassifier instance."""
from backend.rag.template_sparql import TemplateClassifier
return TemplateClassifier()
def test_pattern_tier_takes_priority(self, classifier):
"""Tier 1 (pattern) should be checked first."""
# A question that matches a pattern exactly
question = "Welke musea zijn er in Amsterdam?"
result = classifier.forward(question, language="nl")
# Should match (if templates are loaded)
if result.matched:
# Pattern matches should have high confidence
assert result.confidence >= 0.75
assert "pattern" in result.reasoning.lower() or result.confidence >= 0.90
def test_embedding_tier_fallback(self, classifier):
"""Tier 2 (embedding) should be used when pattern fails."""
# A paraphrased question that won't match patterns exactly
question = "Geef mij een lijst van alle musea in de stad Amsterdam"
# This may use embedding tier if pattern doesn't match
result = classifier.forward(question, language="nl")
# Should still get a result (may be embedding or LLM)
# We just verify the fallback mechanism works
assert result is not None
class TestValidatorPredicateSet:
"""Tests to verify the validator predicate/class sets."""
def test_valid_predicates_not_empty(self):
"""Validator should have known predicates."""
from backend.rag.template_sparql import SPARQLValidator
validator = SPARQLValidator()
assert len(validator.VALID_HC_PREDICATES) > 0
assert len(validator.VALID_EXTERNAL_PREDICATES) > 0
def test_institution_type_in_predicates(self):
"""institutionType should be a valid predicate."""
from backend.rag.template_sparql import SPARQLValidator
validator = SPARQLValidator()
assert "hc:institutionType" in validator.VALID_HC_PREDICATES
def test_settlement_name_in_predicates(self):
"""settlementName should be a valid predicate."""
from backend.rag.template_sparql import SPARQLValidator
validator = SPARQLValidator()
assert "hc:settlementName" in validator.VALID_HC_PREDICATES
# =============================================================================
# SCHEMA-AWARE SLOT VALIDATOR TESTS
# =============================================================================
class TestSchemaAwareSlotValidator:
"""Tests for the SchemaAwareSlotValidator class."""
@pytest.fixture
def validator(self):
"""Get a fresh validator instance (reset singleton)."""
from backend.rag.template_sparql import SchemaAwareSlotValidator
# Reset singleton state
SchemaAwareSlotValidator._instance = None
SchemaAwareSlotValidator._valid_values = {}
SchemaAwareSlotValidator._synonym_maps = {}
SchemaAwareSlotValidator._loaded = False
return SchemaAwareSlotValidator()
def test_singleton_pattern(self):
"""SchemaAwareSlotValidator should be a singleton."""
from backend.rag.template_sparql import SchemaAwareSlotValidator
SchemaAwareSlotValidator._instance = None
v1 = SchemaAwareSlotValidator()
v2 = SchemaAwareSlotValidator()
assert v1 is v2
def test_direct_institution_type_match(self, validator):
"""Direct institution type should be valid."""
result = validator.validate_slot("institution_type", "museum")
assert result.valid is True
assert result.corrected_value == "M"
assert result.confidence == 1.0
def test_code_institution_type_valid(self, validator):
"""Single-letter code should be valid."""
result = validator.validate_slot("institution_type", "M")
assert result.valid is True
def test_dutch_plural_institution_type(self, validator):
"""Dutch plural forms should resolve correctly."""
result = validator.validate_slot("institution_type", "musea")
assert result.corrected_value == "M"
def test_typo_auto_correction(self, validator):
"""Typos should be auto-corrected with fuzzy matching."""
result = validator.validate_slot("institution_type", "msueum")
# Should correct to "M" with lower confidence
assert result.corrected_value == "M"
assert result.confidence < 1.0
assert result.confidence >= 0.7
def test_subregion_name_resolution(self, validator):
"""Subregion names should resolve to ISO codes."""
result = validator.validate_slot("subregion", "noord-holland")
assert result.valid is True
assert result.corrected_value == "NL-NH"
def test_subregion_code_valid(self, validator):
"""ISO subregion codes should be valid."""
result = validator.validate_slot("subregion", "NL-NH")
assert result.valid is True
def test_country_name_resolution(self, validator):
"""Country names should resolve to codes."""
result = validator.validate_slot("country", "nederland")
assert result.valid is True
assert result.corrected_value == "NL"
def test_validate_multiple_slots(self, validator):
"""validate_slots should handle multiple slots."""
slots = {
"institution_type": "museum",
"subregion": "noord-holland",
}
results = validator.validate_slots(slots)
assert len(results) == 2
assert results["institution_type"].corrected_value == "M"
assert results["subregion"].corrected_value == "NL-NH"
def test_get_corrected_slots(self, validator):
"""get_corrected_slots should return corrected values."""
slots = {
"institution_type": "bibliotheek",
"subregion": "zuid-holland",
}
corrected = validator.get_corrected_slots(slots)
assert corrected["institution_type"] == "L"
class TestSchemaSlotValidatorFactory:
"""Tests for the get_schema_slot_validator factory."""
def test_factory_returns_instance(self):
"""Factory should return validator instance."""
from backend.rag.template_sparql import get_schema_slot_validator, SchemaAwareSlotValidator
# Reset
SchemaAwareSlotValidator._instance = None
validator = get_schema_slot_validator()
assert isinstance(validator, SchemaAwareSlotValidator)
def test_factory_returns_singleton(self):
"""Factory should return same instance."""
from backend.rag.template_sparql import get_schema_slot_validator, SchemaAwareSlotValidator
import backend.rag.template_sparql as module
# Reset
SchemaAwareSlotValidator._instance = None
module._schema_slot_validator = None
v1 = get_schema_slot_validator()
v2 = get_schema_slot_validator()
assert v1 is v2
class TestSlotValidationResult:
"""Tests for SlotValidationResult model."""
def test_valid_result_creation(self):
"""Should create valid result."""
from backend.rag.template_sparql import SlotValidationResult
result = SlotValidationResult(
valid=True,
original_value="museum",
slot_name="institution_type",
corrected_value="M"
)
assert result.valid is True
assert result.original_value == "museum"
assert result.corrected_value == "M"
def test_invalid_result_with_suggestions(self):
"""Should create invalid result with suggestions."""
from backend.rag.template_sparql import SlotValidationResult
result = SlotValidationResult(
valid=False,
original_value="xyz",
slot_name="institution_type",
errors=["Invalid value 'xyz'"],
suggestions=["Valid values include: M, L, A"]
)
assert result.valid is False
assert len(result.errors) == 1
assert len(result.suggestions) == 1
if __name__ == "__main__":
pytest.main([__file__, "-v"])