Skip to content

Commit 18262d4

Browse files
authored
[TK] Refactor ShapedType out of KernelBuffer and Grid (#479)
This patch moves out KernelBuffer and Grid into lang and creates a ShapedType/ShapedDataType class in _support/shaped_type which can be used elsewhere. This is in preparation for VectorType in TK. This patch also does some other things: - Removes the string IndexExpr syntax for KernelBuffer/Grid, which I don't think was very useful. - Changes the syntax of KernelBufffer to include data type
1 parent bc02c45 commit 18262d4

17 files changed

+407
-328
lines changed

core/shark_turbine/kernel/_support/dtype.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_INDEX_TYPES = ["index"]
1818

1919

20+
# TODO: this should really be a type.
2021
class DataType:
2122
_name: str
2223
_ir_type_asm: str

core/shark_turbine/kernel/_support/indexing.py

Lines changed: 6 additions & 274 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,24 @@
1-
from typing import Any, ClassVar, Optional, Type, TypeVar, Union, cast
1+
from typing import Any, ClassVar, Optional, Type, TypeVar, Union
22

3-
from abc import ABC, abstractmethod
3+
from abc import ABC
44
from dataclasses import dataclass
5-
from enum import Enum
65

76
import sympy
8-
import torch
9-
10-
from .. import ops
117

128
from . import context
139
from . import dtype
10+
from .shaped_type import ShapedType, ShapedDataType
1411

1512
__all__ = [
1613
"backed_sym_index_type",
1714
"sym",
1815
"BoundedRelation",
1916
"EqualRelation",
20-
"Grid",
2117
"IndexingContext",
2218
"IndexRelation",
2319
"IndexExpr",
2420
"IndexSymbol",
25-
"InputBuffer",
26-
"KernelBuffer",
27-
"OutputBuffer",
2821
"SymIndex",
29-
"TemporaryBuffer",
3022
]
3123

3224
DataType = dtype.DataType
@@ -74,270 +66,12 @@ def __getattr__(self, n):
7466
SymbolicDimable = Union[str, IndexExpr]
7567
SymbolicShapeable = tuple[SymbolicDimable]
7668
SymbolicShapeExpr = tuple[IndexExpr]
77-
78-
79-
def make_symbolic_shape(elements: SymbolicShapeable) -> SymbolicShapeExpr:
80-
return tuple(
81-
index_symbol(expr) if isinstance(expr, str) else expr for expr in elements
82-
)
83-
84-
85-
###############################################################################
86-
# Grid
87-
###############################################################################
88-
89-
90-
class _GridMeta(type):
91-
"""Meta-class for a symbolically shaped grid."""
92-
93-
def __new__(
94-
mcls,
95-
name: str,
96-
bases,
97-
dct,
98-
*,
99-
symbolic_shape: Optional[SymbolicShapeExpr],
100-
):
101-
new_class = type.__new__(mcls, name, bases, dct)
102-
new_class.symbolic_shape = symbolic_shape
103-
new_class.rank = len(symbolic_shape) if symbolic_shape is not None else None
104-
new_class.__qualname__ = repr(new_class)
105-
return new_class
106-
107-
def __repr__(self):
108-
if self.symbolic_shape:
109-
return f"Grid[{', '.join(repr(s) for s in self.symbolic_shape)}]"
110-
else:
111-
return "Grid"
112-
113-
114-
class Grid(metaclass=_GridMeta, symbolic_shape=None):
115-
"""Grid with bounding symbolic shape information in the type."""
116-
117-
symbolic_shape: ClassVar[Optional[SymbolicShapeExpr]]
118-
# TODO: dims should also allow dynamic dimensions.
119-
dims: list[int]
120-
rank: int
121-
122-
def __init__(self):
123-
# Resolve the symbolic shape to concrete values.
124-
idxc = IndexingContext.current()
125-
if self.symbolic_shape:
126-
dims = [idxc.get_static_value(dim) for dim in self.symbolic_shape]
127-
if None in dims:
128-
raise ValueError(f"NYI: Dynamic dims in Grid")
129-
self.dims = cast(list[int], dims)
130-
else:
131-
self.dims = []
132-
133-
# Shadow the type rank with the actual, which makes it concrete
134-
# for the generic case.
135-
self.rank = len(self.dims)
136-
137-
def __class_getitem__(
138-
cls, symbolic_shape: Union[SymbolicDimable, tuple[SymbolicShapeable]]
139-
) -> Type["Grid"]:
140-
if not isinstance(symbolic_shape, tuple):
141-
symbolic_shape = (symbolic_shape,)
142-
return cast(Grid, _make_shaped_grid(cls, make_symbolic_shape(symbolic_shape)))
143-
144-
def __repr__(self):
145-
return f"{repr(type(self))}({', '.join(str(i) for i in self.dims)})"
146-
147-
def __getitem__(self, index: int) -> int:
148-
return self.dims[index]
149-
150-
def __len__(self) -> int:
151-
return len(self.dims)
152-
153-
def __iter__(self):
154-
return iter(self.dims)
155-
156-
157-
def _make_shaped_grid(cls: Type[Grid], symbolic_shape: tuple[IndexExpr]):
158-
class ShapedGrid(Grid, symbolic_shape=symbolic_shape):
159-
...
160-
161-
return ShapedGrid
162-
163-
164-
###############################################################################
165-
# KernelBuffer
166-
###############################################################################
167-
16869
Dims = list[Union[None, IndexSymbol, int]]
16970

170-
171-
class KernelBufferUsage(Enum):
172-
NONE = 0
173-
INPUT = 1
174-
OUTPUT = 2
175-
TEMPORARY = 3
176-
177-
@staticmethod
178-
def _type_name(v) -> str:
179-
if v == KernelBufferUsage.NONE:
180-
return "KernelBuffer"
181-
elif v == KernelBufferUsage.INPUT:
182-
return "InputBuffer"
183-
elif v == KernelBufferUsage.OUTPUT:
184-
return "OutputBuffer"
185-
elif v == KernelBufferUsage.TEMPORARY:
186-
return "TemporaryBuffer"
187-
else:
188-
raise AssertionError(f"uncovered KernelBufferUsage enum ({v})")
189-
190-
191-
class _KernelBufferMeta(type):
192-
"""Meta-class for kernel buffers.
193-
194-
This lets us specialize with symbolic shape information.
195-
"""
196-
197-
element_type: DataType
198-
usage: KernelBufferUsage
199-
symbolic_shape: Optional[SymbolicShapeExpr]
200-
rank: Optional[int]
201-
202-
def __new__(
203-
mcls,
204-
name: str,
205-
bases,
206-
dct,
207-
):
208-
element_type = dct.get("element_type") or DefaultDataType
209-
dct["element_type"] = element_type
210-
usage = dct.get("usage") or KernelBufferUsage.NONE
211-
dct["usage"] = usage
212-
if "usage" not in dct:
213-
dct["usage"] = KernelBufferUsage.NONE
214-
symbolic_shape = dct.get("symbolic_shape")
215-
dct["symbolic_shape"] = symbolic_shape
216-
dct["rank"] = len(symbolic_shape) if symbolic_shape is not None else None
217-
dct["__qualname__"] = _kernel_buffer_type_repr(
218-
element_type=element_type, usage=usage, symbolic_shape=symbolic_shape
219-
)
220-
new_class = type.__new__(mcls, name, bases, dct)
221-
return new_class
222-
223-
def new_subtype(
224-
cls: Type[SubtypeT],
225-
*,
226-
element_type: Union[NotSetType, DataType] = NotSet,
227-
symbolic_shape: Union[NotSetType, Optional[SymbolicShapeable]] = NotSet,
228-
usage: Union[NotSetType, KernelBufferUsage] = NotSet,
229-
) -> Type[SubtypeT]:
230-
init_element_type = (
231-
element_type if element_type is not NotSet else cls.element_type
232-
)
233-
init_symbolic_shape = (
234-
symbolic_shape if symbolic_shape is not NotSet else cls.symbolic_shape
235-
)
236-
init_usage = usage if usage is not NotSet else cls.usage
237-
238-
class Subtype(cls):
239-
element_type = init_element_type
240-
symbolic_shape = make_symbolic_shape(init_symbolic_shape)
241-
usage = init_usage
242-
243-
return Subtype
244-
245-
def of(cls: Type[SubtypeT], element_type: Union[Any, DataType]) -> Type[SubtypeT]:
246-
return cls.new_subtype(element_type=element_type)
247-
248-
def __repr__(cls):
249-
return _kernel_buffer_type_repr(
250-
element_type=cls.element_type,
251-
usage=cls.usage,
252-
symbolic_shape=cls.symbolic_shape,
253-
)
254-
255-
256-
def is_kernel_buffer_meta_derived(t: type) -> bool:
257-
return isinstance(t, _KernelBufferMeta)
258-
259-
260-
def _kernel_buffer_type_repr(
261-
*,
262-
element_type: DataType,
263-
usage: KernelBufferUsage,
264-
symbolic_shape: Optional[tuple[IndexExpr]],
265-
) -> str:
266-
root = KernelBufferUsage._type_name(usage)
267-
if symbolic_shape:
268-
stem = f"{root}[{', '.join(repr(s) for s in symbolic_shape)}]"
269-
else:
270-
stem = f"{root}"
271-
if element_type != DefaultDataType:
272-
stem += f".of({element_type})"
273-
return stem
274-
275-
276-
class KernelBuffer(metaclass=_KernelBufferMeta):
277-
"""Represents a buffer in global memory.
278-
279-
Top level kernels always operate on global memory via these
280-
buffers, and the primary operations that can be performed on
281-
them are loads/stores and DMAs to some form of compute
282-
capable local buffer.
283-
284-
When executing eagerly, these are backed by a normal torch
285-
Tensor. When compiling, an appropriate duck-typed proxy
286-
is used.
287-
"""
288-
289-
usage: ClassVar[KernelBufferUsage]
290-
symbolic_shape: ClassVar[Optional[SymbolicShapeExpr]]
291-
rank: Optional[int]
292-
293-
def __init__(self, tensor: torch.Tensor):
294-
assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}"
295-
type_rank = type(self).rank
296-
tensor_rank = len(tensor.shape)
297-
if type_rank is not None and type_rank != tensor_rank:
298-
raise ValueError(
299-
f"Cannot create {type(self)}(tensor({tensor.shape})): mismatched symbolic rank"
300-
)
301-
self._tensor = tensor
302-
self.rank = tensor_rank
303-
304-
def __class_getitem__(
305-
cls, symbolic_shape: Union[IndexExpr, SymbolicShapeExpr]
306-
) -> Type["KernelBuffer"]:
307-
if not isinstance(symbolic_shape, tuple):
308-
symbolic_shape = (symbolic_shape,)
309-
return cast(
310-
cls, cls.new_subtype(symbolic_shape=make_symbolic_shape(symbolic_shape))
311-
)
312-
313-
def __repr__(self):
314-
return f"{type(self)}({self._tensor})"
315-
316-
def __setitem__(self, key, item):
317-
ops.kernel_buffer_setitem(self, key, item)
318-
319-
def __getitem__(self, key):
320-
return ops.kernel_buffer_getitem(self, key)
321-
322-
323-
class InputBuffer(KernelBuffer):
324-
usage = KernelBufferUsage.INPUT
325-
326-
327-
class OutputBuffer(KernelBuffer):
328-
usage = KernelBufferUsage.OUTPUT
329-
330-
331-
class TemporaryBuffer(KernelBuffer):
332-
usage = KernelBufferUsage.TEMPORARY
333-
334-
33571
###############################################################################
33672
# IndexingContext
33773
###############################################################################
33874

339-
ShapedType = Union[Type[KernelBuffer], Type[Grid]]
340-
34175

34276
@dataclass(slots=True)
34377
class _ShapedBinding:
@@ -377,7 +111,7 @@ def __init__(self):
377111
# Indexed by .instance
378112
self.shaped_bindings: dict[Any, _ShapedBinding] = {}
379113
self.dyn_dims: list[IndexSymbol] = []
380-
self.frozen_subs: list[IndexSymbol, int] = []
114+
self.frozen_subs: list[tuple[IndexSymbol, int]] = []
381115
self.unbacked_symbols: list[IndexSymbol] = []
382116

383117
def next_dyn_dim(self) -> IndexSymbol:
@@ -390,9 +124,7 @@ def new_unbacked_symbol(self) -> IndexSymbol:
390124
self.unbacked_symbols.append(s)
391125
return s
392126

393-
def bind_shaped(
394-
self, instance: Any, shaped_type: ShapedType, dims: Dims
395-
) -> _ShapedBinding:
127+
def bind_shaped(self, instance: Any, shaped_type: ShapedType, dims: Dims) -> None:
396128
if instance in self.shaped_bindings:
397129
raise ValueError(f"Argument binding {instance} is already bound")
398130
symbolic_shape = shaped_type.symbolic_shape
@@ -406,7 +138,7 @@ def bind_shaped(
406138
)
407139
self.shaped_bindings[instance] = binding
408140

409-
def bind_constant(self, sym: IndexSymbol, value: int):
141+
def bind_constant(self, sym: IndexSymbol, value: int) -> None:
410142
try:
411143
self._bind_symbol(sym, value)
412144
except ValueError:

core/shark_turbine/kernel/_support/regions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def create_proxy(
137137
kwargs,
138138
name=None,
139139
type_expr=None,
140-
proxy_factor_fn=None,
140+
proxy_factory_fn=None,
141141
):
142142
if self.parent is not None:
143143
flat_args, tree_spec = pytree.tree_flatten((args, kwargs))
@@ -154,7 +154,7 @@ def create_proxy(
154154
kwargs,
155155
name,
156156
type_expr,
157-
proxy_factor_fn,
157+
proxy_factory_fn,
158158
)
159159

160160
return rv

0 commit comments

Comments
 (0)