-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathops.py
12066 lines (10130 loc) · 511 KB
/
ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2023, Haruka Kiyohara, Ren Kishimoto, HAKUHODO Technologies Inc., and Hanjuku-kaso Co., Ltd. All rights reserved.
# Licensed under the Apache 2.0 License.
"""Meta class to handle Off-Policy Selection (OPS) and evaluation of OPE/OPS."""
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional, Union, List, Dict
from pathlib import Path
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from sklearn.metrics import mean_squared_error
from sklearn.utils import check_scalar
import matplotlib.pyplot as plt
from .ope import (
OffPolicyEvaluation,
CumulativeDistributionOPE,
)
from ..utils import (
MultipleInputDict,
estimate_confidence_interval_by_bootstrap,
estimate_confidence_interval_by_hoeffding,
estimate_confidence_interval_by_empirical_bernstein,
estimate_confidence_interval_by_t_test,
defaultdict_to_dict,
)
from ..types import OPEInputDict
markers = ["o", "v", "^", "s", "p", "P", "*", "h", "X", "D", "d"]
dkred = "#A60628"
@dataclass
class OffPolicySelection:
"""Class to conduct OPS and evaluation of OPE/OPS with multiple estimators simultaneously.
Imported as: :class:`scope_rl.ope.OffPolicySelection`
Note
-----------
**Off-Policy Selection (OPS)**
OPS selects the "best" policy among several candidates based on the policy value or other statistics estimated by OPE.
.. math::
\\hat{\\pi} := {\\arg \\max}_{\\pi \\in \\Pi} \hat{J}(\\pi)
where :math:`\\Pi` is a set of candidate policies and :math:`\hat{J}(\\cdot)` is some OPE estimates of the policy performance. Below, we describe two types of OPE to estimate such policy performance.
**Off-Policy Evaluation (OPE)**
(Basic) OPE estimates the expected policy performance called the policy value.
.. math::
V(\\pi) := \\mathbb{E} \\left[ \\sum_{t=1}^T \\gamma^{t-1} r_t \\mid \\pi \\right]
where :math:`r_t` is the reward observed at each timestep :math:`t`,
:math:`T` is the total number of timesteps in an episode, and :math:`\\gamma` is the discount factor.
.. seealso::
:class:`OffPolicyEvaluation`
**Cumulative Distribution OPE**
In contrast, cumulative distribution OPE first estimates the following cumulative distribution function.
.. math::
F(t, \\pi) := \\mathbb{E} \\left[ \\mathbb{I} \\left \\{ \\sum_{t=1}^T \\gamma^{t-1} r_t \\leq t \\right \\} \\mid \\pi \\right]
Then, cumulative distribution OPE also estimates some risk functions including variance, conditional value at risk, and interquartile range based on the CDF estimate.
.. seealso::
:class:`CumulativeDistributionOPE`
Parameters
-----------
ope: OffPolicyEvaluation, default=None
Instance of the (standard) OPE class.
cumulative_distribution_ope: CumulativeDistributionOPE, default=None
Instance of the cumulative distribution OPE class.
Examples
----------
Preparation:
.. code-block:: python
# import necessary module from SCOPE-RL
from scope_rl.dataset import SyntheticDataset
from scope_rl.policy import EpsilonGreedyHead
from scope_rl.ope import CreateOPEInput
from scope_rl.ope import OffPolicySelection
from scope_rl.ope import OffPolicyEvaluation as OPE
from scope_rl.ope.discrete import TrajectoryWiseImportanceSampling as TIS
from scope_rl.ope.discrete import PerDecisionImportanceSampling as PDIS
from scope_rl.ope import CumulativeDistributionOPE
from scope_rl.ope.discrete import CumulativeDistributionTIS as CD_IS
from scope_rl.ope.discrete import CumulativeDistributionSNTIS as CD_SNIS
# import necessary module from other libraries
import gym
import rtbgym
from d3rlpy.algos import DoubleDQNConfig
from d3rlpy.dataset import create_fifo_replay_buffer
from d3rlpy.algos import ConstantEpsilonGreedy
# initialize environment
env = gym.make("RTBEnv-discrete-v0")
# define (RL) agent (i.e., policy) and train on the environment
ddqn = DoubleDQNConfig().create()
buffer = create_fifo_replay_buffer(
limit=10000,
env=env,
)
explorer = ConstantEpsilonGreedy(
epsilon=0.3,
)
ddqn.fit_online(
env=env,
buffer=buffer,
explorer=explorer,
n_steps=10000,
n_steps_per_epoch=1000,
)
# convert ddqn policy to stochastic data collection policy
behavior_policy = EpsilonGreedyHead(
ddqn,
n_actions=env.action_space.n,
epsilon=0.3,
name="ddqn_epsilon_0.3",
random_state=12345,
)
# initialize dataset class
dataset = SyntheticDataset(
env=env,
max_episode_steps=env.step_per_episode,
)
# data collection
logged_dataset = dataset.obtain_episodes(
behavior_policies=behavior_policy,
n_trajectories=100,
random_state=12345,
)
Create Input for OPE:
.. code-block:: python
# evaluation policy
ddqn_ = EpsilonGreedyHead(
base_policy=ddqn,
n_actions=env.action_space.n,
name="ddqn",
epsilon=0.0,
random_state=12345
)
random_ = EpsilonGreedyHead(
base_policy=ddqn,
n_actions=env.action_space.n,
name="random",
epsilon=1.0,
random_state=12345
)
# create input for off-policy evaluation (OPE)
prep = CreateOPEInput(
env=env,
)
input_dict = prep.obtain_whole_inputs(
logged_dataset=logged_dataset,
evaluation_policies=[ddqn_, random_],
n_trajectories_on_policy_evaluation=100,
random_state=12345,
)
**Off-Policy Evaluation and Selection**:
.. code-block:: python
# OPS
ope = OPE(
logged_dataset=logged_dataset,
ope_estimators=[TIS(), PDIS()],
)
cd_ope = CumulativeDistributionOPE(
logged_dataset=logged_dataset,
ope_estimators=[
CD_IS(estimator_name="cd_is"),
CD_SNIS(estimator_name="cd_snis"),
],
)
ops = OffPolicySelection(
ope=ope,
cumulative_distribution_ope=cd_ope,
)
ops_dict = ops.select_by_policy_value(
input_dict=input_dict,
return_metrics=True,
)
**Output**:
.. code-block:: python
>>> ops_dict
{'tis': {'estimated_ranking': ['ddqn', 'random'],
'estimated_policy_value': array([21.3624954, 0.3827044]),
'estimated_relative_policy_value': array([1.44732354, 0.02592848]),
'mean_squared_error': 94.79587393975419,
'rank_correlation': SpearmanrResult(correlation=0.9999999999999999, pvalue=nan),
'regret': (0.0, 1),
'type_i_error_rate': 0.0,
'type_ii_error_rate': 0.0,
'safety_threshold': 13.284},
'pdis': {'estimated_ranking': ['ddqn', 'random'],
'estimated_policy_value': array([18.02806424, 7.13847486]),
'estimated_relative_policy_value': array([1.22141357, 0.48363651]),
'mean_squared_error': 19.45349619733373,
'rank_correlation': SpearmanrResult(correlation=0.9999999999999999, pvalue=nan),
'regret': (0.0, 1),
'type_i_error_rate': 0.0,
'type_ii_error_rate': 0.0,
'safety_threshold': 13.284}}
.. seealso::
* :doc:`Quickstart </documentation/quickstart>`
* :doc:`Related tutorials (OPS) </documentation/examples/ops>` and :doc:`related tutorials (assessments) <documentation/examples/assessments>`
References
-------
Vladislav Kurenkov and Sergey Kolesnikov.
"Showing Your Offline Reinforcement Learning Work: Online Evaluation Budget Matters." 2022.
Shengpu Tang and Jenna Wiens.
"Model Selection for Offline Reinforcement Learning: Practical Considerations for Healthcare Settings." 2021.
Justin Fu, Mohammad Norouzi, Ofir Nachum, George Tucker, Ziyu Wang, Alexander Novikov, Mengjiao Yang,
Michael R. Zhang, Yutian Chen, Aviral Kumar, Cosmin Paduraru, Sergey Levine, and Tom Le Paine.
"Benchmarks for Deep Off-Policy Evaluation." 2021.
Tom Le Paine, Cosmin Paduraru, Andrea Michi, Caglar Gulcehre, Konrad Zolna, Alexander Novikov, Ziyu Wang, and Nando de Freitas.
"Hyperparameter Selection for Offline Reinforcement Learning." 2020.
"""
ope: Optional[OffPolicyEvaluation] = None
cumulative_distribution_ope: Optional[CumulativeDistributionOPE] = None
def __post_init__(self):
if self.ope is None and self.cumulative_distribution_ope is None:
raise RuntimeError(
"one of `ope` or `cumulative_distribution_ope` must be given"
)
if self.ope is not None and not isinstance(self.ope, OffPolicyEvaluation):
raise RuntimeError("ope must be the instance of OffPolicyEvaluation")
if self.cumulative_distribution_ope is not None and not isinstance(
self.cumulative_distribution_ope, CumulativeDistributionOPE
):
raise RuntimeError(
"cumulative_distribution_ope must be the instance of CumulativeDistributionOPE"
)
self.step_per_trajectory = self.ope.logged_dataset["step_per_trajectory"]
check_scalar(
self.step_per_trajectory,
name="ope.logged_dataset['step_per_trajectory']",
target_type=int,
min_val=1,
)
self.behavior_policy_reward = {}
if self.ope.use_multiple_logged_dataset:
for (
behavior_policy
) in self.ope.multiple_logged_dataset.behavior_policy_names:
logged_dataset_ = self.ope.multiple_logged_dataset.get(
behavior_policy_name=behavior_policy, dataset_id=0
)
self.behavior_policy_reward[behavior_policy] = logged_dataset_[
"reward"
].reshape((-1, self.step_per_trajectory))
if self.ope.disable_reward_after_done:
done = logged_dataset_["done"].reshape(
(-1, self.step_per_trajectory)
)
self.behavior_policy_reward[
behavior_policy
] = self.behavior_policy_reward[behavior_policy] * (
1 - done
).cumprod(
axis=1
)
else:
behavior_policy = self.ope.logged_dataset["behavior_policy"]
self.behavior_policy_reward[behavior_policy] = self.ope.logged_dataset[
"reward"
].reshape((-1, self.step_per_trajectory))
if self.ope.disable_reward_after_done:
done = self.ope.logged_dataset["done"].reshape(
(-1, self.step_per_trajectory)
)
self.behavior_policy_reward[
behavior_policy
] = self.behavior_policy_reward[behavior_policy] * (1 - done).cumprod(
axis=1
)
self._estimate_confidence_interval = {
"bootstrap": estimate_confidence_interval_by_bootstrap,
"hoeffding": estimate_confidence_interval_by_hoeffding,
"bernstein": estimate_confidence_interval_by_empirical_bernstein,
"ttest": estimate_confidence_interval_by_t_test,
}
def _check_compared_estimators(
self,
compared_estimators: Optional[List[str]] = None,
ope_type: str = "standard_ope",
):
if ope_type == "standard_ope":
if self.ope is None:
raise RuntimeError(
"ope is not given. Please initialize the class with ope attribute"
)
else:
if self.cumulative_distribution_ope is None:
raise RuntimeError(
"cumulative_distribution_ope is not given. Please initialize the class with cumulative_distribution_ope attribute"
)
if compared_estimators is None:
compared_estimators = self.estimators_name[ope_type]
elif not set(compared_estimators).issubset(self.estimators_name[ope_type]):
raise ValueError(
f"compared_estimators must be a subset of self.estimators_name['{ope_type}'], but found False."
)
return compared_estimators
def _check_basic_visualization_inputs(
self,
n_cols: Optional[int] = None,
fig_dir: Optional[Path] = None,
fig_name: Optional[str] = None,
):
if n_cols is not None:
check_scalar(n_cols, name="n_cols", target_type=int, min_val=1)
if fig_dir is not None and not isinstance(fig_dir, Path):
raise ValueError(f"fig_dir must be a Path, but {type(fig_dir)} is given")
if fig_name is not None and not isinstance(fig_name, str):
raise ValueError(f"fig_dir must be a string, but {type(fig_dir)} is given")
def _check_topk_inputs(
self,
input_dict: Union[OPEInputDict, MultipleInputDict],
behavior_policy_name: Optional[str] = None,
dataset_id: Optional[int] = None,
max_topk: Optional[int] = None,
metrics: Optional[List[str]] = None,
safety_threshold: Optional[float] = None,
relative_safety_criteria: Optional[float] = None,
gamma: Optional[float] = None,
):
if isinstance(input_dict, MultipleInputDict):
max_topk_ = 100
if behavior_policy_name is None:
if dataset_id is None:
for n_eval_policies in input_dict.n_eval_policies.values():
max_topk_ = min(max_topk_, n_eval_policies.min())
else:
for n_eval_policies in input_dict.n_eval_policies.values():
max_topk_ = min(max_topk_, n_eval_policies[dataset_id])
else:
if dataset_id is None:
max_topk_ = min(
max_topk_,
input_dict.n_eval_policies[behavior_policy_name].min(),
)
else:
max_topk_ = input_dict.n_eval_policies[behavior_policy_name][
dataset_id
]
else:
behavior_policy_name = input_dict[list(input_dict.keys())[0]][
"behavior_policy"
]
max_topk_ = len(input_dict)
if max_topk is None:
max_topk = int(max_topk_)
else:
check_scalar(max_topk, name="max_topk", target_type=int, min_val=1)
max_topk = min(max_topk, max_topk_)
if metrics is not None:
for metric in metrics:
if metric not in [
"k-th",
"best",
"worst",
"mean",
"std",
"safety_violation_rate",
"sharpe_ratio",
]:
raise ValueError(
f"The elements of metrics must be one of 'k-th', 'best', 'worst', 'mean', 'std', 'safety_violation_rate', or 'sharpe_ratio', but {metric} is given."
)
if safety_threshold is None:
if relative_safety_criteria is not None:
check_scalar(
relative_safety_criteria,
name="relative_safety_criteria",
target_type=float,
min_val=0.0,
)
discount = np.full(self.step_per_trajectory, gamma).cumprod() / gamma
if behavior_policy_name is not None:
behavior_policy_reward = self.behavior_policy_reward[
behavior_policy_name
]
behavior_policy_value = (
discount[np.newaxis, :] * behavior_policy_reward
).sum(
axis=1
).mean() + 1e-10 # to avoid zero division
safety_threshold = relative_safety_criteria * behavior_policy_value
safety_threshold = float(safety_threshold)
elif len(self.behavior_policy_reward) == 1:
behavior_policy_reward = list(self.behavior_policy_reward.values())[
0
]
behavior_policy_value = (
discount[np.newaxis, :] * behavior_policy_reward
).sum(
axis=1
).mean() + 1e-10 # to avoid zero division
safety_threshold = relative_safety_criteria * behavior_policy_value
safety_threshold = float(safety_threshold)
else:
safety_threshold = 0.0
else:
safety_threshold = 0.0
check_scalar(
safety_threshold,
name="safety_threshold",
target_type=float,
)
return max_topk, safety_threshold
def _obtain_true_selection_result(
self,
input_dict: OPEInputDict,
return_variance: bool = False,
return_lower_quartile: bool = False,
return_conditional_value_at_risk: bool = False,
return_by_dataframe: bool = False,
quartile_alpha: float = 0.05,
cvar_alpha: float = 0.05,
):
"""Obtain the oracle selection result based on the ground-truth policy value.
Parameters
-------
input_dict: OPEInputDict
Dictionary of the OPE inputs for each evaluation policy.
.. code-block:: python
key: [evaluation_policy][
evaluation_policy_action,
evaluation_policy_action_dist,
state_action_value_prediction,
initial_state_value_prediction,
state_action_marginal_importance_weight,
state_marginal_importance_weight,
on_policy_policy_value,
gamma,
behavior_policy,
evaluation_policy,
dataset_id,
]
.. seealso::
:class:`scope_rl.ope.input.CreateOPEInput` describes the components of :class:`input_dict`.
return_variance: bool, default=False
Whether to return the variance or not.
return_lower_quartile: bool. default=False
Whether to return the lower interquartile or not.
return_conditional_value_at_risk: bool, default=False
Whether to return the conditional value at risk or not.
return_by_dataframe: bool, default=False
Whether to return the result in a dataframe format.
quartile_alpha: float, default=0.05
Proportion of the shaded region of the interquartile range.
cvar_alpha: float, default=0.05
Proportion of the shaded region of the conditional value at risk.
Return
-------
ground_truth_dict/ground_truth_df: dict or dataframe
Dictionary/dataframe containing the following ground-truth (on-policy) metrics.
.. code-block:: python
key: [
ranking,
policy_value,
relative_policy_value,
variance,
ranking_by_lower_quartile,
lower_quartile,
ranking_by_conditional_value_at_risk,
conditional_value_at_risk,
parameters, # only when return_by_dataframe == False
]
ranking: list of str
Name of the candidate policies sorted by the ground-truth policy value.
policy_value: list of float
Ground-truth policy value of the candidate policies (sorted by ranking).
relative_policy_value: list of float
Ground-truth relative policy value of the candidate policies compared to the behavior policy (sorted by ranking).
variance: list of float
Ground-truth variance of the trajectory-wise reward of the candidate policies (sorted by ranking).
If return_variance is `False`, `None` is recorded.
ranking_by_lower_quartile: list of str
Name of the candidate policies sorted by the ground-truth lower quartile of the trajectory-wise reward.
If return_lower_quartile is `False`, `None` is recorded.
lower_quartile: list of float
Ground-truth lower quartile of the candidate policies (sorted by ranking_by_lower_quartile).
If return_lower_quartile is `False`, `None` is recorded.
ranking_by_conditional_value_at_risk: list of str
Name of the candidate policies sorted by the ground-truth conditional value at risk.
If return_conditional_value_at_risk is `False`, `None` is recorded.
conditional_value_at_risk: list of float
Ground-truth conditional value at risk of the candidate policies (sorted by ranking_by_conditional_value_at_risk).
If return_conditional_value_at_risk is `False`, `None` is recorded.
parameters: dict
Dictionary containing quartile_alpha, and cvar_alpha.
If return_by_dataframe is `True`, parameters will not be returned.
"""
candidate_policy_names = list(input_dict.keys())
for eval_policy in candidate_policy_names:
if input_dict[eval_policy]["on_policy_policy_value"] is None:
raise ValueError(
f"one of the candidate policies, {eval_policy}, does not contain on-policy policy value in input_dict"
)
behavior_policy = input_dict[eval_policy]["behavior_policy"]
n_policies = len(candidate_policy_names)
n_samples = len(input_dict[eval_policy]["on_policy_policy_value"])
policy_value = np.zeros(n_policies)
for i, eval_policy in enumerate(candidate_policy_names):
policy_value[i] = input_dict[eval_policy]["on_policy_policy_value"].mean()
ranking_index = np.argsort(policy_value)[::-1]
ranking = [candidate_policy_names[ranking_index[i]] for i in range(n_policies)]
gamma = input_dict[eval_policy]["gamma"]
discount = np.full(self.step_per_trajectory, gamma).cumprod() / gamma
behavior_policy_reward = self.behavior_policy_reward[behavior_policy]
behavior_policy_value = (discount[np.newaxis, :] * behavior_policy_reward).sum(
axis=1
).mean() + 1e-10 # to avoid zero division
policy_value = np.sort(policy_value)[::-1]
relative_policy_value = policy_value / behavior_policy_value
if return_variance:
variance = np.zeros(n_policies)
for i, eval_policy in enumerate(candidate_policy_names):
variance[i] = input_dict[eval_policy]["on_policy_policy_value"].var(
ddof=1
)
variance = variance[ranking_index]
if return_lower_quartile:
lower_quartile = np.zeros(n_policies)
for i, eval_policy in enumerate(candidate_policy_names):
lower_quartile[i] = np.quantile(
input_dict[eval_policy]["on_policy_policy_value"], q=quartile_alpha
)
quartile_ranking_index = np.argsort(policy_value)[::-1]
ranking_by_lower_quartile = [
candidate_policy_names[quartile_ranking_index[i]]
for i in range(n_policies)
]
lower_quartile = np.sort(lower_quartile)[::-1]
if return_conditional_value_at_risk:
cvar = np.zeros(n_policies)
for i, eval_policy in enumerate(candidate_policy_names):
cvar[i] = np.sort(input_dict[eval_policy]["on_policy_policy_value"])[
: int(n_samples * cvar_alpha)
].mean()
cvar_ranking_index = np.argsort(cvar)[::-1]
ranking_by_cvar = [
candidate_policy_names[cvar_ranking_index[i]] for i in range(n_policies)
]
cvar = np.sort(cvar)[::-1]
ground_truth_dict = {
"ranking": ranking,
"policy_value": policy_value,
"relative_policy_value": relative_policy_value,
"variance": variance if return_variance else None,
"ranking_by_lower_quartile": ranking_by_lower_quartile
if return_lower_quartile
else None,
"lower_quartile": lower_quartile if return_lower_quartile else None,
"ranking_by_conditional_value_at_risk": ranking_by_cvar
if return_conditional_value_at_risk
else None,
"conditional_value_at_risk": cvar
if return_conditional_value_at_risk
else None,
"parameters": {
"quartile_alpha": quartile_alpha if return_lower_quartile else None,
"cvar_alpha": cvar_alpha if return_conditional_value_at_risk else None,
},
}
if return_by_dataframe:
ground_truth_df = pd.DataFrame()
for key in ground_truth_dict.keys():
if ground_truth_dict[key] is None or key == "parameters":
continue
ground_truth_df[key] = ground_truth_dict[key]
return ground_truth_df if return_by_dataframe else ground_truth_dict
def _select_by_policy_value(
self,
input_dict: OPEInputDict,
compared_estimators: Optional[List[str]] = None,
return_true_values: bool = False,
return_metrics: bool = False,
return_by_dataframe: bool = False,
top_k_in_eval_metrics: int = 1,
safety_threshold: Optional[float] = None,
relative_safety_criteria: Optional[float] = None,
):
"""Rank the candidate policies by their estimated policy values.
Parameters
-------
input_dict: OPEInputDict
Dictionary of the OPE inputs for each evaluation policy.
.. code-block:: python
key: [evaluation_policy][
evaluation_policy_action,
evaluation_policy_action_dist,
state_action_value_prediction,
initial_state_value_prediction,
state_action_marginal_importance_weight,
state_marginal_importance_weight,
on_policy_policy_value,
gamma,
behavior_policy,
evaluation_policy,
dataset_id,
]
.. seealso::
:class:`scope_rl.ope.input.CreateOPEInput` describes the components of :class:`input_dict`.
compared_estimators: list of str, default=None
Name of compared estimators.
When `None` is given, all the estimators are compared.
return_true_values: bool, default=False
Whether to return the true policy value and corresponding ranking of the candidate policies.
return_metrics: bool, default=False
Whether to return the following evaluation metrics in terms of OPE and OPS:
mean-squared-error, rank-correlation, regret@k, and Type I and Type II error rate.
return_by_dataframe: bool, default=False
Whether to return the result in a dataframe format.
top_k_in_eval_metrics: int, default=1
How many candidate policies are included in regret@k.
safety_threshold: float, default=None.
A policy whose policy value is below the given threshold is to be considered unsafe.
relative_safety_criteria: float, default=None (>= 0)
The relative policy value required to be considered a safe policy.
For example, when 0.9 is given, candidate policy must exceed 90\\% of the behavior policy performance.
Only applicable when using a single behavior policy.
Return
-------
ops_dict/(ranking_df_dict, metric_df): dict or dataframe
Dictionary/dataframe containing the result of OPS conducted by OPE estimators.
.. code-block:: python
key: [estimator_name][
estimated_ranking,
estimated_policy_value,
estimated_relative_policy_value,
true_ranking,
true_policy_value,
true_relative_policy_value,
mean_squared_error,
rank_correlation,
regret,
type_i_error_rate,
type_ii_error_rate,
]
estimated_ranking: list of str
Name of the candidate policies sorted by the estimated policy value.
Recorded in ranking_df_dict if return_by_dataframe is `True`.
estimated_policy_value: list of float
Estimated policy value of the candidate policies (sorted by estimated_ranking).
Recorded in ranking_df_dict if return_by_dataframe is `True`.
estimated_relative_policy_value: list of float
Estimated relative policy value of the candidate policies compared to the behavior policy (sorted by estimated_ranking).
Recorded in ranking_df_dict if return_by_dataframe is `True`.
true_ranking: list of int
Ranking index of the (true) policy value of the candidate policies (sorted by estimated_ranking).
Recorded only when return_true_values is `True`.
Recorded in ranking_df_dict if return_by_dataframe is `True`.
true_policy_value: list of float
True policy value of the candidate policies (sorted by estimated_ranking).
Recorded only when return_true_values is `True`.
Recorded in ranking_df_dict when return_by_dataframe is `True`.
true_relative_policy_value: list of float
True relative policy value of the candidate policies compared to the behavior policy (sorted by estimated_ranking).
Recorded only when return_true_values is `True`.
Recorded in ranking_df_dict if return_by_dataframe is `True`.
mean_squared_error: float
Mean-squared-error of the estimators calculated across candidate evaluation policies.
Recorded only when return_metric is `True`.
Recorded in metric_df if return_by_dataframe is `True`.
rank_correlation: tuple of float
Rank correlation coefficient between the true ranking and the estimated ranking, and its pvalue.
Recorded only when return_metric is `True`.
Recorded in metric_df if return_by_dataframe is `True`.
regret: tuple of float and int
Regret@k and k.
Recorded only when return_metric is `True`.
Recorded in metric_df if return_by_dataframe is `True`.
type_i_error_rate: float
Type I error rate of the hypothetical test. True Negative when the policy is safe but estimated as unsafe.
Recorded only when return_metric is `True`.
Recorded in metric_df if return_by_dataframe is `True`.
type_ii_error_rate: float
Type II error rate of the hypothetical test. False Positive when the policy is unsafe but undetected.
Recorded only when return_metric is `True`.
Recorded in metric_df when return_by_dataframe is `True`.
safety_threshold: float
A policy whose policy value is below the given threshold is to be considered unsafe.
"""
behavior_policy_name = list(input_dict.values())[0]["behavior_policy"]
dataset_id = list(input_dict.values())[0]["dataset_id"]
gamma = list(input_dict.values())[0]["gamma"]
discount = np.full(self.step_per_trajectory, gamma).cumprod() / gamma
behavior_policy_reward = self.behavior_policy_reward[behavior_policy_name]
behavior_policy_value = (discount[np.newaxis, :] * behavior_policy_reward).sum(
axis=1
).mean() + 1e-10 # to avoid zero division
if safety_threshold is None:
if relative_safety_criteria is None:
safety_threshold = 0.0
else:
safety_threshold = relative_safety_criteria * behavior_policy_value
estimated_policy_value_dict = self.ope.estimate_policy_value(
input_dict,
compared_estimators=compared_estimators,
behavior_policy_name=behavior_policy_name,
dataset_id=dataset_id,
)
ground_truth_dict = self.obtain_true_selection_result(input_dict)
true_ranking = ground_truth_dict["ranking"]
true_policy_value = ground_truth_dict["policy_value"]
candidate_policy_names = (
true_ranking if return_metrics else list(input_dict.keys())
)
n_policies = len(candidate_policy_names)
ops_dict = {}
for i, estimator in enumerate(compared_estimators):
estimated_policy_value_ = np.zeros(n_policies)
true_policy_value_ = np.zeros(n_policies)
for j, eval_policy in enumerate(candidate_policy_names):
estimated_policy_value_[j] = estimated_policy_value_dict[eval_policy][
estimator
]
true_policy_value_[j] = true_policy_value[j]
estimated_ranking_index_ = np.argsort(estimated_policy_value_)[::-1]
true_ranking_index_ = np.argsort(true_policy_value_)[::-1]
estimated_ranking = [
candidate_policy_names[estimated_ranking_index_[i]]
for i in range(n_policies)
]
estimated_policy_value = np.sort(estimated_policy_value_)[::-1]
estimated_relative_policy_value = (
estimated_policy_value / behavior_policy_value
)
if return_metrics:
mse = mean_squared_error(true_policy_value, estimated_policy_value_)
rankcorr = spearmanr(np.arange(n_policies), estimated_ranking_index_)
regret = (
true_policy_value[:top_k_in_eval_metrics].sum()
- true_policy_value[estimated_ranking_index_][
:top_k_in_eval_metrics
].sum()
)
true_safety = true_policy_value >= safety_threshold
estimated_safety = estimated_policy_value_ >= safety_threshold
if true_safety.sum() > 0:
type_i_error_rate = (
true_safety > estimated_safety
).sum() / true_safety.sum()
else:
type_i_error_rate = 0.0
if (1 - true_safety).sum() > 0:
type_ii_error_rate = (true_safety < estimated_safety).sum() / (
1 - true_safety
).sum()
else:
type_ii_error_rate = 0.0
ops_dict[estimator] = {
"estimated_ranking": estimated_ranking,
"estimated_policy_value": estimated_policy_value,
"estimated_relative_policy_value": estimated_relative_policy_value,
}
if return_true_values:
ops_dict[estimator]["true_ranking"] = true_ranking_index_[
estimated_ranking_index_
]
ops_dict[estimator]["true_policy_value"] = true_policy_value_[
estimated_ranking_index_
]
ops_dict[estimator]["true_relative_policy_value"] = (
true_policy_value_[estimated_ranking_index_] / behavior_policy_value
)
if return_metrics:
ops_dict[estimator]["mean_squared_error"] = mse
ops_dict[estimator]["rank_correlation"] = rankcorr
ops_dict[estimator]["regret"] = (regret, top_k_in_eval_metrics)
ops_dict[estimator]["type_i_error_rate"] = type_i_error_rate
ops_dict[estimator]["type_ii_error_rate"] = type_ii_error_rate
ops_dict[estimator]["safety_threshold"] = safety_threshold
if return_by_dataframe:
ranking_df_dict = defaultdict(pd.DataFrame)
for i, estimator in enumerate(compared_estimators):
ranking_df_ = pd.DataFrame()
ranking_df_["estimated_ranking"] = ops_dict[estimator][
"estimated_ranking"
]
ranking_df_["estimated_policy_value"] = ops_dict[estimator][
"estimated_policy_value"
]
ranking_df_["estimated_relative_policy_value"] = ops_dict[estimator][
"estimated_relative_policy_value"
]
if return_true_values:
ranking_df_["true_ranking"] = ops_dict[estimator]["true_ranking"]
ranking_df_["true_policy_value"] = ops_dict[estimator][
"true_policy_value"
]
ranking_df_["true_relative_policy_value"] = ops_dict[estimator][
"true_relative_policy_value"
]
ranking_df_dict[estimator] = ranking_df_
ranking_df_dict = defaultdict_to_dict(ranking_df_dict)
if return_metrics:
(
mse,
rankcorr,
pvalue,
regret,
type_i,
type_ii,
) = (
[],
[],
[],
[],
[],
[],
)
for i, estimator in enumerate(compared_estimators):
mse.append(ops_dict[estimator]["mean_squared_error"])
rankcorr.append(ops_dict[estimator]["rank_correlation"][0])
pvalue.append(ops_dict[estimator]["rank_correlation"][1])
regret.append(ops_dict[estimator]["regret"][0])
type_i.append(ops_dict[estimator]["type_i_error_rate"])
type_ii.append(ops_dict[estimator]["type_ii_error_rate"])
metric_df = pd.DataFrame()
metric_df["estimator"] = compared_estimators
metric_df["mean_squared_error"] = mse
metric_df["rank_correlation"] = rankcorr
metric_df["pvalue"] = pvalue
metric_df[f"regret@{top_k_in_eval_metrics}"] = regret
metric_df["type_i_error_rate"] = type_i
metric_df["type_ii_error_rate"] = type_ii
dfs = (ranking_df_dict, metric_df) if return_metrics else ranking_df_dict
return dfs if return_by_dataframe else ops_dict
def _select_by_policy_value_via_cumulative_distribution_ope(
self,
input_dict: OPEInputDict,
compared_estimators: Optional[List[str]] = None,
return_true_values: bool = False,
return_metrics: bool = False,
return_by_dataframe: bool = False,
top_k_in_eval_metrics: int = 1,