Skip to content
This repository was archived by the owner on Mar 12, 2024. It is now read-only.

Commit 79b3d52

Browse files
authored
Optuna Integration (#215)
* optuna integration and tests * optuna integration * optuna integration * added typing for tests
1 parent 6e02427 commit 79b3d52

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

hiplot/experiment.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
if tp.TYPE_CHECKING:
1616
import pandas as pd
1717
from .streamlit_helpers import ExperimentStreamlitComponent
18+
import optuna
1819

1920
DisplayableType = tp.Union[bool, int, float, str]
2021

@@ -502,6 +503,31 @@ def from_dataframe(dataframe: "pd.DataFrame") -> "Experiment": # No type hint t
502503

503504
return experiment
504505

506+
@staticmethod
507+
def from_optuna(study: "optuna.study.Study") -> "Experiment": # No type hint to avoid having optuna as an additional dependency
508+
"""
509+
Creates a HiPlot experiment from a Optuna Study.
510+
511+
:param study: Optuna Study
512+
"""
513+
514+
515+
# Create a list of dictionary objects using study trials
516+
# All parameters are taken using params.copy()
517+
518+
hyper_opt_data = []
519+
for each_trial in study.trials:
520+
trial_params = {}
521+
trial_params["value"] = each_trial.value # name = value, as it could be RMSE / accuracy, or any value that the user selects for tuning
522+
trial_params["uid"] = each_trial.number
523+
trial_params.update(each_trial.params.copy())
524+
hyper_opt_data.append(trial_params)
525+
experiment = Experiment.from_iterable(hyper_opt_data)
526+
527+
return experiment
528+
529+
530+
505531
@staticmethod
506532
def merge(xp_dict: tp.Dict[str, "Experiment"]) -> "Experiment":
507533
"""

hiplot/test_experiment.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import pytest
1212
import pandas as pd
13+
import optuna
1314

1415
import hiplot as hip
1516

@@ -39,6 +40,24 @@ def test_from_dataframe() -> None:
3940
xp.validate()
4041
xp._asdict()
4142

43+
def test_from_optuna() -> None:
44+
45+
def objective(trial: "optuna.trial.Trial") -> float:
46+
x = trial.suggest_float("x", -1, 1)
47+
return x ** 2
48+
49+
study = optuna.create_study()
50+
study.optimize(objective, n_trials=3)
51+
52+
# Create a dataframe from the study.
53+
df = study.trials_dataframe()
54+
assert isinstance(df, pd.DataFrame)
55+
assert df.shape[0] == 3 # n_trials.
56+
xp = hip.Experiment.from_optuna(study)
57+
assert len(xp.datapoints) == 3
58+
xp.validate()
59+
xp._asdict()
60+
4261

4362
def test_from_dataframe_nan_values() -> None:
4463
# Pandas automatically convert numeric-based columns None to NaN in dataframes

requirements/dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ pre-commit
1111
pandas
1212
streamlit>=0.63
1313
beautifulsoup4
14+
optuna

0 commit comments

Comments
 (0)