Skip to content

Commit

Permalink
Refactor keras data utils
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Jan 6, 2025
1 parent 5ed46c7 commit 917b414
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 32 deletions.
56 changes: 28 additions & 28 deletions model_compression_toolkit/core/keras/data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Generator, Callable

import tensorflow as tf

from model_compression_toolkit.core.keras.tf_tensor_numpy import to_tf_tensor

import tensorflow as tf
from typing import Callable, Generator, Sequence, Any
Expand Down Expand Up @@ -58,7 +53,6 @@ def gen():

return gen


class TFDatasetFromGenerator:
"""
TensorFlow dataset from a data generator function, batched to a specified size.
Expand All @@ -77,15 +71,15 @@ def __init__(self, data_gen_fn: Callable[[], Generator]):

# TFDatasetFromGenerator flattens the dataset, thus we ignore the batch dimension
output_signature = get_tensor_spec(inputs, ignore_batch_dim=True)
self.dataset = tf.data.Dataset.from_generator(flat_gen_fn(data_gen_fn), output_signature=output_signature)
self.tf_dataset = tf.data.Dataset.from_generator(flat_gen_fn(data_gen_fn), output_signature=output_signature)

def __iter__(self):
return iter(self.dataset)
return iter(self.tf_dataset)

def __len__(self):
""" Returns the number of batches. """
if self._size is None:
self._size = sum(1 for _ in self.dataset)
self._size = sum(1 for _ in self.tf_dataset)
return self._size


Expand Down Expand Up @@ -116,6 +110,12 @@ def __init__(self, data_gen_fn: Callable[[], Generator], n_samples: int = None):
raise ValueError(f'Not enough samples to create a dataset with {n_samples} samples')
self.samples = samples

# Use from_generator to keep tuples intact
self.tf_dataset = tf.data.Dataset.from_generator(
lambda: iter(self.samples),
output_signature=tuple(tf.TensorSpec(shape=sample.shape, dtype=sample.dtype) for sample in self.samples[0])
)

def __len__(self):
return len(self.samples)

Expand All @@ -134,6 +134,12 @@ def __init__(self, samples: Sequence, sample_info: Sequence):
self.samples = samples
self.sample_info = sample_info

# Create a TensorFlow dataset that holds (sample, sample_info) tuples
self.tf_dataset = tf.data.Dataset.from_tensor_slices((
tf.convert_to_tensor(self.samples),
tuple(tf.convert_to_tensor(info) for info in self.sample_info)
))

def __len__(self):
return len(self.samples)

Expand All @@ -150,18 +156,23 @@ def __init__(self, samples_dataset: tf.data.Dataset, *info: Any):
self.samples_dataset = samples_dataset
self.info = info

# Map to ensure the output is always (sample, info) as a tuple
self.tf_dataset = self.samples_dataset.map(
lambda *x: ((x,) if not isinstance(x, tuple) else x, *self.info)
)

def __iter__(self):
for sample in self.samples_dataset:
yield (sample, *self.info)
yield ((sample,) if not isinstance(sample, tuple) else sample, *self.info)


def data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
"""Create a DataLoader based on samples yielded by data_gen."""
ds = TFDatasetFromGenerator(data_gen_fn)
return create_tf_dataloader(dataset=ds, batch_size=batch_size)
return create_tf_dataloader(mct_dataset=ds, batch_size=batch_size)


def create_tf_dataloader(dataset, batch_size, shuffle=False, collate_fn=None):
def create_tf_dataloader(mct_dataset, batch_size, shuffle=False, collate_fn=None):
"""
Creates a tf.data.Dataset with specified loading options.
Expand All @@ -174,26 +185,15 @@ def create_tf_dataloader(dataset, batch_size, shuffle=False, collate_fn=None):
Returns:
tf.data.Dataset: Configured for batching, shuffling, and custom transformations.
"""
def generator():
for item in dataset:
yield item

dummy_input_tensors = next(generator())

output_signature = get_tensor_spec(dummy_input_tensors)

tf_dataset = tf.data.Dataset.from_generator(
generator,
output_signature=output_signature
)
dataset = mct_dataset.tf_dataset

if shuffle:
tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset))
dataset = dataset.shuffle(buffer_size=len(dataset))

tf_dataset = tf_dataset.batch(batch_size)
dataset = dataset.batch(batch_size)

# Apply collate function if provided
if collate_fn:
tf_dataset = tf_dataset.map(lambda *args: collate_fn(args))
dataset = dataset.map(lambda *args: collate_fn(args))

return tf_dataset
return dataset
6 changes: 3 additions & 3 deletions model_compression_toolkit/gptq/keras/gptq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def collate_fn(samples_with_loss_weights):

# Create final dataset using the new dataloader with collate_fn
final_dataset = create_tf_dataloader(
dataset=sla_train_dataset,
sla_train_dataset,
batch_size=orig_batch_size,
shuffle=True,
collate_fn=collate_fn
Expand All @@ -176,14 +176,14 @@ def _prepare_train_dataloader_for_non_sla(self,

# Step 2: Compute loss weights
if self.gptq_config.hessian_weights_config:
hessian_dataset = create_tf_dataloader(dataset=dataset, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
hessian_dataset = create_tf_dataloader(dataset, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
hessian_weights = self.compute_hessian_based_weights(hessian_dataset)
loss_weights = tf.convert_to_tensor(hessian_weights, dtype=tf.float32)
else:
loss_weights = tf.ones(num_nodes, dtype=tf.float32) / num_nodes

# Step 3: Create a dataset with samples and loss weights
augmented_dataset = IterableSampleWithConstInfoDataset(dataset.dataset, loss_weights)
augmented_dataset = IterableSampleWithConstInfoDataset(dataset.tf_dataset, loss_weights)

# Step 4: Add constant regularization weights
reg_weights = tf.ones(num_nodes, dtype=tf.float32)
Expand Down
2 changes: 1 addition & 1 deletion tests_pytest/keras/core/test_data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_create_tf_dataloader_fixed_tfdataset_with_info(self, fixed_gen):

def test_create_tf_dataloader_iterable_tfdataset_with_const_info(self, fixed_gen):
iterable_ds = TFDatasetFromGenerator(fixed_gen)
dataset = IterableSampleWithConstInfoDataset(iterable_ds, tf.constant("some_string"))
dataset = IterableSampleWithConstInfoDataset(iterable_ds.tf_dataset, tf.constant("some_string"))

for i, sample_with_info in enumerate(dataset):
sample, info = sample_with_info
Expand Down

0 comments on commit 917b414

Please sign in to comment.