glam/tests/rag/test_specificity_schema_integration.py
kempersc 11983014bb Enhance specificity scoring system integration with existing infrastructure
- Updated documentation to clarify integration points with existing components in the RAG pipeline and DSPy framework.
- Added detailed mapping of SPARQL templates to context templates for improved specificity filtering.
- Implemented wrapper patterns around existing classifiers to extend functionality without duplication.
- Introduced new tests for the SpecificityAwareClassifier and SPARQLToContextMapper to ensure proper integration and functionality.
- Enhanced the CustodianRDFConverter to include ISO country and subregion codes from GHCID for better geospatial data handling.
2026-01-05 17:37:49 +01:00

325 lines
13 KiB
Python

"""
Tests for specificity-aware functions in schema_loader.py.
These tests verify the integration between the specificity system and
the schema loader, ensuring that filtered ontology context is generated
correctly for DSPy prompts.
"""
import pytest
class TestGetFilteredClassesForContext:
"""Tests for get_filtered_classes_for_context function."""
def test_returns_list_of_class_names(self):
"""Should return a list of string class names."""
from backend.rag.schema_loader import get_filtered_classes_for_context
classes = get_filtered_classes_for_context("archive_search", 0.6)
assert isinstance(classes, list)
assert len(classes) > 0
assert all(isinstance(c, str) for c in classes)
def test_archive_search_includes_archive_classes(self):
"""Archive search should include archive-related classes."""
from backend.rag.schema_loader import get_filtered_classes_for_context
classes = get_filtered_classes_for_context("archive_search", 0.3)
# Archive-specific classes should be included at low threshold
assert "ArchiveOrganizationType" in classes
def test_lower_threshold_returns_fewer_classes(self):
"""Lower threshold should return fewer classes."""
from backend.rag.schema_loader import get_filtered_classes_for_context
classes_03 = get_filtered_classes_for_context("archive_search", 0.3)
classes_06 = get_filtered_classes_for_context("archive_search", 0.6)
assert len(classes_03) < len(classes_06)
def test_unknown_template_falls_back_to_general(self):
"""Unknown template should fall back to general_heritage."""
from backend.rag.schema_loader import get_filtered_classes_for_context
# Should not raise, should use general_heritage fallback
classes = get_filtered_classes_for_context("nonexistent_template", 0.6)
assert isinstance(classes, list)
def test_custodian_always_included_at_reasonable_threshold(self):
"""Core classes like Custodian should always be included."""
from backend.rag.schema_loader import get_filtered_classes_for_context
for template in ["archive_search", "museum_search", "library_search"]:
classes = get_filtered_classes_for_context(template, 0.5)
assert "Custodian" in classes, f"Custodian not in {template}"
class TestGetFilteredClassScoresForContext:
"""Tests for get_filtered_class_scores_for_context function."""
def test_returns_list_of_tuples(self):
"""Should return list of (class_name, score) tuples."""
from backend.rag.schema_loader import get_filtered_class_scores_for_context
scores = get_filtered_class_scores_for_context("archive_search", 0.6)
assert isinstance(scores, list)
assert len(scores) > 0
for item in scores:
assert isinstance(item, tuple)
assert len(item) == 2
assert isinstance(item[0], str)
assert isinstance(item[1], (int, float))
def test_scores_are_sorted_ascending(self):
"""Scores should be sorted ascending (lowest/most relevant first)."""
from backend.rag.schema_loader import get_filtered_class_scores_for_context
scores = get_filtered_class_scores_for_context("archive_search", 0.6)
score_values = [s for _, s in scores]
assert score_values == sorted(score_values)
def test_all_scores_below_threshold(self):
"""All returned scores should be at or below threshold."""
from backend.rag.schema_loader import get_filtered_class_scores_for_context
threshold = 0.5
scores = get_filtered_class_scores_for_context("archive_search", threshold)
for class_name, score in scores:
assert score <= threshold, f"{class_name} has score {score} > {threshold}"
class TestFormatFilteredOntologyContext:
"""Tests for format_filtered_ontology_context function."""
def test_returns_string(self):
"""Should return a formatted string."""
from backend.rag.schema_loader import format_filtered_ontology_context
context = format_filtered_ontology_context("archive_search", 0.6)
assert isinstance(context, str)
assert len(context) > 0
def test_includes_template_name(self):
"""Context should include the template name."""
from backend.rag.schema_loader import format_filtered_ontology_context
context = format_filtered_ontology_context("archive_search", 0.6)
assert "archive_search" in context
def test_includes_threshold_info(self):
"""Context should include threshold information."""
from backend.rag.schema_loader import format_filtered_ontology_context
context = format_filtered_ontology_context("archive_search", 0.4)
assert "0.4" in context or "threshold" in context.lower()
def test_includes_hub_architecture(self):
"""Context should include hub architecture section."""
from backend.rag.schema_loader import format_filtered_ontology_context
context = format_filtered_ontology_context("archive_search", 0.6)
assert "Hub Architecture" in context
assert "Custodian" in context
def test_includes_relevant_classes_section(self):
"""Context should include relevant classes section."""
from backend.rag.schema_loader import format_filtered_ontology_context
context = format_filtered_ontology_context("archive_search", 0.6)
assert "Relevant Classes" in context or "relevant classes" in context.lower()
def test_includes_ontology_prefixes(self):
"""Context should include ontology prefix declarations."""
from backend.rag.schema_loader import format_filtered_ontology_context
context = format_filtered_ontology_context("archive_search", 0.6)
assert "PREFIX" in context
class TestCreateSpecificityAwareSparqlDocstring:
"""Tests for create_specificity_aware_sparql_docstring function."""
def test_returns_docstring(self):
"""Should return a docstring suitable for DSPy signatures."""
from backend.rag.schema_loader import create_specificity_aware_sparql_docstring
docstring = create_specificity_aware_sparql_docstring("museum_search", 0.6)
assert isinstance(docstring, str)
assert len(docstring) > 500 # Should be substantial
def test_includes_template_context(self):
"""Docstring should mention the query context."""
from backend.rag.schema_loader import create_specificity_aware_sparql_docstring
docstring = create_specificity_aware_sparql_docstring("museum_search", 0.6)
assert "museum_search" in docstring
def test_includes_sparql_guidance(self):
"""Docstring should include SPARQL generation guidance."""
from backend.rag.schema_loader import create_specificity_aware_sparql_docstring
docstring = create_specificity_aware_sparql_docstring("archive_search", 0.6)
assert "SPARQL" in docstring
assert "PREFIX" in docstring
def test_includes_hub_architecture(self):
"""Docstring should include hub architecture info."""
from backend.rag.schema_loader import create_specificity_aware_sparql_docstring
docstring = create_specificity_aware_sparql_docstring("archive_search", 0.6)
assert "Hub Architecture" in docstring
assert "Custodian" in docstring
def test_different_templates_produce_different_output(self):
"""Different templates should produce different docstrings."""
from backend.rag.schema_loader import create_specificity_aware_sparql_docstring
archive_doc = create_specificity_aware_sparql_docstring("archive_search", 0.3)
museum_doc = create_specificity_aware_sparql_docstring("museum_search", 0.3)
# They should differ because class lists are different
assert archive_doc != museum_doc
class TestGetAvailableContextTemplates:
"""Tests for get_available_context_templates function."""
def test_returns_list_of_strings(self):
"""Should return list of template names."""
from backend.rag.schema_loader import get_available_context_templates
templates = get_available_context_templates()
assert isinstance(templates, list)
assert all(isinstance(t, str) for t in templates)
def test_includes_expected_templates(self):
"""Should include all 10 expected templates."""
from backend.rag.schema_loader import get_available_context_templates
templates = get_available_context_templates()
expected = [
"archive_search", "museum_search", "library_search",
"collection_discovery", "person_research", "location_browse",
"identifier_lookup", "organizational_change", "digital_platform",
"general_heritage",
]
for expected_template in expected:
assert expected_template in templates
def test_returns_10_templates(self):
"""Should return exactly 10 templates."""
from backend.rag.schema_loader import get_available_context_templates
templates = get_available_context_templates()
assert len(templates) == 10
class TestGetClassCountByTemplate:
"""Tests for get_class_count_by_template function."""
def test_returns_dict_with_all_templates(self):
"""Should return dict with count for each template."""
from backend.rag.schema_loader import get_class_count_by_template
counts = get_class_count_by_template(0.6)
assert isinstance(counts, dict)
assert len(counts) == 10
def test_all_counts_are_integers(self):
"""All counts should be integers."""
from backend.rag.schema_loader import get_class_count_by_template
counts = get_class_count_by_template(0.6)
for template, count in counts.items():
assert isinstance(count, int)
assert count >= 0
def test_lower_threshold_returns_lower_counts(self):
"""Lower threshold should return lower counts."""
from backend.rag.schema_loader import get_class_count_by_template
counts_03 = get_class_count_by_template(0.3)
counts_06 = get_class_count_by_template(0.6)
# At least some templates should have lower counts at lower threshold
lower_count = sum(1 for t in counts_03 if counts_03[t] < counts_06[t])
assert lower_count > 0
def test_archive_search_has_more_classes_at_low_threshold(self):
"""Archive search should have more classes than museum search at low threshold."""
from backend.rag.schema_loader import get_class_count_by_template
# At 0.3 threshold, archive_search should have more classes because
# there are many archive-specific RecordSetType classes
counts = get_class_count_by_template(0.3)
assert counts["archive_search"] > counts["museum_search"]
class TestSpecificityIntegrationEndToEnd:
"""End-to-end tests for specificity integration."""
def test_full_pipeline_archive_query(self):
"""Test full pipeline for archive-related query."""
from backend.rag.schema_loader import (
get_filtered_classes_for_context,
get_filtered_class_scores_for_context,
format_filtered_ontology_context,
create_specificity_aware_sparql_docstring,
)
# Step 1: Get filtered classes
classes = get_filtered_classes_for_context("archive_search", 0.3)
assert len(classes) > 0
assert len(classes) < 200 # Should be filtered down from 600+
# Step 2: Get scores
scores = get_filtered_class_scores_for_context("archive_search", 0.3)
assert scores[0][1] <= 0.3 # First score should be very low
# Step 3: Generate context
context = format_filtered_ontology_context("archive_search", 0.3)
assert "archive_search" in context
# Step 4: Generate SPARQL docstring
docstring = create_specificity_aware_sparql_docstring("archive_search", 0.3)
assert "SPARQL" in docstring
def test_template_specific_filtering_is_effective(self):
"""Verify that different templates get different class sets."""
from backend.rag.schema_loader import get_filtered_classes_for_context
archive_classes = set(get_filtered_classes_for_context("archive_search", 0.3))
museum_classes = set(get_filtered_classes_for_context("museum_search", 0.3))
person_classes = set(get_filtered_classes_for_context("person_research", 0.3))
# There should be some overlap (core classes) but significant differences
archive_only = archive_classes - museum_classes - person_classes
museum_only = museum_classes - archive_classes - person_classes
person_only = person_classes - archive_classes - museum_classes
# Each template should have some unique classes at low threshold
assert len(archive_only) > 0, "Archive should have unique classes"