Skip to content

Commit 41b9e45

Browse files
authored
Avoid recompilation caused by scan_layers (#9367)
1 parent 9b0b02f commit 41b9e45

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

torch_xla/experimental/scan_layers.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def scan_layers(layers: Iterable[torch.nn.Module],
3434
3535
input_data: The input to be given to the first layer from `layers`.
3636
37-
partition_fn: (Optional[Callable]) The graph parition function passed to AOTAutograd.
37+
partition_fn: (Optional[Callable]) The graph partition function passed to AOTAutograd.
3838
Since this function uses AOTAutograd to trace `fn`, you may override what computation
3939
happen in the forward and backward passes by specifying different partition functions.
4040
`default_partition` implies no activation checkpointing. You may specify
@@ -76,16 +76,12 @@ def scan_layers(layers: Iterable[torch.nn.Module],
7676
stacked_buffers = tree_map(lambda *tensors: torch.stack(tensors, dim=0),
7777
*buffers_list)
7878

79-
# Use the first layer as the example/template layer.
80-
from copy import deepcopy
81-
example_layer = deepcopy(first_layer)
82-
8379
# Define the function to apply at each step
8480
def one_layer(carry, params_buffers):
8581
# Apply the current layer's weights and biases to the example layer,
8682
# then run the resulting layer.
8783
output = torch.func.functional_call( # type: ignore
88-
example_layer, params_buffers, carry, strict=True)
84+
first_layer, params_buffers, carry, strict=True)
8985
return output, None
9086

9187
stacked_params_buffers = (stacked_params, stacked_buffers)

0 commit comments

Comments
 (0)