Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add an embedding layer for the months #120

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/models/neural_networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math

import torch
from torch import nn
from torch.nn import functional as F

import shap
Expand Down Expand Up @@ -348,3 +349,17 @@ def make_shap_input(self, x: TrainData, start_idx: int = 0,
else:
output_tensors.append(x.static[start_idx: start_idx + num_inputs])
return output_tensors


class OneHotMonthEncoder(nn.Module):
"""Since the months are one hot encoded, using a linear layer
is equivalent to the lookup action of an embedding layer.
"""
def __init__(self, embedding_size: int = 12) -> None:
super().__init__()

self.encoder = nn.Linear(in_features=12, out_features=embedding_size,
bias=True)

def forward(self, x):
return self.encoder(x)
155 changes: 21 additions & 134 deletions src/models/neural_networks/ealstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import Dict, List, Optional, Tuple

from .base import NNBase
from .base import NNBase, OneHotMonthEncoder


class EARecurrentNetwork(NNBase):
Expand All @@ -21,19 +21,23 @@ def __init__(self, hidden_size: int,
experiment: str = 'one_month_forecast',
pred_months: Optional[List[int]] = None,
include_latlons: bool = False,
include_pred_month: bool = True,
pred_month_embedding_size: Optional[int] = 12,
include_monthly_aggs: bool = True,
include_yearly_aggs: bool = True,
surrounding_pixels: Optional[int] = None,
ignore_vars: Optional[List[str]] = None,
include_static: bool = True,
device: str = 'cuda:0') -> None:
include_pred_month = False
if pred_month_embedding_size is not None:
include_pred_month = True
super().__init__(data_folder, batch_size, experiment, pred_months, include_pred_month,
include_latlons, include_monthly_aggs, include_yearly_aggs,
surrounding_pixels, ignore_vars, include_static, device)

# to initialize and save the model
self.hidden_size = hidden_size
self.pred_month_embedding_size = pred_month_embedding_size
self.rnn_dropout = rnn_dropout
self.input_dense = copy(dense_features) # this is to make sure we can reload the model
if dense_features is None: dense_features = []
Expand All @@ -58,11 +62,11 @@ def save_model(self):
'hidden_size': self.hidden_size,
'rnn_dropout': self.rnn_dropout,
'dense_features': self.input_dense,
'include_pred_month': self.include_pred_month,
'include_latlons': self.include_latlons,
'surrounding_pixels': self.surrounding_pixels,
'include_monthly_aggs': self.include_monthly_aggs,
'include_yearly_aggs': self.include_yearly_aggs,
'pred_month_embedding_size': self.pred_month_embedding_size,
'experiment': self.experiment,
'ignore_vars': self.ignore_vars,
'include_static': self.include_static,
Expand All @@ -82,7 +86,7 @@ def load(self, state_dict: Dict, features_per_month: int, current_size: Optional
dense_features=self.dense_features,
hidden_size=self.hidden_size,
rnn_dropout=self.rnn_dropout,
include_pred_month=self.include_pred_month,
pred_month_embedding_size=self.pred_month_embedding_size,
experiment=self.experiment,
current_size=self.current_size,
yearly_agg_size=self.yearly_agg_size,
Expand Down Expand Up @@ -114,7 +118,7 @@ def _initialize_model(self, x_ref: Optional[Tuple[torch.Tensor, ...]]) -> nn.Mod
dense_features=self.dense_features,
hidden_size=self.hidden_size,
rnn_dropout=self.rnn_dropout,
include_pred_month=self.include_pred_month,
pred_month_embedding_size=self.pred_month_embedding_size,
experiment=self.experiment,
yearly_agg_size=self.yearly_agg_size,
current_size=self.current_size,
Expand All @@ -126,13 +130,14 @@ def _initialize_model(self, x_ref: Optional[Tuple[torch.Tensor, ...]]) -> nn.Mod

class EALSTM(nn.Module):
def __init__(self, features_per_month, dense_features, hidden_size,
rnn_dropout, include_latlons, include_pred_month,
rnn_dropout, include_latlons, pred_month_embedding_size,
experiment, yearly_agg_size=None, current_size=None,
static_size=None):
super().__init__()

self.experiment = experiment
self.include_pred_month = include_pred_month
self.pred_month_embedding_size = pred_month_embedding_size
self.include_pred_month = False
self.include_latlons = include_latlons
self.include_yearly_agg = False
self.include_static = False
Expand All @@ -148,14 +153,16 @@ def __init__(self, features_per_month, dense_features, hidden_size,
if static_size is not None:
self.include_static = True
ea_static_size += static_size
if include_pred_month:
ea_static_size += 12
if pred_month_embedding_size is not None:
self.include_pred_month = True
self.month_encoder = OneHotMonthEncoder(pred_month_embedding_size)
ea_static_size += pred_month_embedding_size

self.dropout = nn.Dropout(rnn_dropout)
self.rnn = OrgEALSTMCell(input_size_dyn=features_per_month,
input_size_stat=ea_static_size,
hidden_size=hidden_size,
batch_first=True)
self.rnn = EALSTMCell(input_size_dyn=features_per_month,
input_size_stat=ea_static_size,
hidden_size=hidden_size,
batch_first=True)
self.hidden_size = hidden_size
self.rnn_dropout = nn.Dropout(rnn_dropout)

Expand Down Expand Up @@ -198,7 +205,7 @@ def forward(self, x, pred_month=None, latlons=None, current=None, yearly_aggs=No
if self.include_static:
static_x.append(static)
if self.include_pred_month:
static_x.append(pred_month)
static_x.append(self.month_encoder(pred_month))

hidden_state, cell_state = self.rnn(x, torch.cat(static_x, dim=-1))

Expand All @@ -214,126 +221,6 @@ def forward(self, x, pred_month=None, latlons=None, current=None, yearly_aggs=No


class EALSTMCell(nn.Module):
"""See below. Implemented using modules so it can be explained with shap
"""
def __init__(self,
input_size_dyn: int,
input_size_stat: int,
hidden_size: int,
batch_first: bool = True):
super().__init__()

self.input_size_dyn = input_size_dyn
self.input_size_stat = input_size_stat
self.hidden_size = hidden_size
self.batch_first = batch_first

self.forget_gate_i = nn.Linear(in_features=input_size_dyn,
out_features=hidden_size, bias=False)
self.forget_gate_h = nn.Linear(in_features=hidden_size, out_features=hidden_size,
bias=True)

self.update_gate = nn.Sequential(*[
nn.Linear(in_features=input_size_stat, out_features=hidden_size),
nn.Sigmoid()
])

self.update_candidates_i = nn.Linear(in_features=input_size_dyn, out_features=hidden_size,
bias=False)
self.update_candidates_h = nn.Linear(in_features=hidden_size, out_features=hidden_size,
bias=True)

self.output_gate_i = nn.Linear(in_features=input_size_dyn, out_features=hidden_size,
bias=False)
self.output_gate_h = nn.Linear(in_features=hidden_size, out_features=hidden_size,
bias=True)

self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()

self.reset_parameters()

def reset_parameters(self):
self._reset_i(self.forget_gate_i)
self._reset_i(self.update_candidates_i)
self._reset_i(self.output_gate_i)
self._reset_i(self.update_gate[0])
nn.init.constant_(self.update_gate[0].bias.data, val=0)

self._reset_h(self.forget_gate_h, self.hidden_size)
self._reset_h(self.update_candidates_h, self.hidden_size)
self._reset_h(self.output_gate_h, self.hidden_size)

@staticmethod
def _reset_i(layer):
nn.init.orthogonal(layer.weight.data)

@staticmethod
def _reset_h(layer, hidden_size):
weight_hh_data = torch.eye(hidden_size)
layer.weight.data = weight_hh_data
nn.init.constant_(layer.bias.data, val=0)

def forward(self, x_d, x_s):
"""[summary]
Parameters
----------
x_d : torch.Tensor
Tensor, containing a batch of sequences of the dynamic features. Shape has to match
the format specified with batch_first.
x_s : torch.Tensor
Tensor, containing a batch of static features.
Returns
-------
h_n : torch.Tensor
The hidden states of each time step of each sample in the batch.
c_n : torch.Tensor
The cell states of each time step of each sample in the batch.
"""
if self.batch_first:
x_d = x_d.transpose(0, 1)

seq_len, batch_size, _ = x_d.size()

h_0 = x_d.data.new(batch_size, self.hidden_size).zero_()
c_0 = x_d.data.new(batch_size, self.hidden_size).zero_()
h_x = (h_0, c_0)

# empty lists to temporally store all intermediate hidden/cell states
h_n, c_n = [], []

# calculate input gate only once because inputs are static
i = self.update_gate(x_s)

# perform forward steps over input sequence
for t in range(seq_len):
h_0, c_0 = h_x

forget_state = self.sigmoid(self.forget_gate_i(x_d[t]) + self.forget_gate_h(h_0))
cell_candidates = self.tanh(self.update_candidates_i(x_d[t]) +
self.update_candidates_h(h_0))
output_state = self.sigmoid(self.output_gate_i(x_d[t]) + self.output_gate_h(h_0))

c_1 = forget_state * c_0 + i * cell_candidates
h_1 = output_state * self.tanh(c_1)

# store intermediate hidden/cell state in list
h_n.append(h_1)
c_n.append(c_1)

h_x = (h_1, c_1)

h_n = torch.stack(h_n, 0)
c_n = torch.stack(c_n, 0)

if self.batch_first:
h_n = h_n.transpose(0, 1)
c_n = c_n.transpose(0, 1)

return h_n, c_n


class OrgEALSTMCell(nn.Module):
"""Implementation of the Entity-Aware-LSTM (EA-LSTM)

This code was copied from
Expand Down
25 changes: 25 additions & 0 deletions tests/models/neural_networks/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch

from src.models.neural_networks.base import OneHotMonthEncoder


class TestOneHotMonthEncoder:

def test_embedding_likeness(self):

model = OneHotMonthEncoder(15)

init_weights = torch.stack([torch.arange(1, 13)] * 15).float()
init_bias = torch.zeros(15).float()
model.encoder.weight.data = init_weights
model.encoder.bias.data = init_bias

for month in range(1, 13):
indices = torch.tensor([month, month, month])
month_tensor = torch.eye(14)[indices.long()][:, 1:-1].float()

model.eval()
with torch.no_grad():
output_vals = model(month_tensor)

assert (output_vals == month).all()
62 changes: 5 additions & 57 deletions tests/models/neural_networks/test_ealstm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import pickle
import pytest
from copy import copy
import numpy as np

import torch
from torch import nn

from src.models.neural_networks.ealstm import EALSTM, EALSTMCell, OrgEALSTMCell
from src.models.neural_networks.ealstm import EALSTM
from src.models import EARecurrentNetwork

from tests.utils import _make_dataset
Expand All @@ -22,21 +20,21 @@ def test_save(self, tmp_path, monkeypatch):
hidden_size = 128
rnn_dropout = 0.25
include_latlons = True
include_pred_month = True
pred_month_embedding_size = 12
include_yearly_aggs = True
yearly_agg_size = 3

def mocktrain(self):
self.model = EALSTM(features_per_month, dense_features, hidden_size,
rnn_dropout, include_latlons, include_pred_month,
rnn_dropout, include_latlons, pred_month_embedding_size,
experiment='one_month_forecast', yearly_agg_size=yearly_agg_size)
self.features_per_month = features_per_month
self.yearly_agg_size = yearly_agg_size

monkeypatch.setattr(EARecurrentNetwork, 'train', mocktrain)

model = EARecurrentNetwork(hidden_size=hidden_size, dense_features=dense_features,
include_pred_month=include_pred_month,
pred_month_embedding_size=pred_month_embedding_size,
include_latlons=include_latlons,
rnn_dropout=rnn_dropout, data_folder=tmp_path,
include_yearly_aggs=include_yearly_aggs)
Expand All @@ -56,7 +54,7 @@ def mocktrain(self):
assert model_dict['hidden_size'] == hidden_size
assert model_dict['rnn_dropout'] == rnn_dropout
assert model_dict['dense_features'] == input_dense_features
assert model_dict['include_pred_month'] == include_pred_month
assert model_dict['pred_month_embedding_size'] == pred_month_embedding_size
assert model_dict['include_latlons'] == include_latlons
assert model_dict['include_yearly_aggs'] == include_yearly_aggs
assert model_dict['experiment'] == 'one_month_forecast'
Expand Down Expand Up @@ -151,53 +149,3 @@ def test_predict(self, tmp_path, use_pred_months):

# _make_dataset with const=True returns all ones
assert (test_arrays_dict['hello']['y'] == 1).all()


class TestEALSTMCell:
@staticmethod
def test_ealstm(monkeypatch):
"""
We implement our own unrolled RNN, so that it can be explained with
shap. This test makes sure it roughly mirrors the behaviour of the pytorch
LSTM.
"""

batch_size, hidden_size, timesteps, dyn_input, static_input = 3, 5, 2, 6, 4

@staticmethod
def i_init(layer):
nn.init.constant_(layer.weight.data, val=1)
monkeypatch.setattr(EALSTMCell, '_reset_i', i_init)

def org_init(self):
"""Initialize all learnable parameters of the LSTM"""
nn.init.constant_(self.weight_ih.data, val=1)
nn.init.constant_(self.weight_sh, val=1)

weight_hh_data = torch.eye(self.hidden_size)
weight_hh_data = weight_hh_data.repeat(1, 3)
self.weight_hh.data = weight_hh_data

nn.init.constant_(self.bias.data, val=0)
nn.init.constant_(self.bias_s.data, val=0)
monkeypatch.setattr(OrgEALSTMCell, 'reset_parameters', org_init)

org_ealstm = OrgEALSTMCell(input_size_dyn=dyn_input,
input_size_stat=static_input,
hidden_size=hidden_size)

our_ealstm = EALSTMCell(input_size_dyn=dyn_input,
input_size_stat=static_input,
hidden_size=hidden_size)

static = torch.rand(batch_size, static_input)
dynamic = torch.rand(batch_size, timesteps, dyn_input)

with torch.no_grad():
org_hn, org_cn = org_ealstm(dynamic, static)
our_hn, our_cn = our_ealstm(dynamic, static)

assert np.isclose(org_hn.numpy(), our_hn.numpy(), 0.01).all(), \
"Difference in hidden state"
assert np.isclose(org_cn.numpy(), our_cn.numpy(), 0.01).all(), \
"Difference in cell state"