Major architectural changes based on Formica et al. (2023) research: - Add TemplateClassifier for deterministic SPARQL template matching - Add SlotExtractor with synonym resolution for slot values - Add TemplateInstantiator using Jinja2 for query rendering - Refactor dspy_heritage_rag.py to use template system - Update main.py with streamlined pipeline - Fix semantic_router.py ordering issues - Add comprehensive metrics tracking Template-based approach achieves 65% precision vs 10% LLM-only per Formica et al. research on SPARQL generation.
436 lines
17 KiB
Python
436 lines
17 KiB
Python
"""
|
|
Tests for SOTA Template Matching Components
|
|
|
|
Tests the new components added based on SOTA research:
|
|
1. RAGEnhancedMatcher (Tier 2.5) - Context-enriched matching using Q&A examples
|
|
2. SPARQLValidator - Validates SPARQL against ontology schema
|
|
3. 4-tier fallback behavior - Pattern -> Embedding -> RAG -> LLM
|
|
|
|
Based on:
|
|
- SPARQL-LLM (arXiv:2512.14277)
|
|
- COT-SPARQL (SEMANTICS 2024)
|
|
- KGQuest (arXiv:2511.11258)
|
|
|
|
Author: OpenCode
|
|
Created: 2025-01-07
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
import numpy as np
|
|
|
|
# =============================================================================
|
|
# SPARQL VALIDATOR TESTS
|
|
# =============================================================================
|
|
|
|
class TestSPARQLValidator:
|
|
"""Tests for the SPARQLValidator class."""
|
|
|
|
@pytest.fixture
|
|
def validator(self):
|
|
"""Get a fresh validator instance."""
|
|
from backend.rag.template_sparql import SPARQLValidator
|
|
return SPARQLValidator()
|
|
|
|
def test_valid_query_with_known_predicates(self, validator):
|
|
"""Valid query using known hc: predicates should pass."""
|
|
sparql = """
|
|
SELECT ?name ?type WHERE {
|
|
?inst hc:institutionType "M" .
|
|
?inst hc:settlementName "Amsterdam" .
|
|
}
|
|
"""
|
|
result = validator.validate(sparql)
|
|
assert result.valid is True
|
|
assert len(result.errors) == 0
|
|
|
|
def test_valid_query_with_schema_predicates(self, validator):
|
|
"""Valid query using schema.org predicates should pass."""
|
|
sparql = """
|
|
SELECT ?name WHERE {
|
|
?inst schema:name ?name .
|
|
?inst schema:foundingDate ?date .
|
|
}
|
|
"""
|
|
result = validator.validate(sparql)
|
|
assert result.valid is True
|
|
assert len(result.errors) == 0
|
|
|
|
def test_invalid_predicate_detected(self, validator):
|
|
"""Unknown hc: predicate should be flagged."""
|
|
sparql = """
|
|
SELECT ?x WHERE {
|
|
?x hc:unknownPredicate "test" .
|
|
}
|
|
"""
|
|
result = validator.validate(sparql)
|
|
# Should have warning or error for unknown predicate
|
|
assert len(result.warnings) > 0 or len(result.errors) > 0
|
|
|
|
def test_typo_suggestion(self, validator):
|
|
"""Typo in predicate should suggest correction."""
|
|
sparql = """
|
|
SELECT ?x WHERE {
|
|
?x hc:institutionTyp "M" .
|
|
}
|
|
"""
|
|
result = validator.validate(sparql)
|
|
# Should suggest "hc:institutionType"
|
|
if result.suggestions:
|
|
assert any("institutionType" in s for s in result.suggestions)
|
|
|
|
def test_mismatched_braces_detected(self, validator):
|
|
"""Mismatched braces should be flagged."""
|
|
sparql = """
|
|
SELECT ?x WHERE {
|
|
?x hc:institutionType "M" .
|
|
""" # Missing closing brace
|
|
result = validator.validate(sparql)
|
|
assert result.valid is False
|
|
assert any("brace" in e.lower() for e in result.errors)
|
|
|
|
def test_missing_where_clause_detected(self, validator):
|
|
"""SELECT without WHERE should be flagged."""
|
|
sparql = """
|
|
SELECT ?x {
|
|
?x hc:institutionType "M" .
|
|
}
|
|
"""
|
|
result = validator.validate(sparql)
|
|
# Note: This has braces but no WHERE keyword
|
|
assert result.valid is False
|
|
assert any("WHERE" in e for e in result.errors)
|
|
|
|
def test_non_hc_query_passes(self, validator):
|
|
"""Query without hc: predicates should pass (not our responsibility)."""
|
|
sparql = """
|
|
SELECT ?s ?p ?o WHERE {
|
|
?s ?p ?o .
|
|
} LIMIT 10
|
|
"""
|
|
result = validator.validate(sparql)
|
|
assert result.valid is True
|
|
|
|
def test_budget_predicates_valid(self, validator):
|
|
"""Budget-related predicates should be valid."""
|
|
sparql = """
|
|
SELECT ?budget WHERE {
|
|
?inst hc:innovation_budget ?budget .
|
|
?inst hc:digitization_budget ?dbudget .
|
|
}
|
|
"""
|
|
result = validator.validate(sparql)
|
|
assert result.valid is True
|
|
assert len(result.errors) == 0
|
|
|
|
def test_change_event_predicates_valid(self, validator):
|
|
"""Change event predicates should be valid."""
|
|
sparql = """
|
|
SELECT ?event WHERE {
|
|
?event hc:changeType "MERGER" .
|
|
?event hc:eventDate ?date .
|
|
}
|
|
"""
|
|
result = validator.validate(sparql)
|
|
assert result.valid is True
|
|
|
|
|
|
# =============================================================================
|
|
# RAG ENHANCED MATCHER TESTS
|
|
# =============================================================================
|
|
|
|
class TestRAGEnhancedMatcher:
|
|
"""Tests for the RAGEnhancedMatcher class."""
|
|
|
|
@pytest.fixture
|
|
def mock_templates(self):
|
|
"""Create mock template definitions for testing."""
|
|
from backend.rag.template_sparql import TemplateDefinition, SlotDefinition, SlotType
|
|
|
|
return {
|
|
"list_institutions_by_type_city": TemplateDefinition(
|
|
id="list_institutions_by_type_city",
|
|
description="List institutions by type in a city",
|
|
intent=["list", "institutions", "city"],
|
|
question_patterns=["Welke {type} zijn er in {city}?"],
|
|
slots={
|
|
"type": SlotDefinition(type=SlotType.INSTITUTION_TYPE),
|
|
"city": SlotDefinition(type=SlotType.CITY),
|
|
},
|
|
sparql_template="SELECT ?inst WHERE { ?inst hc:institutionType '{{ type }}' }",
|
|
examples=[
|
|
{"question": "Welke musea zijn er in Amsterdam?", "slots": {"type": "M", "city": "Amsterdam"}},
|
|
{"question": "Welke archieven zijn er in Utrecht?", "slots": {"type": "A", "city": "Utrecht"}},
|
|
{"question": "Welke bibliotheken zijn er in Rotterdam?", "slots": {"type": "L", "city": "Rotterdam"}},
|
|
],
|
|
),
|
|
"count_institutions_by_type": TemplateDefinition(
|
|
id="count_institutions_by_type",
|
|
description="Count institutions by type",
|
|
intent=["count", "institutions"],
|
|
question_patterns=["Hoeveel {type} zijn er?"],
|
|
slots={
|
|
"type": SlotDefinition(type=SlotType.INSTITUTION_TYPE),
|
|
},
|
|
sparql_template="SELECT (COUNT(?inst) AS ?count) WHERE { ?inst hc:institutionType '{{ type }}' }",
|
|
examples=[
|
|
{"question": "Hoeveel musea zijn er in Nederland?", "slots": {"type": "M"}},
|
|
{"question": "Hoeveel archieven zijn er?", "slots": {"type": "A"}},
|
|
],
|
|
),
|
|
}
|
|
|
|
@pytest.fixture
|
|
def matcher(self):
|
|
"""Get a fresh RAG matcher instance (reset singleton)."""
|
|
from backend.rag.template_sparql import RAGEnhancedMatcher
|
|
# Reset singleton state for clean tests
|
|
RAGEnhancedMatcher._instance = None
|
|
RAGEnhancedMatcher._example_embeddings = None
|
|
RAGEnhancedMatcher._example_template_ids = None
|
|
RAGEnhancedMatcher._example_texts = None
|
|
RAGEnhancedMatcher._example_slots = None
|
|
return RAGEnhancedMatcher()
|
|
|
|
def test_singleton_pattern(self):
|
|
"""RAGEnhancedMatcher should be a singleton."""
|
|
from backend.rag.template_sparql import RAGEnhancedMatcher
|
|
# Reset first
|
|
RAGEnhancedMatcher._instance = None
|
|
|
|
matcher1 = RAGEnhancedMatcher()
|
|
matcher2 = RAGEnhancedMatcher()
|
|
assert matcher1 is matcher2
|
|
|
|
def test_match_returns_none_without_model(self, matcher, mock_templates):
|
|
"""Should return None if embedding model not available."""
|
|
with patch('backend.rag.template_sparql._get_embedding_model', return_value=None):
|
|
result = matcher.match("Welke musea zijn er in Den Haag?", mock_templates)
|
|
assert result is None
|
|
|
|
def test_match_with_high_agreement(self, matcher, mock_templates):
|
|
"""Should match when examples agree on template."""
|
|
# Create mock embedding model
|
|
mock_model = MagicMock()
|
|
|
|
# Create embeddings that will give high similarity for "list" template
|
|
# The question embedding
|
|
question_emb = np.array([1.0, 0.0, 0.0])
|
|
# Example embeddings - 3 from list template, 2 from count template
|
|
example_embs = np.array([
|
|
[0.95, 0.1, 0.0], # list - high similarity
|
|
[0.90, 0.15, 0.0], # list - high similarity
|
|
[0.92, 0.12, 0.0], # list - high similarity
|
|
[0.3, 0.9, 0.1], # count - low similarity
|
|
[0.25, 0.85, 0.15], # count - low similarity
|
|
])
|
|
|
|
mock_model.encode = MagicMock(side_effect=[
|
|
example_embs, # First call for indexing
|
|
np.array([question_emb]), # Second call for question
|
|
])
|
|
|
|
with patch('backend.rag.template_sparql._get_embedding_model', return_value=mock_model):
|
|
# Reset cached embeddings
|
|
matcher._example_embeddings = None
|
|
result = matcher.match("Welke musea zijn er in Den Haag?", mock_templates, k=5)
|
|
|
|
# Should match list template with high confidence
|
|
if result is not None:
|
|
assert result.matched is True
|
|
assert result.template_id == "list_institutions_by_type_city"
|
|
assert result.confidence >= 0.70
|
|
|
|
def test_match_returns_none_with_low_agreement(self, matcher, mock_templates):
|
|
"""Should return None when examples don't agree."""
|
|
mock_model = MagicMock()
|
|
|
|
# Create embeddings with low agreement (split between templates)
|
|
question_emb = np.array([0.5, 0.5, 0.0])
|
|
example_embs = np.array([
|
|
[0.6, 0.4, 0.0], # list - medium similarity
|
|
[0.55, 0.45, 0.0], # list - medium similarity
|
|
[0.52, 0.48, 0.0], # list - medium similarity
|
|
[0.45, 0.55, 0.0], # count - medium similarity
|
|
[0.4, 0.6, 0.0], # count - medium similarity
|
|
])
|
|
|
|
mock_model.encode = MagicMock(side_effect=[
|
|
example_embs,
|
|
np.array([question_emb]),
|
|
])
|
|
|
|
with patch('backend.rag.template_sparql._get_embedding_model', return_value=mock_model):
|
|
matcher._example_embeddings = None
|
|
# With mixed agreement, should fail min_agreement threshold
|
|
result = matcher.match(
|
|
"Geef me informatie over musea",
|
|
mock_templates,
|
|
k=5,
|
|
min_agreement=0.8 # High threshold
|
|
)
|
|
# May return None if agreement is below threshold
|
|
# The exact behavior depends on similarity calculation
|
|
|
|
|
|
class TestRAGEnhancedMatcherFactory:
|
|
"""Tests for the get_rag_enhanced_matcher factory function."""
|
|
|
|
def test_factory_returns_singleton(self):
|
|
"""Factory should return same instance."""
|
|
from backend.rag.template_sparql import get_rag_enhanced_matcher, RAGEnhancedMatcher
|
|
# Reset
|
|
RAGEnhancedMatcher._instance = None
|
|
|
|
matcher1 = get_rag_enhanced_matcher()
|
|
matcher2 = get_rag_enhanced_matcher()
|
|
assert matcher1 is matcher2
|
|
|
|
|
|
class TestSPARQLValidatorFactory:
|
|
"""Tests for the get_sparql_validator factory function."""
|
|
|
|
def test_factory_returns_instance(self):
|
|
"""Factory should return validator instance."""
|
|
from backend.rag.template_sparql import get_sparql_validator, SPARQLValidator
|
|
|
|
validator = get_sparql_validator()
|
|
assert isinstance(validator, SPARQLValidator)
|
|
|
|
def test_factory_returns_singleton(self):
|
|
"""Factory should return same instance."""
|
|
from backend.rag.template_sparql import get_sparql_validator, _sparql_validator
|
|
import backend.rag.template_sparql as module
|
|
|
|
# Reset
|
|
module._sparql_validator = None
|
|
|
|
validator1 = get_sparql_validator()
|
|
validator2 = get_sparql_validator()
|
|
assert validator1 is validator2
|
|
|
|
|
|
# =============================================================================
|
|
# TIER METRICS TESTS
|
|
# =============================================================================
|
|
|
|
class TestTierMetrics:
|
|
"""Tests for tier tracking in metrics."""
|
|
|
|
def test_rag_tier_in_stats(self):
|
|
"""RAG tier should be tracked in tier stats."""
|
|
from backend.rag.metrics import get_template_tier_stats
|
|
|
|
stats = get_template_tier_stats()
|
|
|
|
# Should include rag tier
|
|
if stats.get("available"):
|
|
assert "rag" in stats.get("tiers", {})
|
|
|
|
def test_record_template_tier_accepts_rag(self):
|
|
"""record_template_tier should accept 'rag' tier."""
|
|
from backend.rag.metrics import record_template_tier
|
|
|
|
# Should not raise
|
|
record_template_tier(tier="rag", matched=True, template_id="test_template")
|
|
record_template_tier(tier="rag", matched=False)
|
|
|
|
|
|
# =============================================================================
|
|
# VALIDATION RESULT TESTS
|
|
# =============================================================================
|
|
|
|
class TestSPARQLValidationResult:
|
|
"""Tests for SPARQLValidationResult model."""
|
|
|
|
def test_valid_result_creation(self):
|
|
"""Should create valid result."""
|
|
from backend.rag.template_sparql import SPARQLValidationResult
|
|
|
|
result = SPARQLValidationResult(valid=True)
|
|
assert result.valid is True
|
|
assert result.errors == []
|
|
assert result.warnings == []
|
|
assert result.suggestions == []
|
|
|
|
def test_invalid_result_with_errors(self):
|
|
"""Should create result with errors."""
|
|
from backend.rag.template_sparql import SPARQLValidationResult
|
|
|
|
result = SPARQLValidationResult(
|
|
valid=False,
|
|
errors=["Unknown predicate: hc:foo"],
|
|
suggestions=["Did you mean: hc:foaf?"]
|
|
)
|
|
assert result.valid is False
|
|
assert len(result.errors) == 1
|
|
assert len(result.suggestions) == 1
|
|
|
|
|
|
# =============================================================================
|
|
# INTEGRATION TESTS
|
|
# =============================================================================
|
|
|
|
class TestFourTierFallback:
|
|
"""Integration tests for 4-tier matching fallback."""
|
|
|
|
@pytest.fixture
|
|
def classifier(self):
|
|
"""Get TemplateClassifier instance."""
|
|
from backend.rag.template_sparql import TemplateClassifier
|
|
return TemplateClassifier()
|
|
|
|
def test_pattern_tier_takes_priority(self, classifier):
|
|
"""Tier 1 (pattern) should be checked first."""
|
|
# A question that matches a pattern exactly
|
|
question = "Welke musea zijn er in Amsterdam?"
|
|
|
|
result = classifier.forward(question, language="nl")
|
|
|
|
# Should match (if templates are loaded)
|
|
if result.matched:
|
|
# Pattern matches should have high confidence
|
|
assert result.confidence >= 0.75
|
|
assert "pattern" in result.reasoning.lower() or result.confidence >= 0.90
|
|
|
|
def test_embedding_tier_fallback(self, classifier):
|
|
"""Tier 2 (embedding) should be used when pattern fails."""
|
|
# A paraphrased question that won't match patterns exactly
|
|
question = "Geef mij een lijst van alle musea in de stad Amsterdam"
|
|
|
|
# This may use embedding tier if pattern doesn't match
|
|
result = classifier.forward(question, language="nl")
|
|
|
|
# Should still get a result (may be embedding or LLM)
|
|
# We just verify the fallback mechanism works
|
|
assert result is not None
|
|
|
|
|
|
class TestValidatorPredicateSet:
|
|
"""Tests to verify the validator predicate/class sets."""
|
|
|
|
def test_valid_predicates_not_empty(self):
|
|
"""Validator should have known predicates."""
|
|
from backend.rag.template_sparql import SPARQLValidator
|
|
|
|
validator = SPARQLValidator()
|
|
assert len(validator.VALID_HC_PREDICATES) > 0
|
|
assert len(validator.VALID_EXTERNAL_PREDICATES) > 0
|
|
|
|
def test_institution_type_in_predicates(self):
|
|
"""institutionType should be a valid predicate."""
|
|
from backend.rag.template_sparql import SPARQLValidator
|
|
|
|
validator = SPARQLValidator()
|
|
assert "hc:institutionType" in validator.VALID_HC_PREDICATES
|
|
|
|
def test_settlement_name_in_predicates(self):
|
|
"""settlementName should be a valid predicate."""
|
|
from backend.rag.template_sparql import SPARQLValidator
|
|
|
|
validator = SPARQLValidator()
|
|
assert "hc:settlementName" in validator.VALID_HC_PREDICATES
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|