From c3fcfe41d1e8d9124d36a9861ae521b0c3cfb381 Mon Sep 17 00:00:00 2001 From: phact Date: Wed, 1 May 2024 16:12:57 -0400 Subject: [PATCH] v1 check --- impl/main.py | 28 +++++++++++++++++++++++----- impl/routes/threads.py | 2 +- poetry.lock | 25 +++++++++++++------------ pyproject.toml | 4 ++-- 4 files changed, 39 insertions(+), 20 deletions(-) diff --git a/impl/main.py b/impl/main.py index 5725760..5f7c95a 100644 --- a/impl/main.py +++ b/impl/main.py @@ -12,6 +12,8 @@ from prometheus_fastapi_instrumentator.metrics import Info from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import Response from impl.background import background_task_set from impl.rate_limiter import get_dbid, limiter @@ -25,18 +27,18 @@ logger = logging.getLogger('cassandra') logger.setLevel(logging.WARN) - logger = logging.getLogger(__name__) app = FastAPI( title="Astra Assistants API", - description="Drop in replacement for OpenAI Assistants API. .", + description="Drop in replacement for OpenAI Assistants API.", version="2.0.0", ) app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + @app.on_event("shutdown") async def shutdown_event(): logger.info("shutting down server") @@ -55,6 +57,24 @@ async def shutdown_event(): app.include_router(threads.router, prefix="/v1") +class APIVersionMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + version_header = request.headers.get("OpenAI-Beta") + + original_path = request.url.path + if version_header is None or version_header == "assistants=v1": + request.scope['path'] = original_path + response = await call_next(request) + return response + else: + return Response( + "Unsupported version, please use openai SDK compatible with: OpenAI-Beta: assistants=v1 (python sdk openai 1.21.0 or older)", + status_code=400) + + +app.add_middleware(APIVersionMiddleware) + + def count_dbid( latency_lowr_buckets: Sequence[Union[float, str]] = (0.1, 0.5, 1), ) -> Callable[[Info], None]: @@ -130,7 +150,6 @@ def instrumentation(info: Info) -> None: info.modified_duration ) - return instrumentation @@ -182,6 +201,5 @@ async def unimplemented(request: Request, full_path: str): status_code=501, content={"message": "This feature is not yet implemented"} ) - -#if __name__ == "__main__": +# if __name__ == "__main__": # uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/impl/routes/threads.py b/impl/routes/threads.py index 8d2770c..93baaac 100644 --- a/impl/routes/threads.py +++ b/impl/routes/threads.py @@ -23,7 +23,7 @@ ThreadRunStepCreated, ThreadRunStepInProgress, ThreadRunStepCompleted from openai.types.beta.threads import Message, MessageDeltaEvent, MessageDelta, TextDeltaBlock, TextDelta from openai.types.beta.threads.runs import RunStepDeltaEvent, RunStepDelta, ToolCallsStepDetails, FunctionToolCallDelta, \ - ToolCallDeltaObject, RunStep, MessageCreationStepDetails, RetrievalToolCall, RetrievalToolCallDelta + ToolCallDeltaObject, RunStep, MessageCreationStepDetails, RetrievalToolCall from openai.types.beta.threads.runs.message_creation_step_details import MessageCreation from starlette.responses import StreamingResponse diff --git a/poetry.lock b/poetry.lock index 92f8160..7e77561 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1030,19 +1030,20 @@ testing = ["pytest"] [[package]] name = "google-generativeai" -version = "0.3.2" +version = "0.4.1" description = "Google Generative AI High level API client library and tools." optional = false python-versions = ">=3.9" files = [ - {file = "google_generativeai-0.3.2-py3-none-any.whl", hash = "sha256:8761147e6e167141932dc14a7b7af08f2310dd56668a78d206c19bb8bd85bcd7"}, + {file = "google_generativeai-0.4.1-py3-none-any.whl", hash = "sha256:89be3c00c2e688108fccefc50f47f45fc9d37ecd53c1ade9d86b5d982919c24a"}, ] [package.dependencies] google-ai-generativelanguage = "0.4.0" google-api-core = "*" -google-auth = "*" +google-auth = ">=2.15.0" protobuf = "*" +pydantic = "*" tqdm = "*" typing-extensions = "*" @@ -1883,13 +1884,13 @@ files = [ [[package]] name = "openai" -version = "1.14.0" +version = "1.20.0" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.14.0-py3-none-any.whl", hash = "sha256:5c9fd3a59f5cbdb4020733ddf79a22f6b7a36d561968cb3f3dd255cdd263d9fe"}, - {file = "openai-1.14.0.tar.gz", hash = "sha256:e287057adf0ec3315abc32ddcc968d095879abd9b68bf51c0402dab13ab5ae9b"}, + {file = "openai-1.20.0-py3-none-any.whl", hash = "sha256:9fcc75256b2425393800e358cd520b02b5ab1a8731921e45aa7ae6aec3ee8187"}, + {file = "openai-1.20.0.tar.gz", hash = "sha256:d7c0e824b7da3c043731943965c737595cf9631c913b7a1464c502fdf492b9a9"}, ] [package.dependencies] @@ -2747,21 +2748,21 @@ full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7 [[package]] name = "streaming-assistants" -version = "0.15.7" +version = "0.16.0" description = "Streaming enabled Assistants API" optional = false python-versions = "<4.0,>=3.10" files = [ - {file = "streaming_assistants-0.15.7-py3-none-any.whl", hash = "sha256:dccd6e4646788ddc63a0fe4758a1efb3b965bbb790f9db4fe8b4c5a13902942c"}, - {file = "streaming_assistants-0.15.7.tar.gz", hash = "sha256:5bd6ed6a2af40abda4b6dec95ad1a6fc4d5c22579aae6a9ec9a6800a7a50a7af"}, + {file = "streaming_assistants-0.16.0-py3-none-any.whl", hash = "sha256:a107400ce685eee4e48a90c8858526e5824024990d27866b4dfc550bd09d39e1"}, + {file = "streaming_assistants-0.16.0.tar.gz", hash = "sha256:927949b6b78adb66c32fb730d053aaad81ca72ac404cf1d302131567ce9df26b"}, ] [package.dependencies] boto3 = ">=1.34.31,<2.0.0" -google-generativeai = ">=0.3.2,<0.4.0" +google-generativeai = ">=0.4.1,<0.5.0" httpx = ">=0.26.0,<0.27.0" litellm = ">=1.34.18,<2.0.0" -openai = ">=1.14.0,<2.0.0" +openai = ">=1.14.0,<1.21.0" [[package]] name = "tenacity" @@ -3564,4 +3565,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10.12,<3.12" -content-hash = "107ec855247ad2f66810945319c1eae27c70386444c0f33d17efba27b00bda66" +content-hash = "c6fcbcb8489fbabe2e954419cdb771c53da48e86f40e2780fcf70e394eaefb21" diff --git a/pyproject.toml b/pyproject.toml index 9220f21..f03e52a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,8 +57,8 @@ litellm = "1.34.42" boto3 = "^1.29.6" prometheus-fastapi-instrumentator = "^6.1.0" google-cloud-aiplatform = "^1.38.0" -google-generativeai = "^0.3.1" -streaming-assistants = "^0.15.7" +google-generativeai = "^0.4.1" +streaming-assistants = "^0.16.0" annotated-types = "^0.6.0" pydantic-core = "^2.16.3" pydantic = "^2.6.4"