Skip to content

Commit 642ed74

Browse files
committed
chore(lint): add linting task using Black for code formatting
Signed-off-by: Eden Reich <[email protected]>
1 parent 0166d9e commit 642ed74

File tree

3 files changed

+17
-26
lines changed

3 files changed

+17
-26
lines changed

Taskfile.yml

+5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ tasks:
77
cmds:
88
- curl -o openapi.yaml https://raw.githubusercontent.com/inference-gateway/inference-gateway/refs/heads/main/openapi.yaml
99

10+
lint:
11+
desc: Lint the code
12+
cmds:
13+
- black inference_gateway/ tests/
14+
1015
test:
1116
desc: Run tests
1217
cmds:

inference_gateway/client.py

+7-17
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
class Provider(str, Enum):
88
"""Supported LLM providers"""
9+
910
OLLAMA = "ollama"
1011
GROQ = "groq"
1112
OPENAI = "openai"
@@ -16,6 +17,7 @@ class Provider(str, Enum):
1617

1718
class Role(str, Enum):
1819
"""Message role types"""
20+
1921
SYSTEM = "system"
2022
USER = "user"
2123
ASSISTANT = "assistant"
@@ -28,10 +30,7 @@ class Message:
2830

2931
def to_dict(self) -> Dict[str, str]:
3032
"""Convert message to dictionary format with string values"""
31-
return {
32-
"role": self.role.value,
33-
"content": self.content
34-
}
33+
return {"role": self.role.value, "content": self.content}
3534

3635

3736
class Model:
@@ -57,7 +56,7 @@ class InferenceGatewayClient:
5756

5857
def __init__(self, base_url: str, token: Optional[str] = None):
5958
"""Initialize the client with base URL and optional auth token"""
60-
self.base_url = base_url.rstrip('/')
59+
self.base_url = base_url.rstrip("/")
6160
self.session = requests.Session()
6261
if token:
6362
self.session.headers.update({"Authorization": f"Bearer {token}"})
@@ -68,20 +67,11 @@ def list_models(self) -> List[ProviderModels]:
6867
response.raise_for_status()
6968
return response.json()
7069

71-
def generate_content(
72-
self,
73-
provider: Provider,
74-
model: str,
75-
messages: List[Message]
76-
) -> Dict:
77-
payload = {
78-
"model": model,
79-
"messages": [msg.to_dict() for msg in messages]
80-
}
70+
def generate_content(self, provider: Provider, model: str, messages: List[Message]) -> Dict:
71+
payload = {"model": model, "messages": [msg.to_dict() for msg in messages]}
8172

8273
response = self.session.post(
83-
f"{self.base_url}/llms/{provider.value}/generate",
84-
json=payload
74+
f"{self.base_url}/llms/{provider.value}/generate", json=payload
8575
)
8676
response.raise_for_status()
8777
return response.json()

tests/test_client.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ def test_client_initialization():
2424
assert client.base_url == "http://test-api"
2525
assert "Authorization" not in client.session.headers
2626

27-
client_with_token = InferenceGatewayClient(
28-
"http://test-api", token="test-token")
27+
client_with_token = InferenceGatewayClient("http://test-api", token="test-token")
2928
assert "Authorization" in client_with_token.session.headers
3029
assert client_with_token.session.headers["Authorization"] == "Bearer test-token"
3130

@@ -43,10 +42,7 @@ def test_list_models(mock_get, client, mock_response):
4342
@patch("requests.Session.post")
4443
def test_generate_content(mock_post, client, mock_response):
4544
"""Test content generation"""
46-
messages = [
47-
Message(Role.SYSTEM, "You are a helpful assistant"),
48-
Message(Role.USER, "Hello!")
49-
]
45+
messages = [Message(Role.SYSTEM, "You are a helpful assistant"), Message(Role.USER, "Hello!")]
5046

5147
mock_post.return_value = mock_response
5248
response = client.generate_content(Provider.OPENAI, "gpt-4", messages)
@@ -57,9 +53,9 @@ def test_generate_content(mock_post, client, mock_response):
5753
"model": "gpt-4",
5854
"messages": [
5955
{"role": "system", "content": "You are a helpful assistant"},
60-
{"role": "user", "content": "Hello!"}
61-
]
62-
}
56+
{"role": "user", "content": "Hello!"},
57+
],
58+
},
6359
)
6460
assert response == {"response": "test"}
6561

0 commit comments

Comments
 (0)