-
Hello, I am trying to train a simple binary image classifier to learn how to use flax and jax. The approach I describe here is a mix of https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html and https://flax.readthedocs.io/en/latest/notebooks/annotated_mnist.html from the docs. To start, my training data is stored in My model is defined as follows (a simple MLP): class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = jnp.ravel(x) # flatten
x = nn.Dense(features=784)(x)
x = nn.Dense(features=2)(x) # 2 because we are using softmax cross entropy
return nn.softmax(x) # This will output probabilities for each class e.g. [0.8, 0.2] I initialized my params as follows: mlp_model = MLP()
sample_batch = jnp.ones((1, 28,28))
sample_batch.shape
# output: (1, 28, 28)
parameters = mlp_model.init(jax.random.PRNGKey(0), sample_batch) And with these randomly initialized params, we can now predict on a sample (2nd sample at idx 1 here) from the dataset: real_output = mlp_model.apply(parameters, d_set[1][0])
real_output
# output: DeviceArray([1.000000e+00, 1.981861e-30], dtype=float32) I then move onto training. For the setup, I defined a loss function that uses softmax cross entropy. def binary_loss(logits, labels):
# binary cross entropy loss
# both logits and labels should be shape [batch, num_classes]
# so one hot encode the label given
one_hot_labels = jax.nn.one_hot([labels], num_classes=2)
return optax.softmax_cross_entropy(logits=logits.reshape(1,2), labels=one_hot_labels) Now onto the training loop to update the model params. This is where I think I am making a mistake possibly. The setup above is quite similar to the annotated mnist notebook but this is where I diverge particularly because I am trying to produce an extremely minimal example of training a classifier and thus avoid using patterns like train state. Particularly, to update the params, I am essentially following the steps taken in this section of Flax Basics: https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html#optimizing-with-optax. Steps: Setup: import optax
# Same as version above but with model.apply().
@jax.jit # wrap in jit for speed-up!
def forward_loss(params, x, y):
pred_logits = mlp_model.apply(params, x)
return binary_loss(pred_logits, y).mean() # mean to extract value
tx = optax.sgd(learning_rate=0.001) # tx is the optimizer
opt_state = tx.init(parameters)
loss_grad_fn = jax.value_and_grad(forward_loss) Now for the loop: for i in range(30):
# loop over entire dataset
for sample in d_set:
x = sample[0]
y = sample[1]
loss_val, grads = loss_grad_fn(parameters, x, y)
updates, opt_state = tx.update(grads, opt_state)
parameters = optax.apply_updates(parameters, updates)
print('Loss step {}: '.format(i + 1), loss_val) Output:
As we can see, the loss is clearly not changing. This is what I am trying to debug. Why is this the case? Some thoughts I have include: no non-linear activation functions in the model such as relu, something is wrong with the softmax loss function (maybe use sigmoid). I believe the error is somewhere in the model + loss because it turns out that the loss always produces the same value no matter the input. Any tips/suggestions for what I am doing wrong? Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You don't need to apply the softmax in your model because optax.softmax_cross_entropy does this for you. I'm also a bit unsure about the one_hot_labels computation, you might want to double check if it is producing something sensible. |
Beta Was this translation helpful? Give feedback.
You don't need to apply the softmax in your model because optax.softmax_cross_entropy does this for you. I'm also a bit unsure about the one_hot_labels computation, you might want to double check if it is producing something sensible.