""" Tests for the specificity token counter module. These tests verify token counting behavior WITHOUT hardcoding specific token values, since token counts depend on the tokenizer/model being used. """ import pytest from unittest.mock import patch, MagicMock class TestCountTokens: """Tests for the count_tokens function.""" def test_empty_string_returns_zero(self): """Empty string should return 0 tokens.""" from backend.rag.specificity.token_counter import count_tokens assert count_tokens("") == 0 assert count_tokens(None) == 0 if count_tokens(None) is not None else True def test_returns_positive_integer_for_text(self): """Non-empty text should return positive token count.""" from backend.rag.specificity.token_counter import count_tokens result = count_tokens("Hello world") assert isinstance(result, int) assert result > 0 def test_longer_text_has_more_tokens(self): """Longer text should generally have more tokens.""" from backend.rag.specificity.token_counter import count_tokens short = count_tokens("Hello") long = count_tokens("Hello world, this is a much longer piece of text with many more words") assert long > short def test_model_parameter_accepted(self): """Should accept different model parameters without error.""" from backend.rag.specificity.token_counter import count_tokens # Should not raise count_tokens("test", model="gpt-4o") count_tokens("test", model="gpt-4o-mini") count_tokens("test", model="gpt-3.5-turbo") count_tokens("test", model="unknown-model") # Falls back to cl100k_base def test_fallback_when_tiktoken_unavailable(self): """Should fall back to approximation when tiktoken unavailable.""" from backend.rag.specificity import token_counter # Save original state original_available = token_counter.TIKTOKEN_AVAILABLE try: # Disable tiktoken token_counter.TIKTOKEN_AVAILABLE = False result = token_counter.count_tokens("Hello world test") assert isinstance(result, int) assert result > 0 # Fallback uses len(text) // 4 assert result == len("Hello world test") // 4 finally: # Restore token_counter.TIKTOKEN_AVAILABLE = original_available def test_consistent_results_for_same_input(self): """Same input should produce same output (deterministic).""" from backend.rag.specificity.token_counter import count_tokens text = "The quick brown fox jumps over the lazy dog" result1 = count_tokens(text) result2 = count_tokens(text) assert result1 == result2 class TestCountTokensForContext: """Tests for count_tokens_for_context function.""" def test_returns_integer(self): """Should return an integer token count.""" from backend.rag.specificity.token_counter import count_tokens_for_context result = count_tokens_for_context("general_heritage", 0.5) assert isinstance(result, int) def test_returns_positive_for_valid_template(self): """Should return positive count for valid template.""" from backend.rag.specificity.token_counter import count_tokens_for_context result = count_tokens_for_context("archive_search", 0.5) assert result > 0 def test_accepts_threshold_parameter(self): """Should accept different threshold values.""" from backend.rag.specificity.token_counter import count_tokens_for_context # Should not raise count_tokens_for_context("general_heritage", 0.1) count_tokens_for_context("general_heritage", 0.5) count_tokens_for_context("general_heritage", 0.9) class TestContextSizeComparison: """Tests for the ContextSizeComparison dataclass.""" def test_dataclass_creation(self): """Should create dataclass with required fields.""" from backend.rag.specificity.token_counter import ContextSizeComparison comparison = ContextSizeComparison( template="test", threshold=0.5, filtered_tokens=100, unfiltered_tokens=200, filtered_classes=10, unfiltered_classes=20, ) assert comparison.template == "test" assert comparison.threshold == 0.5 assert comparison.filtered_tokens == 100 assert comparison.unfiltered_tokens == 200 def test_derived_metrics_calculated(self): """Should calculate derived metrics automatically.""" from backend.rag.specificity.token_counter import ContextSizeComparison comparison = ContextSizeComparison( template="test", threshold=0.5, filtered_tokens=100, unfiltered_tokens=200, filtered_classes=10, unfiltered_classes=20, ) # Token reduction: 200 - 100 = 100 assert comparison.token_reduction == 100 # Token reduction percent: 100/200 * 100 = 50% assert comparison.token_reduction_percent == 50.0 # Class reduction: 20 - 10 = 10 assert comparison.class_reduction == 10 # Class reduction percent: 10/20 * 100 = 50% assert comparison.class_reduction_percent == 50.0 def test_handles_zero_unfiltered(self): """Should handle zero unfiltered values without division error.""" from backend.rag.specificity.token_counter import ContextSizeComparison comparison = ContextSizeComparison( template="test", threshold=0.5, filtered_tokens=0, unfiltered_tokens=0, filtered_classes=0, unfiltered_classes=0, ) assert comparison.token_reduction_percent == 0.0 assert comparison.class_reduction_percent == 0.0 def test_to_dict_method(self): """Should convert to dictionary with all fields.""" from backend.rag.specificity.token_counter import ContextSizeComparison comparison = ContextSizeComparison( template="test", threshold=0.5, filtered_tokens=100, unfiltered_tokens=200, filtered_classes=10, unfiltered_classes=20, ) d = comparison.to_dict() assert isinstance(d, dict) assert "template" in d assert "threshold" in d assert "filtered_tokens" in d assert "token_reduction_percent" in d def test_str_representation(self): """Should have readable string representation.""" from backend.rag.specificity.token_counter import ContextSizeComparison comparison = ContextSizeComparison( template="test", threshold=0.5, filtered_tokens=100, unfiltered_tokens=200, filtered_classes=10, unfiltered_classes=20, ) s = str(comparison) assert "test" in s assert "0.5" in s class TestCompareContextSizes: """Tests for compare_context_sizes function.""" def test_returns_context_size_comparison(self): """Should return ContextSizeComparison object.""" from backend.rag.specificity.token_counter import ( compare_context_sizes, ContextSizeComparison, ) result = compare_context_sizes("archive_search", 0.5) assert isinstance(result, ContextSizeComparison) def test_contains_expected_fields(self): """Should populate all expected fields.""" from backend.rag.specificity.token_counter import compare_context_sizes result = compare_context_sizes("archive_search", 0.5) assert result.template == "archive_search" assert result.threshold == 0.5 assert isinstance(result.filtered_tokens, int) assert isinstance(result.unfiltered_tokens, int) assert isinstance(result.filtered_classes, int) assert isinstance(result.unfiltered_classes, int) def test_unfiltered_uses_all_classes(self): """Unfiltered should use threshold 1.0 (all classes).""" from backend.rag.specificity.token_counter import compare_context_sizes result = compare_context_sizes("general_heritage", 0.5) # Unfiltered should have more or equal classes than filtered assert result.unfiltered_classes >= result.filtered_classes class TestBenchmarkAllTemplates: """Tests for benchmark_all_templates function.""" def test_returns_dict(self): """Should return dictionary of results.""" from backend.rag.specificity.token_counter import benchmark_all_templates results = benchmark_all_templates(0.5) assert isinstance(results, dict) def test_contains_all_templates(self): """Should have entry for each context template.""" from backend.rag.specificity.token_counter import benchmark_all_templates from backend.rag.specificity import ContextTemplate results = benchmark_all_templates(0.5) for template in ContextTemplate: assert template.value in results def test_all_values_are_comparisons(self): """All values should be ContextSizeComparison objects.""" from backend.rag.specificity.token_counter import ( benchmark_all_templates, ContextSizeComparison, ) results = benchmark_all_templates(0.5) for template_name, comparison in results.items(): assert isinstance(comparison, ContextSizeComparison) class TestFormatBenchmarkReport: """Tests for format_benchmark_report function.""" def test_returns_string(self): """Should return formatted string.""" from backend.rag.specificity.token_counter import ( benchmark_all_templates, format_benchmark_report, ) results = benchmark_all_templates(0.5) report = format_benchmark_report(results) assert isinstance(report, str) assert len(report) > 0 def test_includes_header_by_default(self): """Should include header by default.""" from backend.rag.specificity.token_counter import ( benchmark_all_templates, format_benchmark_report, ) results = benchmark_all_templates(0.5) report = format_benchmark_report(results) assert "BENCHMARK REPORT" in report def test_can_exclude_header(self): """Should be able to exclude header.""" from backend.rag.specificity.token_counter import ( benchmark_all_templates, format_benchmark_report, ) results = benchmark_all_templates(0.5) report = format_benchmark_report(results, include_header=False) assert "BENCHMARK REPORT" not in report def test_includes_template_names(self): """Should include template names in report.""" from backend.rag.specificity.token_counter import ( benchmark_all_templates, format_benchmark_report, ) results = benchmark_all_templates(0.5) report = format_benchmark_report(results) assert "archive_search" in report assert "general_heritage" in report class TestCostEstimate: """Tests for the CostEstimate dataclass.""" def test_dataclass_creation(self): """Should create dataclass with required fields.""" from backend.rag.specificity.token_counter import CostEstimate estimate = CostEstimate( template="test", threshold=0.5, filtered_tokens=1000, unfiltered_tokens=2000, filtered_cost_1k=0.10, unfiltered_cost_1k=0.20, savings_1k=0.10, savings_percent=50.0, ) assert estimate.template == "test" assert estimate.savings_percent == 50.0 def test_str_representation(self): """Should have readable string representation.""" from backend.rag.specificity.token_counter import CostEstimate estimate = CostEstimate( template="test", threshold=0.5, filtered_tokens=1000, unfiltered_tokens=2000, filtered_cost_1k=0.10, unfiltered_cost_1k=0.20, savings_1k=0.10, savings_percent=50.0, ) s = str(estimate) assert "test" in s assert "1000 queries" in s class TestEstimateCostSavings: """Tests for estimate_cost_savings function.""" def test_returns_cost_estimate(self): """Should return CostEstimate object.""" from backend.rag.specificity.token_counter import ( estimate_cost_savings, CostEstimate, ) result = estimate_cost_savings("archive_search", 0.5) assert isinstance(result, CostEstimate) def test_accepts_custom_pricing(self): """Should accept custom pricing parameter.""" from backend.rag.specificity.token_counter import estimate_cost_savings result1 = estimate_cost_savings("archive_search", 0.5, input_price_per_1m=0.15) result2 = estimate_cost_savings("archive_search", 0.5, input_price_per_1m=0.30) # Higher price should result in higher costs assert result2.filtered_cost_1k > result1.filtered_cost_1k def test_savings_calculation_correct(self): """Savings should equal difference between unfiltered and filtered.""" from backend.rag.specificity.token_counter import estimate_cost_savings result = estimate_cost_savings("archive_search", 0.5) expected_savings = result.unfiltered_cost_1k - result.filtered_cost_1k assert abs(result.savings_1k - expected_savings) < 0.0001 # Float tolerance class TestQuickBenchmark: """Tests for quick_benchmark function.""" def test_runs_without_error(self, capsys): """Should run without raising errors.""" from backend.rag.specificity.token_counter import quick_benchmark # Should not raise quick_benchmark(0.5) captured = capsys.readouterr() assert "benchmark" in captured.out.lower() def test_prints_report(self, capsys): """Should print benchmark report to stdout.""" from backend.rag.specificity.token_counter import quick_benchmark quick_benchmark(0.5) captured = capsys.readouterr() assert len(captured.out) > 0 assert "archive_search" in captured.out or "Template" in captured.out class TestModuleExports: """Tests for module-level exports.""" def test_all_functions_importable_from_package(self): """All functions should be importable from the specificity package.""" from backend.rag.specificity import ( count_tokens, count_tokens_for_context, compare_context_sizes, benchmark_all_templates, format_benchmark_report, estimate_cost_savings, quick_benchmark, ContextSizeComparison, CostEstimate, ) # All should be callable or classes assert callable(count_tokens) assert callable(count_tokens_for_context) assert callable(compare_context_sizes) assert callable(benchmark_all_templates) assert callable(format_benchmark_report) assert callable(estimate_cost_savings) assert callable(quick_benchmark) assert isinstance(ContextSizeComparison, type) assert isinstance(CostEstimate, type) class TestTiktokenAvailability: """Tests for tiktoken availability handling.""" def test_tiktoken_available_flag_exists(self): """Module should have TIKTOKEN_AVAILABLE flag.""" from backend.rag.specificity import token_counter assert hasattr(token_counter, 'TIKTOKEN_AVAILABLE') assert isinstance(token_counter.TIKTOKEN_AVAILABLE, bool) def test_works_regardless_of_tiktoken_availability(self): """Should work whether tiktoken is available or not.""" from backend.rag.specificity.token_counter import count_tokens # Should return valid result regardless result = count_tokens("test text") assert isinstance(result, int) assert result >= 0