Skip to content

Commit

Permalink
Support json dumping for more collections (#398)
Browse files Browse the repository at this point in the history
* Support json dumping for more collections

e.g. pandas or other numpy wrappers

* Formatting
  • Loading branch information
WardBrian authored Jul 2, 2021
1 parent 289c35f commit 263ebc8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
8 changes: 5 additions & 3 deletions cmdstanpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sys
import tempfile
from collections import OrderedDict
from collections.abc import Sequence
from collections.abc import Sequence, Collection
from numbers import Integral, Real
from typing import Dict, List, TextIO, Tuple, Union

Expand Down Expand Up @@ -371,8 +371,10 @@ def jsondump(path: str, data: Dict) -> None:
data = data.copy()
for key, val in data.items():
if type(val).__module__ == 'numpy':
val = val.tolist()
data[key] = val
data[key] = val.tolist()
elif isinstance(val, Collection):
data[key] = np.asarray(val).tolist()

with open(path, 'w') as fd:
json.dump(data, fd)

Expand Down
17 changes: 14 additions & 3 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import string
import tempfile
import unittest
import collections.abc

import numpy as np
import pandas as pd

from cmdstanpy import _DOT_CMDSTAN, _DOT_CMDSTANPY, _TMPDIR
from cmdstanpy.model import CmdStanModel
Expand Down Expand Up @@ -212,8 +214,9 @@ def cmp(d1, d2):
for k in d1:
data_1 = d1[k]
data_2 = d2[k]
if isinstance(data_2, np.ndarray):
data_2 = data_2.tolist()
if isinstance(data_2, collections.abc.Collection):
data_2 = np.asarray(data_2).tolist()

self.assertEqual(data_1, data_2)

dict_list = {'a': [1.0, 2.0, 3.0]}
Expand All @@ -222,12 +225,20 @@ def cmp(d1, d2):
with open(file_list) as fd:
cmp(json.load(fd), dict_list)

dict_vec = {'a': np.repeat(3, 4)}
arr = np.repeat(3, 4)
dict_vec = {'a': arr}
file_vec = os.path.join(_TMPDIR, 'vec.json')
jsondump(file_vec, dict_vec)
with open(file_vec) as fd:
cmp(json.load(fd), dict_vec)

series = pd.Series(arr)
dict_vec_pd = {'a': series}
file_vec_pd = os.path.join(_TMPDIR, 'vec_pd.json')
jsondump(file_vec_pd, dict_vec_pd)
with open(file_vec_pd) as fd:
cmp(json.load(fd), dict_vec_pd)

dict_zero_vec = {'a': []}
file_zero_vec = os.path.join(_TMPDIR, 'empty_vec.json')
jsondump(file_zero_vec, dict_zero_vec)
Expand Down

0 comments on commit 263ebc8

Please sign in to comment.