glam/backend/rag/test_template_sota.py
kempersc 99dc608826 Refactor RAG to template-based SPARQL generation
Major architectural changes based on Formica et al. (2023) research:
- Add TemplateClassifier for deterministic SPARQL template matching
- Add SlotExtractor with synonym resolution for slot values
- Add TemplateInstantiator using Jinja2 for query rendering
- Refactor dspy_heritage_rag.py to use template system
- Update main.py with streamlined pipeline
- Fix semantic_router.py ordering issues
- Add comprehensive metrics tracking

Template-based approach achieves 65% precision vs 10% LLM-only
per Formica et al. research on SPARQL generation.
2026-01-07 22:04:43 +01:00

436 lines
17 KiB
Python

"""
Tests for SOTA Template Matching Components
Tests the new components added based on SOTA research:
1. RAGEnhancedMatcher (Tier 2.5) - Context-enriched matching using Q&A examples
2. SPARQLValidator - Validates SPARQL against ontology schema
3. 4-tier fallback behavior - Pattern -> Embedding -> RAG -> LLM
Based on:
- SPARQL-LLM (arXiv:2512.14277)
- COT-SPARQL (SEMANTICS 2024)
- KGQuest (arXiv:2511.11258)
Author: OpenCode
Created: 2025-01-07
"""
import pytest
from unittest.mock import MagicMock, patch
import numpy as np
# =============================================================================
# SPARQL VALIDATOR TESTS
# =============================================================================
class TestSPARQLValidator:
"""Tests for the SPARQLValidator class."""
@pytest.fixture
def validator(self):
"""Get a fresh validator instance."""
from backend.rag.template_sparql import SPARQLValidator
return SPARQLValidator()
def test_valid_query_with_known_predicates(self, validator):
"""Valid query using known hc: predicates should pass."""
sparql = """
SELECT ?name ?type WHERE {
?inst hc:institutionType "M" .
?inst hc:settlementName "Amsterdam" .
}
"""
result = validator.validate(sparql)
assert result.valid is True
assert len(result.errors) == 0
def test_valid_query_with_schema_predicates(self, validator):
"""Valid query using schema.org predicates should pass."""
sparql = """
SELECT ?name WHERE {
?inst schema:name ?name .
?inst schema:foundingDate ?date .
}
"""
result = validator.validate(sparql)
assert result.valid is True
assert len(result.errors) == 0
def test_invalid_predicate_detected(self, validator):
"""Unknown hc: predicate should be flagged."""
sparql = """
SELECT ?x WHERE {
?x hc:unknownPredicate "test" .
}
"""
result = validator.validate(sparql)
# Should have warning or error for unknown predicate
assert len(result.warnings) > 0 or len(result.errors) > 0
def test_typo_suggestion(self, validator):
"""Typo in predicate should suggest correction."""
sparql = """
SELECT ?x WHERE {
?x hc:institutionTyp "M" .
}
"""
result = validator.validate(sparql)
# Should suggest "hc:institutionType"
if result.suggestions:
assert any("institutionType" in s for s in result.suggestions)
def test_mismatched_braces_detected(self, validator):
"""Mismatched braces should be flagged."""
sparql = """
SELECT ?x WHERE {
?x hc:institutionType "M" .
""" # Missing closing brace
result = validator.validate(sparql)
assert result.valid is False
assert any("brace" in e.lower() for e in result.errors)
def test_missing_where_clause_detected(self, validator):
"""SELECT without WHERE should be flagged."""
sparql = """
SELECT ?x {
?x hc:institutionType "M" .
}
"""
result = validator.validate(sparql)
# Note: This has braces but no WHERE keyword
assert result.valid is False
assert any("WHERE" in e for e in result.errors)
def test_non_hc_query_passes(self, validator):
"""Query without hc: predicates should pass (not our responsibility)."""
sparql = """
SELECT ?s ?p ?o WHERE {
?s ?p ?o .
} LIMIT 10
"""
result = validator.validate(sparql)
assert result.valid is True
def test_budget_predicates_valid(self, validator):
"""Budget-related predicates should be valid."""
sparql = """
SELECT ?budget WHERE {
?inst hc:innovation_budget ?budget .
?inst hc:digitization_budget ?dbudget .
}
"""
result = validator.validate(sparql)
assert result.valid is True
assert len(result.errors) == 0
def test_change_event_predicates_valid(self, validator):
"""Change event predicates should be valid."""
sparql = """
SELECT ?event WHERE {
?event hc:changeType "MERGER" .
?event hc:eventDate ?date .
}
"""
result = validator.validate(sparql)
assert result.valid is True
# =============================================================================
# RAG ENHANCED MATCHER TESTS
# =============================================================================
class TestRAGEnhancedMatcher:
"""Tests for the RAGEnhancedMatcher class."""
@pytest.fixture
def mock_templates(self):
"""Create mock template definitions for testing."""
from backend.rag.template_sparql import TemplateDefinition, SlotDefinition, SlotType
return {
"list_institutions_by_type_city": TemplateDefinition(
id="list_institutions_by_type_city",
description="List institutions by type in a city",
intent=["list", "institutions", "city"],
question_patterns=["Welke {type} zijn er in {city}?"],
slots={
"type": SlotDefinition(type=SlotType.INSTITUTION_TYPE),
"city": SlotDefinition(type=SlotType.CITY),
},
sparql_template="SELECT ?inst WHERE { ?inst hc:institutionType '{{ type }}' }",
examples=[
{"question": "Welke musea zijn er in Amsterdam?", "slots": {"type": "M", "city": "Amsterdam"}},
{"question": "Welke archieven zijn er in Utrecht?", "slots": {"type": "A", "city": "Utrecht"}},
{"question": "Welke bibliotheken zijn er in Rotterdam?", "slots": {"type": "L", "city": "Rotterdam"}},
],
),
"count_institutions_by_type": TemplateDefinition(
id="count_institutions_by_type",
description="Count institutions by type",
intent=["count", "institutions"],
question_patterns=["Hoeveel {type} zijn er?"],
slots={
"type": SlotDefinition(type=SlotType.INSTITUTION_TYPE),
},
sparql_template="SELECT (COUNT(?inst) AS ?count) WHERE { ?inst hc:institutionType '{{ type }}' }",
examples=[
{"question": "Hoeveel musea zijn er in Nederland?", "slots": {"type": "M"}},
{"question": "Hoeveel archieven zijn er?", "slots": {"type": "A"}},
],
),
}
@pytest.fixture
def matcher(self):
"""Get a fresh RAG matcher instance (reset singleton)."""
from backend.rag.template_sparql import RAGEnhancedMatcher
# Reset singleton state for clean tests
RAGEnhancedMatcher._instance = None
RAGEnhancedMatcher._example_embeddings = None
RAGEnhancedMatcher._example_template_ids = None
RAGEnhancedMatcher._example_texts = None
RAGEnhancedMatcher._example_slots = None
return RAGEnhancedMatcher()
def test_singleton_pattern(self):
"""RAGEnhancedMatcher should be a singleton."""
from backend.rag.template_sparql import RAGEnhancedMatcher
# Reset first
RAGEnhancedMatcher._instance = None
matcher1 = RAGEnhancedMatcher()
matcher2 = RAGEnhancedMatcher()
assert matcher1 is matcher2
def test_match_returns_none_without_model(self, matcher, mock_templates):
"""Should return None if embedding model not available."""
with patch('backend.rag.template_sparql._get_embedding_model', return_value=None):
result = matcher.match("Welke musea zijn er in Den Haag?", mock_templates)
assert result is None
def test_match_with_high_agreement(self, matcher, mock_templates):
"""Should match when examples agree on template."""
# Create mock embedding model
mock_model = MagicMock()
# Create embeddings that will give high similarity for "list" template
# The question embedding
question_emb = np.array([1.0, 0.0, 0.0])
# Example embeddings - 3 from list template, 2 from count template
example_embs = np.array([
[0.95, 0.1, 0.0], # list - high similarity
[0.90, 0.15, 0.0], # list - high similarity
[0.92, 0.12, 0.0], # list - high similarity
[0.3, 0.9, 0.1], # count - low similarity
[0.25, 0.85, 0.15], # count - low similarity
])
mock_model.encode = MagicMock(side_effect=[
example_embs, # First call for indexing
np.array([question_emb]), # Second call for question
])
with patch('backend.rag.template_sparql._get_embedding_model', return_value=mock_model):
# Reset cached embeddings
matcher._example_embeddings = None
result = matcher.match("Welke musea zijn er in Den Haag?", mock_templates, k=5)
# Should match list template with high confidence
if result is not None:
assert result.matched is True
assert result.template_id == "list_institutions_by_type_city"
assert result.confidence >= 0.70
def test_match_returns_none_with_low_agreement(self, matcher, mock_templates):
"""Should return None when examples don't agree."""
mock_model = MagicMock()
# Create embeddings with low agreement (split between templates)
question_emb = np.array([0.5, 0.5, 0.0])
example_embs = np.array([
[0.6, 0.4, 0.0], # list - medium similarity
[0.55, 0.45, 0.0], # list - medium similarity
[0.52, 0.48, 0.0], # list - medium similarity
[0.45, 0.55, 0.0], # count - medium similarity
[0.4, 0.6, 0.0], # count - medium similarity
])
mock_model.encode = MagicMock(side_effect=[
example_embs,
np.array([question_emb]),
])
with patch('backend.rag.template_sparql._get_embedding_model', return_value=mock_model):
matcher._example_embeddings = None
# With mixed agreement, should fail min_agreement threshold
result = matcher.match(
"Geef me informatie over musea",
mock_templates,
k=5,
min_agreement=0.8 # High threshold
)
# May return None if agreement is below threshold
# The exact behavior depends on similarity calculation
class TestRAGEnhancedMatcherFactory:
"""Tests for the get_rag_enhanced_matcher factory function."""
def test_factory_returns_singleton(self):
"""Factory should return same instance."""
from backend.rag.template_sparql import get_rag_enhanced_matcher, RAGEnhancedMatcher
# Reset
RAGEnhancedMatcher._instance = None
matcher1 = get_rag_enhanced_matcher()
matcher2 = get_rag_enhanced_matcher()
assert matcher1 is matcher2
class TestSPARQLValidatorFactory:
"""Tests for the get_sparql_validator factory function."""
def test_factory_returns_instance(self):
"""Factory should return validator instance."""
from backend.rag.template_sparql import get_sparql_validator, SPARQLValidator
validator = get_sparql_validator()
assert isinstance(validator, SPARQLValidator)
def test_factory_returns_singleton(self):
"""Factory should return same instance."""
from backend.rag.template_sparql import get_sparql_validator, _sparql_validator
import backend.rag.template_sparql as module
# Reset
module._sparql_validator = None
validator1 = get_sparql_validator()
validator2 = get_sparql_validator()
assert validator1 is validator2
# =============================================================================
# TIER METRICS TESTS
# =============================================================================
class TestTierMetrics:
"""Tests for tier tracking in metrics."""
def test_rag_tier_in_stats(self):
"""RAG tier should be tracked in tier stats."""
from backend.rag.metrics import get_template_tier_stats
stats = get_template_tier_stats()
# Should include rag tier
if stats.get("available"):
assert "rag" in stats.get("tiers", {})
def test_record_template_tier_accepts_rag(self):
"""record_template_tier should accept 'rag' tier."""
from backend.rag.metrics import record_template_tier
# Should not raise
record_template_tier(tier="rag", matched=True, template_id="test_template")
record_template_tier(tier="rag", matched=False)
# =============================================================================
# VALIDATION RESULT TESTS
# =============================================================================
class TestSPARQLValidationResult:
"""Tests for SPARQLValidationResult model."""
def test_valid_result_creation(self):
"""Should create valid result."""
from backend.rag.template_sparql import SPARQLValidationResult
result = SPARQLValidationResult(valid=True)
assert result.valid is True
assert result.errors == []
assert result.warnings == []
assert result.suggestions == []
def test_invalid_result_with_errors(self):
"""Should create result with errors."""
from backend.rag.template_sparql import SPARQLValidationResult
result = SPARQLValidationResult(
valid=False,
errors=["Unknown predicate: hc:foo"],
suggestions=["Did you mean: hc:foaf?"]
)
assert result.valid is False
assert len(result.errors) == 1
assert len(result.suggestions) == 1
# =============================================================================
# INTEGRATION TESTS
# =============================================================================
class TestFourTierFallback:
"""Integration tests for 4-tier matching fallback."""
@pytest.fixture
def classifier(self):
"""Get TemplateClassifier instance."""
from backend.rag.template_sparql import TemplateClassifier
return TemplateClassifier()
def test_pattern_tier_takes_priority(self, classifier):
"""Tier 1 (pattern) should be checked first."""
# A question that matches a pattern exactly
question = "Welke musea zijn er in Amsterdam?"
result = classifier.forward(question, language="nl")
# Should match (if templates are loaded)
if result.matched:
# Pattern matches should have high confidence
assert result.confidence >= 0.75
assert "pattern" in result.reasoning.lower() or result.confidence >= 0.90
def test_embedding_tier_fallback(self, classifier):
"""Tier 2 (embedding) should be used when pattern fails."""
# A paraphrased question that won't match patterns exactly
question = "Geef mij een lijst van alle musea in de stad Amsterdam"
# This may use embedding tier if pattern doesn't match
result = classifier.forward(question, language="nl")
# Should still get a result (may be embedding or LLM)
# We just verify the fallback mechanism works
assert result is not None
class TestValidatorPredicateSet:
"""Tests to verify the validator predicate/class sets."""
def test_valid_predicates_not_empty(self):
"""Validator should have known predicates."""
from backend.rag.template_sparql import SPARQLValidator
validator = SPARQLValidator()
assert len(validator.VALID_HC_PREDICATES) > 0
assert len(validator.VALID_EXTERNAL_PREDICATES) > 0
def test_institution_type_in_predicates(self):
"""institutionType should be a valid predicate."""
from backend.rag.template_sparql import SPARQLValidator
validator = SPARQLValidator()
assert "hc:institutionType" in validator.VALID_HC_PREDICATES
def test_settlement_name_in_predicates(self):
"""settlementName should be a valid predicate."""
from backend.rag.template_sparql import SPARQLValidator
validator = SPARQLValidator()
assert "hc:settlementName" in validator.VALID_HC_PREDICATES
if __name__ == "__main__":
pytest.main([__file__, "-v"])