Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 60 additions & 60 deletions tests/test_handlers_numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,73 @@ def setup_module():
pass


TEST_CASES = []
RAW_TEST_CASES = []


@dataclass
class RawTestCase:
raw_dist: str
raw_params: dict[str, str]
raw_params: Sequence[tuple[str, str]]
batch_shape: tuple[int, ...]
xfail: str | None = None


class DistTestCase:
raw_dist: str
params: dict[str, jax.Array]
indexed_params: dict[str, jax.Array]
batch_shape: tuple[int, ...]
xfail: str | None
kind: str

def __init__(
self,
raw_dist: str,
params: dict[str, jax.Array],
indexed_params: dict[str, jax.Array],
batch_shape: tuple[int, ...],
xfail: str | None,
kind: str,
):
self.raw_dist = re.sub(r"\s+", " ", raw_dist.strip())
self.params = params
self.indexed_params = indexed_params
self.batch_shape = batch_shape
self.xfail = xfail
self.kind = kind

def get_dist(self):
"""Return positional and indexed distributions."""
if self.xfail is not None:
pytest.xfail(self.xfail)

Case = namedtuple("Case", tuple(name for name, _ in self.params.items()))

case = Case(**self.params)
dist_ = eval(self.raw_dist)

# case is used by generated code in self.raw_dist
case = Case(**self.indexed_params) # noqa: F841
indexed_dist = eval(self.raw_dist)

return dist_, indexed_dist

def __eq__(self, other):
if isinstance(other, DistTestCase):
return (
self.raw_dist == other.raw_dist
and self.batch_shape == other.batch_shape
and self.kind == other.kind
)

def __hash__(self):
return hash((self.raw_dist, self.batch_shape, self.kind))

def __repr__(self):
return f"{self.raw_dist} {self.batch_shape} {self.kind}"


TEST_CASES: list[DistTestCase] = []
RAW_TEST_CASES: list[RawTestCase] = []


def add_case(raw_dist, raw_params, batch_shape, xfail=None):
RAW_TEST_CASES.append(RawTestCase(raw_dist, raw_params, batch_shape, xfail))

Expand Down Expand Up @@ -474,61 +529,6 @@ def from_indexed(tensor, batch_dims):
return bind_dims(tensor, *indices)


class DistTestCase:
raw_dist: str
params: dict[str, jax.Array]
indexed_params: dict[str, jax.Array]
batch_shape: tuple[int, ...]
xfail: str | None
kind: str

def __init__(
self,
raw_dist: str,
params: dict[str, jax.Array],
indexed_params: dict[str, jax.Array],
batch_shape: tuple[int, ...],
xfail: str | None,
kind: str,
):
self.raw_dist = re.sub(r"\s+", " ", raw_dist.strip())
self.params = params
self.indexed_params = indexed_params
self.batch_shape = batch_shape
self.xfail = xfail
self.kind = kind

def get_dist(self):
"""Return positional and indexed distributions."""
if self.xfail is not None:
pytest.xfail(self.xfail)

Case = namedtuple("Case", tuple(name for name, _ in self.params.items()))

case = Case(**self.params)
dist_ = eval(self.raw_dist)

# case is used by generated code in self.raw_dist
case = Case(**self.indexed_params) # noqa: F841
indexed_dist = eval(self.raw_dist)

return dist_, indexed_dist

def __eq__(self, other):
if isinstance(other, DistTestCase):
return (
self.raw_dist == other.raw_dist
and self.batch_shape == other.batch_shape
and self.kind == other.kind
)

def __hash__(self):
return hash((self.raw_dist, self.batch_shape, self.kind))

def __repr__(self):
return f"{self.raw_dist} {self.batch_shape} {self.kind}"


def full_indexed_test_case(
raw_dist: str,
params: dict[str, jax.Array],
Expand Down
42 changes: 21 additions & 21 deletions tests/test_handlers_pyro_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,6 @@ def setup_module():
torch.distributions.Distribution.set_default_validate_args(False)


TEST_CASES = []


def random_scale_tril(*args):
if isinstance(args[0], tuple):
assert len(args) == 1
shape = args[0]
else:
shape = args

data = torch.randn(shape)
return dist.transforms.transform_to(dist.constraints.lower_cholesky)(data)


def from_indexed(tensor, batch_dims):
tensor_sizes = sizesof(tensor)
indices = [name_to_sym(str(i)) for i in range(batch_dims)]
indices = [i for i in indices if i in tensor_sizes]
return bind_dims(tensor, *indices)


class DistTestCase:
raw_dist: str
params: dict[str, torch.Tensor]
Expand Down Expand Up @@ -104,6 +83,27 @@ def __repr__(self):
return f"{self.raw_dist} {self.batch_shape} {self.kind}"


TEST_CASES: list[DistTestCase] = []


def random_scale_tril(*args):
if isinstance(args[0], tuple):
assert len(args) == 1
shape = args[0]
else:
shape = args

data = torch.randn(shape)
return dist.transforms.transform_to(dist.constraints.lower_cholesky)(data)


def from_indexed(tensor, batch_dims):
tensor_sizes = sizesof(tensor)
indices = [name_to_sym(str(i)) for i in range(batch_dims)]
indices = [i for i in indices if i in tensor_sizes]
return bind_dims(tensor, *indices)


def full_indexed_test_case(
raw_dist: str,
params: dict[str, torch.Tensor],
Expand Down
Loading