The Dataset class represents a potentially large set of elements in TensorFlow's tf.data API. It provides methods for creating, transforming, and iterating over datasets.
- Supports creation of datasets from various sources (tensors, generators, files, etc.)
- Provides transformations like mapping, filtering, batching, shuffling etc.
- Enables efficient input pipelines through prefetching and parallelism
- Supports windowing, grouping, and other advanced operations
- Allows saving/loading datasets
from_tensor_slices(): Creates a dataset from a tensor or nested structure of tensorsfrom_generator(): Creates a dataset from a Python generator functionrange(): Creates a dataset of a step-separated range of valueslist_files(): Creates a dataset of filenames matching glob patterns
map(): Applies a function to each elementfilter(): Filters elements based on a predicatebatch(): Combines consecutive elements into batchesshuffle(): Randomly shuffles elementsrepeat(): Repeats the dataset a given number of timestake(): Takes a specified number of elements from the startskip(): Skips a specified number of elements from the start
__iter__(): Allows iterating over the dataset elementsas_numpy_iterator(): Returns an iterator over NumPy arrays
cache(): Caches elements of the datasetprefetch(): Prefetches elements to improve performanceapply(): Applies a custom transformation functionsave(): Saves the dataset to diskcardinality(): Returns the number of elements in the dataset
# Create dataset from tensor
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
# Apply transformations
dataset = dataset.map(lambda x: x * 2)
dataset = dataset.batch(2)
# Iterate over elements
for element in dataset:
print(element)The Dataset class forms the foundation of efficient data loading and preprocessing pipelines in TensorFlow, enabling scalable machine learning workflows.