Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dataset loading guide (Issue #2116) #3450

Merged
merged 1 commit into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/guides/data_preprocessing/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Data preprocessing
:maxdepth: 1

full_eval
loading_datasets
296 changes: 296 additions & 0 deletions docs/guides/data_preprocessing/loading_datasets.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Loading datasets\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/guides/data_preprocessing/loading_datasets.ipynb)\n",
"\n",
"A neural net written in Jax+Flax expects its input data as `jax.numpy` array instances. Therefore, loading a dataset from any source is as simple as converting it to `jax.numpy` types and reshaping it to the appropriate dimensions for your network.\n",
"\n",
"As an example, this guide demonstrates how to import [MNIST](http://yann.lecun.com/exdb/mnist/) using the APIs from Torchvision, Tensorflow, and Hugging Face. We'll load the whole dataset into memory. For datasets that don't fit into memory the process is analogous but should be done in a batchwise fashion.\n",
"\n",
"The MNIST dataset consists of greyscale images of 28x28 pixels of handwritten digits, and has a designated 60k/10k train/test split. The task is to predict the correct class (digit 0, ..., 9) of each image.\n",
"\n",
"Assuming a CNN-based classifier, the input data should have shape `(B, 28, 28, 1)`, where the trailing singleton dimension denotes the greyscale image channel.\n",
"\n",
"The labels are simply the integer denoting the digit corresponding to the image. Labels should therefore have shape `(B,)`, to enable loss computation with [`optax.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy_with_integer_labels)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"tags": [
"skip-execution"
]
},
"outputs": [],
"source": [
"import numpy as np\n",
"import jax.numpy as jnp"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading from `torchvision.datasets`"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"tags": [
"skip-execution"
]
},
"outputs": [],
"source": [
"import torchvision"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"tags": [
"skip-execution"
]
},
"outputs": [],
"source": [
"def get_dataset_torch():\n",
" mnist = {\n",
" 'train': torchvision.datasets.MNIST('./data', train=True, download=True),\n",
" 'test': torchvision.datasets.MNIST('./data', train=False, download=True)\n",
" }\n",
"\n",
" ds = {}\n",
"\n",
" for split in ['train', 'test']:\n",
" ds[split] = {\n",
" 'image': mnist[split].data.numpy(),\n",
" 'label': mnist[split].targets.numpy()\n",
" }\n",
"\n",
" # cast from np to jnp and rescale the pixel values from [0,255] to [0,1]\n",
" ds[split]['image'] = jnp.float32(ds[split]['image']) / 255\n",
" ds[split]['label'] = jnp.int16(ds[split]['label'])\n",
"\n",
" # torchvision returns shape (B, 28, 28).\n",
" # hence, append the trailing channel dimension.\n",
" ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)\n",
"\n",
" return ds['train'], ds['test']"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"outputId": "be39b756-d13e-4380-b99e-a5cbf61458cc",
"tags": [
"skip-execution"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(60000, 28, 28, 1) float32\n",
"(60000,) int16\n",
"(10000, 28, 28, 1) float32\n",
"(10000,) int16\n"
]
}
],
"source": [
"train, test = get_dataset_torch()\n",
"print(train['image'].shape, train['image'].dtype)\n",
"print(train['label'].shape, train['label'].dtype)\n",
"print(test['image'].shape, test['image'].dtype)\n",
"print(test['label'].shape, test['label'].dtype)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading from `tensorflow_datasets`"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"tags": [
"skip-execution"
]
},
"outputs": [],
"source": [
"import tensorflow_datasets as tfds"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"tags": [
"skip-execution"
]
},
"outputs": [],
"source": [
"def get_dataset_tf():\n",
" mnist = tfds.builder('mnist')\n",
" mnist.download_and_prepare()\n",
"\n",
" ds = {}\n",
"\n",
" for split in ['train', 'test']:\n",
" ds[split] = tfds.as_numpy(mnist.as_dataset(split=split, batch_size=-1))\n",
"\n",
" # cast to jnp and rescale pixel values\n",
" ds[split]['image'] = jnp.float32(ds[split]['image']) / 255\n",
" ds[split]['label'] = jnp.int16(ds[split]['label'])\n",
"\n",
" return ds['train'], ds['test']"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"outputId": "25d2c468-cbc8-4971-a738-1295ce8c6f16",
"tags": [
"skip-execution"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(60000, 28, 28, 1) float32\n",
"(60000,) int16\n",
"(10000, 28, 28, 1) float32\n",
"(10000,) int16\n"
]
}
],
"source": [
"train, test = get_dataset_tf()\n",
"print(train['image'].shape, train['image'].dtype)\n",
"print(train['label'].shape, train['label'].dtype)\n",
"print(test['image'].shape, test['image'].dtype)\n",
"print(test['label'].shape, test['label'].dtype)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading from 🤗 Hugging Face `datasets`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"skip-execution"
]
},
"outputs": [],
"source": [
"#!pip install datasets # datasets isn't preinstalled on Colab; uncomment to install\n",
"from datasets import load_dataset"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"tags": [
"skip-execution"
]
},
"outputs": [],
"source": [
"def get_dataset_hf():\n",
" mnist = load_dataset(\"mnist\")\n",
"\n",
" ds = {}\n",
"\n",
" for split in ['train', 'test']:\n",
" ds[split] = {\n",
" 'image': np.array([np.array(im) for im in mnist[split]['image']]),\n",
" 'label': np.array(mnist[split]['label'])\n",
" }\n",
"\n",
" # cast to jnp and rescale pixel values\n",
" ds[split]['image'] = jnp.float32(ds[split]['image']) / 255\n",
" ds[split]['label'] = jnp.int16(ds[split]['label'])\n",
"\n",
" # append trailing channel dimension\n",
" ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)\n",
"\n",
" return ds['train'], ds['test']"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"outputId": "b026b33f-3bdd-4d26-867c-49400fff1c96",
"tags": [
"skip-execution"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(60000, 28, 28, 1) float32\n",
"(60000,) int16\n",
"(10000, 28, 28, 1) float32\n",
"(10000,) int16\n"
]
}
],
"source": [
"train, test = get_dataset_hf()\n",
"print(train['image'].shape, train['image'].dtype)\n",
"print(train['label'].shape, train['label'].dtype)\n",
"print(test['image'].shape, test['image'].dtype)\n",
"print(test['label'].shape, test['label'].dtype)"
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,md:myst"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Loading
Loading