diff --git a/CHANGELOG.md b/CHANGELOG.md index 070cba11f..5eed7de90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,9 @@ - Added support for HDMF Common Schema 1.9.0. - Introduced a new data type `MeaningsTable` and changes to `DynamicTable` to support included `MeaningsTable` objects. @rly [#1376](https://github.com/hdmf-dev/hdmf/pull/1376) - Promoted `HERD` from the hdmf-experimental namespace to the HDMF Common namespace. @rly [#1387](https://github.com/hdmf-dev/hdmf/pull/1387) +- Added a check when setting or adding data to a `DynamicTableRegion` or setting the `table` attribute of a `DynamicTableRegion` + that the data values are in bounds of the linked table. This can be turned off for + `DynamicTableRegion.__init__` using the keyword argument `validate_data=False`. @rly [#1168](https://github.com/hdmf-dev/hdmf/pull/1168) - Added warning when `data_type_def` and `data_type_inc` are the same in a spec. @rly [#1312](https://github.com/hdmf-dev/hdmf/pull/1312) - Added abstract methods `HDMFIO.load_namespaces` and `HDMFIO.load_namespaces_io`. @rly [#1299](https://github.com/hdmf-dev/hdmf/pull/1299) diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index 503f49e00..5cdb7d6b8 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -1081,12 +1081,6 @@ def create_region(self, **kwargs): msg = 'region slice %s is out of range for this DynamicTable of length %d' % (str(region), len(self)) raise IndexError(msg) region = list(range(*region.indices(len(self)))) - else: - for idx in region: - if idx < 0 or idx >= len(self): - raise IndexError('The index ' + str(idx) + - ' is out of range for this DynamicTable of length ' - + str(len(self))) desc = getargs('description', kwargs) name = getargs('name', kwargs) return DynamicTableRegion(name=name, data=region, description=desc, table=self) @@ -1505,11 +1499,58 @@ class DynamicTableRegion(VectorData): {'name': 'description', 'type': str, 'doc': 'a description of what this region represents'}, {'name': 'table', 'type': DynamicTable, 'doc': 'the DynamicTable this region applies to', 'default': None}, + {'name': 'validate_data', 'type': bool, + 'doc': 'whether to validate the data is in bounds of the linked table', 'default': True}, allow_positional=AllowPositional.WARNING) def __init__(self, **kwargs): - t = popargs('table', kwargs) + table, validate_data = popargs('table', 'validate_data', kwargs) + data = getargs('data', kwargs) + self._validate_data = validate_data + if self._validate_data: + self._validate_index_in_range(data, table) + super().__init__(**kwargs) - self.table = t + if table is not None: # set the table attribute using fields to avoid another validation in the setter + self.fields['table'] = table + + def _check_indices_in_bounds(self, data, table): + """Check if data contains indices that are out of bounds. + + Args: + data: Single index or array of indices to check + table: The DynamicTable to check bounds against + + Returns: + Error message string if validation fails, None if validation passes or table is None + """ + if not table: + return None + + # Convert to numpy array for efficient checking + if isinstance(data, (list, tuple)): + data_arr = np.array(data) + elif np.isscalar(data): + data_arr = np.array([data]) + else: + data_arr = data[:] + + # Find indices that are out of bounds + violators = np.where((data_arr >= len(table)) | (data_arr < 0))[0] + if violators.size > 0: + return ( + f"DynamicTableRegion values {data_arr[violators]} are out of bounds for " + f"{type(table)} '{table.name}'." + ) + return None + + def _validate_index_in_range(self, data, table): + """If data contains an index that is out of bounds, then raise an error. + If the object is being constructed from a file, raise a warning instead to ensure invalid data can still be + read. + """ + error_msg = self._check_indices_in_bounds(data, table) + if error_msg: + self._error_on_new_warn_on_construct(error_msg, error_cls=IndexError) @property def table(self): @@ -1517,24 +1558,37 @@ def table(self): return self.fields.get('table') @table.setter - def table(self, val): + def table(self, table): """ - Set the table this DynamicTableRegion should be pointing to + Set the table this DynamicTableRegion should be pointing to. - :param val: The DynamicTable this DynamicTableRegion should be pointing to + This will validate all data elements in this DynamicTableRegion to ensure they are within bounds if + validate_data was set to True. + + :param table: The DynamicTable this DynamicTableRegion should be pointing to :raises: AttributeError if table is already in fields :raises: IndexError if the current indices are out of bounds for the new table given by val """ - if val is None: + if table is None: return if 'table' in self.fields: msg = "can't set attribute 'table' -- already set" raise AttributeError(msg) - dat = self.data - if isinstance(dat, DataIO): - dat = dat.data - self.fields['table'] = val + + self.fields['table'] = table + if self._validate_data: + self._validate_index_in_range(self.data, table) + + def extend(self, arg): + """Add all elements of the iterable arg to the end of this DynamicTableRegion. + + This override uses efficient batch validation instead of validating element-by-element. + """ + # Use the parent Data class extend which calls _validate_new_data for batch validation + # Skip VectorData.extend which would fall back to element-by-element add_row + from hdmf.container import Data + Data.extend(self, arg) def __getitem__(self, arg): return self.get(arg) @@ -1678,6 +1732,26 @@ def _validate_on_set_parent(self): warn(msg, stacklevel=2) return super()._validate_on_set_parent() + def _validate_new_data(self, data): + """Validate a batch of indices before adding to this DynamicTableRegion. + + Validation only occurs if validate_data was set to True (the default). + """ + if self._validate_data: + error_msg = self._check_indices_in_bounds(data, self.table) + if error_msg: + raise IndexError(error_msg) + + def _validate_new_data_element(self, arg): + """Validate that the new index is within bounds of the table. Raises an IndexError if not. + + Validation only occurs if validate_data was set to True (the default). + """ + if self._validate_data: + error_msg = self._check_indices_in_bounds(arg, self.table) + if error_msg: + raise IndexError(error_msg) + def _uint_precision(elements): """ Calculate the uint precision needed to encode a set of elements """ diff --git a/src/hdmf/container.py b/src/hdmf/container.py index ec2345050..762555112 100644 --- a/src/hdmf/container.py +++ b/src/hdmf/container.py @@ -551,6 +551,19 @@ def _validate_on_set_parent(self): """ pass + def _error_on_new_warn_on_construct(self, error_msg: str, error_cls: type = ValueError): + """Raise a ValueError when a check is violated on instance creation. + To ensure backwards compatibility, this method throws a warning + instead of raising an error when reading from a file, ensuring that + files with invalid data can be read. If error_msg is set to None + the function will simply return without further action. + """ + if error_msg is None: + return + if not self._in_construct_mode: + raise error_cls(error_msg) + warn(error_msg) + class Container(AbstractContainer): """A container that can contain other containers and has special functionality for printing.""" diff --git a/tests/unit/build_tests/test_classgenerator.py b/tests/unit/build_tests/test_classgenerator.py index 0a39a0f7d..82f9d5193 100644 --- a/tests/unit/build_tests/test_classgenerator.py +++ b/tests/unit/build_tests/test_classgenerator.py @@ -1,5 +1,4 @@ import numpy as np -import os import shutil import tempfile from warnings import warn @@ -677,9 +676,6 @@ class TestGetClassSeparateNamespace(TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp() - if os.path.exists(self.test_dir): # start clean - self.tearDown() - os.mkdir(self.test_dir) self.bar_spec = GroupSpec( doc='A test group specification with a data type', @@ -863,9 +859,6 @@ class TestGetClassObjectReferences(TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp() - if os.path.exists(self.test_dir): # start clean - self.tearDown() - os.mkdir(self.test_dir) self.type_map = TypeMap() def tearDown(self): diff --git a/tests/unit/common/test_table.py b/tests/unit/common/test_table.py index a07d0a4d4..1421c8d71 100644 --- a/tests/unit/common/test_table.py +++ b/tests/unit/common/test_table.py @@ -1390,6 +1390,46 @@ def test_no_df_nested(self): with self.assertRaisesWith(ValueError, msg): dynamic_table_region.get(0, df=False, index=False) + def test_init_out_of_bounds(self): + table = self.with_columns_and_data() + with self.assertRaises(IndexError): + DynamicTableRegion(name='dtr', data=[0, 1, 2, 2, 5], description='desc', table=table) + + def test_init_out_of_bounds_no_validate(self): + table = self.with_columns_and_data() + dtr = DynamicTableRegion(name='dtr', data=[0, 1, 5], description='desc', table=table, validate_data=False) + self.assertEqual(dtr.data, [0, 1, 5]) # no exception raised + + def test_add_row_out_of_bounds(self): + table = self.with_columns_and_data() + dtr = DynamicTableRegion(name='dtr', data=[0, 1, 2, 2], description='desc', table=table) + with self.assertRaises(IndexError): + dtr.add_row(5) + + def test_add_row_out_of_bounds_no_validate(self): + table = self.with_columns_and_data() + dtr = DynamicTableRegion(name='dtr', data=[0, 1], description='desc', table=table, validate_data=False) + dtr.add_row(5) # should not raise an error + self.assertEqual(list(dtr.data), [0, 1, 5]) + + def test_set_table_out_of_bounds(self): + table = self.with_columns_and_data() + dtr = DynamicTableRegion(name='dtr', data=[0, 1, 5], description='desc') + with self.assertRaises(IndexError): + dtr.table = table + + def test_extend_out_of_bounds(self): + table = self.with_columns_and_data() + dtr = DynamicTableRegion(name='dtr', data=[0, 1], description='desc', table=table) + with self.assertRaises(IndexError): + dtr.extend([2, 10, 20]) + + def test_extend_out_of_bounds_no_validate(self): + table = self.with_columns_and_data() + dtr = DynamicTableRegion(name='dtr', data=[0, 1], description='desc', table=table, validate_data=False) + dtr.extend([10, 20]) # should not raise an error + self.assertEqual(list(dtr.data), [0, 1, 10, 20]) + def test_create_region_with_valid_slice_range(self): table = self.with_columns_and_data() region = table.create_region(name='region', region=slice(0, 2), description='test region') @@ -1406,18 +1446,6 @@ def test_create_region_with_none_slice(self): region = table.create_region(name='region2', region=slice(0, None), description='test region') self.assertEqual(region.data, [0, 1, 2, 3, 4]) - def test_create_region_with_negative_index(self): - table = self.with_columns_and_data() - - msg = 'The index -1 is out of range for this DynamicTable of length 5' - with self.assertRaisesWith(IndexError, msg): - table.create_region(name='region', region=[-1, 0], description='test region') - - def test_create_region_with_out_of_range_index(self): - table = self.with_columns_and_data() - msg = 'The index 10 is out of range for this DynamicTable of length 5' - with self.assertRaisesWith(IndexError, msg): - table.create_region(name='region', region=[0, 10], description='test region') class DynamicTableRegionRoundTrip(H5RoundTripMixin, TestCase): diff --git a/tests/unit/test_io_hdf5_streaming.py b/tests/unit/test_io_hdf5_streaming.py index 1a487b939..c03110a76 100644 --- a/tests/unit/test_io_hdf5_streaming.py +++ b/tests/unit/test_io_hdf5_streaming.py @@ -80,9 +80,9 @@ def setUp(self): self.manager = BuildManager(type_map) def tearDown(self): - if os.path.exists(self.ns_filename): + if hasattr(self, 'ns_filename') and os.path.exists(self.ns_filename): os.remove(self.ns_filename) - if os.path.exists(self.ext_filename): + if hasattr(self, 'ext_filename') and os.path.exists(self.ext_filename): os.remove(self.ext_filename) def test_basic_read(self):