diff --git a/README.md b/README.md index ab6baa5..1c85885 100644 --- a/README.md +++ b/README.md @@ -95,11 +95,6 @@ The different kinds of importance samples can also be visualised by querying the ### Release log -#### 8.0.0 -* Removed OPEEstimatorBase as it was already tightly coupled with ISEstimatorBase; -* Altered the internal location of the process_weights call for estimators inheriting from ISEstimatorBase - * Previously, the weights for weighted estimators were not being clipped etc - #### 7.0.1 * Made the logging location for d3rlpy/FQE callback an optional parameter * D3rlpy API now handles the use of observation and action scalers diff --git a/src/offline_rl_ope/OPEEstimators/DoublyRobust.py b/src/offline_rl_ope/OPEEstimators/DoublyRobust.py index c5097e4..f44a5b1 100644 --- a/src/offline_rl_ope/OPEEstimators/DoublyRobust.py +++ b/src/offline_rl_ope/OPEEstimators/DoublyRobust.py @@ -9,12 +9,12 @@ from .EmpiricalMeanDenom import EmpiricalMeanDenomBase from .WeightDenom import WeightDenomBase from ..types import WeightTensor -from .IS import ISEstimatorBase +from .IS import ISEstimator from .DirectMethod import DirectMethodBase from ..RuntimeChecks import check_array_shape -class DREstimator(ISEstimatorBase): +class DREstimator(ISEstimator): def __init__( self, @@ -84,6 +84,7 @@ def predict_traj_rewards( discnt_rewards = self.get_dataset_discnt_reward( rewards=rewards, discount=discount, h=h) # weights dim is (n_trajectories, max_length) + weights = self.process_weights(weights=weights, is_msk=is_msk) v:List[Float[torch.Tensor, "max_length 1"]] = [] q:List[Float[torch.Tensor, "max_length 1"]] = [] for s,a in zip(states, actions): diff --git a/src/offline_rl_ope/OPEEstimators/IS.py b/src/offline_rl_ope/OPEEstimators/IS.py index a2e6beb..26109e8 100644 --- a/src/offline_rl_ope/OPEEstimators/IS.py +++ b/src/offline_rl_ope/OPEEstimators/IS.py @@ -1,18 +1,19 @@ import torch -from abc import ABCMeta, abstractmethod from typing import Any, Dict, List, Union from jaxtyping import jaxtyped, Float from typeguard import typechecked as typechecker + from .utils import ( clip_weights_pass as cwp, clip_weights as cw ) from .EmpiricalMeanDenom import EmpiricalMeanDenomBase from .WeightDenom import WeightDenomBase +from .base import OPEEstimatorBase from ..types import (RewardTensor,WeightTensor) -class ISEstimatorBase(metaclass=ABCMeta): +class ISEstimatorBase(OPEEstimatorBase): def __init__( self, @@ -22,29 +23,21 @@ def __init__( cache_traj_rewards:bool=False, clip:float=0.0, ) -> None: + super().__init__( + empirical_denom=empirical_denom, + cache_traj_rewards=cache_traj_rewards + ) assert isinstance(weight_denom,WeightDenomBase) assert isinstance(clip_weights,bool) assert isinstance(cache_traj_rewards,bool) assert isinstance(clip,float) - self.traj_rewards_cache:torch.Tensor = torch.Tensor(0) - if cache_traj_rewards: - self.__cache_func = self.__cache - else: - self.__cache_func = self.__pass_cache - self.empirical_denom = empirical_denom self.clip = clip if clip_weights: self.clip_weights = cw else: self.clip_weights = cwp self.weight_denom = weight_denom - - def __cache(self, traj_rewards): - self.traj_rewards_cache = traj_rewards - - def __pass_cache(self, traj_rewards): - pass - + @jaxtyped(typechecker=typechecker) def process_weights( self, @@ -131,71 +124,7 @@ def get_traj_discnt_reward( reward_array = reward_array.squeeze() discnt_reward = reward_array*discnt_vals return discnt_reward - - @jaxtyped(typechecker=typechecker) - def predict( - self, - rewards:List[torch.Tensor], - states:List[torch.Tensor], - actions:List[torch.Tensor], - weights:torch.Tensor, - discount:float, - is_msk:torch.Tensor - )->torch.Tensor: - l_s = len(states) - l_r = len(rewards) - l_a = len(actions) - _msg = f"State({l_s}), rewards({l_r}), actions({l_a}), should be equal" - assert l_s==l_r==l_a, _msg - assert isinstance(weights,torch.Tensor) - assert isinstance(discount,float) - assert isinstance(is_msk,torch.Tensor) - weights = self.process_weights(weights=weights, is_msk=is_msk) - traj_rewards = self.predict_traj_rewards( - rewards=rewards, states=states, actions=actions, weights=weights, - discount=discount, is_msk=is_msk - ) - self.__cache_func(traj_rewards) - denom = self.empirical_denom( - weights=weights, - is_msk=is_msk - ) - return traj_rewards.sum()/denom - @abstractmethod - def predict_traj_rewards( - self, - rewards:List[torch.Tensor], - states:List[torch.Tensor], - actions:List[torch.Tensor], - weights:WeightTensor, - discount:float, - is_msk:WeightTensor - )->Float[torch.Tensor, "n_trajectories"]: - """Function for subclasses to override defining the trajectory level - estimates of return - - Args: - rewards (List[torch.Tensor]): List of Tensors of undiscounted - rewards of dimension (max horizon, 1). Trajectories with - length < max_horizon should have zero weight imputed - states (List[torch.Tensor]): List of Tensors of state values. Should - be of dimension (traj horizon, state features) - actions (List[torch.Tensor]): List of Tensors of state values. - Should be of dimension (traj horizon, action features) - weights (torch.Tensor): Tensor of IS weights of dimension - (# trajectories, max_horizon). Trajectories with length < - max_horizon should have zero weight imputed - discount (float): One step discount factor - is_msk (torch.Tensor): Tensor of dimension - (# trajectories, max_horizon) defining the lengths of individual - trajectories - - Returns: - torch.Tensor: tensor of size (# trajectories,) defining the - individual trajectory rewards - """ - pass class ISEstimator(ISEstimatorBase): @@ -260,6 +189,7 @@ def predict_traj_rewards( discnt_rewards = self.get_dataset_discnt_reward( rewards=rewards, discount=discount, h=h) # weights dim is (n_trajectories, max_length) + weights = self.process_weights(weights=weights, is_msk=is_msk) # (n_trajectories,max_length) ELEMENT WISE * (n_trajectories,max_length) res = torch.mul(discnt_rewards,weights).sum(dim=1) return res \ No newline at end of file diff --git a/src/offline_rl_ope/OPEEstimators/__init__.py b/src/offline_rl_ope/OPEEstimators/__init__.py index 1bf47bc..11c98ec 100644 --- a/src/offline_rl_ope/OPEEstimators/__init__.py +++ b/src/offline_rl_ope/OPEEstimators/__init__.py @@ -1,5 +1,6 @@ from .DirectMethod import DirectMethodBase, D3rlpyQlearnDM from .DoublyRobust import DREstimator from .IS import ISEstimatorBase, ISEstimator +from .base import OPEEstimatorBase from .EmpiricalMeanDenom import * from .WeightDenom import * \ No newline at end of file diff --git a/src/offline_rl_ope/OPEEstimators/base.py b/src/offline_rl_ope/OPEEstimators/base.py new file mode 100644 index 0000000..b1ae93c --- /dev/null +++ b/src/offline_rl_ope/OPEEstimators/base.py @@ -0,0 +1,93 @@ +from abc import ABCMeta, abstractmethod +import torch +from typing import List +from jaxtyping import jaxtyped, Float +from typeguard import typechecked as typechecker + +from ..types import WeightTensor +from .EmpiricalMeanDenom import EmpiricalMeanDenomBase + +class OPEEstimatorBase(metaclass=ABCMeta): + + + def __init__( + self, + empirical_denom:EmpiricalMeanDenomBase, + cache_traj_rewards:bool=False + ) -> None: + self.traj_rewards_cache:torch.Tensor = torch.Tensor(0) + if cache_traj_rewards: + self.__cache_func = self.__cache + else: + self.__cache_func = self.__pass_cache + self.empirical_denom = empirical_denom + + def __cache(self, traj_rewards): + self.traj_rewards_cache = traj_rewards + + def __pass_cache(self, traj_rewards): + pass + + @jaxtyped(typechecker=typechecker) + def predict( + self, + rewards:List[torch.Tensor], + states:List[torch.Tensor], + actions:List[torch.Tensor], + weights:torch.Tensor, + discount:float, + is_msk:torch.Tensor + )->torch.Tensor: + l_s = len(states) + l_r = len(rewards) + l_a = len(actions) + _msg = f"State({l_s}), rewards({l_r}), actions({l_a}), should be equal" + assert l_s==l_r==l_a, _msg + assert isinstance(weights,torch.Tensor) + assert isinstance(discount,float) + assert isinstance(is_msk,torch.Tensor) + traj_rewards = self.predict_traj_rewards( + rewards=rewards, states=states, actions=actions, weights=weights, + discount=discount, is_msk=is_msk + ) + self.__cache_func(traj_rewards) + denom = self.empirical_denom( + weights=weights, + is_msk=is_msk + ) + return traj_rewards.sum()/denom + + @abstractmethod + def predict_traj_rewards( + self, + rewards:List[torch.Tensor], + states:List[torch.Tensor], + actions:List[torch.Tensor], + weights:WeightTensor, + discount:float, + is_msk:WeightTensor + )->Float[torch.Tensor, "n_trajectories"]: + """Function for subclasses to override defining the trajectory level + estimates of return + + Args: + rewards (List[torch.Tensor]): List of Tensors of undiscounted + rewards of dimension (max horizon, 1). Trajectories with + length < max_horizon should have zero weight imputed + states (List[torch.Tensor]): List of Tensors of state values. Should + be of dimension (traj horizon, state features) + actions (List[torch.Tensor]): List of Tensors of state values. + Should be of dimension (traj horizon, action features) + weights (torch.Tensor): Tensor of IS weights of dimension + (# trajectories, max_horizon). Trajectories with length < + max_horizon should have zero weight imputed + discount (float): One step discount factor + is_msk (torch.Tensor): Tensor of dimension + (# trajectories, max_horizon) defining the lengths of individual + trajectories + + Returns: + torch.Tensor: tensor of size (# trajectories,) defining the + individual trajectory rewards + """ + pass