Skip to content

Commit f2e8283

Browse files
Update transformer config
1 parent c2d9bcb commit f2e8283

File tree

2 files changed

+45
-9
lines changed

2 files changed

+45
-9
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# @package _global_
2+
3+
# Short training run with small Transformer encoder for quick testing
4+
5+
defaults:
6+
- override /model: bert_transformer_small
7+
8+
logger:
9+
wandb:
10+
name: debug-transformer-small
11+
tags: ["debug"]
12+
13+
trainer:
14+
max_steps: 100
15+
log_every_n_steps: 5
16+
val_check_interval: 5
17+
limit_val_batches: 2
18+
check_val_every_n_epoch: null
19+
20+
model:
21+
net:
22+
embedder:
23+
d_model: 32
24+
encoder:
25+
n_layers: 2
26+
scheduler:
27+
_target_: transformers.get_cosine_schedule_with_warmup
28+
_partial_: true
29+
num_warmup_steps: 10
30+
num_training_steps: ${trainer.max_steps}
31+
32+
data:
33+
batch_size: 8
34+
per_device_batch_size: 8
35+
36+
compile: false

glm_experiments/models/components/transformer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ def __init__(self, d_in: int, d_out: int):
3636

3737
super().__init__()
3838
std = math.sqrt(2 / (d_in + d_out))
39-
self.weight: Float[Tensor, d_out d_in] = nn.Parameter(
39+
self.weight: Float[Tensor, " d_out d_in"] = nn.Parameter(
4040
nn.init.trunc_normal_(torch.empty(d_out, d_in), std=std, a=-3 * std, b=3 * std),
4141
requires_grad=True,
4242
)
4343

44-
def forward(self, x: Float[Tensor, ... d_in]) -> Float[Tensor, ... d_out]:
44+
def forward(self, x: Float[Tensor, " ... d_in"]) -> Float[Tensor, " ... d_out"]:
4545
return einsum(x, self.weight, "... d_in, d_out d_in -> ... d_out")
4646

4747
def extra_repr(self):
@@ -59,7 +59,7 @@ def __init__(self, vocab_size: int, d_model: int):
5959
requires_grad=True,
6060
)
6161

62-
def forward(self, token_ids: Int[Tensor, ...]) -> Float[Tensor, ... d_model]:
62+
def forward(self, token_ids: Int[Tensor, " ..."]) -> Float[Tensor, " ... d_model"]:
6363
return self.weight[token_ids, :]
6464

6565
def extra_repr(self):
@@ -78,7 +78,7 @@ def __init__(self, context_length: int, dim: int, theta: float = 10000.0):
7878
@staticmethod
7979
def _init_cache(
8080
context_length: int, dim: int, theta: float
81-
) -> Float[Tensor, 2 context_length half_dim]:
81+
) -> Float[Tensor, " 2 context_length half_dim"]:
8282
assert dim % 2 == 0
8383

8484
d = torch.arange(0, dim, 2) / dim
@@ -91,8 +91,8 @@ def _init_cache(
9191
return torch.stack((cos, sin))
9292

9393
def forward(
94-
self, x: Float[Tensor, ... seq d], pos_ids: Int[Tensor, ... seq]
95-
) -> Float[Tensor, ... seq d]:
94+
self, x: Float[Tensor, " ... seq d"], pos_ids: Int[Tensor, " ... seq"]
95+
) -> Float[Tensor, " ... seq d"]:
9696
x1, x2 = rearrange(x, "... (half_d xy) -> xy ... half_d", xy=2)
9797

9898
# einx
@@ -172,9 +172,9 @@ def __init__(
172172

173173
def forward(
174174
self,
175-
x: Float[Tensor, ... seq d_k],
176-
token_positions: Int[Tensor, ... seq] | None = None,
177-
) -> Float[Tensor, ... seq d_v]:
175+
x: Float[Tensor, " ... seq d_k"],
176+
token_positions: Int[Tensor, " ... seq"] | None = None,
177+
) -> Float[Tensor, " ... seq d_v"]:
178178
"""
179179
Args:
180180
x: The input to perform multi-headed self-attention on.

0 commit comments

Comments
 (0)