Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: nengo/nengo-extras
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: da5895f42f0729b1dc061f710d9c72ecfa7f775b
Choose a base ref
..
head repository: nengo/nengo-extras
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: c39951fdb29f2012e14a59b2b1731ce07076aae0
Choose a head ref
5 changes: 1 addition & 4 deletions examples/cuda_convnet/imagenet_spiking_cnn.py
Original file line number Diff line number Diff line change
@@ -11,10 +11,7 @@
from nengo_extras.cuda_convnet import CudaConvnetNetwork, load_model_pickle
from nengo_extras.gui import image_display_function

# retrieve from https://figshare.com/s/cdde71007405eb11a88f
filename = 'ilsvrc-2012-batches-test3.tar.gz'
X_test, Y_test, data_mean, label_names = load_ilsvrc2012(filename, n_files=1)

X_test, Y_test, data_mean, label_names = load_ilsvrc2012(n_files=1)
X_test = X_test.astype('float32')

# crop data
81 changes: 76 additions & 5 deletions nengo_extras/convnet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import numpy as np

from nengo.processes import Process
from nengo.params import EnumParam, NdarrayParam, ShapeParam
from nengo.params import EnumParam, NdarrayParam, NumberParam, ShapeParam
from nengo.utils.compat import is_iterable, range


def softmax(x, axis=None):
"""Stable softmax function"""
ex = np.exp(x - x.max(axis=axis, keepdims=True))
return ex / ex.sum(axis=axis, keepdims=True)


class Conv2d(Process):
"""Perform 2-D (image) convolution on an input.
@@ -225,7 +231,72 @@ def step_pool2d(t, x):
return step_pool2d


def softmax(x, axis=None):
"""Stable softmax function"""
ex = np.exp(x - x.max(axis=axis, keepdims=True))
return ex / ex.sum(axis=axis, keepdims=True)
class PresentJitteredImages(Process):
images = NdarrayParam('images', shape=('...',))
image_shape = ShapeParam('image_shape', length=3, low=1)
output_shape = ShapeParam('output_shape', length=2, low=1)
presentation_time = NumberParam('presentation_time', low=0, low_open=True)
jitter_std = NumberParam('jitter_std', low=0, low_open=True, optional=True)
jitter_tau = NumberParam('jitter_tau', low=0, low_open=True)

def __init__(self, images, presentation_time, output_shape,
jitter_std=None, jitter_tau=None, **kwargs):
import scipy.ndimage.interpolation # noqa: F401
# ^ required for simulation, so check it here

self.images = images
self.presentation_time = presentation_time
self.image_shape = images.shape[1:]
self.output_shape = output_shape
self.jitter_std = jitter_std
self.jitter_tau = (presentation_time if jitter_tau is None else
jitter_tau)

nc = self.image_shape[0]
nyi, nyj = self.output_shape
super(PresentJitteredImages, self).__init__(
default_size_in=0, default_size_out=nc*nyi*nyj, **kwargs)

def make_step(self, shape_in, shape_out, dt, rng):
import scipy.ndimage.interpolation

nc, nxi, nxj = self.image_shape
nyi, nyj = self.output_shape
ni, nj = nxi - nyi, nxj - nyj
nij = np.array([ni, nj])
assert shape_in == (0,)
assert shape_out == (nc*nyi*nyj,)

if self.jitter_std is None:
si, sj = ni / 4., nj / 4.
else:
si = sj = self.jitter_std

tau = self.jitter_tau

n = len(self.images)
images = self.images.reshape((n, nc, nxi, nxj))
presentation_time = float(self.presentation_time)

cij = (nij - 1) / 2.
dt7tau = dt / tau
sigma2 = np.sqrt(2.*dt/tau) * np.array([si, sj])
ij = cij.copy()

def step_presentjitteredimages(t):
# update jitter position
ij0 = dt7tau*(cij - ij) + sigma2*rng.normal(size=2)
ij[:] = (ij + ij0).clip((0, 0), (ni, nj))

# select image
k = int((t-dt) / presentation_time + 1e-7)
image = images[k % n]

# interpolate jittered sub-image
i, j = ij
image = scipy.ndimage.interpolation.shift(
image, (0, ni-i, nj-j))[:, -nyi:, -nyj:]

return image.ravel()

return step_presentjitteredimages
4 changes: 4 additions & 0 deletions nengo_extras/deepnetworks.py
Original file line number Diff line number Diff line change
@@ -325,6 +325,10 @@ def input(self):
def output(self):
return self._output

@property
def neurons(self):
return self.ensemble.neurons

@property
def amplitude(self):
return self.connection.transform
11 changes: 8 additions & 3 deletions nengo_extras/learning_rules.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
from nengo.builder.operator import DotInc, ElementwiseInc, Reset, SimPyFunc
from nengo.exceptions import ValidationError
from nengo.learning_rules import LearningRuleType
from nengo.params import FunctionParam, NumberParam
from nengo.params import EnumParam, FunctionParam, NumberParam
from nengo.synapses import Lowpass


@@ -65,15 +65,17 @@ class DeltaRule(LearningRuleType):
pre_tau = NumberParam('pre_tau', low=0, low_open=True)
post_tau = NumberParam('post_tau', low=0, low_open=True, optional=True)
post_fn = DeltaRuleFunctionParam('post_fn', optional=True)
post_target = EnumParam('post_target', values=('in', 'out'))

def __init__(self, learning_rate=1e-4, pre_tau=0.005,
post_fn=None, post_tau=None):
post_fn=None, post_tau=None, post_target='in'):
if learning_rate >= 1.0:
warnings.warn("This learning rate is very high, and can result "
"in floating point errors from too much current.")
self.pre_tau = pre_tau
self.post_tau = post_tau
self.post_fn = post_fn
self.post_target = post_target
super(DeltaRule, self).__init__(learning_rate, size_in='post')

@property
@@ -87,6 +89,8 @@ def _argreprs(self):
args.append("post_fn=%s" % self.post_fn.function)
if self.post_tau is not None:
args.append("post_tau=%f" % self.post_tau)
if self.post_target != 'in':
args.append("post_target=%s" % self.post_target)

return args

@@ -103,8 +107,9 @@ def build_delta_rule(model, delta_rule, rule):
# Multiply by post_fn output if necessary
post_fn = delta_rule.post_fn.function
post_tau = delta_rule.post_tau
post_target = delta_rule.post_target
if post_fn is not None:
post_sig = model.sig[conn.post_obj]['in']
post_sig = model.sig[conn.post_obj][post_target]
post_synapse = Lowpass(post_tau) if post_tau is not None else None
post_input = (post_sig if post_synapse is None else
model.build(post_synapse, post_sig))
6 changes: 1 addition & 5 deletions nengo_extras/neurons.py
Original file line number Diff line number Diff line change
@@ -50,20 +50,16 @@ class SoftLIFRate(nengo.neurons.LIFRate):
"""

sigma = NumberParam('sigma', low=0, low_open=True)
amplitude = NumberParam('amplitude', low=0, low_open=True)

def __init__(self, sigma=1., amplitude=1., **lif_args):
def __init__(self, sigma=1., **lif_args):
super(SoftLIFRate, self).__init__(**lif_args)
self.sigma = sigma # smoothing around the threshold
self.amplitude = amplitude # scaling on the output rates

@property
def _argreprs(self):
args = super(SoftLIFRate, self)._argreprs
if self.sigma != 1.:
args.append("sigma=%s" % self.sigma)
if self.amplitude != 1.:
args.append("amplitude=%s" % self.amplitude)
return args

def rates(self, x, gain, bias):
34 changes: 28 additions & 6 deletions nengo_extras/tests/test_learning_rules.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,10 @@
from nengo_extras.learning_rules import DeltaRule


def test_delta_rule(Simulator, seed, rng, plt):
@pytest.mark.parametrize('post_target', [None, 'in', 'out'])
def test_delta_rule(Simulator, seed, rng, plt, post_target):
f = lambda x: np.abs(x)

learning_rate = 2e-2

tau_s = 0.005
@@ -26,6 +29,19 @@ def test_delta_rule(Simulator, seed, rng, plt):
max_rates=nengo.dists.Choice([max_rate]),
intercepts=nengo.dists.Uniform(-1, 0.8))

if post_target == 'in':
step = lambda j: (j > 1).astype(j.dtype)
learning_rule_type = DeltaRule(
learning_rate=learning_rate, post_fn=step, post_target=post_target,
post_tau=0.005)
elif post_target == 'out':
step = lambda s: (s > 18).astype(s.dtype)
learning_rule_type = DeltaRule(
learning_rate=learning_rate, post_fn=step, post_target=post_target,
post_tau=0.005)
else:
learning_rule_type = DeltaRule(learning_rate=learning_rate)

with nengo.Network(seed=seed) as model:
u = nengo.Node(nengo.processes.WhiteSignal(period=10, high=5))
a = nengo.Ensemble(n, 1, **ens_params)
@@ -39,13 +55,13 @@ def test_delta_rule(Simulator, seed, rng, plt):
e = nengo.Node(lambda t, x: x if t < t_train else 0, size_in=1)
eb = nengo.Node(size_in=n)

nengo.Connection(u, e, transform=-1,
nengo.Connection(u, e, transform=-1, function=f,
synapse=nengo.synapses.Alpha(tau_s))
nengo.Connection(b.neurons, e, transform=decoders, synapse=tau_s)
nengo.Connection(e, eb, synapse=None, transform=decoders.T)

c.transform = np.zeros((n, n))
c.learning_rule_type = DeltaRule(learning_rate=learning_rate)
c.learning_rule_type = learning_rule_type
nengo.Connection(eb, c.learning_rule, synapse=None)

ep = nengo.Probe(e)
@@ -58,21 +74,27 @@ def test_delta_rule(Simulator, seed, rng, plt):
t = sim.trange()
filt = nengo.synapses.Alpha(0.005)
x = filt.filtfilt(sim.data[up])
fx = f(x)
y = filt.filtfilt(sim.data[yp])

plt.subplot(311)
plt.plot(t, sim.data[ep])
plt.ylabel('error')

plt.subplot(312)
plt.plot(t, x)
plt.plot(t, fx)
plt.plot(t, y)
plt.ylabel('output')

plt.subplot(313)
plt.plot(t[t > t_train], x[t > t_train])
plt.plot(t[t > t_train], fx[t > t_train])
plt.plot(t[t > t_train], y[t > t_train])
plt.ylabel('test output')

plt.tight_layout()

m = t > t_train
rms_error = rms(y[m] - x[m]) / rms(x[m])
rms_error = rms(y[m] - fx[m]) / rms(fx[m])
assert rms_error < 0.3


9 changes: 3 additions & 6 deletions nengo_extras/tests/test_solvers.py
Original file line number Diff line number Diff line change
@@ -47,12 +47,9 @@ def gen(m):
outs = np.dot(acts, sim.data[c].weights.T)
error = (np.argmax(outs, axis=1) != testY).mean()

assert error < 0.1


# def test_lstsq(Simulator, seed, rng):
# solver = nengo.solvers.LstsqL2(reg=0.01)
# _test_classifier(solver, Simulator, seed=seed, rng=rng)
assert error < 0.065
# ^ Threshold chosen based on empirical upper bound for solvers in the
# repo. Should catch if something breaks.


def test_lstsqclassifier(Simulator, seed, rng):