244 lines
8.2 KiB
Python
244 lines
8.2 KiB
Python
#!/usr/bin/env python3
|
|
"""Test rate limit handling for the streaming pipeline.
|
|
|
|
This script tests that:
|
|
1. Rate limit errors are detected correctly
|
|
2. Retry logic works with exponential backoff
|
|
3. Streaming works after rate limit recovery
|
|
|
|
Run from project root:
|
|
source .venv/bin/activate && source .env
|
|
python backend/rag/test_rate_limit_handling.py
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
import sys
|
|
|
|
# Add project root to path
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
|
|
from backend.rag.dspy_heritage_rag import (
|
|
is_rate_limit_error,
|
|
extract_actual_error,
|
|
call_with_rate_limit_retry,
|
|
)
|
|
|
|
|
|
def test_rate_limit_detection():
|
|
"""Test that rate limit errors are detected correctly."""
|
|
print("\n=== Testing Rate Limit Detection ===\n")
|
|
|
|
# Test direct 429 error
|
|
error_429 = Exception("Error code: 429 - Rate limit exceeded")
|
|
assert is_rate_limit_error(error_429), "Should detect 429 error"
|
|
print("✓ Direct 429 error detected")
|
|
|
|
# Test Z.AI specific error code
|
|
error_1305 = Exception("{'error': {'code': '1305', 'message': 'Too many API requests'}}")
|
|
assert is_rate_limit_error(error_1305), "Should detect Z.AI 1305 error"
|
|
print("✓ Z.AI error code 1305 detected")
|
|
|
|
# Test rate limit in error message
|
|
error_rate = Exception("Rate limit exceeded, please try again later")
|
|
assert is_rate_limit_error(error_rate), "Should detect 'rate' in message"
|
|
print("✓ Rate limit message detected")
|
|
|
|
# Test ExceptionGroup (TaskGroup wrapper)
|
|
class MockExceptionGroup(Exception):
|
|
def __init__(self, exceptions):
|
|
self.exceptions = exceptions
|
|
super().__init__(str(exceptions))
|
|
|
|
nested_error = MockExceptionGroup([Exception("unrelated"), Exception("Error 429")])
|
|
assert is_rate_limit_error(nested_error), "Should detect nested rate limit"
|
|
print("✓ Nested rate limit in ExceptionGroup detected")
|
|
|
|
# Test non-rate-limit error
|
|
other_error = Exception("Connection timeout")
|
|
assert not is_rate_limit_error(other_error), "Should not detect non-rate-limit error"
|
|
print("✓ Non-rate-limit error correctly ignored")
|
|
|
|
print("\n✅ All rate limit detection tests passed!")
|
|
|
|
|
|
def test_error_extraction():
|
|
"""Test that actual errors are extracted from ExceptionGroups."""
|
|
print("\n=== Testing Error Extraction ===\n")
|
|
|
|
# Test simple error (no extraction needed)
|
|
simple_error = Exception("Simple error")
|
|
extracted = extract_actual_error(simple_error)
|
|
assert extracted is simple_error, "Simple error should return itself"
|
|
print("✓ Simple error extraction works")
|
|
|
|
# Test ExceptionGroup with rate limit
|
|
class MockExceptionGroup(Exception):
|
|
def __init__(self, exceptions):
|
|
self.exceptions = exceptions
|
|
super().__init__("ExceptionGroup")
|
|
|
|
rate_error = Exception("Error 429 rate limit")
|
|
wrapped_error = MockExceptionGroup([Exception("other"), rate_error])
|
|
extracted = extract_actual_error(wrapped_error)
|
|
assert extracted is rate_error, "Should extract rate limit error from group"
|
|
print("✓ Rate limit error extracted from ExceptionGroup")
|
|
|
|
print("\n✅ All error extraction tests passed!")
|
|
|
|
|
|
async def test_retry_logic():
|
|
"""Test that retry logic works with exponential backoff."""
|
|
print("\n=== Testing Retry Logic ===\n")
|
|
|
|
# Test successful call (no retry needed)
|
|
call_count = 0
|
|
def success_func():
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return "success"
|
|
|
|
result = await call_with_rate_limit_retry(success_func, max_retries=3)
|
|
assert result == "success", "Should return result on success"
|
|
assert call_count == 1, "Should only call once on success"
|
|
print("✓ Successful call works without retry")
|
|
|
|
# Test retry on rate limit
|
|
call_count = 0
|
|
def fail_then_succeed():
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count < 3:
|
|
raise Exception("Error 429 rate limit")
|
|
return "success after retry"
|
|
|
|
result = await call_with_rate_limit_retry(
|
|
fail_then_succeed,
|
|
max_retries=3,
|
|
base_delay=0.1 # Fast delays for testing
|
|
)
|
|
assert result == "success after retry", "Should succeed after retries"
|
|
assert call_count == 3, "Should retry twice before success"
|
|
print(f"✓ Retry logic works (succeeded after {call_count} attempts)")
|
|
|
|
# Test max retries exceeded
|
|
call_count = 0
|
|
def always_fail():
|
|
nonlocal call_count
|
|
call_count += 1
|
|
raise Exception("Error 429 rate limit")
|
|
|
|
try:
|
|
await call_with_rate_limit_retry(
|
|
always_fail,
|
|
max_retries=2,
|
|
base_delay=0.1
|
|
)
|
|
assert False, "Should raise after max retries"
|
|
except Exception as e:
|
|
assert is_rate_limit_error(e), "Should raise rate limit error"
|
|
assert call_count == 3, "Should try initial + 2 retries"
|
|
print(f"✓ Max retries exceeded correctly (tried {call_count} times)")
|
|
|
|
# Test non-rate-limit error (no retry)
|
|
call_count = 0
|
|
def other_error():
|
|
nonlocal call_count
|
|
call_count += 1
|
|
raise Exception("Connection timeout")
|
|
|
|
try:
|
|
await call_with_rate_limit_retry(other_error, max_retries=3, base_delay=0.1)
|
|
assert False, "Should raise on non-rate-limit error"
|
|
except Exception as e:
|
|
assert call_count == 1, "Should not retry non-rate-limit errors"
|
|
print("✓ Non-rate-limit errors not retried")
|
|
|
|
print("\n✅ All retry logic tests passed!")
|
|
|
|
|
|
async def test_live_streaming():
|
|
"""Test live streaming with the actual API (optional, requires ZAI_API_TOKEN)."""
|
|
import dspy
|
|
|
|
api_key = os.environ.get('ZAI_API_TOKEN')
|
|
if not api_key:
|
|
print("\n⚠️ Skipping live streaming test (ZAI_API_TOKEN not set)")
|
|
return
|
|
|
|
print("\n=== Testing Live Streaming ===\n")
|
|
|
|
# Configure DSPy with Z.AI API
|
|
lm = dspy.LM(
|
|
'openai/glm-4.5-flash',
|
|
api_base='https://api.z.ai/api/coding/paas/v4',
|
|
api_key=api_key,
|
|
cache=False,
|
|
)
|
|
dspy.configure(lm=lm)
|
|
|
|
# Simple signature for testing
|
|
class Simple(dspy.Signature):
|
|
"""Answer a simple question."""
|
|
question: str = dspy.InputField()
|
|
answer: str = dspy.OutputField()
|
|
|
|
cot = dspy.ChainOfThought(Simple)
|
|
streamified = dspy.streamify(cot)
|
|
|
|
print("Streaming test: 'What is 2+2?'")
|
|
token_count = 0
|
|
prediction_received = False
|
|
|
|
try:
|
|
async for value in streamified(question="What is 2+2?"):
|
|
if isinstance(value, dspy.Prediction):
|
|
prediction_received = True
|
|
print(f"\n✓ Final answer: {value.answer}")
|
|
elif hasattr(value, 'choices') and value.choices:
|
|
delta = getattr(value.choices[0], 'delta', None)
|
|
if delta:
|
|
content = getattr(delta, 'content', None)
|
|
if content:
|
|
token_count += 1
|
|
print(content, end='', flush=True)
|
|
|
|
print(f"\n✓ Received {token_count} streaming tokens")
|
|
print(f"✓ Prediction received: {prediction_received}")
|
|
|
|
if token_count > 10:
|
|
print("\n✅ Live streaming test PASSED - received multiple tokens!")
|
|
else:
|
|
print("\n⚠️ Live streaming test: low token count (may indicate rate limiting)")
|
|
|
|
except Exception as e:
|
|
actual = extract_actual_error(e)
|
|
if is_rate_limit_error(e):
|
|
print(f"\n⚠️ Rate limited during live test: {actual}")
|
|
print(" This is expected if the API is being used frequently.")
|
|
else:
|
|
print(f"\n❌ Live streaming test failed: {actual}")
|
|
raise
|
|
|
|
|
|
async def main():
|
|
"""Run all tests."""
|
|
print("=" * 60)
|
|
print("Rate Limit Handling Test Suite")
|
|
print("=" * 60)
|
|
|
|
# Unit tests (always run)
|
|
test_rate_limit_detection()
|
|
test_error_extraction()
|
|
await test_retry_logic()
|
|
|
|
# Live test (optional, requires API key)
|
|
await test_live_streaming()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("All tests completed!")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|