JITting a model application results in a mutated model #391
Replies: 8 comments
-
This is "working as intended". In the pure python case you are just passing references around, but when you pass a "jax-serializable" python object like a so it is true that you're not overwriting the original |
Beta Was this translation helpful? Give feedback.
-
Yes certainly, I’m overwriting it - sorry for the poor choice of words. However it seems that serialization and reconstituting a Flax model should return precisely the same model, yes?
…On Aug 5, 2020, 5:28 PM -0400, Anselm Levskaya ***@***.***>, wrote:
This is "working as intended".
In the pure python case you are just passing references around, but when you pass a "jax-serializable" python object like a nn.Model into a jax-jitted function and return it, you're actually serializing the object into its corresponding pure data pytree, doing whatever numerical jaxy thing your function does to that data, then reconstituting the object from scratch through a defined deserialization method that ingests the serialized data pytree.
so it is true that you're not overwriting the original model by passing it in, but in your example code above you are in fact overwriting model - if you use different variable names for the return value you should see that the original isn't touched.
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub, or unsubscribe.
|
Beta Was this translation helpful? Give feedback.
-
For context this resulted in a bug in my code that took almost eight hours of work to isolate.
…On Aug 5, 2020, 5:28 PM -0400, Anselm Levskaya ***@***.***>, wrote:
This is "working as intended".
In the pure python case you are just passing references around, but when you pass a "jax-serializable" python object like a nn.Model into a jax-jitted function and return it, you're actually serializing the object into its corresponding pure data pytree, doing whatever numerical jaxy thing your function does to that data, then reconstituting the object from scratch through a defined deserialization method that ingests the serialized data pytree.
so it is true that you're not overwriting the original model by passing it in, but in your example code above you are in fact overwriting model - if you use different variable names for the return value you should see that the original isn't touched.
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub, or unsubscribe.
|
Beta Was this translation helpful? Give feedback.
-
Ah OK now I understand what you're pointing out! The problem is that the to illustrate: class Test:
def __init__(self, x):
self.x = x
def __eq__(self, other):
print('eq called')
return self.x == other.x
@flax.struct.dataclass
class Foo:
bar: Test
a = Foo(Test(0))
b = Foo(Test(0))
print(a==a) # True
print(b==b) # True
print(a==b) # eq called \n True and if we do this with arrays we repeat the phenomenon you see, eg.: a = Foo(Test(jnp.array([0,0])))
b = Foo(Test(jnp.array([0,0])))
print(a==a) # True
print(b==b) # True
print(a==b) # eq called \n ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() So if you want to test for the mathematical equality of a model rather than just the model object identity (and model is really just a holder of params) then you want to test all leaves in the tree with an array comparison, by e.g.: def tree_eq(a, b):
comparisons = jax.tree_multimap(lambda x,y: jnp.all(x==y), a, b)
return jax.tree_util.tree_reduce(lambda x, y: x and y, comparisons, True)
tree_eq(model, initial_model) # True I'm sorry this subtlety cost you so much time! |
Beta Was this translation helpful? Give feedback.
-
Ah I see, that is... fiendishly subtle. Thanks for clarifying what's going on! The circumstance where I ran into this was pretty niche, and probably my fault, but it was breaking memoization for another jitted JAX function when the model was an element of a |
Beta Was this translation helpful? Give feedback.
-
Why do you want to mark the dataclass field for the Model as a leaf node? if I comment out your annotation that marks the model as a leaf (and spiritually it's not really a leaf since it's just a wrapper for a nested param dict) everything seems to work? @struct.dataclass
class ModelContainer():
model: Any #= struct.field(pytree_node=False) |
Beta Was this translation helpful? Give feedback.
-
Yes, that does fix it. This is a simplified case from a library that I'm writing, where the user provides a state (in this case it's a Flax model, but it could be anything) and functions for using and updating that state. I wanted to mark it as a leaf because in general it might contain arbitrary structures like a Pytorch model. I'm not sure what the best way to deal with data like this is in JAX/Flax. The solution I'm at right now is to maintain a clear separation between known, all-JAX data structures and the functions that deal with them, and user-provided data structures / functions. Then it's easy to reason about what code can be transformed. Are there other tools or models for managing arbitrary data like this? |
Beta Was this translation helpful? Give feedback.
-
Hey will, sorry for slow reply - in the general case it's hard to give a complete answer to this - inevitably there are parts of code that fall within JAX numerical domain, some that can be interfaced with care to external systems (e.g. sharing data w. numpy, pytorch via dlpack, etc.), and generally crazy python stuff that JAX can't touch -- it depends on the situation and goals. The general idea of sticking untouchable python things in leaves is probably ok, but in the case of an actual flax model/parameter-tree you don't want to handle it opaquely since you probably really do want to serialize/deserialize the data, etc. |
Beta Was this translation helpful? Give feedback.
-
Problem you have encountered:
I have a function which evaluates a model at a point, then returns both the model and its output. If I
jit
this function, the returned model is subtly different from the original model.In particular, the returned model is no longer comparable with
==
.What you expected to happen:
The model should not be mutated at all.
Logs, error messages, etc:
The specific error when comparing the returned model is:
Steps to reproduce:
Colab notebook reproducing this issue: https://colab.research.google.com/drive/1kSrbIMti78sqEwE9t3rMz5nLvDewSrwW
Beta Was this translation helpful? Give feedback.
All reactions