Skip to content

Commit 0166d9e

Browse files
committed
feat(client): enhance InferenceGatewayClient with support for multiple providers and message handling
Also add health check method. docs(README): update documentation to include health check and new client usage examples test: add unit tests for client initialization, health check, and message serialization Signed-off-by: Eden Reich <eden.reich@gmail.com>
1 parent ab5202b commit 0166d9e

4 files changed

Lines changed: 208 additions & 42 deletions

File tree

README.md

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ An SDK written in Python for the [Inference Gateway](https://github.com/edenreic
88
- [Creating a Client](#creating-a-client)
99
- [Listing Models](#listing-models)
1010
- [Generating Content](#generating-content)
11+
- [Health Check](#health-check)
1112
- [License](#license)
1213

1314
## Installation
@@ -21,17 +22,12 @@ pip install inference-gateway
2122
### Creating a Client
2223

2324
```python
24-
from inference_gateway.client import InferenceGatewayClient
25+
from inference_gateway.client import InferenceGatewayClient, Provider
2526

27+
client = InferenceGatewayClient("http://localhost:8080")
2628

27-
if __name__ == "__main__":
28-
client = InferenceGatewayClient("http://localhost:8080")
29-
30-
models = client.list_models()
31-
print("Available models:", models)
32-
33-
response = client.generate_content("providerName", "modelName", "your prompt here")
34-
print("Generated content:", response["Response"]["Content"])
29+
# With authentication token(optional)
30+
client = InferenceGatewayClient("http://localhost:8080", token="your-token")
3531
```
3632

3733
### Listing Models
@@ -48,8 +44,28 @@ print("Available models:", models)
4844
To generate content using a model, use the generate_content method:
4945

5046
```python
51-
response = client.generate_content("providerName", "modelName", "your prompt here")
52-
print("Generated content:", response["Response"]["Content"])
47+
from inference_gateway.client import Provider, Role, Message
48+
49+
messages = [
50+
Message(Role.SYSTEM, "You are a helpful assistant"),
51+
Message(Role.USER, "Hello!"),
52+
]
53+
54+
response = client.generate_content(
55+
provider=Provider.OPENAI,
56+
model="gpt-4",
57+
messages=messages
58+
)
59+
print("Assistant:", response["choices"][0]["message"]["content"])
60+
```
61+
62+
### Health Check
63+
64+
To check the health of the API, use the health_check method:
65+
66+
```python
67+
is_healthy = client.health_check()
68+
print("API Status:", "Healthy" if is_healthy else "Unhealthy")
5369
```
5470

5571
## License

Taskfile.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ tasks:
77
cmds:
88
- curl -o openapi.yaml https://raw.githubusercontent.com/inference-gateway/inference-gateway/refs/heads/main/openapi.yaml
99

10+
test:
11+
desc: Run tests
12+
cmds:
13+
- pytest tests/
14+
1015
clean:
1116
desc: Clean up
1217
cmds:

inference_gateway/client.py

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,92 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
from typing import List, Dict, Optional
14
import requests
25

36

7+
class Provider(str, Enum):
8+
"""Supported LLM providers"""
9+
OLLAMA = "ollama"
10+
GROQ = "groq"
11+
OPENAI = "openai"
12+
GOOGLE = "google"
13+
CLOUDFLARE = "cloudflare"
14+
COHERE = "cohere"
15+
16+
17+
class Role(str, Enum):
18+
"""Message role types"""
19+
SYSTEM = "system"
20+
USER = "user"
21+
ASSISTANT = "assistant"
22+
23+
24+
@dataclass
25+
class Message:
26+
role: Role
27+
content: str
28+
29+
def to_dict(self) -> Dict[str, str]:
30+
"""Convert message to dictionary format with string values"""
31+
return {
32+
"role": self.role.value,
33+
"content": self.content
34+
}
35+
36+
37+
class Model:
38+
"""Represents an LLM model"""
39+
40+
def __init__(self, id: str, object: str, owned_by: str, created: int):
41+
self.id = id
42+
self.object = object
43+
self.owned_by = owned_by
44+
self.created = created
45+
46+
47+
class ProviderModels:
48+
"""Groups models by provider"""
49+
50+
def __init__(self, provider: Provider, models: List[Model]):
51+
self.provider = provider
52+
self.models = models
53+
54+
455
class InferenceGatewayClient:
5-
def __init__(self, base_url):
6-
self.base_url = base_url
56+
"""Client for interacting with the Inference Gateway API"""
57+
58+
def __init__(self, base_url: str, token: Optional[str] = None):
59+
"""Initialize the client with base URL and optional auth token"""
60+
self.base_url = base_url.rstrip('/')
61+
self.session = requests.Session()
62+
if token:
63+
self.session.headers.update({"Authorization": f"Bearer {token}"})
764

8-
def list_models(self):
9-
response = requests.get(f"{self.base_url}/llms")
65+
def list_models(self) -> List[ProviderModels]:
66+
"""List all available language models"""
67+
response = self.session.get(f"{self.base_url}/llms")
1068
response.raise_for_status()
1169
return response.json()
1270

13-
def generate_content(self, provider, model, prompt):
14-
payload = {"modelName": model, "prompt": prompt}
15-
response = requests.post(f"{self.base_url}/llms/{provider}/generate", json=payload)
71+
def generate_content(
72+
self,
73+
provider: Provider,
74+
model: str,
75+
messages: List[Message]
76+
) -> Dict:
77+
payload = {
78+
"model": model,
79+
"messages": [msg.to_dict() for msg in messages]
80+
}
81+
82+
response = self.session.post(
83+
f"{self.base_url}/llms/{provider.value}/generate",
84+
json=payload
85+
)
1686
response.raise_for_status()
1787
return response.json()
88+
89+
def health_check(self) -> bool:
90+
"""Check if the API is healthy"""
91+
response = self.session.get(f"{self.base_url}/health")
92+
return response.status_code == 200

tests/test_client.py

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,102 @@
1-
import unittest
2-
from unittest.mock import patch, Mock
3-
from inference_gateway.client import InferenceGatewayClient
1+
import pytest
2+
from unittest.mock import Mock, patch
3+
from inference_gateway.client import InferenceGatewayClient, Provider, Role, Message
44

55

6-
class TestInferenceGatewayClient(unittest.TestCase):
7-
def setUp(self):
8-
self.client = InferenceGatewayClient("http://localhost:8080")
6+
@pytest.fixture
7+
def client():
8+
"""Create a test client instance"""
9+
return InferenceGatewayClient("http://test-api")
910

10-
@patch("inference_gateway.client.requests.get")
11-
def test_list_models(self, mock_get):
12-
mock_response = Mock()
13-
mock_response.json.return_value = {"models": ["model1", "model2"]}
14-
mock_response.raise_for_status = Mock()
15-
mock_get.return_value = mock_response
1611

17-
models = self.client.list_models()
18-
self.assertEqual(models, {"models": ["model1", "model2"]})
12+
@pytest.fixture
13+
def mock_response():
14+
"""Create a mock response"""
15+
mock = Mock()
16+
mock.status_code = 200
17+
mock.json.return_value = {"response": "test"}
18+
return mock
1919

20-
@patch("inference_gateway.client.requests.post")
21-
def test_generate_content(self, mock_post):
22-
mock_response = Mock()
23-
mock_response.json.return_value = {"Response": {"Content": "generated content"}}
24-
mock_response.raise_for_status = Mock()
25-
mock_post.return_value = mock_response
2620

27-
response = self.client.generate_content("provider", "model", "prompt")
28-
self.assertEqual(response, {"Response": {"Content": "generated content"}})
21+
def test_client_initialization():
22+
"""Test client initialization with and without token"""
23+
client = InferenceGatewayClient("http://test-api")
24+
assert client.base_url == "http://test-api"
25+
assert "Authorization" not in client.session.headers
2926

27+
client_with_token = InferenceGatewayClient(
28+
"http://test-api", token="test-token")
29+
assert "Authorization" in client_with_token.session.headers
30+
assert client_with_token.session.headers["Authorization"] == "Bearer test-token"
3031

31-
if __name__ == "__main__":
32-
unittest.main()
32+
33+
@patch("requests.Session.get")
34+
def test_list_models(mock_get, client, mock_response):
35+
"""Test listing available models"""
36+
mock_get.return_value = mock_response
37+
response = client.list_models()
38+
39+
mock_get.assert_called_once_with("http://test-api/llms")
40+
assert response == {"response": "test"}
41+
42+
43+
@patch("requests.Session.post")
44+
def test_generate_content(mock_post, client, mock_response):
45+
"""Test content generation"""
46+
messages = [
47+
Message(Role.SYSTEM, "You are a helpful assistant"),
48+
Message(Role.USER, "Hello!")
49+
]
50+
51+
mock_post.return_value = mock_response
52+
response = client.generate_content(Provider.OPENAI, "gpt-4", messages)
53+
54+
mock_post.assert_called_once_with(
55+
"http://test-api/llms/openai/generate",
56+
json={
57+
"model": "gpt-4",
58+
"messages": [
59+
{"role": "system", "content": "You are a helpful assistant"},
60+
{"role": "user", "content": "Hello!"}
61+
]
62+
}
63+
)
64+
assert response == {"response": "test"}
65+
66+
67+
@patch("requests.Session.get")
68+
def test_health_check(mock_get, client):
69+
"""Test health check endpoint"""
70+
mock_response = Mock()
71+
mock_response.status_code = 200
72+
mock_get.return_value = mock_response
73+
74+
assert client.health_check() is True
75+
mock_get.assert_called_once_with("http://test-api/health")
76+
77+
# Test unhealthy response
78+
mock_response.status_code = 500
79+
assert client.health_check() is False
80+
81+
82+
def test_message_to_dict():
83+
"""Test Message class serialization"""
84+
message = Message(Role.USER, "Hello!")
85+
assert message.to_dict() == {"role": "user", "content": "Hello!"}
86+
87+
88+
def test_provider_enum():
89+
"""Test Provider enum values"""
90+
assert Provider.OPENAI == "openai"
91+
assert Provider.OLLAMA == "ollama"
92+
assert Provider.GROQ == "groq"
93+
assert Provider.GOOGLE == "google"
94+
assert Provider.CLOUDFLARE == "cloudflare"
95+
assert Provider.COHERE == "cohere"
96+
97+
98+
def test_role_enum():
99+
"""Test Role enum values"""
100+
assert Role.SYSTEM == "system"
101+
assert Role.USER == "user"
102+
assert Role.ASSISTANT == "assistant"

0 commit comments

Comments
 (0)