diff --git a/.changes/next-release/waiter-delay-max-attempts.json b/.changes/next-release/waiter-delay-max-attempts.json new file mode 100644 index 000000000000..f4e75bfc936f --- /dev/null +++ b/.changes/next-release/waiter-delay-max-attempts.json @@ -0,0 +1,7 @@ +[ + { + "category": "waiter", + "description": "Add ``--delay`` and ``--max-attempts`` arguments to all ``wait`` commands, allowing users to override the default polling interval and maximum number of attempts.", + "type": "enhancement" + } +] diff --git a/awscli/customizations/waiters.py b/awscli/customizations/waiters.py index a5ce1d9590a7..1117ff131e41 100644 --- a/awscli/customizations/waiters.py +++ b/awscli/customizations/waiters.py @@ -10,15 +10,84 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from botocore import xform_name +from botocore import model, xform_name from botocore.exceptions import DataNotFoundError +from awscli.arguments import BaseCLIArgument from awscli.clidriver import ServiceOperation from awscli.customizations.commands import ( BasicCommand, BasicDocHandler, BasicHelp, ) +from awscli.customizations.exceptions import ParamValidationError + + +DELAY_HELP = ( + '

The amount of time in seconds to wait between attempts. ' + 'If not specified, the default delay for the waiter is used.

' +) + +MAX_ATTEMPTS_HELP = ( + '

The maximum number of attempts to be made. ' + 'If not specified, the default max attempts value for the ' + 'waiter is used.

' +) + + +class WaiterArgument(BaseCLIArgument): + + def __init__(self, name, documentation, serialized_name): + self.argument_model = model.Shape( + 'WaiterArgument', {'type': 'integer'} + ) + self._name = name + self._serialized_name = serialized_name + self._documentation = documentation + self._required = False + + @property + def cli_name(self): + return '--' + self._name + + @property + def cli_type_name(self): + return 'integer' + + @property + def required(self): + return self._required + + @required.setter + def required(self, value): + self._required = value + + @property + def documentation(self): + return self._documentation + + def add_to_parser(self, parser): + parser.add_argument( + self.cli_name, + dest=self.py_name, + type=int, + ) + + def add_to_params(self, parameters, value): + if value is not None: + if self._serialized_name == 'Delay' and value < 0: + raise ParamValidationError( + '--delay must be a non-negative integer, ' + 'got %s' % value + ) + if self._serialized_name == 'MaxAttempts' and value < 1: + raise ParamValidationError( + '--max-attempts must be a positive integer, ' + 'got %s' % value + ) + waiter_config = parameters.get('WaiterConfig', {}) + waiter_config[self._serialized_name] = value + parameters['WaiterConfig'] = waiter_config def register_add_waiters(cli): @@ -203,7 +272,9 @@ def _build_polling_description(self, delay, max_attempts): description = ( ' It will poll every %s seconds until a successful state ' 'has been reached. This will exit with a return code of 255 ' - 'after %s failed checks.' % (delay, max_attempts) + 'after %s failed checks. You can override the default polling ' + 'behavior with ``--delay`` and ``--max-attempts``.' + % (delay, max_attempts) ) return description @@ -212,6 +283,10 @@ class WaiterCaller: def __init__(self, session, waiter_name): self._session = session self._waiter_name = waiter_name + self._waiter_config = None + + def set_waiter_config(self, waiter_config): + self._waiter_config = waiter_config def invoke(self, service_name, operation_name, parameters, parsed_globals): client = self._session.create_client( @@ -221,6 +296,8 @@ def invoke(self, service_name, operation_name, parameters, parsed_globals): verify=parsed_globals.verify_ssl, ) waiter = client.get_waiter(xform_name(self._waiter_name)) + if self._waiter_config is not None: + parameters = dict(parameters, WaiterConfig=self._waiter_config) waiter.wait(**parameters) return 0 @@ -228,6 +305,25 @@ def invoke(self, service_name, operation_name, parameters, parsed_globals): class WaiterStateCommand(ServiceOperation): DESCRIPTION = '' + def _create_argument_table(self): + argument_table = super()._create_argument_table() + argument_table['delay'] = WaiterArgument( + 'delay', DELAY_HELP, 'Delay', + ) + argument_table['max-attempts'] = WaiterArgument( + 'max-attempts', MAX_ATTEMPTS_HELP, 'MaxAttempts', + ) + return argument_table + + def _build_call_parameters(self, args, arg_table): + service_params = super()._build_call_parameters(args, arg_table) + # Strip WaiterConfig from call parameters so it doesn't leak + # into --generate-cli-skeleton or other API model validation. + # The WaiterCaller will inject it back when calling waiter.wait(). + waiter_config = service_params.pop('WaiterConfig', None) + self._operation_caller.set_waiter_config(waiter_config) + return service_params + def create_help_command(self): help_command = super(WaiterStateCommand, self).create_help_command() # Change the operation object's description by changing it to the diff --git a/tests/unit/customizations/test_waiters.py b/tests/unit/customizations/test_waiters.py index 11673e877f2d..13ca4686855b 100644 --- a/tests/unit/customizations/test_waiters.py +++ b/tests/unit/customizations/test_waiters.py @@ -16,6 +16,7 @@ from awscli.customizations.exceptions import ParamValidationError from awscli.customizations.waiters import ( WaitCommand, + WaiterArgument, WaiterCaller, WaiterStateCommand, WaiterStateCommandBuilder, @@ -166,6 +167,13 @@ def test_wait_state_help_command(self): self.assert_contains('``--filters`` (list)') self.assert_contains('======\nOutput\n======\n\nNone') + def test_wait_state_help_shows_waiter_params(self): + self.driver.main(['ec2', 'wait', 'instance-running', 'help']) + self.assert_contains('[--delay ]') + self.assert_contains('[--max-attempts ]') + self.assert_contains('``--delay``') + self.assert_contains('``--max-attempts``') + class TestWait(BaseAWSCommandParamsTest): """This is merely a smoke test. @@ -203,6 +211,53 @@ def test_rds_jobs_complete(self): } self.assert_params_for_cmd(cmdline, result) + def test_ec2_wait_with_delay_and_max_attempts(self): + cmdline = 'ec2 wait instance-running' + cmdline += ' --instance-ids i-12345678' + cmdline += ' --delay 10 --max-attempts 100' + # WaiterConfig is popped by botocore's Waiter.wait() before + # the API call, so only API params appear in last_kwargs. + result = {'InstanceIds': ['i-12345678']} + self.parsed_response = { + 'Reservations': [{'Instances': [{'State': {'Name': 'running'}}]}] + } + self.assert_params_for_cmd(cmdline, result) + + def test_ec2_wait_forwards_waiter_config_end_to_end(self): + # End-to-end check: --delay and --max-attempts must flow from the + # command line all the way into waiter.wait() as WaiterConfig, + # alongside the regular API parameters. + with mock.patch( + 'botocore.client.BaseClient.get_waiter' + ) as mock_get_waiter: + mock_waiter = mock.Mock() + mock_get_waiter.return_value = mock_waiter + cmdline = ( + 'ec2 wait instance-running ' + '--instance-ids i-12345678 ' + '--delay 10 --max-attempts 100' + ) + self.run_cmd(cmdline, expected_rc=0) + mock_get_waiter.assert_called_with('instance_running') + mock_waiter.wait.assert_called_once_with( + InstanceIds=['i-12345678'], + WaiterConfig={'Delay': 10, 'MaxAttempts': 100}, + ) + + def test_ec2_wait_without_waiter_flags_omits_waiter_config(self): + # When neither --delay nor --max-attempts are supplied, no + # WaiterConfig kwarg should be passed to waiter.wait(). + with mock.patch( + 'botocore.client.BaseClient.get_waiter' + ) as mock_get_waiter: + mock_waiter = mock.Mock() + mock_get_waiter.return_value = mock_waiter + cmdline = 'ec2 wait instance-running --instance-ids i-12345678' + self.run_cmd(cmdline, expected_rc=0) + mock_waiter.wait.assert_called_once_with( + InstanceIds=['i-12345678'] + ) + class TestWaiterStateCommandBuilder(unittest.TestCase): def setUp(self): @@ -258,13 +313,17 @@ def test_build_waiter_state_cmds(self): instance_running_cmd.DESCRIPTION, 'My waiter description. It will poll every 1 seconds until ' 'a successful state has been reached. This will exit with a ' - 'return code of 255 after 10 failed checks.', + 'return code of 255 after 10 failed checks. You can override ' + 'the default polling behavior with ``--delay`` and ' + '``--max-attempts``.', ) self.assertEqual( bucket_exists_cmd.DESCRIPTION, 'My waiter description. It will poll every 1 seconds until ' 'a successful state has been reached. This will exit with a ' - 'return code of 255 after 10 failed checks.', + 'return code of 255 after 10 failed checks. You can override ' + 'the default polling behavior with ``--delay`` and ' + '``--max-attempts``.', ) @@ -299,7 +358,9 @@ def test_config_provided_description(self): description, 'My description. It will poll every 5 seconds until a ' 'successful state has been reached. This will exit with a ' - 'return code of 255 after 20 failed checks.', + 'return code of 255 after 20 failed checks. You can override ' + 'the default polling behavior with ``--delay`` and ' + '``--max-attempts``.', ) def test_error_acceptor(self): @@ -311,7 +372,9 @@ def test_error_acceptor(self): 'Wait until MyException is thrown when polling with ' '``my-operation``. It will poll every 5 seconds until a ' 'successful state has been reached. This will exit with a ' - 'return code of 255 after 20 failed checks.', + 'return code of 255 after 20 failed checks. You can override ' + 'the default polling behavior with ``--delay`` and ' + '``--max-attempts``.', ) def test_status_acceptor(self): @@ -323,7 +386,9 @@ def test_status_acceptor(self): 'Wait until 200 response is received when polling with ' '``my-operation``. It will poll every 5 seconds until a ' 'successful state has been reached. This will exit with a ' - 'return code of 255 after 20 failed checks.', + 'return code of 255 after 20 failed checks. You can override ' + 'the default polling behavior with ``--delay`` and ' + '``--max-attempts``.', ) def test_path_acceptor(self): @@ -336,7 +401,9 @@ def test_path_acceptor(self): 'Wait until JMESPath query MyResource.name returns running when ' 'polling with ``my-operation``. It will poll every 5 seconds ' 'until a successful state has been reached. This will exit with ' - 'a return code of 255 after 20 failed checks.', + 'a return code of 255 after 20 failed checks. You can override ' + 'the default polling behavior with ``--delay`` and ' + '``--max-attempts``.', ) def test_path_all_acceptor(self): @@ -349,7 +416,9 @@ def test_path_all_acceptor(self): 'Wait until JMESPath query MyResource[].name returns running for ' 'all elements when polling with ``my-operation``. It will poll ' 'every 5 seconds until a successful state has been reached. ' - 'This will exit with a return code of 255 after 20 failed checks.', + 'This will exit with a return code of 255 after 20 failed checks. ' + 'You can override the default polling behavior with ``--delay`` ' + 'and ``--max-attempts``.', ) def test_path_any_acceptor(self): @@ -362,7 +431,9 @@ def test_path_any_acceptor(self): 'Wait until JMESPath query MyResource[].name returns running for ' 'any element when polling with ``my-operation``. It will poll ' 'every 5 seconds until a successful state has been reached. ' - 'This will exit with a return code of 255 after 20 failed checks.', + 'This will exit with a return code of 255 after 20 failed checks. ' + 'You can override the default polling behavior with ``--delay`` ' + 'and ``--max-attempts``.', ) @@ -399,3 +470,129 @@ def test_invoke(self): # Ensure the wait command was called properly. waiter.wait.assert_called_with(Foo='bar', Baz='biz') + + def test_invoke_with_set_waiter_config(self): + waiter = mock.Mock() + waiter_name = 'my_waiter' + session = mock.Mock() + session.create_client.return_value.get_waiter.return_value = waiter + + parameters = {'Foo': 'bar'} + parsed_globals = mock.Mock() + parsed_globals.region = 'us-east-1' + parsed_globals.endpoint_url = 'myurl' + parsed_globals.verify_ssl = True + + waiter_caller = WaiterCaller(session, waiter_name) + waiter_caller.set_waiter_config({'Delay': 10, 'MaxAttempts': 50}) + waiter_caller.invoke( + 'myservice', 'MyWaiter', parameters, parsed_globals + ) + + waiter.wait.assert_called_with( + Foo='bar', + WaiterConfig={'Delay': 10, 'MaxAttempts': 50}, + ) + + def test_invoke_without_waiter_config(self): + waiter = mock.Mock() + waiter_name = 'my_waiter' + session = mock.Mock() + session.create_client.return_value.get_waiter.return_value = waiter + + parameters = {'Foo': 'bar'} + parsed_globals = mock.Mock() + parsed_globals.region = 'us-east-1' + parsed_globals.endpoint_url = 'myurl' + parsed_globals.verify_ssl = True + + waiter_caller = WaiterCaller(session, waiter_name) + waiter_caller.invoke( + 'myservice', 'MyWaiter', parameters, parsed_globals + ) + + # Without set_waiter_config, no WaiterConfig should be passed. + waiter.wait.assert_called_with(Foo='bar') + + +class TestWaiterArgument(unittest.TestCase): + def test_add_delay_to_params(self): + arg = WaiterArgument('delay', 'help text', 'Delay') + params = {} + arg.add_to_params(params, 10) + self.assertEqual(params, {'WaiterConfig': {'Delay': 10}}) + + def test_add_max_attempts_to_params(self): + arg = WaiterArgument('max-attempts', 'help text', 'MaxAttempts') + params = {} + arg.add_to_params(params, 50) + self.assertEqual(params, {'WaiterConfig': {'MaxAttempts': 50}}) + + def test_both_params_together(self): + delay_arg = WaiterArgument('delay', 'help text', 'Delay') + max_arg = WaiterArgument('max-attempts', 'help text', 'MaxAttempts') + params = {} + delay_arg.add_to_params(params, 10) + max_arg.add_to_params(params, 50) + self.assertEqual( + params, + {'WaiterConfig': {'Delay': 10, 'MaxAttempts': 50}}, + ) + + def test_none_value_not_added(self): + arg = WaiterArgument('delay', 'help text', 'Delay') + params = {} + arg.add_to_params(params, None) + self.assertEqual(params, {}) + + def test_cli_name(self): + arg = WaiterArgument('delay', 'help text', 'Delay') + self.assertEqual(arg.cli_name, '--delay') + + def test_cli_type_name(self): + arg = WaiterArgument('delay', 'help text', 'Delay') + self.assertEqual(arg.cli_type_name, 'integer') + + def test_required_defaults_to_false(self): + arg = WaiterArgument('delay', 'help text', 'Delay') + self.assertFalse(arg.required) + + def test_delay_rejects_negative(self): + arg = WaiterArgument('delay', 'help text', 'Delay') + params = {} + with self.assertRaises(Exception) as ctx: + arg.add_to_params(params, -1) + self.assertIn('--delay', str(ctx.exception)) + self.assertEqual(params, {}) + + def test_delay_rejects_negative_large(self): + arg = WaiterArgument('delay', 'help text', 'Delay') + params = {} + with self.assertRaises(Exception): + arg.add_to_params(params, -100) + + def test_delay_accepts_zero(self): + arg = WaiterArgument('delay', 'help text', 'Delay') + params = {} + arg.add_to_params(params, 0) + self.assertEqual(params, {'WaiterConfig': {'Delay': 0}}) + + def test_max_attempts_rejects_zero(self): + arg = WaiterArgument('max-attempts', 'help text', 'MaxAttempts') + params = {} + with self.assertRaises(Exception) as ctx: + arg.add_to_params(params, 0) + self.assertIn('--max-attempts', str(ctx.exception)) + self.assertEqual(params, {}) + + def test_max_attempts_rejects_negative(self): + arg = WaiterArgument('max-attempts', 'help text', 'MaxAttempts') + params = {} + with self.assertRaises(Exception): + arg.add_to_params(params, -1) + + def test_max_attempts_accepts_one(self): + arg = WaiterArgument('max-attempts', 'help text', 'MaxAttempts') + params = {} + arg.add_to_params(params, 1) + self.assertEqual(params, {'WaiterConfig': {'MaxAttempts': 1}})