Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions src/ninetoothed/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,12 @@
class CodeGenerator(ast.NodeTransformer):
def __init__(self):
super().__init__()

device = triton.runtime.driver.active.get_current_device()
properties = triton.runtime.driver.active.utils.get_device_properties(device)

self._min_num_elements = 1
self._max_num_elements = 2**18 // 8
self._configs = []
self._free_symbols = set()
self._block_size = None

if "max_num_regs" in properties:
max_innermost_size = 4 * properties["max_num_regs"]
elif "max_nram_size" in properties:
max_innermost_size = properties["max_nram_size"]
else:
max_innermost_size = 2**18

self._max_num_elements = max_innermost_size // 8

def __call__(
self,
Expand Down Expand Up @@ -724,25 +716,24 @@ def _generate_slices(tensor, dim):
@staticmethod
def _generate_overall_offsets_and_mask(tensor, indices):
indices = list(indices)

offsets, mask = CodeGenerator._generate_offsets_and_mask(tensor, indices)

tensor._last_generated_offsets = offsets

overall_offsets = sum(
offsets[source_dim] * Symbol(tensor.source.stride_string(source_dim))
for source_dim in range(tensor.source.ndim)
)
if hasattr(tensor.source, "is_contiguous") and tensor.source.is_contiguous:
overall_offsets = indices[0] if len(indices) == 1 else sum(offsets)
else:
overall_offsets = sum(
offsets[source_dim] * Symbol(tensor.source.stride_string(source_dim))
for source_dim in range(tensor.source.ndim)
)

if tensor.source.jagged_dim is not None:
overall_offsets += CodeGenerator._name_for_seq_start(tensor) * Symbol(
tensor.source.stride_string(tensor.source.jagged_dim)
)

tensor._last_generated_overall_offsets = overall_offsets

return overall_offsets, mask

@staticmethod
def _generate_offsets_and_mask(tensor, indices):
offsets = [Symbol(0) for _ in range(tensor.source.ndim)]
Expand Down
18 changes: 16 additions & 2 deletions src/ninetoothed/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,23 @@ def _offsets(indices):
self._inputs = []

self._history = []

self.is_contiguous = self._check_contiguous()
type(self).num_instances += 1

def _check_contiguous(self):
"""
判断张量是否为标准内存连续布局
连续张量的stride满足:stride[i] = product(shape[i+1:])
"""
if self.ndim == 0:
return True

expected_stride = 1
# 从最后一个维度向前遍历
for i in reversed(range(self.ndim)):
if self.strides[i] != expected_stride:
return False
expected_stride *= self.shape[i]
return True
def __getitem__(self, indices):
"""Returns an indexed tensor using the specified ``indices``.

Expand Down