Random Subspace Training using a custom GradientTransformation
#641
-
The above mentioned Colab Notebook contains a demo for training a MLP based model for Image Classification on the MNIST dataset. I wrote a custom I wrote it in a way that it can be chained together with other transformations such as optimizer = combine.chain(
optax.adam(learning_rate),
random_subspace_gradients(
subspace_dim = 16
)
) More than happy to make a PR if the devs find it interesting ! Would also appreciate any feedback ! from typing import NamedTuple
from optax._src import base
class RandomSubSpaceState(NamedTuple):
"""State containing PRNGKey for `random_subspace_gradients`."""
rng_key: chex.PRNGKey
def random_subspace_gradients(
seed: int,
subspace_dim: int
) -> base.GradientTransformation:
"""Computes Gradients along some random subspace
References:
[Li et al, 2018](https://arxiv.org/abs/1804.08838)
Args:
seed (int): initial seed used for the jax.random.PRNGKey
subspace_dim (int): number of subspace dimensions
Returns:
A `GradientTransformation`.
"""
def init_fn(params: base.Params) -> RandomSubSpaceState:
del params
return RandomSubSpaceState(rng_key=jax.random.PRNGKey(seed))
def update_fn(grads, state, params=None):
del params
# split rng
keys = jax.random.split(state.rng_key, 2)
# Flatten gradients into 1d array
grads_flat, grads_treedef = jax.tree_util.tree_flatten(grads)
# Choose a random intrinsic subspace and reshape as per grads
intrinsic_space = jax.random.normal(keys[1], (len(grads_flat),))
intrinsic_space /= jnp.linalg.norm(intrinsic_space)
intrinsic_space = jax.tree_util.tree_unflatten(grads_treedef, intrinsic_space)
# project gradients along intrinsic dimension
projected_grads = jax.tree_util.tree_map(lambda x, y: jax.jit(lambda a, b: a * b)(x, y), grads, intrinsic_space)
return projected_grads, RandomSubSpaceState(rng_key=keys[0])
return base.GradientTransformation(init_fn, update_fn) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Thanks for sharing @SauravMaheshkar ! |
Beta Was this translation helpful? Give feedback.
Thanks for sharing @SauravMaheshkar !