22
33import os
44import time
5+ import traceback
56from typing import Any , BinaryIO
67
78import requests
@@ -16,22 +17,45 @@ class Client:
1617 api_key: WaveSpeed API key. If not provided, uses wavespeed.config.api.api_key.
1718 base_url: Base URL for the API. If not provided, uses wavespeed.config.api.base_url.
1819 connection_timeout: Timeout for HTTP requests in seconds.
20+ max_retries: Maximum number of retries for the entire operation.
21+ max_connection_retries: Maximum retries for individual HTTP requests.
22+ retry_interval: Base interval between retries in seconds.
1923
2024 Example:
2125 client = Client(api_key="your-api-key")
2226 output = client.run("wavespeed-ai/z-image/turbo", input={"prompt": "Cat"})
27+
28+ # With sync mode (single request, waits for result)
29+ output = client.run("wavespeed-ai/z-image/turbo", input={"prompt": "Cat"}, enable_sync_mode=True)
30+
31+ # With retry
32+ output = client.run("wavespeed-ai/z-image/turbo", input={"prompt": "Cat"}, max_retries=3)
2333 """
2434
2535 def __init__ (
2636 self ,
2737 api_key : str | None = None ,
2838 base_url : str | None = None ,
2939 connection_timeout : float | None = None ,
40+ max_retries : int | None = None ,
41+ max_connection_retries : int | None = None ,
42+ retry_interval : float | None = None ,
3043 ) -> None :
3144 """Initialize the client."""
3245 self .api_key = api_key or api_config .api_key
3346 self .base_url = (base_url or api_config .base_url ).rstrip ("/" )
3447 self .connection_timeout = connection_timeout or api_config .connection_timeout
48+ self .max_retries = (
49+ max_retries if max_retries is not None else api_config .max_retries
50+ )
51+ self .max_connection_retries = (
52+ max_connection_retries
53+ if max_connection_retries is not None
54+ else api_config .max_connection_retries
55+ )
56+ self .retry_interval = (
57+ retry_interval if retry_interval is not None else api_config .retry_interval
58+ )
3559
3660 def _get_headers (self ) -> dict [str , str ]:
3761 """Get request headers with authentication."""
@@ -45,64 +69,138 @@ def _get_headers(self) -> dict[str, str]:
4569 "Authorization" : f"Bearer { self .api_key } " ,
4670 }
4771
48- def _submit (self , model : str , input : dict [str , Any ] | None ) -> str :
72+ def _submit (
73+ self ,
74+ model : str ,
75+ input : dict [str , Any ] | None ,
76+ enable_sync_mode : bool = False ,
77+ timeout : float | None = None ,
78+ ) -> tuple [str | None , dict [str , Any ] | None ]:
4979 """Submit a prediction request.
5080
5181 Args:
5282 model: Model identifier.
5383 input: Input parameters.
84+ enable_sync_mode: If True, wait for result in single request.
85+ timeout: Request timeout in seconds.
5486
5587 Returns:
56- Request ID for polling.
88+ Tuple of (request_id, result). In async mode, result is None.
89+ In sync mode, request_id is None and result contains the response.
5790
5891 Raises:
59- RuntimeError: If submission fails.
92+ RuntimeError: If submission fails after retries .
6093 """
6194 url = f"{ self .base_url } /api/v3/{ model } "
62- body = input or {}
95+ body = dict (input ) if input else {}
96+
97+ if enable_sync_mode :
98+ body ["enable_sync_mode" ] = True
6399
64- response = requests .post (
65- url , json = body , headers = self ._get_headers (), timeout = self .connection_timeout
100+ request_timeout = timeout if timeout is not None else api_config .timeout
101+ # Use connection timeout for connect, request_timeout for read
102+ connect_timeout = (
103+ min (self .connection_timeout , request_timeout )
104+ if request_timeout
105+ else self .connection_timeout
66106 )
107+ timeouts = (connect_timeout , request_timeout )
67108
68- if response . status_code != 200 :
69- raise RuntimeError (
70- f"Failed to submit prediction: HTTP { response . status_code } : "
71- f" { response . text } "
72- )
109+ for retry in range ( self . max_connection_retries + 1 ) :
110+ try :
111+ response = requests . post (
112+ url , json = body , headers = self . _get_headers (), timeout = timeouts
113+ )
73114
74- result = response .json ()
75- request_id = result .get ("data" , {}).get ("id" )
115+ if response .status_code != 200 :
116+ raise RuntimeError (
117+ f"Failed to submit prediction: HTTP { response .status_code } : "
118+ f"{ response .text } "
119+ )
120+
121+ result = response .json ()
122+
123+ if enable_sync_mode :
124+ return None , result
76125
77- if not request_id :
78- raise RuntimeError (f"No request ID in response: { result } " )
126+ request_id = result .get ("data" , {}).get ("id" )
127+ if not request_id :
128+ raise RuntimeError (f"No request ID in response: { result } " )
79129
80- return request_id
130+ return request_id , None
81131
82- def _get_result (self , request_id : str ) -> dict [str , Any ]:
132+ except (
133+ requests .exceptions .ConnectionError ,
134+ requests .exceptions .Timeout ,
135+ ) as e :
136+ print (
137+ f"Connection error on attempt { retry + 1 } /{ self .max_connection_retries + 1 } :"
138+ )
139+ traceback .print_exc ()
140+
141+ if retry < self .max_connection_retries :
142+ delay = self .retry_interval * (retry + 1 )
143+ print (f"Retrying in { delay } seconds..." )
144+ time .sleep (delay )
145+ else :
146+ raise RuntimeError (
147+ f"Failed to submit prediction after { self .max_connection_retries + 1 } attempts"
148+ ) from e
149+
150+ def _get_result (
151+ self , request_id : str , timeout : float | None = None
152+ ) -> dict [str , Any ]:
83153 """Get prediction result.
84154
85155 Args:
86156 request_id: The prediction request ID.
157+ timeout: Request timeout in seconds.
87158
88159 Returns:
89160 Full API response.
90161
91162 Raises:
92- RuntimeError: If fetching result fails.
163+ RuntimeError: If fetching result fails after retries .
93164 """
94165 url = f"{ self .base_url } /api/v3/predictions/{ request_id } /result"
95-
96- response = requests .get (
97- url , headers = self ._get_headers (), timeout = self .connection_timeout
166+ request_timeout = timeout if timeout is not None else api_config .timeout
167+ connect_timeout = (
168+ min (self .connection_timeout , request_timeout )
169+ if request_timeout
170+ else self .connection_timeout
98171 )
172+ timeouts = (connect_timeout , request_timeout )
99173
100- if response .status_code != 200 :
101- raise RuntimeError (
102- f"Failed to get result: HTTP { response .status_code } : { response .text } "
103- )
174+ for retry in range (self .max_connection_retries + 1 ):
175+ try :
176+ response = requests .get (
177+ url , headers = self ._get_headers (), timeout = timeouts
178+ )
179+
180+ if response .status_code != 200 :
181+ raise RuntimeError (
182+ f"Failed to get result: HTTP { response .status_code } : { response .text } "
183+ )
184+
185+ return response .json ()
186+
187+ except (
188+ requests .exceptions .ConnectionError ,
189+ requests .exceptions .Timeout ,
190+ ) as e :
191+ print (
192+ f"Connection error getting result on attempt { retry + 1 } /{ self .max_connection_retries + 1 } :"
193+ )
194+ traceback .print_exc ()
104195
105- return response .json ()
196+ if retry < self .max_connection_retries :
197+ delay = self .retry_interval * (retry + 1 )
198+ print (f"Retrying in { delay } seconds..." )
199+ time .sleep (delay )
200+ else :
201+ raise RuntimeError (
202+ f"Failed to get result after { self .max_connection_retries + 1 } attempts"
203+ ) from e
106204
107205 def _wait (
108206 self ,
@@ -133,7 +231,7 @@ def _wait(
133231 if elapsed >= timeout :
134232 raise TimeoutError (f"Prediction timed out after { timeout } seconds" )
135233
136- result = self ._get_result (request_id )
234+ result = self ._get_result (request_id , timeout = timeout )
137235 data = result .get ("data" , {})
138236 status = data .get ("status" )
139237
@@ -146,13 +244,38 @@ def _wait(
146244
147245 time .sleep (poll_interval )
148246
247+ def _is_retryable_error (self , error : Exception ) -> bool :
248+ """Determine if an error is worth retrying at the task level.
249+
250+ Args:
251+ error: The exception to check.
252+
253+ Returns:
254+ True if the error is retryable.
255+ """
256+ # Always retry timeout and connection errors
257+ if isinstance (
258+ error , (requests .exceptions .Timeout , requests .exceptions .ConnectionError )
259+ ):
260+ return True
261+
262+ # Retry server errors (5xx) and rate limiting (429)
263+ if isinstance (error , RuntimeError ):
264+ error_str = str (error )
265+ if "HTTP 5" in error_str or "HTTP 429" in error_str :
266+ return True
267+
268+ return False
269+
149270 def run (
150271 self ,
151272 model : str ,
152273 input : dict [str , Any ] | None = None ,
153274 * ,
154275 timeout : float | None = None ,
155276 poll_interval : float = 1.0 ,
277+ enable_sync_mode : bool = False ,
278+ max_retries : int | None = None ,
156279 ) -> dict [str , Any ]:
157280 """Run a model and wait for the output.
158281
@@ -161,6 +284,8 @@ def run(
161284 input: Input parameters for the model.
162285 timeout: Maximum time to wait for completion (None = no timeout).
163286 poll_interval: Interval between status checks in seconds.
287+ enable_sync_mode: If True, use synchronous mode (single request).
288+ max_retries: Maximum task-level retries (overrides client setting).
164289
165290 Returns:
166291 Dict containing "outputs" array with model outputs.
@@ -170,8 +295,38 @@ def run(
170295 RuntimeError: If the prediction fails.
171296 TimeoutError: If the prediction times out.
172297 """
173- request_id = self ._submit (model , input )
174- return self ._wait (request_id , timeout , poll_interval )
298+ task_retries = max_retries if max_retries is not None else self .max_retries
299+ last_error = None
300+
301+ for attempt in range (task_retries + 1 ):
302+ try :
303+ request_id , sync_result = self ._submit (
304+ model , input , enable_sync_mode = enable_sync_mode , timeout = timeout
305+ )
306+
307+ if enable_sync_mode :
308+ # In sync mode, extract outputs from the result
309+ data = sync_result .get ("data" , {})
310+ return {"outputs" : data .get ("outputs" , [])}
311+
312+ return self ._wait (request_id , timeout , poll_interval )
313+
314+ except Exception as e :
315+ last_error = e
316+ is_retryable = self ._is_retryable_error (e )
317+
318+ if not is_retryable or attempt >= task_retries :
319+ raise
320+
321+ print (f"Task attempt { attempt + 1 } /{ task_retries + 1 } failed: { e } " )
322+ delay = self .retry_interval * (attempt + 1 )
323+ print (f"Retrying in { delay } seconds..." )
324+ time .sleep (delay )
325+
326+ # Should not reach here, but just in case
327+ if last_error :
328+ raise last_error
329+ raise RuntimeError (f"All { task_retries + 1 } attempts failed" )
175330
176331 def upload (self , file : str | BinaryIO , * , timeout : float | None = None ) -> str :
177332 """Upload a file to WaveSpeed.
0 commit comments