Skip to content

Commit

Permalink
Merge pull request #26 from datastax/v1-check
Browse files Browse the repository at this point in the history
v1 check
  • Loading branch information
phact authored May 1, 2024
2 parents ccb6161 + 5706878 commit d32e233
Show file tree
Hide file tree
Showing 4 changed files with 904 additions and 808 deletions.
28 changes: 23 additions & 5 deletions impl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from prometheus_fastapi_instrumentator.metrics import Info
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response

from impl.background import background_task_set
from impl.rate_limiter import get_dbid, limiter
Expand All @@ -25,18 +27,18 @@
logger = logging.getLogger('cassandra')
logger.setLevel(logging.WARN)


logger = logging.getLogger(__name__)

app = FastAPI(
title="Astra Assistants API",
description="Drop in replacement for OpenAI Assistants API. .",
description="Drop in replacement for OpenAI Assistants API.",
version="2.0.0",
)

app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)


@app.on_event("shutdown")
async def shutdown_event():
logger.info("shutting down server")
Expand All @@ -55,6 +57,24 @@ async def shutdown_event():
app.include_router(threads.router, prefix="/v1")


class APIVersionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
version_header = request.headers.get("OpenAI-Beta")

original_path = request.url.path
if version_header is None or version_header == "assistants=v1":
request.scope['path'] = original_path
response = await call_next(request)
return response
else:
return Response(
"Unsupported version, please use openai SDK compatible with: OpenAI-Beta: assistants=v1 (python sdk openai 1.21.0 or older)",
status_code=400)


app.add_middleware(APIVersionMiddleware)


def count_dbid(
latency_lowr_buckets: Sequence[Union[float, str]] = (0.1, 0.5, 1),
) -> Callable[[Info], None]:
Expand Down Expand Up @@ -130,7 +150,6 @@ def instrumentation(info: Info) -> None:
info.modified_duration
)


return instrumentation


Expand Down Expand Up @@ -182,6 +201,5 @@ async def unimplemented(request: Request, full_path: str):
status_code=501, content={"message": "This feature is not yet implemented"}
)


#if __name__ == "__main__":
# if __name__ == "__main__":
# uvicorn.run(app, host="0.0.0.0", port=8000)
2 changes: 1 addition & 1 deletion impl/routes/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ThreadRunStepCreated, ThreadRunStepInProgress, ThreadRunStepCompleted
from openai.types.beta.threads import Message, MessageDeltaEvent, MessageDelta, TextDeltaBlock, TextDelta
from openai.types.beta.threads.runs import RunStepDeltaEvent, RunStepDelta, ToolCallsStepDetails, FunctionToolCallDelta, \
ToolCallDeltaObject, RunStep, MessageCreationStepDetails, RetrievalToolCall, RetrievalToolCallDelta
ToolCallDeltaObject, RunStep, MessageCreationStepDetails, RetrievalToolCall
from openai.types.beta.threads.runs.message_creation_step_details import MessageCreation
from starlette.responses import StreamingResponse

Expand Down
Loading

0 comments on commit d32e233

Please sign in to comment.