@@ -34,7 +34,7 @@ def scan_layers(layers: Iterable[torch.nn.Module],
34
34
35
35
input_data: The input to be given to the first layer from `layers`.
36
36
37
- partition_fn: (Optional[Callable]) The graph parition function passed to AOTAutograd.
37
+ partition_fn: (Optional[Callable]) The graph partition function passed to AOTAutograd.
38
38
Since this function uses AOTAutograd to trace `fn`, you may override what computation
39
39
happen in the forward and backward passes by specifying different partition functions.
40
40
`default_partition` implies no activation checkpointing. You may specify
@@ -76,16 +76,12 @@ def scan_layers(layers: Iterable[torch.nn.Module],
76
76
stacked_buffers = tree_map (lambda * tensors : torch .stack (tensors , dim = 0 ),
77
77
* buffers_list )
78
78
79
- # Use the first layer as the example/template layer.
80
- from copy import deepcopy
81
- example_layer = deepcopy (first_layer )
82
-
83
79
# Define the function to apply at each step
84
80
def one_layer (carry , params_buffers ):
85
81
# Apply the current layer's weights and biases to the example layer,
86
82
# then run the resulting layer.
87
83
output = torch .func .functional_call ( # type: ignore
88
- example_layer , params_buffers , carry , strict = True )
84
+ first_layer , params_buffers , carry , strict = True )
89
85
return output , None
90
86
91
87
stacked_params_buffers = (stacked_params , stacked_buffers )
0 commit comments