diff --git a/examples/seq2seq/train.py b/examples/seq2seq/train.py index 55141a43e8..b9dd0316ce 100644 --- a/examples/seq2seq/train.py +++ b/examples/seq2seq/train.py @@ -104,7 +104,9 @@ def get_train_state(rng: PRNGKey, ctable: CTable) -> train_state.TrainState: return state -def cross_entropy_loss(logits: Array, labels: Array, lengths: Array) -> float: +def cross_entropy_loss( + logits: Array, labels: Array, lengths: Array +) -> jax.Array: """Returns cross-entropy loss.""" xe = jnp.sum(nn.log_softmax(logits) * labels, axis=-1) masked_xe = jnp.mean(mask_sequences(xe, lengths)) @@ -113,7 +115,7 @@ def cross_entropy_loss(logits: Array, labels: Array, lengths: Array) -> float: def compute_metrics( logits: Array, labels: Array, eos_id: int -) -> Dict[str, float]: +) -> Dict[str, jax.Array]: """Computes metrics and returns them.""" lengths = get_sequence_lengths(labels, eos_id) loss = cross_entropy_loss(logits, labels, lengths) @@ -134,7 +136,7 @@ def compute_metrics( @jax.jit def train_step( state: train_state.TrainState, batch: Array, lstm_rng: PRNGKey, eos_id: int -) -> Tuple[train_state.TrainState, Dict[str, float]]: +) -> Tuple[train_state.TrainState, Dict[str, jax.Array]]: """Trains one step.""" labels = batch['answer'][:, 1:] lstm_key = jax.random.fold_in(lstm_rng, state.step)