@@ -15,6 +15,7 @@ class Client:
1515 Args:
1616 api_key: WaveSpeed API key. If not provided, uses wavespeed.config.api.api_key.
1717 base_url: Base URL for the API. If not provided, uses wavespeed.config.api.base_url.
18+ connection_timeout: Timeout for HTTP requests in seconds.
1819
1920 Example:
2021 client = Client(api_key="your-api-key")
@@ -25,10 +26,12 @@ def __init__(
2526 self ,
2627 api_key : str | None = None ,
2728 base_url : str | None = None ,
29+ connection_timeout : float | None = None ,
2830 ) -> None :
2931 """Initialize the client."""
3032 self .api_key = api_key or api_config .api_key
3133 self .base_url = (base_url or api_config .base_url ).rstrip ("/" )
34+ self .connection_timeout = connection_timeout or api_config .connection_timeout
3235
3336 def _get_headers (self ) -> dict [str , str ]:
3437 """Get request headers with authentication."""
@@ -58,7 +61,9 @@ def _submit(self, model: str, input: dict[str, Any] | None) -> str:
5861 url = f"{ self .base_url } /api/v3/{ model } "
5962 body = input or {}
6063
61- response = requests .post (url , json = body , headers = self ._get_headers ())
64+ response = requests .post (
65+ url , json = body , headers = self ._get_headers (), timeout = self .connection_timeout
66+ )
6267
6368 if response .status_code != 200 :
6469 raise RuntimeError (
@@ -88,7 +93,9 @@ def _get_result(self, request_id: str) -> dict[str, Any]:
8893 """
8994 url = f"{ self .base_url } /api/v3/predictions/{ request_id } /result"
9095
91- response = requests .get (url , headers = self ._get_headers ())
96+ response = requests .get (
97+ url , headers = self ._get_headers (), timeout = self .connection_timeout
98+ )
9299
93100 if response .status_code != 200 :
94101 raise RuntimeError (
@@ -166,11 +173,12 @@ def run(
166173 request_id = self ._submit (model , input )
167174 return self ._wait (request_id , timeout , poll_interval )
168175
169- def upload (self , file : str | BinaryIO ) -> str :
176+ def upload (self , file : str | BinaryIO , * , timeout : float | None = None ) -> str :
170177 """Upload a file to WaveSpeed.
171178
172179 Args:
173180 file: File path string or file-like object to upload.
181+ timeout: Total API call timeout in seconds.
174182
175183 Returns:
176184 URL of the uploaded file.
@@ -192,19 +200,25 @@ def upload(self, file: str | BinaryIO) -> str:
192200
193201 url = f"{ self .base_url } /api/v3/media/upload/binary"
194202 headers = {"Authorization" : f"Bearer { self .api_key } " }
203+ timeout = timeout or api_config .timeout
204+ request_timeout = (min (self .connection_timeout , timeout ), timeout )
195205
196206 if isinstance (file , str ):
197207 if not os .path .exists (file ):
198208 raise FileNotFoundError (f"File not found: { file } " )
199209 with open (file , "rb" ) as f :
200210 files = {"file" : (os .path .basename (file ), f )}
201- response = requests .post (url , headers = headers , files = files )
211+ response = requests .post (
212+ url , headers = headers , files = files , timeout = request_timeout
213+ )
202214 else :
203215 filename = getattr (file , "name" , "upload" )
204216 if isinstance (filename , str ) and os .path .sep in filename :
205217 filename = os .path .basename (filename )
206218 files = {"file" : (filename , file )}
207- response = requests .post (url , headers = headers , files = files )
219+ response = requests .post (
220+ url , headers = headers , files = files , timeout = request_timeout
221+ )
208222
209223 if response .status_code != 200 :
210224 raise RuntimeError (
0 commit comments