- Updated documentation to clarify integration points with existing components in the RAG pipeline and DSPy framework. - Added detailed mapping of SPARQL templates to context templates for improved specificity filtering. - Implemented wrapper patterns around existing classifiers to extend functionality without duplication. - Introduced new tests for the SpecificityAwareClassifier and SPARQLToContextMapper to ensure proper integration and functionality. - Enhanced the CustodianRDFConverter to include ISO country and subregion codes from GHCID for better geospatial data handling.
471 lines
16 KiB
Python
471 lines
16 KiB
Python
"""
|
|
Tests for the specificity token counter module.
|
|
|
|
These tests verify token counting behavior WITHOUT hardcoding specific token values,
|
|
since token counts depend on the tokenizer/model being used.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
|
|
class TestCountTokens:
|
|
"""Tests for the count_tokens function."""
|
|
|
|
def test_empty_string_returns_zero(self):
|
|
"""Empty string should return 0 tokens."""
|
|
from backend.rag.specificity.token_counter import count_tokens
|
|
|
|
assert count_tokens("") == 0
|
|
assert count_tokens(None) == 0 if count_tokens(None) is not None else True
|
|
|
|
def test_returns_positive_integer_for_text(self):
|
|
"""Non-empty text should return positive token count."""
|
|
from backend.rag.specificity.token_counter import count_tokens
|
|
|
|
result = count_tokens("Hello world")
|
|
assert isinstance(result, int)
|
|
assert result > 0
|
|
|
|
def test_longer_text_has_more_tokens(self):
|
|
"""Longer text should generally have more tokens."""
|
|
from backend.rag.specificity.token_counter import count_tokens
|
|
|
|
short = count_tokens("Hello")
|
|
long = count_tokens("Hello world, this is a much longer piece of text with many more words")
|
|
|
|
assert long > short
|
|
|
|
def test_model_parameter_accepted(self):
|
|
"""Should accept different model parameters without error."""
|
|
from backend.rag.specificity.token_counter import count_tokens
|
|
|
|
# Should not raise
|
|
count_tokens("test", model="gpt-4o")
|
|
count_tokens("test", model="gpt-4o-mini")
|
|
count_tokens("test", model="gpt-3.5-turbo")
|
|
count_tokens("test", model="unknown-model") # Falls back to cl100k_base
|
|
|
|
def test_fallback_when_tiktoken_unavailable(self):
|
|
"""Should fall back to approximation when tiktoken unavailable."""
|
|
from backend.rag.specificity import token_counter
|
|
|
|
# Save original state
|
|
original_available = token_counter.TIKTOKEN_AVAILABLE
|
|
|
|
try:
|
|
# Disable tiktoken
|
|
token_counter.TIKTOKEN_AVAILABLE = False
|
|
|
|
result = token_counter.count_tokens("Hello world test")
|
|
assert isinstance(result, int)
|
|
assert result > 0
|
|
# Fallback uses len(text) // 4
|
|
assert result == len("Hello world test") // 4
|
|
finally:
|
|
# Restore
|
|
token_counter.TIKTOKEN_AVAILABLE = original_available
|
|
|
|
def test_consistent_results_for_same_input(self):
|
|
"""Same input should produce same output (deterministic)."""
|
|
from backend.rag.specificity.token_counter import count_tokens
|
|
|
|
text = "The quick brown fox jumps over the lazy dog"
|
|
result1 = count_tokens(text)
|
|
result2 = count_tokens(text)
|
|
|
|
assert result1 == result2
|
|
|
|
|
|
class TestCountTokensForContext:
|
|
"""Tests for count_tokens_for_context function."""
|
|
|
|
def test_returns_integer(self):
|
|
"""Should return an integer token count."""
|
|
from backend.rag.specificity.token_counter import count_tokens_for_context
|
|
|
|
result = count_tokens_for_context("general_heritage", 0.5)
|
|
assert isinstance(result, int)
|
|
|
|
def test_returns_positive_for_valid_template(self):
|
|
"""Should return positive count for valid template."""
|
|
from backend.rag.specificity.token_counter import count_tokens_for_context
|
|
|
|
result = count_tokens_for_context("archive_search", 0.5)
|
|
assert result > 0
|
|
|
|
def test_accepts_threshold_parameter(self):
|
|
"""Should accept different threshold values."""
|
|
from backend.rag.specificity.token_counter import count_tokens_for_context
|
|
|
|
# Should not raise
|
|
count_tokens_for_context("general_heritage", 0.1)
|
|
count_tokens_for_context("general_heritage", 0.5)
|
|
count_tokens_for_context("general_heritage", 0.9)
|
|
|
|
|
|
class TestContextSizeComparison:
|
|
"""Tests for the ContextSizeComparison dataclass."""
|
|
|
|
def test_dataclass_creation(self):
|
|
"""Should create dataclass with required fields."""
|
|
from backend.rag.specificity.token_counter import ContextSizeComparison
|
|
|
|
comparison = ContextSizeComparison(
|
|
template="test",
|
|
threshold=0.5,
|
|
filtered_tokens=100,
|
|
unfiltered_tokens=200,
|
|
filtered_classes=10,
|
|
unfiltered_classes=20,
|
|
)
|
|
|
|
assert comparison.template == "test"
|
|
assert comparison.threshold == 0.5
|
|
assert comparison.filtered_tokens == 100
|
|
assert comparison.unfiltered_tokens == 200
|
|
|
|
def test_derived_metrics_calculated(self):
|
|
"""Should calculate derived metrics automatically."""
|
|
from backend.rag.specificity.token_counter import ContextSizeComparison
|
|
|
|
comparison = ContextSizeComparison(
|
|
template="test",
|
|
threshold=0.5,
|
|
filtered_tokens=100,
|
|
unfiltered_tokens=200,
|
|
filtered_classes=10,
|
|
unfiltered_classes=20,
|
|
)
|
|
|
|
# Token reduction: 200 - 100 = 100
|
|
assert comparison.token_reduction == 100
|
|
# Token reduction percent: 100/200 * 100 = 50%
|
|
assert comparison.token_reduction_percent == 50.0
|
|
# Class reduction: 20 - 10 = 10
|
|
assert comparison.class_reduction == 10
|
|
# Class reduction percent: 10/20 * 100 = 50%
|
|
assert comparison.class_reduction_percent == 50.0
|
|
|
|
def test_handles_zero_unfiltered(self):
|
|
"""Should handle zero unfiltered values without division error."""
|
|
from backend.rag.specificity.token_counter import ContextSizeComparison
|
|
|
|
comparison = ContextSizeComparison(
|
|
template="test",
|
|
threshold=0.5,
|
|
filtered_tokens=0,
|
|
unfiltered_tokens=0,
|
|
filtered_classes=0,
|
|
unfiltered_classes=0,
|
|
)
|
|
|
|
assert comparison.token_reduction_percent == 0.0
|
|
assert comparison.class_reduction_percent == 0.0
|
|
|
|
def test_to_dict_method(self):
|
|
"""Should convert to dictionary with all fields."""
|
|
from backend.rag.specificity.token_counter import ContextSizeComparison
|
|
|
|
comparison = ContextSizeComparison(
|
|
template="test",
|
|
threshold=0.5,
|
|
filtered_tokens=100,
|
|
unfiltered_tokens=200,
|
|
filtered_classes=10,
|
|
unfiltered_classes=20,
|
|
)
|
|
|
|
d = comparison.to_dict()
|
|
assert isinstance(d, dict)
|
|
assert "template" in d
|
|
assert "threshold" in d
|
|
assert "filtered_tokens" in d
|
|
assert "token_reduction_percent" in d
|
|
|
|
def test_str_representation(self):
|
|
"""Should have readable string representation."""
|
|
from backend.rag.specificity.token_counter import ContextSizeComparison
|
|
|
|
comparison = ContextSizeComparison(
|
|
template="test",
|
|
threshold=0.5,
|
|
filtered_tokens=100,
|
|
unfiltered_tokens=200,
|
|
filtered_classes=10,
|
|
unfiltered_classes=20,
|
|
)
|
|
|
|
s = str(comparison)
|
|
assert "test" in s
|
|
assert "0.5" in s
|
|
|
|
|
|
class TestCompareContextSizes:
|
|
"""Tests for compare_context_sizes function."""
|
|
|
|
def test_returns_context_size_comparison(self):
|
|
"""Should return ContextSizeComparison object."""
|
|
from backend.rag.specificity.token_counter import (
|
|
compare_context_sizes,
|
|
ContextSizeComparison,
|
|
)
|
|
|
|
result = compare_context_sizes("archive_search", 0.5)
|
|
assert isinstance(result, ContextSizeComparison)
|
|
|
|
def test_contains_expected_fields(self):
|
|
"""Should populate all expected fields."""
|
|
from backend.rag.specificity.token_counter import compare_context_sizes
|
|
|
|
result = compare_context_sizes("archive_search", 0.5)
|
|
|
|
assert result.template == "archive_search"
|
|
assert result.threshold == 0.5
|
|
assert isinstance(result.filtered_tokens, int)
|
|
assert isinstance(result.unfiltered_tokens, int)
|
|
assert isinstance(result.filtered_classes, int)
|
|
assert isinstance(result.unfiltered_classes, int)
|
|
|
|
def test_unfiltered_uses_all_classes(self):
|
|
"""Unfiltered should use threshold 1.0 (all classes)."""
|
|
from backend.rag.specificity.token_counter import compare_context_sizes
|
|
|
|
result = compare_context_sizes("general_heritage", 0.5)
|
|
|
|
# Unfiltered should have more or equal classes than filtered
|
|
assert result.unfiltered_classes >= result.filtered_classes
|
|
|
|
|
|
class TestBenchmarkAllTemplates:
|
|
"""Tests for benchmark_all_templates function."""
|
|
|
|
def test_returns_dict(self):
|
|
"""Should return dictionary of results."""
|
|
from backend.rag.specificity.token_counter import benchmark_all_templates
|
|
|
|
results = benchmark_all_templates(0.5)
|
|
assert isinstance(results, dict)
|
|
|
|
def test_contains_all_templates(self):
|
|
"""Should have entry for each context template."""
|
|
from backend.rag.specificity.token_counter import benchmark_all_templates
|
|
from backend.rag.specificity import ContextTemplate
|
|
|
|
results = benchmark_all_templates(0.5)
|
|
|
|
for template in ContextTemplate:
|
|
assert template.value in results
|
|
|
|
def test_all_values_are_comparisons(self):
|
|
"""All values should be ContextSizeComparison objects."""
|
|
from backend.rag.specificity.token_counter import (
|
|
benchmark_all_templates,
|
|
ContextSizeComparison,
|
|
)
|
|
|
|
results = benchmark_all_templates(0.5)
|
|
|
|
for template_name, comparison in results.items():
|
|
assert isinstance(comparison, ContextSizeComparison)
|
|
|
|
|
|
class TestFormatBenchmarkReport:
|
|
"""Tests for format_benchmark_report function."""
|
|
|
|
def test_returns_string(self):
|
|
"""Should return formatted string."""
|
|
from backend.rag.specificity.token_counter import (
|
|
benchmark_all_templates,
|
|
format_benchmark_report,
|
|
)
|
|
|
|
results = benchmark_all_templates(0.5)
|
|
report = format_benchmark_report(results)
|
|
|
|
assert isinstance(report, str)
|
|
assert len(report) > 0
|
|
|
|
def test_includes_header_by_default(self):
|
|
"""Should include header by default."""
|
|
from backend.rag.specificity.token_counter import (
|
|
benchmark_all_templates,
|
|
format_benchmark_report,
|
|
)
|
|
|
|
results = benchmark_all_templates(0.5)
|
|
report = format_benchmark_report(results)
|
|
|
|
assert "BENCHMARK REPORT" in report
|
|
|
|
def test_can_exclude_header(self):
|
|
"""Should be able to exclude header."""
|
|
from backend.rag.specificity.token_counter import (
|
|
benchmark_all_templates,
|
|
format_benchmark_report,
|
|
)
|
|
|
|
results = benchmark_all_templates(0.5)
|
|
report = format_benchmark_report(results, include_header=False)
|
|
|
|
assert "BENCHMARK REPORT" not in report
|
|
|
|
def test_includes_template_names(self):
|
|
"""Should include template names in report."""
|
|
from backend.rag.specificity.token_counter import (
|
|
benchmark_all_templates,
|
|
format_benchmark_report,
|
|
)
|
|
|
|
results = benchmark_all_templates(0.5)
|
|
report = format_benchmark_report(results)
|
|
|
|
assert "archive_search" in report
|
|
assert "general_heritage" in report
|
|
|
|
|
|
class TestCostEstimate:
|
|
"""Tests for the CostEstimate dataclass."""
|
|
|
|
def test_dataclass_creation(self):
|
|
"""Should create dataclass with required fields."""
|
|
from backend.rag.specificity.token_counter import CostEstimate
|
|
|
|
estimate = CostEstimate(
|
|
template="test",
|
|
threshold=0.5,
|
|
filtered_tokens=1000,
|
|
unfiltered_tokens=2000,
|
|
filtered_cost_1k=0.10,
|
|
unfiltered_cost_1k=0.20,
|
|
savings_1k=0.10,
|
|
savings_percent=50.0,
|
|
)
|
|
|
|
assert estimate.template == "test"
|
|
assert estimate.savings_percent == 50.0
|
|
|
|
def test_str_representation(self):
|
|
"""Should have readable string representation."""
|
|
from backend.rag.specificity.token_counter import CostEstimate
|
|
|
|
estimate = CostEstimate(
|
|
template="test",
|
|
threshold=0.5,
|
|
filtered_tokens=1000,
|
|
unfiltered_tokens=2000,
|
|
filtered_cost_1k=0.10,
|
|
unfiltered_cost_1k=0.20,
|
|
savings_1k=0.10,
|
|
savings_percent=50.0,
|
|
)
|
|
|
|
s = str(estimate)
|
|
assert "test" in s
|
|
assert "1000 queries" in s
|
|
|
|
|
|
class TestEstimateCostSavings:
|
|
"""Tests for estimate_cost_savings function."""
|
|
|
|
def test_returns_cost_estimate(self):
|
|
"""Should return CostEstimate object."""
|
|
from backend.rag.specificity.token_counter import (
|
|
estimate_cost_savings,
|
|
CostEstimate,
|
|
)
|
|
|
|
result = estimate_cost_savings("archive_search", 0.5)
|
|
assert isinstance(result, CostEstimate)
|
|
|
|
def test_accepts_custom_pricing(self):
|
|
"""Should accept custom pricing parameter."""
|
|
from backend.rag.specificity.token_counter import estimate_cost_savings
|
|
|
|
result1 = estimate_cost_savings("archive_search", 0.5, input_price_per_1m=0.15)
|
|
result2 = estimate_cost_savings("archive_search", 0.5, input_price_per_1m=0.30)
|
|
|
|
# Higher price should result in higher costs
|
|
assert result2.filtered_cost_1k > result1.filtered_cost_1k
|
|
|
|
def test_savings_calculation_correct(self):
|
|
"""Savings should equal difference between unfiltered and filtered."""
|
|
from backend.rag.specificity.token_counter import estimate_cost_savings
|
|
|
|
result = estimate_cost_savings("archive_search", 0.5)
|
|
|
|
expected_savings = result.unfiltered_cost_1k - result.filtered_cost_1k
|
|
assert abs(result.savings_1k - expected_savings) < 0.0001 # Float tolerance
|
|
|
|
|
|
class TestQuickBenchmark:
|
|
"""Tests for quick_benchmark function."""
|
|
|
|
def test_runs_without_error(self, capsys):
|
|
"""Should run without raising errors."""
|
|
from backend.rag.specificity.token_counter import quick_benchmark
|
|
|
|
# Should not raise
|
|
quick_benchmark(0.5)
|
|
|
|
captured = capsys.readouterr()
|
|
assert "benchmark" in captured.out.lower()
|
|
|
|
def test_prints_report(self, capsys):
|
|
"""Should print benchmark report to stdout."""
|
|
from backend.rag.specificity.token_counter import quick_benchmark
|
|
|
|
quick_benchmark(0.5)
|
|
|
|
captured = capsys.readouterr()
|
|
assert len(captured.out) > 0
|
|
assert "archive_search" in captured.out or "Template" in captured.out
|
|
|
|
|
|
class TestModuleExports:
|
|
"""Tests for module-level exports."""
|
|
|
|
def test_all_functions_importable_from_package(self):
|
|
"""All functions should be importable from the specificity package."""
|
|
from backend.rag.specificity import (
|
|
count_tokens,
|
|
count_tokens_for_context,
|
|
compare_context_sizes,
|
|
benchmark_all_templates,
|
|
format_benchmark_report,
|
|
estimate_cost_savings,
|
|
quick_benchmark,
|
|
ContextSizeComparison,
|
|
CostEstimate,
|
|
)
|
|
|
|
# All should be callable or classes
|
|
assert callable(count_tokens)
|
|
assert callable(count_tokens_for_context)
|
|
assert callable(compare_context_sizes)
|
|
assert callable(benchmark_all_templates)
|
|
assert callable(format_benchmark_report)
|
|
assert callable(estimate_cost_savings)
|
|
assert callable(quick_benchmark)
|
|
assert isinstance(ContextSizeComparison, type)
|
|
assert isinstance(CostEstimate, type)
|
|
|
|
|
|
class TestTiktokenAvailability:
|
|
"""Tests for tiktoken availability handling."""
|
|
|
|
def test_tiktoken_available_flag_exists(self):
|
|
"""Module should have TIKTOKEN_AVAILABLE flag."""
|
|
from backend.rag.specificity import token_counter
|
|
|
|
assert hasattr(token_counter, 'TIKTOKEN_AVAILABLE')
|
|
assert isinstance(token_counter.TIKTOKEN_AVAILABLE, bool)
|
|
|
|
def test_works_regardless_of_tiktoken_availability(self):
|
|
"""Should work whether tiktoken is available or not."""
|
|
from backend.rag.specificity.token_counter import count_tokens
|
|
|
|
# Should return valid result regardless
|
|
result = count_tokens("test text")
|
|
assert isinstance(result, int)
|
|
assert result >= 0
|