glam/tests/rag/test_specificity_token_counter.py
kempersc 11983014bb Enhance specificity scoring system integration with existing infrastructure
- Updated documentation to clarify integration points with existing components in the RAG pipeline and DSPy framework.
- Added detailed mapping of SPARQL templates to context templates for improved specificity filtering.
- Implemented wrapper patterns around existing classifiers to extend functionality without duplication.
- Introduced new tests for the SpecificityAwareClassifier and SPARQLToContextMapper to ensure proper integration and functionality.
- Enhanced the CustodianRDFConverter to include ISO country and subregion codes from GHCID for better geospatial data handling.
2026-01-05 17:37:49 +01:00

471 lines
16 KiB
Python

"""
Tests for the specificity token counter module.
These tests verify token counting behavior WITHOUT hardcoding specific token values,
since token counts depend on the tokenizer/model being used.
"""
import pytest
from unittest.mock import patch, MagicMock
class TestCountTokens:
"""Tests for the count_tokens function."""
def test_empty_string_returns_zero(self):
"""Empty string should return 0 tokens."""
from backend.rag.specificity.token_counter import count_tokens
assert count_tokens("") == 0
assert count_tokens(None) == 0 if count_tokens(None) is not None else True
def test_returns_positive_integer_for_text(self):
"""Non-empty text should return positive token count."""
from backend.rag.specificity.token_counter import count_tokens
result = count_tokens("Hello world")
assert isinstance(result, int)
assert result > 0
def test_longer_text_has_more_tokens(self):
"""Longer text should generally have more tokens."""
from backend.rag.specificity.token_counter import count_tokens
short = count_tokens("Hello")
long = count_tokens("Hello world, this is a much longer piece of text with many more words")
assert long > short
def test_model_parameter_accepted(self):
"""Should accept different model parameters without error."""
from backend.rag.specificity.token_counter import count_tokens
# Should not raise
count_tokens("test", model="gpt-4o")
count_tokens("test", model="gpt-4o-mini")
count_tokens("test", model="gpt-3.5-turbo")
count_tokens("test", model="unknown-model") # Falls back to cl100k_base
def test_fallback_when_tiktoken_unavailable(self):
"""Should fall back to approximation when tiktoken unavailable."""
from backend.rag.specificity import token_counter
# Save original state
original_available = token_counter.TIKTOKEN_AVAILABLE
try:
# Disable tiktoken
token_counter.TIKTOKEN_AVAILABLE = False
result = token_counter.count_tokens("Hello world test")
assert isinstance(result, int)
assert result > 0
# Fallback uses len(text) // 4
assert result == len("Hello world test") // 4
finally:
# Restore
token_counter.TIKTOKEN_AVAILABLE = original_available
def test_consistent_results_for_same_input(self):
"""Same input should produce same output (deterministic)."""
from backend.rag.specificity.token_counter import count_tokens
text = "The quick brown fox jumps over the lazy dog"
result1 = count_tokens(text)
result2 = count_tokens(text)
assert result1 == result2
class TestCountTokensForContext:
"""Tests for count_tokens_for_context function."""
def test_returns_integer(self):
"""Should return an integer token count."""
from backend.rag.specificity.token_counter import count_tokens_for_context
result = count_tokens_for_context("general_heritage", 0.5)
assert isinstance(result, int)
def test_returns_positive_for_valid_template(self):
"""Should return positive count for valid template."""
from backend.rag.specificity.token_counter import count_tokens_for_context
result = count_tokens_for_context("archive_search", 0.5)
assert result > 0
def test_accepts_threshold_parameter(self):
"""Should accept different threshold values."""
from backend.rag.specificity.token_counter import count_tokens_for_context
# Should not raise
count_tokens_for_context("general_heritage", 0.1)
count_tokens_for_context("general_heritage", 0.5)
count_tokens_for_context("general_heritage", 0.9)
class TestContextSizeComparison:
"""Tests for the ContextSizeComparison dataclass."""
def test_dataclass_creation(self):
"""Should create dataclass with required fields."""
from backend.rag.specificity.token_counter import ContextSizeComparison
comparison = ContextSizeComparison(
template="test",
threshold=0.5,
filtered_tokens=100,
unfiltered_tokens=200,
filtered_classes=10,
unfiltered_classes=20,
)
assert comparison.template == "test"
assert comparison.threshold == 0.5
assert comparison.filtered_tokens == 100
assert comparison.unfiltered_tokens == 200
def test_derived_metrics_calculated(self):
"""Should calculate derived metrics automatically."""
from backend.rag.specificity.token_counter import ContextSizeComparison
comparison = ContextSizeComparison(
template="test",
threshold=0.5,
filtered_tokens=100,
unfiltered_tokens=200,
filtered_classes=10,
unfiltered_classes=20,
)
# Token reduction: 200 - 100 = 100
assert comparison.token_reduction == 100
# Token reduction percent: 100/200 * 100 = 50%
assert comparison.token_reduction_percent == 50.0
# Class reduction: 20 - 10 = 10
assert comparison.class_reduction == 10
# Class reduction percent: 10/20 * 100 = 50%
assert comparison.class_reduction_percent == 50.0
def test_handles_zero_unfiltered(self):
"""Should handle zero unfiltered values without division error."""
from backend.rag.specificity.token_counter import ContextSizeComparison
comparison = ContextSizeComparison(
template="test",
threshold=0.5,
filtered_tokens=0,
unfiltered_tokens=0,
filtered_classes=0,
unfiltered_classes=0,
)
assert comparison.token_reduction_percent == 0.0
assert comparison.class_reduction_percent == 0.0
def test_to_dict_method(self):
"""Should convert to dictionary with all fields."""
from backend.rag.specificity.token_counter import ContextSizeComparison
comparison = ContextSizeComparison(
template="test",
threshold=0.5,
filtered_tokens=100,
unfiltered_tokens=200,
filtered_classes=10,
unfiltered_classes=20,
)
d = comparison.to_dict()
assert isinstance(d, dict)
assert "template" in d
assert "threshold" in d
assert "filtered_tokens" in d
assert "token_reduction_percent" in d
def test_str_representation(self):
"""Should have readable string representation."""
from backend.rag.specificity.token_counter import ContextSizeComparison
comparison = ContextSizeComparison(
template="test",
threshold=0.5,
filtered_tokens=100,
unfiltered_tokens=200,
filtered_classes=10,
unfiltered_classes=20,
)
s = str(comparison)
assert "test" in s
assert "0.5" in s
class TestCompareContextSizes:
"""Tests for compare_context_sizes function."""
def test_returns_context_size_comparison(self):
"""Should return ContextSizeComparison object."""
from backend.rag.specificity.token_counter import (
compare_context_sizes,
ContextSizeComparison,
)
result = compare_context_sizes("archive_search", 0.5)
assert isinstance(result, ContextSizeComparison)
def test_contains_expected_fields(self):
"""Should populate all expected fields."""
from backend.rag.specificity.token_counter import compare_context_sizes
result = compare_context_sizes("archive_search", 0.5)
assert result.template == "archive_search"
assert result.threshold == 0.5
assert isinstance(result.filtered_tokens, int)
assert isinstance(result.unfiltered_tokens, int)
assert isinstance(result.filtered_classes, int)
assert isinstance(result.unfiltered_classes, int)
def test_unfiltered_uses_all_classes(self):
"""Unfiltered should use threshold 1.0 (all classes)."""
from backend.rag.specificity.token_counter import compare_context_sizes
result = compare_context_sizes("general_heritage", 0.5)
# Unfiltered should have more or equal classes than filtered
assert result.unfiltered_classes >= result.filtered_classes
class TestBenchmarkAllTemplates:
"""Tests for benchmark_all_templates function."""
def test_returns_dict(self):
"""Should return dictionary of results."""
from backend.rag.specificity.token_counter import benchmark_all_templates
results = benchmark_all_templates(0.5)
assert isinstance(results, dict)
def test_contains_all_templates(self):
"""Should have entry for each context template."""
from backend.rag.specificity.token_counter import benchmark_all_templates
from backend.rag.specificity import ContextTemplate
results = benchmark_all_templates(0.5)
for template in ContextTemplate:
assert template.value in results
def test_all_values_are_comparisons(self):
"""All values should be ContextSizeComparison objects."""
from backend.rag.specificity.token_counter import (
benchmark_all_templates,
ContextSizeComparison,
)
results = benchmark_all_templates(0.5)
for template_name, comparison in results.items():
assert isinstance(comparison, ContextSizeComparison)
class TestFormatBenchmarkReport:
"""Tests for format_benchmark_report function."""
def test_returns_string(self):
"""Should return formatted string."""
from backend.rag.specificity.token_counter import (
benchmark_all_templates,
format_benchmark_report,
)
results = benchmark_all_templates(0.5)
report = format_benchmark_report(results)
assert isinstance(report, str)
assert len(report) > 0
def test_includes_header_by_default(self):
"""Should include header by default."""
from backend.rag.specificity.token_counter import (
benchmark_all_templates,
format_benchmark_report,
)
results = benchmark_all_templates(0.5)
report = format_benchmark_report(results)
assert "BENCHMARK REPORT" in report
def test_can_exclude_header(self):
"""Should be able to exclude header."""
from backend.rag.specificity.token_counter import (
benchmark_all_templates,
format_benchmark_report,
)
results = benchmark_all_templates(0.5)
report = format_benchmark_report(results, include_header=False)
assert "BENCHMARK REPORT" not in report
def test_includes_template_names(self):
"""Should include template names in report."""
from backend.rag.specificity.token_counter import (
benchmark_all_templates,
format_benchmark_report,
)
results = benchmark_all_templates(0.5)
report = format_benchmark_report(results)
assert "archive_search" in report
assert "general_heritage" in report
class TestCostEstimate:
"""Tests for the CostEstimate dataclass."""
def test_dataclass_creation(self):
"""Should create dataclass with required fields."""
from backend.rag.specificity.token_counter import CostEstimate
estimate = CostEstimate(
template="test",
threshold=0.5,
filtered_tokens=1000,
unfiltered_tokens=2000,
filtered_cost_1k=0.10,
unfiltered_cost_1k=0.20,
savings_1k=0.10,
savings_percent=50.0,
)
assert estimate.template == "test"
assert estimate.savings_percent == 50.0
def test_str_representation(self):
"""Should have readable string representation."""
from backend.rag.specificity.token_counter import CostEstimate
estimate = CostEstimate(
template="test",
threshold=0.5,
filtered_tokens=1000,
unfiltered_tokens=2000,
filtered_cost_1k=0.10,
unfiltered_cost_1k=0.20,
savings_1k=0.10,
savings_percent=50.0,
)
s = str(estimate)
assert "test" in s
assert "1000 queries" in s
class TestEstimateCostSavings:
"""Tests for estimate_cost_savings function."""
def test_returns_cost_estimate(self):
"""Should return CostEstimate object."""
from backend.rag.specificity.token_counter import (
estimate_cost_savings,
CostEstimate,
)
result = estimate_cost_savings("archive_search", 0.5)
assert isinstance(result, CostEstimate)
def test_accepts_custom_pricing(self):
"""Should accept custom pricing parameter."""
from backend.rag.specificity.token_counter import estimate_cost_savings
result1 = estimate_cost_savings("archive_search", 0.5, input_price_per_1m=0.15)
result2 = estimate_cost_savings("archive_search", 0.5, input_price_per_1m=0.30)
# Higher price should result in higher costs
assert result2.filtered_cost_1k > result1.filtered_cost_1k
def test_savings_calculation_correct(self):
"""Savings should equal difference between unfiltered and filtered."""
from backend.rag.specificity.token_counter import estimate_cost_savings
result = estimate_cost_savings("archive_search", 0.5)
expected_savings = result.unfiltered_cost_1k - result.filtered_cost_1k
assert abs(result.savings_1k - expected_savings) < 0.0001 # Float tolerance
class TestQuickBenchmark:
"""Tests for quick_benchmark function."""
def test_runs_without_error(self, capsys):
"""Should run without raising errors."""
from backend.rag.specificity.token_counter import quick_benchmark
# Should not raise
quick_benchmark(0.5)
captured = capsys.readouterr()
assert "benchmark" in captured.out.lower()
def test_prints_report(self, capsys):
"""Should print benchmark report to stdout."""
from backend.rag.specificity.token_counter import quick_benchmark
quick_benchmark(0.5)
captured = capsys.readouterr()
assert len(captured.out) > 0
assert "archive_search" in captured.out or "Template" in captured.out
class TestModuleExports:
"""Tests for module-level exports."""
def test_all_functions_importable_from_package(self):
"""All functions should be importable from the specificity package."""
from backend.rag.specificity import (
count_tokens,
count_tokens_for_context,
compare_context_sizes,
benchmark_all_templates,
format_benchmark_report,
estimate_cost_savings,
quick_benchmark,
ContextSizeComparison,
CostEstimate,
)
# All should be callable or classes
assert callable(count_tokens)
assert callable(count_tokens_for_context)
assert callable(compare_context_sizes)
assert callable(benchmark_all_templates)
assert callable(format_benchmark_report)
assert callable(estimate_cost_savings)
assert callable(quick_benchmark)
assert isinstance(ContextSizeComparison, type)
assert isinstance(CostEstimate, type)
class TestTiktokenAvailability:
"""Tests for tiktoken availability handling."""
def test_tiktoken_available_flag_exists(self):
"""Module should have TIKTOKEN_AVAILABLE flag."""
from backend.rag.specificity import token_counter
assert hasattr(token_counter, 'TIKTOKEN_AVAILABLE')
assert isinstance(token_counter.TIKTOKEN_AVAILABLE, bool)
def test_works_regardless_of_tiktoken_availability(self):
"""Should work whether tiktoken is available or not."""
from backend.rag.specificity.token_counter import count_tokens
# Should return valid result regardless
result = count_tokens("test text")
assert isinstance(result, int)
assert result >= 0