Skip to content

Commit

Permalink
Don't try to bind unused inputs in the Training frontend (microsoft#6166
Browse files Browse the repository at this point in the history
)
  • Loading branch information
Sergii Dymchenko authored Dec 18, 2020
1 parent adc2071 commit 824ef9a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
19 changes: 12 additions & 7 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,15 +804,20 @@ def _training_session_run_helper(self, is_train, inputs, inputs_desc, outputs_de
else:
iobinding = self._eval_io_binding

# Get the list of the actual session inputs because unused inputs can be removed.
input_nodes = self._training_session.get_inputs()
input_node_names = [input_node.name for input_node in input_nodes]

# Bind input tensors
for input, input_desc in zip(inputs, inputs_desc):
device_index = _utils.get_device_index_from_input(input)
iobinding.bind_input(input_desc.name,
input.device.type,
device_index,
_utils.dtype_torch_to_numpy(input.dtype),
list(input.size()),
input.data_ptr())
if input_desc.name in input_node_names:
device_index = _utils.get_device_index_from_input(input)
iobinding.bind_input(input_desc.name,
input.device.type,
device_index,
_utils.dtype_torch_to_numpy(input.dtype),
list(input.size()),
input.data_ptr())

# Bind output tensors
outputs_desc_resolved = self._resolve_symbolic_dimensions(inputs, inputs_desc, outputs_desc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1428,3 +1428,21 @@ def testORTTrainerOptionsDisabledAdasumFlag(test_input):

actual_values = orttrainer_options.ORTTrainerOptions(test_input)
assert actual_values.distributed.enable_adasum == False

def testORTTrainerUnusedInput():
class UnusedInputModel(torch.nn.Module):
def __init__(self):
super(UnusedInputModel, self).__init__()
def forward(self, x, y):
return torch.mean(x)

model = UnusedInputModel()
model_desc = {'inputs': [('x', [1]), ('y', [1])], 'outputs': [('loss', [], True)]}
optim_config = optim.LambConfig(lr=0.001)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config)

# Run just one step to make sure there are no iobinding errors for the unused input.
try:
trainer.train_step(torch.FloatTensor([1.0]), torch.FloatTensor([1.0]))
except RuntimeError:
pytest.fail("RuntimeError doing train_step with an unused input.")

0 comments on commit 824ef9a

Please sign in to comment.