bug while fitting a multivariate normal #2442
Answered
by
aldopareja
aldopareja
asked this question in
Q&A
-
I'm trying to fit a multivariate normal and it's not working. This should be trivial since I'm initializing at the optimum. import jax
from flax import linen as nn
import jax.numpy as jnp
import optax
samples = jax.random.multivariate_normal(jax.random.PRNGKey(12346), jnp.array([5]), jnp.array([[10.0]]), shape=(1000,))
class MultiVariateNormal(nn.Module):
@nn.compact
def __call__(self, samples, mu, cov_terms):
mu = self.param('mu',
nn.initializers.constant(5.0), # Initialization function
(1,))
cov = self.param('cov',
nn.initializers.constant(10.0), # Initialization function
(1,1))
return jax.vmap(jax.scipy.stats.multivariate_normal.logpdf, in_axes=(0,None,None))(x,mu, cov)
def update_step(apply_fn, samples, opt_state, params, mu, cov_terms):
def loss(params):
l = apply_fn(params, samples, mu, cov_terms)
return -l.sum()
l, grads = jax.value_and_grad(loss, has_aux=False)(params)
updates, opt_state = tx.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return opt_state, params, l
m = MultiVariateNormal()
params = m.init(jax.random.PRNGKey(567), jnp.ones((10,2)), jnp.array([0.5,0.5]), jnp.array([1.0,0.0,1.0]))
tx = optax.adam(learning_rate=0.01)
opt_state = tx.init(params)
for i in range(1000):
# opt_state, params, l = jax.jit(update_step, static_argnums=(0,))(m.apply, samples, opt_state, params, jnp.array([0.5,0.5]), jnp.array([1.0,0.0,1.0]))
opt_state, params, l = update_step(m.apply, samples, opt_state, params, jnp.array([0.5,0.5]), jnp.array([1.0,0.0,1.0]))
if i % 100 == 0:
print(l)
print(params) This is the output I'm getting:
weird enough, when printing the loss output of the found parameters and the optimal parameters it clearly shows that the optimal are better, but somehow the printed loss is different. What could be the problem? mu = params['params']['mu']
cov = params['params']['cov']
print(jax.vmap(jax.scipy.stats.multivariate_normal.logpdf, in_axes=(0,None,None))(samples, mu, cov).sum())
print(jax.vmap(jax.scipy.stats.multivariate_normal.logpdf, in_axes=(0,None,None))(samples, jnp.array([0.5]), jnp.array([[10.0]])).sum())
|
Beta Was this translation helpful? Give feedback.
Answered by
aldopareja
Sep 7, 2022
Replies: 1 comment
-
I have a bug in the code, using x instead of samples in the forward function. Sorry. |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
aldopareja
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I have a bug in the code, using x instead of samples in the forward function. Sorry.