diff --git a/da_rnn/keras/model.py b/da_rnn/keras/model.py index 5406b0b..8b357b5 100644 --- a/da_rnn/keras/model.py +++ b/da_rnn/keras/model.py @@ -7,7 +7,8 @@ Layer, LSTM, Dense, - Permute + Permute, + Masking ) from tensorflow.keras.models import Model @@ -145,6 +146,8 @@ def __init__( self.input_lstm = LSTM(m, return_state=True) self.input_attention = InputAttention(T) + self.masking = Masking(mask_value=float('nan')) + def call(self, X) -> tf.Tensor: """ @@ -162,6 +165,8 @@ def call(self, X) -> tf.Tensor: X_encoded = [] + X = self.masking(X) + for t in range(self.T): Alpha_t = self.input_attention(hidden_state, cell_state, X) @@ -307,6 +312,8 @@ def __init__( self.Wb = Dense(p) self.vb = Dense(y_dim) + self.masking = Masking(mask_value=float('nan')) + def call(self, Y, X_encoded) -> tf.Tensor: """ Args: @@ -324,6 +331,8 @@ def call(self, Y, X_encoded) -> tf.Tensor: # c in the paper context_vector = tf.zeros((batch_size, 1, self.m)) # -> (batch_size, 1, m) + + X_encoded = self.masking(X_encoded) for t in range(self.T - 1): Beta_t = self.temp_attention(