diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 8a1877c9fe84d..accad0c6f34af 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -804,15 +804,20 @@ def _training_session_run_helper(self, is_train, inputs, inputs_desc, outputs_de else: iobinding = self._eval_io_binding + # Get the list of the actual session inputs because unused inputs can be removed. + input_nodes = self._training_session.get_inputs() + input_node_names = [input_node.name for input_node in input_nodes] + # Bind input tensors for input, input_desc in zip(inputs, inputs_desc): - device_index = _utils.get_device_index_from_input(input) - iobinding.bind_input(input_desc.name, - input.device.type, - device_index, - _utils.dtype_torch_to_numpy(input.dtype), - list(input.size()), - input.data_ptr()) + if input_desc.name in input_node_names: + device_index = _utils.get_device_index_from_input(input) + iobinding.bind_input(input_desc.name, + input.device.type, + device_index, + _utils.dtype_torch_to_numpy(input.dtype), + list(input.size()), + input.data_ptr()) # Bind output tensors outputs_desc_resolved = self._resolve_symbolic_dimensions(inputs, inputs_desc, outputs_desc) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index c6c17482a5c08..23ad32f5b62fa 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -1428,3 +1428,21 @@ def testORTTrainerOptionsDisabledAdasumFlag(test_input): actual_values = orttrainer_options.ORTTrainerOptions(test_input) assert actual_values.distributed.enable_adasum == False + +def testORTTrainerUnusedInput(): + class UnusedInputModel(torch.nn.Module): + def __init__(self): + super(UnusedInputModel, self).__init__() + def forward(self, x, y): + return torch.mean(x) + + model = UnusedInputModel() + model_desc = {'inputs': [('x', [1]), ('y', [1])], 'outputs': [('loss', [], True)]} + optim_config = optim.LambConfig(lr=0.001) + trainer = orttrainer.ORTTrainer(model, model_desc, optim_config) + + # Run just one step to make sure there are no iobinding errors for the unused input. + try: + trainer.train_step(torch.FloatTensor([1.0]), torch.FloatTensor([1.0])) + except RuntimeError: + pytest.fail("RuntimeError doing train_step with an unused input.")