-
Using tx = optax.multi_transform(
{
"weights": optax.adamw(learning_rate, momentum),
"biases": optax.adamw(learning_rate, momentum),
},
# this doesn't work:
# {
# "weights": "weights",
# "biases": "biases",
# },
# this does, but is it safe?:
nnx.State({
"weights": "weights",
"biases": "biases"
})
}
) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Interesting. You are right that To answer your question I think its safe to do what you are showing. I'll going to turn this into an issue. |
Beta Was this translation helpful? Give feedback.
-
Closing in favor of #3955 |
Beta Was this translation helpful? Give feedback.
Interesting. You are right that
State
is currently restrictive regarding the value type. I think it can be safely generalized to be generic over the value type e.g.State[V]
and then haveV
beStateLeaf
for current APIsgraph
APIs but users can just change the type for these kind of situations.To answer your question I think its safe to do what you are showing. I'll going to turn this into an issue.