Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module Slicing 2 #169

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
6 changes: 3 additions & 3 deletions docs/api/Module.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/AvgPool.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/BatchNormalization.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/Conv1D.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/Conv2D.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/Conv3D.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/ConvND.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/Dropout.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/EMAParamsTree.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/Embedding.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/Flatten.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/InstanceNormalization.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/LayerNormalization.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/Linear.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/MaxPool.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/MultiHeadAttention.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/Reshape.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/Sequential.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/Transformer.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/TransformerDecoder.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/TransformerDecoderLayer.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/TransformerEncoder.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

6 changes: 3 additions & 3 deletions docs/api/nn/TransformerEncoderLayer.md
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@
- __init__
- call
- add_parameter
- get_parameters
- set_parameters
- reset
- get_default_parameters
- set_default_parameters
- init
- slice

2 changes: 2 additions & 0 deletions elegy/__init__.py
Original file line number Diff line number Diff line change
@@ -45,6 +45,8 @@
except types.DependencyUnavailable as e:
pass

from .slicing import slice_module


__all__ = [
"GeneralizedModule",
17 changes: 17 additions & 0 deletions elegy/hooks.py
Original file line number Diff line number Diff line change
@@ -30,19 +30,22 @@ class HooksContext(types.Protocol):
losses: tp.Optional[types.Logs]
metrics: tp.Optional[types.Logs]
summaries: tp.Optional[types.Summaries]
named_call: tp.Optional[bool]


@dataclass
class _HooksContext(threading.local):
losses: tp.Optional[types.Logs]
metrics: tp.Optional[types.Logs]
summaries: tp.Optional[types.Summaries]
named_call: tp.Optional[bool]


LOCAL: HooksContext = _HooksContext(
losses=None,
metrics=None,
summaries=None,
named_call=None,
)


@@ -164,6 +167,10 @@ def summaries_active() -> bool:
return LOCAL.summaries is not None


def named_call_active() -> bool:
return bool(LOCAL.named_call)


# ----------------------------------------------------------------
# contexts
# ----------------------------------------------------------------
@@ -174,6 +181,7 @@ def context(
losses: tp.Union[types.Logs, bool, None] = None,
metrics: tp.Union[types.Logs, bool, None] = None,
summaries: tp.Union[types.Summaries, bool, None] = None,
named_call: tp.Union[str, None] = None,
set_all: bool = False,
) -> tp.ContextManager[None]:

@@ -184,6 +192,8 @@ def context(
metrics = True
if summaries is None:
summaries = True
if named_call is None:
named_call = False

if isinstance(losses, bool):
losses = {} if losses else None
@@ -194,10 +204,13 @@ def context(
if isinstance(summaries, bool):
summaries = [] if summaries else None

named_call = bool(named_call)

return _context(
losses=losses,
metrics=metrics,
summaries=summaries,
named_call=named_call,
)


@@ -206,22 +219,26 @@ def _context(
losses: tp.Optional[types.Logs],
metrics: tp.Optional[types.Logs],
summaries: tp.Optional[types.Summaries],
named_call: tp.Optional[str],
) -> tp.Iterator[None]:

prev_losses = LOCAL.losses
prev_metrics = LOCAL.metrics
prev_summaries = LOCAL.summaries
prev_named_call = LOCAL.named_call

LOCAL.losses = losses
LOCAL.metrics = metrics
LOCAL.summaries = summaries
LOCAL.named_call = named_call

try:
yield
finally:
LOCAL.losses = prev_losses
LOCAL.metrics = prev_metrics
LOCAL.summaries = prev_summaries
LOCAL.named_call = prev_named_call


# -------------------------------------------------------------
30 changes: 30 additions & 0 deletions elegy/hooks_test.py
Original file line number Diff line number Diff line change
@@ -66,3 +66,33 @@ def f(x):
assert losses["x_loss"] == 6
assert metrics["x"] == 7
assert summaries[0] == (("a", 0, "b"), jax.nn.relu, 8)

def test_named_call(self):
class Module0(elegy.Module):
def call(self, x):
x = elegy.nn.Linear(5)(x)
x = elegy.nn.Linear(7)(x)
return x

m = elegy.Model(Module0())
m.init(jnp.ones(4))

with elegy.hooks.context(named_call=True):
jaxpr = jax.make_jaxpr(
lambda x, states: m.pred_step(x, states, False, False)
)(jnp.ones([4]), m.states)
print(jaxpr)

assert jaxpr.jaxpr.eqns[0].params["name"] == ()
assert jaxpr.jaxpr.eqns[0].params["call_jaxpr"].eqns[0].params["name"] == (
"linear",
)
assert jaxpr.jaxpr.eqns[0].params["call_jaxpr"].eqns[1].params["name"] == (
"linear_1",
)

# no named call without hook
jaxpr = jax.make_jaxpr(lambda x, states: m.pred_step(x, states, False, False))(
jnp.ones([4]), m.states
)
assert jaxpr.jaxpr.eqns[0].primitive.name != "named_call"
59 changes: 55 additions & 4 deletions elegy/module.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,11 @@
from elegy import hooks, utils
from elegy import types


# placeholder for elegy.slicing
# injected from inside the module because of a circular dependency
slicing = None

__all__ = [
"Module",
"to_module",
@@ -154,10 +159,10 @@ class Module(metaclass=ModuleMeta):
"__init__",
"call",
"add_parameter",
"get_parameters",
"set_parameters",
"reset",
"get_default_parameters",
"set_default_parameters",
"init",
"slice",
]

def __init__(self, name: tp.Optional[str] = None, dtype: tp.Any = jnp.float32):
@@ -362,8 +367,11 @@ def __call__(self, *args, **kwargs) -> tp.Any:
# this marks initialization

with call_context(self):
method_fn = self.call
if hooks.named_call_active():
method_fn = jax.named_call(method_fn, name=get_module_path(self))

outputs = self.call(*args, **kwargs)
outputs = method_fn(*args, **kwargs)

if hooks.summaries_active():
path = get_module_path(self)
@@ -830,6 +838,49 @@ def _get_parameters(self, defaults: bool) -> tp.Dict[str, tp.Any]:
def has_parameter(self, name: str) -> bool:
return hasattr(self, name)

def slice(
self,
start: tp.Union[str, None],
end: tp.Union[str, None, tp.List[tp.Union[str, None]]],
sample_input: tp.Any,
) -> "Module":
"""
Creates a new submodule starting from the input of `start` to the outputs of `end`.
Current limitations:
- Only elegy.Module can be specified as `start` or `end`
- Only one `start` is supported
Note:
You might need to call `model.update_modules()` before slicing to make sure the parameters are transferred to the new submodule.
Example usage:
```
x = jnp.zeros((2, 224, 224, 3))
resnet = elegy.nets.resnet.ResNet18()
submodule = resnet.slice(
start = '/inputs',
end = ["/res_net_block_1", "/res_net_block_3", "/res_net_block_5", "/res_net_block_7" ],
sample_input = x,
)
outputs = elegy.Model(submodule).predict(x, initialize=True)
assert outputs[0].shape == (2, 56, 56, 64)
assert outputs[1].shape == (2, 28, 28, 128)
assert outputs[2].shape == (2, 14, 14, 256)
assert outputs[3].shape == (2, 7, 7, 512)
```
Arguments:
start: Name of a child module which will be the input module of the resulting module.
If `None`, the first module is used.
end: Name of child module, `None` or a list thereof which will be the output module(s) of the resulting module.
If `None`, the last module is used.
sample_input: An example input to the model.
Returns: A new Module.
"""
return slicing.slice_module(self, start, end, sample_input)


# -------------------------------------------------------------
# hooks
322 changes: 322 additions & 0 deletions elegy/slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
import typing as tp
import functools

import numpy as np
import jax, jax.numpy as jnp
import elegy
from elegy import Module, Model
import elegy.hooks as hooks
import elegy.types as types

import sys
from . import module

# injecting this module into elegy.module because of a circular dependency
module.slicing = sys.modules[__name__]


def slice_module(
module: Module,
start: tp.Union[str, None],
end: tp.Union[str, None, tp.List[tp.Union[str, None]]],
sample_input: tp.Any,
) -> Module:

model = Model(module)
model.init(sample_input)

# create a jaxpr, marking all modules with named_call
with hooks.context(named_call=True), jax.disable_jit():
jaxpr = jax.make_jaxpr(model.pred_step, static_argnums=[2, 3])(
sample_input, model.states, False, False
)
jaxpr = jaxpr.jaxpr
jaxpr = replace_named_call_vars(jaxpr)

n_inputs = len(jax.tree_leaves(sample_input))
# actual module inputs
toplevel_input_vars = jaxpr.invars[:n_inputs]
# weights and biases
state_vars = jaxpr.invars[n_inputs:]

ends = end if isinstance(end, (list, tuple)) else [end]
input_vars, output_vars, unresolved_vars, eqn_sequence = analyze_jaxpr(
jaxpr, start, ends
)

stateless_unresolved_vars = [
v for v in unresolved_vars if v not in state_vars + input_vars
]
stateless_unresolved_vars = filter_literals(stateless_unresolved_vars)
stateless_input_vars = [v for v in input_vars if v not in state_vars]
# flatten
output_vars = [var for vars in output_vars for var in vars]
output_vars = [v for v in output_vars if v not in state_vars]

if len(stateless_unresolved_vars) or None in output_vars:
raise RuntimeError(f"No path from {start} to {end}")

return SlicedModule(model.module, eqn_sequence, stateless_input_vars, output_vars)


def filter_literals(
vars: tp.List[tp.Union[jax.core.Var, jax.core.Literal]]
) -> tp.List[jax.core.Var]:
"""Removes jax.core.Literal values from the input list"""
return [v for v in vars if not isinstance(v, jax.core.Literal)]


def replace_named_call_vars(
jaxpr: jax.core.Jaxpr,
env: tp.Dict[jax.core.Var, jax.core.Var] = None,
level: str = "",
) -> jax.core.Jaxpr:
"""Replaces vars of inner jaxprs with vars from outer jaxprs if they are equivalent."""
env = env or dict()
for inv in jaxpr.invars:
if inv not in env:
env[inv] = inv

top_eqns = []
invars = [
env[v] if not isinstance(v, jax.core.Literal) else v for v in jaxpr.invars
]

for eq in jaxpr.eqns:
eq_outvars = [env[v] if v in env else v for v in eq.outvars]
for outv in eq_outvars:
env[outv] = jax.core.Var(outv.count, f"_{level}", outv.aval)

inner_invars = [
env[v] if not isinstance(v, jax.core.Literal) else v for v in eq.invars
]
inner_outvars = [env[v] for v in eq.outvars]
eq_params = eq.params

if eq.primitive.name == "named_call":
inner_jaxpr = eq.params["call_jaxpr"]
inner_env = dict(
(
[
(inner_v, env[outer_v])
for inner_v, outer_v in zip(inner_jaxpr.invars, eq.invars)
]
+ [
(inner_v, env[outer_v])
for inner_v, outer_v in zip(inner_jaxpr.outvars, eq.outvars)
]
)
)
inner_jaxpr = replace_named_call_vars(
inner_jaxpr, inner_env, eq.params["name"]
)
eq_params = {"call_jaxpr": inner_jaxpr, "name": eq.params["name"]}

new_eqn = jax.core.JaxprEqn(
invars=inner_invars,
outvars=inner_outvars,
primitive=eq.primitive,
params=eq_params,
source_info=eq.source_info,
)
top_eqns += [new_eqn]

outvars = [env[v] for v in jaxpr.outvars]
new_jaxpr = jax.core.Jaxpr(
constvars=jaxpr.constvars, invars=invars, outvars=outvars, eqns=top_eqns
)
return new_jaxpr


def strict_startswith(s0: str, s1: str) -> bool:
return s0.startswith(s1) and not s0 == s1


def path_to_str(path: tp.Tuple[str]) -> str:
"""Converts a module path to a string e.g. ('A','B') -> "/A/B/" """
pathstr = "/" + "/".join(path)
if len(path) > 0:
pathstr += "/"
return pathstr


# constants, variations of the string "input"
INPUTS_STRINGS = ["input", "inputs"]
INPUTS_STRINGS += ["/" + s for s in INPUTS_STRINGS]
INPUTS_STRINGS += [s.upper() for s in INPUTS_STRINGS]


def normalize_module_path(module_path: tp.Union[str, None]) -> str:
if module_path in [None] + INPUTS_STRINGS:
return "/"
if not module_path.startswith("/"):
module_path = "/" + module_path
if not module_path.endswith("/"):
module_path = module_path + "/"
return module_path


def analyze_jaxpr(
jaxpr: jax.core.Jaxpr,
start_path: tp.Union[None, str],
end_paths: tp.List[tp.Union[None, str]],
unresolved: tp.Set[jax.core.Var] = None,
):
"""Analyze a jaxpr, collecting only the necessary equations from start_path to end_path"""
start_path = normalize_module_path(start_path)
normed_end_paths = [normalize_module_path(e) for e in end_paths]

output_vars = [[None]] * len(end_paths)
input_vars = []
unresolved = unresolved or set()
eqn_sequence = []

# iterate over equations in reverse order to make sure only necessary equations are collected
for eq in reversed(jaxpr.eqns):
if eq.primitive.name == "named_call":
# encountered an elegy.Module

eq_module_name = path_to_str(eq.params["name"])
# check if it's one of the targets in end_paths
for idx, end_path in enumerate(normed_end_paths):
if eq_module_name == end_path:
# target found
if end_paths[idx] in INPUTS_STRINGS:
# special case "/input" or similar
# collect the input of the whole module
output_vars[idx] = eq.invars
else:
# otherwise collect the output of the module
output_vars[idx] = eq.outvars
unresolved = unresolved.union(eq.outvars)

if any(
[
strict_startswith(p, eq_module_name)
for p in normed_end_paths + [start_path]
]
):
# module is not a target in end_paths but is a parent module
# e.g. "/module_A" when "/module_A/module_B" is in end_paths
# go inside the module: analyze the inner jaxpr
inner_jaxpr = eq.params["call_jaxpr"]
(
inner_invars,
inner_outvars,
inner_unresolved,
inner_eqns,
) = analyze_jaxpr(inner_jaxpr, start_path, end_paths, unresolved)
input_vars += inner_invars
for idx, outvar in enumerate(inner_outvars):
if outvar != [None]:
output_vars[idx] = outvar
unresolved = inner_unresolved
eqn_sequence = inner_eqns + eqn_sequence

# check whether the equation is necessary
# i.e. if some of its outputs are required by previously collected equations
common_vars = unresolved.intersection(eq.outvars)
if len(common_vars):
# yes it is required
unresolved = unresolved.difference(eq.outvars)
unresolved = unresolved.union(filter_literals(eq.invars))
eqn_sequence = [eq] + eqn_sequence

# check if this equation is the specified start of slicing
if eq.primitive.name == "named_call" and eq_module_name == start_path:
input_vars = eq.invars
unresolved = unresolved.difference(filter_literals(input_vars))

return input_vars, output_vars, unresolved, eqn_sequence


class Environment(dict):
"""A dict that can deal with jax.core.Literal which is not hashable"""

def __getitem__(self, var):
if isinstance(var, jax.core.Literal):
return var.val
else:
return super().__getitem__(var)

def __setitem__(self, var, value):
if isinstance(var, jax.core.Literal):
pass
else:
super().__setitem__(var, value)

def __contains__(self, var):
if isinstance(var, jax.core.Literal):
return True
else:
return super().__contains__(var)


def get_module(parentmodule: Module, path: tp.Tuple[str]):
"""Returns a submodule from the parentmodule according to path"""
if path == ():
return parentmodule
else:
for n in path:
parentmodule = getattr(parentmodule, n)
return parentmodule


class SlicedModule(Module):
def __init__(
self,
mainmodule: Module,
equations: tp.List[jax.core.JaxprEqn],
input_vars: tp.List[jax.core.Var],
output_vars: tp.List[jax.core.Var],
):
super().__init__()
for eq in equations:
if eq.primitive.name == "named_call":
setattr(
self,
"_".join(eq.params["name"]),
get_module(mainmodule, eq.params["name"]),
)
self.equations = equations
self.input_vars = input_vars
self.output_vars = output_vars

def call(self, *args):
if len(args) != len(self.input_vars):
raise TypeError(
f"Expected {len(self.input_vars)} inputs, received {len(args)}"
)

environment: tp.Dict[jax.core.Var, tp.Any] = Environment()
for var, arg in zip(self.input_vars, args):
environment[var] = arg

# execute all equations one by one
for eq in self.equations:
eq_inputs = [environment[v] for v in eq.invars if v in environment]

# if the equation is a module, execute the module instead
if eq.primitive.name == "named_call":
# yes it is a module
module = getattr(self, "_".join(eq.params["name"]))
outputs = module(*eq_inputs)
else:
# not a module, simple execute the primitive
if isinstance(eq.primitive.impl, functools.partial):
outputs = eq.primitive.bind(*eq_inputs, **eq.params)
else:
outputs = eq.primitive.impl(*eq_inputs, **eq.params)

if isinstance(outputs, list):
outputs = tuple(outputs)
elif not isinstance(outputs, tuple):
outputs = (outputs,)

for o, v in zip(outputs, eq.outvars):
environment[v] = o

outputs = tuple(environment[v] for v in self.output_vars)
if len(outputs) == 1:
outputs = outputs[0]
return outputs
198 changes: 198 additions & 0 deletions elegy/slicing_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import jax, jax.numpy as jnp
import numpy as np

import elegy, elegy.slicing
import optax
from unittest import TestCase


class BasicModuleSlicingTest(TestCase):
def setUp(self):
self.x = np.random.random((32, 100)).astype("float32")
self.module = BasicModule0()
self.module.init(rng=elegy.RNGSeq(0), set_defaults=True)(self.x)

def test_basic_slice_by_name0(self):
start, end = ("linear0", "linear1")
submodule = self.module.slice(start, end, self.x)
submodel = elegy.Model(submodule)
assert submodel.predict(self.x, initialize=True).shape == (32, 10)
assert jnp.all(submodel.predict(self.x) == self.module.test_call0(self.x))
# different batch size
assert submodel.predict(self.x[:8]).shape == (8, 10)

def test_basic_slice_by_name1(self):
start, end = (None, "linear1") # None means input
submodule = self.module.slice(start, end, self.x)
submodel = elegy.Model(submodule)
assert submodel.predict(self.x, initialize=True).shape == (32, 10)
assert jnp.allclose(submodel.predict(self.x), self.module.test_call1(self.x))

def test_slice_multi_output(self):
start, end = None, ["linear2", "linear1"]
submodule = self.module.slice(start, end, self.x)
submodel = elegy.Model(submodule)

outputs = submodel.predict(self.x, initialize=True)
true_outputs = self.module.test_call2(self.x)
assert len(outputs) == 2
assert outputs[0].shape == true_outputs[0].shape
assert outputs[1].shape == true_outputs[1].shape
assert jnp.allclose(outputs[0], true_outputs[0])
assert jnp.allclose(outputs[1], true_outputs[1])

def test_slice_return_input(self):
submodule = self.module.slice("input", ["/linear1", "input"], self.x)
submodel = elegy.Model(submodule)
submodel.summary(self.x)
ypred = submodel.predict(self.x, initialize=True)
assert jnp.all(ypred[1] == self.x)
assert ypred[0].shape == (32, 10)
assert jnp.allclose(ypred[0], self.module.test_call1(self.x))
assert "linear2" not in submodel.states["net_params"].keys()

def test_no_path(self):
for start_module in ["linear2", "linear1"]:
try:
submodule = self.module.slice(start_module, "linear0", self.x)
submodel = elegy.Model(submodule)
submodel.summary(self.x)
except RuntimeError as e:
assert e.args[0].startswith(f"No path from {start_module} to linear0")
else:
assert False, "No error or wrong error raised"

def test_retrain(self):
y = jnp.zeros((32, 10))

submodule = self.module.slice("linear0", "linear1", self.x)
submodel = elegy.Model(
submodule,
loss=elegy.losses.MeanAbsoluteError(),
optimizer=optax.sgd(0.05),
)
submodel.init(self.x, y)
y0 = submodel.predict(self.x)

submodel.fit(self.x, y, epochs=3, verbose=2)

y2 = submodel.predict(self.x)
# output after training should be closer to zero because targets are zero
assert jnp.abs(y2.mean()) < jnp.abs(y0.mean())


class ResNetSlicingTest(TestCase):
def test_multi_out(self):
x = jnp.zeros((2, 224, 224, 3))
resnet = elegy.nets.resnet.ResNet18()
resnet.init(rng=elegy.RNGSeq(0), set_defaults=True)(x)

submodule = resnet.slice(
start=None,
end=[
"/res_net_block_1",
"/res_net_block_3",
"/res_net_block_5",
"/res_net_block_6",
"/res_net_block_7",
],
sample_input=x,
)
submodel = elegy.Model(submodule, run_eagerly=True)

# submodel.summary(x)
outputs = submodel.predict(x, initialize=True)
print(jax.tree_map(jnp.shape, outputs))
assert len(outputs) == 5
assert outputs[0].shape == (2, 56, 56, 64)
assert outputs[1].shape == (2, 28, 28, 128)
assert outputs[2].shape == (2, 14, 14, 256)
assert outputs[3].shape == (2, 7, 7, 512)
assert outputs[4].shape == (2, 7, 7, 512)


class NestedSlicingTest(TestCase):
def test_basic_nested(self):
self.x = np.random.random((32, 100)).astype("float32")
self.module = NestedModule0()
self.module.init(rng=elegy.RNGSeq(0), set_defaults=True)(self.x)

# self.model.summary(self.x)
submodule = self.module.slice("/module0/linear1", "/module1/linear1", self.x)
submodel = elegy.Model(submodule)

x_for_submodel = np.random.random([16, 25])
submodel.predict(x_for_submodel, initialize=True)
submodel.summary(x_for_submodel)

assert jnp.allclose(
submodel.predict(x_for_submodel), self.module.test_call0(x_for_submodel)
)
assert "module0_linear1" in submodel.states["net_params"].keys()
assert "module0_linear0" not in submodel.states["net_params"].keys()
assert "module1_linear2" not in submodel.states["net_params"].keys()


def test_no_default_parameters():
x = np.random.random((32, 100)).astype("float32")
module = BasicModule0()
model = elegy.Model(module, seed=np.random.randint(100, 100000))
model.init(x)
model.update_modules()

submodel = elegy.Model(model.module.slice("linear0", "linear1", x))
assert submodel.predict(x, initialize=True).shape == (32, 10)

assert jnp.allclose(submodel.predict(x), module.test_call0(x))


class BasicModule0(elegy.Module):
def call(self, x):
x = x / 255.0
x = elegy.nn.Linear(25, name="linear0")(x)
x = jax.nn.relu(x)
x = elegy.nn.Linear(10, name="linear1")(x)
x = jax.nn.relu(x)
x = elegy.nn.Linear(5, name="linear2")(x)
return x

def test_call0(self, x):
x = self.linear0.call_with_defaults()(x)
x = jax.nn.relu(x)
x = self.linear1.call_with_defaults()(x)
return x

def test_call1(self, x):
x = x / 255.0
x = self.linear0.call_with_defaults()(x)
x = jax.nn.relu(x)
x = self.linear1.call_with_defaults()(x)
return x

def test_call2(self, x):
x = x / 255.0
x = self.linear0.call_with_defaults()(x)
x = jax.nn.relu(x)
x = x0 = self.linear1.call_with_defaults()(x)
x = jax.nn.relu(x)
x = x1 = self.linear2.call_with_defaults()(x)
return x1, x0


class NestedModule0(elegy.Module):
def call(self, x):
x = BasicModule0(name="module0")(x)
x = x * 255
x = BasicModule0(name="module1")(x)
return x

def test_call0(self, x):
x = self.module0.linear1.call_with_defaults()(x)
x = jax.nn.relu(x)
x = self.module0.linear2.call_with_defaults()(x)
x = x * 255
x = x / 255.0
x = self.module1.linear0.call_with_defaults()(x)
x = jax.nn.relu(x)
x = self.module1.linear1.call_with_defaults()(x)
return x
19 changes: 0 additions & 19 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -9,32 +9,13 @@ nav:
- Low Level API: getting-started/low-level-api.ipynb
- High Level API:
- Intro: basic-api/modules-losses-metrics.md
# - Data Sources: na.md
# - Module:
# - Flax: na.md
# - Haiku: na.md
# - Elegy: na.md
# - Pure Jax: na.md
# - Supporting Other Modules: na.md
# - Losses: na.md
# - Metrics: na.md
# - Dependency Injection: na.md
# - Hooks: na.md
# - Optimizer:
# - Optax Optimizer: na.md
# - Elegy Optimizer: na.md
# - Monitoring Learning Rate: na.md
# - Supporting Other Optimizers: na.md
# - Callbacks: na.md
# - Serialization: na.md
- Low Level API:
- Basics: low-level-api/basics.md
- States: low-level-api/states.md
- Methods:
- pred_step: low-level-api/methods/pred_step.md
- test_step: low-level-api/methods/test_step.md
- Default Implementation: low-level-api/default-implementation.md
# - Elegy Module: module-system.md
- Contributing: contributing.md
- API Reference:
GeneralizedModule: api/GeneralizedModule.md