Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions snntorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Note: need NumPy 1.17 or later for RNG functions
import numpy as np

import snntorch as snn


Expand Down Expand Up @@ -185,26 +186,26 @@ def _layer_check(net):
global is_sconv2dlstm
global is_slstm

for idx in range(len(list(net._modules.values()))):
if isinstance(list(net._modules.values())[idx], snn.Lapicque):
for module in net.modules():
if isinstance(module, snn.Lapicque):
is_lapicque = True
if isinstance(list(net._modules.values())[idx], snn.Synaptic):
if isinstance(module, snn.Synaptic):
is_synaptic = True
if isinstance(list(net._modules.values())[idx], snn.Leaky):
if isinstance(module, snn.Leaky):
is_leaky = True
if isinstance(list(net._modules.values())[idx], snn.LinearLeaky):
if isinstance(module, snn.LinearLeaky):
is_linearleaky = True
if isinstance(list(net._modules.values())[idx], snn.StateLeaky):
if isinstance(module, snn.StateLeaky):
is_stateleaky = True
if isinstance(list(net._modules.values())[idx], snn.Alpha):
if isinstance(module, snn.Alpha):
is_alpha = True
if isinstance(list(net._modules.values())[idx], snn.RLeaky):
if isinstance(module, snn.RLeaky):
is_rleaky = True
if isinstance(list(net._modules.values())[idx], snn.RSynaptic):
if isinstance(module, snn.RSynaptic):
is_rsynaptic = True
if isinstance(list(net._modules.values())[idx], snn.SConv2dLSTM):
if isinstance(module, snn.SConv2dLSTM):
is_sconv2dlstm = True
if isinstance(list(net._modules.values())[idx], snn.SLSTM):
if isinstance(module, snn.SLSTM):
is_slstm = True


Expand Down
107 changes: 107 additions & 0 deletions tests/test_snntorch/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import unittest

import torch
import torch.nn as nn

import snntorch as snn
from snntorch import utils


class Model(nn.Module):
def __init__(self):
super().__init__()
self.neuron = snn.Leaky(beta=0.5, init_hidden=True)
self.neuron2 = snn.Leaky(beta=0.5, init_hidden=True)

def forward(self, x):
return self.neuron(x) + self.neuron2(x)


class SequentialModel(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(snn.Leaky(beta=0.5, init_hidden=True))

def forward(self, x):
return self.net(x)


class ListModel(nn.Module):
def __init__(self):
super().__init__()
self.list = nn.ModuleList([snn.Leaky(beta=0.5, init_hidden=True)])

def forward(self, x):
return self.list[0](x)


class Block(nn.Module):
def __init__(self):
super().__init__()
self.neuron = snn.Leaky(beta=0.5, init_hidden=True)

def forward(self, x):
return self.neuron(x)


class NestedModel(nn.Module):
def __init__(self):
super().__init__()
self.block = Block()

def forward(self, x):
return self.block(x)


class MultiNeuronBlock(nn.Module):
def __init__(self):
super().__init__()
self.neuron1 = snn.Leaky(beta=0.5, init_hidden=True)
self.neuron2 = snn.Leaky(beta=0.5, init_hidden=True)

def forward(self, x):
x = self.neuron1(x)
return self.neuron2(x)


class NestedMultiNeuronModel(nn.Module):
def __init__(self):
super().__init__()
self.block = MultiNeuronBlock()

def forward(self, x):
return self.block(x)


class TestResetMechanism(unittest.TestCase):

def _check_reset(self, model_class):
model = model_class()
x = torch.randn(1, 10)
model(x)

neurons = [m for m in model.modules() if isinstance(m, snn.Leaky)]
self.assertTrue(len(neurons) > 0)

for neuron in neurons:
self.assertNotEqual(neuron.mem.abs().sum().item(), 0)

utils.reset(model)

for neuron in neurons:
self.assertEqual(neuron.mem.abs().sum().item(), 0)

def test_flat_model(self):
self._check_reset(Model)

def test_sequential_model(self):
self._check_reset(SequentialModel)

def test_list_model(self):
self._check_reset(ListModel)

def test_nested_custom_model(self):
self._check_reset(NestedModel)

def test_nested_multi_neuron_model(self):
self._check_reset(NestedMultiNeuronModel)