diff --git a/src/ninetoothed/generation.py b/src/ninetoothed/generation.py index 6ea1d0f..5517971 100644 --- a/src/ninetoothed/generation.py +++ b/src/ninetoothed/generation.py @@ -28,20 +28,12 @@ class CodeGenerator(ast.NodeTransformer): def __init__(self): super().__init__() - - device = triton.runtime.driver.active.get_current_device() - properties = triton.runtime.driver.active.utils.get_device_properties(device) - self._min_num_elements = 1 + self._max_num_elements = 2**18 // 8 + self._configs = [] + self._free_symbols = set() + self._block_size = None - if "max_num_regs" in properties: - max_innermost_size = 4 * properties["max_num_regs"] - elif "max_nram_size" in properties: - max_innermost_size = properties["max_nram_size"] - else: - max_innermost_size = 2**18 - - self._max_num_elements = max_innermost_size // 8 def __call__( self, @@ -724,15 +716,16 @@ def _generate_slices(tensor, dim): @staticmethod def _generate_overall_offsets_and_mask(tensor, indices): indices = list(indices) - offsets, mask = CodeGenerator._generate_offsets_and_mask(tensor, indices) - tensor._last_generated_offsets = offsets - overall_offsets = sum( - offsets[source_dim] * Symbol(tensor.source.stride_string(source_dim)) - for source_dim in range(tensor.source.ndim) - ) + if hasattr(tensor.source, "is_contiguous") and tensor.source.is_contiguous: + overall_offsets = indices[0] if len(indices) == 1 else sum(offsets) + else: + overall_offsets = sum( + offsets[source_dim] * Symbol(tensor.source.stride_string(source_dim)) + for source_dim in range(tensor.source.ndim) + ) if tensor.source.jagged_dim is not None: overall_offsets += CodeGenerator._name_for_seq_start(tensor) * Symbol( @@ -740,9 +733,7 @@ def _generate_overall_offsets_and_mask(tensor, indices): ) tensor._last_generated_overall_offsets = overall_offsets - return overall_offsets, mask - @staticmethod def _generate_offsets_and_mask(tensor, indices): offsets = [Symbol(0) for _ in range(tensor.source.ndim)] diff --git a/src/ninetoothed/tensor.py b/src/ninetoothed/tensor.py index 7c4b14f..8e8a323 100644 --- a/src/ninetoothed/tensor.py +++ b/src/ninetoothed/tensor.py @@ -121,9 +121,23 @@ def _offsets(indices): self._inputs = [] self._history = [] - + self.is_contiguous = self._check_contiguous() type(self).num_instances += 1 - + def _check_contiguous(self): + """ + 判断张量是否为标准内存连续布局 + 连续张量的stride满足:stride[i] = product(shape[i+1:]) + """ + if self.ndim == 0: + return True + + expected_stride = 1 + # 从最后一个维度向前遍历 + for i in reversed(range(self.ndim)): + if self.strides[i] != expected_stride: + return False + expected_stride *= self.shape[i] + return True def __getitem__(self, indices): """Returns an indexed tensor using the specified ``indices``.