diff --git a/arrayfire/array_object.py b/arrayfire/array_object.py index d44ee29..ea8c5cd 100755 --- a/arrayfire/array_object.py +++ b/arrayfire/array_object.py @@ -746,7 +746,7 @@ def __getitem__(self, key: IndexKey, /) -> Array: ---------- self : Array Array instance. - key : int | slice | tuple[int | slice, ...] | Array + key : int | slice | tuple[int | slice | Array, ...] | Array Index key. Returns @@ -754,19 +754,40 @@ def __getitem__(self, key: IndexKey, /) -> Array: out : Array An array containing the accessed value(s). The returned array must have the same data type as self. """ - # TODO - # API Specification - key: Union[int, slice, ellipsis, tuple[Union[int, slice, ellipsis], ...], array]. - # consider using af.span to replace ellipsis during refactoring out = Array() ndims = self.ndim - if isinstance(key, Array) and key == afbool.c_api_value: + indexing = key + + if isinstance(key, int | float | slice): # when indexing with one dimension, treat it as indexing a flat array + ndims = 1 + elif isinstance(key, Array): # when indexing with one array, treat it as indexing a flat array ndims = 1 - if wrapper.count_all(key.arr) == 0: # HACK was count() method before - return out + if key.is_bool: + indexing = wrapper.where(key.arr) + else: + indexing = key.arr + elif isinstance(key, tuple): + key_list = [] + for elem in key: + if isinstance(elem, Array): + if elem.is_bool: + key_list.append(wrapper.where(elem.arr)) + else: + key_list.append(elem.arr) + else: + key_list.append(elem) + indexing = tuple(key_list) + + out._arr = wrapper.index_gen(self._arr, ndims, wrapper.get_indices(indexing)) # type: ignore[arg-type] + + if isinstance(key, Array) and key.is_bool: + wrapper.release_array(indexing) + elif isinstance(key, tuple): + for i in range(len(key)): + if isinstance(key[i], Array) and key[i].is_bool: + wrapper.release_array(indexing[i]) - # HACK known issue - out._arr = wrapper.index_gen(self._arr, ndims, wrapper.get_indices(key)) # type: ignore[arg-type] return out def __index__(self) -> int: @@ -781,8 +802,19 @@ def __len__(self) -> int: return self.shape[0] if self.shape else 0 def __setitem__(self, key: IndexKey, value: int | float | bool | Array, /) -> None: - ndims = self.ndim + """ + Assigns self[key] = value + + Parameters + ---------- + self : Array + Array instance. + key : int | slice | tuple[int | slice | Array, ...] | Array + Index key. + value: int | float | complex | bool | Array + """ + ndims = self.ndim is_array_with_bool = isinstance(key, Array) and type(key) is afbool if is_array_with_bool: @@ -803,12 +835,41 @@ def __setitem__(self, key: IndexKey, value: int | float | bool | Array, /) -> No other_arr = value.arr del_other = False - indices = wrapper.get_indices(key) # type: ignore[arg-type] # FIXME - out = wrapper.assign_gen(self._arr, other_arr, ndims, indices) + indexing = key + if isinstance(key, int | float | slice): # when indexing with one dimension, treat it as indexing a flat array + ndims = 1 + elif isinstance(key, Array): # when indexing with one array, treat it as indexing a flat array + ndims = 1 + if key.is_bool: + indexing = wrapper.where(key.arr) + else: + indexing = key.arr + elif isinstance(key, tuple): + key_list = [] + for elem in key: + if isinstance(elem, Array): + if elem.is_bool: + locs = wrapper.where(elem.arr) + key_list.append(locs) + else: + key_list.append(elem.arr) + else: + key_list.append(elem) + indexing = tuple(key_list) + + out = wrapper.assign_gen(self._arr, other_arr, ndims, wrapper.get_indices(indexing)) + + if isinstance(key, Array) and key.is_bool: + wrapper.release_array(indexing) + elif isinstance(key, tuple): + for i in range(len(key)): + if isinstance(key[i], Array) and key[i].is_bool: + wrapper.release_array(indexing[i]) wrapper.release_array(self._arr) if del_other: wrapper.release_array(other_arr) + self._arr = out def __str__(self) -> str: @@ -1144,7 +1205,10 @@ def _get_processed_index(key: IndexKey, shape: tuple[int, ...]) -> tuple[int, .. if isinstance(key, tuple): return tuple(_index_to_afindex(key[i], shape[i]) for i in range(len(key))) - return (_index_to_afindex(key, shape[0]),) + shape[1:] + size = 1 + for dim_size in shape: + size *= dim_size + return (_index_to_afindex(key, size),) def _index_to_afindex(key: int | float | complex | bool | slice | wrapper.ParallelRange | Array, axis: int) -> int: @@ -1168,6 +1232,10 @@ def _index_to_afindex(key: int | float | complex | bool | slice | wrapper.Parall def _slice_to_length(key: slice, axis: int) -> int: + start = key.start + stop = key.stop + step = key.step + if key.start is None: start = 0 elif key.start < 0: