This README is generated from a notebook you can run!
SmallPebble is a minimalist autodiff and deep learning library written from scratch in Python. It runs automatic differentiation on NumPy arrays.
The core implementation is in a single file: smallpebble.py
The only dependency of the core implementation is NumPy.
SmallPebble demonstrates the key concepts under the hood of deep learning frameworks, without the complexity of performance optimizations or GPU support.
- Read this introduction to autodiff by the author (featured in university teaching materials).
- It presents a similar autodiff implementation to SmallPebble, but simplified, on scalars.
- Read the source code: smallpebble.py.
- Take a look at the
tests/folder — this library was written test-first. - Clone the repo and play with it...
Clone the repository and install it locally:
git clone https://github.com/sradc/smallpebble.git
cd smallpebble
# Install core library
pip install .
# Or, to run the examples/notebooks
pip install ".[examples]"Alternatively, to use it as a package:
pip install smallpebble- Compact: < 1000 lines of core logic.
- Educational: Prioritizes code clarity over raw performance (CPU-only).
- Feature-rich: Supports matmul, conv2d, maxpool2d, and array broadcasting.
- Flexible: Supports eager execution and implicit graph building.
- Extensible: Easy API for adding custom operations.
SmallPebble builds dynamic computation graphs implicitly via Python object referencing (similar to PyTorch).
When get_gradients is called, autodiff is performed by traversing the graph backward.
- Example models created and trained using SmallPebble.
- A brief guide to using the library.
# Uncomment if you are running in Google Colab:
# !pip install smallpebbleimport matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import smallpebble as sp
from smallpebble.datasets import load_data"Load the dataset, and create a validation set."
X_train, y_train, _, _ = load_data('mnist') # downloads from smallpebble github and is cached
X_train = X_train/255 # normalize
# Seperate out data for validation.
X = X_train[:50_000, ...]
y = y_train[:50_000]
X_eval = X_train[50_000:60_000, ...]
y_eval = y_train[50_000:60_000]Downloading mnist.npz...
100%|██████████| 11.5M/11.5M [00:00<00:00, 31.9MB/s]
"Plot, to check we have the right data."
plt.figure(figsize=(5,5))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(X_train[i,:].reshape(28,28), cmap='gray', vmin=0, vmax=1)
plt.show()"Create a model, with two fully connected hidden layers."
X_in = sp.Placeholder()
y_true = sp.Placeholder()
h = sp.linearlayer(28*28, 100)(X_in)
h = sp.Lazy(sp.leaky_relu)(h)
h = sp.linearlayer(100, 100)(h)
h = sp.Lazy(sp.leaky_relu)(h)
h = sp.linearlayer(100, 10)(h)
y_pred = sp.Lazy(sp.softmax)(h)
loss = sp.Lazy(sp.cross_entropy)(y_pred, y_true)
learnables = sp.get_learnables(y_pred)
loss_vals = []
validation_acc = []"Train model, while measuring performance on the validation dataset."
NUM_ITERS = 300
BATCH_SIZE = 200
eval_batch = sp.batch(X_eval, y_eval, BATCH_SIZE)
adam = sp.Adam() # Adam optimization
for i, (xbatch, ybatch) in tqdm(enumerate(sp.batch(X, y, BATCH_SIZE)), total=NUM_ITERS):
if i >= NUM_ITERS: break
X_in.assign_value(sp.Variable(xbatch))
y_true.assign_value(ybatch)
loss_val = loss.run() # run the graph
if np.isnan(loss_val.array):
print("loss is nan, aborting.")
break
loss_vals.append(loss_val.array)
# Compute gradients, and use to carry out learning step:
gradients = sp.get_gradients(loss_val)
adam.training_step(learnables, gradients)
# Compute validation accuracy:
x_eval_batch, y_eval_batch = next(eval_batch)
X_in.assign_value(sp.Variable(x_eval_batch))
predictions = y_pred.run()
predictions = np.argmax(predictions.array, axis=1)
accuracy = (y_eval_batch == predictions).mean()
validation_acc.append(accuracy)
# Plot results:
print(f'Final validation accuracy: {np.mean(validation_acc[-10:])}')
plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)
plt.ylabel('Loss')
plt.xlabel('Iteration')
plt.plot(loss_vals)
plt.subplot(1, 2, 2)
plt.ylabel('Validation accuracy')
plt.xlabel('Iteration')
plt.suptitle('Neural network trained on MNIST, using SmallPebble.')
plt.ylim([0, 1])
plt.plot(validation_acc)
plt.show()100%|██████████| 300/300 [00:01<00:00, 189.88it/s]
Final validation accuracy: 0.9155000000000001
Use a subset of the data because CNNs are very slow on CPU.
"Load the CIFAR dataset."
X_train, y_train, _, _ = load_data('cifar')
X_train = X_train/255 # normalizeDownloading cifar.npz...
100%|██████████| 170M/170M [00:04<00:00, 37.3MB/s]
"""Plot, to check it's the right data.
(This cell's code is from: https://www.tensorflow.org/tutorials/images/cnn#verify_the_data)
"""
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
plt.figure(figsize=(8,8))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(X_train[i,:].reshape(32,32,3))
plt.xlabel(class_names[y_train[i]])
plt.show()# Seperate out data for validation as before.
X = X_train[:1000, ...]
y = y_train[:1000]
X_eval = X_train[1000:1100, ...]
y_eval = y_train[1000:1100]"""Define a model."""
X_in = sp.Placeholder()
y_true = sp.Placeholder()
h = sp.convlayer(height=3, width=3, depth=3, n_kernels=32)(X_in)
h = sp.Lazy(sp.leaky_relu)(h)
h = sp.Lazy(lambda a: sp.maxpool2d(a, 2, 2, strides=[2, 2]))(h)
h = sp.convlayer(3, 3, 32, 128, padding='VALID')(h)
h = sp.Lazy(sp.leaky_relu)(h)
h = sp.Lazy(lambda a: sp.maxpool2d(a, 2, 2, strides=[2, 2]))(h)
h = sp.convlayer(3, 3, 128, 128, padding='VALID')(h)
h = sp.Lazy(sp.leaky_relu)(h)
h = sp.Lazy(lambda a: sp.maxpool2d(a, 2, 2, strides=[2, 2]))(h)
h = sp.Lazy(lambda x: sp.reshape(x, [-1, 3*3*128]))(h)
h = sp.linearlayer(3*3*128, 10)(h)
h = sp.Lazy(sp.softmax)(h)
y_pred = h
loss = sp.Lazy(sp.cross_entropy)(y_pred, y_true)
learnables = sp.get_learnables(y_pred)
loss_vals = []
validation_acc = []
# Check we get the expected dimensions
X_in.assign_value(sp.Variable(X[0:3, :].reshape([-1, 32, 32, 3])))
h.run().shape(3, 10)
Train the model.
NUM_ITERS = 100
BATCH_SIZE = 128
eval_batch = sp.batch(X_eval, y_eval, BATCH_SIZE)
adam = sp.Adam()
for i, (xbatch, ybatch) in tqdm(enumerate(sp.batch(X, y, BATCH_SIZE)), total=NUM_ITERS):
if i >= NUM_ITERS: break
xbatch_images = xbatch.reshape([-1, 32, 32, 3])
X_in.assign_value(sp.Variable(xbatch_images))
y_true.assign_value(ybatch)
loss_val = loss.run()
if np.isnan(loss_val.array):
print("Aborting, loss is nan.")
break
loss_vals.append(loss_val.array)
# Compute gradients, and carry out learning step.
gradients = sp.get_gradients(loss_val)
adam.training_step(learnables, gradients)
# Compute validation accuracy:
x_eval_batch, y_eval_batch = next(eval_batch)
X_in.assign_value(sp.Variable(x_eval_batch.reshape([-1, 32, 32, 3])))
predictions = y_pred.run()
predictions = np.argmax(predictions.array, axis=1)
accuracy = (y_eval_batch == predictions).mean()
validation_acc.append(accuracy)
print(f'Final validation accuracy: {np.mean(validation_acc[-10:])}')
plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)
plt.ylabel('Loss')
plt.xlabel('Iteration')
plt.plot(loss_vals)
plt.subplot(1, 2, 2)
plt.ylabel('Validation accuracy')
plt.xlabel('Iteration')
plt.suptitle('CNN trained on CIFAR-10, using SmallPebble.')
plt.ylim([0, 1])
plt.plot(validation_acc)
plt.show()100%|██████████| 100/100 [01:22<00:00, 1.22it/s]
Final validation accuracy: 0.3390625
Note: Accuracy is limited here by the small subset of data used to ensure the example runs quickly on CPU.
SmallPebble provides the following building blocks to make models with:
sp.Variable- Operations, such as
sp.add,sp.mul, etc. sp.get_gradientssp.Lazysp.Placeholder(this is really justsp.Lazyon the identity function)sp.learnablesp.get_learnables
The following examples show how these are used.
With SmallPebble, you can:
- Wrap NumPy arrays in
sp.Variable - Apply SmallPebble operations (e.g.
sp.matmul,sp.add, etc.) - Compute gradients with
sp.get_gradients
a = sp.Variable(np.random.random([2, 2]))
b = sp.Variable(np.random.random([2, 2]))
c = sp.Variable(np.random.random([2]))
y = sp.mul(a, b) + c
print('y.array:\n', y.array)
gradients = sp.get_gradients(y)
grad_a = gradients[a]
grad_b = gradients[b]
grad_c = gradients[c]
print('grad_a:\n', grad_a)
print('grad_b:\n', grad_b)
print('grad_c:\n', grad_c)y.array:
[[0.43457083 1.14068148]
[0.63550632 0.97019613]]
grad_a:
[[0.25686699 0.46522723]
[0.68233942 0.03916393]]
grad_b:
[[0.77756233 0.38989307]
[0.58719396 0.27840733]]
grad_c:
[2. 2.]
Note that y is computed straight away, i.e. the (forward) computation happens immediately.
Also note that y is a sp.Variable and we could continue to carry out SmallPebble operations on it.
Lazy graphs are constructed using sp.Lazy and sp.Placeholder.
lazy_node = sp.Lazy(lambda a, b: a + b)(1, 2)
print(lazy_node)
print(lazy_node.run())<smallpebble.smallpebble.Lazy object at 0x10e785960>
3
a = sp.Lazy(lambda a: a)(2)
y = sp.Lazy(lambda a, b, c: a * b + c)(a, 3, 4)
print(y)
print(y.run())<smallpebble.smallpebble.Lazy object at 0x10e785930>
10
Forward computation does not happen immediately - only when .run() is called.
a = sp.Placeholder()
b = sp.Variable(np.random.random([2, 2]))
y = sp.Lazy(sp.matmul)(a, b)
a.assign_value(sp.Variable(np.array([[1,2], [3,4]])))
result = y.run()
print('result.array:\n', result.array)result.array:
[[2.24308885 1.05636046]
[5.26999297 2.85349506]]
You can use .run() as many times as you like.
Let's change the placeholder value and re-run the graph:
a.assign_value(sp.Variable(np.array([[10,20], [30,40]])))
result = y.run()
print('result.array:\n', result.array)result.array:
[[22.43088846 10.5636046 ]
[52.6999297 28.53495056]]
Finally, let's compute gradients:
gradients = sp.get_gradients(result)Note that sp.get_gradients is called on result,
which is a sp.Variable,
not on y, which is a sp.Lazy instance.
Use sp.learnable to flag parameters as learnable,
allowing them to be extracted from a lazy graph with sp.get_learnables.
This enables a workflow of: building a model, while flagging parameters as learnable, and then extracting all the parameters in one go at the end.
a = sp.Placeholder()
b = sp.learnable(sp.Variable(np.random.random([2, 1])))
y = sp.Lazy(sp.matmul)(a, b)
y = sp.Lazy(sp.add)(y, sp.learnable(sp.Variable(np.array([5]))))
learnables = sp.get_learnables(y)
for learnable in learnables:
print(learnable)<smallpebble.smallpebble.Variable object at 0x10ec685e0>
<smallpebble.smallpebble.Variable object at 0x10e8be050>
SmallPebble (2022–2026)



