338 lines
13 KiB
Python
338 lines
13 KiB
Python
"""
|
|
Layer 1: Unit Tests - Fast tests without LLM calls
|
|
|
|
Tests core components:
|
|
- Semantic signal extraction
|
|
- Query routing rules
|
|
- Entity extraction patterns
|
|
- SPARQL template selection
|
|
- Metrics calculations
|
|
|
|
Target: < 10 seconds, 100% pass rate required for merge
|
|
"""
|
|
|
|
import pytest
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add backend to path for imports
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "backend" / "rag"))
|
|
|
|
|
|
# =============================================================================
|
|
# Semantic Router Tests
|
|
# =============================================================================
|
|
|
|
class TestSemanticSignalExtractor:
|
|
"""Test signal extraction without LLM."""
|
|
|
|
@pytest.fixture
|
|
def extractor(self):
|
|
from backend.rag.semantic_router import SemanticSignalExtractor
|
|
return SemanticSignalExtractor()
|
|
|
|
def test_detect_person_entity_type(self, extractor):
|
|
"""Should detect person queries."""
|
|
signals = extractor.extract_signals("Wie is de directeur van het Rijksmuseum?")
|
|
assert signals.entity_type == "person"
|
|
|
|
def test_detect_institution_entity_type(self, extractor):
|
|
"""Should detect institution queries."""
|
|
signals = extractor.extract_signals("Hoeveel musea zijn er in Amsterdam?")
|
|
assert signals.entity_type == "institution"
|
|
|
|
def test_detect_statistical_intent(self, extractor):
|
|
"""Should detect statistical/count queries."""
|
|
signals = extractor.extract_signals("Hoeveel archieven zijn er in Noord-Holland?")
|
|
assert signals.intent == "statistical"
|
|
assert signals.requires_aggregation is True
|
|
|
|
def test_detect_temporal_intent(self, extractor):
|
|
"""Should detect temporal queries."""
|
|
signals = extractor.extract_signals("Welke musea zijn opgericht voor 1900?")
|
|
assert signals.intent == "temporal"
|
|
assert signals.has_temporal_constraint is True
|
|
|
|
def test_detect_geographic_constraint(self, extractor):
|
|
"""Should detect geographic constraints."""
|
|
signals = extractor.extract_signals("Welke bibliotheken zijn er in Leiden?")
|
|
assert signals.has_geographic_constraint is True
|
|
assert "Leiden" in signals.location_mentions
|
|
|
|
def test_detect_dutch_language(self, extractor):
|
|
"""Should detect Dutch language."""
|
|
signals = extractor.extract_signals("Hoeveel musea zijn er in Nederland?")
|
|
assert signals.language == "nl"
|
|
|
|
def test_detect_english_language(self, extractor):
|
|
"""Should detect English language."""
|
|
signals = extractor.extract_signals("How many museums are there in the Netherlands?")
|
|
assert signals.language == "en"
|
|
|
|
def test_extract_institutions(self, extractor):
|
|
"""Should extract institution names."""
|
|
signals = extractor.extract_signals("Wat is de collectie van het Rijksmuseum?")
|
|
# Check at least one institution-related pattern matches
|
|
assert any("rijksmuseum" in m.lower() for m in signals.institution_mentions) or \
|
|
signals.entity_type == "institution"
|
|
|
|
def test_year_pattern_detection(self, extractor):
|
|
"""Should detect year patterns as temporal constraint."""
|
|
signals = extractor.extract_signals("Musea gesticht in 1850")
|
|
assert signals.has_temporal_constraint is True
|
|
|
|
|
|
class TestSemanticDecisionRouter:
|
|
"""Test routing decisions."""
|
|
|
|
@pytest.fixture
|
|
def router(self):
|
|
from backend.rag.semantic_router import SemanticDecisionRouter
|
|
return SemanticDecisionRouter()
|
|
|
|
@pytest.fixture
|
|
def extractor(self):
|
|
from backend.rag.semantic_router import SemanticSignalExtractor
|
|
return SemanticSignalExtractor()
|
|
|
|
def test_route_person_query_to_qdrant(self, router, extractor):
|
|
"""Person queries should route to Qdrant persons collection."""
|
|
signals = extractor.extract_signals("Wie werkt als archivaris bij het Nationaal Archief?")
|
|
route = router.route(signals)
|
|
assert route.primary_backend == "qdrant"
|
|
assert route.qdrant_collection == "heritage_persons"
|
|
|
|
def test_route_statistical_to_sparql(self, router, extractor):
|
|
"""Statistical queries should route to SPARQL."""
|
|
signals = extractor.extract_signals("Hoeveel musea zijn er in Amsterdam?")
|
|
route = router.route(signals)
|
|
assert route.primary_backend == "sparql"
|
|
|
|
def test_route_temporal_with_templates(self, router, extractor):
|
|
"""Temporal queries should use temporal templates."""
|
|
signals = extractor.extract_signals("Welke archieven zijn opgericht na 1945?")
|
|
route = router.route(signals)
|
|
assert route.use_temporal_templates is True
|
|
|
|
|
|
# =============================================================================
|
|
# Metrics Unit Tests
|
|
# =============================================================================
|
|
|
|
class TestIntentAccuracyMetrics:
|
|
"""Test intent accuracy calculations."""
|
|
|
|
def test_exact_match_returns_1(self):
|
|
from tests.dspy_gitops.metrics.intent_accuracy import intent_accuracy
|
|
assert intent_accuracy("statistical", "statistical") == 1.0
|
|
|
|
def test_case_insensitive_match(self):
|
|
from tests.dspy_gitops.metrics.intent_accuracy import intent_accuracy
|
|
assert intent_accuracy("Statistical", "statistical") == 1.0
|
|
|
|
def test_no_match_returns_0(self):
|
|
from tests.dspy_gitops.metrics.intent_accuracy import intent_accuracy
|
|
assert intent_accuracy("statistical", "temporal") == 0.0
|
|
|
|
def test_similarity_gives_partial_credit(self):
|
|
from tests.dspy_gitops.metrics.intent_accuracy import intent_similarity_score
|
|
# Similar intents should get partial credit
|
|
score = intent_similarity_score("statistical", "exploration")
|
|
assert 0 < score < 1
|
|
|
|
|
|
class TestEntityExtractionMetrics:
|
|
"""Test entity extraction metrics."""
|
|
|
|
def test_perfect_f1(self):
|
|
from tests.dspy_gitops.metrics.entity_extraction import entity_f1
|
|
expected = ["amsterdam", "museum"]
|
|
predicted = ["amsterdam", "museum"]
|
|
assert entity_f1(expected, predicted) == 1.0
|
|
|
|
def test_partial_match_f1(self):
|
|
from tests.dspy_gitops.metrics.entity_extraction import entity_f1
|
|
expected = ["amsterdam", "museum", "library"]
|
|
predicted = ["amsterdam", "museum"]
|
|
score = entity_f1(expected, predicted)
|
|
assert 0 < score < 1 # Should be partial credit
|
|
|
|
def test_no_match_f1(self):
|
|
from tests.dspy_gitops.metrics.entity_extraction import entity_f1
|
|
expected = ["amsterdam"]
|
|
predicted = ["rotterdam"]
|
|
assert entity_f1(expected, predicted) == 0.0
|
|
|
|
def test_precision_calculation(self):
|
|
from tests.dspy_gitops.metrics.entity_extraction import entity_precision
|
|
expected = ["amsterdam"]
|
|
predicted = ["amsterdam", "rotterdam"]
|
|
# Precision = 1 correct / 2 predicted = 0.5
|
|
assert entity_precision(expected, predicted) == 0.5
|
|
|
|
def test_recall_calculation(self):
|
|
from tests.dspy_gitops.metrics.entity_extraction import entity_recall
|
|
expected = ["amsterdam", "rotterdam"]
|
|
predicted = ["amsterdam"]
|
|
# Recall = 1 correct / 2 expected = 0.5
|
|
assert entity_recall(expected, predicted) == 0.5
|
|
|
|
def test_empty_expected_recall(self):
|
|
from tests.dspy_gitops.metrics.entity_extraction import entity_recall
|
|
# If nothing expected, recall is perfect
|
|
assert entity_recall([], ["something"]) == 1.0
|
|
|
|
|
|
class TestSPARQLMetrics:
|
|
"""Test SPARQL validation metrics."""
|
|
|
|
def test_valid_sparql_syntax(self):
|
|
from tests.dspy_gitops.metrics.sparql_correctness import validate_sparql_syntax
|
|
sparql = """
|
|
PREFIX hc: <https://nde.nl/ontology/hc/>
|
|
SELECT ?s WHERE { ?s a hc:Custodian }
|
|
"""
|
|
is_valid, error = validate_sparql_syntax(sparql)
|
|
assert is_valid is True
|
|
assert error is None
|
|
|
|
def test_invalid_sparql_missing_where(self):
|
|
from tests.dspy_gitops.metrics.sparql_correctness import validate_sparql_syntax
|
|
sparql = "SELECT ?s"
|
|
is_valid, error = validate_sparql_syntax(sparql)
|
|
assert is_valid is False
|
|
assert "WHERE" in error
|
|
|
|
def test_invalid_sparql_unbalanced_braces(self):
|
|
from tests.dspy_gitops.metrics.sparql_correctness import validate_sparql_syntax
|
|
sparql = "SELECT ?s WHERE { ?s a hc:Custodian" # Missing closing brace
|
|
is_valid, error = validate_sparql_syntax(sparql)
|
|
assert is_valid is False
|
|
assert "brace" in error.lower()
|
|
|
|
def test_sparql_validation_score(self):
|
|
from tests.dspy_gitops.metrics.sparql_correctness import sparql_validation_score
|
|
valid_sparql = """
|
|
PREFIX hc: <https://nde.nl/ontology/hc/>
|
|
PREFIX crm: <http://www.cidoc-crm.org/cidoc-crm/>
|
|
SELECT (COUNT(?s) as ?count) WHERE {
|
|
?s a crm:E39_Actor ;
|
|
hc:institutionType "M" .
|
|
}
|
|
"""
|
|
score = sparql_validation_score(valid_sparql)
|
|
assert score > 0.8
|
|
|
|
|
|
class TestAnswerRelevanceMetrics:
|
|
"""Test answer relevance calculations."""
|
|
|
|
def test_answer_has_content(self):
|
|
from tests.dspy_gitops.metrics.answer_relevance import answer_has_content
|
|
assert answer_has_content("Er zijn 45 musea in Amsterdam.") is True
|
|
assert answer_has_content("I don't know") is False
|
|
assert answer_has_content("") is False
|
|
assert answer_has_content(" ") is False
|
|
|
|
def test_answer_mentions_entities(self):
|
|
from tests.dspy_gitops.metrics.answer_relevance import answer_mentions_entities
|
|
answer = "Er zijn 45 musea in Amsterdam, waaronder het Rijksmuseum."
|
|
entities = ["amsterdam", "rijksmuseum"]
|
|
score = answer_mentions_entities(answer, entities)
|
|
assert score == 1.0 # Both entities mentioned
|
|
|
|
def test_partial_entity_mention(self):
|
|
from tests.dspy_gitops.metrics.answer_relevance import answer_mentions_entities
|
|
answer = "Er zijn 45 musea in Amsterdam."
|
|
entities = ["amsterdam", "rijksmuseum"]
|
|
score = answer_mentions_entities(answer, entities)
|
|
assert score == 0.5 # Only Amsterdam mentioned
|
|
|
|
def test_language_match_dutch(self):
|
|
from tests.dspy_gitops.metrics.answer_relevance import language_match_score
|
|
dutch_answer = "Er zijn 45 musea in Nederland. De meeste zijn in Amsterdam te vinden."
|
|
assert language_match_score("nl", dutch_answer) == 1.0
|
|
|
|
def test_language_match_english(self):
|
|
from tests.dspy_gitops.metrics.answer_relevance import language_match_score
|
|
english_answer = "There are 45 museums in the Netherlands. Most of them are in Amsterdam."
|
|
assert language_match_score("en", english_answer) == 1.0
|
|
|
|
|
|
class TestCompositeMetrics:
|
|
"""Test composite metric calculations."""
|
|
|
|
def test_heritage_rag_metric_structure(self):
|
|
"""Verify metric accepts correct input structure."""
|
|
from tests.dspy_gitops.metrics.composite import heritage_rag_metric
|
|
from unittest.mock import MagicMock
|
|
|
|
# Create mock example and prediction
|
|
example = MagicMock()
|
|
example.expected_intent = "statistical"
|
|
example.expected_entities = ["amsterdam", "museum"]
|
|
example.language = "nl"
|
|
|
|
pred = MagicMock()
|
|
pred.intent = "statistical"
|
|
pred.entities = ["amsterdam", "museum"]
|
|
pred.sparql = "SELECT ?s WHERE { ?s a ?t }"
|
|
pred.answer = "Er zijn 45 musea in Amsterdam."
|
|
pred.citations = ["oxigraph"]
|
|
pred.confidence = 0.85
|
|
|
|
score = heritage_rag_metric(example, pred)
|
|
assert 0 <= score <= 1
|
|
|
|
def test_create_weighted_metric(self):
|
|
"""Test custom metric creation."""
|
|
from tests.dspy_gitops.metrics.composite import create_weighted_metric
|
|
|
|
# Create intent-only metric
|
|
metric = create_weighted_metric(
|
|
weights={"intent": 1.0},
|
|
include_sparql=False,
|
|
include_answer=False,
|
|
)
|
|
|
|
assert callable(metric)
|
|
|
|
|
|
# =============================================================================
|
|
# Dataset Loading Tests
|
|
# =============================================================================
|
|
|
|
class TestDatasetLoading:
|
|
"""Test dataset loading functionality."""
|
|
|
|
def test_load_dev_examples(self):
|
|
from tests.dspy_gitops.conftest import load_examples_from_json
|
|
examples = load_examples_from_json("heritage_rag_dev.json")
|
|
assert len(examples) > 0
|
|
|
|
# Check structure
|
|
for ex in examples:
|
|
assert "question" in ex
|
|
assert "language" in ex
|
|
assert "expected_intent" in ex
|
|
|
|
def test_golden_queries_exist(self):
|
|
import yaml
|
|
from pathlib import Path
|
|
|
|
golden_path = Path(__file__).parent / "datasets" / "golden_queries.yaml"
|
|
assert golden_path.exists()
|
|
|
|
with open(golden_path) as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
assert "golden_tests" in data
|
|
assert len(data["golden_tests"]) > 0
|
|
|
|
|
|
# =============================================================================
|
|
# Run tests when executed directly
|
|
# =============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "--tb=short"])
|