Skip to content

Commit 1bd5aa9

Browse files
BryanCutlerterrytangyuan
authored andcommitted
Arrow batch_mode fail fast for unsupported options (#255)
* Fast fail for unsupported batch_mode * expanded ArrowDataset batching documentation and added to arrow/README.md
1 parent 4f5b935 commit 1bd5aa9

File tree

3 files changed

+70
-7
lines changed

3 files changed

+70
-7
lines changed

tensorflow_io/arrow/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,23 @@ with tf.Session() as sess:
109109

110110
An alternate constructor can also be used to infer output types and shapes from
111111
a given `pyarrow.Schema`, e.g. `dataset = ArrowStreamDataset.from_schema(host, schema)`
112+
113+
## Creating Batches with Arrow Datasets
114+
115+
Arrow Datasets have optional parameters to specify a `batch_size` and
116+
`batch_mode`. Supported `batch_modes` are: 'keep_remainder', 'drop_remainder'
117+
and 'auto'. If the last elements of the Dataset do not combine to the set
118+
`batch_size`, then 'keep_remainder' will return a partial batch, while
119+
'drop_remainder' will discard the partial batch. Setting `batch_mode` to 'auto'
120+
will automatically set a batch size to the number of records in the incoming
121+
Arrow record batches. This a good option to use if the incoming Arrow record
122+
batch size can be controlled to ensure the output batch size is not too large
123+
and sequential Arrow record batches are sized equally.
124+
125+
Setting the `batch_size` or using `batch_mode` of 'auto' can be more efficient
126+
than using `tf.data.Dataset.batch()` on an Arrow Dataset. This is because the
127+
output tensor can be sized to the desired batch size on creation, and then data
128+
is transferred directly from Arrow memory. Otherwise, if batching elements with
129+
the output of an Arrow Dataset, e.g. `ArrowDataset(...).batch(batch_size=4)`,
130+
then the tensor data will need to be aggregated and copied to get the final
131+
batched outputs.

tensorflow_io/arrow/python/ops/arrow_dataset_ops.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ class ArrowBaseDataset(data.Dataset):
8888
and corresponding output tensor types, shapes and classes.
8989
"""
9090

91+
batch_modes_supported = ('keep_remainder', 'drop_remainder', 'auto')
92+
9193
def __init__(self,
9294
columns,
9395
output_types,
@@ -104,6 +106,10 @@ def __init__(self,
104106
batch_size or 0,
105107
dtype=dtypes.int64,
106108
name="batch_size")
109+
if batch_mode not in self.batch_modes_supported:
110+
raise ValueError(
111+
"Unsupported batch_mode: '{}', must be one of {}"
112+
.format(batch_mode, self.batch_modes_supported))
107113
self._batch_mode = tensorflow.convert_to_tensor(
108114
batch_mode,
109115
dtypes.string,
@@ -147,7 +153,10 @@ def __init__(self,
147153
output_types: Tensor dtypes of the output tensors
148154
output_shapes: TensorShapes of the output tensors or None to
149155
infer partial
150-
batch_size: Batch size of output tensors
156+
batch_size: Batch size of output tensors, setting a batch size here
157+
will create batched Tensors from Arrow memory and can be more
158+
efficient than using tf.data.Dataset.batch().
159+
NOTE: batch_size does not need to be set if batch_mode='auto'
151160
batch_mode: Mode of batching, supported strings:
152161
"keep_remainder" (default, keeps partial batch data),
153162
"drop_remainder" (discard partial batch data),
@@ -187,7 +196,10 @@ def from_record_batches(cls,
187196
output_types: Tensor dtypes of the output tensors
188197
output_shapes: TensorShapes of the output tensors or None to
189198
infer partial
190-
batch_size: Batch size of output tensors
199+
batch_size: Batch size of output tensors, setting a batch size here
200+
will create batched tensors from Arrow memory and can be more
201+
efficient than using tf.data.Dataset.batch().
202+
NOTE: batch_size does not need to be set if batch_mode='auto'
191203
batch_mode: Mode of batching, supported strings:
192204
"keep_remainder" (default, keeps partial batch data),
193205
"drop_remainder" (discard partial batch data),
@@ -230,7 +242,10 @@ def from_pandas(cls,
230242
df: a Pandas DataFrame
231243
columns: Optional column indices to use, if None all are used
232244
preserve_index: Flag to include the DataFrame index as the last column
233-
batch_size: Batch size of output tensors
245+
batch_size: Batch size of output tensors, setting a batch size here
246+
will create batched tensors from Arrow memory and can be more
247+
efficient than using tf.data.Dataset.batch().
248+
NOTE: batch_size does not need to be set if batch_mode='auto'
234249
batch_mode: Mode of batching, supported strings:
235250
"keep_remainder" (default, keeps partial batch data),
236251
"drop_remainder" (discard partial batch data),
@@ -274,7 +289,10 @@ def __init__(self,
274289
output_types: Tensor dtypes of the output tensors
275290
output_shapes: TensorShapes of the output tensors or None to
276291
infer partial
277-
batch_size: Batch size of output tensors
292+
batch_size: Batch size of output tensors, setting a batch size here
293+
will create batched tensors from Arrow memory and can be more
294+
efficient than using tf.data.Dataset.batch().
295+
NOTE: batch_size does not need to be set if batch_mode='auto'
278296
batch_mode: Mode of batching, supported strings:
279297
"keep_remainder" (default, keeps partial batch data),
280298
"drop_remainder" (discard partial batch data),
@@ -316,7 +334,10 @@ def from_schema(cls,
316334
in Arrow Feather format
317335
schema: Arrow schema defining the record batch data in the stream
318336
columns: A list of column indicies to use from the schema, None for all
319-
batch_size: Batch size of output tensors
337+
batch_size: Batch size of output tensors, setting a batch size here
338+
will create batched tensors from Arrow memory and can be more
339+
efficient than using tf.data.Dataset.batch().
340+
NOTE: batch_size does not need to be set if batch_mode='auto'
320341
batch_mode: Mode of batching, supported strings:
321342
"keep_remainder" (default, keeps partial batch data),
322343
"drop_remainder" (discard partial batch data),
@@ -355,7 +376,10 @@ def __init__(self,
355376
output_types: Tensor dtypes of the output tensors
356377
output_shapes: TensorShapes of the output tensors or None to
357378
infer partial
358-
batch_size: Batch size of output tensors
379+
batch_size: Batch size of output tensors, setting a batch size here
380+
will create batched tensors from Arrow memory and can be more
381+
efficient than using tf.data.Dataset.batch().
382+
NOTE: batch_size does not need to be set if batch_mode='auto'
359383
batch_mode: Mode of batching, supported strings:
360384
"keep_remainder" (default, keeps partial batch data),
361385
"drop_remainder" (discard partial batch data),
@@ -397,7 +421,10 @@ def from_schema(cls,
397421
For a socket client, use "<HOST_IP>:<PORT>", for stdin use "STDIN".
398422
schema: Arrow schema defining the record batch data in the stream
399423
columns: A list of column indicies to use from the schema, None for all
400-
batch_size: Batch size of output tensors
424+
batch_size: Batch size of output tensors, setting a batch size here
425+
will create batched tensors from Arrow memory and can be more
426+
efficient than using tf.data.Dataset.batch().
427+
NOTE: batch_size does not need to be set if batch_mode='auto'
401428
batch_mode: Mode of batching, supported strings:
402429
"keep_remainder" (default, keeps partial batch data),
403430
"drop_remainder" (discard partial batch data),

tests/test_arrow.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,22 @@ def test_batch_variable_length_list(self):
639639
with self.assertRaisesRegexp(errors.OpError, 'variable.*unsupported'):
640640
self.run_test_case(dataset, truth_data, batch_size=batch_size)
641641

642+
def test_unsupported_batch_mode(self):
643+
"""Test using an unsupported batch mode
644+
"""
645+
truth_data = TruthData(
646+
self.scalar_data,
647+
self.scalar_dtypes,
648+
self.scalar_shapes)
649+
650+
with self.assertRaisesRegexp(ValueError, 'Unsupported batch_mode.*doh'):
651+
arrow_io.ArrowDataset.from_record_batches(
652+
[self.make_record_batch(truth_data)],
653+
list(range(len(truth_data.output_types))),
654+
truth_data.output_types,
655+
truth_data.output_shapes,
656+
batch_mode='doh')
657+
642658

643659
if __name__ == "__main__":
644660
test.main()

0 commit comments

Comments
 (0)