-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcells.py
67 lines (55 loc) · 2.61 KB
/
cells.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from flax import linen as nn
from flax.linen import GRUCell
from jax import numpy as jnp
from tensorflow_probability.substrates.jax import distributions as tfd
from config import Config
class RSSMPrior(nn.Module):
c: Config
@nn.compact
def __call__(self, prev_state, context):
inputs = jnp.concatenate([prev_state["sample"], context], -1)
hl = nn.relu(nn.Dense(self.c.cell_embed_size)(inputs))
det_state, det_out = GRUCell()(prev_state["det_state"], hl)
hl = nn.relu(nn.Dense(self.c.cell_embed_size)(det_out))
mean = nn.Dense(self.c.cell_stoch_size)(hl)
stddev = nn.softplus(
nn.Dense(self.c.cell_stoch_size)(hl + .54)) + self.c.cell_min_stddev
dist = tfd.MultivariateNormalDiag(mean, stddev)
sample = dist.sample(seed=self.make_rng('sample'))
return dict(mean=mean, stddev=stddev, sample=sample,
det_out=det_out, det_state=det_state,
output=jnp.concatenate([sample, det_out], -1))
class RSSMPosterior(nn.Module):
c: Config
@nn.compact
def __call__(self, prior, obs_inputs):
inputs = jnp.concatenate([prior["det_out"], obs_inputs], -1)
hl = nn.relu(nn.Dense(self.c.cell_embed_size)(inputs))
hl = nn.relu(nn.Dense(self.c.cell_embed_size)(hl))
mean = nn.Dense(self.c.cell_stoch_size)(hl)
stddev = nn.softplus(
nn.Dense(self.c.cell_stoch_size)(hl + .54)) + self.c.cell_min_stddev
dist = tfd.MultivariateNormalDiag(mean, stddev)
sample = dist.sample(seed=self.make_rng('sample'))
return dict(mean=mean, stddev=stddev, sample=sample,
det_out=prior["det_out"], det_state=prior["det_state"],
output=jnp.concatenate([sample, prior["det_out"]], -1))
class RSSMCell(nn.Module):
c: Config
@property
def state_size(self):
return dict(
mean=self.c.cell_stoch_size, stddev=self.c.cell_stoch_size,
sample=self.c.cell_stoch_size, det_out=self.c.cell_deter_size,
det_state=self.c.cell_deter_size,
output=self.c.cell_stoch_size + self.c.cell_deter_size)
def zero_state(self, batch_size, dtype=jnp.float32):
return {k: jnp.zeros((batch_size, v), dtype=dtype)
for k, v in self.state_size.items()}
@nn.compact
def __call__(self, state, inputs, use_obs):
obs_input, context = inputs
prior = RSSMPrior(self.c)(state, context)
posterior = RSSMPosterior(self.c)(prior,
obs_input) if use_obs else prior
return posterior, (prior, posterior)