-
Hi! JAX Process: 0 / 1
JAX Local Devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
[15:59:22] Epoch: 1
Training: 0%| | 0/937 [00:00<?, ?it/s] Below are the training and test steps that I used, which were also working just fine on other runtimes. The link to full training script I've removed some bits of code that I felt are not causing the issue, however, they are accessible by the full training script link below. train_it, train_examples = create_iterator(
"mnist", cfg.batch_size, cfg.data_shape, cfg.split_keys[0]
)
train_steps = train_examples // cfg.batch_size
test_steps = test_examples // cfg.batch_size
rng = random.PRNGKey(0)
rng, init_rng = random.split(rng)
dropout_rng = random.split(rng, jax.local_device_count())
model = model_dict["PVT_V2_B0"](num_classes=cfg.num_classes)
params = model.init(rng, jnp.ones([1, 28, 28, 3]), False)["params"]
tx = optax.adamw(learning_rate=schedule)
state = train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=tx,
)
state = jax_utils.replicate(state)
def train_step(state, inputs, labels, dropout_rng):
"""Perform a single training step."""
dropout_rng = random.fold_in(dropout_rng, state.step)
def loss_fn(params):
logits = state.apply_fn(
{"params": params},
inputs,
trainable=True,
rngs={"dropout": dropout_rng},
)
one_hot = jax.nn.one_hot(labels, cfg.num_classes)
loss = optax.softmax_cross_entropy(logits, one_hot).mean()
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
probs = jax.lax.pmean(jax.nn.softmax(logits), axis_name="batch")
accuracy = jnp.mean(jnp.argmax(probs, -1) == labels)
grads = jax.lax.pmean(grads, axis_name="batch")
updated_state = state.apply_gradients(grads=grads)
return updated_state, loss, accuracy
p_train_step = jax.pmap(train_step, axis_name="batch")
train_loss, train_accuracy = list(), list()
"""
Single train epoch (for testing)
"""
for _, batch in zip(
range(train_steps),
tqdm(
train_it,
total=train_steps,
desc=colored(f"{' '*10} Training", "magenta"),
colour="cyan",
),
):
inputs, labels = batch["image"], batch["label"]
state, loss, accuracy = p_train_step(state, inputs, labels, dropout_rng)
train_loss.append(jax_utils.unreplicate(loss))
train_accuracy.append(jax_utils.unreplicate(accuracy)) Training script: https://github.com/muhd-umer/PVT-Flax/blob/main/train.py Any help is appreciated ❤️ |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
It seems that the Colab TPU drivers were the cause of this issue, and not the code itself. Since distributed training is working as expected, I'm marking this as answered. |
Beta Was this translation helpful? Give feedback.
It seems that the Colab TPU drivers were the cause of this issue, and not the code itself. Since distributed training is working as expected, I'm marking this as answered.