Skip to content

Commit 075bccd

Browse files
Nic-Mamonai-botericspodwyli
authored
2310 Add load_csv_datalist utility API (#2349)
* [DKMED] add CSV datalist Signed-off-by: Nic Ma <[email protected]> * [DLMED] add group feature Signed-off-by: Nic Ma <[email protected]> * [DLMED] add unit test Signed-off-by: Nic Ma <[email protected]> * [DLMED] add more unit tests Signed-off-by: Nic Ma <[email protected]> * [DLMED] add optional install Signed-off-by: Nic Ma <[email protected]> * [MONAI] python code formatting Signed-off-by: monai-bot <[email protected]> * [DLMED] fix flake8 issue Signed-off-by: Nic Ma <[email protected]> * [DLMED] add doc-strings Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix typo Signed-off-by: Nic Ma <[email protected]> * [DLMED] add CSVDataset for non-iterable data Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix min test Signed-off-by: Nic Ma <[email protected]> * [DLMED] add CSVIterableDataset base Signed-off-by: Nic Ma <[email protected]> * [DLMED] add CSVIterableDataset Signed-off-by: Nic Ma <[email protected]> * [DLMED] support multiple processes Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix tests Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix flake8 Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix docs-build Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix min tests Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix CI tests Signed-off-by: Nic Ma <[email protected]> * [MONAI] python code formatting Signed-off-by: monai-bot <[email protected]> * [DLMED] fix typo Signed-off-by: Nic Ma <[email protected]> * [DLMED] change sys.platform Signed-off-by: Nic Ma <[email protected]> * [DLMED] skip if windows Signed-off-by: Nic Ma <[email protected]> * [MONAI] python code formatting Signed-off-by: monai-bot <[email protected]> * [DLMED] add col_types arg Signed-off-by: Nic Ma <[email protected]> Co-authored-by: monai-bot <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: Wenqi Li <[email protected]>
1 parent 8cda6c1 commit 075bccd

File tree

13 files changed

+622
-7
lines changed

13 files changed

+622
-7
lines changed

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ sphinxcontrib-jsmath
1818
sphinxcontrib-qthelp
1919
sphinxcontrib-serializinghtml
2020
sphinx-autodoc-typehints==1.11.1
21+
pandas

docs/source/data.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ Generic Interfaces
2121
:members:
2222
:special-members: __next__
2323

24+
`CSVIterableDataset`
25+
~~~~~~~~~~~~~~~~~~~~
26+
.. autoclass:: CSVIterableDataset
27+
:members:
28+
:special-members: __next__
29+
2430
`PersistentDataset`
2531
~~~~~~~~~~~~~~~~~~~
2632
.. autoclass:: PersistentDataset
@@ -75,6 +81,12 @@ Generic Interfaces
7581
:members:
7682
:special-members: __getitem__
7783

84+
`CSVDataset`
85+
~~~~~~~~~~~~
86+
.. autoclass:: CSVDataset
87+
:members:
88+
:special-members: __getitem__
89+
7890
Patch-based dataset
7991
-------------------
8092

docs/source/installation.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
174174

175175
- The options are
176176
```
177-
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil]
177+
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas]
178178
```
179179
which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
180-
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb` and `psutil`, respectively.
180+
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim` `openslide-python` and `pandas`, respectively.
181181

182182
- `pip install 'monai[all]'` installs all the optional dependencies.

monai/config/deviceconfig.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def get_optional_config_values():
7979
output["tqdm"] = get_package_version("tqdm")
8080
output["lmdb"] = get_package_version("lmdb")
8181
output["psutil"] = psutil_version
82+
output["pandas"] = get_package_version("pandas")
8283

8384
return output
8485

monai/data/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ArrayDataset,
1616
CacheDataset,
1717
CacheNTransDataset,
18+
CSVDataset,
1819
Dataset,
1920
LMDBDataset,
2021
NPZDictItemDataset,
@@ -26,7 +27,7 @@
2627
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter
2728
from .image_dataset import ImageDataset
2829
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader
29-
from .iterable_dataset import IterableDataset
30+
from .iterable_dataset import CSVIterableDataset, IterableDataset
3031
from .nifti_saver import NiftiSaver
3132
from .nifti_writer import write_nifti
3233
from .png_saver import PNGSaver
@@ -38,6 +39,7 @@
3839
from .utils import (
3940
compute_importance_map,
4041
compute_shape_offset,
42+
convert_tables_to_dicts,
4143
correct_nifti_header_if_necessary,
4244
create_file_basename,
4345
decollate_batch,

monai/data/dataset.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
from torch.utils.data import Dataset as _TorchDataset
3030
from torch.utils.data import Subset
3131

32-
from monai.data.utils import first, pickle_hashing
32+
from monai.data.utils import convert_tables_to_dicts, first, pickle_hashing
3333
from monai.transforms import Compose, Randomizable, ThreadUnsafe, Transform, apply_transform
34-
from monai.utils import MAX_SEED, get_seed, min_version, optional_import
34+
from monai.utils import MAX_SEED, ensure_tuple, get_seed, min_version, optional_import
3535

3636
if TYPE_CHECKING:
3737
from tqdm import tqdm
@@ -41,6 +41,7 @@
4141
tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")
4242

4343
lmdb, _ = optional_import("lmdb")
44+
pd, _ = optional_import("pandas")
4445

4546

4647
class Dataset(_TorchDataset):
@@ -1061,3 +1062,72 @@ def _transform(self, index: int):
10611062
data = apply_transform(self.transform, data)
10621063

10631064
return data
1065+
1066+
1067+
class CSVDataset(Dataset):
1068+
"""
1069+
Dataset to load data from CSV files and generate a list of dictionaries,
1070+
every dictionay maps to a row of the CSV file, and the keys of dictionary
1071+
map to the column names of the CSV file.
1072+
1073+
It can load multiple CSV files and join the tables with addtional `kwargs` arg.
1074+
Support to only load specific rows and columns.
1075+
And it can also group several loaded columns to generate a new column, for example,
1076+
set `col_groups={"meta": ["meta_0", "meta_1", "meta_2"]}`, output can be::
1077+
1078+
[
1079+
{"image": "./image0.nii", "meta_0": 11, "meta_1": 12, "meta_2": 13, "meta": [11, 12, 13]},
1080+
{"image": "./image1.nii", "meta_0": 21, "meta_1": 22, "meta_2": 23, "meta": [21, 22, 23]},
1081+
]
1082+
1083+
Args:
1084+
filename: the filename of expected CSV file to load. if providing a list
1085+
of filenames, it will load all the files and join tables.
1086+
row_indices: indices of the expected rows to load. it should be a list,
1087+
every item can be a int number or a range `[start, end)` for the indices.
1088+
for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None,
1089+
load all the rows in the file.
1090+
col_names: names of the expected columns to load. if None, load all the columns.
1091+
col_types: `type` and `default value` to convert the loaded columns, if None, use original data.
1092+
it should be a dictionary, every item maps to an expected column, the `key` is the column
1093+
name and the `value` is None or a dictionary to define the default value and data type.
1094+
the supported keys in dictionary are: ["type", "default"]. for example::
1095+
1096+
col_types = {
1097+
"subject_id": {"type": str},
1098+
"label": {"type": int, "default": 0},
1099+
"ehr_0": {"type": float, "default": 0.0},
1100+
"ehr_1": {"type": float, "default": 0.0},
1101+
"image": {"type": str, "default": None},
1102+
}
1103+
1104+
col_groups: args to group the loaded columns to generate a new column,
1105+
it should be a dictionary, every item maps to a group, the `key` will
1106+
be the new column name, the `value` is the names of columns to combine. for example:
1107+
`col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}`
1108+
transform: transform to apply on the loaded items of a dictionary data.
1109+
kwargs: additional arguments for `pandas.merge()` API to join tables.
1110+
1111+
"""
1112+
1113+
def __init__(
1114+
self,
1115+
filename: Union[str, Sequence[str]],
1116+
row_indices: Optional[Sequence[Union[int, str]]] = None,
1117+
col_names: Optional[Sequence[str]] = None,
1118+
col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None,
1119+
col_groups: Optional[Dict[str, Sequence[str]]] = None,
1120+
transform: Optional[Callable] = None,
1121+
**kwargs,
1122+
):
1123+
files = ensure_tuple(filename)
1124+
dfs = [pd.read_csv(f) for f in files]
1125+
data = convert_tables_to_dicts(
1126+
dfs=dfs,
1127+
row_indices=row_indices,
1128+
col_names=col_names,
1129+
col_types=col_types,
1130+
col_groups=col_groups,
1131+
**kwargs,
1132+
)
1133+
super().__init__(data=data, transform=transform)

monai/data/iterable_dataset.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Callable, Iterable, Optional
12+
import math
13+
from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union
1314

1415
from torch.utils.data import IterableDataset as _TorchIterableDataset
16+
from torch.utils.data import get_worker_info
1517

18+
from monai.data.utils import convert_tables_to_dicts
1619
from monai.transforms import apply_transform
20+
from monai.utils import ensure_tuple, optional_import
21+
22+
pd, _ = optional_import("pandas")
1723

1824

1925
class IterableDataset(_TorchIterableDataset):
@@ -43,3 +49,94 @@ def __iter__(self):
4349
if self.transform is not None:
4450
data = apply_transform(self.transform, data)
4551
yield data
52+
53+
54+
class CSVIterableDataset(IterableDataset):
55+
"""
56+
Iterable dataset to load CSV files and generate dictionary data.
57+
It can be helpful when loading extemely big CSV files that can't read into memory directly.
58+
To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers,
59+
every process executes tranforms on part of every loaded chunk.
60+
Note: the order of output data may not match data source in multi-processing mode.
61+
62+
It can load data from multiple CSV files and join the tables with addtional `kwargs` arg.
63+
Support to only load specific columns.
64+
And it can also group several loaded columns to generate a new column, for example,
65+
set `col_groups={"meta": ["meta_0", "meta_1", "meta_2"]}`, output can be::
66+
67+
[
68+
{"image": "./image0.nii", "meta_0": 11, "meta_1": 12, "meta_2": 13, "meta": [11, 12, 13]},
69+
{"image": "./image1.nii", "meta_0": 21, "meta_1": 22, "meta_2": 23, "meta": [21, 22, 23]},
70+
]
71+
72+
Args:
73+
filename: the filename of expected CSV file to load. if providing a list
74+
of filenames, it will load all the files and join tables.
75+
chunksize: rows of a chunk when loading iterable data from CSV files, default to 1000. more details:
76+
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html.
77+
col_names: names of the expected columns to load. if None, load all the columns.
78+
col_types: `type` and `default value` to convert the loaded columns, if None, use original data.
79+
it should be a dictionary, every item maps to an expected column, the `key` is the column
80+
name and the `value` is None or a dictionary to define the default value and data type.
81+
the supported keys in dictionary are: ["type", "default"]. for example::
82+
83+
col_types = {
84+
"subject_id": {"type": str},
85+
"label": {"type": int, "default": 0},
86+
"ehr_0": {"type": float, "default": 0.0},
87+
"ehr_1": {"type": float, "default": 0.0},
88+
"image": {"type": str, "default": None},
89+
}
90+
91+
col_groups: args to group the loaded columns to generate a new column,
92+
it should be a dictionary, every item maps to a group, the `key` will
93+
be the new column name, the `value` is the names of columns to combine. for example:
94+
`col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}`
95+
transform: transform to apply on the loaded items of a dictionary data.
96+
kwargs: additional arguments for `pandas.merge()` API to join tables.
97+
98+
"""
99+
100+
def __init__(
101+
self,
102+
filename: Union[str, Sequence[str]],
103+
chunksize: int = 1000,
104+
col_names: Optional[Sequence[str]] = None,
105+
col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None,
106+
col_groups: Optional[Dict[str, Sequence[str]]] = None,
107+
transform: Optional[Callable] = None,
108+
**kwargs,
109+
):
110+
self.files = ensure_tuple(filename)
111+
self.chunksize = chunksize
112+
self.iters = self.reset()
113+
self.col_names = col_names
114+
self.col_types = col_types
115+
self.col_groups = col_groups
116+
self.kwargs = kwargs
117+
super().__init__(data=None, transform=transform) # type: ignore
118+
119+
def reset(self, filename: Optional[Union[str, Sequence[str]]] = None):
120+
if filename is not None:
121+
# update files if necessary
122+
self.files = ensure_tuple(filename)
123+
self.iters = [pd.read_csv(f, chunksize=self.chunksize) for f in self.files]
124+
return self.iters
125+
126+
def __iter__(self):
127+
for chunks in zip(*self.iters):
128+
self.data = convert_tables_to_dicts(
129+
dfs=chunks,
130+
col_names=self.col_names,
131+
col_types=self.col_types,
132+
col_groups=self.col_groups,
133+
**self.kwargs,
134+
)
135+
info = get_worker_info()
136+
if info is not None:
137+
length = len(self.data)
138+
per_worker = int(math.ceil(length / float(info.num_workers)))
139+
start = info.id * per_worker
140+
self.data = self.data[start : min(start + per_worker, length)]
141+
142+
return super().__iter__()

monai/data/utils.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
import pickle
1717
import warnings
1818
from collections import defaultdict
19+
from functools import reduce
1920
from itertools import product, starmap
2021
from pathlib import PurePath
21-
from typing import Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
22+
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
2223

2324
import numpy as np
2425
import torch
@@ -37,8 +38,11 @@
3738
)
3839
from monai.utils.enums import Method
3940

41+
pd, _ = optional_import("pandas")
42+
DataFrame, _ = optional_import("pandas", name="DataFrame")
4043
nib, _ = optional_import("nibabel")
4144

45+
4246
__all__ = [
4347
"get_random_patch",
4448
"iter_patch_slices",
@@ -65,6 +69,7 @@
6569
"decollate_batch",
6670
"pad_list_data_collate",
6771
"no_collation",
72+
"convert_tables_to_dicts",
6873
]
6974

7075

@@ -983,3 +988,80 @@ def sorted_dict(item, key=None, reverse=False):
983988
if not isinstance(item, dict):
984989
return item
985990
return {k: sorted_dict(v) if isinstance(v, dict) else v for k, v in sorted(item.items(), key=key, reverse=reverse)}
991+
992+
993+
def convert_tables_to_dicts(
994+
dfs,
995+
row_indices: Optional[Sequence[Union[int, str]]] = None,
996+
col_names: Optional[Sequence[str]] = None,
997+
col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None,
998+
col_groups: Optional[Dict[str, Sequence[str]]] = None,
999+
**kwargs,
1000+
) -> List[Dict[str, Any]]:
1001+
"""
1002+
Utility to join pandas tables, select rows, columns and generate groups.
1003+
Will return a list of dictionaries, every dictionary maps to a row of data in tables.
1004+
1005+
Args:
1006+
dfs: data table in pandas Dataframe format. if providing a list of tables, will join them.
1007+
row_indices: indices of the expected rows to load. it should be a list,
1008+
every item can be a int number or a range `[start, end)` for the indices.
1009+
for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None,
1010+
load all the rows in the file.
1011+
col_names: names of the expected columns to load. if None, load all the columns.
1012+
col_types: `type` and `default value` to convert the loaded columns, if None, use original data.
1013+
it should be a dictionary, every item maps to an expected column, the `key` is the column
1014+
name and the `value` is None or a dictionary to define the default value and data type.
1015+
the supported keys in dictionary are: ["type", "default"], and note that the value of `default`
1016+
should not be `None`. for example::
1017+
1018+
col_types = {
1019+
"subject_id": {"type": str},
1020+
"label": {"type": int, "default": 0},
1021+
"ehr_0": {"type": float, "default": 0.0},
1022+
"ehr_1": {"type": float, "default": 0.0},
1023+
}
1024+
1025+
col_groups: args to group the loaded columns to generate a new column,
1026+
it should be a dictionary, every item maps to a group, the `key` will
1027+
be the new column name, the `value` is the names of columns to combine. for example:
1028+
`col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}`
1029+
kwargs: additional arguments for `pandas.merge()` API to join tables.
1030+
1031+
"""
1032+
df = reduce(lambda l, r: pd.merge(l, r, **kwargs), ensure_tuple(dfs))
1033+
# parse row indices
1034+
rows: List[Union[int, str]] = []
1035+
if row_indices is None:
1036+
rows = slice(df.shape[0]) # type: ignore
1037+
else:
1038+
for i in row_indices:
1039+
if isinstance(i, (tuple, list)):
1040+
if len(i) != 2:
1041+
raise ValueError("range of row indices must contain 2 values: start and end.")
1042+
rows.extend(list(range(i[0], i[1])))
1043+
else:
1044+
rows.append(i)
1045+
1046+
# convert to a list of dictionaries corresponding to every row
1047+
data_ = df.loc[rows] if col_names is None else df.loc[rows, col_names]
1048+
if isinstance(col_types, dict):
1049+
# fill default values for NaN
1050+
defaults = {k: v["default"] for k, v in col_types.items() if v is not None and v.get("default") is not None}
1051+
if len(defaults) > 0:
1052+
data_ = data_.fillna(value=defaults)
1053+
# convert data types
1054+
types = {k: v["type"] for k, v in col_types.items() if v is not None and "type" in v}
1055+
if len(types) > 0:
1056+
data_ = data_.astype(dtype=types)
1057+
data: List[Dict] = data_.to_dict(orient="records")
1058+
1059+
# group columns to generate new column
1060+
if col_groups is not None:
1061+
groups: Dict[str, List] = {}
1062+
for name, cols in col_groups.items():
1063+
groups[name] = df.loc[rows, cols].values
1064+
# invert items of groups to every row of data
1065+
data = [dict(d, **{k: v[i] for k, v in groups.items()}) for i, d in enumerate(data)]
1066+
1067+
return data

0 commit comments

Comments
 (0)