#!/usr/bin/env python3 """Test rate limit handling for the streaming pipeline. This script tests that: 1. Rate limit errors are detected correctly 2. Retry logic works with exponential backoff 3. Streaming works after rate limit recovery Run from project root: source .venv/bin/activate && source .env python backend/rag/test_rate_limit_handling.py """ import asyncio import os import sys # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from backend.rag.dspy_heritage_rag import ( is_rate_limit_error, extract_actual_error, call_with_rate_limit_retry, ) def test_rate_limit_detection(): """Test that rate limit errors are detected correctly.""" print("\n=== Testing Rate Limit Detection ===\n") # Test direct 429 error error_429 = Exception("Error code: 429 - Rate limit exceeded") assert is_rate_limit_error(error_429), "Should detect 429 error" print("✓ Direct 429 error detected") # Test Z.AI specific error code error_1305 = Exception("{'error': {'code': '1305', 'message': 'Too many API requests'}}") assert is_rate_limit_error(error_1305), "Should detect Z.AI 1305 error" print("✓ Z.AI error code 1305 detected") # Test rate limit in error message error_rate = Exception("Rate limit exceeded, please try again later") assert is_rate_limit_error(error_rate), "Should detect 'rate' in message" print("✓ Rate limit message detected") # Test ExceptionGroup (TaskGroup wrapper) class MockExceptionGroup(Exception): def __init__(self, exceptions): self.exceptions = exceptions super().__init__(str(exceptions)) nested_error = MockExceptionGroup([Exception("unrelated"), Exception("Error 429")]) assert is_rate_limit_error(nested_error), "Should detect nested rate limit" print("✓ Nested rate limit in ExceptionGroup detected") # Test non-rate-limit error other_error = Exception("Connection timeout") assert not is_rate_limit_error(other_error), "Should not detect non-rate-limit error" print("✓ Non-rate-limit error correctly ignored") print("\n✅ All rate limit detection tests passed!") def test_error_extraction(): """Test that actual errors are extracted from ExceptionGroups.""" print("\n=== Testing Error Extraction ===\n") # Test simple error (no extraction needed) simple_error = Exception("Simple error") extracted = extract_actual_error(simple_error) assert extracted is simple_error, "Simple error should return itself" print("✓ Simple error extraction works") # Test ExceptionGroup with rate limit class MockExceptionGroup(Exception): def __init__(self, exceptions): self.exceptions = exceptions super().__init__("ExceptionGroup") rate_error = Exception("Error 429 rate limit") wrapped_error = MockExceptionGroup([Exception("other"), rate_error]) extracted = extract_actual_error(wrapped_error) assert extracted is rate_error, "Should extract rate limit error from group" print("✓ Rate limit error extracted from ExceptionGroup") print("\n✅ All error extraction tests passed!") async def test_retry_logic(): """Test that retry logic works with exponential backoff.""" print("\n=== Testing Retry Logic ===\n") # Test successful call (no retry needed) call_count = 0 def success_func(): nonlocal call_count call_count += 1 return "success" result = await call_with_rate_limit_retry(success_func, max_retries=3) assert result == "success", "Should return result on success" assert call_count == 1, "Should only call once on success" print("✓ Successful call works without retry") # Test retry on rate limit call_count = 0 def fail_then_succeed(): nonlocal call_count call_count += 1 if call_count < 3: raise Exception("Error 429 rate limit") return "success after retry" result = await call_with_rate_limit_retry( fail_then_succeed, max_retries=3, base_delay=0.1 # Fast delays for testing ) assert result == "success after retry", "Should succeed after retries" assert call_count == 3, "Should retry twice before success" print(f"✓ Retry logic works (succeeded after {call_count} attempts)") # Test max retries exceeded call_count = 0 def always_fail(): nonlocal call_count call_count += 1 raise Exception("Error 429 rate limit") try: await call_with_rate_limit_retry( always_fail, max_retries=2, base_delay=0.1 ) assert False, "Should raise after max retries" except Exception as e: assert is_rate_limit_error(e), "Should raise rate limit error" assert call_count == 3, "Should try initial + 2 retries" print(f"✓ Max retries exceeded correctly (tried {call_count} times)") # Test non-rate-limit error (no retry) call_count = 0 def other_error(): nonlocal call_count call_count += 1 raise Exception("Connection timeout") try: await call_with_rate_limit_retry(other_error, max_retries=3, base_delay=0.1) assert False, "Should raise on non-rate-limit error" except Exception as e: assert call_count == 1, "Should not retry non-rate-limit errors" print("✓ Non-rate-limit errors not retried") print("\n✅ All retry logic tests passed!") async def test_live_streaming(): """Test live streaming with the actual API (optional, requires ZAI_API_TOKEN).""" import dspy api_key = os.environ.get('ZAI_API_TOKEN') if not api_key: print("\n⚠️ Skipping live streaming test (ZAI_API_TOKEN not set)") return print("\n=== Testing Live Streaming ===\n") # Configure DSPy with Z.AI API lm = dspy.LM( 'openai/glm-4.5-flash', api_base='https://api.z.ai/api/coding/paas/v4', api_key=api_key, cache=False, ) dspy.configure(lm=lm) # Simple signature for testing class Simple(dspy.Signature): """Answer a simple question.""" question: str = dspy.InputField() answer: str = dspy.OutputField() cot = dspy.ChainOfThought(Simple) streamified = dspy.streamify(cot) print("Streaming test: 'What is 2+2?'") token_count = 0 prediction_received = False try: async for value in streamified(question="What is 2+2?"): if isinstance(value, dspy.Prediction): prediction_received = True print(f"\n✓ Final answer: {value.answer}") elif hasattr(value, 'choices') and value.choices: delta = getattr(value.choices[0], 'delta', None) if delta: content = getattr(delta, 'content', None) if content: token_count += 1 print(content, end='', flush=True) print(f"\n✓ Received {token_count} streaming tokens") print(f"✓ Prediction received: {prediction_received}") if token_count > 10: print("\n✅ Live streaming test PASSED - received multiple tokens!") else: print("\n⚠️ Live streaming test: low token count (may indicate rate limiting)") except Exception as e: actual = extract_actual_error(e) if is_rate_limit_error(e): print(f"\n⚠️ Rate limited during live test: {actual}") print(" This is expected if the API is being used frequently.") else: print(f"\n❌ Live streaming test failed: {actual}") raise async def main(): """Run all tests.""" print("=" * 60) print("Rate Limit Handling Test Suite") print("=" * 60) # Unit tests (always run) test_rate_limit_detection() test_error_extraction() await test_retry_logic() # Live test (optional, requires API key) await test_live_streaming() print("\n" + "=" * 60) print("All tests completed!") print("=" * 60) if __name__ == "__main__": asyncio.run(main())