Gary Simmons commited on
Commit
5557a0e
·
1 Parent(s): 13127ce

enhance RateLimitedModel with retry logic moved to model and utilizing configurable parameters

Browse files
Files changed (1) hide show
  1. app.py +83 -67
app.py CHANGED
@@ -61,10 +61,19 @@ class TokenBucketRateLimiter:
61
 
62
 
63
  class RateLimitedModel:
64
- """Wraps a model-like callable and enforces a TokenBucketRateLimiter before each call."""
65
-
66
- def __init__(self, model_obj, rpm: int = 8, burst: int | None = None):
 
 
 
 
 
 
 
67
  self._model = model_obj
 
 
68
  # rpm -> tokens per minute
69
  capacity = burst if burst is not None else max(1, rpm)
70
  refill_rate = float(rpm) / 60.0
@@ -72,24 +81,65 @@ class RateLimitedModel:
72
  capacity=capacity, refill_rate=refill_rate
73
  )
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def __call__(self, *args, **kwargs):
76
- # Each high-level call consumes 1 token. Internal model calls
77
- # inside smolagents may still produce multiple requests; this
78
- # aims to protect against too many top-level calls per-minute.
79
- wait = self._limiter.acquire(1.0)
80
- if wait > 0:
81
- # Small jitter to avoid stampede
82
- jitter = random.uniform(0.0, 0.5)
83
- total_wait = wait + jitter
84
- print(f"RateLimitedModel sleeping {total_wait:.2f}s to respect RPM limit")
85
- time.sleep(total_wait)
86
- return self._model(*args, **kwargs)
87
 
88
  def __getattr__(self, name: str):
89
  """Proxy attribute access to the underlying model.
90
 
91
  For callable attributes (like `generate`) we wrap the call so the
92
- token-bucket rate limiter is applied consistently.
93
  """
94
  # Avoid recursion
95
  if name.startswith("_"):
@@ -100,15 +150,7 @@ class RateLimitedModel:
100
  if callable(attr):
101
 
102
  def wrapped(*args, **kwargs):
103
- wait = self._limiter.acquire(1.0)
104
- if wait > 0:
105
- jitter = random.uniform(0.0, 0.5)
106
- total_wait = wait + jitter
107
- print(
108
- f"RateLimitedModel sleeping {total_wait:.2f}s to respect RPM limit"
109
- )
110
- time.sleep(total_wait)
111
- return attr(*args, **kwargs)
112
 
113
  # Preserve original metadata where possible
114
  try:
@@ -120,14 +162,19 @@ class RateLimitedModel:
120
  return attr
121
 
122
 
123
- # Wrap the model with a rate-limiter. Default RPM is reduced to 8
124
  # but can be configured via the MODEL_RPM environment variable.
125
  _configured_rpm = int(os.getenv("MODEL_RPM", "8"))
126
  _configured_burst = None
 
 
 
127
  model = RateLimitedModel(
128
  LiteLLMModel(model_id="gemini/gemini-2.5-flash", temperature=0.2),
129
  rpm=_configured_rpm,
130
  burst=_configured_burst,
 
 
131
  )
132
 
133
 
@@ -183,56 +230,25 @@ class BasicAgent:
183
 
184
  def __call__(self, question: str) -> str:
185
  print(f"Agent received question (first 50 chars): {question[:50]}...")
 
 
 
186
  try:
187
- # Use retry wrapper to handle rate limit (429) errors from the model
188
- # The CodeAgent is callable and expects a task string as first argument
189
- # Pass it through the `call_model_with_retry` helper so RateLimitError is
190
- # handled with exponential backoff.
191
- response = call_model_with_retry(self.code_agent, question)
 
192
  print(f"Agent returning response: {response}")
193
  return response
194
  except Exception as e:
195
- print(f"Error in code agent: {e}")
 
196
  return f"AGENT ERROR: {e}"
197
 
198
 
199
- # Note: The backoff_multiplier was changed from 3 to 4, which increases the delay between retries exponentially.
200
- # This means that after each failed attempt, the wait time before the next retry will grow more rapidly,
201
- # potentially resulting in significantly longer total retry durations.
202
-
203
 
204
- def call_model_with_retry(
205
- callable_fn,
206
- *args,
207
- max_retries=15,
208
- initial_delay=10.0,
209
- rate_limit_delay=60.0,
210
- **kwargs,
211
- ):
212
- """
213
- Calls a function with retry logic and a configurable wait on rate-limit errors.
214
- Retries up to `max_retries` times, waiting `rate_limit_delay` seconds between attempts after a RateLimitError.
215
- """
216
- delay = initial_delay
217
- # Initial delay before the first attempt in case a very recent call occurred
218
- time.sleep(delay)
219
- for attempt in range(1, max_retries + 1):
220
- try:
221
- print(f"Attempt {attempt} of {max_retries}...")
222
- return callable_fn(*args, **kwargs)
223
- except RateLimitError as e:
224
- # If we've exhausted retries, re-raise
225
- print(f"RateLimitError on attempt {attempt}: {e}")
226
- print(f"max_retries={max_retries}, attempt={attempt}")
227
- if attempt == max_retries:
228
- raise
229
- # Wait a configurable delay on rate-limit errors instead of recalculating
230
- # using exponential backoff. This avoids growing wait times and keeps
231
- # retry behavior predictable.
232
- print(
233
- f"RateLimitError encountered (attempt {attempt}/{max_retries}). Waiting {rate_limit_delay}s before retrying..."
234
- )
235
- time.sleep(rate_limit_delay)
236
 
237
 
238
  def run_and_submit_all(profile: gr.OAuthProfile | None):
 
61
 
62
 
63
  class RateLimitedModel:
64
+ """Wraps a model-like callable and enforces a TokenBucketRateLimiter before each call with retry logic."""
65
+
66
+ def __init__(
67
+ self,
68
+ model_obj,
69
+ rpm: int = 8,
70
+ burst: int | None = None,
71
+ max_retries: int = 10,
72
+ base_delay: float = 30.0,
73
+ ):
74
  self._model = model_obj
75
+ self.max_retries = max_retries
76
+ self.base_delay = base_delay
77
  # rpm -> tokens per minute
78
  capacity = burst if burst is not None else max(1, rpm)
79
  refill_rate = float(rpm) / 60.0
 
81
  capacity=capacity, refill_rate=refill_rate
82
  )
83
 
84
+ def _call_with_retry(self, func, *args, **kwargs):
85
+ """Call a function with retry logic for rate limit errors."""
86
+ last_exception = None
87
+ for attempt in range(1, self.max_retries + 1):
88
+ try:
89
+ # Apply rate limiting before each attempt
90
+ wait = self._limiter.acquire(1.0)
91
+ if wait > 0:
92
+ jitter = random.uniform(0.0, 0.5)
93
+ total_wait = wait + jitter
94
+ print(
95
+ f"RateLimitedModel sleeping {total_wait:.2f}s to respect RPM limit"
96
+ )
97
+ time.sleep(total_wait)
98
+
99
+ print(f"Model call attempt {attempt} of {self.max_retries}")
100
+ result = func(*args, **kwargs)
101
+ print(f"Model call attempt {attempt} succeeded")
102
+ return result
103
+
104
+ except Exception as e:
105
+ last_exception = e
106
+ error_str = str(e).lower()
107
+
108
+ # Check if this is a rate limit error (various ways it might be reported)
109
+ is_rate_limit = (
110
+ isinstance(e, RateLimitError)
111
+ or "rate limit" in error_str
112
+ or "quota" in error_str
113
+ or "429" in error_str
114
+ or "resource_exhausted" in error_str
115
+ or "too many requests" in error_str
116
+ )
117
+
118
+ if is_rate_limit:
119
+ print(f"Rate limit error on attempt {attempt}: {e}")
120
+ if attempt < self.max_retries:
121
+ # Use a longer delay for rate limit errors
122
+ delay = self.base_delay + random.uniform(0, 5)
123
+ print(f"Waiting {delay:.1f}s before retry {attempt + 1}...")
124
+ time.sleep(delay)
125
+ continue
126
+ else:
127
+ # Non-rate-limit error, don't retry
128
+ print(f"Non-rate-limit error on attempt {attempt}: {e}")
129
+ raise e
130
+
131
+ # All retries exhausted
132
+ print(f"All {self.max_retries} attempts failed. Raising last exception.")
133
+ raise last_exception
134
+
135
  def __call__(self, *args, **kwargs):
136
+ return self._call_with_retry(self._model, *args, **kwargs)
 
 
 
 
 
 
 
 
 
 
137
 
138
  def __getattr__(self, name: str):
139
  """Proxy attribute access to the underlying model.
140
 
141
  For callable attributes (like `generate`) we wrap the call so the
142
+ token-bucket rate limiter and retry logic are applied consistently.
143
  """
144
  # Avoid recursion
145
  if name.startswith("_"):
 
150
  if callable(attr):
151
 
152
  def wrapped(*args, **kwargs):
153
+ return self._call_with_retry(attr, *args, **kwargs)
 
 
 
 
 
 
 
 
154
 
155
  # Preserve original metadata where possible
156
  try:
 
162
  return attr
163
 
164
 
165
+ # Wrap the model with a rate-limiter and retry logic. Default RPM is reduced to 8
166
  # but can be configured via the MODEL_RPM environment variable.
167
  _configured_rpm = int(os.getenv("MODEL_RPM", "8"))
168
  _configured_burst = None
169
+ _configured_max_retries = int(os.getenv("MODEL_MAX_RETRIES", "10"))
170
+ _configured_base_delay = float(os.getenv("MODEL_BASE_DELAY", "30.0"))
171
+
172
  model = RateLimitedModel(
173
  LiteLLMModel(model_id="gemini/gemini-2.5-flash", temperature=0.2),
174
  rpm=_configured_rpm,
175
  burst=_configured_burst,
176
+ max_retries=_configured_max_retries,
177
+ base_delay=_configured_base_delay,
178
  )
179
 
180
 
 
230
 
231
  def __call__(self, question: str) -> str:
232
  print(f"Agent received question (first 50 chars): {question[:50]}...")
233
+ print(f"Starting agent execution with model retry logic enabled...")
234
+
235
+ start_time = time.time()
236
  try:
237
+ # The retry logic is now handled at the model level within RateLimitedModel
238
+ # so we can call the agent directly
239
+ response = self.code_agent(question)
240
+
241
+ duration = time.time() - start_time
242
+ print(f"Agent completed successfully in {duration:.1f}s")
243
  print(f"Agent returning response: {response}")
244
  return response
245
  except Exception as e:
246
+ duration = time.time() - start_time
247
+ print(f"Error in code agent after {duration:.1f}s: {e}")
248
  return f"AGENT ERROR: {e}"
249
 
250
 
 
 
 
 
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
 
254
  def run_and_submit_all(profile: gr.OAuthProfile | None):