How to skip integer params from Flax Model? #2621
Replies: 1 comment
-
Hey @amankhandelia Params IssueFor the Option 1Since class FlaxDonutSwinSelfAttention(nn.Module):
config: DonutSwinConfig
dim: int
num_attention_heads: int
relative_position_index: jnp.ndarray # accept is as an input
dtype: jnp.dtype = jnp.float32 This way its not in your Option 2Use optax.multi_transform to prevent the optimizer from trying to modify those parameters. See the 'Freezing Layers' from our Transfer Learning guide as an example. Option 3You can create a separate collection (e.g. # we are piggybacking the `params` rng here for simplicity
index_key = self.make_rng('params') if self.has_rng('params') else None
self.relative_position_index = self.variable(
"index", "relative_position_index",
relative_position_index_init, index_key, self.window_size, self.dtype
).value Just beware that you have to pass this collection around now, same as with the Updating issueI am not familiar enough with HuggingFace's API so its hard to comment if |
Beta Was this translation helpful? Give feedback.
-
Hello there,
I have been working on porting Donut Model from torch to Flax. In my test, inference is fully functional and gives the same results as torch. In the same vain, I am finetuning Flax model to check if it gets the same performance as torch model. For that I am following this notebook by @NielsRogge
This is where I am getting errors. To reproduce the error, please refer to the colab notebook.
On this line we are defining
relative_position_index
which of typeint32
, and it is not a learnable parameter, In order to make thejax.value_and_grad
work with int type I am passingallow_int=True
.I am guessing that is the source of problem, not skipping gradient calculation on the integer variable (although I could be wrong), So my primary question is how to skip the gradient calculation for such parameters, and if that not the case, what should I do correct the error. Any help is deeply appreciated.I have done one more hacky thing, which could also be the source of my pains, although I am not so certain of that, as you will see in the notebook, we have to expand the number of tokens in the pretrained model, to include special tokens for document classes. In order to the same
after running this code, I am running
model.init
again throughmodel.init_weights
:Could this be problem, my understanding is Jax/Flax is not good enough to comment if it is, if somebody else sees a problem please do comment
cc: @cgarciae
Beta Was this translation helpful? Give feedback.
All reactions