1
- from typing import Any , ClassVar , Optional , Type , TypeVar , Union , cast
1
+ from typing import Any , ClassVar , Optional , Type , TypeVar , Union
2
2
3
- from abc import ABC , abstractmethod
3
+ from abc import ABC
4
4
from dataclasses import dataclass
5
- from enum import Enum
6
5
7
6
import sympy
8
- import torch
9
-
10
- from .. import ops
11
7
12
8
from . import context
13
9
from . import dtype
10
+ from .shaped_type import ShapedType , ShapedDataType
14
11
15
12
__all__ = [
16
13
"backed_sym_index_type" ,
17
14
"sym" ,
18
15
"BoundedRelation" ,
19
16
"EqualRelation" ,
20
- "Grid" ,
21
17
"IndexingContext" ,
22
18
"IndexRelation" ,
23
19
"IndexExpr" ,
24
20
"IndexSymbol" ,
25
- "InputBuffer" ,
26
- "KernelBuffer" ,
27
- "OutputBuffer" ,
28
21
"SymIndex" ,
29
- "TemporaryBuffer" ,
30
22
]
31
23
32
24
DataType = dtype .DataType
@@ -74,270 +66,12 @@ def __getattr__(self, n):
74
66
SymbolicDimable = Union [str , IndexExpr ]
75
67
SymbolicShapeable = tuple [SymbolicDimable ]
76
68
SymbolicShapeExpr = tuple [IndexExpr ]
77
-
78
-
79
- def make_symbolic_shape (elements : SymbolicShapeable ) -> SymbolicShapeExpr :
80
- return tuple (
81
- index_symbol (expr ) if isinstance (expr , str ) else expr for expr in elements
82
- )
83
-
84
-
85
- ###############################################################################
86
- # Grid
87
- ###############################################################################
88
-
89
-
90
- class _GridMeta (type ):
91
- """Meta-class for a symbolically shaped grid."""
92
-
93
- def __new__ (
94
- mcls ,
95
- name : str ,
96
- bases ,
97
- dct ,
98
- * ,
99
- symbolic_shape : Optional [SymbolicShapeExpr ],
100
- ):
101
- new_class = type .__new__ (mcls , name , bases , dct )
102
- new_class .symbolic_shape = symbolic_shape
103
- new_class .rank = len (symbolic_shape ) if symbolic_shape is not None else None
104
- new_class .__qualname__ = repr (new_class )
105
- return new_class
106
-
107
- def __repr__ (self ):
108
- if self .symbolic_shape :
109
- return f"Grid[{ ', ' .join (repr (s ) for s in self .symbolic_shape )} ]"
110
- else :
111
- return "Grid"
112
-
113
-
114
- class Grid (metaclass = _GridMeta , symbolic_shape = None ):
115
- """Grid with bounding symbolic shape information in the type."""
116
-
117
- symbolic_shape : ClassVar [Optional [SymbolicShapeExpr ]]
118
- # TODO: dims should also allow dynamic dimensions.
119
- dims : list [int ]
120
- rank : int
121
-
122
- def __init__ (self ):
123
- # Resolve the symbolic shape to concrete values.
124
- idxc = IndexingContext .current ()
125
- if self .symbolic_shape :
126
- dims = [idxc .get_static_value (dim ) for dim in self .symbolic_shape ]
127
- if None in dims :
128
- raise ValueError (f"NYI: Dynamic dims in Grid" )
129
- self .dims = cast (list [int ], dims )
130
- else :
131
- self .dims = []
132
-
133
- # Shadow the type rank with the actual, which makes it concrete
134
- # for the generic case.
135
- self .rank = len (self .dims )
136
-
137
- def __class_getitem__ (
138
- cls , symbolic_shape : Union [SymbolicDimable , tuple [SymbolicShapeable ]]
139
- ) -> Type ["Grid" ]:
140
- if not isinstance (symbolic_shape , tuple ):
141
- symbolic_shape = (symbolic_shape ,)
142
- return cast (Grid , _make_shaped_grid (cls , make_symbolic_shape (symbolic_shape )))
143
-
144
- def __repr__ (self ):
145
- return f"{ repr (type (self ))} ({ ', ' .join (str (i ) for i in self .dims )} )"
146
-
147
- def __getitem__ (self , index : int ) -> int :
148
- return self .dims [index ]
149
-
150
- def __len__ (self ) -> int :
151
- return len (self .dims )
152
-
153
- def __iter__ (self ):
154
- return iter (self .dims )
155
-
156
-
157
- def _make_shaped_grid (cls : Type [Grid ], symbolic_shape : tuple [IndexExpr ]):
158
- class ShapedGrid (Grid , symbolic_shape = symbolic_shape ):
159
- ...
160
-
161
- return ShapedGrid
162
-
163
-
164
- ###############################################################################
165
- # KernelBuffer
166
- ###############################################################################
167
-
168
69
Dims = list [Union [None , IndexSymbol , int ]]
169
70
170
-
171
- class KernelBufferUsage (Enum ):
172
- NONE = 0
173
- INPUT = 1
174
- OUTPUT = 2
175
- TEMPORARY = 3
176
-
177
- @staticmethod
178
- def _type_name (v ) -> str :
179
- if v == KernelBufferUsage .NONE :
180
- return "KernelBuffer"
181
- elif v == KernelBufferUsage .INPUT :
182
- return "InputBuffer"
183
- elif v == KernelBufferUsage .OUTPUT :
184
- return "OutputBuffer"
185
- elif v == KernelBufferUsage .TEMPORARY :
186
- return "TemporaryBuffer"
187
- else :
188
- raise AssertionError (f"uncovered KernelBufferUsage enum ({ v } )" )
189
-
190
-
191
- class _KernelBufferMeta (type ):
192
- """Meta-class for kernel buffers.
193
-
194
- This lets us specialize with symbolic shape information.
195
- """
196
-
197
- element_type : DataType
198
- usage : KernelBufferUsage
199
- symbolic_shape : Optional [SymbolicShapeExpr ]
200
- rank : Optional [int ]
201
-
202
- def __new__ (
203
- mcls ,
204
- name : str ,
205
- bases ,
206
- dct ,
207
- ):
208
- element_type = dct .get ("element_type" ) or DefaultDataType
209
- dct ["element_type" ] = element_type
210
- usage = dct .get ("usage" ) or KernelBufferUsage .NONE
211
- dct ["usage" ] = usage
212
- if "usage" not in dct :
213
- dct ["usage" ] = KernelBufferUsage .NONE
214
- symbolic_shape = dct .get ("symbolic_shape" )
215
- dct ["symbolic_shape" ] = symbolic_shape
216
- dct ["rank" ] = len (symbolic_shape ) if symbolic_shape is not None else None
217
- dct ["__qualname__" ] = _kernel_buffer_type_repr (
218
- element_type = element_type , usage = usage , symbolic_shape = symbolic_shape
219
- )
220
- new_class = type .__new__ (mcls , name , bases , dct )
221
- return new_class
222
-
223
- def new_subtype (
224
- cls : Type [SubtypeT ],
225
- * ,
226
- element_type : Union [NotSetType , DataType ] = NotSet ,
227
- symbolic_shape : Union [NotSetType , Optional [SymbolicShapeable ]] = NotSet ,
228
- usage : Union [NotSetType , KernelBufferUsage ] = NotSet ,
229
- ) -> Type [SubtypeT ]:
230
- init_element_type = (
231
- element_type if element_type is not NotSet else cls .element_type
232
- )
233
- init_symbolic_shape = (
234
- symbolic_shape if symbolic_shape is not NotSet else cls .symbolic_shape
235
- )
236
- init_usage = usage if usage is not NotSet else cls .usage
237
-
238
- class Subtype (cls ):
239
- element_type = init_element_type
240
- symbolic_shape = make_symbolic_shape (init_symbolic_shape )
241
- usage = init_usage
242
-
243
- return Subtype
244
-
245
- def of (cls : Type [SubtypeT ], element_type : Union [Any , DataType ]) -> Type [SubtypeT ]:
246
- return cls .new_subtype (element_type = element_type )
247
-
248
- def __repr__ (cls ):
249
- return _kernel_buffer_type_repr (
250
- element_type = cls .element_type ,
251
- usage = cls .usage ,
252
- symbolic_shape = cls .symbolic_shape ,
253
- )
254
-
255
-
256
- def is_kernel_buffer_meta_derived (t : type ) -> bool :
257
- return isinstance (t , _KernelBufferMeta )
258
-
259
-
260
- def _kernel_buffer_type_repr (
261
- * ,
262
- element_type : DataType ,
263
- usage : KernelBufferUsage ,
264
- symbolic_shape : Optional [tuple [IndexExpr ]],
265
- ) -> str :
266
- root = KernelBufferUsage ._type_name (usage )
267
- if symbolic_shape :
268
- stem = f"{ root } [{ ', ' .join (repr (s ) for s in symbolic_shape )} ]"
269
- else :
270
- stem = f"{ root } "
271
- if element_type != DefaultDataType :
272
- stem += f".of({ element_type } )"
273
- return stem
274
-
275
-
276
- class KernelBuffer (metaclass = _KernelBufferMeta ):
277
- """Represents a buffer in global memory.
278
-
279
- Top level kernels always operate on global memory via these
280
- buffers, and the primary operations that can be performed on
281
- them are loads/stores and DMAs to some form of compute
282
- capable local buffer.
283
-
284
- When executing eagerly, these are backed by a normal torch
285
- Tensor. When compiling, an appropriate duck-typed proxy
286
- is used.
287
- """
288
-
289
- usage : ClassVar [KernelBufferUsage ]
290
- symbolic_shape : ClassVar [Optional [SymbolicShapeExpr ]]
291
- rank : Optional [int ]
292
-
293
- def __init__ (self , tensor : torch .Tensor ):
294
- assert isinstance (tensor , torch .Tensor ), f"Expected Tensor but got { tensor } "
295
- type_rank = type (self ).rank
296
- tensor_rank = len (tensor .shape )
297
- if type_rank is not None and type_rank != tensor_rank :
298
- raise ValueError (
299
- f"Cannot create { type (self )} (tensor({ tensor .shape } )): mismatched symbolic rank"
300
- )
301
- self ._tensor = tensor
302
- self .rank = tensor_rank
303
-
304
- def __class_getitem__ (
305
- cls , symbolic_shape : Union [IndexExpr , SymbolicShapeExpr ]
306
- ) -> Type ["KernelBuffer" ]:
307
- if not isinstance (symbolic_shape , tuple ):
308
- symbolic_shape = (symbolic_shape ,)
309
- return cast (
310
- cls , cls .new_subtype (symbolic_shape = make_symbolic_shape (symbolic_shape ))
311
- )
312
-
313
- def __repr__ (self ):
314
- return f"{ type (self )} ({ self ._tensor } )"
315
-
316
- def __setitem__ (self , key , item ):
317
- ops .kernel_buffer_setitem (self , key , item )
318
-
319
- def __getitem__ (self , key ):
320
- return ops .kernel_buffer_getitem (self , key )
321
-
322
-
323
- class InputBuffer (KernelBuffer ):
324
- usage = KernelBufferUsage .INPUT
325
-
326
-
327
- class OutputBuffer (KernelBuffer ):
328
- usage = KernelBufferUsage .OUTPUT
329
-
330
-
331
- class TemporaryBuffer (KernelBuffer ):
332
- usage = KernelBufferUsage .TEMPORARY
333
-
334
-
335
71
###############################################################################
336
72
# IndexingContext
337
73
###############################################################################
338
74
339
- ShapedType = Union [Type [KernelBuffer ], Type [Grid ]]
340
-
341
75
342
76
@dataclass (slots = True )
343
77
class _ShapedBinding :
@@ -377,7 +111,7 @@ def __init__(self):
377
111
# Indexed by .instance
378
112
self .shaped_bindings : dict [Any , _ShapedBinding ] = {}
379
113
self .dyn_dims : list [IndexSymbol ] = []
380
- self .frozen_subs : list [IndexSymbol , int ] = []
114
+ self .frozen_subs : list [tuple [ IndexSymbol , int ] ] = []
381
115
self .unbacked_symbols : list [IndexSymbol ] = []
382
116
383
117
def next_dyn_dim (self ) -> IndexSymbol :
@@ -390,9 +124,7 @@ def new_unbacked_symbol(self) -> IndexSymbol:
390
124
self .unbacked_symbols .append (s )
391
125
return s
392
126
393
- def bind_shaped (
394
- self , instance : Any , shaped_type : ShapedType , dims : Dims
395
- ) -> _ShapedBinding :
127
+ def bind_shaped (self , instance : Any , shaped_type : ShapedType , dims : Dims ) -> None :
396
128
if instance in self .shaped_bindings :
397
129
raise ValueError (f"Argument binding { instance } is already bound" )
398
130
symbolic_shape = shaped_type .symbolic_shape
@@ -406,7 +138,7 @@ def bind_shaped(
406
138
)
407
139
self .shaped_bindings [instance ] = binding
408
140
409
- def bind_constant (self , sym : IndexSymbol , value : int ):
141
+ def bind_constant (self , sym : IndexSymbol , value : int ) -> None :
410
142
try :
411
143
self ._bind_symbol (sym , value )
412
144
except ValueError :
0 commit comments