Skip to content

pos_emb = self.position_embedding_table(...) why create this every time in forward pass? #54

@adityaghai07

Description

@adityaghai07

Problem

In the current implementation, this tensor is created on every forward pass:

pos_emb = self.position_embedding_table(torch.arange(T, device=device))

This means we are allocating a new tensor every time, even though the values are constant and do not change across iterations.

This feels unnecessary and slightly wasteful.


Proposed Solution: Register as a Buffer

We can pre-create the tensor once and register it as a buffer on the module:

self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)

# Pre-create position indices once
self.register_buffer("position_ids", torch.arange(block_size))

Then in forward():

pos = self.position_ids[:T]

Why This Helps

  • Avoids repeated tensor allocation every forward pass
  • Keeps the tensor on the correct device (CPU/GPU) automatically
  • Makes it part of the model state (saved in state_dict, but not trainable)
  • Slightly more efficient and cleaner design for static tensors

Question

Is this the recommended approach? I am also learning so will appreciate some feedback :)
If this makes sense, I can open a small PR for it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions