Multiple Inheritance -> doesn't recognize as Module throws ValueError: parent must be None, Module or Scope #1390
-
I'm working on a Flax implementation for ProteinBERT: A universal deep-learning model of protein sequence and function. My work so far is in SauravMaheshkar/ProteinBERT. I've made a simple from proteinbert import ProteinBERT
import jax
from jax import random
def test():
seq = jax.random.randint(
key=random.PRNGKey(0), minval=0, maxval=21, shape=(2, 2048)
)
annotation = jax.random.randint(
key=random.PRNGKey(0), minval=0, maxval=1, shape=(2, 8943)
)
init_rngs = {"params": random.PRNGKey(0), "layers": random.PRNGKey(1)}
ProteinBERT().init(init_rngs, seq, annotation)
if __name__ == "__main__":
test() And I've been getting this error message Error Message
The problem lies in the class Reduce(ReduceMixin, nn.Module):
"""
Flax Module to act as a Reduce layer (from einops)
"""
def __call__(self, input):
return self._apply_recipe(input) The idea is to create a Any help would be much appreciated 😊. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
@marcvanzee will be looking into this. |
Beta Was this translation helpful? Give feedback.
-
Hi @SauravMaheshkar, Flax Modules are special kinds of dataclasses in the sense that they will attach class Bar(nn.Module):
@nn.compact
def __call__(self, x):
pass
In your case, in (Note that I think this is something we should try to improve from the Flax side as well, but in the meantime an easy solution is to make class Reduce(nn.Module):
pattern: str
reduction: str
def setup(self):
self.reducer = ReduceMixin(self.pattern, self.reduction)
def __call__(self, input):
return self.reducer._apply_recipe(input)
class Foo(nn.Module):
def setup(self):
self.reducer = Reduce(pattern="b n d -> b d", reduction="mean")
def __call__(self, x):
...
seq = jax.random.randint(key=random.PRNGKey(0), minval=0, maxval=21, shape=(2, 2048))
annotation = jax.random.randint(key=random.PRNGKey(0), minval=0, maxval=1, shape=(2, 8943))
init_rngs = {"params": random.PRNGKey(0), "layers": random.PRNGKey(1)}
Foo().init(init_rngs, seq) Please let me know if this works for you! |
Beta Was this translation helpful? Give feedback.
Hi @SauravMaheshkar, Flax Modules are special kinds of dataclasses in the sense that they will attach
parent
andname
attributes automatically. Theparent
attribute is usually assigned automatically: it is set to the parent Module, or set toNone
if this is the top-level module. However, if you provide Modules with a string argument, theparent
attribute will be set to the string value. Here is a minimal example:In your case, in
ProteinBert.setup()
, you initializeReduce
as follows:Reduce("b n d -> b d", "mean")
. You intend to call the constructor of…