Implementation of "Kernelized information bottleneck leads to biologically plausible 3-factor Hebbian learning in deep networks" by R. Pogodin and P. E. Latham (https://arxiv.org/abs/2006.07123)
Scripts with reported experiments are stored in ./experiments_scripts/.
The _grid.sh files run a loop over all setups, calling _single.sh that
sets the hyperparameters and executes experiments.py.
experiment.py runs a single experiments with the given command line arguments;
run python3 experiments.py --help for the list of arguments
(or see utils.py, function parse_arguments).
./experiments_scripts/run_mlp_experiments_grid.sh
Mean test accuracy over 5 trials (first row: method; cossim: cosine similarity kernel; Gaussian: Gaussian kernel; last layer: training of the last layer only; second row: additional modification of the method grp+div: grouping with divisive normalization; grp: grouping without divisive normalization):
| backprop | last layer | cossim | Gaussian | |||||||
|---|---|---|---|---|---|---|---|---|---|---|
| grp+div | grp+div | grp | grp+div | grp | grp+div | |||||
| MNIST | 98.6 | 98.4 | 92.0 | 95.4 | 94.9 | 95.8 | 96.3 | 94.6 | 98.4 | 98.1 |
| fMNIST | 90.2 | 90.8 | 83.3 | 85.7 | 86.3 | 88.7 | 88.1 | 86.5 | 88.6 | 88.8 |
| kMNIST | 93.4 | 93.5 | 71.2 | 78.2 | 80.4 | 86.2 | 87.2 | 80.2 | 92.7 | 91.1 |
| CIFAR10 | 60.0 | 60.3 | 39.2 | 38.0 | 51.1 | 52.5 | 47.6 | 41.4 | 48.4 | 46.4 |
./experiments_scripts/run_vgg_sgd_experiments_grid.sh
./experiments_scripts/run_vgg_adam_experiments_grid.sh
Mean test accuracy on CIFAR10 over 5 runs for a 7-layer conv nets (1x and 2x wide). Cossim: cosine similarity; divnorm: divisive normalization; bn: batchnorm. Empty entries: experiments for which we didn't find a satisfying set of parameters due to instabilities in the methods.
| backprop | pHSIC: cossim | pHSIC: Gaussian | ||||
|---|---|---|---|---|---|---|
| div | grp | grp+div | grp | grp+div | ||
| 1x wide net + SGD | 91.0 | 91.0 | 88.8 | 89.8 | 86.2 | |
| 2x wide net + SGD | 91.9 | 90.9 | 89.4 | 91.3 | 90.4 | |
| 1x wide net + AdamW + batchnorm | 94.1 | 94.3 | 91.3 | 90.1 | 89.9 | 89.4 |
| 2x wide net + AdamW + batchnorm | 94.3 | 94.5 | 91.9 | 91.0 | 91.0 | 91.2 |
Mean test accuracy on CIFAR10 over 5 runs for a 7-layer conv nets (1x and 2x wide). FA: feedback alignment; sign sym.: sign symmetry; layer class.: layer-wise classification; divnorm: divisive normalization; bn: batchnorm. Empty entries: experiments for which we didn't find a satisfying set of parameters due to instabilities in the methods.
| FA | sign sym. | layer class. | ||
|---|---|---|---|---|
| +FA | ||||
| 1x wide net + SGD | 90.0 | |||
| 2x wide net + SGD | 90.3 | |||
| 1x wide net + SGD + divnorm | 80.4 | 89.5 | 90.5 | 81.0 |
| 2x wide net + SGD + divnorm | 80.6 | 91.3 | 91.3 | 81.2 |
| 1x wide net + AdamW + bn | 82.4 | 93.6 | 92.1 | 90.3 |
| 2x wide net + AdamW + bn | 81.6 | 93.9 | 92.1 | 91.1 |
The code was tested with the following setup:
OS: Debian GNU/Linux 9.12 (stretch)
CUDA: Driver Version: 418.87.01, CUDA Version: 10.1
Python 3.7.6
numpy 1.18.1
torch 1.6.0+cu101
torchvision 0.7.0+cu101