Skip to content

Commit

Permalink
Merge pull request #8 from joshuaspear/hot_fix/wis_correction
Browse files Browse the repository at this point in the history
fixed per decision weighted IS. Updated testing. Altered effective sa…
  • Loading branch information
joshuaspear authored Mar 1, 2024
2 parents d3ed4d9 + bf8d506 commit 10308ac
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 48 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# offline_rl_ope (BETA RELEASE)

**WARNING**
- Per-decision weighted importance sampling was incorrectly implemented in versions < 5.X
- Weighted importance sampling was incorrectly implemented in versions 1.X.X and 2.1.X, 2.2.X
- Unit testing currently only running in Python 3.11. 3.10 will be supported in the future
- Only 1 dimensional discrete action spaces are currently supported!
Expand Down Expand Up @@ -88,6 +89,16 @@ If importance sampling based methods are evaluating to 0, consider visualising t
The different kinds of importance samples can also be visualised by querying the ```traj_is_weights``` attribute of a given ```ImportanceSampler``` object. If for example, vanilla importance sampling is being used and the samples are not ```NaN``` or ```Inf``` then visualising the ```traj_is_weights``` may provide insight. In particular, IS weights will tend to inifinity when the evaluation policy places large density on an action in comparison to the behaviour policy.

### Release log
#### 5.0.0
* Correctly implemented per-decision weighted importance sampling
* Expanded the different types of weights that can be implemented based on:
* http://proceedings.mlr.press/v48/jiang16.pdf: Per-decision weights are defined as the average weight at a given timepoint. This results in a different denominator for different timepoints. This is implemented with the following ```WISWeightNorm(avg_denom=True)```
* https://scholarworks.umass.edu/cgi/viewcontent.cgi?article=1079&context=cs_faculty_pubs: Per-decision weights are defined as the sum of discounted weights across all timesteps. This is implemented with the following ```WISWeightNorm(discount=discount_value)```
* Combinations of different weights can be easily implemented for example 'average discounted weights' ```WISWeightNorm(discount=discount_value, avg_denom=True)``` however, these do not necessaily have backing from literature.
* EffectiveSampleSize metric optinally returns nan if all weights are 0
* Bug fixes:
* Fix bug when running on cuda where tensors were not being pushed to CPU
* Improved static typing
#### 4.0.0
* Predefined propensity models including:
* Generic feedforward MLP for continuous and discrete action spaces built in PyTorch
Expand Down
9 changes: 5 additions & 4 deletions src/offline_rl_ope/Metrics/EffectiveSampleSize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ def __init__(self, nan_if_all_0:bool=True) -> None:

def __ess(self, weights:torch.Tensor) -> float:
# https://victorelvira.github.io/papers/kong92.pdf
weights = weights.sum(dim=1)
numer = len(weights)
w_var = torch.var(weights).item()
if (w_var == 0) and (self.__nan_if_all_0):
all_0 = (weights == 0).all().item()
if (all_0) and (self.__nan_if_all_0):
res = np.nan
else:
weights = weights.sum(dim=1)
numer = len(weights)
w_var = torch.var(weights).item()
res = (numer/(1+w_var))
return res

Expand Down
6 changes: 3 additions & 3 deletions src/offline_rl_ope/OPEEstimators/IS.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, List

from .utils import (
WISNormWeights, NormWeightsPass, WeightNorm,
WISWeightNorm, VanillaNormWeights, WeightNorm,
clip_weights_pass as cwp,
clip_weights as cw
)
Expand All @@ -23,9 +23,9 @@ def __init__(
) -> None:
super().__init__(cache_traj_rewards)
if norm_weights:
_norm_weights = WISNormWeights(**norm_kwargs)
_norm_weights = WISWeightNorm(**norm_kwargs)
else:
_norm_weights = NormWeightsPass(**norm_kwargs)
_norm_weights = VanillaNormWeights(**norm_kwargs)
self.norm_weights:WeightNorm = _norm_weights
self.clip = clip
if clip_weights:
Expand Down
81 changes: 60 additions & 21 deletions src/offline_rl_ope/OPEEstimators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,52 @@ class WeightNorm(metaclass=ABCMeta):
def __call__(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor
) -> torch.Tensor:
pass

# is_msk.sum(axis=0, keepdim=True) is taken as the
# denominator since it is required to take the average over valid time t
# importance ratios. This may differ for different episodes.
# ref: http://proceedings.mlr.press/v48/jiang16.pdf


class WISWeightNorm(WeightNorm):

class WISNormWeights(WeightNorm):

def __init__(self, smooth_eps:float=0.0, *args, **kwargs) -> None:
def __init__(
self,
smooth_eps:float=0.0,
avg_denom:bool=False,
discount:float=1,
*args,
**kwargs
) -> None:
self.smooth_eps = smooth_eps
self.avg_denom = avg_denom
self.discount = discount

def calc_norm(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor
) -> torch.Tensor:
"""Calculates the denominator for weighted importance sampling i.e.
w_{t} = 1/n sum_{i=1}^{n} p_{1:t}. Note, if traj_is_weights represent
vanilla IS samples then this will be w_{t} = 1/n sum_{i=1}^{n} p_{1:H}
for all samples. is_msk.sum(axis=0, keepdim=True) is taken as the
denominator since it is required to take the average over valid time t
importance ratios. This may differ for different episodes.
ref: http://proceedings.mlr.press/v48/jiang16.pdf
def calc_norm(
self,
traj_is_weights:torch.Tensor,
is_msk:torch.Tensor
) -> torch.Tensor:
"""Calculates the denominator for weighted importance sampling.
smooth_eps prevents nan values occuring in instances where there exists
valid time t importance ratios however, these are all 0. This should
be set as small as possible.
avg_denom: defines the denominator as the average weight for time t
as per http://proceedings.mlr.press/v48/jiang16.pdf
Note:
- If traj_is_weights represents vanilla IS samples then:
- The denominator will be w_{t} = sum_{i=1}^{n} p_{1:H} for all
samples.
- If avg_denom is set to true, the denominator will be
w_{t} = 1/n_{t} sum_{i=1}^{n} p_{1:H} where n_{t} is the number of
trajectories of at least length, t.
- If traj_is_weights represents PD IS samples then:
- The denominator will be w_{t} = sum_{i=1}^{n} p_{1:t}.
- If avg_denom is set to true, the denominator will be
w_{t} = 1/n_{t} sum_{i=1}^{n} p_{1:t} where n_{t} is the number of
trajectories of at least length, t. This definition aligns with
http://proceedings.mlr.press/v48/jiang16.pdf
Args:
traj_is_weights (torch.Tensor): (# trajectories, max(traj_length))
Tensor. traj_is_weights[i,j] defines the jth timestep propensity
Expand All @@ -40,11 +67,19 @@ def calc_norm(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor
torch.Tensor: Tensor of dimension (# trajectories, 1) defining the
normalisation value for each timestep
"""
denom:torch.Tensor = traj_is_weights.sum(dim=0, keepdim=True)
denom = (denom+self.smooth_eps)/(
is_msk.sum(dim=0, keepdim=True)+self.smooth_eps)
discnt_tens = torch.full(traj_is_weights.shape, self.discount)
discnt_pows = torch.arange(0, traj_is_weights.shape[1])[None,:].repeat(
traj_is_weights.shape[0],1)
discnt_tens = torch.pow(discnt_tens,discnt_pows)
traj_is_weights = torch.mul(traj_is_weights,discnt_tens)
denom = (
traj_is_weights.sum(dim=0, keepdim=True) + self.smooth_eps
)
if self.avg_denom:
denom = denom/(
is_msk.sum(dim=0, keepdim=True)+self.smooth_eps)
return denom

def __call__(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor
) -> torch.Tensor:
"""Normalised propensity weights according to
Expand All @@ -63,10 +98,12 @@ def __call__(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor
with normalised weights
"""
denom = self.calc_norm(traj_is_weights=traj_is_weights, is_msk=is_msk)
res = traj_is_weights/(denom+self.smooth_eps)
res = traj_is_weights/denom
return res



class NormWeightsPass(WeightNorm):
class VanillaNormWeights(WeightNorm):

def __init__(self, *args, **kwargs) -> None:
pass
Expand All @@ -84,9 +121,11 @@ def __call__(self, traj_is_weights:torch.Tensor, is_msk:torch.Tensor
ith trajectory was observed
Returns:
torch.Tensor: Identical tensor to traj_is_weights
torch.Tensor: traj_is_weights with element wise average
"""
return traj_is_weights
# The first dimension defines the number of trajectories and we require
# the average over trajectories
return traj_is_weights/traj_is_weights.shape[0]

def clip_weights(
traj_is_weights:torch.Tensor,
Expand Down
9 changes: 6 additions & 3 deletions tests/Metrics/test_EffectiveSampleSize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def test_call(self):
assert len(weights) == 2
denum = 1 + torch.var(weights)
act_res = (num/denum).item()
metric = EffectiveSampleSize(is_obj=TestImportanceSampler())
pred_res = metric()
self.assertEqual(act_res,pred_res)
metric = EffectiveSampleSize(nan_if_all_0=True)
pred_res = metric(
weights=weight_test_res
)
tol = act_res/1000
np.testing.assert_allclose(pred_res, act_res, atol=tol)
8 changes: 5 additions & 3 deletions tests/OPEEstimators/test_DoublyRobust.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,20 @@ def v_side_effect(state:torch.Tensor):
weights=weight_test_res,
discount=gamma,
is_msk=msk_test_res)
#weight_test_res = weight_test_res/weight_test_res.shape[0]
denom = weight_test_res.shape[0]
for idx, (r,s,a,w,msk) in enumerate(zip(rewards, states, actions,
weight_test_res, msk_test_res)):
w = w/denom
p = torch.masked_select(w, msk>0)
__test_res = is_est.get_traj_discnt_reward(
reward_array=r, discount=gamma, state_array=s, action_array=a,
weight_array=p)
test_res.append(__test_res.numpy())
#test_res = np.concatenate(test_res).mean()
test_res = np.concatenate(test_res)
tol = (test_res.mean()/1000).item()
tol = (np.abs(test_res.mean()/100)).item()
self.assertEqual(pred_res.shape, torch.Size((len(rewards),)))
np.testing.assert_allclose(pred_res.numpy(),test_res, atol=tol)

np.testing.assert_allclose(pred_res.numpy(),test_res, atol=tol)


5 changes: 4 additions & 1 deletion tests/OPEEstimators/test_IS.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def __mock_return(rewards, discount, h):
pred_res = self.is_estimator.predict_traj_rewards(
rewards=rewards, actions=[], states=[], weights=weight_test_res,
discount=gamma, is_msk=msk_test_res)
test_res = np.multiply(reward_test_res.numpy(), weight_test_res.numpy())
test_res = np.multiply(
reward_test_res.numpy(),
weight_test_res.numpy()/weight_test_res.shape[0]
)
test_res=test_res.sum(axis=1)
#test_res = test_res.sum(axis=1).mean()
tol = test_res.mean()/1000
Expand Down
Loading

0 comments on commit 10308ac

Please sign in to comment.