Skip to content

Commit 1f7e87c

Browse files
committed
feat: update Cerebras inference provider to support dynamic model listing
- update Cerebras to use OpenAIMixin - enable openai completions tests - enable openai chat completions tests - disable with n > 1 tests - add recording for --setup cerebras --subdirs inference --pattern openai test with: `./scripts/integration-tests.sh --stack-config server:ci-tests --setup cerebras --subdirs inference --pattern openai`
1 parent 521865c commit 1f7e87c

File tree

16 files changed

+3369
-14
lines changed

16 files changed

+3369
-14
lines changed

docs/source/providers/inference/remote_cerebras.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
99
| Field | Type | Required | Default | Description |
1010
|-------|------|----------|---------|-------------|
1111
| `base_url` | `<class 'str'>` | No | https://api.cerebras.ai | Base URL for the Cerebras API |
12-
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Cerebras API Key |
12+
| `api_key` | `<class 'pydantic.types.SecretStr'>` | No | | Cerebras API Key |
1313

1414
## Sample Configuration
1515

llama_stack/providers/remote/inference/cerebras/cerebras.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# the root directory of this source tree.
66

77
from collections.abc import AsyncGenerator
8+
from urllib.parse import urljoin
89

910
from cerebras.cloud.sdk import AsyncCerebras
1011

@@ -35,14 +36,13 @@
3536
ModelRegistryHelper,
3637
)
3738
from llama_stack.providers.utils.inference.openai_compat import (
38-
OpenAIChatCompletionToLlamaStackMixin,
39-
OpenAICompletionToLlamaStackMixin,
4039
get_sampling_options,
4140
process_chat_completion_response,
4241
process_chat_completion_stream_response,
4342
process_completion_response,
4443
process_completion_stream_response,
4544
)
45+
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
4646
from llama_stack.providers.utils.inference.prompt_adapter import (
4747
chat_completion_request_to_prompt,
4848
completion_request_to_prompt,
@@ -53,10 +53,9 @@
5353

5454

5555
class CerebrasInferenceAdapter(
56+
OpenAIMixin,
5657
ModelRegistryHelper,
5758
Inference,
58-
OpenAIChatCompletionToLlamaStackMixin,
59-
OpenAICompletionToLlamaStackMixin,
6059
):
6160
def __init__(self, config: CerebrasImplConfig) -> None:
6261
ModelRegistryHelper.__init__(
@@ -66,11 +65,17 @@ def __init__(self, config: CerebrasImplConfig) -> None:
6665
self.config = config
6766

6867
# TODO: make this use provider data, etc. like other providers
69-
self.client = AsyncCerebras(
68+
self._cerebras_client = AsyncCerebras(
7069
base_url=self.config.base_url,
7170
api_key=self.config.api_key.get_secret_value(),
7271
)
7372

73+
def get_api_key(self) -> str:
74+
return self.config.api_key.get_secret_value()
75+
76+
def get_base_url(self) -> str:
77+
return urljoin(self.config.base_url, "v1")
78+
7479
async def initialize(self) -> None:
7580
return
7681

@@ -107,14 +112,14 @@ async def completion(
107112
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
108113
params = await self._get_params(request)
109114

110-
r = await self.client.completions.create(**params)
115+
r = await self._cerebras_client.completions.create(**params)
111116

112117
return process_completion_response(r)
113118

114119
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
115120
params = await self._get_params(request)
116121

117-
stream = await self.client.completions.create(**params)
122+
stream = await self._cerebras_client.completions.create(**params)
118123

119124
async for chunk in process_completion_stream_response(stream):
120125
yield chunk
@@ -156,14 +161,14 @@ async def chat_completion(
156161
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse:
157162
params = await self._get_params(request)
158163

159-
r = await self.client.completions.create(**params)
164+
r = await self._cerebras_client.completions.create(**params)
160165

161166
return process_chat_completion_response(r, request)
162167

163168
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
164169
params = await self._get_params(request)
165170

166-
stream = await self.client.completions.create(**params)
171+
stream = await self._cerebras_client.completions.create(**params)
167172

168173
async for chunk in process_chat_completion_stream_response(stream, request):
169174
yield chunk

llama_stack/providers/remote/inference/cerebras/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ class CerebrasImplConfig(BaseModel):
2020
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
2121
description="Base URL for the Cerebras API",
2222
)
23-
api_key: SecretStr | None = Field(
24-
default=os.environ.get("CEREBRAS_API_KEY"),
23+
api_key: SecretStr = Field(
24+
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
2525
description="Cerebras API Key",
2626
)
2727

tests/integration/inference/test_openai_completion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
4040
"inline::sentence-transformers",
4141
"inline::vllm",
4242
"remote::bedrock",
43-
"remote::cerebras",
4443
"remote::databricks",
4544
# Technically Nvidia does support OpenAI completions, but none of their hosted models
4645
# support both completions and chat completions endpoint and all the Llama models are
@@ -98,6 +97,8 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
9897
# the entered value was 2. Update the candidateCount value and try again.', 'status': 'INVALID_ARGUMENT'}
9998
"remote::tgi", # TGI ignores n param silently
10099
"remote::together", # `n` > 1 is not supported when streaming tokens. Please disable `stream`
100+
# Error code 400 - {'message': '"n" > 1 is not currently supported', 'type': 'invalid_request_error', 'param': 'n', 'code': 'wrong_api_format'}
101+
"remote::cerebras",
101102
):
102103
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
103104

@@ -109,7 +110,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode
109110
"inline::sentence-transformers",
110111
"inline::vllm",
111112
"remote::bedrock",
112-
"remote::cerebras",
113113
"remote::databricks",
114114
"remote::runpod",
115115
"remote::watsonx", # watsonx returns 404 when hitting the /openai/v1 endpoint
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
{
2+
"request": {
3+
"method": "POST",
4+
"url": "https://api.cerebras.ai/v1/v1/completions",
5+
"headers": {},
6+
"body": {
7+
"model": "llama-3.3-70b",
8+
"prompt": "Respond to this question and explain your answer. Complete the sentence using one word: Roses are red, violets are ",
9+
"stream": false,
10+
"extra_body": {}
11+
},
12+
"endpoint": "/v1/completions",
13+
"model": "llama-3.3-70b"
14+
},
15+
"response": {
16+
"body": {
17+
"__type__": "openai.types.completion.Completion",
18+
"__data__": {
19+
"id": "chatcmpl-6438a448-bbbd-4da1-af88-19390676b0e9",
20+
"choices": [
21+
{
22+
"finish_reason": "stop",
23+
"index": 0,
24+
"logprobs": null,
25+
"text": " blue, sugar is white, but my heart is ________________________.\nA) black\nB) pink\nC) blank\nD) broken\nMy answer is D) broken. This is because the traditional romantic poem has a positive tone until it comes to the heart, which represents the speaker's emotional state. The word \"broken\" shows that the speaker is hurting, which adds a element of sadness to the poem. This is a typical way to express sorrow or longing in poetry.\nThe best answer is D.<|eot_id|>"
26+
}
27+
],
28+
"created": 1758191351,
29+
"model": "llama-3.3-70b",
30+
"object": "text_completion",
31+
"system_fingerprint": "fp_c5ec625e72d41732d8fd",
32+
"usage": {
33+
"completion_tokens": 105,
34+
"prompt_tokens": 26,
35+
"total_tokens": 131,
36+
"completion_tokens_details": null,
37+
"prompt_tokens_details": {
38+
"audio_tokens": null,
39+
"cached_tokens": 0
40+
}
41+
},
42+
"time_info": {
43+
"queue_time": 0.00016155,
44+
"prompt_time": 0.001595551,
45+
"completion_time": 0.107480394,
46+
"total_time": 0.11038637161254883,
47+
"created": 1758191351
48+
}
49+
}
50+
},
51+
"is_streaming": false
52+
}
53+
}
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
{
2+
"request": {
3+
"method": "POST",
4+
"url": "https://api.cerebras.ai/v1/v1/chat/completions",
5+
"headers": {},
6+
"body": {
7+
"model": "llama-3.3-70b",
8+
"messages": [
9+
{
10+
"role": "user",
11+
"content": "What's the weather in Tokyo? Use the get_weather function to get the weather."
12+
}
13+
],
14+
"stream": true,
15+
"tools": [
16+
{
17+
"type": "function",
18+
"function": {
19+
"name": "get_weather",
20+
"description": "Get the weather in a given city",
21+
"parameters": {
22+
"type": "object",
23+
"properties": {
24+
"city": {
25+
"type": "string",
26+
"description": "The city to get the weather for"
27+
}
28+
}
29+
}
30+
}
31+
}
32+
]
33+
},
34+
"endpoint": "/v1/chat/completions",
35+
"model": "llama-3.3-70b"
36+
},
37+
"response": {
38+
"body": [
39+
{
40+
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
41+
"__data__": {
42+
"id": "chatcmpl-8b6a9499-1a5f-46dc-96b7-3d2b71eecd99",
43+
"choices": [
44+
{
45+
"delta": {
46+
"content": null,
47+
"function_call": null,
48+
"refusal": null,
49+
"role": "assistant",
50+
"tool_calls": null
51+
},
52+
"finish_reason": null,
53+
"index": 0,
54+
"logprobs": null
55+
}
56+
],
57+
"created": 1758191362,
58+
"model": "llama-3.3-70b",
59+
"object": "chat.completion.chunk",
60+
"service_tier": null,
61+
"system_fingerprint": "fp_c5ec625e72d41732d8fd",
62+
"usage": null
63+
}
64+
},
65+
{
66+
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
67+
"__data__": {
68+
"id": "chatcmpl-8b6a9499-1a5f-46dc-96b7-3d2b71eecd99",
69+
"choices": [
70+
{
71+
"delta": {
72+
"content": null,
73+
"function_call": null,
74+
"refusal": null,
75+
"role": null,
76+
"tool_calls": [
77+
{
78+
"index": 0,
79+
"id": "439c86fe5",
80+
"function": {
81+
"arguments": "{\"city\": \"Tokyo\"}",
82+
"name": "get_weather"
83+
},
84+
"type": "function"
85+
}
86+
]
87+
},
88+
"finish_reason": null,
89+
"index": 0,
90+
"logprobs": null
91+
}
92+
],
93+
"created": 1758191362,
94+
"model": "llama-3.3-70b",
95+
"object": "chat.completion.chunk",
96+
"service_tier": null,
97+
"system_fingerprint": "fp_c5ec625e72d41732d8fd",
98+
"usage": null
99+
}
100+
},
101+
{
102+
"__type__": "openai.types.chat.chat_completion_chunk.ChatCompletionChunk",
103+
"__data__": {
104+
"id": "chatcmpl-8b6a9499-1a5f-46dc-96b7-3d2b71eecd99",
105+
"choices": [
106+
{
107+
"delta": {
108+
"content": null,
109+
"function_call": null,
110+
"refusal": null,
111+
"role": null,
112+
"tool_calls": null
113+
},
114+
"finish_reason": "tool_calls",
115+
"index": 0,
116+
"logprobs": null
117+
}
118+
],
119+
"created": 1758191362,
120+
"model": "llama-3.3-70b",
121+
"object": "chat.completion.chunk",
122+
"service_tier": null,
123+
"system_fingerprint": "fp_c5ec625e72d41732d8fd",
124+
"usage": {
125+
"completion_tokens": 12,
126+
"prompt_tokens": 248,
127+
"total_tokens": 260,
128+
"completion_tokens_details": null,
129+
"prompt_tokens_details": {
130+
"audio_tokens": null,
131+
"cached_tokens": 0
132+
}
133+
},
134+
"time_info": {
135+
"queue_time": 0.00016941,
136+
"prompt_time": 0.007276727,
137+
"completion_time": 0.00388514,
138+
"total_time": 0.013146162033081055,
139+
"created": 1758191362
140+
}
141+
}
142+
}
143+
],
144+
"is_streaming": true
145+
}
146+
}

0 commit comments

Comments
 (0)