glam/backend/rag/test_rate_limit_handling.py
2025-12-23 13:27:35 +01:00

244 lines
8.2 KiB
Python

#!/usr/bin/env python3
"""Test rate limit handling for the streaming pipeline.
This script tests that:
1. Rate limit errors are detected correctly
2. Retry logic works with exponential backoff
3. Streaming works after rate limit recovery
Run from project root:
source .venv/bin/activate && source .env
python backend/rag/test_rate_limit_handling.py
"""
import asyncio
import os
import sys
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from backend.rag.dspy_heritage_rag import (
is_rate_limit_error,
extract_actual_error,
call_with_rate_limit_retry,
)
def test_rate_limit_detection():
"""Test that rate limit errors are detected correctly."""
print("\n=== Testing Rate Limit Detection ===\n")
# Test direct 429 error
error_429 = Exception("Error code: 429 - Rate limit exceeded")
assert is_rate_limit_error(error_429), "Should detect 429 error"
print("✓ Direct 429 error detected")
# Test Z.AI specific error code
error_1305 = Exception("{'error': {'code': '1305', 'message': 'Too many API requests'}}")
assert is_rate_limit_error(error_1305), "Should detect Z.AI 1305 error"
print("✓ Z.AI error code 1305 detected")
# Test rate limit in error message
error_rate = Exception("Rate limit exceeded, please try again later")
assert is_rate_limit_error(error_rate), "Should detect 'rate' in message"
print("✓ Rate limit message detected")
# Test ExceptionGroup (TaskGroup wrapper)
class MockExceptionGroup(Exception):
def __init__(self, exceptions):
self.exceptions = exceptions
super().__init__(str(exceptions))
nested_error = MockExceptionGroup([Exception("unrelated"), Exception("Error 429")])
assert is_rate_limit_error(nested_error), "Should detect nested rate limit"
print("✓ Nested rate limit in ExceptionGroup detected")
# Test non-rate-limit error
other_error = Exception("Connection timeout")
assert not is_rate_limit_error(other_error), "Should not detect non-rate-limit error"
print("✓ Non-rate-limit error correctly ignored")
print("\n✅ All rate limit detection tests passed!")
def test_error_extraction():
"""Test that actual errors are extracted from ExceptionGroups."""
print("\n=== Testing Error Extraction ===\n")
# Test simple error (no extraction needed)
simple_error = Exception("Simple error")
extracted = extract_actual_error(simple_error)
assert extracted is simple_error, "Simple error should return itself"
print("✓ Simple error extraction works")
# Test ExceptionGroup with rate limit
class MockExceptionGroup(Exception):
def __init__(self, exceptions):
self.exceptions = exceptions
super().__init__("ExceptionGroup")
rate_error = Exception("Error 429 rate limit")
wrapped_error = MockExceptionGroup([Exception("other"), rate_error])
extracted = extract_actual_error(wrapped_error)
assert extracted is rate_error, "Should extract rate limit error from group"
print("✓ Rate limit error extracted from ExceptionGroup")
print("\n✅ All error extraction tests passed!")
async def test_retry_logic():
"""Test that retry logic works with exponential backoff."""
print("\n=== Testing Retry Logic ===\n")
# Test successful call (no retry needed)
call_count = 0
def success_func():
nonlocal call_count
call_count += 1
return "success"
result = await call_with_rate_limit_retry(success_func, max_retries=3)
assert result == "success", "Should return result on success"
assert call_count == 1, "Should only call once on success"
print("✓ Successful call works without retry")
# Test retry on rate limit
call_count = 0
def fail_then_succeed():
nonlocal call_count
call_count += 1
if call_count < 3:
raise Exception("Error 429 rate limit")
return "success after retry"
result = await call_with_rate_limit_retry(
fail_then_succeed,
max_retries=3,
base_delay=0.1 # Fast delays for testing
)
assert result == "success after retry", "Should succeed after retries"
assert call_count == 3, "Should retry twice before success"
print(f"✓ Retry logic works (succeeded after {call_count} attempts)")
# Test max retries exceeded
call_count = 0
def always_fail():
nonlocal call_count
call_count += 1
raise Exception("Error 429 rate limit")
try:
await call_with_rate_limit_retry(
always_fail,
max_retries=2,
base_delay=0.1
)
assert False, "Should raise after max retries"
except Exception as e:
assert is_rate_limit_error(e), "Should raise rate limit error"
assert call_count == 3, "Should try initial + 2 retries"
print(f"✓ Max retries exceeded correctly (tried {call_count} times)")
# Test non-rate-limit error (no retry)
call_count = 0
def other_error():
nonlocal call_count
call_count += 1
raise Exception("Connection timeout")
try:
await call_with_rate_limit_retry(other_error, max_retries=3, base_delay=0.1)
assert False, "Should raise on non-rate-limit error"
except Exception as e:
assert call_count == 1, "Should not retry non-rate-limit errors"
print("✓ Non-rate-limit errors not retried")
print("\n✅ All retry logic tests passed!")
async def test_live_streaming():
"""Test live streaming with the actual API (optional, requires ZAI_API_TOKEN)."""
import dspy
api_key = os.environ.get('ZAI_API_TOKEN')
if not api_key:
print("\n⚠️ Skipping live streaming test (ZAI_API_TOKEN not set)")
return
print("\n=== Testing Live Streaming ===\n")
# Configure DSPy with Z.AI API
lm = dspy.LM(
'openai/glm-4.5-flash',
api_base='https://api.z.ai/api/coding/paas/v4',
api_key=api_key,
cache=False,
)
dspy.configure(lm=lm)
# Simple signature for testing
class Simple(dspy.Signature):
"""Answer a simple question."""
question: str = dspy.InputField()
answer: str = dspy.OutputField()
cot = dspy.ChainOfThought(Simple)
streamified = dspy.streamify(cot)
print("Streaming test: 'What is 2+2?'")
token_count = 0
prediction_received = False
try:
async for value in streamified(question="What is 2+2?"):
if isinstance(value, dspy.Prediction):
prediction_received = True
print(f"\n✓ Final answer: {value.answer}")
elif hasattr(value, 'choices') and value.choices:
delta = getattr(value.choices[0], 'delta', None)
if delta:
content = getattr(delta, 'content', None)
if content:
token_count += 1
print(content, end='', flush=True)
print(f"\n✓ Received {token_count} streaming tokens")
print(f"✓ Prediction received: {prediction_received}")
if token_count > 10:
print("\n✅ Live streaming test PASSED - received multiple tokens!")
else:
print("\n⚠️ Live streaming test: low token count (may indicate rate limiting)")
except Exception as e:
actual = extract_actual_error(e)
if is_rate_limit_error(e):
print(f"\n⚠️ Rate limited during live test: {actual}")
print(" This is expected if the API is being used frequently.")
else:
print(f"\n❌ Live streaming test failed: {actual}")
raise
async def main():
"""Run all tests."""
print("=" * 60)
print("Rate Limit Handling Test Suite")
print("=" * 60)
# Unit tests (always run)
test_rate_limit_detection()
test_error_extraction()
await test_retry_logic()
# Live test (optional, requires API key)
await test_live_streaming()
print("\n" + "=" * 60)
print("All tests completed!")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())