6
6
7
7
class Provider (str , Enum ):
8
8
"""Supported LLM providers"""
9
+
9
10
OLLAMA = "ollama"
10
11
GROQ = "groq"
11
12
OPENAI = "openai"
@@ -16,6 +17,7 @@ class Provider(str, Enum):
16
17
17
18
class Role (str , Enum ):
18
19
"""Message role types"""
20
+
19
21
SYSTEM = "system"
20
22
USER = "user"
21
23
ASSISTANT = "assistant"
@@ -28,10 +30,7 @@ class Message:
28
30
29
31
def to_dict (self ) -> Dict [str , str ]:
30
32
"""Convert message to dictionary format with string values"""
31
- return {
32
- "role" : self .role .value ,
33
- "content" : self .content
34
- }
33
+ return {"role" : self .role .value , "content" : self .content }
35
34
36
35
37
36
class Model :
@@ -57,7 +56,7 @@ class InferenceGatewayClient:
57
56
58
57
def __init__ (self , base_url : str , token : Optional [str ] = None ):
59
58
"""Initialize the client with base URL and optional auth token"""
60
- self .base_url = base_url .rstrip ('/' )
59
+ self .base_url = base_url .rstrip ("/" )
61
60
self .session = requests .Session ()
62
61
if token :
63
62
self .session .headers .update ({"Authorization" : f"Bearer { token } " })
@@ -68,20 +67,11 @@ def list_models(self) -> List[ProviderModels]:
68
67
response .raise_for_status ()
69
68
return response .json ()
70
69
71
- def generate_content (
72
- self ,
73
- provider : Provider ,
74
- model : str ,
75
- messages : List [Message ]
76
- ) -> Dict :
77
- payload = {
78
- "model" : model ,
79
- "messages" : [msg .to_dict () for msg in messages ]
80
- }
70
+ def generate_content (self , provider : Provider , model : str , messages : List [Message ]) -> Dict :
71
+ payload = {"model" : model , "messages" : [msg .to_dict () for msg in messages ]}
81
72
82
73
response = self .session .post (
83
- f"{ self .base_url } /llms/{ provider .value } /generate" ,
84
- json = payload
74
+ f"{ self .base_url } /llms/{ provider .value } /generate" , json = payload
85
75
)
86
76
response .raise_for_status ()
87
77
return response .json ()
0 commit comments