@@ -88,6 +88,8 @@ class ArrowBaseDataset(data.Dataset):
8888 and corresponding output tensor types, shapes and classes.
8989 """
9090
91+ batch_modes_supported = ('keep_remainder' , 'drop_remainder' , 'auto' )
92+
9193 def __init__ (self ,
9294 columns ,
9395 output_types ,
@@ -104,6 +106,10 @@ def __init__(self,
104106 batch_size or 0 ,
105107 dtype = dtypes .int64 ,
106108 name = "batch_size" )
109+ if batch_mode not in self .batch_modes_supported :
110+ raise ValueError (
111+ "Unsupported batch_mode: '{}', must be one of {}"
112+ .format (batch_mode , self .batch_modes_supported ))
107113 self ._batch_mode = tensorflow .convert_to_tensor (
108114 batch_mode ,
109115 dtypes .string ,
@@ -147,7 +153,10 @@ def __init__(self,
147153 output_types: Tensor dtypes of the output tensors
148154 output_shapes: TensorShapes of the output tensors or None to
149155 infer partial
150- batch_size: Batch size of output tensors
156+ batch_size: Batch size of output tensors, setting a batch size here
157+ will create batched Tensors from Arrow memory and can be more
158+ efficient than using tf.data.Dataset.batch().
159+ NOTE: batch_size does not need to be set if batch_mode='auto'
151160 batch_mode: Mode of batching, supported strings:
152161 "keep_remainder" (default, keeps partial batch data),
153162 "drop_remainder" (discard partial batch data),
@@ -187,7 +196,10 @@ def from_record_batches(cls,
187196 output_types: Tensor dtypes of the output tensors
188197 output_shapes: TensorShapes of the output tensors or None to
189198 infer partial
190- batch_size: Batch size of output tensors
199+ batch_size: Batch size of output tensors, setting a batch size here
200+ will create batched tensors from Arrow memory and can be more
201+ efficient than using tf.data.Dataset.batch().
202+ NOTE: batch_size does not need to be set if batch_mode='auto'
191203 batch_mode: Mode of batching, supported strings:
192204 "keep_remainder" (default, keeps partial batch data),
193205 "drop_remainder" (discard partial batch data),
@@ -230,7 +242,10 @@ def from_pandas(cls,
230242 df: a Pandas DataFrame
231243 columns: Optional column indices to use, if None all are used
232244 preserve_index: Flag to include the DataFrame index as the last column
233- batch_size: Batch size of output tensors
245+ batch_size: Batch size of output tensors, setting a batch size here
246+ will create batched tensors from Arrow memory and can be more
247+ efficient than using tf.data.Dataset.batch().
248+ NOTE: batch_size does not need to be set if batch_mode='auto'
234249 batch_mode: Mode of batching, supported strings:
235250 "keep_remainder" (default, keeps partial batch data),
236251 "drop_remainder" (discard partial batch data),
@@ -274,7 +289,10 @@ def __init__(self,
274289 output_types: Tensor dtypes of the output tensors
275290 output_shapes: TensorShapes of the output tensors or None to
276291 infer partial
277- batch_size: Batch size of output tensors
292+ batch_size: Batch size of output tensors, setting a batch size here
293+ will create batched tensors from Arrow memory and can be more
294+ efficient than using tf.data.Dataset.batch().
295+ NOTE: batch_size does not need to be set if batch_mode='auto'
278296 batch_mode: Mode of batching, supported strings:
279297 "keep_remainder" (default, keeps partial batch data),
280298 "drop_remainder" (discard partial batch data),
@@ -316,7 +334,10 @@ def from_schema(cls,
316334 in Arrow Feather format
317335 schema: Arrow schema defining the record batch data in the stream
318336 columns: A list of column indicies to use from the schema, None for all
319- batch_size: Batch size of output tensors
337+ batch_size: Batch size of output tensors, setting a batch size here
338+ will create batched tensors from Arrow memory and can be more
339+ efficient than using tf.data.Dataset.batch().
340+ NOTE: batch_size does not need to be set if batch_mode='auto'
320341 batch_mode: Mode of batching, supported strings:
321342 "keep_remainder" (default, keeps partial batch data),
322343 "drop_remainder" (discard partial batch data),
@@ -355,7 +376,10 @@ def __init__(self,
355376 output_types: Tensor dtypes of the output tensors
356377 output_shapes: TensorShapes of the output tensors or None to
357378 infer partial
358- batch_size: Batch size of output tensors
379+ batch_size: Batch size of output tensors, setting a batch size here
380+ will create batched tensors from Arrow memory and can be more
381+ efficient than using tf.data.Dataset.batch().
382+ NOTE: batch_size does not need to be set if batch_mode='auto'
359383 batch_mode: Mode of batching, supported strings:
360384 "keep_remainder" (default, keeps partial batch data),
361385 "drop_remainder" (discard partial batch data),
@@ -397,7 +421,10 @@ def from_schema(cls,
397421 For a socket client, use "<HOST_IP>:<PORT>", for stdin use "STDIN".
398422 schema: Arrow schema defining the record batch data in the stream
399423 columns: A list of column indicies to use from the schema, None for all
400- batch_size: Batch size of output tensors
424+ batch_size: Batch size of output tensors, setting a batch size here
425+ will create batched tensors from Arrow memory and can be more
426+ efficient than using tf.data.Dataset.batch().
427+ NOTE: batch_size does not need to be set if batch_mode='auto'
401428 batch_mode: Mode of batching, supported strings:
402429 "keep_remainder" (default, keeps partial batch data),
403430 "drop_remainder" (discard partial batch data),
0 commit comments