diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 3b9993736e4..1f25929ccd4 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -31,6 +31,7 @@ cast_to_python_objects, generate_from_arrow_type, get_nested_type, + list_of_dicts_to_pyarrow_structarray, list_of_np_array_to_pyarrow_listarray, numpy_to_pyarrow_listarray, to_pyarrow_listarray, @@ -183,6 +184,9 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None): out = numpy_to_pyarrow_listarray(data) elif isinstance(data, list) and data and isinstance(first_non_null_value(data)[1], np.ndarray): out = list_of_np_array_to_pyarrow_listarray(data) + elif isinstance(data, list) and data and isinstance(first_non_null_value(data)[1], dict): + # pa_type should be a struct type + out = list_of_dicts_to_pyarrow_structarray(data, pa_type) else: trying_cast_to_python_objects = True out = pa.array(cast_to_python_objects(data, only_1d_for_numpy=True)) diff --git a/src/datasets/features/__init__.py b/src/datasets/features/__init__.py index 35ebfb4ac0c..d45eb1d06ce 100644 --- a/src/datasets/features/__init__.py +++ b/src/datasets/features/__init__.py @@ -1,4 +1,5 @@ __all__ = [ + "Array1D", "Audio", "Array2D", "Array3D", @@ -14,6 +15,6 @@ "TranslationVariableLanguages", ] from .audio import Audio -from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value +from .features import Array1D, Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value from .image import Image from .translation import Translation, TranslationVariableLanguages diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 1d241e0b7b7..6ef9b1c1eab 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -545,6 +545,33 @@ def encode_example(self, value): return value +@dataclass +class Array1D(_ArrayXD): + """Create a one-dimensional array. + + Unlike Sequence, will be extracted as a numpy array irrespective of formatting. + + Args: + shape (`tuple`): + Size of each dimension. + dtype (`str`): + Name of the data type. + + Example: + + ```py + >>> from datasets import Features + >>> features = Features({'x': Array1D(shape=(3,), dtype='int32')}) + ``` + """ + + shape: tuple + dtype: str + id: Optional[str] = None + # Automatically constructed + _type: str = field(default="Array1D", init=False, repr=False) + + @dataclass class Array2D(_ArrayXD): """Create a two-dimensional array. @@ -649,8 +676,8 @@ class _ArrayXDExtensionType(pa.ExtensionType): ndims: Optional[int] = None def __init__(self, shape: tuple, dtype: str): - if self.ndims is None or self.ndims <= 1: - raise ValueError("You must instantiate an array type with a value for dim that is > 1") + if self.ndims is None: + raise ValueError("You must instantiate an array type with a value for dim that is >= 1") if len(shape) != self.ndims: raise ValueError(f"shape={shape} and ndims={self.ndims} don't match") for dim in range(1, self.ndims): @@ -691,6 +718,10 @@ def to_pandas_dtype(self): return PandasArrayExtensionDtype(self.value_type) +class Array1DExtensionType(_ArrayXDExtensionType): + ndims = 1 + + class Array2DExtensionType(_ArrayXDExtensionType): ndims = 2 @@ -708,6 +739,7 @@ class Array5DExtensionType(_ArrayXDExtensionType): # Register the extension types for deserialization +pa.register_extension_type(Array1DExtensionType((1,), "int64")) pa.register_extension_type(Array2DExtensionType((1, 2), "int64")) pa.register_extension_type(Array3DExtensionType((1, 2, 3), "int64")) pa.register_extension_type(Array4DExtensionType((1, 2, 3, 4), "int64")) @@ -791,10 +823,7 @@ def to_numpy(self, zero_copy_only=True): def to_pylist(self): zero_copy_only = _is_zero_copy_only(self.storage.type, unnest=True) numpy_arr = self.to_numpy(zero_copy_only=zero_copy_only) - if self.type.shape[0] is None and numpy_arr.dtype == object: - return [arr.tolist() for arr in numpy_arr.tolist()] - else: - return numpy_arr.tolist() + return list(numpy_arr) class PandasArrayExtensionDtype(PandasExtensionDtype): @@ -1196,6 +1225,7 @@ class LargeList: TranslationVariableLanguages, LargeList, Sequence, + Array1D, Array2D, Array3D, Array4D, @@ -1411,6 +1441,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni TranslationVariableLanguages.__name__: TranslationVariableLanguages, LargeList.__name__: LargeList, Sequence.__name__: Sequence, + Array1D.__name__: Array1D, Array2D.__name__: Array2D, Array3D.__name__: Array3D, Array4D.__name__: Array4D, @@ -1494,8 +1525,11 @@ def generate_from_arrow_type(pa_type: pa.DataType) -> FeatureType: feature = generate_from_arrow_type(pa_type.value_type) return LargeList(feature=feature) elif isinstance(pa_type, _ArrayXDExtensionType): - array_feature = [None, None, Array2D, Array3D, Array4D, Array5D][pa_type.ndims] - return array_feature(shape=pa_type.shape, dtype=pa_type.value_type) + if pa_type.ndims >= 1: + array_feature = [Array1D, Array2D, Array3D, Array4D, Array5D][pa_type.ndims - 1] + return array_feature(shape=pa_type.shape, dtype=pa_type.value_type) + else: + raise ValueError("Cannot convert 0-dimensional array to Array Feature type.") elif isinstance(pa_type, pa.DataType): return Value(dtype=_arrow_to_datasets_dtype(pa_type)) else: @@ -1586,6 +1620,61 @@ def to_pyarrow_listarray(data: Any, pa_type: _ArrayXDExtensionType) -> pa.Array: return pa.array(data, pa_type.storage_dtype) +def list_of_dicts_to_pyarrow_structarray(data: List[Dict[str, Any]], struct_type: pa.StructType) -> pa.StructArray: + """Convert a list of dictionaries to a pyarrow StructArray. + + First builds a dict of lists, then converts each list to a pyarrow array, + then creates a StructArray from the arrays. + """ + if not data: + raise ValueError("Input data must be a non-empty list of dictionaries.") + + # Get field names from struct type if available, otherwise from first non-null dict + if struct_type is not None: + field_names = [field.name for field in struct_type] + else: + first_dict = next((d for d in data if d is not None), None) + if first_dict is None: + raise ValueError("All dictionaries in input data are None") + field_names = list(first_dict.keys()) + + # Initialize empty lists for each field + field_arrays = {name: [] for name in field_names} + + null_mask = [] + for row in data: + if row is None: + null_mask.append(True) + for key in field_arrays.keys(): + field_arrays[key].append(None) + else: + null_mask.append(False) + for key in field_arrays.keys(): + value = row.get(key, None) + field_arrays[key].append(value) + + # TODO: do these need to be ordered? + pa_fields = [] + for key, values in field_arrays.items(): + if struct_type is not None: + index = struct_type.get_field_index(key) + field_type = struct_type[index].type + else: + field_type = None + # TODO: should field_type None be handled better? + pa_field = ( + to_pyarrow_listarray(values, field_type) + if contains_any_np_array(values) and field_type is not None + else pa.array(values) + ) + pa_fields.append((key, pa_field)) + + field_names, field_arrays = zip(*pa_fields) + null_mask_array = pa.array(null_mask, type=pa.bool_()) + + return pa.StructArray.from_arrays(field_arrays, field_names, mask=null_mask_array) + + def _visit(feature: FeatureType, func: Callable[[FeatureType], Optional[FeatureType]]) -> FeatureType: """Visit a (possibly nested) feature. @@ -1715,7 +1804,7 @@ class Features(dict): - - [`Array2D`], [`Array3D`], [`Array4D`] or [`Array5D`] feature for multidimensional arrays. + - [`Array1D`], [`Array2D`], [`Array3D`], [`Array4D`] or [`Array5D`] feature for multidimensional arrays. - [`Audio`] feature to store the absolute path to an audio file or a dictionary with the relative path to an audio file ("path" key) and its bytes content ("bytes" key). This feature extracts the audio data. - [`Image`] feature to store the absolute path to an image file, an `np.ndarray` object, a `PIL.Image.Image` object diff --git a/src/datasets/formatting/formatting.py b/src/datasets/formatting/formatting.py index 2dae3a52fd3..b114c09bccc 100644 --- a/src/datasets/formatting/formatting.py +++ b/src/datasets/formatting/formatting.py @@ -107,6 +107,16 @@ def _is_array_with_nulls(pa_array: pa.Array) -> bool: return pa_array.null_count > 0 +def dict_of_lists_to_list_of_dicts(dict_of_lists: Dict[str, List[T]]) -> List[Dict[str, T]]: + # convert to list of dicts + list_of_dicts = [] + keys = dict_of_lists.keys() + value_arrays = [dict_of_lists[key] for key in keys] + for vals in zip(*value_arrays): + list_of_dicts.append(dict(zip(keys, vals))) + return list_of_dicts + + class BaseArrowExtractor(Generic[RowFormat, ColumnFormat, BatchFormat]): """ Arrow extractor are used to extract data from pyarrow tables. @@ -140,15 +150,60 @@ def extract_batch(self, pa_table: pa.Table) -> pa.Table: return pa_table +def extract_struct_array(pa_array: pa.StructArray) -> list: + """StructArray.to_pylist / to_pydict does not call sub-arrays to_pylist / to_pydict methods so handle them manually.""" + if isinstance(pa_array, pa.ChunkedArray): + batch_chunks = [extract_struct_array(chunk) for chunk in pa_array.chunks] + return [item for chunk in batch_chunks for item in chunk] + + batch = {} + for field in pa_array.type: + if pa.types.is_struct(pa_array.field(field.name).type): + batch[field.name] = extract_struct_array(pa_array.field(field.name)) + else: + # use logic from _arrow_array_to_numpy to preserve dtype + if isinstance(pa_array.type, _ArrayXDExtensionType): + zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) + batch[field.name] = list(pa_array.to_numpy(zero_copy_only=zero_copy_only)) + else: + batch[field.name] = pa_array.field(field.name).to_pylist() + return dict_of_lists_to_list_of_dicts(batch) + + +def extract_array_xdextension_array(pa_array: pa.Array) -> list: + print("Extracting array xdextension array") + if isinstance(pa_array, pa.ChunkedArray): + return [arr for chunk in pa_array.chunks for arr in extract_array_xdextension_array(chunk)] + else: + zero_copy_only = _is_zero_copy_only(pa_array.type.storage_dtype, unnest=True) + return list(pa_array.to_numpy(zero_copy_only=zero_copy_only)) + + class PythonArrowExtractor(BaseArrowExtractor[dict, list, dict]): def extract_row(self, pa_table: pa.Table) -> dict: - return _unnest(pa_table.to_pydict()) + return _unnest(self.extract_batch(pa_table)) def extract_column(self, pa_table: pa.Table) -> list: - return pa_table.column(0).to_pylist() + if pa.types.is_struct(pa_table[pa_table.column_names[0]].type): + return extract_struct_array(pa_table[pa_table.column_names[0]]) + # TODO: handle list of struct + else: + # should work for list of ArrayXD + return pa_table.column(0).to_pylist() def extract_batch(self, pa_table: pa.Table) -> dict: - return pa_table.to_pydict() + batch = {} + for col in pa_table.column_names: + if pa.types.is_struct(pa_table[col].type): + batch[col] = extract_struct_array(pa_table[col]) + else: + pa_array = pa_table[col] + if isinstance(pa_array.type, _ArrayXDExtensionType): + # don't call to_pylist() to preserve dtype of the fixed-size array + batch[col] = extract_array_xdextension_array(pa_array) + else: + batch[col] = pa_table[col].to_pylist() + return batch class NumpyArrowExtractor(BaseArrowExtractor[dict, np.ndarray, dict]): diff --git a/tests/features/test_array_xd.py b/tests/features/test_array_xd.py index 8a50823b996..2029b6b5230 100644 --- a/tests/features/test_array_xd.py +++ b/tests/features/test_array_xd.py @@ -175,20 +175,20 @@ def get_dict_examples(self, shape_1, shape_2): def _check_getitem_output_type(self, dataset, shape_1, shape_2, first_matrix): matrix_column = dataset["matrix"] self.assertIsInstance(matrix_column, list) - self.assertIsInstance(matrix_column[0], list) - self.assertIsInstance(matrix_column[0][0], list) + self.assertIsInstance(matrix_column[0], np.ndarray) + self.assertIsInstance(matrix_column[0][0], np.ndarray) self.assertTupleEqual(np.array(matrix_column).shape, (2, *shape_2)) matrix_field_of_first_example = dataset[0]["matrix"] - self.assertIsInstance(matrix_field_of_first_example, list) - self.assertIsInstance(matrix_field_of_first_example, list) + self.assertIsInstance(matrix_field_of_first_example, np.ndarray) + self.assertIsInstance(matrix_field_of_first_example[0], np.ndarray) self.assertEqual(np.array(matrix_field_of_first_example).shape, shape_2) np.testing.assert_array_equal(np.array(matrix_field_of_first_example), np.array(first_matrix)) matrix_field_of_first_two_examples = dataset[:2]["matrix"] self.assertIsInstance(matrix_field_of_first_two_examples, list) - self.assertIsInstance(matrix_field_of_first_two_examples[0], list) - self.assertIsInstance(matrix_field_of_first_two_examples[0][0], list) + self.assertIsInstance(matrix_field_of_first_two_examples[0], np.ndarray) + self.assertIsInstance(matrix_field_of_first_two_examples[0][0], np.ndarray) self.assertTupleEqual(np.array(matrix_field_of_first_two_examples).shape, (2, *shape_2)) with dataset.formatted_as("numpy"): @@ -268,7 +268,7 @@ def test_to_pylist(self): pylist = arr_xd.to_pylist() for first_dim, single_arr in zip(first_dim_list, pylist): - self.assertIsInstance(single_arr, list) + self.assertIsInstance(single_arr, np.ndarray) self.assertTupleEqual(np.array(single_arr).shape, (first_dim, *fixed_shape)) def test_to_numpy(self): @@ -311,8 +311,8 @@ def test_iter_dataset(self): for first_dim, ds_row in zip(first_dim_list, dataset): single_arr = ds_row["image"] - self.assertIsInstance(single_arr, list) - self.assertTupleEqual(np.array(single_arr).shape, (first_dim, *fixed_shape)) + self.assertIsInstance(single_arr, np.ndarray) + self.assertTupleEqual(single_arr.shape, (first_dim, *fixed_shape)) def test_to_pandas(self): fixed_shape = (2, 2) @@ -353,8 +353,8 @@ def test_map_dataset(self): # check also if above function resulted with 2x bigger first dim for first_dim, ds_row in zip(first_dim_list, dataset): single_arr = ds_row["image"] - self.assertIsInstance(single_arr, list) - self.assertTupleEqual(np.array(single_arr).shape, (first_dim * 2, *fixed_shape)) + self.assertIsInstance(single_arr, np.ndarray) + self.assertTupleEqual(single_arr.shape, (first_dim * 2, *fixed_shape)) @pytest.mark.parametrize("dtype, dummy_value", [("int32", 1), ("bool", True), ("float64", 1)]) @@ -419,7 +419,7 @@ def test_array_xd_with_none(): def test_array_xd_with_np(seq_type, dtype, shape, feature_class): feature = feature_class(dtype=dtype, shape=shape) data = np.zeros(shape, dtype=dtype) - expected = data.tolist() + expected = data if seq_type == "sequence": feature = datasets.Sequence(feature) data = [data] @@ -429,7 +429,12 @@ def test_array_xd_with_np(seq_type, dtype, shape, feature_class): data = [[data]] expected = [[expected]] ds = datasets.Dataset.from_dict({"col": [data]}, features=datasets.Features({"col": feature})) - assert ds[0]["col"] == expected + if seq_type == "sequence": + assert (ds[0]["col"][0] == expected[0]).all() + elif seq_type == "sequence_of_sequence": + assert (ds[0]["col"][0][0] == expected[0][0]).all() + else: + assert (ds[0]["col"] == expected).all() @pytest.mark.parametrize("with_none", [False, True]) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index ffa048644e2..37167ed3316 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -195,8 +195,8 @@ def test_dummy_dataset(self, in_memory): } ), ) - self.assertEqual(dset[0]["col_2"], [[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]]) - self.assertEqual(dset["col_2"][0], [[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]]) + assert (dset[0]["col_2"] == np.array([[[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]]])).all() + assert (dset["col_2"][0] == np.array([[[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]]])).all() def test_dataset_getitem(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: @@ -4386,30 +4386,33 @@ def f(x): def f(x): """May return a mix of LazyDict and regular Dict, but using an extension type""" if x["a"][0][0] < 2: - x["a"] = [[-1]] + x["a"] = np.array([[-1]], dtype="int32") return dict(x) if return_lazy_dict is False else x else: return x if return_lazy_dict is True else {} features = Features({"a": Array2D(shape=(1, 1), dtype="int32")}) - ds = Dataset.from_dict({"a": [[[i]] for i in [0, 1, 2, 3]]}, features=features) + # If not passing array we get exceptions that are not easy to understand - not sure if there could be some type-checking needed somewhere? + ds = Dataset.from_dict({"a": [np.array([[i]], dtype="int32") for i in [0, 1, 2, 3]]}, features=features) ds = ds.map(f) outputs = ds[:] - assert outputs == {"a": [[[i]] for i in [-1, -1, 2, 3]]} + assert outputs == {"a": [np.array([[i]], dtype="int32") for i in [-1, -1, 2, 3]]} def f(x): """May return a mix of LazyDict and regular Dict, but using a nested extension type""" if x["a"]["nested"][0][0] < 2: - x["a"] = {"nested": [[-1]]} + x["a"] = {"nested": np.array([[-1]], dtype="int64")} return dict(x) if return_lazy_dict is False else x else: return x if return_lazy_dict is True else {} features = Features({"a": {"nested": Array2D(shape=(1, 1), dtype="int64")}}) - ds = Dataset.from_dict({"a": [{"nested": [[i]]} for i in [0, 1, 2, 3]]}, features=features) + ds = Dataset.from_dict( + {"a": [{"nested": np.array([[i]], dtype="int64")} for i in [0, 1, 2, 3]]}, features=features + ) ds = ds.map(f) outputs = ds[:] - assert outputs == {"a": [{"nested": [[i]]} for i in [-1, -1, 2, 3]]} + assert outputs == {"a": [{"nested": np.array([[i]], dtype="int64")} for i in [-1, -1, 2, 3]]} def test_dataset_getitem_raises(): diff --git a/tests/test_table.py b/tests/test_table.py index 3d3db09e5d6..a921527eae7 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -1332,6 +1332,21 @@ def test_cast_array_to_feature_with_list_array_and_large_list_feature(from_list_ assert cast_array.type == expected_array_type +def all_arrays_equal(arr1, arr2): + if len(arr1) != len(arr2): + return False + for a1, a2 in zip(arr1, arr2): + if isinstance(a1, list) and isinstance(a2, list): + if not all_arrays_equal(a1, a2): + return False + elif isinstance(a1, np.ndarray) and isinstance(a2, np.ndarray): + if not (a1 == a2).all(): + return False + elif a1 != a2: + return False + return True + + def test_cast_array_xd_to_features_sequence(): arr = np.random.randint(0, 10, size=(8, 2, 3)).tolist() arr = Array2DExtensionType(shape=(2, 3), dtype="int64").wrap_array(pa.array(arr, pa.list_(pa.list_(pa.int64())))) @@ -1339,11 +1354,11 @@ def test_cast_array_xd_to_features_sequence(): # Variable size list casted_array = cast_array_to_feature(arr, Sequence(Array2D(shape=(2, 3), dtype="int32"))) assert casted_array.type == get_nested_type(Sequence(Array2D(shape=(2, 3), dtype="int32"))) - assert casted_array.to_pylist() == arr.to_pylist() + assert all_arrays_equal(casted_array.to_pylist(), arr.to_pylist()) # Fixed size list casted_array = cast_array_to_feature(arr, Sequence(Array2D(shape=(2, 3), dtype="int32"), length=4)) assert casted_array.type == get_nested_type(Sequence(Array2D(shape=(2, 3), dtype="int32"), length=4)) - assert casted_array.to_pylist() == arr.to_pylist() + assert all_arrays_equal(casted_array.to_pylist(), arr.to_pylist()) def test_embed_array_storage(image_file):