Skip to content

Add capture_time_hooks to make_graphed_callables for non-capturable per-callable hooks#2831

Open
buptzyb wants to merge 2 commits intoNVIDIA:mainfrom
buptzyb:robinz/capture-time-hooks
Open

Add capture_time_hooks to make_graphed_callables for non-capturable per-callable hooks#2831
buptzyb wants to merge 2 commits intoNVIDIA:mainfrom
buptzyb:robinz/capture-time-hooks

Conversation

@buptzyb
Copy link
Copy Markdown
Contributor

@buptzyb buptzyb commented Apr 3, 2026

Description

Summary

make_graphed_callables previously had no mechanism to run per-callable
hooks during warmup/capture that must execute outside the CUDA graph
capture context (i.e., hooks that are inherently non-capturable, such as
CPU-side state updates or FSDP parameter un-shard/re-shard calls).

This PR adds a new capture_time_hooks parameter to
make_graphed_callables that accepts per-callable hooks invoked at
capture time (warmup iterations and graph capture), but intentionally
executed outside the CUDA graph capture context so they are not
recorded into the graph and will not be replayed.

Changes

  • transformer_engine/pytorch/graph.py:
    • Add capture_time_hooks: Optional[List[Optional[Dict[str, Dict]]]]
      parameter to make_graphed_callables
    • Invoke hooks around forward and backward passes during both warmup
      iterations and graph capture, in both the _order is not None and
      _order is None capture paths
    • Hook dict structure mirrors PyTorch's _forward_pre_hooks /
      _forward_hooks format: Dict[hook_type, Dict[handle_id, hook_fn]]
      where hook types are forward_pre, forward, pre_backward,
      backward
    • Rename parameter from capture_hookscapture_time_hooks with
      updated docstring clarifying the non-capturable semantics

Motivation

Used by Megatron-LM's FSDP integration: during CUDA Graph capture,
PyTorch's memory allocator is frozen, causing FSDP parameter
un-shard/re-shard (which requires allocation) to fail. By routing FSDP
hooks through capture_time_hooks, they execute outside the capture
context and are manually driven at the right points during
warmup/capture, while the graph itself only records pure GPU compute.

Hook Invocation Order

For each callable at each warmup/capture iteration:

  1. forward_pre hooks — before func(*args, **kwargs)
  2. (CUDA graph capture context entered)
  3. Forward pass
  4. (CUDA graph capture context exited)
  5. forward hooks — after forward
  6. pre_backward hooks — before torch.autograd.backward
  7. Backward pass
  8. backward hooks — after backward

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

buptzyb and others added 2 commits April 3, 2026 05:20
Signed-off-by: Robin Zhang <robinz@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 3, 2026

Greptile Summary

This PR adds a capture_time_hooks parameter to make_graphed_callables, letting callers inject per-callable hooks that run outside the CUDA graph capture context (e.g., FSDP un-shard/re-shard). It also refactors the monolithic warmup loop into _run_warmup_forward / _run_warmup_backward helpers and unifies the _order is None / _order is not None warmup paths.

  • P1 – hook output inconsistency: _run_warmup_forward passes flattened outputs to forward hooks, while both capture paths pass the raw (unflattened) return value of func(...). Any hook that inspects output tensors will observe different types depending on the phase.
  • P1 – pre_warmup_hook/post_warmup_hook regression: These hooks were previously called once per callable; after the refactor they are called once globally, silently breaking callers that relied on per-callable invocation.

Confidence Score: 3/5

Not safe to merge as-is; two P1 behavioral regressions need fixing before the feature is reliable.

The forward-hook output-type mismatch (flattened vs. raw) and the silently changed pre/post_warmup_hook invocation semantics are both present defects in the changed code that can produce wrong behavior for existing and new callers without any runtime error.

transformer_engine/pytorch/graph.py — specifically the _run_warmup_forward forward-hook call and the pre/post_warmup_hook placement

Important Files Changed

Filename Overview
transformer_engine/pytorch/graph.py Adds capture_time_hooks parameter and refactors warmup into helper functions; two behavioral issues: (1) forward hooks get flattened outputs during warmup but raw outputs during capture, breaking hook API contract; (2) pre/post_warmup_hook are now called once globally instead of once per callable, silently changing existing semantics.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant make_graphed_callables
    participant _run_warmup_forward
    participant _run_warmup_backward
    participant CUDAGraph

    Caller->>make_graphed_callables: capture_time_hooks=[{forward_pre, forward, pre_backward, backward}]
    make_graphed_callables->>make_graphed_callables: pre_warmup_hook() [once globally]

    loop warmup_iter in range(num_warmup_iters)
        make_graphed_callables->>_run_warmup_forward: (func_idx, func)
        _run_warmup_forward->>Caller: forward_pre hooks(func, args, kwargs) [outside capture]
        _run_warmup_forward->>_run_warmup_forward: func(*args, **kwargs)
        _run_warmup_forward->>Caller: forward hooks(func, args, FLATTENED outputs) [outside capture]
        make_graphed_callables->>_run_warmup_backward: (func_idx, func, outputs, warmup_iter)
        _run_warmup_backward->>Caller: pre_backward hooks(func) [outside capture]
        _run_warmup_backward->>_run_warmup_backward: torch.autograd.backward(...)
        _run_warmup_backward->>Caller: backward hooks(func) [outside capture]
    end

    make_graphed_callables->>make_graphed_callables: post_warmup_hook() [once globally]

    note over make_graphed_callables,CUDAGraph: Graph Capture Phase
    make_graphed_callables->>Caller: forward_pre hooks(func, args, kwargs) [outside capture]
    make_graphed_callables->>CUDAGraph: begin capture
    CUDAGraph->>CUDAGraph: func(*args, **kwargs) [recorded]
    make_graphed_callables->>CUDAGraph: end capture
    make_graphed_callables->>Caller: forward hooks(func, args, RAW outputs) [outside capture]
    make_graphed_callables->>Caller: pre_backward hooks(func) [outside capture]
    make_graphed_callables->>CUDAGraph: begin capture
    CUDAGraph->>CUDAGraph: autograd.backward(...) [recorded]
    make_graphed_callables->>CUDAGraph: end capture
    make_graphed_callables->>Caller: backward hooks(func) [outside capture]
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +492 to +503
outputs, _ = _tree_flatten(func(*args, **kwargs))
for hook in hooks:
hook.remove()

if (
capture_time_hooks is not None
and func_idx < len(capture_time_hooks)
and capture_time_hooks[func_idx] is not None
and "forward" in capture_time_hooks[func_idx]
):
for hook in capture_time_hooks[func_idx]["forward"].values():
hook(func, args, outputs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 forward hook receives inconsistent output types between warmup and capture

During warmup, _run_warmup_forward calls _tree_flatten on the raw outputs and then passes the already-flattened list to the forward hook:

outputs, _ = _tree_flatten(func(*args, **kwargs))
# ...
hook(func, args, outputs)   # ← outputs is a flat list of tensors

But in both capture paths (_order is None and _order is not None, lines ~690 and ~913), the forward hook is called with the raw, unflattened return value of func(...):

with _graph_context_wrapper(fwd_graph, pool=mempool):
    outputs = func(*args, **kwargs)   # raw module output
# ...
hook(func, args, outputs)   # ← outputs is the original nested structure

Any hook that inspects or stores output will see different types depending on whether it is called during warmup or during capture. For FSDP re-shard hooks that key off tensor identities or shapes, this can silently produce wrong behavior. The fix is to separate the flatten step from the hook call in _run_warmup_forward so the hook always receives the original (unflattened) return value.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point.

Comment on lines 583 to +630
with torch.cuda.stream(torch.cuda.Stream()):
for func_idx, func in zip(warmup_func_idx, warmup_func):
args = sample_args[func_idx]
kwargs = sample_kwargs[func_idx]
static_input_surface = per_callable_static_input_surfaces[func_idx]

def hook_fn(
module, inputs, outputs, func_idx=func_idx
): # pylint: disable=unused-argument
modules = set()
if isinstance(module, TransformerEngineBaseModule):
modules.add(module)
# If forward is called on a BasicOperation directly the hook will run
elif isinstance(module, BasicOperation):
modules.add(module)
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
elif isinstance(module, Sequential):
if module._module_groups is None:
raise RuntimeError(
"module._module_groups should have been initialized by warmup"
)
for module_group in module._module_groups:
if isinstance(module_group, OperationFuser):
for basic_op in module_group._basic_ops:
modules.add(basic_op)
if modules:
if func_idx not in visited_te_modules:
visited_te_modules[func_idx] = modules
else:
visited_te_modules[func_idx].update(modules)

if pre_warmup_hook is not None:
pre_warmup_hook()
for warmup_iter in range(num_warmup_iters):
hooks = []
for module in func.modules():
hook = module.register_forward_hook(hook_fn)
hooks.append(hook)
outputs, _ = _tree_flatten(func(*args, **kwargs))
for hook in hooks:
hook.remove()
if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs):
outputs_requiring_grad = tuple(
o for o in outputs if o is not None and o.requires_grad
)
torch.autograd.backward(
outputs_requiring_grad,
grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad),
)
grad_inputs = tuple(input.grad for input in inputs)
if pre_warmup_hook is not None:
pre_warmup_hook()

# Filter module params that get None grad from grad_inputs and remove them
# from static_input_surface. This is to ensure that the backward hooks
# registered to these params are not wrongly triggered.
num_required_grad_sample_args = sum(
arg.requires_grad for arg in flatten_sample_args[func_idx]
)
required_grad_input_idx = []
for i, arg in enumerate(static_input_surface):
if arg.requires_grad:
required_grad_input_idx.append(i)
module_params_with_grad = []
for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx):
if (
grad_inputs[grad_inputs_idx] is None
and grad_inputs_idx < num_required_grad_sample_args
):
if not allow_unused_input:
raise RuntimeError(
"The input tensor requires grad, but the grad is None after"
" backward pass."
for warmup_iter in range(num_warmup_iters):
if _order is None:
# All forwards in order, then all backwards in reverse order.
warmup_outputs = []
for func_idx, func in zip(warmup_func_idx, warmup_func):
outputs = _run_warmup_forward(func_idx, func)
warmup_outputs.append((func_idx, func, outputs))
if is_training:
for func_idx, func, outputs in reversed(warmup_outputs):
_run_warmup_backward(func_idx, func, outputs, warmup_iter)
else:
# Follow _order exactly, mirroring the capture phase.
per_fwd_outputs = {} # per_callable_fwd_idx -> flattened outputs
fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks
for c_id in _order:
if c_id > 0:
# Forward pass for chunk c_id.
m_chunk = c_id - 1
for l_no in range(_num_layers_per_chunk[m_chunk]):
per_callable_fwd_idx = (
_prefix_num_layers[m_chunk] * num_microbatches
) + (fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no)
func = callables[_prefix_num_layers[m_chunk] + l_no]
outputs = _run_warmup_forward(per_callable_fwd_idx, func)
per_fwd_outputs[per_callable_fwd_idx] = outputs
fwd_idx[m_chunk] += 1
elif ceil(c_id) == c_id:
# Backward pass for chunk -c_id.
if is_training:
m_chunk = -c_id - 1
for l_no in reversed(range(_num_layers_per_chunk[m_chunk])):
per_callable_bwd_idx = (
_prefix_num_layers[m_chunk] * num_microbatches
) + (bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no)
func = callables[_prefix_num_layers[m_chunk] + l_no]
outputs = per_fwd_outputs[per_callable_bwd_idx]
_run_warmup_backward(
per_callable_bwd_idx, func, outputs, warmup_iter
)
elif (
grad_inputs[grad_inputs_idx] is not None
and grad_inputs_idx >= num_required_grad_sample_args
):
module_params_with_grad.append(static_input_surface[inputs_idx])
if len(module_params_with_grad) != len(per_callable_module_params[func_idx]):
if warmup_iter != 0:
raise RuntimeError(
"no-grad params should only be used as inputs in the first warmup"
f" iteration, but found in iteration {warmup_iter}"
)
per_callable_module_params[func_idx] = tuple(module_params_with_grad)
static_input_surface = flatten_sample_args[func_idx] + tuple(
module_params_with_grad
)
per_callable_static_input_surfaces[func_idx] = static_input_surface

# Run wgrad. This is essential for some TE modules when they have
# delay_wgrad_compute enabled.
need_backward_dw = False
for module in visited_te_modules.get(func_idx, set()):
if hasattr(module, "need_backward_dw") and module.need_backward_dw():
need_backward_dw = True
module.backward_dw()
need_bwd_dw_graph[func_idx] = need_backward_dw
else:
grad_inputs = None
del outputs, grad_inputs
if post_warmup_hook is not None:
post_warmup_hook()
bwd_idx[m_chunk] += 1

if post_warmup_hook is not None:
post_warmup_hook()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 pre_warmup_hook/post_warmup_hook invocation count changed silently

In the original code these hooks were called once per callable, wrapping that callable's num_warmup_iters warmup iterations:

# original
for func_idx, func in zip(warmup_func_idx, warmup_func):
    if pre_warmup_hook is not None:
        pre_warmup_hook()
    for warmup_iter in range(num_warmup_iters):
        ...
    if post_warmup_hook is not None:
        post_warmup_hook()

The refactor moved the hooks outside both loops, so they now fire once globally, regardless of how many callables are registered. If a caller passes a pre_warmup_hook that relies on being invoked per-callable (e.g. to set up per-callable CUDA state before each callable's warmup), it will silently run only once and skip N-1 callables. This is an unintentional behavior regression for existing users of these hooks.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should confirm that this doesn't cause problems, and probably update the docs for pre_warmup_hook/post_warmup_hook to be more explicit about its behavior. That said, I don't think this function will break the use-cases from https://github.com/NVIDIA/TransformerEngine/pull/2435/changes#r2850937378.

CC @lhb8125

Comment on lines +1377 to +1378
- ``"forward_pre"``: dict of hooks called *before* the forward pass.
Hook signature: ``hook(module, args, kwargs)``.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this sound weird? Let's make it "pre_forward" to be consistent with "pre_backward".

Comment on lines +1377 to +1378
- ``"forward_pre"``: dict of hooks called *before* the forward pass.
Hook signature: ``hook(module, args, kwargs)``.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is capture_time_hooks[callable_idx]["forward_pre"] a dict? As far as I can tell, we just iterate through the hooks, so it would be more straightforward to make it a list (and similar for "forward", "pre_backward", "backward").

Comment on lines 583 to +630
with torch.cuda.stream(torch.cuda.Stream()):
for func_idx, func in zip(warmup_func_idx, warmup_func):
args = sample_args[func_idx]
kwargs = sample_kwargs[func_idx]
static_input_surface = per_callable_static_input_surfaces[func_idx]

def hook_fn(
module, inputs, outputs, func_idx=func_idx
): # pylint: disable=unused-argument
modules = set()
if isinstance(module, TransformerEngineBaseModule):
modules.add(module)
# If forward is called on a BasicOperation directly the hook will run
elif isinstance(module, BasicOperation):
modules.add(module)
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
elif isinstance(module, Sequential):
if module._module_groups is None:
raise RuntimeError(
"module._module_groups should have been initialized by warmup"
)
for module_group in module._module_groups:
if isinstance(module_group, OperationFuser):
for basic_op in module_group._basic_ops:
modules.add(basic_op)
if modules:
if func_idx not in visited_te_modules:
visited_te_modules[func_idx] = modules
else:
visited_te_modules[func_idx].update(modules)

if pre_warmup_hook is not None:
pre_warmup_hook()
for warmup_iter in range(num_warmup_iters):
hooks = []
for module in func.modules():
hook = module.register_forward_hook(hook_fn)
hooks.append(hook)
outputs, _ = _tree_flatten(func(*args, **kwargs))
for hook in hooks:
hook.remove()
if is_training:
inputs = tuple(i for i in static_input_surface if i.requires_grad)
with _none_grad_context_wrapper(inputs):
outputs_requiring_grad = tuple(
o for o in outputs if o is not None and o.requires_grad
)
torch.autograd.backward(
outputs_requiring_grad,
grad_tensors=tuple(torch.empty_like(o) for o in outputs_requiring_grad),
)
grad_inputs = tuple(input.grad for input in inputs)
if pre_warmup_hook is not None:
pre_warmup_hook()

# Filter module params that get None grad from grad_inputs and remove them
# from static_input_surface. This is to ensure that the backward hooks
# registered to these params are not wrongly triggered.
num_required_grad_sample_args = sum(
arg.requires_grad for arg in flatten_sample_args[func_idx]
)
required_grad_input_idx = []
for i, arg in enumerate(static_input_surface):
if arg.requires_grad:
required_grad_input_idx.append(i)
module_params_with_grad = []
for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx):
if (
grad_inputs[grad_inputs_idx] is None
and grad_inputs_idx < num_required_grad_sample_args
):
if not allow_unused_input:
raise RuntimeError(
"The input tensor requires grad, but the grad is None after"
" backward pass."
for warmup_iter in range(num_warmup_iters):
if _order is None:
# All forwards in order, then all backwards in reverse order.
warmup_outputs = []
for func_idx, func in zip(warmup_func_idx, warmup_func):
outputs = _run_warmup_forward(func_idx, func)
warmup_outputs.append((func_idx, func, outputs))
if is_training:
for func_idx, func, outputs in reversed(warmup_outputs):
_run_warmup_backward(func_idx, func, outputs, warmup_iter)
else:
# Follow _order exactly, mirroring the capture phase.
per_fwd_outputs = {} # per_callable_fwd_idx -> flattened outputs
fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks
for c_id in _order:
if c_id > 0:
# Forward pass for chunk c_id.
m_chunk = c_id - 1
for l_no in range(_num_layers_per_chunk[m_chunk]):
per_callable_fwd_idx = (
_prefix_num_layers[m_chunk] * num_microbatches
) + (fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no)
func = callables[_prefix_num_layers[m_chunk] + l_no]
outputs = _run_warmup_forward(per_callable_fwd_idx, func)
per_fwd_outputs[per_callable_fwd_idx] = outputs
fwd_idx[m_chunk] += 1
elif ceil(c_id) == c_id:
# Backward pass for chunk -c_id.
if is_training:
m_chunk = -c_id - 1
for l_no in reversed(range(_num_layers_per_chunk[m_chunk])):
per_callable_bwd_idx = (
_prefix_num_layers[m_chunk] * num_microbatches
) + (bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no)
func = callables[_prefix_num_layers[m_chunk] + l_no]
outputs = per_fwd_outputs[per_callable_bwd_idx]
_run_warmup_backward(
per_callable_bwd_idx, func, outputs, warmup_iter
)
elif (
grad_inputs[grad_inputs_idx] is not None
and grad_inputs_idx >= num_required_grad_sample_args
):
module_params_with_grad.append(static_input_surface[inputs_idx])
if len(module_params_with_grad) != len(per_callable_module_params[func_idx]):
if warmup_iter != 0:
raise RuntimeError(
"no-grad params should only be used as inputs in the first warmup"
f" iteration, but found in iteration {warmup_iter}"
)
per_callable_module_params[func_idx] = tuple(module_params_with_grad)
static_input_surface = flatten_sample_args[func_idx] + tuple(
module_params_with_grad
)
per_callable_static_input_surfaces[func_idx] = static_input_surface

# Run wgrad. This is essential for some TE modules when they have
# delay_wgrad_compute enabled.
need_backward_dw = False
for module in visited_te_modules.get(func_idx, set()):
if hasattr(module, "need_backward_dw") and module.need_backward_dw():
need_backward_dw = True
module.backward_dw()
need_bwd_dw_graph[func_idx] = need_backward_dw
else:
grad_inputs = None
del outputs, grad_inputs
if post_warmup_hook is not None:
post_warmup_hook()
bwd_idx[m_chunk] += 1

if post_warmup_hook is not None:
post_warmup_hook()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should confirm that this doesn't cause problems, and probably update the docs for pre_warmup_hook/post_warmup_hook to be more explicit about its behavior. That said, I don't think this function will break the use-cases from https://github.com/NVIDIA/TransformerEngine/pull/2435/changes#r2850937378.

CC @lhb8125


if (
capture_time_hooks is not None
and func_idx < len(capture_time_hooks)
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer if length mismatches triggered errors. This pattern is inviting silent bugs.

Comment on lines 1323 to +1325
pre_warmup_hook: Optional[Callable] = None,
post_warmup_hook: Optional[Callable] = None,
capture_time_hooks: Optional[List[Optional[Dict[str, Dict]]]] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: It's kind of uncomfortable that the APIs for the warmup and capture hooks are inconsistent. Also, the capture_time_hooks is slightly misleading because they are applied during warmup, not just during capture.

Comment on lines +492 to +503
outputs, _ = _tree_flatten(func(*args, **kwargs))
for hook in hooks:
hook.remove()

if (
capture_time_hooks is not None
and func_idx < len(capture_time_hooks)
and capture_time_hooks[func_idx] is not None
and "forward" in capture_time_hooks[func_idx]
):
for hook in capture_time_hooks[func_idx]["forward"].values():
hook(func, args, outputs)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants