Skip to content

Commit

Permalink
added weight standard deviation metric
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed Sep 29, 2024
1 parent aef7618 commit c0ffaa0
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/offline_rl_ope/Metrics/MetricBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ class MetricBase(metaclass=ABCMeta):

@abstractmethod
def __call__(self, weights:torch.Tensor, *args:Any, **kwargs:Any) -> float:
return self.__ess(weights=weights)
pass
1 change: 0 additions & 1 deletion src/offline_rl_ope/Metrics/ValidWeightsProp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from .MetricBase import MetricBase
from ..OPEEstimators.utils import get_traj_weight_final
from jaxtyping import jaxtyped
from typeguard import typechecked as typechecker

Expand Down
21 changes: 21 additions & 0 deletions src/offline_rl_ope/Metrics/WeightStd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from .MetricBase import MetricBase
from jaxtyping import jaxtyped
from typeguard import typechecked as typechecker

from ..types import WeightTensor

__all__ = [
"WeightStd"
]

class WeightStd(MetricBase):

@jaxtyped(typechecker=typechecker)
def __call__(
self,
weights:WeightTensor,
weight_msk:WeightTensor
) -> float:
sum_weights = torch.mul(weights,weight_msk).sum(dim=1)
return torch.std(sum_weights).item()
1 change: 1 addition & 0 deletions src/offline_rl_ope/Metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .EffectiveSampleSize import *
from .ValidWeightsProp import *
from .WeightStd import *
from .MetricBase import *
2 changes: 1 addition & 1 deletion tests/Metrics/test_ValidWeightsProp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_call(self):
min_val=0.000001
fnl_weights = []
for idx,i in enumerate(self.test_conf.traj_lengths):
fnl_weights.append(self.test_conf.weight_test_res[idx,:i-1].sum(
fnl_weights.append(self.test_conf.weight_test_res[idx,:i].sum(
dim=0,
keepdim=True
))
Expand Down
31 changes: 31 additions & 0 deletions tests/Metrics/test_WeightStd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest
import torch
import numpy as np
import copy
from offline_rl_ope.Metrics import WeightStd
from offline_rl_ope import logger
from parameterized import parameterized_class
from ..base import test_configs_fmt_class, TestConfig

@parameterized_class(test_configs_fmt_class)
class TestWeightStd(unittest.TestCase):

test_conf:TestConfig

def test_call(self):
fnl_weights = []
for idx,i in enumerate(self.test_conf.traj_lengths):
fnl_weights.append(
self.test_conf.weight_test_res[idx,:i].sum(
dim=0,
keepdim=True
)
)
fnl_weights_tens = torch.concat(fnl_weights, axis=0)
act_res = torch.std(fnl_weights_tens).item()
metric = WeightStd()
pred_res = metric(
weights=self.test_conf.weight_test_res,
weight_msk=self.test_conf.msk_test_res
)
self.assertEqual(act_res,pred_res)

0 comments on commit c0ffaa0

Please sign in to comment.