Skip to content

Commit

Permalink
Fix and enable few ORTModule Unit Tests (microsoft#19847)
Browse files Browse the repository at this point in the history
### Fix and enable few ORTModule Unit Tests

Fix 'test_bert_inputs_with_dynamic_shape' and
'test_bert_result_with_layerwise_recompute' generate Nan loss in ORT
run.

The root cause is, the logic to generatic attention mask test data is
not correct, only 0 or 1 is allowed in the dataset, but we see lots of
other numbers. ( The reason we don't have this using old version of
transformers for example v4.4.2 or 4.16.2 is because they don't contains
such
huggingface/transformers@d3cb288,
which increase the scaling to a bigger number, causing a overflow to
inf)

Another improvement during the investigation using convergence tools:
Don't dump the activations during model export phase, otherwise, the
dumped data might contains some PyTorch run's result making us confused
during comparing with stock PyTorch run results.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored Mar 12, 2024
1 parent 0c078df commit 3e954da
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 49 deletions.
2 changes: 1 addition & 1 deletion docs/ORTModule_Convergence_Notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ The limitation of `GlobalSubscriberManager` is, only 'nn.Module's forward output
dump the intermediate tensors in a `nn.Module`'s forward function, refer to the following example:

```diff
+ from onnxruntime.training.utils import inspect_activation
+ from onnxruntime.training.utils.hooks import inspect_activation
class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config: BloomConfig):
...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from ._subscriber_base import RuntimeStates, SubscriberBase
from ._subscriber_manager import ORT_NO_INCREASE_GLOBAL_STEP


class _InspectActivation(torch.autograd.Function):
Expand Down Expand Up @@ -176,21 +177,23 @@ def _summarize_activations(self, tensor: torch.Tensor, depth: int, name: str, st
display_name = name + " forward run" if is_forward is True else name + " backward run"
output_file_name = name + "_forward" if is_forward is True else name + "_backward"

if tensor is None or not isinstance(tensor, torch.Tensor):
print(f"{display_name} not a torch tensor, value: {tensor}")
return
# Skip dump during model pre-export output schema preparison run and export run.
if ORT_NO_INCREASE_GLOBAL_STEP[0] is False:
if tensor is None or not isinstance(tensor, torch.Tensor):
print(f"{display_name} not a torch tensor, value: {tensor}")
return

step_path = Path(step_folder)
if not step_path.exists():
step_path.mkdir(parents=True, exist_ok=False)
order_file_path = step_path / "order.txt"
tensor_file_path = step_path / output_file_name
step_path = Path(step_folder)
if not step_path.exists():
step_path.mkdir(parents=True, exist_ok=False)
order_file_path = step_path / "order.txt"
tensor_file_path = step_path / output_file_name

with order_file_path.open(mode="a", encoding="utf-8") as f:
f.write(f"{output_file_name}\n")
with order_file_path.open(mode="a", encoding="utf-8") as f:
f.write(f"{output_file_name}\n")

with tensor_file_path.open(mode="w", encoding="utf-8") as f:
_summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size)
with tensor_file_path.open(mode="w", encoding="utf-8") as f:
_summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size)


def _summarize_tensor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,24 +417,38 @@ def _get_bert_for_sequence_classification_model(
return model


def _get_bert_for_sequence_classification_sample_data(device):
"""Returns sample data to be used with BertForSequenceClassification model"""
def _generate_attention_mask_for_encoder_following_hf(batch_size, seq_length, device, past_key_values_length=0):
"""Generate attention mask for encoder following the implementation in HuggingFace.
input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
input_mask = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
labels = torch.randint(0, 1, (32,), dtype=torch.long, device=device)
Be noted: past_key_values_length is 0 for training.
return input_ids, input_mask, labels
Generate mask using this
https://github.com/huggingface/transformers/blame/4f27ee936a861f56f32ea6db138978b274008006/src/transformers/models/bert/modeling_bert.py#L974C81-L974C81
"""

attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
return attention_mask


def _get_bert_for_sequence_classification_sample_data_with_random_shapes(device):
"""Returns sample data with random shape to be used with BertForSequenceClassification model"""

x = random.randint(1, 100)
y = random.randint(1, 100)
input_ids = torch.randint(0, 100, (x, y), dtype=torch.long, device=device)
input_mask = torch.randint(0, 100, (x, y), dtype=torch.long, device=device)
labels = torch.randint(0, 1, (x,), dtype=torch.long, device=device)
bsz = random.randint(1, 100)
seq_length = random.randint(1, 100)
input_ids = torch.randint(0, 100, (bsz, seq_length), dtype=torch.long, device=device)
input_mask = _generate_attention_mask_for_encoder_following_hf(bsz, seq_length, device)
labels = torch.randint(0, 1, (bsz,), dtype=torch.long, device=device)

return input_ids, input_mask, labels


def _get_bert_for_sequence_classification_sample_data(device):
"""Returns sample data to be used with BertForSequenceClassification model"""

input_ids = torch.randint(0, 100, (32, 64), dtype=torch.long, device=device)
input_mask = _generate_attention_mask_for_encoder_following_hf(32, 64, device)
labels = torch.randint(0, 1, (32,), dtype=torch.long, device=device)

return input_ids, input_mask, labels

Expand Down Expand Up @@ -2211,32 +2225,27 @@ def run_step(model, x):
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)


# TODO(askhade): This test is failing with smaller tolerance, need to investigate! Disabling it right now to
# unblock the move to a later version of transformers to resolve security vulnerability.
# (Moving from transformers v4.4.2 to v4.30.0)
# def test_bert_inputs_with_dynamic_shape():
# # create pytorch model with dropout disabled
# pt_model = _get_bert_for_sequence_classification_model(
# "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
# )
# ort_model = ORTModule(copy.deepcopy(pt_model))
def test_bert_inputs_with_dynamic_shape():
# create pytorch model with dropout disabled
pt_model = _get_bert_for_sequence_classification_model(
"cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0
)
ort_model = ORTModule(copy.deepcopy(pt_model))

# def run_step(model, x, y, z):
# outputs = model(x, y, None, None, None, None, z)
# loss = outputs[0]
# loss.backward()
# return outputs[0]
def run_step(model, x, y, z):
outputs = model(x, y, None, None, None, None, z)
loss = outputs[0]
loss.backward()
return outputs[0]

# for _step in range(10):
# x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda")
for _step in range(10):
x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda")

# pt_p = run_step(pt_model, x, y, z)
# ort_p = run_step(ort_model, x, y, z)
pt_p = run_step(pt_model, x, y, z)
ort_p = run_step(ort_model, x, y, z)

# _test_helpers.assert_values_are_close(
# ort_p, pt_p, atol=1e-01
# ) # TODO: this assert is failing with smaller tolerance, need to investigate!!
# # _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) #TODO - enable this check after the investigation
_test_helpers.assert_values_are_close(ort_p, pt_p, atol=1e-01)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)


@pytest.mark.parametrize("device", ["cuda", "cpu"])
Expand Down Expand Up @@ -6424,9 +6433,6 @@ def run_step(model, x):
del os.environ["ORTMODULE_CONV_ALGO_SEARCH"]


@pytest.mark.skip(
reason="This test fail because bert forward loss is nan in updated transformers lib, disable for now."
)
def test_bert_result_with_layerwise_recompute():
original_val = os.environ.get("ORTMODULE_MEMORY_OPT_LEVEL", None)
# Create PyTorch model with dropout disabled.
Expand Down

0 comments on commit 3e954da

Please sign in to comment.