diff --git a/lecture_014/A_Practitioners_Guide_to_Triton.ipynb b/lecture_014/A_Practitioners_Guide_to_Triton.ipynb index af4503e..98d5ac0 100644 --- a/lecture_014/A_Practitioners_Guide_to_Triton.ipynb +++ b/lecture_014/A_Practitioners_Guide_to_Triton.ipynb @@ -1221,8 +1221,11 @@ " acc = tl.zeros((bm, bn), dtype=tl.float32)\n", " for _ in range(0, k, bk):\n", " # todo umer: don't we need mask when loading a & b?\n", - " a = tl.load(offs_a)\n", - " b = tl.load(offs_b)\n", + " mask_a = get_2d_mask(rm, rn, m, k) + " mask_b = get_2d_mask(rm, rn, k, n) + + " a = tl.load(offs_a, mask_a)\n", + " b = tl.load(offs_b, mask_b)\n", " acc += tl.dot(a, b, allow_tf32=False) # matmul in block ; Weirdness: allow_tf32 must be set to False for older GPUs, otherwise won't compile\n", " # increase offets, so next iteration loads next chunks\n", " offs_a += bk * stride_ak\n",