- Updated documentation to clarify integration points with existing components in the RAG pipeline and DSPy framework. - Added detailed mapping of SPARQL templates to context templates for improved specificity filtering. - Implemented wrapper patterns around existing classifiers to extend functionality without duplication. - Introduced new tests for the SpecificityAwareClassifier and SPARQLToContextMapper to ensure proper integration and functionality. - Enhanced the CustodianRDFConverter to include ISO country and subregion codes from GHCID for better geospatial data handling.
325 lines
13 KiB
Python
325 lines
13 KiB
Python
"""
|
|
Tests for specificity-aware functions in schema_loader.py.
|
|
|
|
These tests verify the integration between the specificity system and
|
|
the schema loader, ensuring that filtered ontology context is generated
|
|
correctly for DSPy prompts.
|
|
"""
|
|
|
|
import pytest
|
|
|
|
|
|
class TestGetFilteredClassesForContext:
|
|
"""Tests for get_filtered_classes_for_context function."""
|
|
|
|
def test_returns_list_of_class_names(self):
|
|
"""Should return a list of string class names."""
|
|
from backend.rag.schema_loader import get_filtered_classes_for_context
|
|
|
|
classes = get_filtered_classes_for_context("archive_search", 0.6)
|
|
|
|
assert isinstance(classes, list)
|
|
assert len(classes) > 0
|
|
assert all(isinstance(c, str) for c in classes)
|
|
|
|
def test_archive_search_includes_archive_classes(self):
|
|
"""Archive search should include archive-related classes."""
|
|
from backend.rag.schema_loader import get_filtered_classes_for_context
|
|
|
|
classes = get_filtered_classes_for_context("archive_search", 0.3)
|
|
|
|
# Archive-specific classes should be included at low threshold
|
|
assert "ArchiveOrganizationType" in classes
|
|
|
|
def test_lower_threshold_returns_fewer_classes(self):
|
|
"""Lower threshold should return fewer classes."""
|
|
from backend.rag.schema_loader import get_filtered_classes_for_context
|
|
|
|
classes_03 = get_filtered_classes_for_context("archive_search", 0.3)
|
|
classes_06 = get_filtered_classes_for_context("archive_search", 0.6)
|
|
|
|
assert len(classes_03) < len(classes_06)
|
|
|
|
def test_unknown_template_falls_back_to_general(self):
|
|
"""Unknown template should fall back to general_heritage."""
|
|
from backend.rag.schema_loader import get_filtered_classes_for_context
|
|
|
|
# Should not raise, should use general_heritage fallback
|
|
classes = get_filtered_classes_for_context("nonexistent_template", 0.6)
|
|
|
|
assert isinstance(classes, list)
|
|
|
|
def test_custodian_always_included_at_reasonable_threshold(self):
|
|
"""Core classes like Custodian should always be included."""
|
|
from backend.rag.schema_loader import get_filtered_classes_for_context
|
|
|
|
for template in ["archive_search", "museum_search", "library_search"]:
|
|
classes = get_filtered_classes_for_context(template, 0.5)
|
|
assert "Custodian" in classes, f"Custodian not in {template}"
|
|
|
|
|
|
class TestGetFilteredClassScoresForContext:
|
|
"""Tests for get_filtered_class_scores_for_context function."""
|
|
|
|
def test_returns_list_of_tuples(self):
|
|
"""Should return list of (class_name, score) tuples."""
|
|
from backend.rag.schema_loader import get_filtered_class_scores_for_context
|
|
|
|
scores = get_filtered_class_scores_for_context("archive_search", 0.6)
|
|
|
|
assert isinstance(scores, list)
|
|
assert len(scores) > 0
|
|
for item in scores:
|
|
assert isinstance(item, tuple)
|
|
assert len(item) == 2
|
|
assert isinstance(item[0], str)
|
|
assert isinstance(item[1], (int, float))
|
|
|
|
def test_scores_are_sorted_ascending(self):
|
|
"""Scores should be sorted ascending (lowest/most relevant first)."""
|
|
from backend.rag.schema_loader import get_filtered_class_scores_for_context
|
|
|
|
scores = get_filtered_class_scores_for_context("archive_search", 0.6)
|
|
|
|
score_values = [s for _, s in scores]
|
|
assert score_values == sorted(score_values)
|
|
|
|
def test_all_scores_below_threshold(self):
|
|
"""All returned scores should be at or below threshold."""
|
|
from backend.rag.schema_loader import get_filtered_class_scores_for_context
|
|
|
|
threshold = 0.5
|
|
scores = get_filtered_class_scores_for_context("archive_search", threshold)
|
|
|
|
for class_name, score in scores:
|
|
assert score <= threshold, f"{class_name} has score {score} > {threshold}"
|
|
|
|
|
|
class TestFormatFilteredOntologyContext:
|
|
"""Tests for format_filtered_ontology_context function."""
|
|
|
|
def test_returns_string(self):
|
|
"""Should return a formatted string."""
|
|
from backend.rag.schema_loader import format_filtered_ontology_context
|
|
|
|
context = format_filtered_ontology_context("archive_search", 0.6)
|
|
|
|
assert isinstance(context, str)
|
|
assert len(context) > 0
|
|
|
|
def test_includes_template_name(self):
|
|
"""Context should include the template name."""
|
|
from backend.rag.schema_loader import format_filtered_ontology_context
|
|
|
|
context = format_filtered_ontology_context("archive_search", 0.6)
|
|
|
|
assert "archive_search" in context
|
|
|
|
def test_includes_threshold_info(self):
|
|
"""Context should include threshold information."""
|
|
from backend.rag.schema_loader import format_filtered_ontology_context
|
|
|
|
context = format_filtered_ontology_context("archive_search", 0.4)
|
|
|
|
assert "0.4" in context or "threshold" in context.lower()
|
|
|
|
def test_includes_hub_architecture(self):
|
|
"""Context should include hub architecture section."""
|
|
from backend.rag.schema_loader import format_filtered_ontology_context
|
|
|
|
context = format_filtered_ontology_context("archive_search", 0.6)
|
|
|
|
assert "Hub Architecture" in context
|
|
assert "Custodian" in context
|
|
|
|
def test_includes_relevant_classes_section(self):
|
|
"""Context should include relevant classes section."""
|
|
from backend.rag.schema_loader import format_filtered_ontology_context
|
|
|
|
context = format_filtered_ontology_context("archive_search", 0.6)
|
|
|
|
assert "Relevant Classes" in context or "relevant classes" in context.lower()
|
|
|
|
def test_includes_ontology_prefixes(self):
|
|
"""Context should include ontology prefix declarations."""
|
|
from backend.rag.schema_loader import format_filtered_ontology_context
|
|
|
|
context = format_filtered_ontology_context("archive_search", 0.6)
|
|
|
|
assert "PREFIX" in context
|
|
|
|
|
|
class TestCreateSpecificityAwareSparqlDocstring:
|
|
"""Tests for create_specificity_aware_sparql_docstring function."""
|
|
|
|
def test_returns_docstring(self):
|
|
"""Should return a docstring suitable for DSPy signatures."""
|
|
from backend.rag.schema_loader import create_specificity_aware_sparql_docstring
|
|
|
|
docstring = create_specificity_aware_sparql_docstring("museum_search", 0.6)
|
|
|
|
assert isinstance(docstring, str)
|
|
assert len(docstring) > 500 # Should be substantial
|
|
|
|
def test_includes_template_context(self):
|
|
"""Docstring should mention the query context."""
|
|
from backend.rag.schema_loader import create_specificity_aware_sparql_docstring
|
|
|
|
docstring = create_specificity_aware_sparql_docstring("museum_search", 0.6)
|
|
|
|
assert "museum_search" in docstring
|
|
|
|
def test_includes_sparql_guidance(self):
|
|
"""Docstring should include SPARQL generation guidance."""
|
|
from backend.rag.schema_loader import create_specificity_aware_sparql_docstring
|
|
|
|
docstring = create_specificity_aware_sparql_docstring("archive_search", 0.6)
|
|
|
|
assert "SPARQL" in docstring
|
|
assert "PREFIX" in docstring
|
|
|
|
def test_includes_hub_architecture(self):
|
|
"""Docstring should include hub architecture info."""
|
|
from backend.rag.schema_loader import create_specificity_aware_sparql_docstring
|
|
|
|
docstring = create_specificity_aware_sparql_docstring("archive_search", 0.6)
|
|
|
|
assert "Hub Architecture" in docstring
|
|
assert "Custodian" in docstring
|
|
|
|
def test_different_templates_produce_different_output(self):
|
|
"""Different templates should produce different docstrings."""
|
|
from backend.rag.schema_loader import create_specificity_aware_sparql_docstring
|
|
|
|
archive_doc = create_specificity_aware_sparql_docstring("archive_search", 0.3)
|
|
museum_doc = create_specificity_aware_sparql_docstring("museum_search", 0.3)
|
|
|
|
# They should differ because class lists are different
|
|
assert archive_doc != museum_doc
|
|
|
|
|
|
class TestGetAvailableContextTemplates:
|
|
"""Tests for get_available_context_templates function."""
|
|
|
|
def test_returns_list_of_strings(self):
|
|
"""Should return list of template names."""
|
|
from backend.rag.schema_loader import get_available_context_templates
|
|
|
|
templates = get_available_context_templates()
|
|
|
|
assert isinstance(templates, list)
|
|
assert all(isinstance(t, str) for t in templates)
|
|
|
|
def test_includes_expected_templates(self):
|
|
"""Should include all 10 expected templates."""
|
|
from backend.rag.schema_loader import get_available_context_templates
|
|
|
|
templates = get_available_context_templates()
|
|
|
|
expected = [
|
|
"archive_search", "museum_search", "library_search",
|
|
"collection_discovery", "person_research", "location_browse",
|
|
"identifier_lookup", "organizational_change", "digital_platform",
|
|
"general_heritage",
|
|
]
|
|
|
|
for expected_template in expected:
|
|
assert expected_template in templates
|
|
|
|
def test_returns_10_templates(self):
|
|
"""Should return exactly 10 templates."""
|
|
from backend.rag.schema_loader import get_available_context_templates
|
|
|
|
templates = get_available_context_templates()
|
|
|
|
assert len(templates) == 10
|
|
|
|
|
|
class TestGetClassCountByTemplate:
|
|
"""Tests for get_class_count_by_template function."""
|
|
|
|
def test_returns_dict_with_all_templates(self):
|
|
"""Should return dict with count for each template."""
|
|
from backend.rag.schema_loader import get_class_count_by_template
|
|
|
|
counts = get_class_count_by_template(0.6)
|
|
|
|
assert isinstance(counts, dict)
|
|
assert len(counts) == 10
|
|
|
|
def test_all_counts_are_integers(self):
|
|
"""All counts should be integers."""
|
|
from backend.rag.schema_loader import get_class_count_by_template
|
|
|
|
counts = get_class_count_by_template(0.6)
|
|
|
|
for template, count in counts.items():
|
|
assert isinstance(count, int)
|
|
assert count >= 0
|
|
|
|
def test_lower_threshold_returns_lower_counts(self):
|
|
"""Lower threshold should return lower counts."""
|
|
from backend.rag.schema_loader import get_class_count_by_template
|
|
|
|
counts_03 = get_class_count_by_template(0.3)
|
|
counts_06 = get_class_count_by_template(0.6)
|
|
|
|
# At least some templates should have lower counts at lower threshold
|
|
lower_count = sum(1 for t in counts_03 if counts_03[t] < counts_06[t])
|
|
assert lower_count > 0
|
|
|
|
def test_archive_search_has_more_classes_at_low_threshold(self):
|
|
"""Archive search should have more classes than museum search at low threshold."""
|
|
from backend.rag.schema_loader import get_class_count_by_template
|
|
|
|
# At 0.3 threshold, archive_search should have more classes because
|
|
# there are many archive-specific RecordSetType classes
|
|
counts = get_class_count_by_template(0.3)
|
|
|
|
assert counts["archive_search"] > counts["museum_search"]
|
|
|
|
|
|
class TestSpecificityIntegrationEndToEnd:
|
|
"""End-to-end tests for specificity integration."""
|
|
|
|
def test_full_pipeline_archive_query(self):
|
|
"""Test full pipeline for archive-related query."""
|
|
from backend.rag.schema_loader import (
|
|
get_filtered_classes_for_context,
|
|
get_filtered_class_scores_for_context,
|
|
format_filtered_ontology_context,
|
|
create_specificity_aware_sparql_docstring,
|
|
)
|
|
|
|
# Step 1: Get filtered classes
|
|
classes = get_filtered_classes_for_context("archive_search", 0.3)
|
|
assert len(classes) > 0
|
|
assert len(classes) < 200 # Should be filtered down from 600+
|
|
|
|
# Step 2: Get scores
|
|
scores = get_filtered_class_scores_for_context("archive_search", 0.3)
|
|
assert scores[0][1] <= 0.3 # First score should be very low
|
|
|
|
# Step 3: Generate context
|
|
context = format_filtered_ontology_context("archive_search", 0.3)
|
|
assert "archive_search" in context
|
|
|
|
# Step 4: Generate SPARQL docstring
|
|
docstring = create_specificity_aware_sparql_docstring("archive_search", 0.3)
|
|
assert "SPARQL" in docstring
|
|
|
|
def test_template_specific_filtering_is_effective(self):
|
|
"""Verify that different templates get different class sets."""
|
|
from backend.rag.schema_loader import get_filtered_classes_for_context
|
|
|
|
archive_classes = set(get_filtered_classes_for_context("archive_search", 0.3))
|
|
museum_classes = set(get_filtered_classes_for_context("museum_search", 0.3))
|
|
person_classes = set(get_filtered_classes_for_context("person_research", 0.3))
|
|
|
|
# There should be some overlap (core classes) but significant differences
|
|
archive_only = archive_classes - museum_classes - person_classes
|
|
museum_only = museum_classes - archive_classes - person_classes
|
|
person_only = person_classes - archive_classes - museum_classes
|
|
|
|
# Each template should have some unique classes at low threshold
|
|
assert len(archive_only) > 0, "Archive should have unique classes"
|