55from collections import OrderedDict
66from concurrent .futures import ThreadPoolExecutor
77from datetime import datetime
8- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
8+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union , Iterator
99
1010import pandas as pd
1111import pyarrow as arrow
@@ -578,6 +578,7 @@ def map_batches(
578578 func : Callable [[arrow .Table ], arrow .Table ],
579579 * ,
580580 batch_size : int = 122880 ,
581+ streaming : bool = False ,
581582 ** kwargs ,
582583 ) -> DataFrame :
583584 """
@@ -590,18 +591,35 @@ def map_batches(
590591 It should take a `arrow.Table` as input and returns a `arrow.Table`.
591592 batch_size, optional
592593 The number of rows in each batch. Defaults to 122880.
594+ streaming, optional
595+ If true, the function takes an iterator of `arrow.Table` as input and yields a streaming of `arrow.Table` as output.
596+ i.e. func: Callable[[Iterator[arrow.Table]], Iterator[arrow.Table]]
597+ Defaults to false.
593598 """
594599
595- def process_func (_runtime_ctx , tables : List [arrow .Table ]) -> arrow .Table :
596- return func (tables [0 ])
600+ if streaming :
601+ def process_func (_runtime_ctx , readers : List [arrow .RecordBatchReader ]) -> Iterator [arrow .Table ]:
602+ tables = map (lambda batch : arrow .Table .from_batches ([batch ]), readers [0 ])
603+ return func (tables )
597604
598- plan = ArrowBatchNode (
599- self .session ._ctx ,
600- (self .plan ,),
601- process_func = process_func ,
602- streaming_batch_size = batch_size ,
603- ** kwargs ,
604- )
605+ plan = ArrowStreamNode (
606+ self .session ._ctx ,
607+ (self .plan ,),
608+ process_func = process_func ,
609+ streaming_batch_size = batch_size ,
610+ ** kwargs ,
611+ )
612+ else :
613+ def process_func (_runtime_ctx , tables : List [arrow .Table ]) -> arrow .Table :
614+ return func (tables [0 ])
615+
616+ plan = ArrowBatchNode (
617+ self .session ._ctx ,
618+ (self .plan ,),
619+ process_func = process_func ,
620+ streaming_batch_size = batch_size ,
621+ ** kwargs ,
622+ )
605623 return DataFrame (self .session , plan , recompute = self .need_recompute )
606624
607625 def limit (self , limit : int ) -> DataFrame :
0 commit comments