Skip to content
This repository was archived by the owner on Mar 12, 2024. It is now read-only.

Commit 4bd678e

Browse files
authored
Support multi-objective study for from_optuna (#216)
* clarify the supported study type * revert docs change and support multi-objective * add test for from_optuna with multi-objective value * fix returned type
1 parent 79b3d52 commit 4bd678e

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

hiplot/experiment.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# LICENSE file in the root directory of this source tree.
44

55
import csv
6+
import enum
67
import uuid
78
import json
89
import warnings
@@ -514,11 +515,18 @@ def from_optuna(study: "optuna.study.Study") -> "Experiment": # No type hint to
514515

515516
# Create a list of dictionary objects using study trials
516517
# All parameters are taken using params.copy()
517-
518+
518519
hyper_opt_data = []
519520
for each_trial in study.trials:
520521
trial_params = {}
521-
trial_params["value"] = each_trial.value # name = value, as it could be RMSE / accuracy, or any value that the user selects for tuning
522+
num_objectives = len(each_trial.values)
523+
524+
if num_objectives == 1:
525+
trial_params["value"] = each_trial.value # name = value, as it could be RMSE / accuracy, or any value that the user selects for tuning
526+
else:
527+
for objective_id, value in enumerate(each_trial.values):
528+
trial_params[f"value_{objective_id}"] = value
529+
522530
trial_params["uid"] = each_trial.number
523531
trial_params.update(each_trial.params.copy())
524532
hyper_opt_data.append(trial_params)

hiplot/test_experiment.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,26 @@ def objective(trial: "optuna.trial.Trial") -> float:
5959
xp._asdict()
6060

6161

62+
def test_from_optuna_multi_objective() -> None:
63+
64+
def objective(trial: "optuna.trial.Trial") -> tp.Tuple[float, float]:
65+
x = trial.suggest_float("x", -1, 1)
66+
y = trial.suggest_float("y", -1, 1)
67+
return x ** 2, y
68+
69+
study = optuna.create_study(directions=["minimize", "minimize"])
70+
study.optimize(objective, n_trials=3)
71+
72+
# Create a dataframe from the study.
73+
df = study.trials_dataframe()
74+
assert isinstance(df, pd.DataFrame)
75+
assert df.shape[0] == 3 # n_trials.
76+
xp = hip.Experiment.from_optuna(study)
77+
assert len(xp.datapoints) == 3
78+
xp.validate()
79+
xp._asdict()
80+
81+
6282
def test_from_dataframe_nan_values() -> None:
6383
# Pandas automatically convert numeric-based columns None to NaN in dataframes
6484
# Pandas will also automatically convert columns with NaN from integer to floats, since NaN is considered a float

0 commit comments

Comments
 (0)