Skip to content

Commit

Permalink
having doubts about the skip connection proposed in the paper. defaul…
Browse files Browse the repository at this point in the history
…t to turning it off
  • Loading branch information
lucidrains committed Sep 25, 2023
1 parent d36b1bc commit 9499426
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.0',
version = '0.1.1',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand Down
11 changes: 8 additions & 3 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,9 @@ def __init__(
attn_dropout=0,
attn_flash = False,
adaptive_rmsnorm = False,
adaptive_rmsnorm_cond_dim_in = None
adaptive_rmsnorm_cond_dim_in = None,
use_unet_skip_connection = False,
skip_connect_scale = None
):
super().__init__()
assert divisible_by(depth, 2)
Expand All @@ -280,9 +282,11 @@ def __init__(
else:
rmsnorm_klass = RMSNorm

self.skip_connect_scale = default(skip_connect_scale, 2 ** -0.5)

for ind in range(depth):
layer = ind + 1
has_skip = layer > (depth // 2)
has_skip = use_unet_skip_connection and layer > (depth // 2)

self.layers.append(nn.ModuleList([
nn.Linear(dim * 2, dim) if has_skip else None,
Expand Down Expand Up @@ -315,7 +319,8 @@ def forward(
if not exists(skip_combiner):
skip_connects.append(x)
else:
x = torch.cat((x, skip_connects.pop()), dim = -1)
skip_connect = skip_connects.pop() * self.skip_connect_scale
x = torch.cat((x, skip_connect), dim = -1)
x = skip_combiner(x)

attn_input = attn_prenorm(x, **rmsnorm_kwargs)
Expand Down

0 comments on commit 9499426

Please sign in to comment.