Skip to content

Commit

Permalink
Revert "version 8.0.0"
Browse files Browse the repository at this point in the history
This reverts commit cef21bb.
  • Loading branch information
joshuaspear committed Nov 26, 2024
1 parent cef21bb commit 69f8a78
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 86 deletions.
5 changes: 0 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ The different kinds of importance samples can also be visualised by querying the

### Release log

#### 8.0.0
* Removed OPEEstimatorBase as it was already tightly coupled with ISEstimatorBase;
* Altered the internal location of the process_weights call for estimators inheriting from ISEstimatorBase
* Previously, the weights for weighted estimators were not being clipped etc

#### 7.0.1
* Made the logging location for d3rlpy/FQE callback an optional parameter
* D3rlpy API now handles the use of observation and action scalers
Expand Down
5 changes: 3 additions & 2 deletions src/offline_rl_ope/OPEEstimators/DoublyRobust.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from .EmpiricalMeanDenom import EmpiricalMeanDenomBase
from .WeightDenom import WeightDenomBase
from ..types import WeightTensor
from .IS import ISEstimatorBase
from .IS import ISEstimator
from .DirectMethod import DirectMethodBase
from ..RuntimeChecks import check_array_shape


class DREstimator(ISEstimatorBase):
class DREstimator(ISEstimator):

def __init__(
self,
Expand Down Expand Up @@ -84,6 +84,7 @@ def predict_traj_rewards(
discnt_rewards = self.get_dataset_discnt_reward(
rewards=rewards, discount=discount, h=h)
# weights dim is (n_trajectories, max_length)
weights = self.process_weights(weights=weights, is_msk=is_msk)
v:List[Float[torch.Tensor, "max_length 1"]] = []
q:List[Float[torch.Tensor, "max_length 1"]] = []
for s,a in zip(states, actions):
Expand Down
88 changes: 9 additions & 79 deletions src/offline_rl_ope/OPEEstimators/IS.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import torch
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, List, Union
from jaxtyping import jaxtyped, Float
from typeguard import typechecked as typechecker

from .utils import (
clip_weights_pass as cwp,
clip_weights as cw
)
from .EmpiricalMeanDenom import EmpiricalMeanDenomBase
from .WeightDenom import WeightDenomBase
from .base import OPEEstimatorBase
from ..types import (RewardTensor,WeightTensor)


class ISEstimatorBase(metaclass=ABCMeta):
class ISEstimatorBase(OPEEstimatorBase):

def __init__(
self,
Expand All @@ -22,29 +23,21 @@ def __init__(
cache_traj_rewards:bool=False,
clip:float=0.0,
) -> None:
super().__init__(
empirical_denom=empirical_denom,
cache_traj_rewards=cache_traj_rewards
)
assert isinstance(weight_denom,WeightDenomBase)
assert isinstance(clip_weights,bool)
assert isinstance(cache_traj_rewards,bool)
assert isinstance(clip,float)
self.traj_rewards_cache:torch.Tensor = torch.Tensor(0)
if cache_traj_rewards:
self.__cache_func = self.__cache
else:
self.__cache_func = self.__pass_cache
self.empirical_denom = empirical_denom
self.clip = clip
if clip_weights:
self.clip_weights = cw
else:
self.clip_weights = cwp
self.weight_denom = weight_denom

def __cache(self, traj_rewards):
self.traj_rewards_cache = traj_rewards

def __pass_cache(self, traj_rewards):
pass


@jaxtyped(typechecker=typechecker)
def process_weights(
self,
Expand Down Expand Up @@ -131,71 +124,7 @@ def get_traj_discnt_reward(
reward_array = reward_array.squeeze()
discnt_reward = reward_array*discnt_vals
return discnt_reward

@jaxtyped(typechecker=typechecker)
def predict(
self,
rewards:List[torch.Tensor],
states:List[torch.Tensor],
actions:List[torch.Tensor],
weights:torch.Tensor,
discount:float,
is_msk:torch.Tensor
)->torch.Tensor:
l_s = len(states)
l_r = len(rewards)
l_a = len(actions)
_msg = f"State({l_s}), rewards({l_r}), actions({l_a}), should be equal"
assert l_s==l_r==l_a, _msg
assert isinstance(weights,torch.Tensor)
assert isinstance(discount,float)
assert isinstance(is_msk,torch.Tensor)
weights = self.process_weights(weights=weights, is_msk=is_msk)
traj_rewards = self.predict_traj_rewards(
rewards=rewards, states=states, actions=actions, weights=weights,
discount=discount, is_msk=is_msk
)
self.__cache_func(traj_rewards)
denom = self.empirical_denom(
weights=weights,
is_msk=is_msk
)
return traj_rewards.sum()/denom

@abstractmethod
def predict_traj_rewards(
self,
rewards:List[torch.Tensor],
states:List[torch.Tensor],
actions:List[torch.Tensor],
weights:WeightTensor,
discount:float,
is_msk:WeightTensor
)->Float[torch.Tensor, "n_trajectories"]:
"""Function for subclasses to override defining the trajectory level
estimates of return
Args:
rewards (List[torch.Tensor]): List of Tensors of undiscounted
rewards of dimension (max horizon, 1). Trajectories with
length < max_horizon should have zero weight imputed
states (List[torch.Tensor]): List of Tensors of state values. Should
be of dimension (traj horizon, state features)
actions (List[torch.Tensor]): List of Tensors of state values.
Should be of dimension (traj horizon, action features)
weights (torch.Tensor): Tensor of IS weights of dimension
(# trajectories, max_horizon). Trajectories with length <
max_horizon should have zero weight imputed
discount (float): One step discount factor
is_msk (torch.Tensor): Tensor of dimension
(# trajectories, max_horizon) defining the lengths of individual
trajectories
Returns:
torch.Tensor: tensor of size (# trajectories,) defining the
individual trajectory rewards
"""
pass


class ISEstimator(ISEstimatorBase):
Expand Down Expand Up @@ -260,6 +189,7 @@ def predict_traj_rewards(
discnt_rewards = self.get_dataset_discnt_reward(
rewards=rewards, discount=discount, h=h)
# weights dim is (n_trajectories, max_length)
weights = self.process_weights(weights=weights, is_msk=is_msk)
# (n_trajectories,max_length) ELEMENT WISE * (n_trajectories,max_length)
res = torch.mul(discnt_rewards,weights).sum(dim=1)
return res
1 change: 1 addition & 0 deletions src/offline_rl_ope/OPEEstimators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .DirectMethod import DirectMethodBase, D3rlpyQlearnDM
from .DoublyRobust import DREstimator
from .IS import ISEstimatorBase, ISEstimator
from .base import OPEEstimatorBase
from .EmpiricalMeanDenom import *
from .WeightDenom import *
93 changes: 93 additions & 0 deletions src/offline_rl_ope/OPEEstimators/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from abc import ABCMeta, abstractmethod
import torch
from typing import List
from jaxtyping import jaxtyped, Float
from typeguard import typechecked as typechecker

from ..types import WeightTensor
from .EmpiricalMeanDenom import EmpiricalMeanDenomBase

class OPEEstimatorBase(metaclass=ABCMeta):


def __init__(
self,
empirical_denom:EmpiricalMeanDenomBase,
cache_traj_rewards:bool=False
) -> None:
self.traj_rewards_cache:torch.Tensor = torch.Tensor(0)
if cache_traj_rewards:
self.__cache_func = self.__cache
else:
self.__cache_func = self.__pass_cache
self.empirical_denom = empirical_denom

def __cache(self, traj_rewards):
self.traj_rewards_cache = traj_rewards

def __pass_cache(self, traj_rewards):
pass

@jaxtyped(typechecker=typechecker)
def predict(
self,
rewards:List[torch.Tensor],
states:List[torch.Tensor],
actions:List[torch.Tensor],
weights:torch.Tensor,
discount:float,
is_msk:torch.Tensor
)->torch.Tensor:
l_s = len(states)
l_r = len(rewards)
l_a = len(actions)
_msg = f"State({l_s}), rewards({l_r}), actions({l_a}), should be equal"
assert l_s==l_r==l_a, _msg
assert isinstance(weights,torch.Tensor)
assert isinstance(discount,float)
assert isinstance(is_msk,torch.Tensor)
traj_rewards = self.predict_traj_rewards(
rewards=rewards, states=states, actions=actions, weights=weights,
discount=discount, is_msk=is_msk
)
self.__cache_func(traj_rewards)
denom = self.empirical_denom(
weights=weights,
is_msk=is_msk
)
return traj_rewards.sum()/denom

@abstractmethod
def predict_traj_rewards(
self,
rewards:List[torch.Tensor],
states:List[torch.Tensor],
actions:List[torch.Tensor],
weights:WeightTensor,
discount:float,
is_msk:WeightTensor
)->Float[torch.Tensor, "n_trajectories"]:
"""Function for subclasses to override defining the trajectory level
estimates of return
Args:
rewards (List[torch.Tensor]): List of Tensors of undiscounted
rewards of dimension (max horizon, 1). Trajectories with
length < max_horizon should have zero weight imputed
states (List[torch.Tensor]): List of Tensors of state values. Should
be of dimension (traj horizon, state features)
actions (List[torch.Tensor]): List of Tensors of state values.
Should be of dimension (traj horizon, action features)
weights (torch.Tensor): Tensor of IS weights of dimension
(# trajectories, max_horizon). Trajectories with length <
max_horizon should have zero weight imputed
discount (float): One step discount factor
is_msk (torch.Tensor): Tensor of dimension
(# trajectories, max_horizon) defining the lengths of individual
trajectories
Returns:
torch.Tensor: tensor of size (# trajectories,) defining the
individual trajectory rewards
"""
pass

0 comments on commit 69f8a78

Please sign in to comment.