diff --git a/packages/client-python/src/rocketride/mixins/data.py b/packages/client-python/src/rocketride/mixins/data.py index c96c10166..9b9f231c2 100644 --- a/packages/client-python/src/rocketride/mixins/data.py +++ b/packages/client-python/src/rocketride/mixins/data.py @@ -139,6 +139,20 @@ def pipe_id(self) -> Optional[int]: """Get the unique ID assigned to this pipe by the server.""" return self._pipe_id + @staticmethod + def _is_transient_open_error(message: str) -> bool: + """Return True for short-lived connection errors while task ports are starting.""" + if not message: + return False + + lowered = message.lower() + transient_markers = [ + 'connect call failed', + '[errno 111]', + 'connection refused', + ] + return any(marker in lowered for marker in transient_markers) + async def open(self) -> 'DataMixin.DataPipe': """ Open the pipe for data transmission. @@ -171,10 +185,24 @@ async def open(self) -> 'DataMixin.DataPipe': token=self._token, ) - response = await self._client.request(request) + # Right after use(), the task's data port may need a brief warm-up. + # Retry connect-refused open errors for a short period. + max_attempts = 20 + retry_delay_seconds = 0.25 + response = None - if self._client.did_fail(response): - raise RuntimeError(response.get('message', 'Your pipeline is not currently running.')) + for attempt in range(max_attempts): + response = await self._client.request(request) + + if not self._client.did_fail(response): + break + + message = response.get('message', '') + is_last_attempt = attempt == (max_attempts - 1) + if is_last_attempt or not self._is_transient_open_error(message): + raise RuntimeError(response.get('message', 'Your pipeline is not currently running.')) + + await asyncio.sleep(retry_delay_seconds) self._pipe_id = response.get('body', {}).get('pipe_id') self._opened = True diff --git a/packages/client-python/tests/RocketRideClient_test.py b/packages/client-python/tests/RocketRideClient_test.py index c70540f14..2737c9e19 100644 --- a/packages/client-python/tests/RocketRideClient_test.py +++ b/packages/client-python/tests/RocketRideClient_test.py @@ -63,6 +63,7 @@ import json import random import string +import tempfile import time from pathlib import Path from typing import Dict, Any @@ -348,6 +349,47 @@ async def test_should_send_text_data_no_mime_type(self): if client.is_connected(): await client.disconnect() + @pytest.mark.asyncio + async def test_should_send_text_data_after_use_with_filepath(self): + """Regression for docs flow: use(filepath=...) followed by send(...).""" + client = RocketRideClient(auth=TEST_CONFIG['auth'], uri=TEST_CONFIG['uri']) + pipeline_token = None + token = f'{self.DATA_TOKEN}-filepath' + try: + await client.connect() + await ensure_clean_pipeline(client, token) + + # Mirror docs-style usage where pipeline config is loaded from disk. + pipeline_config = get_echo_pipeline() + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', encoding='utf-8', delete=False) as temp_file: + json.dump(pipeline_config, temp_file) + temp_path = temp_file.name + + try: + result = await client.use(filepath=temp_path, token=token) + pipeline_token = result['token'] + + send_result = await client.send(pipeline_token, 'Hello, pipeline!', objinfo={'name': 'input.txt'}, mimetype='text/plain') + + assert send_result is not None + assert isinstance(send_result, dict) + assert 'result_types' in send_result + assert isinstance(send_result['result_types'], dict) + assert send_result['result_types'].get('text') == 'text' + assert 'text' in send_result + assert isinstance(send_result['text'], list) + assert any('Hello, pipeline!' in chunk for chunk in send_result['text']) + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + finally: + if pipeline_token: + await ensure_clean_pipeline(client, pipeline_token) + else: + await ensure_clean_pipeline(client, token) + if client.is_connected(): + await client.disconnect() + @pytest.mark.asyncio async def test_should_send_text_data_with_mime_type(self): client = RocketRideClient(auth=TEST_CONFIG['auth'], uri=TEST_CONFIG['uri'])