Skip to content

Commit

Permalink
Fix types in example code.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583453706
  • Loading branch information
DrMarcII authored and Flax Authors committed Nov 17, 2023
1 parent 70214f4 commit b18c613
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions examples/seq2seq/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def get_train_state(rng: PRNGKey, ctable: CTable) -> train_state.TrainState:
return state


def cross_entropy_loss(logits: Array, labels: Array, lengths: Array) -> float:
def cross_entropy_loss(
logits: Array, labels: Array, lengths: Array
) -> jax.Array:
"""Returns cross-entropy loss."""
xe = jnp.sum(nn.log_softmax(logits) * labels, axis=-1)
masked_xe = jnp.mean(mask_sequences(xe, lengths))
Expand All @@ -113,7 +115,7 @@ def cross_entropy_loss(logits: Array, labels: Array, lengths: Array) -> float:

def compute_metrics(
logits: Array, labels: Array, eos_id: int
) -> Dict[str, float]:
) -> Dict[str, jax.Array]:
"""Computes metrics and returns them."""
lengths = get_sequence_lengths(labels, eos_id)
loss = cross_entropy_loss(logits, labels, lengths)
Expand All @@ -134,7 +136,7 @@ def compute_metrics(
@jax.jit
def train_step(
state: train_state.TrainState, batch: Array, lstm_rng: PRNGKey, eos_id: int
) -> Tuple[train_state.TrainState, Dict[str, float]]:
) -> Tuple[train_state.TrainState, Dict[str, jax.Array]]:
"""Trains one step."""
labels = batch['answer'][:, 1:]
lstm_key = jax.random.fold_in(lstm_rng, state.step)
Expand Down

0 comments on commit b18c613

Please sign in to comment.