Add capture_time_hooks to make_graphed_callables for non-capturable per-callable hooks#2831
Add capture_time_hooks to make_graphed_callables for non-capturable per-callable hooks#2831buptzyb wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Robin Zhang <robinz@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a
Confidence Score: 3/5Not safe to merge as-is; two P1 behavioral regressions need fixing before the feature is reliable. The forward-hook output-type mismatch (flattened vs. raw) and the silently changed pre/post_warmup_hook invocation semantics are both present defects in the changed code that can produce wrong behavior for existing and new callers without any runtime error. transformer_engine/pytorch/graph.py — specifically the _run_warmup_forward forward-hook call and the pre/post_warmup_hook placement Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant make_graphed_callables
participant _run_warmup_forward
participant _run_warmup_backward
participant CUDAGraph
Caller->>make_graphed_callables: capture_time_hooks=[{forward_pre, forward, pre_backward, backward}]
make_graphed_callables->>make_graphed_callables: pre_warmup_hook() [once globally]
loop warmup_iter in range(num_warmup_iters)
make_graphed_callables->>_run_warmup_forward: (func_idx, func)
_run_warmup_forward->>Caller: forward_pre hooks(func, args, kwargs) [outside capture]
_run_warmup_forward->>_run_warmup_forward: func(*args, **kwargs)
_run_warmup_forward->>Caller: forward hooks(func, args, FLATTENED outputs) [outside capture]
make_graphed_callables->>_run_warmup_backward: (func_idx, func, outputs, warmup_iter)
_run_warmup_backward->>Caller: pre_backward hooks(func) [outside capture]
_run_warmup_backward->>_run_warmup_backward: torch.autograd.backward(...)
_run_warmup_backward->>Caller: backward hooks(func) [outside capture]
end
make_graphed_callables->>make_graphed_callables: post_warmup_hook() [once globally]
note over make_graphed_callables,CUDAGraph: Graph Capture Phase
make_graphed_callables->>Caller: forward_pre hooks(func, args, kwargs) [outside capture]
make_graphed_callables->>CUDAGraph: begin capture
CUDAGraph->>CUDAGraph: func(*args, **kwargs) [recorded]
make_graphed_callables->>CUDAGraph: end capture
make_graphed_callables->>Caller: forward hooks(func, args, RAW outputs) [outside capture]
make_graphed_callables->>Caller: pre_backward hooks(func) [outside capture]
make_graphed_callables->>CUDAGraph: begin capture
CUDAGraph->>CUDAGraph: autograd.backward(...) [recorded]
make_graphed_callables->>CUDAGraph: end capture
make_graphed_callables->>Caller: backward hooks(func) [outside capture]
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| outputs, _ = _tree_flatten(func(*args, **kwargs)) | ||
| for hook in hooks: | ||
| hook.remove() | ||
|
|
||
| if ( | ||
| capture_time_hooks is not None | ||
| and func_idx < len(capture_time_hooks) | ||
| and capture_time_hooks[func_idx] is not None | ||
| and "forward" in capture_time_hooks[func_idx] | ||
| ): | ||
| for hook in capture_time_hooks[func_idx]["forward"].values(): | ||
| hook(func, args, outputs) |
There was a problem hiding this comment.
forward hook receives inconsistent output types between warmup and capture
During warmup, _run_warmup_forward calls _tree_flatten on the raw outputs and then passes the already-flattened list to the forward hook:
outputs, _ = _tree_flatten(func(*args, **kwargs))
# ...
hook(func, args, outputs) # ← outputs is a flat list of tensorsBut in both capture paths (_order is None and _order is not None, lines ~690 and ~913), the forward hook is called with the raw, unflattened return value of func(...):
with _graph_context_wrapper(fwd_graph, pool=mempool):
outputs = func(*args, **kwargs) # raw module output
# ...
hook(func, args, outputs) # ← outputs is the original nested structureAny hook that inspects or stores output will see different types depending on whether it is called during warmup or during capture. For FSDP re-shard hooks that key off tensor identities or shapes, this can silently produce wrong behavior. The fix is to separate the flatten step from the hook call in _run_warmup_forward so the hook always receives the original (unflattened) return value.
| with torch.cuda.stream(torch.cuda.Stream()): | ||
| for func_idx, func in zip(warmup_func_idx, warmup_func): | ||
| args = sample_args[func_idx] | ||
| kwargs = sample_kwargs[func_idx] | ||
| static_input_surface = per_callable_static_input_surfaces[func_idx] | ||
|
|
||
| def hook_fn( | ||
| module, inputs, outputs, func_idx=func_idx | ||
| ): # pylint: disable=unused-argument | ||
| modules = set() | ||
| if isinstance(module, TransformerEngineBaseModule): | ||
| modules.add(module) | ||
| # If forward is called on a BasicOperation directly the hook will run | ||
| elif isinstance(module, BasicOperation): | ||
| modules.add(module) | ||
| # If forward is called on a te.ops.Sequential it is not called on its constituent ops | ||
| elif isinstance(module, Sequential): | ||
| if module._module_groups is None: | ||
| raise RuntimeError( | ||
| "module._module_groups should have been initialized by warmup" | ||
| ) | ||
| for module_group in module._module_groups: | ||
| if isinstance(module_group, OperationFuser): | ||
| for basic_op in module_group._basic_ops: | ||
| modules.add(basic_op) | ||
| if modules: | ||
| if func_idx not in visited_te_modules: | ||
| visited_te_modules[func_idx] = modules | ||
| else: | ||
| visited_te_modules[func_idx].update(modules) | ||
|
|
||
| if pre_warmup_hook is not None: | ||
| pre_warmup_hook() | ||
| for warmup_iter in range(num_warmup_iters): | ||
| hooks = [] | ||
| for module in func.modules(): | ||
| hook = module.register_forward_hook(hook_fn) | ||
| hooks.append(hook) | ||
| outputs, _ = _tree_flatten(func(*args, **kwargs)) | ||
| for hook in hooks: | ||
| hook.remove() | ||
| if is_training: | ||
| inputs = tuple(i for i in static_input_surface if i.requires_grad) | ||
| with _none_grad_context_wrapper(inputs): | ||
| outputs_requiring_grad = tuple( | ||
| o for o in outputs if o is not None and o.requires_grad | ||
| ) | ||
| torch.autograd.backward( | ||
| outputs_requiring_grad, | ||
| grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad), | ||
| ) | ||
| grad_inputs = tuple(input.grad for input in inputs) | ||
| if pre_warmup_hook is not None: | ||
| pre_warmup_hook() | ||
|
|
||
| # Filter module params that get None grad from grad_inputs and remove them | ||
| # from static_input_surface. This is to ensure that the backward hooks | ||
| # registered to these params are not wrongly triggered. | ||
| num_required_grad_sample_args = sum( | ||
| arg.requires_grad for arg in flatten_sample_args[func_idx] | ||
| ) | ||
| required_grad_input_idx = [] | ||
| for i, arg in enumerate(static_input_surface): | ||
| if arg.requires_grad: | ||
| required_grad_input_idx.append(i) | ||
| module_params_with_grad = [] | ||
| for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx): | ||
| if ( | ||
| grad_inputs[grad_inputs_idx] is None | ||
| and grad_inputs_idx < num_required_grad_sample_args | ||
| ): | ||
| if not allow_unused_input: | ||
| raise RuntimeError( | ||
| "The input tensor requires grad, but the grad is None after" | ||
| " backward pass." | ||
| for warmup_iter in range(num_warmup_iters): | ||
| if _order is None: | ||
| # All forwards in order, then all backwards in reverse order. | ||
| warmup_outputs = [] | ||
| for func_idx, func in zip(warmup_func_idx, warmup_func): | ||
| outputs = _run_warmup_forward(func_idx, func) | ||
| warmup_outputs.append((func_idx, func, outputs)) | ||
| if is_training: | ||
| for func_idx, func, outputs in reversed(warmup_outputs): | ||
| _run_warmup_backward(func_idx, func, outputs, warmup_iter) | ||
| else: | ||
| # Follow _order exactly, mirroring the capture phase. | ||
| per_fwd_outputs = {} # per_callable_fwd_idx -> flattened outputs | ||
| fwd_idx = [0] * num_model_chunks | ||
| bwd_idx = [0] * num_model_chunks | ||
| for c_id in _order: | ||
| if c_id > 0: | ||
| # Forward pass for chunk c_id. | ||
| m_chunk = c_id - 1 | ||
| for l_no in range(_num_layers_per_chunk[m_chunk]): | ||
| per_callable_fwd_idx = ( | ||
| _prefix_num_layers[m_chunk] * num_microbatches | ||
| ) + (fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no) | ||
| func = callables[_prefix_num_layers[m_chunk] + l_no] | ||
| outputs = _run_warmup_forward(per_callable_fwd_idx, func) | ||
| per_fwd_outputs[per_callable_fwd_idx] = outputs | ||
| fwd_idx[m_chunk] += 1 | ||
| elif ceil(c_id) == c_id: | ||
| # Backward pass for chunk -c_id. | ||
| if is_training: | ||
| m_chunk = -c_id - 1 | ||
| for l_no in reversed(range(_num_layers_per_chunk[m_chunk])): | ||
| per_callable_bwd_idx = ( | ||
| _prefix_num_layers[m_chunk] * num_microbatches | ||
| ) + (bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no) | ||
| func = callables[_prefix_num_layers[m_chunk] + l_no] | ||
| outputs = per_fwd_outputs[per_callable_bwd_idx] | ||
| _run_warmup_backward( | ||
| per_callable_bwd_idx, func, outputs, warmup_iter | ||
| ) | ||
| elif ( | ||
| grad_inputs[grad_inputs_idx] is not None | ||
| and grad_inputs_idx >= num_required_grad_sample_args | ||
| ): | ||
| module_params_with_grad.append(static_input_surface[inputs_idx]) | ||
| if len(module_params_with_grad) != len(per_callable_module_params[func_idx]): | ||
| if warmup_iter != 0: | ||
| raise RuntimeError( | ||
| "no-grad params should only be used as inputs in the first warmup" | ||
| f" iteration, but found in iteration {warmup_iter}" | ||
| ) | ||
| per_callable_module_params[func_idx] = tuple(module_params_with_grad) | ||
| static_input_surface = flatten_sample_args[func_idx] + tuple( | ||
| module_params_with_grad | ||
| ) | ||
| per_callable_static_input_surfaces[func_idx] = static_input_surface | ||
|
|
||
| # Run wgrad. This is essential for some TE modules when they have | ||
| # delay_wgrad_compute enabled. | ||
| need_backward_dw = False | ||
| for module in visited_te_modules.get(func_idx, set()): | ||
| if hasattr(module, "need_backward_dw") and module.need_backward_dw(): | ||
| need_backward_dw = True | ||
| module.backward_dw() | ||
| need_bwd_dw_graph[func_idx] = need_backward_dw | ||
| else: | ||
| grad_inputs = None | ||
| del outputs, grad_inputs | ||
| if post_warmup_hook is not None: | ||
| post_warmup_hook() | ||
| bwd_idx[m_chunk] += 1 | ||
|
|
||
| if post_warmup_hook is not None: | ||
| post_warmup_hook() |
There was a problem hiding this comment.
pre_warmup_hook/post_warmup_hook invocation count changed silently
In the original code these hooks were called once per callable, wrapping that callable's num_warmup_iters warmup iterations:
# original
for func_idx, func in zip(warmup_func_idx, warmup_func):
if pre_warmup_hook is not None:
pre_warmup_hook()
for warmup_iter in range(num_warmup_iters):
...
if post_warmup_hook is not None:
post_warmup_hook()The refactor moved the hooks outside both loops, so they now fire once globally, regardless of how many callables are registered. If a caller passes a pre_warmup_hook that relies on being invoked per-callable (e.g. to set up per-callable CUDA state before each callable's warmup), it will silently run only once and skip N-1 callables. This is an unintentional behavior regression for existing users of these hooks.
There was a problem hiding this comment.
We should confirm that this doesn't cause problems, and probably update the docs for pre_warmup_hook/post_warmup_hook to be more explicit about its behavior. That said, I don't think this function will break the use-cases from https://github.com/NVIDIA/TransformerEngine/pull/2435/changes#r2850937378.
CC @lhb8125
| - ``"forward_pre"``: dict of hooks called *before* the forward pass. | ||
| Hook signature: ``hook(module, args, kwargs)``. |
There was a problem hiding this comment.
Doesn't this sound weird? Let's make it "pre_forward" to be consistent with "pre_backward".
| - ``"forward_pre"``: dict of hooks called *before* the forward pass. | ||
| Hook signature: ``hook(module, args, kwargs)``. |
There was a problem hiding this comment.
Why is capture_time_hooks[callable_idx]["forward_pre"] a dict? As far as I can tell, we just iterate through the hooks, so it would be more straightforward to make it a list (and similar for "forward", "pre_backward", "backward").
| with torch.cuda.stream(torch.cuda.Stream()): | ||
| for func_idx, func in zip(warmup_func_idx, warmup_func): | ||
| args = sample_args[func_idx] | ||
| kwargs = sample_kwargs[func_idx] | ||
| static_input_surface = per_callable_static_input_surfaces[func_idx] | ||
|
|
||
| def hook_fn( | ||
| module, inputs, outputs, func_idx=func_idx | ||
| ): # pylint: disable=unused-argument | ||
| modules = set() | ||
| if isinstance(module, TransformerEngineBaseModule): | ||
| modules.add(module) | ||
| # If forward is called on a BasicOperation directly the hook will run | ||
| elif isinstance(module, BasicOperation): | ||
| modules.add(module) | ||
| # If forward is called on a te.ops.Sequential it is not called on its constituent ops | ||
| elif isinstance(module, Sequential): | ||
| if module._module_groups is None: | ||
| raise RuntimeError( | ||
| "module._module_groups should have been initialized by warmup" | ||
| ) | ||
| for module_group in module._module_groups: | ||
| if isinstance(module_group, OperationFuser): | ||
| for basic_op in module_group._basic_ops: | ||
| modules.add(basic_op) | ||
| if modules: | ||
| if func_idx not in visited_te_modules: | ||
| visited_te_modules[func_idx] = modules | ||
| else: | ||
| visited_te_modules[func_idx].update(modules) | ||
|
|
||
| if pre_warmup_hook is not None: | ||
| pre_warmup_hook() | ||
| for warmup_iter in range(num_warmup_iters): | ||
| hooks = [] | ||
| for module in func.modules(): | ||
| hook = module.register_forward_hook(hook_fn) | ||
| hooks.append(hook) | ||
| outputs, _ = _tree_flatten(func(*args, **kwargs)) | ||
| for hook in hooks: | ||
| hook.remove() | ||
| if is_training: | ||
| inputs = tuple(i for i in static_input_surface if i.requires_grad) | ||
| with _none_grad_context_wrapper(inputs): | ||
| outputs_requiring_grad = tuple( | ||
| o for o in outputs if o is not None and o.requires_grad | ||
| ) | ||
| torch.autograd.backward( | ||
| outputs_requiring_grad, | ||
| grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad), | ||
| ) | ||
| grad_inputs = tuple(input.grad for input in inputs) | ||
| if pre_warmup_hook is not None: | ||
| pre_warmup_hook() | ||
|
|
||
| # Filter module params that get None grad from grad_inputs and remove them | ||
| # from static_input_surface. This is to ensure that the backward hooks | ||
| # registered to these params are not wrongly triggered. | ||
| num_required_grad_sample_args = sum( | ||
| arg.requires_grad for arg in flatten_sample_args[func_idx] | ||
| ) | ||
| required_grad_input_idx = [] | ||
| for i, arg in enumerate(static_input_surface): | ||
| if arg.requires_grad: | ||
| required_grad_input_idx.append(i) | ||
| module_params_with_grad = [] | ||
| for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx): | ||
| if ( | ||
| grad_inputs[grad_inputs_idx] is None | ||
| and grad_inputs_idx < num_required_grad_sample_args | ||
| ): | ||
| if not allow_unused_input: | ||
| raise RuntimeError( | ||
| "The input tensor requires grad, but the grad is None after" | ||
| " backward pass." | ||
| for warmup_iter in range(num_warmup_iters): | ||
| if _order is None: | ||
| # All forwards in order, then all backwards in reverse order. | ||
| warmup_outputs = [] | ||
| for func_idx, func in zip(warmup_func_idx, warmup_func): | ||
| outputs = _run_warmup_forward(func_idx, func) | ||
| warmup_outputs.append((func_idx, func, outputs)) | ||
| if is_training: | ||
| for func_idx, func, outputs in reversed(warmup_outputs): | ||
| _run_warmup_backward(func_idx, func, outputs, warmup_iter) | ||
| else: | ||
| # Follow _order exactly, mirroring the capture phase. | ||
| per_fwd_outputs = {} # per_callable_fwd_idx -> flattened outputs | ||
| fwd_idx = [0] * num_model_chunks | ||
| bwd_idx = [0] * num_model_chunks | ||
| for c_id in _order: | ||
| if c_id > 0: | ||
| # Forward pass for chunk c_id. | ||
| m_chunk = c_id - 1 | ||
| for l_no in range(_num_layers_per_chunk[m_chunk]): | ||
| per_callable_fwd_idx = ( | ||
| _prefix_num_layers[m_chunk] * num_microbatches | ||
| ) + (fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no) | ||
| func = callables[_prefix_num_layers[m_chunk] + l_no] | ||
| outputs = _run_warmup_forward(per_callable_fwd_idx, func) | ||
| per_fwd_outputs[per_callable_fwd_idx] = outputs | ||
| fwd_idx[m_chunk] += 1 | ||
| elif ceil(c_id) == c_id: | ||
| # Backward pass for chunk -c_id. | ||
| if is_training: | ||
| m_chunk = -c_id - 1 | ||
| for l_no in reversed(range(_num_layers_per_chunk[m_chunk])): | ||
| per_callable_bwd_idx = ( | ||
| _prefix_num_layers[m_chunk] * num_microbatches | ||
| ) + (bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no) | ||
| func = callables[_prefix_num_layers[m_chunk] + l_no] | ||
| outputs = per_fwd_outputs[per_callable_bwd_idx] | ||
| _run_warmup_backward( | ||
| per_callable_bwd_idx, func, outputs, warmup_iter | ||
| ) | ||
| elif ( | ||
| grad_inputs[grad_inputs_idx] is not None | ||
| and grad_inputs_idx >= num_required_grad_sample_args | ||
| ): | ||
| module_params_with_grad.append(static_input_surface[inputs_idx]) | ||
| if len(module_params_with_grad) != len(per_callable_module_params[func_idx]): | ||
| if warmup_iter != 0: | ||
| raise RuntimeError( | ||
| "no-grad params should only be used as inputs in the first warmup" | ||
| f" iteration, but found in iteration {warmup_iter}" | ||
| ) | ||
| per_callable_module_params[func_idx] = tuple(module_params_with_grad) | ||
| static_input_surface = flatten_sample_args[func_idx] + tuple( | ||
| module_params_with_grad | ||
| ) | ||
| per_callable_static_input_surfaces[func_idx] = static_input_surface | ||
|
|
||
| # Run wgrad. This is essential for some TE modules when they have | ||
| # delay_wgrad_compute enabled. | ||
| need_backward_dw = False | ||
| for module in visited_te_modules.get(func_idx, set()): | ||
| if hasattr(module, "need_backward_dw") and module.need_backward_dw(): | ||
| need_backward_dw = True | ||
| module.backward_dw() | ||
| need_bwd_dw_graph[func_idx] = need_backward_dw | ||
| else: | ||
| grad_inputs = None | ||
| del outputs, grad_inputs | ||
| if post_warmup_hook is not None: | ||
| post_warmup_hook() | ||
| bwd_idx[m_chunk] += 1 | ||
|
|
||
| if post_warmup_hook is not None: | ||
| post_warmup_hook() |
There was a problem hiding this comment.
We should confirm that this doesn't cause problems, and probably update the docs for pre_warmup_hook/post_warmup_hook to be more explicit about its behavior. That said, I don't think this function will break the use-cases from https://github.com/NVIDIA/TransformerEngine/pull/2435/changes#r2850937378.
CC @lhb8125
|
|
||
| if ( | ||
| capture_time_hooks is not None | ||
| and func_idx < len(capture_time_hooks) |
There was a problem hiding this comment.
I'd prefer if length mismatches triggered errors. This pattern is inviting silent bugs.
| pre_warmup_hook: Optional[Callable] = None, | ||
| post_warmup_hook: Optional[Callable] = None, | ||
| capture_time_hooks: Optional[List[Optional[Dict[str, Dict]]]] = None, |
There was a problem hiding this comment.
Nit: It's kind of uncomfortable that the APIs for the warmup and capture hooks are inconsistent. Also, the capture_time_hooks is slightly misleading because they are applied during warmup, not just during capture.
| outputs, _ = _tree_flatten(func(*args, **kwargs)) | ||
| for hook in hooks: | ||
| hook.remove() | ||
|
|
||
| if ( | ||
| capture_time_hooks is not None | ||
| and func_idx < len(capture_time_hooks) | ||
| and capture_time_hooks[func_idx] is not None | ||
| and "forward" in capture_time_hooks[func_idx] | ||
| ): | ||
| for hook in capture_time_hooks[func_idx]["forward"].values(): | ||
| hook(func, args, outputs) |
Description
Summary
make_graphed_callablespreviously had no mechanism to run per-callablehooks during warmup/capture that must execute outside the CUDA graph
capture context (i.e., hooks that are inherently non-capturable, such as
CPU-side state updates or FSDP parameter un-shard/re-shard calls).
This PR adds a new
capture_time_hooksparameter tomake_graphed_callablesthat accepts per-callable hooks invoked atcapture time (warmup iterations and graph capture), but intentionally
executed outside the CUDA graph capture context so they are not
recorded into the graph and will not be replayed.
Changes
transformer_engine/pytorch/graph.py:capture_time_hooks: Optional[List[Optional[Dict[str, Dict]]]]parameter to
make_graphed_callablesiterations and graph capture, in both the
_order is not Noneand_order is Nonecapture paths_forward_pre_hooks/_forward_hooksformat:Dict[hook_type, Dict[handle_id, hook_fn]]where hook types are
forward_pre,forward,pre_backward,backwardcapture_hooks→capture_time_hookswithupdated docstring clarifying the non-capturable semantics
Motivation
Used by Megatron-LM's FSDP integration: during CUDA Graph capture,
PyTorch's memory allocator is frozen, causing FSDP parameter
un-shard/re-shard (which requires allocation) to fail. By routing FSDP
hooks through
capture_time_hooks, they execute outside the capturecontext and are manually driven at the right points during
warmup/capture, while the graph itself only records pure GPU compute.
Hook Invocation Order
For each callable at each warmup/capture iteration:
forward_prehooks — beforefunc(*args, **kwargs)forwardhooks — after forwardpre_backwardhooks — beforetorch.autograd.backwardbackwardhooks — after backwardType of change
Checklist: