Skip to content

Commit

Permalink
version 8.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed Nov 26, 2024
1 parent ed797db commit cef21bb
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 106 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ The different kinds of importance samples can also be visualised by querying the

### Release log

#### 8.0.0
* Removed OPEEstimatorBase as it was already tightly coupled with ISEstimatorBase;
* Altered the internal location of the process_weights call for estimators inheriting from ISEstimatorBase
* Previously, the weights for weighted estimators were not being clipped etc

#### 7.0.1
* Made the logging location for d3rlpy/FQE callback an optional parameter
* D3rlpy API now handles the use of observation and action scalers
Expand Down
5 changes: 2 additions & 3 deletions src/offline_rl_ope/OPEEstimators/DoublyRobust.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from .EmpiricalMeanDenom import EmpiricalMeanDenomBase
from .WeightDenom import WeightDenomBase
from ..types import WeightTensor
from .IS import ISEstimator
from .IS import ISEstimatorBase
from .DirectMethod import DirectMethodBase
from ..RuntimeChecks import check_array_shape


class DREstimator(ISEstimator):
class DREstimator(ISEstimatorBase):

def __init__(
self,
Expand Down Expand Up @@ -84,7 +84,6 @@ def predict_traj_rewards(
discnt_rewards = self.get_dataset_discnt_reward(
rewards=rewards, discount=discount, h=h)
# weights dim is (n_trajectories, max_length)
weights = self.process_weights(weights=weights, is_msk=is_msk)
v:List[Float[torch.Tensor, "max_length 1"]] = []
q:List[Float[torch.Tensor, "max_length 1"]] = []
for s,a in zip(states, actions):
Expand Down
88 changes: 79 additions & 9 deletions src/offline_rl_ope/OPEEstimators/IS.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import torch
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, List, Union
from jaxtyping import jaxtyped, Float
from typeguard import typechecked as typechecker

from .utils import (
clip_weights_pass as cwp,
clip_weights as cw
)
from .EmpiricalMeanDenom import EmpiricalMeanDenomBase
from .WeightDenom import WeightDenomBase
from .base import OPEEstimatorBase
from ..types import (RewardTensor,WeightTensor)


class ISEstimatorBase(OPEEstimatorBase):
class ISEstimatorBase(metaclass=ABCMeta):

def __init__(
self,
Expand All @@ -23,21 +22,29 @@ def __init__(
cache_traj_rewards:bool=False,
clip:float=0.0,
) -> None:
super().__init__(
empirical_denom=empirical_denom,
cache_traj_rewards=cache_traj_rewards
)
assert isinstance(weight_denom,WeightDenomBase)
assert isinstance(clip_weights,bool)
assert isinstance(cache_traj_rewards,bool)
assert isinstance(clip,float)
self.traj_rewards_cache:torch.Tensor = torch.Tensor(0)
if cache_traj_rewards:
self.__cache_func = self.__cache
else:
self.__cache_func = self.__pass_cache
self.empirical_denom = empirical_denom
self.clip = clip
if clip_weights:
self.clip_weights = cw
else:
self.clip_weights = cwp
self.weight_denom = weight_denom


def __cache(self, traj_rewards):
self.traj_rewards_cache = traj_rewards

def __pass_cache(self, traj_rewards):
pass

@jaxtyped(typechecker=typechecker)
def process_weights(
self,
Expand Down Expand Up @@ -124,7 +131,71 @@ def get_traj_discnt_reward(
reward_array = reward_array.squeeze()
discnt_reward = reward_array*discnt_vals
return discnt_reward

@jaxtyped(typechecker=typechecker)
def predict(
self,
rewards:List[torch.Tensor],
states:List[torch.Tensor],
actions:List[torch.Tensor],
weights:torch.Tensor,
discount:float,
is_msk:torch.Tensor
)->torch.Tensor:
l_s = len(states)
l_r = len(rewards)
l_a = len(actions)
_msg = f"State({l_s}), rewards({l_r}), actions({l_a}), should be equal"
assert l_s==l_r==l_a, _msg
assert isinstance(weights,torch.Tensor)
assert isinstance(discount,float)
assert isinstance(is_msk,torch.Tensor)
weights = self.process_weights(weights=weights, is_msk=is_msk)
traj_rewards = self.predict_traj_rewards(
rewards=rewards, states=states, actions=actions, weights=weights,
discount=discount, is_msk=is_msk
)
self.__cache_func(traj_rewards)
denom = self.empirical_denom(
weights=weights,
is_msk=is_msk
)
return traj_rewards.sum()/denom

@abstractmethod
def predict_traj_rewards(
self,
rewards:List[torch.Tensor],
states:List[torch.Tensor],
actions:List[torch.Tensor],
weights:WeightTensor,
discount:float,
is_msk:WeightTensor
)->Float[torch.Tensor, "n_trajectories"]:
"""Function for subclasses to override defining the trajectory level
estimates of return
Args:
rewards (List[torch.Tensor]): List of Tensors of undiscounted
rewards of dimension (max horizon, 1). Trajectories with
length < max_horizon should have zero weight imputed
states (List[torch.Tensor]): List of Tensors of state values. Should
be of dimension (traj horizon, state features)
actions (List[torch.Tensor]): List of Tensors of state values.
Should be of dimension (traj horizon, action features)
weights (torch.Tensor): Tensor of IS weights of dimension
(# trajectories, max_horizon). Trajectories with length <
max_horizon should have zero weight imputed
discount (float): One step discount factor
is_msk (torch.Tensor): Tensor of dimension
(# trajectories, max_horizon) defining the lengths of individual
trajectories
Returns:
torch.Tensor: tensor of size (# trajectories,) defining the
individual trajectory rewards
"""
pass


class ISEstimator(ISEstimatorBase):
Expand Down Expand Up @@ -189,7 +260,6 @@ def predict_traj_rewards(
discnt_rewards = self.get_dataset_discnt_reward(
rewards=rewards, discount=discount, h=h)
# weights dim is (n_trajectories, max_length)
weights = self.process_weights(weights=weights, is_msk=is_msk)
# (n_trajectories,max_length) ELEMENT WISE * (n_trajectories,max_length)
res = torch.mul(discnt_rewards,weights).sum(dim=1)
return res
1 change: 0 additions & 1 deletion src/offline_rl_ope/OPEEstimators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .DirectMethod import DirectMethodBase, D3rlpyQlearnDM
from .DoublyRobust import DREstimator
from .IS import ISEstimatorBase, ISEstimator
from .base import OPEEstimatorBase
from .EmpiricalMeanDenom import *
from .WeightDenom import *
93 changes: 0 additions & 93 deletions src/offline_rl_ope/OPEEstimators/base.py

This file was deleted.

0 comments on commit cef21bb

Please sign in to comment.