Skip to content
This repository was archived by the owner on Mar 12, 2024. It is now read-only.

Commit c7f0cd4

Browse files
authored
Fix encoding in hiplot-render for windows users (#220)
Co-authored-by: danthe3rd <[email protected]>
1 parent eaa37fc commit c7f0cd4

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

hiplot/experiment.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
# LICENSE file in the root directory of this source tree.
44

55
import csv
6-
import enum
76
import uuid
87
import json
8+
import codecs
99
import warnings
1010
from abc import ABCMeta, abstractmethod
1111
from enum import Enum
@@ -18,6 +18,8 @@
1818
from .streamlit_helpers import ExperimentStreamlitComponent
1919
import optuna
2020

21+
TextWriterIO = tp.Union[tp.IO[str], codecs.StreamWriter]
22+
2123
DisplayableType = tp.Union[bool, int, float, str]
2224

2325

@@ -345,7 +347,7 @@ def get_experiment():
345347
To render an experiment to HTML, use `experiment.to_html(file_name)` or `html_page = experiment.to_html()`""")
346348
return streamlit_helpers.ExperimentStreamlitComponent(self, key=key, ret=ret)
347349

348-
def to_html(self, file: tp.Optional[tp.Union[Path, str, tp.IO[str]]] = None, **kwargs: tp.Any) -> str:
350+
def to_html(self, file: tp.Optional[tp.Union[Path, str, TextWriterIO]] = None, **kwargs: tp.Any) -> str:
349351
"""
350352
Returns the content of a standalone .html file that displays this experiment
351353
without any dependency to HiPlot server or static files.
@@ -368,7 +370,7 @@ def to_html(self, file: tp.Optional[tp.Union[Path, str, tp.IO[str]]] = None, **k
368370
file.write(html)
369371
return html
370372

371-
def to_csv(self, file: tp.Union[Path, str, tp.IO[str]]) -> None:
373+
def to_csv(self, file: tp.Union[Path, str, TextWriterIO]) -> None:
372374
"""
373375
Dumps this Experiment as a .csv file.
374376
Information about display_data, parameters definition will be lost.
@@ -381,7 +383,7 @@ def to_csv(self, file: tp.Union[Path, str, tp.IO[str]]) -> None:
381383
else:
382384
return self._to_csv(file)
383385

384-
def _to_csv(self, fh: tp.IO[str]) -> None:
386+
def _to_csv(self, fh: TextWriterIO) -> None:
385387
fieldnames: tp.Set[str] = set()
386388
for dp in self.datapoints:
387389
for f in dp.values.keys():
@@ -512,20 +514,23 @@ def from_optuna(study: "optuna.study.Study") -> "Experiment": # No type hint to
512514
:param study: Optuna Study
513515
"""
514516

515-
516517
# Create a list of dictionary objects using study trials
517518
# All parameters are taken using params.copy()
519+
# pylint: disable=redefined-outer-name
518520
import optuna
519-
521+
520522
hyper_opt_data = []
521523
for each_trial in study.get_trials(states=(optuna.trial.TrialState.COMPLETE, )):
522524
trial_params = {}
523-
if not each_trial.values: # This checks if the trial was fully completed - the value will be None if the trial was interrupted halfway (e.g. via KeyboardInterrupt)
525+
# This checks if the trial was fully completed
526+
# the value will be None if the trial was interrupted halfway (e.g. via KeyboardInterrupt)
527+
if not each_trial.values:
524528
continue
525529
num_objectives = len(each_trial.values)
526530

527531
if num_objectives == 1:
528-
trial_params["value"] = each_trial.value # name = value, as it could be RMSE / accuracy, or any value that the user selects for tuning
532+
# name = value, as it could be RMSE / accuracy, or any value that the user selects for tuning
533+
trial_params["value"] = each_trial.value
529534
else:
530535
for objective_id, value in enumerate(each_trial.values):
531536
trial_params[f"value_{objective_id}"] = value
@@ -537,8 +542,6 @@ def from_optuna(study: "optuna.study.Study") -> "Experiment": # No type hint to
537542

538543
return experiment
539544

540-
541-
542545
@staticmethod
543546
def merge(xp_dict: tp.Dict[str, "Experiment"]) -> "Experiment":
544547
"""

hiplot/render.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import json
1010
from typing import Any, Dict
1111
from pathlib import Path
12+
import codecs
1213

1314
from . import fetchers
1415

@@ -100,10 +101,11 @@ def hiplot_render_main() -> int:
100101

101102
exp = fetchers.load_xp_with_fetchers(fetchers.get_fetchers(args.fetchers), args.experiment_uri)
102103
exp.validate()
104+
stdout_writer = codecs.getwriter("utf-8")(sys.stdout.buffer)
103105
if args.format == 'csv':
104-
exp.to_csv(sys.stdout)
106+
exp.to_csv(stdout_writer)
105107
elif args.format == 'html':
106-
exp.to_html(sys.stdout)
108+
exp.to_html(stdout_writer)
107109
else:
108110
assert False, args.format
109111
return 0

0 commit comments

Comments
 (0)