diff --git a/.devcontainer/.zshrc b/.devcontainer/.zshrc deleted file mode 100644 index 13a21d5..0000000 --- a/.devcontainer/.zshrc +++ /dev/null @@ -1,111 +0,0 @@ -# If you come from bash you might have to change your $PATH. -# export PATH=$HOME/bin:$HOME/.local/bin:/usr/local/bin:$PATH - -# Path to your Oh My Zsh installation. -export ZSH=$HOME/.oh-my-zsh - -# Set name of the theme to load --- if set to "random", it will -# load a random theme each time Oh My Zsh is loaded, in which case, -# to know which specific one was loaded, run: echo $RANDOM_THEME -# See https://github.com/ohmyzsh/ohmyzsh/wiki/Themes -ZSH_THEME="powerlevel10k/powerlevel10k" - -# Set list of themes to pick from when loading at random -# Setting this variable when ZSH_THEME="powerlevel10k/powerlevel10k" -# a theme from this variable instead of looking in $ZSH/themes/ -# If set to an empty array, this variable will have no effect. -# ZSH_THEME_RANDOM_CANDIDATES=( "robbyrussell" "agnoster" ) - -# Uncomment the following line to use case-sensitive completion. -# CASE_SENSITIVE="true" - -# Uncomment the following line to use hyphen-insensitive completion. -# Case-sensitive completion must be off. _ and - will be interchangeable. -# HYPHEN_INSENSITIVE="true" - -# Uncomment one of the following lines to change the auto-update behavior -# zstyle ':omz:update' mode disabled # disable automatic updates -# zstyle ':omz:update' mode auto # update automatically without asking -# zstyle ':omz:update' mode reminder # just remind me to update when it's time - -# Uncomment the following line to change how often to auto-update (in days). -# zstyle ':omz:update' frequency 13 - -# Uncomment the following line if pasting URLs and other text is messed up. -# DISABLE_MAGIC_FUNCTIONS="true" - -# Uncomment the following line to disable colors in ls. -# DISABLE_LS_COLORS="true" - -# Uncomment the following line to disable auto-setting terminal title. -# DISABLE_AUTO_TITLE="true" - -# Uncomment the following line to enable command auto-correction. -# ENABLE_CORRECTION="true" - -# Uncomment the following line to display red dots whilst waiting for completion. -# You can also set it to another string to have that shown instead of the default red dots. -# e.g. COMPLETION_WAITING_DOTS="%F{yellow}waiting...%f" -# Caution: this setting can cause issues with multiline prompts in zsh < 5.7.1 (see #5765) -# COMPLETION_WAITING_DOTS="true" - -# Uncomment the following line if you want to disable marking untracked files -# under VCS as dirty. This makes repository status check for large repositories -# much, much faster. -# DISABLE_UNTRACKED_FILES_DIRTY="true" - -# Uncomment the following line if you want to change the command execution time -# stamp shown in the history command output. -# You can set one of the optional three formats: -# "mm/dd/yyyy"|"dd.mm.yyyy"|"yyyy-mm-dd" -# or set a custom format using the strftime function format specifications, -# see 'man strftime' for details. -# HIST_STAMPS="mm/dd/yyyy" - -# Would you like to use another custom folder than $ZSH/custom? -# ZSH_CUSTOM=/path/to/new-custom-folder - -# Which plugins would you like to load? -# Standard plugins can be found in $ZSH/plugins/ -# Custom plugins may be added to $ZSH_CUSTOM/plugins/ -# Example format: plugins=(rails git textmate ruby lighthouse) -# Add wisely, as too many plugins slow down shell startup. -plugins=(git) - -source $ZSH/oh-my-zsh.sh - -# User configuration - -# export MANPATH="/usr/local/man:$MANPATH" - -# You may need to manually set your language environment -# export LANG=en_US.UTF-8 - -# Preferred editor for local and remote sessions -# if [[ -n $SSH_CONNECTION ]]; then -# export EDITOR='vim' -# else -# export EDITOR='mvim' -# fi - -# Compilation flags -# export ARCHFLAGS="-arch $(uname -m)" - -# Set personal aliases, overriding those provided by Oh My Zsh libs, -# plugins, and themes. Aliases can be placed here, though Oh My Zsh -# users are encouraged to define aliases within a top-level file in -# the $ZSH_CUSTOM folder, with .zsh extension. Examples: -# - $ZSH_CUSTOM/aliases.zsh -# - $ZSH_CUSTOM/macos.zsh -# For a full list of active aliases, run `alias`. -# -# Example aliases -# alias zshconfig="mate ~/.zshrc" -# alias ohmyzsh="mate ~/.oh-my-zsh" -DISABLE_AUTO_UPDATE=true -DISABLE_UPDATE_PROMPT=true -[[ ! -f ~/.p10k.zsh ]] || source ~/.p10k.zsh - -POWERLEVEL9K_DISABLE_CONFIGURATION_WIZARD=true - -eval "$(task --completion zsh)" diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 55feeee..fd05fed 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,16 +1,22 @@ FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm -ENV ZSH_CUSTOM=/home/vscode/.oh-my-zsh/custom \ - BLACK_VERSION=24.10.0 \ +ENV BLACK_VERSION=25.1.0 \ ISORT_VERSION=5.13.2 \ - PYLINT_VERSION=3.3.3 \ + PYLINT_VERSION=3.3.7 \ BUILD_VERSION=1.2.2.post1 \ - TWINE_VERSION=6.0.1 \ - TASK_VERSION=v3.41.0 \ - PYTEST_VERSION=8.3.4 \ - PYTEST_WATCH_VERSION=4.2.0 + TWINE_VERSION=6.1.0 \ + TASK_VERSION=v3.43.3 \ + PYTEST_VERSION=8.3.5 \ + PYTEST_WATCH_VERSION=4.2.0 \ + PRE_COMMIT_VERSION=4.2.0 RUN apt-get update && \ + # Install GitHub CLI + curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg | dd of=/usr/share/keyrings/githubcli-archive-keyring.gpg && \ + chmod go+r /usr/share/keyrings/githubcli-archive-keyring.gpg && \ + echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null && \ + apt-get update && \ + apt-get install -y gh && \ # Install nodejs and npm curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - && \ apt-get install -y nodejs && \ @@ -20,15 +26,17 @@ RUN apt-get update && \ python -m pip install --upgrade pip && \ # Install development tools using pip pip install black==${BLACK_VERSION} \ - isort==${ISORT_VERSION} \ - pylint==${PYLINT_VERSION} \ - build==${BUILD_VERSION} \ - twine==${TWINE_VERSION} \ - pytest==${PYTEST_VERSION} \ - pytest-watch==${PYTEST_WATCH_VERSION} \ - pytest-cov \ - pytest-xdist \ - debugpy && \ + isort==${ISORT_VERSION} \ + pylint==${PYLINT_VERSION} \ + build==${BUILD_VERSION} \ + twine==${TWINE_VERSION} \ + pytest==${PYTEST_VERSION} \ + pytest-watch==${PYTEST_WATCH_VERSION} \ + pytest-cov \ + pytest-xdist \ + datamodel-code-generator \ + pre-commit==${PRE_COMMIT_VERSION} \ + debugpy && \ # Clean up apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -36,7 +44,13 @@ RUN apt-get update && \ # Install semantic-release RUN npm install -g semantic-release @semantic-release/changelog @semantic-release/exec @semantic-release/git @semantic-release/github conventional-changelog-conventionalcommits -# Install powerlevel10k theme -RUN git clone --depth=1 https://github.com/romkatv/powerlevel10k.git ${ZSH_CUSTOM}/themes/powerlevel10k - USER vscode + +# Use Powerlevel10k theme +RUN git clone --depth=1 https://github.com/romkatv/powerlevel10k.git /home/vscode/.powerlevel10k && \ + echo 'source /home/vscode/.powerlevel10k/powerlevel10k.zsh-theme' >> /home/vscode/.zshrc && \ + echo 'POWERLEVEL9K_DISABLE_CONFIGURATION_WIZARD=true' >> /home/vscode/.zshrc + +# Shell completion +RUN echo "source <(gh completion -s zsh)" >> /home/vscode/.zshrc +RUN echo "source <(task --completion zsh)" >> /home/vscode/.zshrc diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index d23297b..e0a34fa 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,6 +1,9 @@ { "name": "Debian with Python 3", "dockerFile": "Dockerfile", + "features": { + "ghcr.io/devcontainers/features/docker-in-docker:latest": {} + }, "customizations": { "vscode": { "extensions": [ @@ -14,22 +17,27 @@ "ms-python.python", "ms-python.vscode-pylance", "ms-python.black-formatter", + "ms-python.isort", "tamasfe.even-better-toml" ], "settings": { "python.pythonPath": "/usr/local/bin/python", "python.linting.enabled": true, "python.linting.pylintEnabled": false, - "python.formatting.provider": "black", - "python.formatting.blackPath": "/usr/local/py-utils/bin/black", - "python.sortImports.path": "/usr/local/py-utils/bin/isort", + "python.defaultInterpreterPath": "/usr/local/bin/python", "terminal.integrated.defaultProfile.linux": "zsh", "editor.renderWhitespace": "all", "cSpell.enabled": true, "cSpell.files": ["**/*.md"], "editor.formatOnSave": true, - "editor.defaultFormatter": "ms-python.black-formatter", "black-formatter.args": ["--config", "pyproject.toml"], + "isort.args": ["--profile", "black", "--line-length", "100"], + "editor.codeActionsOnSave": { + "source.organizeImports": "explicit" + }, + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter" + }, "git.enableCommitSigning": true, "dev.containers.copyGitConfig": true, "githubPullRequests.experimental.chat": true, @@ -38,16 +46,57 @@ "python.testing.pytestEnabled": true, "python.testing.unittestEnabled": false, "python.testing.nosetestsEnabled": false, - "python.testing.pytestArgs": [ - "tests" - ] + "python.testing.pytestArgs": ["tests"], + "notebook.insertFinalNewline": true, + "yaml.schemas": { + "https://json.schemastore.org/pre-commit-config.json": [ + ".pre-commit-config.yaml", + ".pre-commit-config.yml" + ] + }, + "yaml.schemaStore.enable": true, + "yaml.validate": true, + "github.copilot.enable": { + "*": true + }, + "github.copilot.advanced": { + "authProvider": "github" + }, + "github.copilot.chat.codeGeneration.useInstructionFiles": true, + "github.copilot.chat.commitMessageGeneration.instructions": [ + { + "text": "Always use conventional commit message format." + } + ], + "github.copilot.chat.pullRequestDescriptionGeneration.instructions": [ + { + "text": "Always fill the pull request with the following information: \n ## Summary\n \n" + } + ], + "github.copilot.chat.testGeneration.instructions": [ + { + "text": "Always use table-driven tests." + } + ], + "mcp": { + "servers": { + "Context7": { + "command": "docker", + "args": [ + "run", + "-i", + "--rm", + "node:lts", + "npx", + "-y", + "@upstash/context7-mcp@latest" + ] + } + } + } } } }, - "mounts": [ - "source=${localWorkspaceFolder}/.devcontainer/.zshrc,target=/home/vscode/.zshrc,type=bind,consistency=cached", - "source=${localWorkspaceFolder}/.devcontainer/launch.json,target=/workspaces/python-sdk/.vscode/launch.json,type=bind,consistency=cached" - ], "postCreateCommand": "pip install -r requirements.txt", "remoteEnv": { "GITHUB_TOKEN": "${localEnv:GITHUB_TOKEN}" diff --git a/.devcontainer/launch.json b/.devcontainer/launch.json deleted file mode 100644 index 7b7e0cc..0000000 --- a/.devcontainer/launch.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "version": "0.2.0", - "configurations": [ - { - "name": "Python: Debug Tests", - "type": "debugpy", - "request": "launch", - "program": "/usr/local/bin/python", - "args": [ - "-v", - "--no-cov", - "tests/" - ], - "console": "integratedTerminal", - "justMyCode": false - }, - { - "name": "Python: Debug Current Test", - "type": "debugpy", - "request": "launch", - "program": "/usr/local/bin/python", - "args": [ - "-v", - "--no-cov", - "${file}" - ], - "console": "integratedTerminal", - "justMyCode": false - } - ] -} diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000..a295946 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,47 @@ +# Custom Instructions for Copilot + +Today is May 26, 2025. + +- Always use context7 to check for the latest updates, features, or best practices of a library relevant to the task at hand. +- Always prefer Table-Driven Testing: When writing tests. +- Always use Early Returns: Favor early returns to simplify logic and avoid deep nesting with if-else structures. +- Always prefer switch statements over if-else chains: Use switch statements for cleaner and more readable code when checking multiple conditions. +- Always run `task lint` before committing code to ensure it adheres to the project's linting rules. +- Always run `task test` before committing code to ensure all tests pass. +- Always search for the simplest solution first before considering more complex alternatives. +- Always prefer type safety over dynamic typing: Use strong typing and interfaces to ensure type safety and reduce runtime errors. +- When possible code to an interface so it's easier to mock in tests. +- When writing tests, each test case should have it's own isolated mock server mock dependecies so it's easier to understand and maintain. + +## Development Workflow + +### Configuration Changes + +When adding new configuration fields: + +1. Run `task oas-download` - OpenAPI is the source of truth - readonly file. +2. If added new Schemas to openapi.yaml, make sure to run `task generate` to regenerate the Python code. +3. Run `task lint` to ensure code quality +4. Run `task test` to ensure all tests pass +5. Update the README.md file or any documentation files with the recently added implementation + +## Available Tools and MCPs + +- context7 - Helps by finding the latest updates, features, or best practices of a library relevant to the task at hand. + +## Related Repositories + +- [Inference Gateway](https://github.com/inference-gateway) + - [Inference Gateway UI](https://github.com/inference-gateway/ui) + - [Go SDK](https://github.com/inference-gateway/go-sdk) + - [Rust SDK](https://github.com/inference-gateway/rust-sdk) + - [TypeScript SDK](https://github.com/inference-gateway/typescript-sdk) + - [Python SDK](https://github.com/inference-gateway/python-sdk) + - [Documentation](https://docs.inference-gateway.com) + +## MCP Useful links + +- [Introduction](https://modelcontextprotocol.io/introduction) +- [Specification](https://modelcontextprotocol.io/specification) +- [Examples](https://modelcontextprotocol.io/examples) +- [Schema](https://raw.githubusercontent.com/modelcontextprotocol/modelcontextprotocol/refs/heads/main/schema/draft/schema.json) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3df02eb..7fdd65d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,7 @@ on: push: branches: - main - + pull_request: branches: - main @@ -15,18 +15,18 @@ jobs: name: Test steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@v4.2.2 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v5.6.0 with: - python-version: '3.12' + python-version: "3.12" - name: Install dependencies run: | python -m pip install --upgrade pip pip install -r requirements.txt - pip install black==24.10.0 pytest + pip install black==25.1.0 pytest - name: Check formatting with Black run: black --check . diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5c766ec..40606d2 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -13,20 +13,20 @@ jobs: new_release_version: ${{ steps.semantic.outputs.new_release_version }} new_release_published: ${{ steps.semantic.outputs.new_release_published }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 with: fetch-depth: 0 persist-credentials: false ref: main - name: Setup Node.js - uses: actions/setup-node@v4 + uses: actions/setup-node@v4.4.0 with: node-version: "lts/*" - name: Install semantic release and plugins run: | - npm install -g semantic-release@v24.2.1 \ + npm install -g semantic-release@v24.2.5 \ conventional-changelog-cli \ conventional-changelog-conventionalcommits \ @semantic-release/changelog \ @@ -73,14 +73,14 @@ jobs: if: needs.github_release.outputs.new_release_published == 'true' steps: - name: Check out the code - uses: actions/checkout@v4 + uses: actions/checkout@v4.2.2 with: - persist-credentials: false + persist-credentials: false - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v5.6.0 with: - python-version: '3.12' + python-version: "3.12" - name: Update pyproject.toml version run: | diff --git a/.gitignore b/.gitignore index 73a2c70..c481c18 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ -.vscode +.vscode/ .pytest_cache **/__pycache__ dist **.egg-info +.coverage +node_modules/ +.mypy_cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..c7f7768 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,56 @@ +repos: + - repo: https://github.com/psf/black + rev: 25.1.0 + hooks: + - id: black + args: [--config=pyproject.toml] + language_version: python3.12 + + - repo: https://github.com/pycqa/isort + rev: 6.0.1 + hooks: + - id: isort + args: [--profile=black, --line-length=100] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.15.0 + hooks: + - id: mypy + additional_dependencies: + [ + pydantic>=2.11.5, + httpx>=0.28.1, + requests>=2.32.3, + types-requests, + pytest>=8.3.5, + ] + args: [--config-file=pyproject.toml] + exclude: ^tests/ + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-json + - id: check-merge-conflict + - id: check-added-large-files + - id: check-docstring-first + - id: debug-statements + - id: name-tests-test + args: [--pytest-test-first] + + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v4.0.0-alpha.8 + hooks: + - id: prettier + types_or: [yaml, json, markdown] + exclude: ^(CHANGELOG\.md|openapi\.yaml)$ + +ci: + autofix_commit_msg: "style: auto-fix pre-commit hooks" + autofix_prs: true + autoupdate_commit_msg: "chore: pre-commit autoupdate" + autoupdate_schedule: weekly diff --git a/.releaserc.yaml b/.releaserc.yaml index f242303..712e624 100644 --- a/.releaserc.yaml +++ b/.releaserc.yaml @@ -60,8 +60,5 @@ plugins: "releaseNameTemplate": "🚀 Version ${nextRelease.version}", }, ] - - [ - "@semantic-release/git", - { "assets": ["CHANGELOG.md", "pyproject.toml"] }, - ] + - ["@semantic-release/git", { "assets": ["CHANGELOG.md", "pyproject.toml"] }] repositoryUrl: "https://github.com/inference-gateway/python-sdk" diff --git a/README.md b/README.md index 1fb8636..7189a81 100644 --- a/README.md +++ b/README.md @@ -1,126 +1,226 @@ # Inference Gateway Python SDK -An SDK written in Python for the [Inference Gateway](https://github.com/edenreich/inference-gateway). - - [Inference Gateway Python SDK](#inference-gateway-python-sdk) - - [Installation](#installation) - - [Usage](#usage) - - [Creating a Client](#creating-a-client) + - [Features](#features) + - [Quick Start](#quick-start) + - [Installation](#installation) + - [Basic Usage](#basic-usage) + - [Requirements](#requirements) + - [Client Configuration](#client-configuration) + - [Core Functionality](#core-functionality) - [Listing Models](#listing-models) - - [List Provider's Models](#list-providers-models) - - [Generating Content](#generating-content) - - [Streaming Content](#streaming-content) - - [Health Check](#health-check) + - [Chat Completions](#chat-completions) + - [Standard Completion](#standard-completion) + - [Streaming Completion](#streaming-completion) + - [Proxy Requests](#proxy-requests) + - [Health Checking](#health-checking) + - [Error Handling](#error-handling) + - [Advanced Usage](#advanced-usage) + - [Using Tools](#using-tools) + - [Custom HTTP Configuration](#custom-http-configuration) - [License](#license) -## Installation +A modern Python SDK for interacting with the [Inference Gateway](https://github.com/edenreich/inference-gateway), providing a unified interface to multiple AI providers. + +## Features + +- 🔗 Unified interface for multiple AI providers (OpenAI, Anthropic, Ollama, etc.) +- 🛡️ Type-safe operations using Pydantic models +- ⚡ Support for both synchronous and streaming responses +- 🚨 Built-in error handling and validation +- 🔄 Proxy requests directly to provider APIs + +## Quick Start + +### Installation ```sh pip install inference-gateway ``` -## Usage +### Basic Usage + +```python +from inference_gateway import InferenceGatewayClient, Message, MessageRole + +# Initialize client +client = InferenceGatewayClient("http://localhost:8080") + +# Simple chat completion +response = client.create_chat_completion( + model="openai/gpt-4", + messages=[ + Message(role=MessageRole.SYSTEM, content="You are a helpful assistant"), + Message(role=MessageRole.USER, content="Hello!") + ] +) + +print(response.choices[0].message.content) +``` + +## Requirements -### Creating a Client +- Python 3.8+ +- `requests` or `httpx` (for HTTP client) +- `pydantic` (for data validation) + +## Client Configuration ```python -from inference_gateway.client import InferenceGatewayClient, Provider +from inference_gateway import InferenceGatewayClient +# Basic configuration client = InferenceGatewayClient("http://localhost:8080") -# With authentication token(optional) -client = InferenceGatewayClient("http://localhost:8080", token="your-token") +# With authentication +client = InferenceGatewayClient( + "http://localhost:8080", + token="your-api-token", + timeout=60.0 # Custom timeout +) + +# Using httpx instead of requests +client = InferenceGatewayClient( + "http://localhost:8080", + use_httpx=True +) ``` -### Listing Models +## Core Functionality -To list all available models from all providers, use the list_models method: +### Listing Models ```python +# List all available models models = client.list_models() -print("Available models: ", models) +print("All models:", models) + +# Filter by provider +openai_models = client.list_models(provider="openai") +print("OpenAI models:", openai_models) ``` -### List Provider's Models +### Chat Completions -To list available models for a specific provider, use the list_provider_models method: +#### Standard Completion ```python -models = client.list_provider_models(Provider.OPENAI) -print("Available OpenAI models: ", models) +from inference_gateway import Message, MessageRole + +response = client.create_chat_completion( + model="openai/gpt-4", + messages=[ + Message(role=MessageRole.SYSTEM, content="You are a helpful assistant"), + Message(role=MessageRole.USER, content="Explain quantum computing") + ], + max_tokens=500 +) + +print(response.choices[0].message.content) ``` -### Generating Content +#### Streaming Completion + +```python +# Using Server-Sent Events (SSE) +for chunk in client.create_chat_completion_stream( + model="ollama/llama2", + messages=[ + Message(role=MessageRole.USER, content="Tell me a story") + ], + use_sse=True +): + print(chunk.data, end="", flush=True) + +# Using JSON lines +for chunk in client.create_chat_completion_stream( + model="anthropic/claude-3", + messages=[ + Message(role=MessageRole.USER, content="Explain AI safety") + ], + use_sse=False +): + print(chunk["choices"][0]["delta"]["content"], end="", flush=True) +``` -To generate content using a model, use the generate_content method: +### Proxy Requests ```python -from inference_gateway.client import Provider, Role, Message - -messages = [ - Message( - Role.SYSTEM, - "You are an helpful assistant" - ), - Message( - Role.USER, - "Hello!" - ), -] - -response = client.generate_content( - provider=Provider.OPENAI, - model="gpt-4", - messages=messages +# Proxy request to OpenAI's API +response = client.proxy_request( + provider="openai", + path="/v1/models", + method="GET" ) -print("Assistant: ", response["response"]["content"]) + +print("OpenAI models:", response) +``` + +### Health Checking + +```python +if client.health_check(): + print("API is healthy") +else: + print("API is unavailable") ``` -### Streaming Content +## Error Handling -To stream content using a model, use the stream_content method: +The SDK provides several exception types: ```python -from inference_gateway.client import Provider, Role, Message - -messages = [ - Message( - Role.SYSTEM, - "You are an helpful assistant" - ), - Message( - Role.USER, - "Hello!" - ), -] - -# Use SSE for streaming -for response in client.generate_content_stream( - provider=Provider.Ollama, - model="llama2", - messages=messages, - use_sse=true -): -print("Event: ", response["event"]) -print("Assistant: ", response["data"]["content"]) - -# Or raw JSON response -for response in client.generate_content_stream( - provider=Provider.GROQ, - model="deepseek-r1", - messages=messages, - use_sse=false -): -print("Assistant: ", response.content) +try: + response = client.create_chat_completion(...) +except InferenceGatewayAPIError as e: + print(f"API Error: {e} (Status: {e.status_code})") + print("Response:", e.response_data) +except InferenceGatewayValidationError as e: + print(f"Validation Error: {e}") +except InferenceGatewayError as e: + print(f"General Error: {e}") ``` -### Health Check +## Advanced Usage + +### Using Tools + +```python +# List available MCP tools works when MCP_ENABLE and MCP_EXPOSE are set on the gateway +tools = client.list_tools() +print("Available tools:", tools) + +# Use tools in chat completion works when MCP_ENABLE and MCP_EXPOSE are set to false on the gateway +response = client.create_chat_completion( + model="openai/gpt-4", + messages=[...], + tools=[ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": {...} + } + } + ] +) +``` -To check the health of the API, use the health_check method: +### Custom HTTP Configuration ```python -is_healthy = client.health_check() -print("API Status: ", "Healthy" if is_healthy else "Unhealthy") +# With custom headers +client = InferenceGatewayClient( + "http://localhost:8080", + headers={"X-Custom-Header": "value"} +) + +# With proxy settings +client = InferenceGatewayClient( + "http://localhost:8080", + proxies={"http": "http://proxy.example.com"} +) ``` ## License diff --git a/Taskfile.yml b/Taskfile.yml index cc88613..07438e3 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -1,38 +1,170 @@ --- version: "3" +vars: + OPENAPI_URL: https://raw.githubusercontent.com/inference-gateway/inference-gateway/refs/heads/main/openapi.yaml + PYTHON_VERSION: "3.12" + tasks: + default: + desc: Show available tasks + cmds: + - task --list + + install: + desc: Install dependencies for development + cmds: + - pip install -e ".[dev]" + oas-download: - desc: Download OpenAPI specification + desc: Download latest OpenAPI specification from inference-gateway repository cmds: - - curl -o openapi.yaml https://raw.githubusercontent.com/inference-gateway/inference-gateway/refs/heads/main/openapi.yaml + - echo "Downloading OpenAPI spec from {{.OPENAPI_URL}}..." + - curl -sSL -o openapi.yaml "{{.OPENAPI_URL}}" + - echo "✅ OpenAPI spec downloaded successfully" - lint: - desc: Lint the code + oas-validate: + desc: Validate the OpenAPI specification + deps: + - oas-download + cmds: + - python -c "import yaml; yaml.safe_load(open('openapi.yaml', 'r'))" + - echo "✅ OpenAPI spec is valid YAML" + + generate: + desc: Generate Pydantic models from OpenAPI specification + deps: + - oas-validate + cmds: + - echo "Generating Pydantic v2 models from OpenAPI spec..." + - > + datamodel-codegen + --input openapi.yaml + --output inference_gateway/models.py + --output-model-type pydantic_v2.BaseModel + --enum-field-as-literal all + --target-python-version {{.PYTHON_VERSION}} + --use-schema-description + --use-generic-container-types + --use-standard-collections + --use-annotated + --use-field-description + --field-constraints + --disable-appending-item-suffix + --custom-template-dir templates/ + --wrap-string-literal + --use-one-literal-as-default + --use-subclass-enum + --strict-nullable + --allow-population-by-field-name + --snake-case-field + --strip-default-none + --use-title-as-name + - echo "✅ Models generated successfully" + - task: format + + format: + desc: Format code with black and isort cmds: + - echo "Formatting code..." - black inference_gateway/ tests/ + - isort inference_gateway/ tests/ + - echo "✅ Code formatted" + + lint: + desc: Run all linting checks + cmds: + - echo "Running linting checks..." + - black --check inference_gateway/ tests/ + - isort --check-only inference_gateway/ tests/ + - mypy inference_gateway/ + - echo "✅ All linting checks passed" test: desc: Run tests cmds: + - echo "Running tests..." - pytest tests/ -v + - echo "✅ All tests passed" test:watch: desc: Run tests in watch mode cmds: + - echo "Running tests in watch mode..." - ptw tests/ -- -v test:coverage: desc: Run tests with coverage report cmds: - - pytest tests/ -v --cov=inference_gateway --cov-report=term-missing + - echo "Running tests with coverage..." + - pytest tests/ -v --cov=inference_gateway --cov-report=term-missing --cov-report=html + - echo "✅ Coverage report generated" - test:debug: - desc: Run tests with debugger enabled + clean: + desc: Clean up build artifacts and cache files cmds: - - pytest tests/ -v --pdb + - echo "Cleaning up..." + - rm -rf inference_gateway.egg-info dist build .pytest_cache .coverage htmlcov + - find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true + - find . -type d -name ".mypy_cache" -exec rm -rf {} + 2>/dev/null || true + - find . -type d -name "node_modules" -exec rm -rf {} + 2>/dev/null || true + - find . -type f -name "*.pyc" -delete + - echo "✅ Cleanup completed" - clean: - desc: Clean up + build: + desc: Build the package + deps: + - clean + - lint + - test + cmds: + - echo "Building package..." + - python -m build + - echo "✅ Package built successfully" + + docs:serve: + desc: Serve documentation locally (placeholder for future docs) + cmds: + - echo "📚 Documentation server would start here" + - echo "Future mkdocs serve or similar" + + dev:setup: + desc: Complete development environment setup + cmds: + - echo "Setting up development environment..." + - task: install + - task: oas-download + - task: generate + - task: pre-commit:install + - task: test + - echo "✅ Development environment setup complete" + + ci:check: + desc: Run all CI checks (lint, test, build) + cmds: + - echo "Running CI checks..." + - task: lint + - task: test + - task: build + - echo "✅ All CI checks passed" + + pre-commit:install: + desc: Install pre-commit hooks + cmds: + - echo "Installing pre-commit hooks..." + - pre-commit install + - echo "✅ Pre-commit hooks installed" + + pre-commit:run: + desc: Run pre-commit hooks on all files + cmds: + - echo "Running pre-commit hooks on all files..." + - pre-commit run --all-files + - echo "✅ Pre-commit hooks completed" + + pre-commit:update: + desc: Update pre-commit hook versions cmds: - - rm -rf inference_gateway.egg-info dist .pytest_cache inference_gateway/__pycache__ tests/__pycache__ .coverage + - echo "Updating pre-commit hook versions..." + - pre-commit autoupdate + - echo "✅ Pre-commit hooks updated" diff --git a/inference_gateway/__init__.py b/inference_gateway/__init__.py index e69de29..15e4f6f 100644 --- a/inference_gateway/__init__.py +++ b/inference_gateway/__init__.py @@ -0,0 +1,47 @@ +"""Inference Gateway Python SDK. + +A modern Python SDK for the Inference Gateway API with full OpenAPI support +and type safety using Pydantic v2. +""" + +from inference_gateway.client import ( + InferenceGatewayAPIError, + InferenceGatewayClient, + InferenceGatewayError, + InferenceGatewayValidationError, +) +from inference_gateway.models import ( + ChatCompletionMessageToolCall, + CompletionUsage, + CreateChatCompletionRequest, + CreateChatCompletionResponse, + Function, + ListModelsResponse, + Message, + MessageRole, + Model, + Provider, + SSEvent, +) + +__version__ = "0.4.0" +__all__ = [ + # Client classes + "InferenceGatewayClient", + # Exceptions + "InferenceGatewayError", + "InferenceGatewayAPIError", + "InferenceGatewayValidationError", + # Core models + "Provider", + "Message", + "MessageRole", + "ListModelsResponse", + "CreateChatCompletionRequest", + "CreateChatCompletionResponse", + "SSEvent", + "Model", + "CompletionUsage", + "ChatCompletionMessageToolCall", + "Function", +] diff --git a/inference_gateway/client.py b/inference_gateway/client.py index 6407f65..d2c7101 100644 --- a/inference_gateway/client.py +++ b/inference_gateway/client.py @@ -1,235 +1,401 @@ -from typing import Generator, Optional, Union, List, Dict, Optional -import json -from dataclasses import dataclass -from enum import Enum -import requests - +"""Modern Python SDK client for Inference Gateway API. -class Provider(str, Enum): - """Supported LLM providers""" +This module provides a comprehensive client for interacting with the Inference Gateway, +supporting multiple AI providers with a unified interface. +""" - OLLAMA = "ollama" - GROQ = "groq" - OPENAI = "openai" - CLOUDFLARE = "cloudflare" - COHERE = "cohere" +import json +from typing import Any, Dict, Generator, List, Optional, Union +import httpx +import requests +from pydantic import ValidationError -class Role(str, Enum): - """Message role types""" +from inference_gateway.models import ( + CreateChatCompletionRequest, + CreateChatCompletionResponse, + ListModelsResponse, + ListToolsResponse, + Message, + Provider, + SSEvent, +) - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" +class InferenceGatewayError(Exception): + """Base exception for Inference Gateway SDK errors.""" -@dataclass -class Message: - role: Role - content: str + pass - def to_dict(self) -> Dict[str, str]: - """Convert message to dictionary format with string values""" - return {"role": self.role.value, "content": self.content} +class InferenceGatewayAPIError(InferenceGatewayError): + """Exception raised for API-related errors.""" -@dataclass -class Model: - """Represents an LLM model""" + def __init__( + self, + message: str, + status_code: Optional[int] = None, + response_data: Optional[Dict[str, Any]] = None, + ): + super().__init__(message) + self.status_code = status_code + self.response_data = response_data - name: str +class InferenceGatewayValidationError(InferenceGatewayError): + """Exception raised for validation errors.""" -@dataclass -class ProviderModels: - """Groups models by provider""" + pass - provider: Provider - models: List[Model] +class InferenceGatewayClient: + """Modern client for interacting with the Inference Gateway API. -@dataclass -class ResponseTokens: - """Response tokens structure as defined in the API spec""" + This client provides a comprehensive interface to the Inference Gateway, + supporting multiple AI providers with type-safe operations. - role: str - model: str - content: str + Example: + ```python + # Basic usage + client = InferenceGatewayClient("https://api.example.com") - @classmethod - def from_dict(cls, data: dict) -> "ResponseTokens": - """Create ResponseTokens from dictionary data + # With authentication + client = InferenceGatewayClient( + "https://api.example.com", + token="your-api-token" + ) - Args: - data: Dictionary containing response data + # List available models + models = client.list_models() - Returns: - ResponseTokens instance + # Create a chat completion + messages = [Message(role="user", content="Hello!")] + response = client.create_chat_completion( + model="gpt-4o", + messages=messages + ) + ``` + """ + + def __init__( + self, + base_url: str, + token: Optional[str] = None, + timeout: float = 30.0, + use_httpx: bool = False, + ): + """Initialize the client with base URL and optional auth token. - Raises: - TypeError: If data is not a dictionary - ValueError: If required fields are missing + Args: + base_url: The base URL of the Inference Gateway API + token: Optional authentication token + timeout: Request timeout in seconds (default: 30.0) + use_httpx: Whether to use httpx instead of requests (default: False) """ - if not isinstance(data, dict): - raise TypeError(f"Expected dict, got {type(data)}") + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self.use_httpx = use_httpx + + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } - required = ["role", "model", "content"] - missing = [field for field in required if field not in data] + if token: + headers["Authorization"] = f"Bearer {token}" - if missing: - raise ValueError( - f"Missing required arguments: { - ', '.join(missing)}" + if use_httpx: + self.client = httpx.Client( + timeout=timeout, + headers=headers, ) + else: + self.session = requests.Session() + self.session.headers.update(headers) + self._timeout = timeout + + def __enter__(self) -> "InferenceGatewayClient": + """Context manager entry.""" + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Context manager exit.""" + self.close() + + def close(self) -> None: + """Close the HTTP client.""" + if self.use_httpx and hasattr(self, "client"): + self.client.close() + + def _make_request( + self, method: str, url: str, **kwargs: Any + ) -> Union[requests.Response, httpx.Response]: + """Make an HTTP request using the configured client.""" + response: Optional[Union[requests.Response, httpx.Response]] = None + try: + if self.use_httpx: + response = self.client.request(method, url, **kwargs) + else: + if "timeout" not in kwargs and hasattr(self, "_timeout"): + kwargs["timeout"] = self._timeout + response = self.session.request(method, url, **kwargs) + + if response is not None: + response.raise_for_status() + return response + else: + raise InferenceGatewayError("No response received") + + except (requests.HTTPError, httpx.HTTPStatusError) as e: + try: + error_data = response.json() if response and response.content else {} + except (json.JSONDecodeError, ValueError): + error_data = {} + + status_code = response.status_code if response else 0 + raise InferenceGatewayAPIError( + f"Request failed: {str(e)}", + status_code=status_code, + response_data=error_data, + ) + except (requests.RequestException, httpx.RequestError) as e: + raise InferenceGatewayError(f"Request failed: {str(e)}") - return cls(role=data["role"], model=data["model"], content=data["content"]) + def list_models(self, provider: Optional[Union[Provider, str]] = None) -> ListModelsResponse: + """List all available language models. + Args: + provider: Optional provider to filter models -@dataclass -class GenerateResponse: - """Response structure for token generation""" + Returns: + ListModelsResponse: List of available models - provider: str - response: ResponseTokens + Raises: + InferenceGatewayAPIError: If the API request fails + InferenceGatewayValidationError: If response validation fails + """ + url = f"{self.base_url}/v1/models" + params = {} - @classmethod - def from_dict(cls, data: dict) -> "GenerateResponse": - """Create GenerateResponse from dictionary data""" - return cls( - provider=data.get("provider", ""), response=ResponseTokens(**data.get("response", {})) - ) + if provider: + provider_value = provider.root if hasattr(provider, "root") else str(provider) + params["provider"] = provider_value + try: + response = self._make_request("GET", url, params=params) + return ListModelsResponse.model_validate(response.json()) + except ValidationError as e: + raise InferenceGatewayValidationError(f"Response validation failed: {e}") -class InferenceGatewayClient: - """Client for interacting with the Inference Gateway API""" + def list_tools(self) -> ListToolsResponse: + """List all available MCP tools. - def __init__(self, base_url: str, token: Optional[str] = None): - """Initialize the client with base URL and optional auth token""" - self.base_url = base_url.rstrip("/") - self.session = requests.Session() - if token: - self.session.headers.update({"Authorization": f"Bearer {token}"}) + Returns: + ListToolsResponse: List of available MCP tools - def list_models(self) -> List[ProviderModels]: - """List all available language models""" - response = self.session.get(f"{self.base_url}/llms") - response.raise_for_status() - return response.json() + Raises: + InferenceGatewayAPIError: If the API request fails + InferenceGatewayValidationError: If response validation fails + """ + url = f"{self.base_url}/mcp/tools" - def list_providers_models(self, provider: Provider) -> List[Model]: - """List models for a specific provider""" - response = self.session.get(f"{self.base_url}/llms/{provider.value}") - response.raise_for_status() - return response.json() + try: + response = self._make_request("GET", url) + return ListToolsResponse.model_validate(response.json()) + except ValidationError as e: + raise InferenceGatewayValidationError(f"Response validation failed: {e}") - def _parse_sse_chunk(self, chunk: bytes) -> dict: - """Parse an SSE message chunk into structured event data + def _parse_sse_chunk(self, chunk: bytes) -> SSEvent: + """Parse an SSE message chunk into structured event data. Args: chunk: Raw SSE message chunk in bytes format Returns: - dict: Parsed SSE message with event type and data fields + SSEvent: Parsed SSE message with event type and data fields Raises: - json.JSONDecodeError: If chunk format or content is invalid + InferenceGatewayValidationError: If chunk format or content is invalid """ if not isinstance(chunk, bytes): raise TypeError(f"Expected bytes, got {type(chunk)}") try: decoded = chunk.decode("utf-8") - message = {} + event_type = None + data = None for line in (l.strip() for l in decoded.split("\n") if l.strip()): - if line.startswith("event: "): - message["event"] = line.removeprefix("event: ") - elif line.startswith("data: "): - try: - json_str = line.removeprefix("data: ") - data = json.loads(json_str) - if not isinstance(data, dict): - raise json.JSONDecodeError( - f"Invalid SSE data format - expected object, got: { - json_str}", - json_str, - 0, - ) - message["data"] = data - except json.JSONDecodeError as e: - raise json.JSONDecodeError(f"Invalid SSE JSON: {json_str}", e.doc, e.pos) - - if not message.get("data"): - raise json.JSONDecodeError( - f"Missing or invalid data field in SSE message: { - decoded}", - decoded, - 0, - ) + if line.startswith("event:"): + event_type = line.removeprefix("event:").strip() + elif line.startswith("data:"): + data = line.removeprefix("data:").strip() - return message + return SSEvent(event=event_type, data=data, retry=None) except UnicodeDecodeError as e: - raise json.JSONDecodeError( - f"Invalid UTF-8 encoding in SSE chunk: { - chunk!r}", - str(chunk), - 0, - ) + raise InferenceGatewayValidationError(f"Invalid UTF-8 encoding in SSE chunk: {chunk!r}") + + def _parse_json_line(self, line: bytes) -> Dict[str, Any]: + """Parse a single JSON line into a dictionary. - def _parse_json_line(self, line: bytes) -> ResponseTokens: - """Parse a single JSON line into GenerateResponse""" + Args: + line: JSON line as bytes + + Returns: + Dict[str, Any]: Parsed JSON data + + Raises: + InferenceGatewayValidationError: If JSON parsing fails + """ try: decoded_line = line.decode("utf-8") - data = json.loads(decoded_line) - return ResponseTokens.from_dict(data) + result: Dict[str, Any] = json.loads(decoded_line) + return result except UnicodeDecodeError as e: - raise json.JSONDecodeError(f"Invalid UTF-8 encoding: {line}", str(line), 0) + raise InferenceGatewayValidationError(f"Invalid UTF-8 encoding: {line!r}") except json.JSONDecodeError as e: - raise json.JSONDecodeError( - f"Invalid JSON response: { - decoded_line}", - e.doc, - e.pos, + raise InferenceGatewayValidationError(f"Invalid JSON response: {decoded_line}") + + def create_chat_completion( + self, + model: str, + messages: List[Message], + provider: Optional[Union[Provider, str]] = None, + max_tokens: Optional[int] = None, + stream: bool = False, + tools: Optional[List[Dict[str, Any]]] = None, + **kwargs: Any, + ) -> CreateChatCompletionResponse: + """Generate a chat completion. + + Args: + model: Name of the model to use + messages: List of messages for the conversation + provider: Optional provider specification + max_tokens: Maximum number of tokens to generate + stream: Whether to stream the response + tools: List of tools the model may call + **kwargs: Additional parameters to pass to the API + + Returns: + CreateChatCompletionResponse: The completion response + + Raises: + InferenceGatewayAPIError: If the API request fails + InferenceGatewayValidationError: If request/response validation fails + """ + url = f"{self.base_url}/v1/chat/completions" + params = {} + + if provider: + provider_value = provider.root if hasattr(provider, "root") else str(provider) + params["provider"] = provider_value + + try: + request_data = { + "model": model, + "messages": [msg.model_dump(exclude_none=True) for msg in messages], + "stream": stream, + } + + if max_tokens is not None: + request_data["max_tokens"] = max_tokens + if tools: + request_data["tools"] = tools + + request_data.update(kwargs) + + request = CreateChatCompletionRequest.model_validate(request_data) + + response = self._make_request( + "POST", url, params=params, json=request.model_dump(exclude_none=True) ) - def generate_content(self, provider: Provider, model: str, messages: List[Message]) -> Dict: - payload = {"model": model, "messages": [msg.to_dict() for msg in messages]} + return CreateChatCompletionResponse.model_validate(response.json()) - response = self.session.post( - f"{self.base_url}/llms/{provider.value}/generate", json=payload - ) - response.raise_for_status() - return response.json() + except ValidationError as e: + raise InferenceGatewayValidationError(f"Request/response validation failed: {e}") - def generate_content_stream( - self, provider: Provider, model: str, messages: List[Message], use_sse: bool = False - ) -> Generator[Union[ResponseTokens, dict], None, None]: - """Stream content generation from the model + def create_chat_completion_stream( + self, + model: str, + messages: List[Message], + provider: Optional[Union[Provider, str]] = None, + max_tokens: Optional[int] = None, + tools: Optional[List[Dict[str, Any]]] = None, + use_sse: bool = True, + **kwargs: Any, + ) -> Generator[Union[Dict[str, Any], SSEvent], None, None]: + """Stream a chat completion. Args: - provider: The provider to use model: Name of the model to use messages: List of messages for the conversation + provider: Optional provider specification + max_tokens: Maximum number of tokens to generate + tools: List of tools the model may call use_sse: Whether to use Server-Sent Events format + **kwargs: Additional parameters to pass to the API Yields: - Either ResponseTokens objects (for raw JSON) or dicts (for SSE) + Union[Dict[str, Any], SSEvent]: Stream chunks + + Raises: + InferenceGatewayAPIError: If the API request fails + InferenceGatewayValidationError: If request validation fails """ - payload = { - "model": model, - "messages": [msg.to_dict() for msg in messages], - "stream": True, - "ssevents": use_sse, - } + url = f"{self.base_url}/v1/chat/completions" + params = {} - response = self.session.post( - f"{self.base_url}/llms/{provider.value}/generate", json=payload, stream=True - ) - response.raise_for_status() + if provider: + provider_value = provider.root if hasattr(provider, "root") else str(provider) + params["provider"] = provider_value + + try: + request_data = { + "model": model, + "messages": [msg.model_dump(exclude_none=True) for msg in messages], + "stream": True, + } + + if max_tokens is not None: + request_data["max_tokens"] = max_tokens + if tools: + request_data["tools"] = tools + request_data.update(kwargs) + + request = CreateChatCompletionRequest.model_validate(request_data) + + if self.use_httpx: + with self.client.stream( + "POST", url, params=params, json=request.model_dump(exclude_none=True) + ) as response: + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + raise InferenceGatewayAPIError(f"Request failed: {str(e)}") + yield from self._process_stream_response(response, use_sse) + else: + requests_response = self.session.post( + url, params=params, json=request.model_dump(exclude_none=True), stream=True + ) + try: + requests_response.raise_for_status() + except (requests.exceptions.HTTPError, Exception) as e: + raise InferenceGatewayAPIError(f"Request failed: {str(e)}") + yield from self._process_stream_response(requests_response, use_sse) + + except ValidationError as e: + raise InferenceGatewayValidationError(f"Request validation failed: {e}") + + def _process_stream_response( + self, response: Union[requests.Response, httpx.Response], use_sse: bool + ) -> Generator[Union[Dict[str, Any], SSEvent], None, None]: + """Process streaming response data.""" if use_sse: - buffer = [] + buffer: List[bytes] = [] for line in response.iter_lines(): if not line: @@ -239,14 +405,77 @@ def generate_content_stream( buffer = [] continue - buffer.append(line) + if isinstance(line, str): + line_bytes = line.encode("utf-8") + else: + line_bytes = line + buffer.append(line_bytes) else: for line in response.iter_lines(): if not line: continue - yield self._parse_json_line(line) + + if isinstance(line, str): + line_bytes = line.encode("utf-8") + else: + line_bytes = line + + if line_bytes.strip() == b"data: [DONE]": + continue + if line_bytes.startswith(b"data: "): + json_str = line_bytes[6:].decode("utf-8") + data = json.loads(json_str) + yield data + else: + yield self._parse_json_line(line_bytes) + + def proxy_request( + self, + provider: Union[Provider, str], + path: str, + method: str = "GET", + json_data: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """Proxy a request to a provider's API. + + Args: + provider: The provider to route to + path: Path segment after the provider + method: HTTP method to use + json_data: Optional JSON data for request body + **kwargs: Additional parameters to pass to the request + + Returns: + Dict[str, Any]: Provider response + + Raises: + InferenceGatewayAPIError: If the API request fails + ValueError: If an unsupported HTTP method is used + """ + provider_value = provider.root if hasattr(provider, "root") else str(provider) + url = f"{self.base_url}/proxy/{provider_value}/{path.lstrip('/')}" + + method = method.upper() + if method not in ["GET", "POST", "PUT", "DELETE", "PATCH"]: + raise ValueError(f"Unsupported HTTP method: {method}") + + request_kwargs = kwargs.copy() + if json_data and method in ["POST", "PUT", "PATCH"]: + request_kwargs["json"] = json_data + + response = self._make_request(method, url, **request_kwargs) + result: Dict[str, Any] = response.json() + return result def health_check(self) -> bool: - """Check if the API is healthy""" - response = self.session.get(f"{self.base_url}/health") - return response.status_code == 200 + """Check if the API is healthy. + + Returns: + bool: True if the API is healthy, False otherwise + """ + try: + response = self._make_request("GET", f"{self.base_url}/health") + return response.status_code == 200 + except Exception: + return False diff --git a/inference_gateway/models.py b/inference_gateway/models.py new file mode 100644 index 0000000..932edeb --- /dev/null +++ b/inference_gateway/models.py @@ -0,0 +1,643 @@ +# generated by datamodel-codegen: +# filename: openapi.yaml +# timestamp: 2025-05-26T15:46:03+00:00 + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Annotated, Any, Literal, Optional + +from pydantic import BaseModel, ConfigDict, Field, RootModel + + +class Provider( + RootModel[Literal["ollama", "groq", "openai", "cloudflare", "cohere", "anthropic", "deepseek"]] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Literal["ollama", "groq", "openai", "cloudflare", "cohere", "anthropic", "deepseek"] + + def __eq__(self, other: Any) -> bool: + """Allow comparison with strings.""" + if isinstance(other, str): + return self.root == other + return super().__eq__(other) + + +class ProviderSpecificResponse(BaseModel): + """ + Provider-specific response format. Examples: + + OpenAI GET /v1/models?provider=openai response: + ```json + { + "provider": "openai", + "object": "list", + "data": [ + { + "id": "gpt-4", + "object": "model", + "created": 1687882410, + "owned_by": "openai", + "served_by": "openai" + } + ] + } + ``` + + Anthropic GET /v1/models?provider=anthropic response: + ```json + { + "provider": "anthropic", + "object": "list", + "data": [ + { + "id": "gpt-4", + "object": "model", + "created": 1687882410, + "owned_by": "openai", + "served_by": "openai" + } + ] + } + ``` + + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + + +class ProviderAuthType(RootModel[Literal["bearer", "xheader", "query", "none"]]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Literal["bearer", "xheader", "query", "none"] + """ + Authentication type for providers + """ + + +class SSEvent(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + event: Optional[ + Literal[ + "message-start", + "stream-start", + "content-start", + "content-delta", + "content-end", + "message-end", + "stream-end", + ] + ] + data: Optional[str] + retry: Optional[int] + + +class Endpoints(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + models: str + chat: str + + +class Error(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + error: Optional[str] + + +class MessageRole(RootModel[Literal["system", "user", "assistant", "tool"]]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Literal["system", "user", "assistant", "tool"] + """ + Role of the message sender + """ + + def __eq__(self, other: Any) -> bool: + """Allow comparison with strings.""" + if isinstance(other, str): + return self.root == other + return super().__eq__(other) + + +class Model(BaseModel): + """ + Common model information + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + id: str + object: str + created: int + owned_by: str + served_by: Provider + + +class ListModelsResponse(BaseModel): + """ + Response structure for listing models + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + provider: Optional[Provider] + object: str + data: Sequence[Model] + + +class MCPTool(BaseModel): + """ + An MCP tool definition + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + name: Annotated[str, Field(examples=["read_file"])] + """ + The name of the tool + """ + description: Annotated[str, Field(examples=["Read content from a file"])] + """ + A description of what the tool does + """ + server: Annotated[str, Field(examples=["http://mcp-filesystem-server:8083/mcp"])] + """ + The MCP server that provides this tool + """ + input_schema: Annotated[ + Optional[Mapping[str, Any]], + Field( + examples=[ + { + "type": "object", + "properties": { + "file_path": {"type": "string", "description": "Path to the file to read"} + }, + "required": ["file_path"], + } + ] + ), + ] + """ + JSON schema for the tool's input parameters + """ + + +class FunctionParameters(BaseModel): + """ + The parameters the functions accepts, described as a JSON Schema object. See the [guide](/docs/guides/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format. + Omitting `parameters` defines a function with an empty parameter list. + """ + + model_config = ConfigDict( + extra="allow", + populate_by_name=True, + ) + + +class ChatCompletionToolType(RootModel[Literal["function"]]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Literal["function"] = "function" + """ + The type of the tool. Currently, only `function` is supported. + """ + + +class CompletionUsage(BaseModel): + """ + Usage statistics for the completion request. + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + completion_tokens: int + """ + Number of tokens in the generated completion. + """ + prompt_tokens: int + """ + Number of tokens in the prompt. + """ + total_tokens: int + """ + Total number of tokens used in the request (prompt + completion). + """ + + +class ChatCompletionStreamOptions(BaseModel): + """ + Options for streaming response. Only set this when you set `stream: true`. + + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + include_usage: bool + """ + If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value. + + """ + + +class ChatCompletionMessageToolCallFunction(BaseModel): + """ + The function that the model called. + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + name: str + """ + The name of the function to call. + """ + arguments: str + """ + The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + """ + + +class ChatCompletionMessageToolCall(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + id: str + """ + The ID of the tool call. + """ + type: ChatCompletionToolType + function: ChatCompletionMessageToolCallFunction + + +class Function(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + name: Optional[str] + """ + The name of the function to call. + """ + arguments: Optional[str] + """ + The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + """ + + +class ChatCompletionMessageToolCallChunk(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + index: int + id: Optional[str] + """ + The ID of the tool call. + """ + type: Optional[str] + """ + The type of the tool. Currently, only `function` is supported. + """ + function: Optional[Function] + + +class TopLogprob(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + token: str + """ + The token. + """ + logprob: float + """ + The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely. + """ + bytes: Sequence[int] + """ + A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token. + """ + + +class ChatCompletionTokenLogprob(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + token: str + """ + The token. + """ + logprob: float + """ + The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely. + """ + bytes: Sequence[int] + """ + A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token. + """ + top_logprobs: Sequence[TopLogprob] + """ + List of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested `top_logprobs` returned. + """ + + +class FinishReason( + RootModel[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]] +): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Literal["stop", "length", "tool_calls", "content_filter", "function_call"] + """ + The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, + `length` if the maximum number of tokens specified in the request was reached, + `content_filter` if content was omitted due to a flag from our content filters, + `tool_calls` if the model called a tool. + + """ + + +class Config(RootModel[Any]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: Any + + +class Message(BaseModel): + """ + Message structure for provider requests + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + role: MessageRole + content: str + tool_calls: Optional[Sequence[ChatCompletionMessageToolCall]] = None + tool_call_id: Optional[str] = None + reasoning_content: Optional[str] = None + """ + The reasoning content of the chunk message. + """ + reasoning: Optional[str] = None + """ + The reasoning of the chunk message. Same as reasoning_content. + """ + + +class ListToolsResponse(BaseModel): + """ + Response structure for listing MCP tools + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + object: Annotated[str, Field(examples=["list"])] + """ + Always "list" + """ + data: Sequence[MCPTool] + """ + Array of available MCP tools + """ + + +class FunctionObject(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + description: Optional[str] + """ + A description of what the function does, used by the model to choose when and how to call the function. + """ + name: str + """ + The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + """ + parameters: Optional[FunctionParameters] + strict: bool = False + """ + Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](docs/guides/function-calling). + """ + + +class ChatCompletionTool(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + type: ChatCompletionToolType + function: FunctionObject + + +class CreateChatCompletionRequest(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + model: str + """ + Model ID to use + """ + messages: Annotated[Sequence[Message], Field(min_length=1)] + """ + A list of messages comprising the conversation so far. + + """ + max_tokens: Optional[int] = None + """ + An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens. + + """ + stream: bool = False + """ + If set to true, the model response data will be streamed to the client as it is generated using [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format). + + """ + stream_options: Optional[ChatCompletionStreamOptions] = None + tools: Optional[Sequence[ChatCompletionTool]] = None + """ + A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported. + + """ + reasoning_format: Optional[str] = None + """ + The format of the reasoning content. Can be `raw` or `parsed`. + When specified as raw some reasoning models will output tags. When specified as parsed the model will output the reasoning under `reasoning` or `reasoning_content` attribute. + + """ + + +class ChatCompletionChoice(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + finish_reason: Literal["stop", "length", "tool_calls", "content_filter", "function_call"] + """ + The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, + `length` if the maximum number of tokens specified in the request was reached, + `content_filter` if content was omitted due to a flag from our content filters, + `tool_calls` if the model called a tool. + + """ + index: int + """ + The index of the choice in the list of choices. + """ + message: Message + + +class Logprobs(BaseModel): + """ + Log probability information for the choice. + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + content: Sequence[ChatCompletionTokenLogprob] + """ + A list of message content tokens with log probability information. + """ + refusal: Sequence[ChatCompletionTokenLogprob] + """ + A list of message refusal tokens with log probability information. + """ + + +class CreateChatCompletionResponse(BaseModel): + """ + Represents a chat completion response returned by model, based on the provided input. + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + id: str + """ + A unique identifier for the chat completion. + """ + choices: Sequence[ChatCompletionChoice] + """ + A list of chat completion choices. Can be more than one if `n` is greater than 1. + """ + created: int + """ + The Unix timestamp (in seconds) of when the chat completion was created. + """ + model: str + """ + The model used for the chat completion. + """ + object: str + """ + The object type, which is always `chat.completion`. + """ + usage: Optional[CompletionUsage] + + +class ChatCompletionStreamResponseDelta(BaseModel): + """ + A chat completion delta generated by streamed model responses. + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + content: str + """ + The contents of the chunk message. + """ + reasoning_content: Optional[str] + """ + The reasoning content of the chunk message. + """ + reasoning: Optional[str] + """ + The reasoning of the chunk message. Same as reasoning_content. + """ + tool_calls: Optional[Sequence[ChatCompletionMessageToolCallChunk]] + role: MessageRole + refusal: Optional[str] + """ + The refusal message generated by the model. + """ + + +class ChatCompletionStreamChoice(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + delta: ChatCompletionStreamResponseDelta + logprobs: Optional[Logprobs] + """ + Log probability information for the choice. + """ + finish_reason: FinishReason + index: int + """ + The index of the choice in the list of choices. + """ + + +class CreateChatCompletionStreamResponse(BaseModel): + """ + Represents a streamed chunk of a chat completion response returned + by the model, based on the provided input. + + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + id: str + """ + A unique identifier for the chat completion. Each chunk has the same ID. + """ + choices: Sequence[ChatCompletionStreamChoice] + """ + A list of chat completion choices. Can contain more than one elements if `n` is greater than 1. Can also be empty for the + last chunk if you set `stream_options: {"include_usage": true}`. + + """ + created: int + """ + The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp. + """ + model: str + """ + The model to generate the completion. + """ + system_fingerprint: Optional[str] + """ + This fingerprint represents the backend configuration that the model runs with. + Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. + + """ + object: str + """ + The object type, which is always `chat.completion.chunk`. + """ + usage: Optional[CompletionUsage] + reasoning_format: Optional[str] + """ + The format of the reasoning content. Can be `raw` or `parsed`. + When specified as raw some reasoning models will output tags. When specified as parsed the model will output the reasoning under reasoning_content. + + """ diff --git a/openapi.yaml b/openapi.yaml index b4abc85..c81d41f 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -3,79 +3,178 @@ openapi: 3.1.0 info: title: Inference Gateway API description: | - API for interacting with various language models through the Inference Gateway. + The API for interacting with various language models and other AI services. + OpenAI, Groq, Ollama, and other providers are supported. + OpenAI compatible API for using with existing clients. + Unified API for all providers. + contact: + name: Inference Gateway + url: https://inference-gateway.github.io/docs/ version: 1.0.0 + license: + name: MIT + url: https://github.com/inference-gateway/inference-gateway/blob/main/LICENSE servers: - url: http://localhost:8080 + description: Default server without version prefix for healthcheck and proxy and points + x-server-tags: ["Health", "Proxy"] + - url: http://localhost:8080/v1 + description: Default server with version prefix for listing models and chat completions + x-server-tags: ["Models", "Completions"] + - url: https://api.inference-gateway.local/v1 + description: Local server with version prefix for listing models and chat completions + x-server-tags: ["Models", "Completions"] +tags: + - name: Models + description: List and describe the various models available in the API. + - name: Completions + description: Generate completions from the models. + - name: Tools + description: List and manage MCP tools. + - name: Proxy + description: Proxy requests to provider endpoints. + - name: Health + description: Health check paths: - /llms: + /models: get: - summary: List all language models operationId: listModels + tags: + - Models + description: | + Lists the currently available models, and provides basic information + about each one such as the owner and availability. + summary: + Lists the currently available models, and provides basic information + about each one such as the owner and availability. security: - bearerAuth: [] + parameters: + - name: provider + in: query + required: false + schema: + $ref: "#/components/schemas/Provider" + description: Specific provider to query (optional) responses: "200": - description: A list of models by provider + description: List of available models content: application/json: schema: - type: array - items: - $ref: "#/components/schemas/ListModelsResponse" + $ref: "#/components/schemas/ListModelsResponse" + examples: + allProviders: + summary: Models from all providers + value: + object: "list" + data: + - id: "openai/gpt-4o" + object: "model" + created: 1686935002 + owned_by: "openai" + served_by: "openai" + - id: "openai/llama-3.3-70b-versatile" + object: "model" + created: 1723651281 + owned_by: "groq" + served_by: "groq" + - id: "cohere/claude-3-opus-20240229" + object: "model" + created: 1708905600 + owned_by: "anthropic" + served_by: "anthropic" + - id: "cohere/command-r" + object: "model" + created: 1707868800 + owned_by: "cohere" + served_by: "cohere" + - id: "ollama/phi3:3.8b" + object: "model" + created: 1718441600 + owned_by: "ollama" + served_by: "ollama" + singleProvider: + summary: Models from a specific provider + value: + object: "list" + data: + - id: "openai/gpt-4o" + object: "model" + created: 1686935002 + owned_by: "openai" + served_by: "openai" + - id: "openai/gpt-4-turbo" + object: "model" + created: 1687882410 + owned_by: "openai" + served_by: "openai" + - id: "openai/gpt-3.5-turbo" + object: "model" + created: 1677649963 + owned_by: "openai" + served_by: "openai" "401": $ref: "#/components/responses/Unauthorized" - /llms/{provider}: - get: - summary: List all models for a specific provider - operationId: listModelsByProvider + "500": + $ref: "#/components/responses/InternalError" + /chat/completions: + post: + operationId: createChatCompletion + tags: + - Completions + description: | + Generates a chat completion based on the provided input. + The completion can be streamed to the client as it is generated. + summary: Create a chat completion + security: + - bearerAuth: [] parameters: - name: provider - in: path - required: true + in: query + required: false schema: - $ref: "#/components/schemas/Providers" - security: - - bearerAuth: [] + $ref: "#/components/schemas/Provider" + description: Specific provider to use (default determined by model) + requestBody: + $ref: "#/components/requestBodies/CreateChatCompletionRequest" responses: "200": - description: A list of models + description: Successful response content: application/json: schema: - $ref: "#/components/schemas/ListModelsResponse" + $ref: "#/components/schemas/CreateChatCompletionResponse" + text/event-stream: + schema: + $ref: "#/components/schemas/SSEvent" "400": $ref: "#/components/responses/BadRequest" "401": $ref: "#/components/responses/Unauthorized" - /llms/{provider}/generate: - post: - summary: Generate content with a specific provider's LLM - operationId: generateContent - parameters: - - name: provider - in: path - required: true - schema: - $ref: "#/components/schemas/Providers" + "500": + $ref: "#/components/responses/InternalError" + /mcp/tools: + get: + operationId: listTools + tags: + - Tools + description: | + Lists the currently available MCP tools. Only accessible when EXPOSE_MCP is enabled. + summary: Lists the currently available MCP tools security: - bearerAuth: [] - requestBody: - content: - application/json: - schema: - $ref: "#/components/schemas/GenerateRequest" responses: "200": - description: Generated content + description: Successful response content: application/json: schema: - $ref: "#/components/schemas/GenerateResponse" - "400": - $ref: "#/components/responses/BadRequest" + $ref: "#/components/schemas/ListToolsResponse" "401": $ref: "#/components/responses/Unauthorized" + "403": + $ref: "#/components/responses/MCPNotExposed" "500": $ref: "#/components/responses/InternalError" /proxy/{provider}/{path}: @@ -84,7 +183,7 @@ paths: in: path required: true schema: - $ref: "#/components/schemas/Providers" + $ref: "#/components/schemas/Provider" - name: path in: path required: true @@ -94,8 +193,14 @@ paths: type: string description: The remaining path to proxy to the provider get: - summary: Proxy GET request to provider operationId: proxyGet + tags: + - Proxy + description: | + Proxy GET request to provider + The request body depends on the specific provider and endpoint being called. + If you decide to use this approach, please follow the provider-specific documentations. + summary: Proxy GET request to provider responses: "200": $ref: "#/components/responses/ProviderResponse" @@ -108,8 +213,14 @@ paths: security: - bearerAuth: [] post: - summary: Proxy POST request to provider operationId: proxyPost + tags: + - Proxy + description: | + Proxy POST request to provider + The request body depends on the specific provider and endpoint being called. + If you decide to use this approach, please follow the provider-specific documentations. + summary: Proxy POST request to provider requestBody: $ref: "#/components/requestBodies/ProviderRequest" responses: @@ -124,8 +235,14 @@ paths: security: - bearerAuth: [] put: - summary: Proxy PUT request to provider operationId: proxyPut + tags: + - Proxy + description: | + Proxy PUT request to provider + The request body depends on the specific provider and endpoint being called. + If you decide to use this approach, please follow the provider-specific documentations. + summary: Proxy PUT request to provider requestBody: $ref: "#/components/requestBodies/ProviderRequest" responses: @@ -140,8 +257,14 @@ paths: security: - bearerAuth: [] delete: - summary: Proxy DELETE request to provider operationId: proxyDelete + tags: + - Proxy + description: | + Proxy DELETE request to provider + The request body depends on the specific provider and endpoint being called. + If you decide to use this approach, please follow the provider-specific documentations. + summary: Proxy DELETE request to provider responses: "200": $ref: "#/components/responses/ProviderResponse" @@ -154,8 +277,14 @@ paths: security: - bearerAuth: [] patch: - summary: Proxy PATCH request to provider operationId: proxyPatch + tags: + - Proxy + description: | + Proxy PATCH request to provider + The request body depends on the specific provider and endpoint being called. + If you decide to use this approach, please follow the provider-specific documentations. + summary: Proxy PATCH request to provider requestBody: $ref: "#/components/requestBodies/ProviderRequest" responses: @@ -171,6 +300,12 @@ paths: - bearerAuth: [] /health: get: + operationId: healthCheck + tags: + - Health + description: | + Health check endpoint + Returns a 200 status code if the service is healthy summary: Health check responses: "200": @@ -200,25 +335,34 @@ components: type: string temperature: type: number - format: float64 + format: float default: 0.7 - examples: - - openai: - summary: OpenAI chat completion request - value: - model: "gpt-3.5-turbo" - messages: - - role: "user" - content: "Hello! How can I assist you today?" - temperature: 0.7 - - anthropic: - summary: Anthropic Claude request - value: - model: "claude-3-opus-20240229" - messages: - - role: "user" - content: "Explain quantum computing" - temperature: 0.5 + examples: + openai: + summary: OpenAI chat completion request + value: + model: "gpt-3.5-turbo" + messages: + - role: "user" + content: "Hello! How can I assist you today?" + temperature: 0.7 + anthropic: + summary: Anthropic Claude request + value: + model: "claude-3-opus-20240229" + messages: + - role: "user" + content: "Explain quantum computing" + temperature: 0.5 + CreateChatCompletionRequest: + required: true + description: | + ProviderRequest depends on the specific provider and endpoint being called + If you decide to use this approach, please follow the provider-specific documentations. + content: + application/json: + schema: + $ref: "#/components/schemas/CreateChatCompletionRequest" responses: BadRequest: description: Bad request @@ -238,6 +382,14 @@ components: application/json: schema: $ref: "#/components/schemas/Error" + MCPNotExposed: + description: MCP tools endpoint is not exposed + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + example: + error: "MCP tools endpoint is not exposed. Set EXPOSE_MCP=true to enable." ProviderResponse: description: | ProviderResponse depends on the specific provider and endpoint being called @@ -278,7 +430,7 @@ components: To enable authentication, set ENABLE_AUTH to true. When enabled, requests must include a valid JWT token in the Authorization header. schemas: - Providers: + Provider: type: string enum: - ollama @@ -287,36 +439,137 @@ components: - cloudflare - cohere - anthropic + - deepseek + x-provider-configs: + ollama: + id: "ollama" + url: "http://ollama:8080/v1" + auth_type: "none" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/chat/completions" + anthropic: + id: "anthropic" + url: "https://api.anthropic.com/v1" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/chat/completions" + cohere: + id: "cohere" + url: "https://api.cohere.ai" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/v1/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/compatibility/v1/chat/completions" + groq: + id: "groq" + url: "https://api.groq.com/openai/v1" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/chat/completions" + openai: + id: "openai" + url: "https://api.openai.com/v1" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/chat/completions" + cloudflare: + id: "cloudflare" + url: "https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/finetunes/public?limit=1000" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/v1/chat/completions" + deepseek: + id: "deepseek" + url: "https://api.deepseek.com" + auth_type: "bearer" + endpoints: + models: + name: "list_models" + method: "GET" + endpoint: "/models" + chat: + name: "chat_completions" + method: "POST" + endpoint: "/chat/completions" ProviderSpecificResponse: type: object description: | Provider-specific response format. Examples: - OpenAI GET /v1/models response: + OpenAI GET /v1/models?provider=openai response: ```json { + "provider": "openai", + "object": "list", "data": [ { "id": "gpt-4", "object": "model", - "created": 1687882410 + "created": 1687882410, + "owned_by": "openai", + "served_by": "openai" } ] } ``` - Anthropic GET /v1/models response: + Anthropic GET /v1/models?provider=anthropic response: ```json { - "models": [ + "provider": "anthropic", + "object": "list", + "data": [ { - "name": "claude-3-opus-20240229", - "description": "Most capable model for highly complex tasks" + "id": "gpt-4", + "object": "model", + "created": 1687882410, + "owned_by": "openai", + "served_by": "openai" } ] } ``` - additionalProperties: true ProviderAuthType: type: string description: Authentication type for providers @@ -325,6 +578,34 @@ components: - xheader - query - none + SSEvent: + type: object + properties: + event: + type: string + enum: + - message-start + - stream-start + - content-start + - content-delta + - content-end + - message-end + - stream-end + data: + type: string + format: byte + retry: + type: integer + Endpoints: + type: object + properties: + models: + type: string + chat: + type: string + required: + - models + - chat Error: type: object properties: @@ -337,6 +618,7 @@ components: - system - user - assistant + - tool Message: type: object description: Message structure for provider requests @@ -345,83 +627,552 @@ components: $ref: "#/components/schemas/MessageRole" content: type: string + tool_calls: + type: array + items: + $ref: "#/components/schemas/ChatCompletionMessageToolCall" + tool_call_id: + type: string + reasoning_content: + type: string + description: The reasoning content of the chunk message. + reasoning: + type: string + description: The reasoning of the chunk message. Same as reasoning_content. + required: + - role + - content Model: type: object description: Common model information properties: - name: + id: + type: string + object: type: string + created: + type: integer + format: int64 + owned_by: + type: string + served_by: + $ref: "#/components/schemas/Provider" + required: + - id + - object + - created + - owned_by + - served_by ListModelsResponse: type: object description: Response structure for listing models properties: provider: - $ref: "#/components/schemas/Providers" - models: + $ref: "#/components/schemas/Provider" + object: + type: string + data: type: array items: $ref: "#/components/schemas/Model" - GenerateRequest: + default: [] + required: + - object + - data + ListToolsResponse: + type: object + description: Response structure for listing MCP tools + properties: + object: + type: string + description: Always "list" + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/MCPTool" + default: [] + description: Array of available MCP tools + required: + - object + - data + MCPTool: type: object - description: Request structure for token generation + description: An MCP tool definition + properties: + name: + type: string + description: The name of the tool + example: "read_file" + description: + type: string + description: A description of what the tool does + example: "Read content from a file" + server: + type: string + description: The MCP server that provides this tool + example: "http://mcp-filesystem-server:8083/mcp" + input_schema: + type: object + description: JSON schema for the tool's input parameters + example: + type: "object" + properties: + file_path: + type: "string" + description: "Path to the file to read" + required: ["file_path"] required: - - model - - messages + - name + - description + - server + FunctionObject: + type: object + properties: + description: + type: string + description: + A description of what the function does, used by the model to + choose when and how to call the function. + name: + type: string + description: + The name of the function to be called. Must be a-z, A-Z, 0-9, or + contain underscores and dashes, with a maximum length of 64. + parameters: + $ref: "#/components/schemas/FunctionParameters" + strict: + type: boolean + default: false + description: + Whether to enable strict schema adherence when generating the + function call. If set to true, the model will follow the exact + schema defined in the `parameters` field. Only a subset of JSON + Schema is supported when `strict` is `true`. Learn more about + Structured Outputs in the [function calling + guide](docs/guides/function-calling). + required: + - name + ChatCompletionTool: + type: object + properties: + type: + $ref: "#/components/schemas/ChatCompletionToolType" + function: + $ref: "#/components/schemas/FunctionObject" + required: + - type + - function + FunctionParameters: + type: object + description: >- + The parameters the functions accepts, described as a JSON Schema object. + See the [guide](/docs/guides/function-calling) for examples, and the + [JSON Schema + reference](https://json-schema.org/understanding-json-schema/) for + documentation about the format. + + Omitting `parameters` defines a function with an empty parameter list. + additionalProperties: true + ChatCompletionToolType: + type: string + description: The type of the tool. Currently, only `function` is supported. + enum: + - function + CompletionUsage: + type: object + description: Usage statistics for the completion request. + properties: + completion_tokens: + type: integer + default: 0 + format: int64 + description: Number of tokens in the generated completion. + prompt_tokens: + type: integer + default: 0 + format: int64 + description: Number of tokens in the prompt. + total_tokens: + type: integer + default: 0 + format: int64 + description: Total number of tokens used in the request (prompt + completion). + required: + - prompt_tokens + - completion_tokens + - total_tokens + ChatCompletionStreamOptions: + description: > + Options for streaming response. Only set this when you set `stream: + true`. + type: object + properties: + include_usage: + type: boolean + description: > + If set, an additional chunk will be streamed before the `data: + [DONE]` message. The `usage` field on this chunk shows the token + usage statistics for the entire request, and the `choices` field + will always be an empty array. All other chunks will also include a + `usage` field, but with a null value. + required: + - include_usage + CreateChatCompletionRequest: + type: object properties: model: type: string + description: Model ID to use messages: + description: > + A list of messages comprising the conversation so far. type: array + minItems: 1 items: $ref: "#/components/schemas/Message" + max_tokens: + description: > + An upper bound for the number of tokens that can be generated + for a completion, including visible output tokens and reasoning tokens. + type: integer stream: + description: > + If set to true, the model response data will be streamed to the + client as it is generated using [server-sent + events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format). type: boolean default: false - description: Whether to stream tokens as they are generated in raw json - ssevents: - type: boolean - default: false - description: | - Whether to use Server-Sent Events for token generation. - When enabled, the response will be streamed as SSE with the following event types: - - message-start: Initial message event with assistant role - - stream-start: Stream initialization - - content-start: Content beginning - - content-delta: Content update with new tokens - - content-end: Content completion - - message-end: Message completion - - stream-end: Stream completion + stream_options: + $ref: "#/components/schemas/ChatCompletionStreamOptions" + tools: + type: array + description: > + A list of tools the model may call. Currently, only functions + are supported as a tool. Use this to provide a list of functions + the model may generate JSON inputs for. A max of 128 functions + are supported. + items: + $ref: "#/components/schemas/ChatCompletionTool" + reasoning_format: + type: string + description: > + The format of the reasoning content. Can be `raw` or `parsed`. - **Note:** Depending on the provider, some events may not be present. - ResponseTokens: + When specified as raw some reasoning models will output tags. + When specified as parsed the model will output the reasoning under + `reasoning` or `reasoning_content` attribute. + required: + - model + - messages + ChatCompletionMessageToolCallFunction: type: object - description: Token response structure + description: The function that the model called. properties: - role: + name: + type: string + description: The name of the function to call. + arguments: type: string + description: + The arguments to call the function with, as generated by the model + in JSON format. Note that the model does not always generate + valid JSON, and may hallucinate parameters not defined by your + function schema. Validate the arguments in your code before + calling your function. + required: + - name + - arguments + ChatCompletionMessageToolCall: + type: object + properties: + id: + type: string + description: The ID of the tool call. + type: + $ref: "#/components/schemas/ChatCompletionToolType" + function: + $ref: "#/components/schemas/ChatCompletionMessageToolCallFunction" + required: + - id + - type + - function + ChatCompletionChoice: + type: object + properties: + finish_reason: + type: string + description: > + The reason the model stopped generating tokens. This will be + `stop` if the model hit a natural stop point or a provided + stop sequence, + + `length` if the maximum number of tokens specified in the + request was reached, + + `content_filter` if content was omitted due to a flag from our + content filters, + + `tool_calls` if the model called a tool. + enum: + - stop + - length + - tool_calls + - content_filter + - function_call + index: + type: integer + description: The index of the choice in the list of choices. + message: + $ref: "#/components/schemas/Message" + required: + - finish_reason + - index + - message + - logprobs + ChatCompletionStreamChoice: + type: object + required: + - delta + - finish_reason + - index + properties: + delta: + $ref: "#/components/schemas/ChatCompletionStreamResponseDelta" + logprobs: + description: Log probability information for the choice. + type: object + properties: + content: + description: A list of message content tokens with log probability information. + type: array + items: + $ref: "#/components/schemas/ChatCompletionTokenLogprob" + refusal: + description: A list of message refusal tokens with log probability information. + type: array + items: + $ref: "#/components/schemas/ChatCompletionTokenLogprob" + required: + - content + - refusal + finish_reason: + $ref: "#/components/schemas/FinishReason" + index: + type: integer + description: The index of the choice in the list of choices. + CreateChatCompletionResponse: + type: object + description: + Represents a chat completion response returned by model, based on + the provided input. + properties: + id: + type: string + description: A unique identifier for the chat completion. + choices: + type: array + description: + A list of chat completion choices. Can be more than one if `n` is + greater than 1. + items: + $ref: "#/components/schemas/ChatCompletionChoice" + created: + type: integer + description: + The Unix timestamp (in seconds) of when the chat completion was + created. model: type: string + description: The model used for the chat completion. + object: + type: string + description: The object type, which is always `chat.completion`. + x-stainless-const: true + usage: + $ref: "#/components/schemas/CompletionUsage" + required: + - choices + - created + - id + - model + - object + ChatCompletionStreamResponseDelta: + type: object + description: A chat completion delta generated by streamed model responses. + properties: content: type: string - GenerateResponse: + description: The contents of the chunk message. + reasoning_content: + type: string + description: The reasoning content of the chunk message. + reasoning: + type: string + description: The reasoning of the chunk message. Same as reasoning_content. + tool_calls: + type: array + items: + $ref: "#/components/schemas/ChatCompletionMessageToolCallChunk" + role: + $ref: "#/components/schemas/MessageRole" + refusal: + type: string + description: The refusal message generated by the model. + required: + - content + - role + ChatCompletionMessageToolCallChunk: type: object - description: Response structure for token generation properties: - provider: + index: + type: integer + id: + type: string + description: The ID of the tool call. + type: + type: string + description: The type of the tool. Currently, only `function` is supported. + function: + type: object + properties: + name: + type: string + description: The name of the function to call. + arguments: + type: string + description: + The arguments to call the function with, as generated by the model + in JSON format. Note that the model does not always generate + valid JSON, and may hallucinate parameters not defined by your + function schema. Validate the arguments in your code before + calling your function. + required: + - index + ChatCompletionTokenLogprob: + type: object + properties: + token: &a1 + description: The token. + type: string + logprob: &a2 + description: + The log probability of this token, if it is within the top 20 most + likely tokens. Otherwise, the value `-9999.0` is used to signify + that the token is very unlikely. + type: number + bytes: &a3 + description: + A list of integers representing the UTF-8 bytes representation of + the token. Useful in instances where characters are represented by + multiple tokens and their byte representations must be combined to + generate the correct text representation. Can be `null` if there is + no bytes representation for the token. + type: array + items: + type: integer + top_logprobs: + description: + List of the most likely tokens and their log probability, at this + token position. In rare cases, there may be fewer than the number of + requested `top_logprobs` returned. + type: array + items: + type: object + properties: + token: *a1 + logprob: *a2 + bytes: *a3 + required: + - token + - logprob + - bytes + required: + - token + - logprob + - bytes + - top_logprobs + FinishReason: + type: string + description: > + The reason the model stopped generating tokens. This will be + `stop` if the model hit a natural stop point or a provided + stop sequence, + + `length` if the maximum number of tokens specified in the + request was reached, + + `content_filter` if content was omitted due to a flag from our + content filters, + + `tool_calls` if the model called a tool. + enum: + - stop + - length + - tool_calls + - content_filter + - function_call + CreateChatCompletionStreamResponse: + type: object + description: | + Represents a streamed chunk of a chat completion response returned + by the model, based on the provided input. + properties: + id: + type: string + description: + A unique identifier for the chat completion. Each chunk has the + same ID. + choices: + type: array + description: > + A list of chat completion choices. Can contain more than one + elements if `n` is greater than 1. Can also be empty for the + + last chunk if you set `stream_options: {"include_usage": true}`. + items: + $ref: "#/components/schemas/ChatCompletionStreamChoice" + created: + type: integer + description: + The Unix timestamp (in seconds) of when the chat completion was + created. Each chunk has the same timestamp. + model: type: string - response: - $ref: "#/components/schemas/ResponseTokens" + description: The model to generate the completion. + system_fingerprint: + type: string + description: > + This fingerprint represents the backend configuration that the model + runs with. + + Can be used in conjunction with the `seed` request parameter to + understand when backend changes have been made that might impact + determinism. + object: + type: string + description: The object type, which is always `chat.completion.chunk`. + usage: + $ref: "#/components/schemas/CompletionUsage" + reasoning_format: + type: string + description: > + The format of the reasoning content. Can be `raw` or `parsed`. + + When specified as raw some reasoning models will output tags. + When specified as parsed the model will output the reasoning under reasoning_content. + required: + - choices + - created + - id + - model + - object Config: x-config: sections: - general: title: "General settings" settings: - - name: application_name - env: "APPLICATION_NAME" - type: string - default: "inference-gateway" - description: "The name of the application" - name: environment env: "ENVIRONMENT" type: string @@ -437,6 +1188,53 @@ components: type: bool default: "false" description: "Enable authentication" + - mcp: + title: "Model Context Protocol (MCP)" + settings: + - name: mcp_enable + env: "MCP_ENABLE" + type: bool + default: "false" + description: "Enable MCP" + - name: mcp_expose + env: "MCP_EXPOSE" + type: bool + default: "false" + description: "Expose MCP tools endpoint" + - name: mcp_servers + env: "MCP_SERVERS" + type: string + description: "List of MCP servers" + - name: mcp_client_timeout + env: "MCP_CLIENT_TIMEOUT" + type: time.Duration + default: "5s" + description: "MCP client HTTP timeout" + - name: mcp_dial_timeout + env: "MCP_DIAL_TIMEOUT" + type: time.Duration + default: "3s" + description: "MCP client dial timeout" + - name: mcp_tls_handshake_timeout + env: "MCP_TLS_HANDSHAKE_TIMEOUT" + type: time.Duration + default: "3s" + description: "MCP client TLS handshake timeout" + - name: mcp_response_header_timeout + env: "MCP_RESPONSE_HEADER_TIMEOUT" + type: time.Duration + default: "3s" + description: "MCP client response header timeout" + - name: mcp_expect_continue_timeout + env: "MCP_EXPECT_CONTINUE_TIMEOUT" + type: time.Duration + default: "1s" + description: "MCP client expect continue timeout" + - name: mcp_request_timeout + env: "MCP_REQUEST_TIMEOUT" + type: time.Duration + default: "5s" + description: "MCP client request timeout for initialize and tool calls" - oidc: title: "OpenID Connect" settings: @@ -526,7 +1324,7 @@ components: - name: anthropic_api_url env: "ANTHROPIC_API_URL" type: string - default: "https://api.anthropic.com" + default: "https://api.anthropic.com/v1" description: "Anthropic API URL" - name: anthropic_api_key env: "ANTHROPIC_API_KEY" @@ -536,7 +1334,7 @@ components: - name: cloudflare_api_url env: "CLOUDFLARE_API_URL" type: string - default: "https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}" + default: "https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai" description: "Cloudflare API URL" - name: cloudflare_api_key env: "CLOUDFLARE_API_KEY" @@ -546,7 +1344,7 @@ components: - name: cohere_api_url env: "COHERE_API_URL" type: string - default: "https://api.cohere.com" + default: "https://api.cohere.ai" description: "Cohere API URL" - name: cohere_api_key env: "COHERE_API_KEY" @@ -556,7 +1354,7 @@ components: - name: groq_api_url env: "GROQ_API_URL" type: string - default: "https://api.groq.com" + default: "https://api.groq.com/openai/v1" description: "Groq API URL" - name: groq_api_key env: "GROQ_API_KEY" @@ -566,7 +1364,7 @@ components: - name: ollama_api_url env: "OLLAMA_API_URL" type: string - default: "http://ollama:8080" + default: "http://ollama:8080/v1" description: "Ollama API URL" - name: ollama_api_key env: "OLLAMA_API_KEY" @@ -576,451 +1374,20 @@ components: - name: openai_api_url env: "OPENAI_API_URL" type: string - default: "https://api.openai.com" + default: "https://api.openai.com/v1" description: "OpenAI API URL" - name: openai_api_key env: "OPENAI_API_KEY" type: string description: "OpenAI API Key" secret: true - x-provider-configs: - ollama: - id: "ollama" - url: "http://ollama:8080" - auth_type: "none" - endpoints: - list: - endpoint: "/api/tags" - method: "GET" - schema: - response: - type: object - properties: - models: - type: array - items: - type: object - properties: - name: - type: string - modified_at: - type: string - size: - type: integer - digest: - type: string - details: - type: object - properties: - format: - type: string - family: - type: string - families: - type: array - items: - type: string - parameter_size: - type: string - generate: - endpoint: "/api/generate" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - prompt: - type: string - stream: - type: boolean - system: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - provider: - type: string - response: - type: object - properties: - role: - type: string - model: - type: string - content: - type: string - openai: - id: "openai" - url: "https://api.openai.com" - auth_type: "bearer" - endpoints: - list: - endpoint: "/v1/models" - method: "GET" - schema: - response: - type: object - properties: - object: - type: string - data: - type: array - items: - type: object - properties: - id: - type: string - object: - type: string - created: - type: integer - format: int64 - owned_by: - type: string - permission: - type: array - items: - type: object - properties: - id: - type: string - object: - type: string - created: - type: integer - format: int64 - allow_create_engine: - type: boolean - allow_sampling: - type: boolean - allow_logprobs: - type: boolean - allow_search_indices: - type: boolean - allow_view: - type: boolean - allow_fine_tuning: - type: boolean - root: - type: string - parent: - type: string - generate: - endpoint: "/v1/chat/completions" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - messages: - type: array - items: - type: object - properties: - role: - type: string - content: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - model: - type: string - choices: - type: array - items: - type: object - properties: - message: - type: object - properties: - role: - type: string - content: - type: string - groq: - id: "groq" - url: "https://api.groq.com" - auth_type: "bearer" - endpoints: - list: - endpoint: "/openai/v1/models" - method: "GET" - schema: - response: - type: object - properties: - object: - type: string - data: - type: array - items: - type: object - properties: - id: - type: string - object: - type: string - created: - type: integer - format: int64 - owned_by: - type: string - active: - type: boolean - context_window: - type: integer - public_apps: - type: object - generate: - endpoint: "/openai/v1/chat/completions" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - messages: - type: array - items: - type: object - properties: - role: - type: string - content: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - model: - type: string - choices: - type: array - items: - type: object - properties: - message: - type: object - properties: - role: - type: string - content: - type: string - cloudflare: - id: "cloudflare" - url: "https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}" - auth_type: "bearer" - endpoints: - list: - endpoint: "/ai/finetunes/public" - method: "GET" - schema: - response: - type: object - properties: - result: - type: array - items: - type: object - properties: - id: - type: string - name: - type: string - description: - type: string - created_at: - type: string - modified_at: - type: string - public: - type: integer - model: - type: string - generate: - endpoint: "/v1/chat/completions" - method: "POST" - schema: - request: - type: object - properties: - prompt: - type: string - model: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - result: - type: object - properties: - response: - type: string - cohere: - id: "cohere" - url: "https://api.cohere.com" - auth_type: "bearer" - endpoints: - list: - endpoint: "/v1/models" - method: "GET" - schema: - response: - type: object - properties: - models: - type: array - items: - type: object - properties: - name: - type: string - endpoints: - type: array - items: - type: string - finetuned: - type: boolean - context_length: - type: number - format: float64 - tokenizer_url: - type: string - default_endpoints: - type: array - items: - type: string - next_page_token: - type: string - generate: - endpoint: "/v2/chat" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - messages: - type: array - items: - type: object - properties: - role: - type: string - content: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - message: - type: object - properties: - role: - type: string - content: - type: array - items: - type: object - properties: - type: - type: string - text: - type: string - anthropic: - id: "anthropic" - url: "https://api.anthropic.com" - auth_type: "xheader" - extra_headers: - anthropic-version: "2023-06-01" - endpoints: - list: - endpoint: "/v1/models" - method: "GET" - schema: - response: - type: object - properties: - models: - type: array - items: - type: object - properties: - type: - type: string - id: - type: string - display_name: - type: string - created_at: - type: string - has_more: - type: boolean - first_id: - type: string - last_id: - type: string - generate: - endpoint: "/v1/messages" - method: "POST" - schema: - request: - type: object - properties: - model: - type: string - messages: - type: array - items: - type: object - properties: - role: - type: string - content: - type: string - temperature: - type: number - format: float64 - default: 0.7 - response: - type: object - properties: - model: - type: string - choices: - type: array - items: - type: object - properties: - message: - type: object - properties: - role: - type: string - content: - type: string + - name: deepseek_api_url + env: "DEEPSEEK_API_URL" + type: string + default: "https://api.deepseek.com" + description: "DeepSeek API URL" + - name: deepseek_api_key + env: "DEEPSEEK_API_KEY" + type: string + description: "DeepSeek API Key" + secret: true diff --git a/pyproject.toml b/pyproject.toml index a786928..2adedaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,21 +1,55 @@ [project] name = "inference-gateway" -version = "0.3.0" -authors = [ - { name="Eden Reich", email="eden.reich@gmail.com" }, -] +version = "0.4.0" +authors = [{ name = "Eden Reich", email = "eden.reich@gmail.com" }] description = "A Python SDK for Inference Gateway" readme = "README.md" requires-python = ">=3.12" classifiers = [ "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Operating System :: OS Independent", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Internet :: WWW/HTTP :: HTTP Servers", + "Topic :: Scientific/Engineering :: Artificial Intelligence", ] license = { file = "LICENSE" } +dependencies = [ + "requests>=2.32.3", + "pydantic>=2.11.5", + "httpx>=0.28.1", + "typing-extensions>=4.13.2", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.3.5", + "pytest-cov>=6.1.1", + "pytest-watch>=4.2.0", + "black>=25.1.0", + "isort>=6.0.1", + "mypy>=1.15.0", + "datamodel-code-generator>=0.30.1", + "pre-commit>=4.2.0", + "types-requests>=2.32.3", +] [project.urls] Homepage = "https://github.com/inference-gateway/python-sdk" Issues = "https://github.com/inference-gateway/python-sdk/issues" +Documentation = "https://inference-gateway.github.io/docs/" +Repository = "https://github.com/inference-gateway/python-sdk" + +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +include = ["inference_gateway*"] [tool.black] line-length = 100 @@ -34,3 +68,39 @@ exclude = ''' | dist )/ ''' + +[tool.isort] +profile = "black" +line_length = 100 +multi_line_output = 3 + +[tool.mypy] +python_version = "3.12" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +plugins = ["pydantic.mypy"] + +[[tool.mypy.overrides]] +module = "tests.*" +disallow_untyped_defs = false +disable_error_code = ["import-not-found"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = "-v --tb=short --strict-config --strict-markers" + +[tool.coverage.run] +source = ["inference_gateway"] +omit = ["*/tests/*"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", +] diff --git a/requirements.txt b/requirements.txt index ee88e26..09b867c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,12 @@ -requests>=2.26.0 +requests>=2.32.3 +types-requests>=2.32.0.20250515 +pydantic>=2.11.5 +pytest>=8.3.5 +pytest-cov>=6.1.1 +pytest-watch>=4.2.0 +datamodel-code-generator>=0.30.1 +httpx>=0.28.1 +typing-extensions>=4.13.2 +black>=25.1.0 +isort>=6.0.1 +mypy>=1.15.0 diff --git a/templates/header.jinja2 b/templates/header.jinja2 new file mode 100644 index 0000000..90345f0 --- /dev/null +++ b/templates/header.jinja2 @@ -0,0 +1,35 @@ +# Generated by datamodel-codegen +# This file contains auto-generated Pydantic v2 models based on OpenAPI specification +# Source: https://github.com/inference-gateway/inference-gateway/blob/main/openapi.yaml +# Do not edit this file manually - it will be overwritten during code generation + +from __future__ import annotations + +from enum import Enum +from collections.abc import Mapping, Sequence +from typing import Annotated, Any, Literal, Optional, Union, Dict, List +from pydantic import BaseModel, Field, ConfigDict + +# Provider enum for easy access and type safety +class Provider(str, Enum): + """Supported AI providers for the Inference Gateway.""" + OLLAMA = "ollama" + GROQ = "groq" + OPENAI = "openai" + CLOUDFLARE = "cloudflare" + COHERE = "cohere" + ANTHROPIC = "anthropic" + DEEPSEEK = "deepseek" + + def __str__(self) -> str: + return self.value + DEEPSEEK = "deepseek" + +# Message role enum +class MessageRole(str, Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + +{{ imports }} diff --git a/templates/pydantic/BaseModel.jinja2 b/templates/pydantic/BaseModel.jinja2 new file mode 100644 index 0000000..39117a9 --- /dev/null +++ b/templates/pydantic/BaseModel.jinja2 @@ -0,0 +1,13 @@ +class {{ class_name }}({{ base_class }}): + """{{ description or class_name + " model for Inference Gateway API." }}""" + + # Pydantic v2 configuration + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + use_enum_values=True, + populate_by_name=True, + ) +{%- for field in fields %} + {{ field.name }}: {{ field.type_hint }} +{%- endfor -%} diff --git a/tests/test_client.py b/tests/test_client.py index 2684421..706bc8e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,13 +1,24 @@ +from typing import Any, Dict, List +from unittest.mock import Mock, patch + import pytest import requests -from unittest.mock import Mock, patch + from inference_gateway.client import ( + InferenceGatewayAPIError, InferenceGatewayClient, - Provider, - Role, + InferenceGatewayError, + InferenceGatewayValidationError, +) +from inference_gateway.models import ( + CreateChatCompletionRequest, + CreateChatCompletionResponse, + ListModelsResponse, Message, - GenerateResponse, - ResponseTokens, + MessageRole, + Model, + Provider, + SSEvent, ) @@ -22,7 +33,21 @@ def mock_response(): """Create a mock response""" mock = Mock() mock.status_code = 200 - mock.json.return_value = {"response": "test"} + mock.json.return_value = { + "provider": "openai", + "object": "list", + "data": [ + { + "id": "gpt-4", + "object": "model", + "created": 1687882410, + "owned_by": "openai", + "served_by": "openai", + } + ], + } + mock.headers = {"content-type": "application/json"} + mock.raise_for_status.return_value = None return mock @@ -31,10 +56,9 @@ def test_params(): """Fixture providing test parameters""" return { "api_url": "http://test-api", - "provider": Provider.OPENAI, + "provider": "openai", "model": "gpt-4", - "message": Message(Role.USER, "Hello"), - "endpoint": "/llms/openai/generate", + "message": Message(role="user", content="Hello"), } @@ -49,136 +73,183 @@ def test_client_initialization(): assert client_with_token.session.headers["Authorization"] == "Bearer test-token" -@patch("requests.Session.get") -def test_list_models(mock_get, client, mock_response): +@patch("requests.Session.request") +def test_list_models(mock_request, client, mock_response): """Test listing available models""" - mock_get.return_value = mock_response + mock_request.return_value = mock_response response = client.list_models() - mock_get.assert_called_once_with("http://test-api/llms") - assert response == {"response": "test"} + mock_request.assert_called_once_with( + "GET", "http://test-api/v1/models", params={}, timeout=30.0 + ) + assert isinstance(response, ListModelsResponse) + assert response.provider == "openai" + assert response.object == "list" + assert len(response.data) == 1 + assert response.data[0].id == "gpt-4" -@patch("requests.Session.get") -def test_list_provider_models(mock_get, client, mock_response): - """Test listing models for a specific provider""" +@patch("requests.Session.request") +def test_list_models_with_provider(mock_request, client): + """Test listing models with provider filter""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None mock_response.json.return_value = { "provider": "openai", - "models": [{"name": "gpt-4"}, {"name": "gpt-3.5-turbo"}], + "object": "list", + "data": [ + { + "id": "gpt-4", + "object": "model", + "created": 1687882410, + "owned_by": "openai", + "served_by": "openai", + }, + { + "id": "gpt-3.5-turbo", + "object": "model", + "created": 1687882410, + "owned_by": "openai", + "served_by": "openai", + }, + ], } - mock_get.return_value = mock_response + mock_request.return_value = mock_response - response = client.list_providers_models(Provider.OPENAI) + response = client.list_models("openai") - mock_get.assert_called_once_with("http://test-api/llms/openai") + mock_request.assert_called_once_with( + "GET", "http://test-api/v1/models", params={"provider": "openai"}, timeout=30.0 + ) + assert isinstance(response, ListModelsResponse) + assert response.provider == "openai" + assert response.object == "list" + assert len(response.data) == 2 + assert response.data[0].id == "gpt-4" + assert response.data[1].id == "gpt-3.5-turbo" - assert response == { - "provider": "openai", - "models": [{"name": "gpt-4"}, {"name": "gpt-3.5-turbo"}], - } +@patch("requests.Session.request") +def test_list_models_error(mock_request, client): + """Test error handling when listing models""" + mock_request.side_effect = requests.exceptions.HTTPError("Provider not found") -@patch("requests.Session.get") -def test_list_provider_models_error(mock_get, client): - """Test error handling when listing provider models""" - mock_get.side_effect = requests.exceptions.HTTPError("Provider not found") + with pytest.raises(InferenceGatewayError, match="Request failed"): + client.list_models("ollama") - with pytest.raises(requests.exceptions.HTTPError, match="Provider not found"): - client.list_providers_models(Provider.OLLAMA) + mock_request.assert_called_once_with( + "GET", "http://test-api/v1/models", params={"provider": "ollama"}, timeout=30.0 + ) - mock_get.assert_called_once_with("http://test-api/llms/ollama") +@patch("requests.Session.request") +def test_create_chat_completion(mock_request, client): + """Test chat completion""" + messages = [ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="Hello!"), + ] -@patch("requests.Session.post") -def test_generate_content(mock_post, client, mock_response): - """Test content generation""" - messages = [Message(Role.SYSTEM, "You are a helpful assistant"), Message(Role.USER, "Hello!")] + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "Hello! How can I help you today?"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30}, + } + mock_request.return_value = mock_response - mock_post.return_value = mock_response - response = client.generate_content(Provider.OPENAI, "gpt-4", messages) + response = client.create_chat_completion("gpt-4", messages, "openai") - mock_post.assert_called_once_with( - "http://test-api/llms/openai/generate", + mock_request.assert_called_once_with( + "POST", + "http://test-api/v1/chat/completions", + params={"provider": "openai"}, json={ "model": "gpt-4", "messages": [ {"role": "system", "content": "You are a helpful assistant"}, {"role": "user", "content": "Hello!"}, ], + "stream": False, }, + timeout=30.0, ) - assert response == {"response": "test"} + assert isinstance(response, CreateChatCompletionResponse) + assert response.id == "chatcmpl-123" -@patch("requests.Session.get") -def test_health_check(mock_get, client): +@patch("requests.Session.request") +def test_health_check(mock_request, client): """Test health check endpoint""" mock_response = Mock() mock_response.status_code = 200 - mock_get.return_value = mock_response + mock_response.raise_for_status.return_value = None + mock_request.return_value = mock_response assert client.health_check() is True - mock_get.assert_called_once_with("http://test-api/health") + mock_request.assert_called_once_with("GET", "http://test-api/health", timeout=30.0) - # Test unhealthy response mock_response.status_code = 500 + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Server error") assert client.health_check() is False -def test_message_to_dict(): - """Test Message class serialization""" - message = Message(Role.USER, "Hello!") - assert message.to_dict() == {"role": "user", "content": "Hello!"} +def test_message_model(): + """Test Message model creation and serialization""" + message = Message(role="user", content="Hello!") + assert message.role == "user" + assert message.content == "Hello!" + + message_dict = message.model_dump() + assert message_dict["role"] == "user" + assert message_dict["content"] == "Hello!" + +def test_provider_values(): + """Test Provider values""" + provider = Provider("openai") + assert provider.root == "openai" -def test_provider_enum(): - """Test Provider enum values""" - assert Provider.OPENAI == "openai" - assert Provider.OLLAMA == "ollama" - assert Provider.GROQ == "groq" - assert Provider.CLOUDFLARE == "cloudflare" - assert Provider.COHERE == "cohere" + with pytest.raises(ValueError): + Provider("invalid_provider") -def test_role_enum(): - """Test Role enum values""" - assert Role.SYSTEM == "system" - assert Role.USER == "user" - assert Role.ASSISTANT == "assistant" +def test_message_role_values(): + """Test MessageRole values""" + role = MessageRole("system") + assert role.root == "system" + + role = MessageRole("user") + assert role.root == "user" + + role = MessageRole("assistant") + assert role.root == "assistant" + + with pytest.raises(ValueError): + MessageRole("invalid_role") @pytest.mark.parametrize("use_sse,expected_format", [(True, "sse"), (False, "json")]) -@patch("requests.Session.post") -def test_generate_content_stream(mock_post, client, use_sse, expected_format): - """Test streaming content generation with both raw JSON and SSE formats""" +@patch("requests.Session.request") +def test_create_chat_completion_stream(mock_request, client, use_sse, expected_format): + """Test streaming chat completion with both raw JSON and SSE formats""" mock_response = Mock() mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None if use_sse: - mock_response.raw = Mock() - mock_response.raw.read = ( - Mock( - side_effect=[ - b"event: message-start\n", - b'data: {"role":"assistant"}\n\n', - b"event: content-delta\n", - b'data: {"content":"Hello"}\n\n', - b"event: content-delta\n", - b'data: {"content":" world!"}\n\n', - b"event: message-end\n", - b'data: {"content":""}\n\n', - b"", - ] - ) - if use_sse - else Mock( - side_effect=[ - b'{"role":"assistant","model":"gpt-4","content":"Hello"}\n', - b'{"role":"assistant","model":"gpt-4","content":" world!"}\n', - b"", - ] - ) - ) mock_response.iter_lines.return_value = [ b"event: message-start", b'data: {"role":"assistant"}', @@ -195,81 +266,71 @@ def test_generate_content_stream(mock_post, client, use_sse, expected_format): ] else: mock_response.iter_lines.return_value = [ - b'{"role":"assistant","model":"gpt-4","content":"Hello"}', - b'{"role":"assistant","model":"gpt-4","content":" world!"}', + b'data: {"choices":[{"delta":{"role":"assistant"}}],"model":"gpt-4"}', + b'data: {"choices":[{"delta":{"content":"Hello"}}],"model":"gpt-4"}', + b'data: {"choices":[{"delta":{"content":" world!"}}],"model":"gpt-4"}', + b"data: [DONE]", ] - mock_post.return_value = mock_response + mock_request.return_value = mock_response - messages = [Message(Role.USER, "What's up?")] + messages = [Message(role="user", content="What's up?")] chunks = list( - client.generate_content_stream( - provider=Provider.OPENAI, model="gpt-4", messages=messages, use_sse=use_sse + client.create_chat_completion_stream( + model="gpt-4", messages=messages, provider="openai", use_sse=use_sse ) ) - mock_post.assert_called_once_with( - "http://test-api/llms/openai/generate", + mock_request.assert_called_once_with( + "POST", + "http://test-api/v1/chat/completions", + data=None, json={ "model": "gpt-4", "messages": [{"role": "user", "content": "What's up?"}], "stream": True, - "ssevents": use_sse, }, + params={"provider": "openai"}, stream=True, ) if expected_format == "sse": assert len(chunks) == 4 - assert chunks[0] == {"event": "message-start", "data": {"role": "assistant"}} - assert chunks[1] == {"event": "content-delta", "data": {"content": "Hello"}} - assert chunks[2] == {"event": "content-delta", "data": {"content": " world!"}} - assert chunks[3] == {"event": "message-end", "data": {"content": ""}} + assert chunks[0].event == "message-start" + assert chunks[0].data == '{"role":"assistant"}' + assert chunks[1].event == "content-delta" + assert chunks[1].data == '{"content":"Hello"}' + assert chunks[2].event == "content-delta" + assert chunks[2].data == '{"content":" world!"}' + assert chunks[3].event == "message-end" + assert chunks[3].data == '{"content":""}' else: - assert len(chunks) == 2 - assert isinstance(chunks[0], ResponseTokens) - assert isinstance(chunks[1], ResponseTokens) - assert chunks[0].role == "assistant" - assert chunks[0].model == "gpt-4" - assert chunks[0].content == "Hello" - assert chunks[1].content == " world!" - - for chunk in chunks: - if use_sse: - assert isinstance(chunk, dict) - assert "event" in chunk - else: - assert isinstance(chunk, ResponseTokens) + assert len(chunks) == 3 + assert "choices" in chunks[0] + assert "delta" in chunks[0]["choices"][0] + assert chunks[0]["choices"][0]["delta"]["role"] == "assistant" + assert chunks[1]["choices"][0]["delta"]["content"] == "Hello" + assert chunks[2]["choices"][0]["delta"]["content"] == " world!" @pytest.mark.parametrize( "error_scenario", [ - {"status_code": 500, "error": Exception("API Error"), "expected_match": "API Error"}, + {"status_code": 500, "error": Exception("API Error"), "expected_match": "Request failed"}, { "status_code": 401, "error": requests.exceptions.HTTPError("Unauthorized"), - "expected_match": "Unauthorized", + "expected_match": "Request failed", }, { "status_code": 400, "error": requests.exceptions.HTTPError("Invalid model"), - "expected_match": "Invalid model", - }, - { - "status_code": 200, - "iter_lines": [b'{"invalid": "json'], - "expected_match": r"Invalid JSON response: \{\"invalid\": \"json.*column \d+.*char \d+", - }, - { - "status_code": 200, - "iter_lines": [b"{}"], - "expected_match": r"Missing required arguments: role, model, content", + "expected_match": "Request failed", }, ], ) -@patch("requests.Session.post") -def test_generate_content_stream_error(mock_post, client, test_params, error_scenario): +@patch("requests.Session.request") +def test_create_chat_completion_stream_error(mock_request, client, test_params, error_scenario): """Test error handling during streaming for various scenarios""" mock_response = Mock() mock_response.status_code = error_scenario["status_code"] @@ -280,25 +341,221 @@ def test_generate_content_stream_error(mock_post, client, test_params, error_sce if "iter_lines" in error_scenario: mock_response.iter_lines.return_value = error_scenario["iter_lines"] - mock_post.return_value = mock_response + mock_request.return_value = mock_response use_sse = error_scenario.get("use_sse", False) - with pytest.raises(Exception, match=error_scenario["expected_match"]): + with pytest.raises(InferenceGatewayError, match=error_scenario["expected_match"]): list( - client.generate_content_stream( - provider=test_params["provider"], + client.create_chat_completion_stream( model=test_params["model"], messages=[test_params["message"]], + provider=test_params["provider"], use_sse=use_sse, ) ) - expected_url = f"{test_params['api_url']}{test_params['endpoint']}" - expected_payload = { - "model": test_params["model"], - "messages": [test_params["message"].to_dict()], - "stream": True, - "ssevents": use_sse, + mock_request.assert_called_once_with( + "POST", + "http://test-api/v1/chat/completions", + data=None, + json={ + "model": test_params["model"], + "messages": [test_params["message"].model_dump(exclude_none=True)], + "stream": True, + }, + params={"provider": test_params["provider"]}, + stream=True, + ) + + +@patch("requests.Session.request") +def test_proxy_request(mock_request, client): + """Test proxy request to provider""" + + mock_resp = Mock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"response": "test"} + mock_resp.raise_for_status.return_value = None + mock_request.return_value = mock_resp + + response = client.proxy_request( + provider="openai", path="completions", method="POST", json_data={"prompt": "Hello"} + ) + + mock_request.assert_called_once_with( + "POST", "http://test-api/proxy/openai/completions", json={"prompt": "Hello"}, timeout=30.0 + ) + + assert response == {"response": "test"} + + +def test_exception_hierarchy(): + """Test exception hierarchy and error handling""" + + base_error = InferenceGatewayError("Base error") + assert str(base_error) == "Base error" + assert isinstance(base_error, Exception) + + api_error = InferenceGatewayAPIError("API error", status_code=400) + assert str(api_error) == "API error" + assert api_error.status_code == 400 + assert isinstance(api_error, InferenceGatewayError) + + validation_error = InferenceGatewayValidationError("Validation error") + assert str(validation_error) == "Validation error" + assert isinstance(validation_error, InferenceGatewayError) + + +def test_context_manager(): + """Test client as context manager""" + with InferenceGatewayClient("http://test-api") as client: + assert client.base_url == "http://test-api" + assert client.session is not None + + +@patch("requests.Session.request") +def test_client_with_custom_timeout(mock_request): + """Test client with custom timeout settings""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = { + "provider": "openai", + "object": "list", + "data": [ + { + "id": "gpt-4", + "object": "model", + "created": 1687882410, + "owned_by": "openai", + "served_by": "openai", + } + ], + } + mock_request.return_value = mock_response + + client = InferenceGatewayClient("http://test-api", timeout=30) + client.list_models() + + mock_request.assert_called_once_with("GET", "http://test-api/v1/models", params={}, timeout=30) + + +def test_sse_event_parsing(): + """Test SSEvent model parsing""" + event = SSEvent(event="content-delta", data='{"content": "Hello"}', retry=None) + assert event.event == "content-delta" + assert event.data == '{"content": "Hello"}' + + event_dict = event.model_dump() + assert event_dict["event"] == "content-delta" + assert event_dict["data"] == '{"content": "Hello"}' + + +def test_list_tools(): + """Test listing MCP tools""" + mock_response_data = { + "object": "list", + "data": [ + { + "name": "read_file", + "description": "Read content from a file", + "server": "http://mcp-filesystem-server:8083/mcp", + "input_schema": { + "type": "object", + "properties": { + "path": {"type": "string", "description": "Path to the file to read"} + }, + "required": ["path"], + }, + }, + { + "name": "write_file", + "description": "Write content to a file", + "server": "http://mcp-filesystem-server:8083/mcp", + "input_schema": { + "type": "object", + "properties": { + "path": {"type": "string", "description": "Path to the file to write"}, + "content": {"type": "string", "description": "Content to write"}, + }, + "required": ["path", "content"], + }, + }, + ], } - mock_post.assert_called_once_with(expected_url, json=expected_payload, stream=True) + with patch("requests.Session.request") as mock_request: + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_response_data + mock_response.raise_for_status.return_value = None + mock_request.return_value = mock_response + + client = InferenceGatewayClient("http://test-api/v1", "test-token") + response = client.list_tools() + + assert response.object == "list" + assert len(response.data) == 2 + + first_tool = response.data[0] + assert first_tool.name == "read_file" + assert first_tool.description == "Read content from a file" + assert first_tool.server == "http://mcp-filesystem-server:8083/mcp" + assert "path" in first_tool.input_schema["properties"] + + second_tool = response.data[1] + assert second_tool.name == "write_file" + assert second_tool.description == "Write content to a file" + assert second_tool.server == "http://mcp-filesystem-server:8083/mcp" + assert "path" in second_tool.input_schema["properties"] + assert "content" in second_tool.input_schema["properties"] + + mock_request.assert_called_once_with( + "GET", + "http://test-api/v1/mcp/tools", + timeout=30.0, + ) + + +def test_list_tools_error(): + """Test list_tools method with API error""" + with patch("requests.Session.request") as mock_request: + mock_response = Mock() + mock_response.status_code = 403 + mock_response.json.return_value = {"error": "MCP not exposed"} + mock_response.raise_for_status.side_effect = requests.HTTPError( + "403 Client Error: Forbidden" + ) + mock_request.return_value = mock_response + + client = InferenceGatewayClient("http://test-api/v1", "test-token") + + with pytest.raises(InferenceGatewayAPIError) as excinfo: + client.list_tools() + + assert "Request failed" in str(excinfo.value) + + mock_request.assert_called_once_with( + "GET", + "http://test-api/v1/mcp/tools", + timeout=30.0, + ) + + +def test_list_tools_validation_error(): + """Test list_tools method with validation error""" + invalid_response_data = {"object": "invalid", "data": "not an array"} + + with patch("requests.Session.request") as mock_request: + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = invalid_response_data + mock_response.raise_for_status.return_value = None + mock_request.return_value = mock_response + + client = InferenceGatewayClient("http://test-api/v1", "test-token") + + with pytest.raises(InferenceGatewayValidationError) as excinfo: + client.list_tools() + + assert "Response validation failed" in str(excinfo.value)