Skip to content

Commit 35e5f88

Browse files
chengzeyiclaude
andcommitted
Add retry and sync mode support to API client
- Add enable_sync_mode parameter for single-request synchronous calls - Add max_retries for task-level retries (entire submit+wait cycle) - Add max_connection_retries for HTTP request retries (connection errors, timeouts) - Add retry_interval config for delay between retries - Update tests to include new config attributes 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 15f14ee commit 35e5f88

File tree

4 files changed

+226
-30
lines changed

4 files changed

+226
-30
lines changed

src/wavespeed/api/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def run(
4141
*,
4242
timeout: float | None = None,
4343
poll_interval: float = 1.0,
44+
enable_sync_mode: bool = False,
45+
max_retries: int | None = None,
4446
) -> dict:
4547
"""Run a model and wait for the output.
4648
@@ -49,6 +51,8 @@ def run(
4951
input: Input parameters for the model.
5052
timeout: Maximum time to wait for completion (None = no timeout).
5153
poll_interval: Interval between status checks in seconds.
54+
enable_sync_mode: If True, use synchronous mode (single request).
55+
max_retries: Maximum retries for this request (overrides default setting).
5256
5357
Returns:
5458
Dict containing "outputs" array with model outputs.
@@ -64,12 +68,28 @@ def run(
6468
input={"prompt": "A cat sitting on a windowsill"}
6569
)
6670
print(output["outputs"][0]) # First output URL
71+
72+
# With sync mode
73+
output = wavespeed.run(
74+
"wavespeed-ai/z-image/turbo",
75+
input={"prompt": "A cat"},
76+
enable_sync_mode=True
77+
)
78+
79+
# With retry
80+
output = wavespeed.run(
81+
"wavespeed-ai/z-image/turbo",
82+
input={"prompt": "A cat"},
83+
max_retries=3
84+
)
6785
"""
6886
return _get_default_client().run(
6987
model,
7088
input=input,
7189
timeout=timeout,
7290
poll_interval=poll_interval,
91+
enable_sync_mode=enable_sync_mode,
92+
max_retries=max_retries,
7393
)
7494

7595

src/wavespeed/api/client.py

Lines changed: 184 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import time
5+
import traceback
56
from typing import Any, BinaryIO
67

78
import requests
@@ -16,22 +17,45 @@ class Client:
1617
api_key: WaveSpeed API key. If not provided, uses wavespeed.config.api.api_key.
1718
base_url: Base URL for the API. If not provided, uses wavespeed.config.api.base_url.
1819
connection_timeout: Timeout for HTTP requests in seconds.
20+
max_retries: Maximum number of retries for the entire operation.
21+
max_connection_retries: Maximum retries for individual HTTP requests.
22+
retry_interval: Base interval between retries in seconds.
1923
2024
Example:
2125
client = Client(api_key="your-api-key")
2226
output = client.run("wavespeed-ai/z-image/turbo", input={"prompt": "Cat"})
27+
28+
# With sync mode (single request, waits for result)
29+
output = client.run("wavespeed-ai/z-image/turbo", input={"prompt": "Cat"}, enable_sync_mode=True)
30+
31+
# With retry
32+
output = client.run("wavespeed-ai/z-image/turbo", input={"prompt": "Cat"}, max_retries=3)
2333
"""
2434

2535
def __init__(
2636
self,
2737
api_key: str | None = None,
2838
base_url: str | None = None,
2939
connection_timeout: float | None = None,
40+
max_retries: int | None = None,
41+
max_connection_retries: int | None = None,
42+
retry_interval: float | None = None,
3043
) -> None:
3144
"""Initialize the client."""
3245
self.api_key = api_key or api_config.api_key
3346
self.base_url = (base_url or api_config.base_url).rstrip("/")
3447
self.connection_timeout = connection_timeout or api_config.connection_timeout
48+
self.max_retries = (
49+
max_retries if max_retries is not None else api_config.max_retries
50+
)
51+
self.max_connection_retries = (
52+
max_connection_retries
53+
if max_connection_retries is not None
54+
else api_config.max_connection_retries
55+
)
56+
self.retry_interval = (
57+
retry_interval if retry_interval is not None else api_config.retry_interval
58+
)
3559

3660
def _get_headers(self) -> dict[str, str]:
3761
"""Get request headers with authentication."""
@@ -45,64 +69,138 @@ def _get_headers(self) -> dict[str, str]:
4569
"Authorization": f"Bearer {self.api_key}",
4670
}
4771

48-
def _submit(self, model: str, input: dict[str, Any] | None) -> str:
72+
def _submit(
73+
self,
74+
model: str,
75+
input: dict[str, Any] | None,
76+
enable_sync_mode: bool = False,
77+
timeout: float | None = None,
78+
) -> tuple[str | None, dict[str, Any] | None]:
4979
"""Submit a prediction request.
5080
5181
Args:
5282
model: Model identifier.
5383
input: Input parameters.
84+
enable_sync_mode: If True, wait for result in single request.
85+
timeout: Request timeout in seconds.
5486
5587
Returns:
56-
Request ID for polling.
88+
Tuple of (request_id, result). In async mode, result is None.
89+
In sync mode, request_id is None and result contains the response.
5790
5891
Raises:
59-
RuntimeError: If submission fails.
92+
RuntimeError: If submission fails after retries.
6093
"""
6194
url = f"{self.base_url}/api/v3/{model}"
62-
body = input or {}
95+
body = dict(input) if input else {}
96+
97+
if enable_sync_mode:
98+
body["enable_sync_mode"] = True
6399

64-
response = requests.post(
65-
url, json=body, headers=self._get_headers(), timeout=self.connection_timeout
100+
request_timeout = timeout if timeout is not None else api_config.timeout
101+
# Use connection timeout for connect, request_timeout for read
102+
connect_timeout = (
103+
min(self.connection_timeout, request_timeout)
104+
if request_timeout
105+
else self.connection_timeout
66106
)
107+
timeouts = (connect_timeout, request_timeout)
67108

68-
if response.status_code != 200:
69-
raise RuntimeError(
70-
f"Failed to submit prediction: HTTP {response.status_code}: "
71-
f"{response.text}"
72-
)
109+
for retry in range(self.max_connection_retries + 1):
110+
try:
111+
response = requests.post(
112+
url, json=body, headers=self._get_headers(), timeout=timeouts
113+
)
73114

74-
result = response.json()
75-
request_id = result.get("data", {}).get("id")
115+
if response.status_code != 200:
116+
raise RuntimeError(
117+
f"Failed to submit prediction: HTTP {response.status_code}: "
118+
f"{response.text}"
119+
)
120+
121+
result = response.json()
122+
123+
if enable_sync_mode:
124+
return None, result
76125

77-
if not request_id:
78-
raise RuntimeError(f"No request ID in response: {result}")
126+
request_id = result.get("data", {}).get("id")
127+
if not request_id:
128+
raise RuntimeError(f"No request ID in response: {result}")
79129

80-
return request_id
130+
return request_id, None
81131

82-
def _get_result(self, request_id: str) -> dict[str, Any]:
132+
except (
133+
requests.exceptions.ConnectionError,
134+
requests.exceptions.Timeout,
135+
) as e:
136+
print(
137+
f"Connection error on attempt {retry + 1}/{self.max_connection_retries + 1}:"
138+
)
139+
traceback.print_exc()
140+
141+
if retry < self.max_connection_retries:
142+
delay = self.retry_interval * (retry + 1)
143+
print(f"Retrying in {delay} seconds...")
144+
time.sleep(delay)
145+
else:
146+
raise RuntimeError(
147+
f"Failed to submit prediction after {self.max_connection_retries + 1} attempts"
148+
) from e
149+
150+
def _get_result(
151+
self, request_id: str, timeout: float | None = None
152+
) -> dict[str, Any]:
83153
"""Get prediction result.
84154
85155
Args:
86156
request_id: The prediction request ID.
157+
timeout: Request timeout in seconds.
87158
88159
Returns:
89160
Full API response.
90161
91162
Raises:
92-
RuntimeError: If fetching result fails.
163+
RuntimeError: If fetching result fails after retries.
93164
"""
94165
url = f"{self.base_url}/api/v3/predictions/{request_id}/result"
95-
96-
response = requests.get(
97-
url, headers=self._get_headers(), timeout=self.connection_timeout
166+
request_timeout = timeout if timeout is not None else api_config.timeout
167+
connect_timeout = (
168+
min(self.connection_timeout, request_timeout)
169+
if request_timeout
170+
else self.connection_timeout
98171
)
172+
timeouts = (connect_timeout, request_timeout)
99173

100-
if response.status_code != 200:
101-
raise RuntimeError(
102-
f"Failed to get result: HTTP {response.status_code}: {response.text}"
103-
)
174+
for retry in range(self.max_connection_retries + 1):
175+
try:
176+
response = requests.get(
177+
url, headers=self._get_headers(), timeout=timeouts
178+
)
179+
180+
if response.status_code != 200:
181+
raise RuntimeError(
182+
f"Failed to get result: HTTP {response.status_code}: {response.text}"
183+
)
184+
185+
return response.json()
186+
187+
except (
188+
requests.exceptions.ConnectionError,
189+
requests.exceptions.Timeout,
190+
) as e:
191+
print(
192+
f"Connection error getting result on attempt {retry + 1}/{self.max_connection_retries + 1}:"
193+
)
194+
traceback.print_exc()
104195

105-
return response.json()
196+
if retry < self.max_connection_retries:
197+
delay = self.retry_interval * (retry + 1)
198+
print(f"Retrying in {delay} seconds...")
199+
time.sleep(delay)
200+
else:
201+
raise RuntimeError(
202+
f"Failed to get result after {self.max_connection_retries + 1} attempts"
203+
) from e
106204

107205
def _wait(
108206
self,
@@ -133,7 +231,7 @@ def _wait(
133231
if elapsed >= timeout:
134232
raise TimeoutError(f"Prediction timed out after {timeout} seconds")
135233

136-
result = self._get_result(request_id)
234+
result = self._get_result(request_id, timeout=timeout)
137235
data = result.get("data", {})
138236
status = data.get("status")
139237

@@ -146,13 +244,38 @@ def _wait(
146244

147245
time.sleep(poll_interval)
148246

247+
def _is_retryable_error(self, error: Exception) -> bool:
248+
"""Determine if an error is worth retrying at the task level.
249+
250+
Args:
251+
error: The exception to check.
252+
253+
Returns:
254+
True if the error is retryable.
255+
"""
256+
# Always retry timeout and connection errors
257+
if isinstance(
258+
error, (requests.exceptions.Timeout, requests.exceptions.ConnectionError)
259+
):
260+
return True
261+
262+
# Retry server errors (5xx) and rate limiting (429)
263+
if isinstance(error, RuntimeError):
264+
error_str = str(error)
265+
if "HTTP 5" in error_str or "HTTP 429" in error_str:
266+
return True
267+
268+
return False
269+
149270
def run(
150271
self,
151272
model: str,
152273
input: dict[str, Any] | None = None,
153274
*,
154275
timeout: float | None = None,
155276
poll_interval: float = 1.0,
277+
enable_sync_mode: bool = False,
278+
max_retries: int | None = None,
156279
) -> dict[str, Any]:
157280
"""Run a model and wait for the output.
158281
@@ -161,6 +284,8 @@ def run(
161284
input: Input parameters for the model.
162285
timeout: Maximum time to wait for completion (None = no timeout).
163286
poll_interval: Interval between status checks in seconds.
287+
enable_sync_mode: If True, use synchronous mode (single request).
288+
max_retries: Maximum task-level retries (overrides client setting).
164289
165290
Returns:
166291
Dict containing "outputs" array with model outputs.
@@ -170,8 +295,38 @@ def run(
170295
RuntimeError: If the prediction fails.
171296
TimeoutError: If the prediction times out.
172297
"""
173-
request_id = self._submit(model, input)
174-
return self._wait(request_id, timeout, poll_interval)
298+
task_retries = max_retries if max_retries is not None else self.max_retries
299+
last_error = None
300+
301+
for attempt in range(task_retries + 1):
302+
try:
303+
request_id, sync_result = self._submit(
304+
model, input, enable_sync_mode=enable_sync_mode, timeout=timeout
305+
)
306+
307+
if enable_sync_mode:
308+
# In sync mode, extract outputs from the result
309+
data = sync_result.get("data", {})
310+
return {"outputs": data.get("outputs", [])}
311+
312+
return self._wait(request_id, timeout, poll_interval)
313+
314+
except Exception as e:
315+
last_error = e
316+
is_retryable = self._is_retryable_error(e)
317+
318+
if not is_retryable or attempt >= task_retries:
319+
raise
320+
321+
print(f"Task attempt {attempt + 1}/{task_retries + 1} failed: {e}")
322+
delay = self.retry_interval * (attempt + 1)
323+
print(f"Retrying in {delay} seconds...")
324+
time.sleep(delay)
325+
326+
# Should not reach here, but just in case
327+
if last_error:
328+
raise last_error
329+
raise RuntimeError(f"All {task_retries + 1} attempts failed")
175330

176331
def upload(self, file: str | BinaryIO, *, timeout: float | None = None) -> str:
177332
"""Upload a file to WaveSpeed.

src/wavespeed/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ class api:
2727
# Total API call timeout in seconds
2828
timeout: float = 36000.0
2929

30+
# Maximum number of retries for the entire operation (task-level retries)
31+
max_retries: int = 0
32+
33+
# Maximum number of retries for individual HTTP requests (connection errors, timeouts)
34+
max_connection_retries: int = 5
35+
36+
# Base interval between retries in seconds (actual delay = retry_interval * attempt)
37+
retry_interval: float = 1.0
38+
3039

3140
class serverless:
3241
"""Serverless configuration options.

0 commit comments

Comments
 (0)