@@ -36,12 +36,12 @@ def __init__(self, d_in: int, d_out: int):
3636
3737 super ().__init__ ()
3838 std = math .sqrt (2 / (d_in + d_out ))
39- self .weight : Float [Tensor , d_out d_in ] = nn .Parameter (
39+ self .weight : Float [Tensor , " d_out d_in" ] = nn .Parameter (
4040 nn .init .trunc_normal_ (torch .empty (d_out , d_in ), std = std , a = - 3 * std , b = 3 * std ),
4141 requires_grad = True ,
4242 )
4343
44- def forward (self , x : Float [Tensor , ... d_in ]) -> Float [Tensor , ... d_out ]:
44+ def forward (self , x : Float [Tensor , " ... d_in" ]) -> Float [Tensor , " ... d_out" ]:
4545 return einsum (x , self .weight , "... d_in, d_out d_in -> ... d_out" )
4646
4747 def extra_repr (self ):
@@ -59,7 +59,7 @@ def __init__(self, vocab_size: int, d_model: int):
5959 requires_grad = True ,
6060 )
6161
62- def forward (self , token_ids : Int [Tensor , ...]) -> Float [Tensor , ... d_model ]:
62+ def forward (self , token_ids : Int [Tensor , " ..." ]) -> Float [Tensor , " ... d_model" ]:
6363 return self .weight [token_ids , :]
6464
6565 def extra_repr (self ):
@@ -78,7 +78,7 @@ def __init__(self, context_length: int, dim: int, theta: float = 10000.0):
7878 @staticmethod
7979 def _init_cache (
8080 context_length : int , dim : int , theta : float
81- ) -> Float [Tensor , 2 context_length half_dim ]:
81+ ) -> Float [Tensor , " 2 context_length half_dim" ]:
8282 assert dim % 2 == 0
8383
8484 d = torch .arange (0 , dim , 2 ) / dim
@@ -91,8 +91,8 @@ def _init_cache(
9191 return torch .stack ((cos , sin ))
9292
9393 def forward (
94- self , x : Float [Tensor , ... seq d ], pos_ids : Int [Tensor , ... seq ]
95- ) -> Float [Tensor , ... seq d ]:
94+ self , x : Float [Tensor , " ... seq d" ], pos_ids : Int [Tensor , " ... seq" ]
95+ ) -> Float [Tensor , " ... seq d" ]:
9696 x1 , x2 = rearrange (x , "... (half_d xy) -> xy ... half_d" , xy = 2 )
9797
9898 # einx
@@ -172,9 +172,9 @@ def __init__(
172172
173173 def forward (
174174 self ,
175- x : Float [Tensor , ... seq d_k ],
176- token_positions : Int [Tensor , ... seq ] | None = None ,
177- ) -> Float [Tensor , ... seq d_v ]:
175+ x : Float [Tensor , " ... seq d_k" ],
176+ token_positions : Int [Tensor , " ... seq" ] | None = None ,
177+ ) -> Float [Tensor , " ... seq d_v" ]:
178178 """
179179 Args:
180180 x: The input to perform multi-headed self-attention on.
0 commit comments