Skip to content

Commit

Permalink
escape errors in swarmplot
Browse files Browse the repository at this point in the history
  • Loading branch information
aiueola committed Jul 23, 2023
1 parent e823f6f commit ea029c2
Showing 1 changed file with 73 additions and 52 deletions.
125 changes: 73 additions & 52 deletions scope_rl/ope/ope.py
Original file line number Diff line number Diff line change
Expand Up @@ -2059,14 +2059,17 @@ def visualize_policy_value_with_multiple_estimates(
)

else:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
hue="behavior_policy",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
hue="behavior_policy",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

on_policy = policy_value_dict[behavior_policy][eval_policy][
"on_policy"
Expand Down Expand Up @@ -2272,13 +2275,16 @@ def visualize_policy_value_with_multiple_estimates(
ax=ax,
)
else:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

on_policy = policy_value_dict[eval_policy]["on_policy"]
if on_policy is not None:
Expand Down Expand Up @@ -2357,13 +2363,16 @@ def visualize_policy_value_with_multiple_estimates(
)

else:
sns.swarmplot(
data=df,
x="eval_policy",
y="policy_value",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="eval_policy",
y="policy_value",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

if visualize_on_policy:
ax.scatter(
Expand Down Expand Up @@ -5530,14 +5539,17 @@ def _visualize_off_policy_estimates_with_multiple_estimates(
)

else:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
hue="behavior_policy",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
hue="behavior_policy",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

on_policy = estimation_dict[behavior_policy][eval_policy][
"on_policy"
Expand Down Expand Up @@ -5645,14 +5657,17 @@ def _visualize_off_policy_estimates_with_multiple_estimates(
)

else:
sns.swarmplot(
data=df,
x="eval_policy",
y="policy_value",
hue="behavior_policy",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="eval_policy",
y="policy_value",
hue="behavior_policy",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

if visualize_on_policy:
ax.scatter(
Expand Down Expand Up @@ -5740,13 +5755,16 @@ def _visualize_off_policy_estimates_with_multiple_estimates(
ax=ax,
)
else:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="estimator",
y="policy_value",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

on_policy = estimation_dict[eval_policy]["on_policy"]
if on_policy is not None:
Expand Down Expand Up @@ -5838,13 +5856,16 @@ def _visualize_off_policy_estimates_with_multiple_estimates(
)

else:
sns.swarmplot(
data=df,
x="eval_policy",
y="policy_value",
palette=palette,
ax=ax,
)
try:
sns.swarmplot(
data=df,
x="eval_policy",
y="policy_value",
palette=palette,
ax=ax,
)
except:
warn("Encountered NaN values during plot.")

if visualize_on_policy:
ax.scatter(
Expand Down

0 comments on commit ea029c2

Please sign in to comment.