Skip to content

sradc/SmallPebble

Repository files navigation

SmallPebble

PyPI Open In Colab

This README is generated from a notebook you can run!

SmallPebble is a minimalist autodiff and deep learning library written from scratch in Python. It runs automatic differentiation on NumPy arrays.

The core implementation is in a single file: smallpebble.py

The only dependency of the core implementation is NumPy.

SmallPebble demonstrates the key concepts under the hood of deep learning frameworks, without the complexity of performance optimizations or GPU support.

Recommended Learning Path:

  • Read this introduction to autodiff by the author (featured in university teaching materials).
    • It presents a similar autodiff implementation to SmallPebble, but simplified, on scalars.
  • Read the source code: smallpebble.py.
  • Take a look at the tests/ folder — this library was written test-first.
  • Clone the repo and play with it...

Installation

Clone the repository and install it locally:

git clone https://github.com/sradc/smallpebble.git
cd smallpebble

# Install core library
pip install .

# Or, to run the examples/notebooks
pip install ".[examples]"

Alternatively, to use it as a package:

pip install smallpebble

Highlights

  • Compact: < 1000 lines of core logic.
  • Educational: Prioritizes code clarity over raw performance (CPU-only).
  • Feature-rich: Supports matmul, conv2d, maxpool2d, and array broadcasting.
  • Flexible: Supports eager execution and implicit graph building.
  • Extensible: Easy API for adding custom operations.

How it works (TL;DR)

SmallPebble builds dynamic computation graphs implicitly via Python object referencing (similar to PyTorch). When get_gradients is called, autodiff is performed by traversing the graph backward.

Read on to see:

  • Example models created and trained using SmallPebble.
  • A brief guide to using the library.
# Uncomment if you are running in Google Colab:
# !pip install smallpebble
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import smallpebble as sp
from smallpebble.datasets import load_data

Training a neural network to classify handwritten digits (MNIST)

"Load the dataset, and create a validation set."

X_train, y_train, _, _ = load_data('mnist')  # downloads from smallpebble github and is cached
X_train = X_train/255  # normalize

# Seperate out data for validation.
X = X_train[:50_000, ...]
y = y_train[:50_000]
X_eval = X_train[50_000:60_000, ...]
y_eval = y_train[50_000:60_000]
Downloading mnist.npz...


100%|██████████| 11.5M/11.5M [00:00<00:00, 31.9MB/s]
"Plot, to check we have the right data."

plt.figure(figsize=(5,5))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(X_train[i,:].reshape(28,28), cmap='gray', vmin=0, vmax=1)

plt.show()

png

"Create a model, with two fully connected hidden layers."

X_in = sp.Placeholder()
y_true = sp.Placeholder()

h = sp.linearlayer(28*28, 100)(X_in)
h = sp.Lazy(sp.leaky_relu)(h)
h = sp.linearlayer(100, 100)(h)
h = sp.Lazy(sp.leaky_relu)(h)
h = sp.linearlayer(100, 10)(h)
y_pred = sp.Lazy(sp.softmax)(h)
loss = sp.Lazy(sp.cross_entropy)(y_pred, y_true)

learnables = sp.get_learnables(y_pred)

loss_vals = []
validation_acc = []
"Train model, while measuring performance on the validation dataset."

NUM_ITERS = 300
BATCH_SIZE = 200

eval_batch = sp.batch(X_eval, y_eval, BATCH_SIZE)
adam = sp.Adam()  # Adam optimization

for i, (xbatch, ybatch) in tqdm(enumerate(sp.batch(X, y, BATCH_SIZE)), total=NUM_ITERS):
    if i >= NUM_ITERS: break
    
    X_in.assign_value(sp.Variable(xbatch))
    y_true.assign_value(ybatch)
    
    loss_val = loss.run()  # run the graph
    if np.isnan(loss_val.array):
        print("loss is nan, aborting.")
        break
    loss_vals.append(loss_val.array)
        
    # Compute gradients, and use to carry out learning step:
    gradients = sp.get_gradients(loss_val)
    adam.training_step(learnables, gradients)
        
    # Compute validation accuracy:
    x_eval_batch, y_eval_batch = next(eval_batch)
    X_in.assign_value(sp.Variable(x_eval_batch))
    predictions = y_pred.run()
    predictions = np.argmax(predictions.array, axis=1)
    accuracy = (y_eval_batch == predictions).mean()
    validation_acc.append(accuracy)

# Plot results:
print(f'Final validation accuracy: {np.mean(validation_acc[-10:])}')
plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)
plt.ylabel('Loss')
plt.xlabel('Iteration')
plt.plot(loss_vals)
plt.subplot(1, 2, 2)
plt.ylabel('Validation accuracy')
plt.xlabel('Iteration')
plt.suptitle('Neural network trained on MNIST, using SmallPebble.')
plt.ylim([0, 1])
plt.plot(validation_acc)
plt.show()
100%|██████████| 300/300 [00:01<00:00, 189.88it/s]


Final validation accuracy: 0.9155000000000001

png

Training a convolutional neural network on a CIFAR-10 subset

Use a subset of the data because CNNs are very slow on CPU.

"Load the CIFAR dataset."

X_train, y_train, _, _ = load_data('cifar')
X_train = X_train/255  # normalize
Downloading cifar.npz...


100%|██████████| 170M/170M [00:04<00:00, 37.3MB/s] 
"""Plot, to check it's the right data.

(This cell's code is from: https://www.tensorflow.org/tutorials/images/cnn#verify_the_data)
"""

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

plt.figure(figsize=(8,8))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(X_train[i,:].reshape(32,32,3))
    plt.xlabel(class_names[y_train[i]])

plt.show()

png

# Seperate out data for validation as before.
X = X_train[:1000, ...]
y = y_train[:1000]
X_eval = X_train[1000:1100, ...]
y_eval = y_train[1000:1100]
"""Define a model."""

X_in = sp.Placeholder()
y_true = sp.Placeholder()

h = sp.convlayer(height=3, width=3, depth=3, n_kernels=32)(X_in)
h = sp.Lazy(sp.leaky_relu)(h)
h = sp.Lazy(lambda a: sp.maxpool2d(a, 2, 2, strides=[2, 2]))(h)

h = sp.convlayer(3, 3, 32, 128, padding='VALID')(h)
h = sp.Lazy(sp.leaky_relu)(h)
h = sp.Lazy(lambda a: sp.maxpool2d(a, 2, 2, strides=[2, 2]))(h)

h = sp.convlayer(3, 3, 128, 128, padding='VALID')(h)
h = sp.Lazy(sp.leaky_relu)(h)
h = sp.Lazy(lambda a: sp.maxpool2d(a, 2, 2, strides=[2, 2]))(h)

h = sp.Lazy(lambda x: sp.reshape(x, [-1, 3*3*128]))(h)
h = sp.linearlayer(3*3*128, 10)(h)
h = sp.Lazy(sp.softmax)(h)

y_pred = h
loss = sp.Lazy(sp.cross_entropy)(y_pred, y_true)

learnables = sp.get_learnables(y_pred)

loss_vals = []
validation_acc = []

# Check we get the expected dimensions
X_in.assign_value(sp.Variable(X[0:3, :].reshape([-1, 32, 32, 3])))
h.run().shape
(3, 10)

Train the model.

NUM_ITERS = 100
BATCH_SIZE = 128

eval_batch = sp.batch(X_eval, y_eval, BATCH_SIZE)
adam = sp.Adam()

for i, (xbatch, ybatch) in tqdm(enumerate(sp.batch(X, y, BATCH_SIZE)), total=NUM_ITERS):
    if i >= NUM_ITERS: break
       
    xbatch_images = xbatch.reshape([-1, 32, 32, 3])
    X_in.assign_value(sp.Variable(xbatch_images))
    y_true.assign_value(ybatch)
    
    loss_val = loss.run()
    if np.isnan(loss_val.array):
        print("Aborting, loss is nan.")
        break
    loss_vals.append(loss_val.array)
    
    # Compute gradients, and carry out learning step.
    gradients = sp.get_gradients(loss_val)  
    adam.training_step(learnables, gradients)
    
    # Compute validation accuracy:
    x_eval_batch, y_eval_batch = next(eval_batch)
    X_in.assign_value(sp.Variable(x_eval_batch.reshape([-1, 32, 32, 3])))
    predictions = y_pred.run()
    predictions = np.argmax(predictions.array, axis=1)
    accuracy = (y_eval_batch == predictions).mean()
    validation_acc.append(accuracy)

print(f'Final validation accuracy: {np.mean(validation_acc[-10:])}')
plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)
plt.ylabel('Loss')
plt.xlabel('Iteration')
plt.plot(loss_vals)
plt.subplot(1, 2, 2)
plt.ylabel('Validation accuracy')
plt.xlabel('Iteration')
plt.suptitle('CNN trained on CIFAR-10, using SmallPebble.')
plt.ylim([0, 1])
plt.plot(validation_acc)
plt.show()
100%|██████████| 100/100 [01:22<00:00,  1.22it/s]

Final validation accuracy: 0.3390625

png

Note: Accuracy is limited here by the small subset of data used to ensure the example runs quickly on CPU.


Brief guide to using SmallPebble

SmallPebble provides the following building blocks to make models with:

  • sp.Variable
  • Operations, such as sp.add, sp.mul, etc.
  • sp.get_gradients
  • sp.Lazy
  • sp.Placeholder (this is really just sp.Lazy on the identity function)
  • sp.learnable
  • sp.get_learnables

The following examples show how these are used.

sp.Variable & sp.get_gradients

With SmallPebble, you can:

  • Wrap NumPy arrays in sp.Variable
  • Apply SmallPebble operations (e.g. sp.matmul, sp.add, etc.)
  • Compute gradients with sp.get_gradients
a = sp.Variable(np.random.random([2, 2]))
b = sp.Variable(np.random.random([2, 2]))
c = sp.Variable(np.random.random([2]))
y = sp.mul(a, b) + c
print('y.array:\n', y.array)

gradients = sp.get_gradients(y)
grad_a = gradients[a]
grad_b = gradients[b]
grad_c = gradients[c]
print('grad_a:\n', grad_a)
print('grad_b:\n', grad_b)
print('grad_c:\n', grad_c)
y.array:
 [[0.43457083 1.14068148]
 [0.63550632 0.97019613]]
grad_a:
 [[0.25686699 0.46522723]
 [0.68233942 0.03916393]]
grad_b:
 [[0.77756233 0.38989307]
 [0.58719396 0.27840733]]
grad_c:
 [2. 2.]

Note that y is computed straight away, i.e. the (forward) computation happens immediately.

Also note that y is a sp.Variable and we could continue to carry out SmallPebble operations on it.

sp.Lazy & sp.Placeholder

Lazy graphs are constructed using sp.Lazy and sp.Placeholder.

lazy_node = sp.Lazy(lambda a, b: a + b)(1, 2)
print(lazy_node)
print(lazy_node.run())
<smallpebble.smallpebble.Lazy object at 0x10e785960>
3
a = sp.Lazy(lambda a: a)(2)
y = sp.Lazy(lambda a, b, c: a * b + c)(a, 3, 4)
print(y)
print(y.run())
<smallpebble.smallpebble.Lazy object at 0x10e785930>
10

Forward computation does not happen immediately - only when .run() is called.

a = sp.Placeholder()
b = sp.Variable(np.random.random([2, 2]))
y = sp.Lazy(sp.matmul)(a, b)

a.assign_value(sp.Variable(np.array([[1,2], [3,4]])))

result = y.run()
print('result.array:\n', result.array)
result.array:
 [[2.24308885 1.05636046]
 [5.26999297 2.85349506]]

You can use .run() as many times as you like.

Let's change the placeholder value and re-run the graph:

a.assign_value(sp.Variable(np.array([[10,20], [30,40]])))
result = y.run()
print('result.array:\n', result.array)
result.array:
 [[22.43088846 10.5636046 ]
 [52.6999297  28.53495056]]

Finally, let's compute gradients:

gradients = sp.get_gradients(result)

Note that sp.get_gradients is called on result, which is a sp.Variable, not on y, which is a sp.Lazy instance.

sp.learnable & sp.get_learnables

Use sp.learnable to flag parameters as learnable, allowing them to be extracted from a lazy graph with sp.get_learnables.

This enables a workflow of: building a model, while flagging parameters as learnable, and then extracting all the parameters in one go at the end.

a = sp.Placeholder()
b = sp.learnable(sp.Variable(np.random.random([2, 1])))
y = sp.Lazy(sp.matmul)(a, b)
y = sp.Lazy(sp.add)(y, sp.learnable(sp.Variable(np.array([5]))))

learnables = sp.get_learnables(y)

for learnable in learnables:
    print(learnable)
<smallpebble.smallpebble.Variable object at 0x10ec685e0>
<smallpebble.smallpebble.Variable object at 0x10e8be050>

SmallPebble (2022–2026)

About

A minimalist deep learning library written from scratch in Python

Topics

Resources

License

Stars

Watchers

Forks

Contributors 2

  •  
  •