Skip to content

Commit

Permalink
checkpoint off by default
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 3, 2022
1 parent c4b8f16 commit facf810
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 13 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'x-clip',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.8.0',
version = '0.8.2',
license='MIT',
description = 'X-CLIP',
author = 'Phil Wang',
Expand Down
16 changes: 4 additions & 12 deletions x_clip/x_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,7 @@ def matrix_diag(t):

# checkpointing helper function

def make_checkpointable(fn, **kwargs):
if isinstance(fn, nn.ModuleList):
return [maybe(make_checkpointable)(el, **kwargs) for el in fn]

condition = kwargs.pop('condition', None)

if exists(condition) and not condition(fn):
return fn

def make_checkpointable(fn):
@wraps(fn)
def inner(*args):
input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])
Expand Down Expand Up @@ -246,12 +238,12 @@ def forward(
mask = None
):
can_checkpoint = self.training and self.checkpoint_during_training
checkpoint_fn = make_checkpointable if can_checkpoint else identity

x = self.norm_in(x)

for attn, ff in self.layers:
if can_checkpoint:
attn, ff = map(make_checkpointable, (attn, ff))
attn, ff = map(checkpoint_fn, (attn, ff))

x = attn(x, mask, rotary_pos_emb) + x
x = ff(x) + x
Expand Down Expand Up @@ -404,7 +396,7 @@ def __init__(
simclr_temperature = 0.1,
image_ssl_loss_weight = 0.05,
multiview_loss_weight = 0.1,
checkpoint_during_training = True,
checkpoint_during_training = False,
**kwargs
):
super().__init__()
Expand Down

0 comments on commit facf810

Please sign in to comment.