Skip to content

Commit

Permalink
improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Oct 14, 2024
1 parent 0a4aed4 commit c029475
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -457,4 +457,4 @@ def get_inferable_quantizers(self, node: BaseNode):

@staticmethod
def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
raise NotImplementedError()
raise NotImplementedError() # pragma: no cover
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/pytorch/data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, data_gen_fn: Callable[[], Generator]):
# validate one batch
test_batch = next(data_gen_fn())
if not isinstance(test_batch, list):
raise TypeError(f'Data generator is expected to yield a list of tensors, got {type(inputs)}')
raise TypeError(f'Data generator is expected to yield a list of tensors, got {type(test_batch)}')
self.orig_batch_size = test_batch[0].shape[0]

self._size = None
Expand Down
37 changes: 37 additions & 0 deletions tests/keras_tests/function_tests/test_hessian_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,43 @@ def test_double_fetch_hessian(self):
"approximation, for the single target node.")
self.assertEqual(hessian[target_node.name].shape[0], 2, "Expecting 2 Hessian scores.")

def test_invalid_request(self):
with self.assertRaises(ValueError, msg='Data loader and the number of samples cannot both be None'):
HessianScoresRequest(mode=HessianMode.ACTIVATION,
granularity=HessianScoresGranularity.PER_TENSOR,
n_samples=None,
data_loader=None,
target_nodes=list(self.graph.nodes))

def test_fetch_hessian_invalid_args(self):
request = HessianScoresRequest(mode=HessianMode.ACTIVATION,
granularity=HessianScoresGranularity.PER_TENSOR,
n_samples=None,
data_loader=data_gen_to_dataloader(representative_dataset, batch_size=1),
target_nodes=list(self.graph.nodes))
with self.assertRaises(ValueError, msg='Number of samples can be None only when force_compute is True.'):
self.hessian_service.fetch_hessian(request)

def test_double_fetch_more_samples(self):
# this is mostly for coverage
self.hessian_service.clear_cache()
node = list(self.graph.get_topo_sorted_nodes())[0]
request = HessianScoresRequest(mode=HessianMode.ACTIVATION,
granularity=HessianScoresGranularity.PER_TENSOR,
n_samples=2,
data_loader=data_gen_to_dataloader(representative_dataset, batch_size=1),
target_nodes=[node])
hess = self.hessian_service.fetch_hessian(request)
assert hess[node.name].shape[0] == 2

request = HessianScoresRequest(mode=HessianMode.ACTIVATION,
granularity=HessianScoresGranularity.PER_TENSOR,
n_samples=4,
data_loader=data_gen_to_dataloader(representative_dataset, batch_size=1),
target_nodes=[node])
hess = self.hessian_service.fetch_hessian(request)
assert hess[node.name].shape[0] == 4


if __name__ == "__main__":
unittest.main()
19 changes: 18 additions & 1 deletion tests_pytest/pytorch/core/test_data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.utils.data import IterableDataset, Dataset

from model_compression_toolkit.core.pytorch.data_util import (data_gen_to_dataloader, IterableDatasetFromGenerator,
FixedDatasetFromGenerator)
FixedDatasetFromGenerator, FixedSampleInfoDataset)


@pytest.fixture(scope='session')
Expand Down Expand Up @@ -90,6 +90,23 @@ def test_fixed_dataset_from_random_gen_subset(self):
ds = FixedDatasetFromGenerator(get_random_data_gen_fn(), n_samples=123)
self._validate_fixed_ds(ds, exp_len=123, exp_batch_size=32)

def test_not_enough_samples_in_datagen(self):
def gen():
yield [np.ones((10, 3))]
with pytest.raises(ValueError, match='Not enough samples in the data generator'):
FixedDatasetFromGenerator(gen, n_samples=11)

def test_extra_info_mismatch(self, fixed_gen):
with pytest.raises(ValueError, match='Mismatch in the number of samples between samples and complementary data'):
FixedSampleInfoDataset([1]*10, [2]*10, [3]*11)

@pytest.mark.parametrize('ds_cls', [FixedDatasetFromGenerator, IterableDatasetFromGenerator])
def test_invalid_gen(self, ds_cls):
def gen():
yield np.ones((10, 3))
with pytest.raises(TypeError, match='Data generator is expected to yield a list of tensors'):
ds_cls(gen)

def _validate_ds_from_fixed_gen(self, ds, exp_len):
for _ in range(2):
for i, sample in enumerate(ds):
Expand Down

0 comments on commit c029475

Please sign in to comment.