diff --git a/.changes/next-release/waiter-delay-max-attempts.json b/.changes/next-release/waiter-delay-max-attempts.json
new file mode 100644
index 000000000000..f4e75bfc936f
--- /dev/null
+++ b/.changes/next-release/waiter-delay-max-attempts.json
@@ -0,0 +1,7 @@
+[
+ {
+ "category": "waiter",
+ "description": "Add ``--delay`` and ``--max-attempts`` arguments to all ``wait`` commands, allowing users to override the default polling interval and maximum number of attempts.",
+ "type": "enhancement"
+ }
+]
diff --git a/awscli/customizations/waiters.py b/awscli/customizations/waiters.py
index a5ce1d9590a7..1117ff131e41 100644
--- a/awscli/customizations/waiters.py
+++ b/awscli/customizations/waiters.py
@@ -10,15 +10,84 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
-from botocore import xform_name
+from botocore import model, xform_name
from botocore.exceptions import DataNotFoundError
+from awscli.arguments import BaseCLIArgument
from awscli.clidriver import ServiceOperation
from awscli.customizations.commands import (
BasicCommand,
BasicDocHandler,
BasicHelp,
)
+from awscli.customizations.exceptions import ParamValidationError
+
+
+DELAY_HELP = (
+ '
The amount of time in seconds to wait between attempts. '
+ 'If not specified, the default delay for the waiter is used.
'
+)
+
+MAX_ATTEMPTS_HELP = (
+ 'The maximum number of attempts to be made. '
+ 'If not specified, the default max attempts value for the '
+ 'waiter is used.
'
+)
+
+
+class WaiterArgument(BaseCLIArgument):
+
+ def __init__(self, name, documentation, serialized_name):
+ self.argument_model = model.Shape(
+ 'WaiterArgument', {'type': 'integer'}
+ )
+ self._name = name
+ self._serialized_name = serialized_name
+ self._documentation = documentation
+ self._required = False
+
+ @property
+ def cli_name(self):
+ return '--' + self._name
+
+ @property
+ def cli_type_name(self):
+ return 'integer'
+
+ @property
+ def required(self):
+ return self._required
+
+ @required.setter
+ def required(self, value):
+ self._required = value
+
+ @property
+ def documentation(self):
+ return self._documentation
+
+ def add_to_parser(self, parser):
+ parser.add_argument(
+ self.cli_name,
+ dest=self.py_name,
+ type=int,
+ )
+
+ def add_to_params(self, parameters, value):
+ if value is not None:
+ if self._serialized_name == 'Delay' and value < 0:
+ raise ParamValidationError(
+ '--delay must be a non-negative integer, '
+ 'got %s' % value
+ )
+ if self._serialized_name == 'MaxAttempts' and value < 1:
+ raise ParamValidationError(
+ '--max-attempts must be a positive integer, '
+ 'got %s' % value
+ )
+ waiter_config = parameters.get('WaiterConfig', {})
+ waiter_config[self._serialized_name] = value
+ parameters['WaiterConfig'] = waiter_config
def register_add_waiters(cli):
@@ -203,7 +272,9 @@ def _build_polling_description(self, delay, max_attempts):
description = (
' It will poll every %s seconds until a successful state '
'has been reached. This will exit with a return code of 255 '
- 'after %s failed checks.' % (delay, max_attempts)
+ 'after %s failed checks. You can override the default polling '
+ 'behavior with ``--delay`` and ``--max-attempts``.'
+ % (delay, max_attempts)
)
return description
@@ -212,6 +283,10 @@ class WaiterCaller:
def __init__(self, session, waiter_name):
self._session = session
self._waiter_name = waiter_name
+ self._waiter_config = None
+
+ def set_waiter_config(self, waiter_config):
+ self._waiter_config = waiter_config
def invoke(self, service_name, operation_name, parameters, parsed_globals):
client = self._session.create_client(
@@ -221,6 +296,8 @@ def invoke(self, service_name, operation_name, parameters, parsed_globals):
verify=parsed_globals.verify_ssl,
)
waiter = client.get_waiter(xform_name(self._waiter_name))
+ if self._waiter_config is not None:
+ parameters = dict(parameters, WaiterConfig=self._waiter_config)
waiter.wait(**parameters)
return 0
@@ -228,6 +305,25 @@ def invoke(self, service_name, operation_name, parameters, parsed_globals):
class WaiterStateCommand(ServiceOperation):
DESCRIPTION = ''
+ def _create_argument_table(self):
+ argument_table = super()._create_argument_table()
+ argument_table['delay'] = WaiterArgument(
+ 'delay', DELAY_HELP, 'Delay',
+ )
+ argument_table['max-attempts'] = WaiterArgument(
+ 'max-attempts', MAX_ATTEMPTS_HELP, 'MaxAttempts',
+ )
+ return argument_table
+
+ def _build_call_parameters(self, args, arg_table):
+ service_params = super()._build_call_parameters(args, arg_table)
+ # Strip WaiterConfig from call parameters so it doesn't leak
+ # into --generate-cli-skeleton or other API model validation.
+ # The WaiterCaller will inject it back when calling waiter.wait().
+ waiter_config = service_params.pop('WaiterConfig', None)
+ self._operation_caller.set_waiter_config(waiter_config)
+ return service_params
+
def create_help_command(self):
help_command = super(WaiterStateCommand, self).create_help_command()
# Change the operation object's description by changing it to the
diff --git a/tests/unit/customizations/test_waiters.py b/tests/unit/customizations/test_waiters.py
index 11673e877f2d..13ca4686855b 100644
--- a/tests/unit/customizations/test_waiters.py
+++ b/tests/unit/customizations/test_waiters.py
@@ -16,6 +16,7 @@
from awscli.customizations.exceptions import ParamValidationError
from awscli.customizations.waiters import (
WaitCommand,
+ WaiterArgument,
WaiterCaller,
WaiterStateCommand,
WaiterStateCommandBuilder,
@@ -166,6 +167,13 @@ def test_wait_state_help_command(self):
self.assert_contains('``--filters`` (list)')
self.assert_contains('======\nOutput\n======\n\nNone')
+ def test_wait_state_help_shows_waiter_params(self):
+ self.driver.main(['ec2', 'wait', 'instance-running', 'help'])
+ self.assert_contains('[--delay ]')
+ self.assert_contains('[--max-attempts ]')
+ self.assert_contains('``--delay``')
+ self.assert_contains('``--max-attempts``')
+
class TestWait(BaseAWSCommandParamsTest):
"""This is merely a smoke test.
@@ -203,6 +211,53 @@ def test_rds_jobs_complete(self):
}
self.assert_params_for_cmd(cmdline, result)
+ def test_ec2_wait_with_delay_and_max_attempts(self):
+ cmdline = 'ec2 wait instance-running'
+ cmdline += ' --instance-ids i-12345678'
+ cmdline += ' --delay 10 --max-attempts 100'
+ # WaiterConfig is popped by botocore's Waiter.wait() before
+ # the API call, so only API params appear in last_kwargs.
+ result = {'InstanceIds': ['i-12345678']}
+ self.parsed_response = {
+ 'Reservations': [{'Instances': [{'State': {'Name': 'running'}}]}]
+ }
+ self.assert_params_for_cmd(cmdline, result)
+
+ def test_ec2_wait_forwards_waiter_config_end_to_end(self):
+ # End-to-end check: --delay and --max-attempts must flow from the
+ # command line all the way into waiter.wait() as WaiterConfig,
+ # alongside the regular API parameters.
+ with mock.patch(
+ 'botocore.client.BaseClient.get_waiter'
+ ) as mock_get_waiter:
+ mock_waiter = mock.Mock()
+ mock_get_waiter.return_value = mock_waiter
+ cmdline = (
+ 'ec2 wait instance-running '
+ '--instance-ids i-12345678 '
+ '--delay 10 --max-attempts 100'
+ )
+ self.run_cmd(cmdline, expected_rc=0)
+ mock_get_waiter.assert_called_with('instance_running')
+ mock_waiter.wait.assert_called_once_with(
+ InstanceIds=['i-12345678'],
+ WaiterConfig={'Delay': 10, 'MaxAttempts': 100},
+ )
+
+ def test_ec2_wait_without_waiter_flags_omits_waiter_config(self):
+ # When neither --delay nor --max-attempts are supplied, no
+ # WaiterConfig kwarg should be passed to waiter.wait().
+ with mock.patch(
+ 'botocore.client.BaseClient.get_waiter'
+ ) as mock_get_waiter:
+ mock_waiter = mock.Mock()
+ mock_get_waiter.return_value = mock_waiter
+ cmdline = 'ec2 wait instance-running --instance-ids i-12345678'
+ self.run_cmd(cmdline, expected_rc=0)
+ mock_waiter.wait.assert_called_once_with(
+ InstanceIds=['i-12345678']
+ )
+
class TestWaiterStateCommandBuilder(unittest.TestCase):
def setUp(self):
@@ -258,13 +313,17 @@ def test_build_waiter_state_cmds(self):
instance_running_cmd.DESCRIPTION,
'My waiter description. It will poll every 1 seconds until '
'a successful state has been reached. This will exit with a '
- 'return code of 255 after 10 failed checks.',
+ 'return code of 255 after 10 failed checks. You can override '
+ 'the default polling behavior with ``--delay`` and '
+ '``--max-attempts``.',
)
self.assertEqual(
bucket_exists_cmd.DESCRIPTION,
'My waiter description. It will poll every 1 seconds until '
'a successful state has been reached. This will exit with a '
- 'return code of 255 after 10 failed checks.',
+ 'return code of 255 after 10 failed checks. You can override '
+ 'the default polling behavior with ``--delay`` and '
+ '``--max-attempts``.',
)
@@ -299,7 +358,9 @@ def test_config_provided_description(self):
description,
'My description. It will poll every 5 seconds until a '
'successful state has been reached. This will exit with a '
- 'return code of 255 after 20 failed checks.',
+ 'return code of 255 after 20 failed checks. You can override '
+ 'the default polling behavior with ``--delay`` and '
+ '``--max-attempts``.',
)
def test_error_acceptor(self):
@@ -311,7 +372,9 @@ def test_error_acceptor(self):
'Wait until MyException is thrown when polling with '
'``my-operation``. It will poll every 5 seconds until a '
'successful state has been reached. This will exit with a '
- 'return code of 255 after 20 failed checks.',
+ 'return code of 255 after 20 failed checks. You can override '
+ 'the default polling behavior with ``--delay`` and '
+ '``--max-attempts``.',
)
def test_status_acceptor(self):
@@ -323,7 +386,9 @@ def test_status_acceptor(self):
'Wait until 200 response is received when polling with '
'``my-operation``. It will poll every 5 seconds until a '
'successful state has been reached. This will exit with a '
- 'return code of 255 after 20 failed checks.',
+ 'return code of 255 after 20 failed checks. You can override '
+ 'the default polling behavior with ``--delay`` and '
+ '``--max-attempts``.',
)
def test_path_acceptor(self):
@@ -336,7 +401,9 @@ def test_path_acceptor(self):
'Wait until JMESPath query MyResource.name returns running when '
'polling with ``my-operation``. It will poll every 5 seconds '
'until a successful state has been reached. This will exit with '
- 'a return code of 255 after 20 failed checks.',
+ 'a return code of 255 after 20 failed checks. You can override '
+ 'the default polling behavior with ``--delay`` and '
+ '``--max-attempts``.',
)
def test_path_all_acceptor(self):
@@ -349,7 +416,9 @@ def test_path_all_acceptor(self):
'Wait until JMESPath query MyResource[].name returns running for '
'all elements when polling with ``my-operation``. It will poll '
'every 5 seconds until a successful state has been reached. '
- 'This will exit with a return code of 255 after 20 failed checks.',
+ 'This will exit with a return code of 255 after 20 failed checks. '
+ 'You can override the default polling behavior with ``--delay`` '
+ 'and ``--max-attempts``.',
)
def test_path_any_acceptor(self):
@@ -362,7 +431,9 @@ def test_path_any_acceptor(self):
'Wait until JMESPath query MyResource[].name returns running for '
'any element when polling with ``my-operation``. It will poll '
'every 5 seconds until a successful state has been reached. '
- 'This will exit with a return code of 255 after 20 failed checks.',
+ 'This will exit with a return code of 255 after 20 failed checks. '
+ 'You can override the default polling behavior with ``--delay`` '
+ 'and ``--max-attempts``.',
)
@@ -399,3 +470,129 @@ def test_invoke(self):
# Ensure the wait command was called properly.
waiter.wait.assert_called_with(Foo='bar', Baz='biz')
+
+ def test_invoke_with_set_waiter_config(self):
+ waiter = mock.Mock()
+ waiter_name = 'my_waiter'
+ session = mock.Mock()
+ session.create_client.return_value.get_waiter.return_value = waiter
+
+ parameters = {'Foo': 'bar'}
+ parsed_globals = mock.Mock()
+ parsed_globals.region = 'us-east-1'
+ parsed_globals.endpoint_url = 'myurl'
+ parsed_globals.verify_ssl = True
+
+ waiter_caller = WaiterCaller(session, waiter_name)
+ waiter_caller.set_waiter_config({'Delay': 10, 'MaxAttempts': 50})
+ waiter_caller.invoke(
+ 'myservice', 'MyWaiter', parameters, parsed_globals
+ )
+
+ waiter.wait.assert_called_with(
+ Foo='bar',
+ WaiterConfig={'Delay': 10, 'MaxAttempts': 50},
+ )
+
+ def test_invoke_without_waiter_config(self):
+ waiter = mock.Mock()
+ waiter_name = 'my_waiter'
+ session = mock.Mock()
+ session.create_client.return_value.get_waiter.return_value = waiter
+
+ parameters = {'Foo': 'bar'}
+ parsed_globals = mock.Mock()
+ parsed_globals.region = 'us-east-1'
+ parsed_globals.endpoint_url = 'myurl'
+ parsed_globals.verify_ssl = True
+
+ waiter_caller = WaiterCaller(session, waiter_name)
+ waiter_caller.invoke(
+ 'myservice', 'MyWaiter', parameters, parsed_globals
+ )
+
+ # Without set_waiter_config, no WaiterConfig should be passed.
+ waiter.wait.assert_called_with(Foo='bar')
+
+
+class TestWaiterArgument(unittest.TestCase):
+ def test_add_delay_to_params(self):
+ arg = WaiterArgument('delay', 'help text', 'Delay')
+ params = {}
+ arg.add_to_params(params, 10)
+ self.assertEqual(params, {'WaiterConfig': {'Delay': 10}})
+
+ def test_add_max_attempts_to_params(self):
+ arg = WaiterArgument('max-attempts', 'help text', 'MaxAttempts')
+ params = {}
+ arg.add_to_params(params, 50)
+ self.assertEqual(params, {'WaiterConfig': {'MaxAttempts': 50}})
+
+ def test_both_params_together(self):
+ delay_arg = WaiterArgument('delay', 'help text', 'Delay')
+ max_arg = WaiterArgument('max-attempts', 'help text', 'MaxAttempts')
+ params = {}
+ delay_arg.add_to_params(params, 10)
+ max_arg.add_to_params(params, 50)
+ self.assertEqual(
+ params,
+ {'WaiterConfig': {'Delay': 10, 'MaxAttempts': 50}},
+ )
+
+ def test_none_value_not_added(self):
+ arg = WaiterArgument('delay', 'help text', 'Delay')
+ params = {}
+ arg.add_to_params(params, None)
+ self.assertEqual(params, {})
+
+ def test_cli_name(self):
+ arg = WaiterArgument('delay', 'help text', 'Delay')
+ self.assertEqual(arg.cli_name, '--delay')
+
+ def test_cli_type_name(self):
+ arg = WaiterArgument('delay', 'help text', 'Delay')
+ self.assertEqual(arg.cli_type_name, 'integer')
+
+ def test_required_defaults_to_false(self):
+ arg = WaiterArgument('delay', 'help text', 'Delay')
+ self.assertFalse(arg.required)
+
+ def test_delay_rejects_negative(self):
+ arg = WaiterArgument('delay', 'help text', 'Delay')
+ params = {}
+ with self.assertRaises(Exception) as ctx:
+ arg.add_to_params(params, -1)
+ self.assertIn('--delay', str(ctx.exception))
+ self.assertEqual(params, {})
+
+ def test_delay_rejects_negative_large(self):
+ arg = WaiterArgument('delay', 'help text', 'Delay')
+ params = {}
+ with self.assertRaises(Exception):
+ arg.add_to_params(params, -100)
+
+ def test_delay_accepts_zero(self):
+ arg = WaiterArgument('delay', 'help text', 'Delay')
+ params = {}
+ arg.add_to_params(params, 0)
+ self.assertEqual(params, {'WaiterConfig': {'Delay': 0}})
+
+ def test_max_attempts_rejects_zero(self):
+ arg = WaiterArgument('max-attempts', 'help text', 'MaxAttempts')
+ params = {}
+ with self.assertRaises(Exception) as ctx:
+ arg.add_to_params(params, 0)
+ self.assertIn('--max-attempts', str(ctx.exception))
+ self.assertEqual(params, {})
+
+ def test_max_attempts_rejects_negative(self):
+ arg = WaiterArgument('max-attempts', 'help text', 'MaxAttempts')
+ params = {}
+ with self.assertRaises(Exception):
+ arg.add_to_params(params, -1)
+
+ def test_max_attempts_accepts_one(self):
+ arg = WaiterArgument('max-attempts', 'help text', 'MaxAttempts')
+ params = {}
+ arg.add_to_params(params, 1)
+ self.assertEqual(params, {'WaiterConfig': {'MaxAttempts': 1}})