Why use JAXnet over stax?
JAXnet improves extensibility and user-friendliness to that of Keras/TensorFlow2/PyTorch, while retaining the functional character of stax. Advantages of JAXnet include:
def Dense(out_dim, W_init=glorot(), b_init=randn()):
@parametrized
def dense(inputs):
W = parameter((inputs.shape[-1], out_dim), W_init)
b = parameter((out_dim,), b_init)
return np.dot(inputs, W) + b
return dense
The same in stax:
def Dense(out_dim, W_init=glorot(), b_init=randn()):
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
k1, k2 = random.split(rng)
W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return np.dot(inputs, W) + b
return init_fun, apply_fun
From the Mnist VAE example:
@parametrized
def elbo(rng, images):
mu_z, sigmasq_z = encode(images)
logits_x = decode(gaussian_sample(rng, mu_z, sigmasq_z))
return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z)
params = loss.init_parameters(example_rng, example_batch, key=PRNGKey(2))
The equivalent from the stax example:
def elbo(rng, params, images):
enc_params, dec_params = params
mu_z, sigmasq_z = encode(enc_params, images)
logits_x = decode(dec_params, gaussian_sample(rng, mu_z, sigmasq_z))
return bernoulli_logpdf(logits_x, images) - gaussian_kl(mu_z, sigmasq_z)
enc_init_rng, dec_init_rng = random.split(PRNGKey(2))
_, encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
_, decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
params = encoder_params, decoder_params
JAXnet does not require boilerplate parameter initialization (output shape inference, random key splitting) and handling code (destructuring, passing to submodules).
opt = optimizers.Adam()
state = opt.init(params)
for _ in range(10):
state = opt.update(loss.apply, state, *next_batch(), jit=True)
trained_params = opt.get_parameters(state)
The same in stax:
opt_init, opt_update, get_params = optimizers.adam(0.001)
@jit
def update(i, state, batch):
params = get_params(state)
return opt_update(i, grad(loss)(params, batch), state)
state = opt_init(params)
for i in range(10):
state = update(i, state, *next_batch())
trained_params = get_params(opt_state)
def fancy_relu(x):
return relu(x * x)
layer = Sequential(Dense(10), fancy_relu)
The same in stax:
def fancy_relu(x):
return relu(x * x)
FancyRelu = elementwise(fancy_relu)
layer = Serial(Dense(10), FancyRelu)
- Streamlined support for parameter sharing and reuse.
- Support for all stax functionality, and more.
- No need for
shape_dependent
, all submodules can depend on input shapes when defined inline.
-
Update custom layers:
- Add
@parametrized
to yourapply_fun
. - Remove the
params
argument and define or use parameters and layers inline. - Remove
init_params
.
- Add
-
Update predefined layers:
stax JAXnet Serial
Sequential
Relu
,Flatten
,Softmax
, ...relu
,flatten
,softmax
, ...FanInConcat
lambda x: np.concatenate(x, axis=-1)
FanInSum
sum
FanOut
,parallel
Reformulate as @parametrized
module.shape_dependent
Define layers inline, dependent on the input. All other layers Stay the same. -
Update usage of your model as described in the readme.