Skip to content

Commit

Permalink
[FIX] fix seed in comparison of agents (#394)
Browse files Browse the repository at this point in the history
* fix seed in comparison of agents

* precommit
  • Loading branch information
TimotheeMathieu authored Nov 16, 2023
1 parent e951ea2 commit d92845b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/doc_stable.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ name: documentation_stable
on:
push:
# Pattern matched against refs/tags
tags:
tags:
- '*' # Push events to every tag not containing /


permissions:
contents: write

Expand Down
12 changes: 9 additions & 3 deletions rlberry/manager/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import rlberry
from rlberry.manager import ExperimentManager
from rlberry.seeding import Seeder
import pathlib

logger = rlberry.logger
Expand All @@ -18,6 +19,7 @@ def compare_agents(
n_simulations=50,
alpha=0.05,
B=10_000,
seed=None,
):
"""
Compare several trained agents using the mean over n_simulations evaluations for each agent.
Expand All @@ -41,6 +43,8 @@ def compare_agents(
Level of the test, control the Family-wise error.
B: int, default = 10_000
Number of random permutations used to approximate the permutation test if method = "permutation"
seed: int or None,
The seed of the random number generator from which we sample permutations. If None, create one.
Returns
-------
Expand All @@ -53,6 +57,7 @@ def compare_agents(
[2]: Testing Statistical Hypotheses by E. L. Lehmann, Joseph P. Romano (Section 15.4.4), https://doi.org/10.1007/0-387-27605-X, Springer
"""

# Construction of the array of evaluations
df = pd.DataFrame()
assert isinstance(agent_source, list)
Expand Down Expand Up @@ -156,7 +161,7 @@ def compare_agents(
}
)
elif method == "permutation":
results_perm = _permutation_test(data, B, alpha) == 1
results_perm = _permutation_test(data, B, alpha, seed) == 1
decisions = [
"accept" if results_perm[i][j] else "reject"
for i in range(n_agents)
Expand All @@ -181,7 +186,7 @@ def compare_agents(
return results


def _permutation_test(data, B, alpha):
def _permutation_test(data, B, alpha, seed):
"""
Permutation test with Step-Down method
"""
Expand All @@ -195,6 +200,7 @@ def _permutation_test(data, B, alpha):

decisions = np.array(["accept" for i in range(len(comparisons))])
comparisons_alive = np.arange(len(comparisons))
seeder = Seeder(seed)

logger.info("Beginning permutationt test")
while True:
Expand All @@ -205,7 +211,7 @@ def _permutation_test(data, B, alpha):
if B is None:
permutations = combinations(2 * n_fit, n_fit)
else:
permutations = (np.random.permutation(2 * n_fit) for _ in range(B))
permutations = (seeder.rng.permutation(2 * n_fit) for _ in range(B))

# Test statistics
T0_max = 0
Expand Down
5 changes: 1 addition & 4 deletions rlberry/manager/tests/test_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_compare(method):
agent1.fit()
agent2.fit()

df = compare_agents([agent1, agent2], method=method, B=10, n_simulations=5)
df = compare_agents([agent1, agent2], method=method, B=20, n_simulations=5, seed=42)
assert len(df) > 0
if method == "tukey_hsd":
assert df["p-val"].item() < 0.05
Expand All @@ -74,6 +74,3 @@ def test_compare(method):
[agent1_pickle, agent2_pickle], method=method, B=10, n_simulations=5
)
assert len(df) > 0


test_compare("tukey_hsd")

0 comments on commit d92845b

Please sign in to comment.