Skip to content

Using Distrax bijectors with Flax #2572

Answered by cgarciae
nalzok asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @nalzok, problem with this is that self.flow is not a Flax Module so Flax can't set the scope for inner Modules that live MaskedCoupling.conditioner, you can go around this by defining everything inside a compact method. Following code runs:

from typing import Sequence, List, Callable, Any

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import distrax


def make_conditioner(
    event_shape: Sequence[int], hidden_sizes: Sequence[int], num_bijector_params: int
) -> nn.Module:
    """Creates an MLP conditioner for each layer of the flow."""
    layers: List[Callable[..., Any]] = [
        lambda x: x.reshape((-1, *x.shape[-len(event_shape) :]))
    ]

    for

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@nalzok
Comment options

@cgarciae
Comment options

@nalzok
Comment options

Answer selected by nalzok
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants