Skip to content

Support Rate Limit #1782

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions metagpt/configs/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ class LLMConfig(YamlModel):
top_logprobs: Optional[int] = None
timeout: int = 600
context_length: Optional[int] = None # Max input tokens
# For rate limit control
rpm: Optional[int] = 0
tpm: Optional[int] = 0

# For Amazon Bedrock
region_name: str = None
Expand Down
10 changes: 5 additions & 5 deletions metagpt/ext/aflow/scripts/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import asyncio
import time
from typing import List, Literal
from typing import List, Literal, Optional

from pydantic import BaseModel, Field

Expand All @@ -24,9 +24,9 @@


class GraphOptimize(BaseModel):
modification: str = Field(default="", description="modification")
graph: str = Field(default="", description="graph")
prompt: str = Field(default="", description="prompt")
modification: Optional[str] = Field(default="", description="modification")
graph: Optional[str] = Field(default="", description="graph")
prompt: Optional[str] = Field(default="", description="prompt")


class Optimizer:
Expand Down Expand Up @@ -90,7 +90,7 @@ def optimize(self, mode: OptimizerType = "Graph"):
break
except Exception as e:
retry_count += 1
logger.info(f"Error occurred: {e}. Retrying... (Attempt {retry_count}/{max_retries})")
logger.exception(f"Error occurred: {e}. Retrying... (Attempt {retry_count}/{max_retries})")
if retry_count == max_retries:
logger.info("Max retries reached. Moving to next round.")
score = None
Expand Down
2 changes: 1 addition & 1 deletion metagpt/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,5 @@ def get_llm_stream_queue():


def _llm_stream_log(msg):
if _print_level in ["INFO"]:
if _print_level in ["INFO", "DEBUG"]:
print(msg, end="")
17 changes: 14 additions & 3 deletions metagpt/provider/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from metagpt.utils.common import log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.token_counter import TOKEN_MAX
from metagpt.utils.rate_limitor import RateLimitor, rate_limitor_registry


class BaseLLM(ABC):
Expand All @@ -44,6 +45,7 @@ class BaseLLM(ABC):
cost_manager: Optional[CostManager] = None
model: Optional[str] = None # deprecated
pricing_plan: Optional[str] = None
current_rate_limitor: Optional[RateLimitor] = None

_reasoning_content: Optional[str] = None # content from reasoning mode

Expand Down Expand Up @@ -134,6 +136,7 @@ def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, model)
self.rate_limitor.cost_token(usage)
except Exception as e:
logger.error(f"{self.__class__.__name__} updates costs failed! exp: {e}")

Expand Down Expand Up @@ -197,11 +200,13 @@ async def aask(
message.extend(msg)
if stream is None:
stream = self.config.stream

async with self.rate_limitor:
await self.rate_limitor.acquire(message)

# the image data is replaced with placeholders to avoid long output
masked_message = [self.mask_base64_data(m) for m in message]
logger.debug(masked_message)

compressed_message = self.compress_messages(message, compress_type=self.config.compress_type)
rsp = await self.acompletion_text(compressed_message, stream=stream, timeout=self.get_timeout(timeout))
# rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
Expand Down Expand Up @@ -323,6 +328,12 @@ def with_model(self, model: str):
"""Set model and return self. For example, `with_model("gpt-3.5-turbo")`."""
self.config.model = model
return self

@property
def rate_limitor(self) -> RateLimitor:
if not self.current_rate_limitor:
self.current_rate_limitor = rate_limitor_registry.register(None, self.config)
return self.current_rate_limitor

def get_timeout(self, timeout: int) -> int:
return timeout or self.config.timeout or LLM_API_TIMEOUT
Expand Down Expand Up @@ -407,4 +418,4 @@ def compress_messages(
)
break

return compressed
return compressed
11 changes: 11 additions & 0 deletions metagpt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from asyncio import iscoroutinefunction
from datetime import datetime
from functools import partial
import asyncio
import nest_asyncio
from io import BytesIO
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
Expand Down Expand Up @@ -672,6 +674,15 @@ def format_trackback_info(limit: int = 2):
return traceback.format_exc(limit=limit)


def asyncio_run(future):
nest_asyncio.apply()
try:
loop = asyncio.get_event_loop()
return loop.run_until_complete(future)
except RuntimeError:
return asyncio.run(future)


def serialize_decorator(func):
async def wrapper(self, *args, **kwargs):
try:
Expand Down
169 changes: 169 additions & 0 deletions metagpt/utils/rate_limitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import time
import asyncio
import math
import json

from pydantic_core import to_jsonable_python
from metagpt.utils.token_counter import count_message_tokens
from metagpt.configs.llm_config import LLMConfig
from metagpt.logs import logger
from metagpt.configs.models_config import ModelsConfig
import metagpt.utils.common as common

class RateLimitor:
def __init__(self, rpm: int, tpm: int):
self.rpm = rpm
self.tpm = tpm
self.tpm_bucket = TokenBucket(tpm)
self.rpm_bucket = TokenBucket(rpm)
self.lock = asyncio.Semaphore(rpm)

async def acquire_rpm(self, tokens=1):
await self.rpm_bucket.acquire(tokens)


async def __enter__(self):
if self.rpm > 0 or self.tpm > 0:
await self.lock.acquire()
return self

async def __exit__(self, exc_type, exc_val, exc_tb):
if self.rpm > 0 or self.tpm > 0:
self.lock.release()
return None

async def __aenter__(self):
return await self.__enter__()

async def __aexit__(self, exc_type, exc_val, exc_tb):
return await self.__exit__(exc_type, exc_val, exc_tb)

def cost_token(self, usage: dict):
if not isinstance(usage, dict):
usage = dict(usage)
self.tpm_bucket._cost(usage.get("input_tokens", usage.get('prompt_tokens', 0)))
self.tpm_bucket._cost(usage.get("output_tokens", usage.get('completion_tokens', 0)))


async def acquire(self, messages):
tokens = count_message_tokens(messages)
await self.tpm_bucket._wait(tokens)
await self.acquire_rpm(1)


class TokenBucket:
def __init__(self, rpm):
"""
Initialize the token bucket (thread-safe version)
:param rpm: the number of requests per minute
"""
if rpm is None:
rpm = 0
self.capacity = rpm # the capacity of the bucket
self.tokens = rpm # the current number of tokens
self.rate = rpm / 60.0 if rpm else 0 # the number of tokens generated per second
self.last_refill = time.time()
self.lock = asyncio.Lock() # 线程安全锁

async def _refill(self, desc_tokens=0):
async with self.lock:
"""Refill the tokens (need to be called in the lock protected context)"""
if self.capacity is None or self.capacity <= 0:
return
# assert self.capacity >= desc_tokens, f"令牌桶的容量[{self.capacity}]无法支撑该次请求的消耗:{desc_tokens}."
now = time.time()
elapsed = now - self.last_refill
new_tokens = elapsed * self.rate

if new_tokens + self.tokens >= desc_tokens or self.tokens >= self.capacity:
self.tokens = min(self.capacity, self.tokens + new_tokens) - desc_tokens
self.last_refill = now
return True # 表示有新增令牌
else:
self.tokens = min(self.capacity, self.tokens + new_tokens)
self.last_refill = now
return False

def _cost(self, tokens: int):
if self.capacity is None or self.capacity <= 0:
return
assert tokens >= 0
common.asyncio_run(self._refill())
self.tokens -= tokens

async def _wait(self, tokens: int):
while True:
if await self._refill(desc_tokens=tokens):
# enough tokens, return immediately
return True
deficit = tokens - self.tokens
wait_time = deficit / self.rate

logger.warning(f"current [{asyncio.current_task().get_name()}] with [{self.tokens:.5f}] tokens, wait_time for tpm: {wait_time:.3f}")
await asyncio.sleep(wait_time)

async def acquire(self, tokens=1):
"""
Block until acquiring the specified number of tokens
:param tokens: the number of tokens needed (default is 1)
"""
if self.capacity is None or self.capacity <= 0:
return

while True:
# if the tokens are enough, return immediately
if await self._refill(desc_tokens=tokens):
return

# calculate the time to wait
deficit = tokens - self.tokens
wait_time = deficit / self.rate

logger.warning(f"current [{asyncio.current_task().get_name()}] with [{self.tokens:.5f}] tokens, wait_time for rpm: {wait_time:.3f}")

# wait until the tokens are replenished (with timeout and notification)
await asyncio.sleep(wait_time)

@property
def available_tokens(self):
"""Get the current number of available tokens (refreshed in real time)"""
if self.capacity is None or self.capacity <= 0:
return math.inf
common.asyncio_run(self._refill())
return self.tokens


class RateLimitorRegistry:
def __init__(self):
self.rate_limitors = {}
self.config_items = {}

def init_rate_limitors(self):
for model_name, llm_config in ModelsConfig.default().items():
self.register(model_name, llm_config)

def _config_to_key(self, llm_config: LLMConfig):
return json.dumps(llm_config.model_dump(), default=to_jsonable_python)

def register(self, model_name: str, llm_config: LLMConfig) -> RateLimitor:
if not llm_config:
raise ValueError("llm_config is required")
if not model_name:
model_name = self._config_to_key(llm_config)
if model_name not in self.rate_limitors:
self.rate_limitors[model_name] = RateLimitor(llm_config.rpm, llm_config.tpm)
self.config_items[self._config_to_key(llm_config)] = model_name
return self.rate_limitors[model_name]

def get(self, model_name: str):
if not model_name:
model_name = "_default_llm"
return self.rate_limitors.get(model_name)

def get_by_config(self, llm_config: LLMConfig):
rate_limitor_key = self._config_to_key(llm_config)
return self.rate_limitors.get(rate_limitor_key, default_rate_limitor)

rate_limitor_registry = RateLimitorRegistry()

default_rate_limitor = RateLimitor(0, 0)
5 changes: 3 additions & 2 deletions tests/data/code/python/1.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@
ax2.set_xlabel("Degree")
ax2.set_ylabel("# of Nodes")

fig.tight_layout()
plt.show()
if __name__ == "__main__":
fig.tight_layout()
plt.show()


class Game:
Expand Down
Loading