Skip to content

Commit

Permalink
Fix FixedSampleInfoDataset to support multiple inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Jan 13, 2025
1 parent e65d561 commit a770899
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 19 deletions.
29 changes: 24 additions & 5 deletions model_compression_toolkit/core/keras/data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,30 @@ def __init__(self, samples: Sequence, sample_info: Sequence):
self.samples = samples
self.sample_info = sample_info

# Create a TensorFlow dataset that holds (sample, sample_info) tuples
self.tf_dataset = tf.data.Dataset.from_tensor_slices((
tf.convert_to_tensor(self.samples),
tuple(tf.convert_to_tensor(info) for info in self.sample_info)
))
# Get the number of tensors in each tuple (corresponds to the number of input layers the model has)
num_tensors = len(samples[0])

# Create separate lists: one for each input layer and separate the tuples into lists
sample_tensor_lists = [[] for _ in range(num_tensors)]
for s in samples:
for i, data_tensor in enumerate(s):
sample_tensor_lists[i].append(data_tensor)

# In order to deal with models that have different input shapes for different layers, we need first to
# organize the data in a dictionary in order to use tf.data.Dataset.from_tensor_slices
samples_dict = {f'tensor_{i}': tensors for i, tensors in enumerate(sample_tensor_lists)}
info_dict = {f'info_{i}': tf.convert_to_tensor(info) for i, info in enumerate(self.sample_info)}
combined_dict = {**samples_dict, **info_dict}

tf_dataset = tf.data.Dataset.from_tensor_slices(combined_dict)

# Map the dataset to return tuples instead of dict
def reorganize_ds_outputs(ds_output):
tensors = tuple(ds_output[f'tensor_{i}'] for i in range(num_tensors))
infos = tuple(ds_output[f'info_{i}'] for i in range(len(sample_info)))
return tensors, infos

self.tf_dataset = tf_dataset.map(reorganize_ds_outputs)

def __len__(self):
return len(self.samples)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self,
quant_method=QuantizationMethod.SYMMETRIC,
rounding_type=RoundingType.STE,
per_channel=True,
val_batch_size=1,
input_shape=(16, 16, 3),
hessian_weights=True,
log_norm_weights=True,
Expand All @@ -80,7 +81,8 @@ def __init__(self,

super().__init__(unit_test,
input_shape=input_shape,
num_calibration_iter=num_calibration_iter)
num_calibration_iter=num_calibration_iter,
val_batch_size=val_batch_size)

self.quant_method = quant_method
self.rounding_type = rounding_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -734,21 +734,13 @@ def test_gptq_with_sample_layer_attention(self):
tf.config.run_functions_eagerly(True)
kwargs = dict(per_sample=True, loss=sample_layer_attention_loss,
hessian_weights=True, hessian_num_samples=None,
norm_scores=False, log_norm_weights=False, scaled_log_norm=False)
norm_scores=False, log_norm_weights=False, scaled_log_norm=False, val_batch_size=7)
GradientPTQTest(self, **kwargs).run_test()
GradientPTQTest(self, hessian_batch_size=16, rounding_type=RoundingType.SoftQuantizer, **kwargs).run_test()
GradientPTQTest(self, hessian_batch_size=5, rounding_type=RoundingType.SoftQuantizer, gradual_activation_quantization=True, **kwargs).run_test()
GradientPTQTest(self, rounding_type=RoundingType.STE, **kwargs)
tf.config.run_functions_eagerly(False)

# TODO: reuven - new experimental facade needs to be tested regardless the exporter.
# def test_gptq_new_exporter(self):
# self.test_gptq(experimental_exporter=True)

# Comment out due to problem in Tensorflow 2.8
# def test_gptq_conv_group(self):
# GradientPTQLearnRateZeroConvGroupTest(self).run_test()
# GradientPTQWeightsUpdateConvGroupTest(self).run_test()

def test_gptq_conv_group_dilation(self):
# This call removes the effect of @tf.function decoration and executes the decorated function eagerly, which
Expand Down
15 changes: 11 additions & 4 deletions tests_pytest/keras/core/test_data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def test_create_tf_dataloader_fixed_tfdataset(self, fixed_gen):

for i, sample in enumerate(dataset):
assert np.array_equal(sample[0].cpu().numpy(), np.full((3, 30, 20), i))
assert np.array_equal(sample[1].cpu().numpy(), np.full((10,), i+10))


batch_size = 16
dataloader = create_tf_dataloader(dataset, batch_size=batch_size)
Expand All @@ -86,13 +88,15 @@ def test_fixed_tfdataset_too_many_requested_samples(self, fixed_gen):
def test_create_tf_dataloader_fixed_tfdataset_with_info(self, fixed_gen):
samples = []
for b in list(fixed_gen()):
samples.extend(tf.unstack(b[0], axis=0))
break # Take one batch only (since this tests fixed,small dataset)
for sample in zip(tf.unstack(b[0], axis=0), tf.unstack(b[1], axis=0)):
samples.append(sample)
break # Take one batch only (since this tests fixed,small dataset)
dataset = FixedSampleInfoDataset(samples, [tf.range(32)])

for i, sample_with_info in enumerate(dataset):
sample, info = sample_with_info
assert np.array_equal(sample.cpu().numpy(), np.full((3, 30, 20), i))
assert np.array_equal(sample[0].cpu().numpy(), np.full((3, 30, 20), i))
assert np.array_equal(sample[1].cpu().numpy(), np.full((10, ), i+10))
assert info == (i,)

batch_size = 16
Expand All @@ -103,7 +107,8 @@ def test_create_tf_dataloader_fixed_tfdataset_with_info(self, fixed_gen):

for batch in dataloader:
samples, additional_info = batch
assert samples.shape == (batch_size, 3, 30, 20)
assert samples[0].shape == (batch_size, 3, 30, 20)
assert samples[1].shape == (batch_size, 10)
assert additional_info[0].shape == (batch_size,)

def test_create_tf_dataloader_iterable_tfdataset_with_const_info(self, fixed_gen):
Expand All @@ -113,6 +118,7 @@ def test_create_tf_dataloader_iterable_tfdataset_with_const_info(self, fixed_gen
for i, sample_with_info in enumerate(dataset):
sample, info = sample_with_info
assert np.array_equal(sample[0].cpu().numpy(), np.full((3, 30, 20), i))
assert np.array_equal(sample[1].cpu().numpy(), np.full((10,), i+10))
assert info == tf.constant("some_string")

batch_size = 16
Expand All @@ -124,5 +130,6 @@ def test_create_tf_dataloader_iterable_tfdataset_with_const_info(self, fixed_gen
for batch in dataloader:
samples, additional_info = batch
assert samples[0].shape == (batch_size, 3, 30, 20)
assert samples[1].shape == (batch_size, 10)
assert additional_info.shape == (batch_size,)
assert all(additional_info == tf.constant("some_string"))

0 comments on commit a770899

Please sign in to comment.