Skip to content
12 changes: 6 additions & 6 deletions areal/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,10 +1446,11 @@ def __call__(
mean = torch.zeros_like(x)

# Subtract mean
x_centered = x - mean
# mask unrelevant elements as 0
if loss_mask is not None:
x_centered = x_centered * loss_mask
x_safe = torch.where(loss_mask.bool(), x, 0.0)
x_centered = (x_safe - mean) * loss_mask
else:
x_centered = x - mean
Comment thread
fishcrap marked this conversation as resolved.

# Step 2: Compute std
if self.std_level == "batch":
Expand Down Expand Up @@ -1517,7 +1518,7 @@ def _compute_mean(
x_sum = x.sum(dim=dim, keepdim=True)
else:
mask = mask.to(dtype)
x_masked = x * mask
x_masked = torch.where(mask.bool(), x, 0.0)
factor = mask.sum(dim, keepdim=True)
x_sum = x_masked.sum(dim=dim, keepdim=True)

Expand Down Expand Up @@ -1576,9 +1577,8 @@ def _compute_std(
x_sum_sq = (x_centered**2).sum(dim=dim, keepdim=True)
else:
mask = mask.to(dtype)
x_masked = x * mask
factor = mask.sum(dim, keepdim=True)
x_centered = x_masked - mean * mask # only apply mean where mask is 1
x_centered = torch.where(mask.bool(), x - mean, 0.0)
x_sum_sq = (x_centered**2).sum(dim=dim, keepdim=True)

if dist.is_initialized() and all_reduce:
Expand Down
42 changes: 42 additions & 0 deletions tests/test_adv_norm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,48 @@ def test_non_trivial_loss_mask_batch_normalization():
assert torch.abs(non_masked_values.std() - 1.0) < 1e-5


def test_masked_invalid_values_do_not_poison_batch_normalization():
config = NormConfig(mean_level="batch", std_level="batch", group_size=1)
adv_norm = Normalization(config)

advantages = torch.tensor(
[[1.0, float("nan")], [3.0, 5.0]],
dtype=torch.float32,
)
loss_mask = torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32)

normalized = adv_norm(advantages, loss_mask)

expected_valid = torch.tensor(
[-1.0, 0.0, 1.0],
dtype=torch.float32,
)
assert torch.isfinite(normalized).all()
assert torch.allclose(normalized[loss_mask.bool()], expected_valid, atol=1e-6)
assert normalized[0, 1].item() == 0.0


def test_masked_invalid_values_do_not_poison_group_normalization():
config = NormConfig(mean_level="group", std_level="group", group_size=2)
adv_norm = Normalization(config)

advantages = torch.tensor(
[[1.0, float("inf")], [3.0, 5.0]],
dtype=torch.float32,
)
loss_mask = torch.tensor([[1.0, 0.0], [1.0, 1.0]], dtype=torch.float32)

normalized = adv_norm(advantages, loss_mask)

expected_valid = torch.tensor(
[-1.0, 0.0, 1.0],
dtype=torch.float32,
)
assert torch.isfinite(normalized).all()
assert torch.allclose(normalized[loss_mask.bool()], expected_valid, atol=1e-6)
assert normalized[0, 1].item() == 0.0


def test_non_trivial_loss_mask_leave_one_out():
"""Test leave-one-out normalization with non-trivial loss mask and verify expected values."""
config = NormConfig(
Expand Down
Loading