Skip to content

get_batch: potential off-by-one in torch.randint upper bound #56

@saslifat-gif

Description

@saslifat-gif

In get_batch, the random starting index is sampled as:

ix = torch.randint(len(data) - block_size, (batch_size,))

But y is sliced as:

y = torch.stack([data[i + 1: i + block_size + 1] for i in ix])

When i equals len(data) - block_size - 1 (the max value from randint), y needs to reach index len(data), which is one past the end. Python slicing clips silently so it won't crash, but the last chunk in y could be 1 token short.
Suggested fix:

ix = torch.randint(len(data) - block_size - 1, (batch_size,))

This ensures every sampled chunk has enough room for both x and y.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions