Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature load matches #77

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 45 additions & 17 deletions nolearn/lasagne/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .._compat import pickle
from collections import OrderedDict
from difflib import SequenceMatcher
import functools
import itertools
import operator
Expand Down Expand Up @@ -395,28 +396,55 @@ def get_all_params(self):
params = sum([l.get_params() for l in layers], [])
return unique(params)

def load_weights_from(self, source):
self.initialize()

if isinstance(source, str):
source = np.load(source)

if isinstance(source, NeuralNet):
source = source.get_all_params()

source_weights = [
w.get_value() if hasattr(w, 'get_value') else w for w in source]

for w1, w2 in zip(source_weights, self.get_all_params()):
if w1.shape != w2.get_value().shape:
continue
w2.set_value(w1)

def save_weights_to(self, fname):
weights = [w.get_value() for w in self.get_all_params()]
with open(fname, 'wb') as f:
pickle.dump(weights, f, -1)

@staticmethod
def _param_alignment(shapes0, shapes1):
shapes0 = list(map(str, shapes0))
shapes1 = list(map(str, shapes1))
matcher = SequenceMatcher(a=shapes0, b=shapes1)
matches = []
for block in matcher.get_matching_blocks():
if block.size == 0:
continue
matches.append((list(range(block.a, block.a + block.size)),
list(range(block.b, block.b + block.size))))
result = [line for match in matches for line in zip(*match)]
return result

def load_weights_from(self, src):
if not hasattr(self, '_initialized'):
raise AttributeError(
"Please initialize the net before loading weights using "
"the '.initialize()' method.")

if isinstance(src, str):
src = np.load(src)
if isinstance(src, NeuralNet):
src = src.get_all_params()

target = self.get_all_params()
src_params = [p.get_value() if hasattr(p, 'get_value') else p
for p in src]
target_params = [p.get_value() for p in target]

src_shapes = [p.shape for p in src_params]
target_shapes = [p.shape for p in target_params]
matches = self._param_alignment(src_shapes, target_shapes)

for i, j in matches:
target[j].set_value(src_params[i])

if not self.verbose:
continue
param_shape = 'x'.join(map(str, src_params[i].shape))
param_name = target[j].name + ' ' if target[j].name else None
print("* Loaded parameter {}(shape: {})".format(
param_name, param_shape))

def __getstate__(self):
state = dict(self.__dict__)
for attr in (
Expand Down
51 changes: 51 additions & 0 deletions nolearn/tests/test_lasagne.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,61 @@ def on_epoch_finished(nn, train_history):

# Use load_weights_from to initialize an untrained model:
nn3 = clone(nn_def)
nn3.initialize()
nn3.load_weights_from(nn2)
assert np.array_equal(nn3.predict(X_test), y_pred)


def test_lasagne_loading_params_matches():
# Loading mechanism should find layers with matching parameter
# shapes, even if they are not perfectly aligned.
from nolearn.lasagne import NeuralNet

layers0 = [('input', InputLayer),
('dense0', DenseLayer),
('dense1', DenseLayer),
('output', DenseLayer)]
net0 = NeuralNet(
layers=layers0,
input_shape=(None, 784),
dense0_num_units=100,
dense1_num_units=200,
output_nonlinearity=softmax, output_num_units=10,
update=nesterov_momentum,
update_learning_rate=0.01,
max_epochs=5,
)
net0.initialize()
net0.save_weights_to('tmp_params.np')

layers1 = [('input', InputLayer),
('dense0', DenseLayer),
('dense1', DenseLayer),
('dense2', DenseLayer),
('output', DenseLayer)]
net1 = NeuralNet(
layers=layers1,
input_shape=(None, 784),
dense0_num_units=100,
dense1_num_units=20,
dense2_num_units=200,
output_nonlinearity=softmax, output_num_units=10,
update=nesterov_momentum,
update_learning_rate=0.01,
max_epochs=5,
)
net1.initialize()

# output weights have the same shape but should differ
assert not (net0.layers_['output'].W.get_value() ==
net1.layers_['output'].W.get_value()).all()
# after loading, these weights should be equal, despite the
# additional dense layer
net1.load_weights_from('tmp_params.np')
assert (net0.layers_['output'].W.get_value() ==
net1.layers_['output'].W.get_value()).all()


def test_lasagne_functional_grid_search(mnist, monkeypatch):
# Make sure that we can satisfy the grid search interface.
from nolearn.lasagne import NeuralNet
Expand Down