Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 89 additions & 2 deletions chirho/interventional/handlers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
from __future__ import annotations

import collections
import dataclasses
import functools
from typing import Callable, Dict, Generic, Hashable, Mapping, Optional, TypeVar, Union
from contextlib import contextmanager
from typing import (
Callable,
Collection,
Dict,
Generic,
Hashable,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)

import pyro
import torch
Expand All @@ -12,6 +25,7 @@
CompoundIntervention,
intervene,
)
from chirho.observational.handlers.predictive import BatchedLatents

K = TypeVar("K")
T = TypeVar("T")
Expand Down Expand Up @@ -64,7 +78,6 @@ def _dict_intervene(
act: Union[Dict[K, AtomicIntervention[T]], Callable[[Dict[K, T]], Dict[K, T]]],
**kwargs,
) -> Dict[K, T]:

if callable(act):
return _dict_intervene_callable(obs, act, **kwargs)

Expand Down Expand Up @@ -137,3 +150,77 @@ def _pyro_post_sample(self, msg):

@pyro.poutine.handlers._make_handler(Interventions)
def do(fn: Callable, actions: Mapping[Hashable, AtomicIntervention[T]]): ...


@dataclasses.dataclass
class _BatchedAction:
act: torch.Tensor
mask: torch.Tensor


class _BatchedInterventions(Interventions):
def __init__(
self, actions: Mapping[Hashable, _BatchedAction], name="batched_interventions"
):
if not actions:
raise ValueError("Expected a nonempty actions dict.")

batch_sizes = set([v.act.shape[0] for v in actions.values()])
if len(batch_sizes) != 1:
raise ValueError("Expected each intervention to have the same batch size.")

self.batch_size = list(batch_sizes)[0]

super().__init__(actions)

def _pyro_intervene(self, msg):
(obs, act) = msg["args"]
if not isinstance(act, _BatchedAction):
return

msg["value"] = torch.where(act.mask, act.act, obs)


@contextmanager
def batched_do(
interventions: (
Mapping[Hashable, Tuple[torch.Tensor, torch.Tensor]]
| Collection[Mapping[Hashable, torch.Tensor]]
),
name="batched_interventions",
):
"""Perform a batch of interventions efficiently.

Batches can be specified either as:

1. A collection of individual interventions, as might be passed to `do`. The
actions are restricted to be tensors, however. The action tensors may be of
different shapes, but they will all be broadcast together.

2. A mapping from sample sites to pairs of tensors (act, mask) that specify
the intervention to apply for each index in the batch and whether an
intervention should be applied. Each act tensor should be of shape
(batch_size, ...) and each mask tensor should be of shape (batch_size).

.. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` .
"""
if isinstance(interventions, collections.abc.Mapping):
batches = {k: _BatchedAction(*v) for (k, v) in interventions.items()}
else:
vars_ = set.union(*[set(i.keys()) for i in interventions])
masks = {k: torch.zeros(len(interventions), dtype=torch.bool) for k in vars_}
acts = {k: [torch.tensor(float("nan"))] * len(interventions) for k in vars_}
for i, intv in enumerate(interventions):
for k, v in intv.items():
masks[k][i] = True
acts[k][i] = v

batched_acts = {
k: torch.stack(torch.broadcast_tensors(v)) for (k, v) in intv.items()
}
batches = {k: _BatchedAction(batched_acts[k], masks[k]) for k in vars_}

batched_intervene = _BatchedInterventions(batches)
batched_latents = BatchedLatents(batched_intervene.batch_size, name=name)
with batched_latents, batched_intervene:
yield
18 changes: 17 additions & 1 deletion tests/interventional/test_do_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
SingleWorldFactual,
TwinWorldCounterfactual,
)
from chirho.interventional.handlers import do
from chirho.indexed.handlers import IndexPlatesMessenger
from chirho.indexed.ops import indices_of
from chirho.interventional.handlers import batched_do, do
from chirho.interventional.ops import intervene

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,6 +57,20 @@ def create_intervened_model_2(x_cf_value):
return intervene(model, {"x": torch.tensor(x_cf_value)})


def test_do_messenger_batched():
interventions = {
"y": (torch.tensor([1, 300, 0]), torch.tensor([True, True, False])),
"z": (torch.tensor([100, 0, 200]), torch.tensor([True, False, True])),
}

with IndexPlatesMessenger(), batched_do(interventions):
z, x, y = model()
assert indices_of(x) == indices_of(y) == indices_of(z) and indices_of(z)[
"batched_interventions"
] == set(range(3))
assert ((z >= 100.0) | (y >= 100)).all()


@pytest.mark.parametrize("x_cf_value", x_cf_values)
def test_do_messenger_factual(x_cf_value):
intervened_model = create_intervened_model(x_cf_value)
Expand Down
26 changes: 17 additions & 9 deletions tests/observational/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ def _soft_eq(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor:
return soft_eq(constraints.real, v1, v2, scale=scale)

reparam_config = {name: KernelSoftConditionReparam(_soft_eq) for name in data}
with pyro.poutine.trace() as tr, pyro.poutine.reparam(
config=reparam_config
), condition(data=data):
with (
pyro.poutine.trace() as tr,
pyro.poutine.reparam(config=reparam_config),
condition(data=data),
):
continuous_scm_1()

tr.trace.compute_log_prob()
Expand Down Expand Up @@ -114,9 +116,11 @@ def _soft_eq(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor:
return soft_eq(constraints.boolean, v1, v2, scale=scale)

reparam_config = {name: KernelSoftConditionReparam(_soft_eq) for name in data}
with pyro.poutine.trace() as tr, pyro.poutine.reparam(
config=reparam_config
), condition(data=data):
with (
pyro.poutine.trace() as tr,
pyro.poutine.reparam(config=reparam_config),
condition(data=data),
):
discrete_scm_1()

tr.trace.compute_log_prob()
Expand Down Expand Up @@ -158,9 +162,13 @@ def test_soft_conditioning_counterfactual_continuous_1(

actions = {"x": torch.tensor(0.1234)}

with pyro.poutine.trace() as tr, pyro.poutine.reparam(
config=reparam_config
), cf_class(cf_dim), do(actions=actions), condition(data=data):
with (
pyro.poutine.trace() as tr,
pyro.poutine.reparam(config=reparam_config),
cf_class(cf_dim),
do(actions=actions),
condition(data=data),
):
continuous_scm_1()

tr.trace.compute_log_prob()
Expand Down
Loading