diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 8d5472ba30601..c79ba59a07ee9 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -304,6 +304,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable ``` +#### ORTMODULE_ATEN_SDPA_FALLBACK + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's [_scaled_dot_product_efficient_attention](https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778) ATen kernel for execution when calling torch.nn.functional.scaled_dot_product_attention. NOTE: only use this feature if user model leverages memory efficient attention WITHOUT masking (ie. attn_mask=None). Utilize GPU profiling looks like NVIDIA Nsight Systems to identify if user model leverages memory efficient attention. + + ```bash + export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE + unset ORTMODULE_ATEN_SDPA_FALLBACK # DISABLE + ``` + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 22dcf4eb92411..76fe0ee91d4c6 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1794,7 +1794,20 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) { } std::vector output_args; - for (const auto& output : node_def.outputs) { + for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) { + // If the input is not used in the forward computation, we don't need it for gradient computation + // Required for ORTMODULE_ATEN_SDPA_FALLBACK + if (static_cast(output_index) >= GetSrcNodeInputSize()) { + continue; + } + + if (!IsGradientRequiredForSrcNodeInput(static_cast(output_index))) { + output_args.emplace_back(ArgDef()); + continue; + } + + const auto& output = node_def.outputs[output_index]; + if (output.find("GI(") == 0) { size_t index = static_cast(std::stoi(output.substr(3, output.length() - 4))); output_args.emplace_back(GI(index)); diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index a8590cea22887..97650f509ac88 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -25,6 +25,7 @@ # 'is_tensor' is optional, if not present, the default is False. import json +import os from onnxruntime.capi import _pybind_state as C @@ -276,3 +277,39 @@ def upsample_nearest3d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec") def upsample_bicubic2d_gradient(): return _upsample_gradient("upsample_bicubic2d_backward", 2) + + +ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None) +if ATEN_SDPA_FALLBACK: + # based on the following internal PyTorch kernel for efficient attention: + # https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784 + @register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "") + def scaled_dot_product_attention_gradient(): + return [ + ( + "Constant", + [], + ["grad_input_mask"], + {"value": {"value": [1, 1, 1, 0], "dtype": "int", "is_tensor": True}}, + ), + ( + ("ATen", "org.pytorch.aten"), + [ + "GO(0)", + "I(0)", + "I(1)", + "I(2)", + "I(3)", + "O(0)", + "O(1)", + "O(2)", + "O(3)", + "I(5)", + "grad_input_mask", + "I(6)", + "I(7)", + ], + ["GI(0)", "GI(1)", "GI(2)", ""], + {"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}}, + ), + ] diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 10e7f60b7da0f..c48968efbb262 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +import os from typing import Callable import torch @@ -969,3 +970,26 @@ def softmax(g, input, dim, dtype=None): softmax = g.op("Softmax", casted_input, axis_i=dim) return softmax + + +ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None) +if ATEN_SDPA_FALLBACK: + # based on the following internal PyTorch kernel for efficient attention: + # https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778 + @register_symbolic("scaled_dot_product_attention") + def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): + dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT) + compute_logsumexp = g.op("Constant", value_t=torch.tensor([1], dtype=torch.bool)) + return g.op( + "org.pytorch.aten::ATen", + query, + key, + value, + attn_mask, + compute_logsumexp, + dropout_p_f, + is_causal, + scale, + operator_s="_scaled_dot_product_efficient_attention", + outputs=4, + )[0] diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 6f5b03685e801..fe59c398d7abb 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6953,3 +6953,74 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): else: if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ: del os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] + + +@pytest.mark.skipif( + Version(torch.__version__) < Version("2.3.0"), + reason="torch.nn.attention module was introduced in PyTorch 2.3.0", +) +def test_aten_attention(): + from torch.nn.attention import SDPBackend, sdpa_kernel + + class _NeuralNetAttention(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v, attn_mask=None): + with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask) + + def gen_inputs(device, dtype): + return [ + torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True), + torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True), + torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True), + ] + + def run_step(model, inputs, attn_mask=None): + prediction = model(*inputs, attn_mask) + prediction.sum().backward() + return prediction + + device = "cuda" + + os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "1" # TESTING WITHOUT ATTN_MASK + + pt_model = _NeuralNetAttention().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn")) + + # reset manual seed to reset the generator + torch.manual_seed(2333) + pt_input = gen_inputs(device=device, dtype=torch.float32) + ort_input = copy.deepcopy(pt_input) + pt_prediction = run_step(pt_model, pt_input) + ort_prediction = run_step(ort_model, ort_input) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(ort_input[0].grad, pt_input[0].grad) + _test_helpers.assert_values_are_close(ort_input[1].grad, pt_input[1].grad) + _test_helpers.assert_values_are_close(ort_input[2].grad, pt_input[2].grad) + + execution_mgr = ort_model._torch_module._execution_manager._training_manager + from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name + + path = os.path.join( + execution_mgr._debug_options.save_onnx_models.path, + _get_onnx_file_name( + execution_mgr._debug_options.save_onnx_models.name_prefix, "execution_model", execution_mgr._export_mode + ), + ) + + onnx_model = onnx.load(path) + onnx_nodes = onnx_model.graph.node + + mem_eff_attn_nodes = 0 + for node in onnx_nodes: + if "ATen" in node.name: + for attr in node.attribute: + if b"_scaled_dot_product_efficient_attention" in attr.s: + mem_eff_attn_nodes += 1 + + assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found" + + del os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"]