Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parametric conflicts with certain usages of customized __init_subclass__ #105

Closed
femtomc opened this issue Sep 18, 2023 · 13 comments · Fixed by #108
Closed

parametric conflicts with certain usages of customized __init_subclass__ #105

femtomc opened this issue Sep 18, 2023 · 13 comments · Fixed by #108

Comments

@femtomc
Copy link

femtomc commented Sep 18, 2023

Hi!

I'm using JAX, and also using plum -- in my library, I've define a mixin class called Pytree which automatically implements the Pytree interface for classes which mix it in.

It's quite simple:

class Pytree:
    def __init_subclass__(cls, **kwargs):
        jtu.register_pytree_node(
            cls,
            cls.flatten,
            cls.unflatten,
        )

If I wish to use this mixin, and parametric -- I'm in for problems, I get duplicate registration:

ERROR ... ValueError: Duplicate custom PyTreeDef type registration for <class...>

I'm not exactly sure why this occurs, but I'm hoping to find a fix -- because I'd like to use parametric classes to guide some of the dispatch in my library functions.

@PhilipVinc
Copy link
Collaborator

PhilipVinc commented Sep 18, 2023

Can you provide a runnable MWE?

@femtomc
Copy link
Author

femtomc commented Sep 18, 2023

Sure! one moment

@femtomc
Copy link
Author

femtomc commented Sep 18, 2023

import abc
import jax.tree_util as jtu
from plum import parametric

class Pytree:
    def __init_subclass__(cls, **kwargs):
        jtu.register_pytree_node(
            cls,
            cls.flatten,
            cls.unflatten,
        )

    @abc.abstractmethod
    def flatten(self):
        raise NotImplementedError

    @classmethod
    def unflatten(cls, data, xs):
        return cls(*data, *xs)
    

@parametric
class Wrapper(Pytree):    
    def flatten():
        return (), ()
    
Wrapper[int]

Even if there's not a convenient idiom using parametric -- I'm wondering if I can figure out a way to define the Pytree behavior once (e.g. -- not repeatedly on subclasses), but take advantage of the typing.

@PhilipVinc
Copy link
Collaborator

I suspect the issue lies in this line which is run after the concrete class Wrapper[int]is created (not Wrapper), and we call the __init_subclass__you implemented, but clsis Wrapper instead of Wrapper[int].

There should be a way to ensure that the proper class is passed there...

@femtomc
Copy link
Author

femtomc commented Sep 18, 2023

Right - this is a bit of a weird setting (in the sense that __init_subclass__ is actually globally stateful for the original class). I'd suspect this doesn't occur very often in practice -- but it just happens to be a convenient way to do the registration.

I wonder if I could do something on my end to get around it.

@PhilipVinc
Copy link
Collaborator

Maybe just changing that line in plum to super(original_class, cls).__init_subclass__(**kw_args) will fix it. Let me try a bit more...

@femtomc
Copy link
Author

femtomc commented Sep 18, 2023

Thanks for any help!

@PhilipVinc
Copy link
Collaborator

Yes, this fixes it. I'm preparing a PR.
By the way, for future reference, a smaller repro without jax is

    from plum import parametric

    register = set()

    class Pytree:
        def __init_subclass__(cls, **kwargs):
            if cls in register:
                raise ValueError("duplicate")
            else:
                register.add(cls)

    @parametric
    class Wrapper(Pytree):    
        pass

    Wrapper[int]

@PhilipVinc
Copy link
Collaborator

By the way, may I ask you to share (a gist?) your PyTree code with __init_subclasses__? I have long considered moving past flax data classes, but this exact bug discouraged me a few times from going forward. Now that's fixed I might reconsider and having a starting point would speed me up a bit.

@femtomc
Copy link
Author

femtomc commented Sep 18, 2023

Sure, one moment.

@femtomc
Copy link
Author

femtomc commented Sep 18, 2023

Or -- rather, what do you want that is different than the MWE I posted above?

@PhilipVinc
Copy link
Collaborator

Ah ok, that's what you do?
(I hate to define the flatten/unflatten by hand. I like things like https://github.com/cgarciae/simple-pytree/blob/main/simple_pytree/pytree.py that automatise it and thought you were doing the same and was curious to see an alternative implementation).

@femtomc
Copy link
Author

femtomc commented Sep 18, 2023

Ah right -- I define unflatten automatically, but have flatten defined custom.

Nothing fancy with introspection -- just custom flatten and an assumption about field order for unflatten.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants