From c0ffaa0c22291f16c1eef2f2ad168d0d6e50cf2f Mon Sep 17 00:00:00 2001 From: Joshua Spear Date: Sun, 29 Sep 2024 16:06:59 +0100 Subject: [PATCH] added weight standard deviation metric --- src/offline_rl_ope/Metrics/MetricBase.py | 2 +- .../Metrics/ValidWeightsProp.py | 1 - src/offline_rl_ope/Metrics/WeightStd.py | 21 +++++++++++++ src/offline_rl_ope/Metrics/__init__.py | 1 + tests/Metrics/test_ValidWeightsProp.py | 2 +- tests/Metrics/test_WeightStd.py | 31 +++++++++++++++++++ 6 files changed, 55 insertions(+), 3 deletions(-) create mode 100644 src/offline_rl_ope/Metrics/WeightStd.py create mode 100644 tests/Metrics/test_WeightStd.py diff --git a/src/offline_rl_ope/Metrics/MetricBase.py b/src/offline_rl_ope/Metrics/MetricBase.py index ba7543b..95d4308 100644 --- a/src/offline_rl_ope/Metrics/MetricBase.py +++ b/src/offline_rl_ope/Metrics/MetricBase.py @@ -11,4 +11,4 @@ class MetricBase(metaclass=ABCMeta): @abstractmethod def __call__(self, weights:torch.Tensor, *args:Any, **kwargs:Any) -> float: - return self.__ess(weights=weights) \ No newline at end of file + pass \ No newline at end of file diff --git a/src/offline_rl_ope/Metrics/ValidWeightsProp.py b/src/offline_rl_ope/Metrics/ValidWeightsProp.py index 91f7c63..5ecddf5 100644 --- a/src/offline_rl_ope/Metrics/ValidWeightsProp.py +++ b/src/offline_rl_ope/Metrics/ValidWeightsProp.py @@ -1,6 +1,5 @@ import torch from .MetricBase import MetricBase -from ..OPEEstimators.utils import get_traj_weight_final from jaxtyping import jaxtyped from typeguard import typechecked as typechecker diff --git a/src/offline_rl_ope/Metrics/WeightStd.py b/src/offline_rl_ope/Metrics/WeightStd.py new file mode 100644 index 0000000..720de45 --- /dev/null +++ b/src/offline_rl_ope/Metrics/WeightStd.py @@ -0,0 +1,21 @@ +import torch +from .MetricBase import MetricBase +from jaxtyping import jaxtyped +from typeguard import typechecked as typechecker + +from ..types import WeightTensor + +__all__ = [ + "WeightStd" +] + +class WeightStd(MetricBase): + + @jaxtyped(typechecker=typechecker) + def __call__( + self, + weights:WeightTensor, + weight_msk:WeightTensor + ) -> float: + sum_weights = torch.mul(weights,weight_msk).sum(dim=1) + return torch.std(sum_weights).item() \ No newline at end of file diff --git a/src/offline_rl_ope/Metrics/__init__.py b/src/offline_rl_ope/Metrics/__init__.py index 0eebf82..3e2ed5c 100644 --- a/src/offline_rl_ope/Metrics/__init__.py +++ b/src/offline_rl_ope/Metrics/__init__.py @@ -1,3 +1,4 @@ from .EffectiveSampleSize import * from .ValidWeightsProp import * +from .WeightStd import * from .MetricBase import * \ No newline at end of file diff --git a/tests/Metrics/test_ValidWeightsProp.py b/tests/Metrics/test_ValidWeightsProp.py index 7b97f31..85342b6 100644 --- a/tests/Metrics/test_ValidWeightsProp.py +++ b/tests/Metrics/test_ValidWeightsProp.py @@ -17,7 +17,7 @@ def test_call(self): min_val=0.000001 fnl_weights = [] for idx,i in enumerate(self.test_conf.traj_lengths): - fnl_weights.append(self.test_conf.weight_test_res[idx,:i-1].sum( + fnl_weights.append(self.test_conf.weight_test_res[idx,:i].sum( dim=0, keepdim=True )) diff --git a/tests/Metrics/test_WeightStd.py b/tests/Metrics/test_WeightStd.py new file mode 100644 index 0000000..dd59836 --- /dev/null +++ b/tests/Metrics/test_WeightStd.py @@ -0,0 +1,31 @@ +import unittest +import torch +import numpy as np +import copy +from offline_rl_ope.Metrics import WeightStd +from offline_rl_ope import logger +from parameterized import parameterized_class +from ..base import test_configs_fmt_class, TestConfig + +@parameterized_class(test_configs_fmt_class) +class TestWeightStd(unittest.TestCase): + + test_conf:TestConfig + + def test_call(self): + fnl_weights = [] + for idx,i in enumerate(self.test_conf.traj_lengths): + fnl_weights.append( + self.test_conf.weight_test_res[idx,:i].sum( + dim=0, + keepdim=True + ) + ) + fnl_weights_tens = torch.concat(fnl_weights, axis=0) + act_res = torch.std(fnl_weights_tens).item() + metric = WeightStd() + pred_res = metric( + weights=self.test_conf.weight_test_res, + weight_msk=self.test_conf.msk_test_res + ) + self.assertEqual(act_res,pred_res) \ No newline at end of file