@@ -812,16 +812,15 @@ def __init__(
812812 self .current_window_index = None
813813 self .stop_window_index = None
814814
815- # TODO(https://github.com/apache/beam/issues/28776): Remove caching after
816- # fully rolling out.
817- # If true, always recalculate window args. If false, has_cached_window_args
818- # and has_cached_window_batch_args will be set to true if the corresponding
819- # self.args_for_process,have been updated and should be reused directly.
820- self .recalculate_window_args = (
821- self .has_windowed_inputs or 'disable_global_windowed_args_caching'
822- in RuntimeValueProvider .experiments )
823- self .has_cached_window_args = False
824- self .has_cached_window_batch_args = False
815+ # If true, after the first process invocation the the args for process will be cached
816+ # in cached_args_for_process and cached_kwargs_for_process and reused on
817+ # subsequent invocations in the same bundle..
818+ self .should_cache_args = (not self .has_windowed_inputs )
819+ self .cached_args_for_process = None
820+ self .cached_kwargs_for_process = None
821+ # See above, similar cached args for process_batch invocations.
822+ self .cached_args_for_process_batch = None
823+ self .cached_kwargs_for_process_batch = None
825824
826825 # Try to prepare all the arguments that can just be filled in
827826 # without any additional work. in the process function.
@@ -984,9 +983,9 @@ def _invoke_process_per_window(
984983 additional_kwargs ,
985984 ):
986985 # type: (...) -> Optional[SplitResultResidual]
987- if self .has_cached_window_args :
986+ if self .cached_args_for_process :
988987 args_for_process , kwargs_for_process = (
989- self .args_for_process , self .kwargs_for_process )
988+ self .cached_args_for_process , self .cached_kwargs_for_process )
990989 else :
991990 if self .has_windowed_inputs :
992991 assert len (windowed_value .windows ) <= 1
@@ -997,10 +996,9 @@ def _invoke_process_per_window(
997996 side_inputs .extend (additional_args )
998997 args_for_process , kwargs_for_process = util .insert_values_in_args (
999998 self .args_for_process , self .kwargs_for_process , side_inputs )
1000- if not self .recalculate_window_args :
1001- self .args_for_process , self .kwargs_for_process = (
999+ if self .should_cache_args :
1000+ self .cached_args_for_process , self .cached_kwargs_for_process = (
10021001 args_for_process , kwargs_for_process )
1003- self .has_cached_window_args = True
10041002
10051003 # Extract key in the case of a stateful DoFn. Note that in the case of a
10061004 # stateful DoFn, we set during __init__ self.has_windowed_inputs to be
@@ -1088,9 +1086,9 @@ def _invoke_process_batch_per_window(
10881086 ):
10891087 # type: (...) -> Optional[SplitResultResidual]
10901088
1091- if self .has_cached_window_batch_args :
1089+ if self .cached_args_for_process_batch :
10921090 args_for_process_batch , kwargs_for_process_batch = (
1093- self .args_for_process_batch , self .kwargs_for_process_batch )
1091+ self .cached_args_for_process_batch , self .cached_kwargs_for_process_batch )
10941092 else :
10951093 if self .has_windowed_inputs :
10961094 assert isinstance (windowed_batch , HomogeneousWindowedBatch )
@@ -1107,10 +1105,9 @@ def _invoke_process_batch_per_window(
11071105 side_inputs ,
11081106 )
11091107 )
1110- if not self .recalculate_window_args :
1111- self .args_for_process_batch , self .kwargs_for_process_batch = (
1108+ if self .should_cache_args :
1109+ self .cached_args_for_process_batch , self .cached_kwargs_for_process_batch = (
11121110 args_for_process_batch , kwargs_for_process_batch )
1113- self .has_cached_window_batch_args = True
11141111
11151112 for i , p in self .placeholders_for_process_batch :
11161113 if core .DoFn .ElementParam == p :
@@ -1150,6 +1147,15 @@ def _invoke_process_batch_per_window(
11501147 * args_for_process_batch , ** kwargs_for_process_batch ),
11511148 self .threadsafe_watermark_estimator )
11521149
1150+ def invoke_finish_bundle (self ):
1151+ # type: () -> None
1152+ # Clear the cached args to allow for refreshing of side inputs across bundles.
1153+ self .cached_args_for_process , self .cached_kwargs_for_process = (None , None )
1154+ self .cached_args_for_process_batch , self .cached_kwargs_for_process_batch = (
1155+ None , None )
1156+
1157+ super (PerWindowInvoker , self ).invoke_finish_bundle ()
1158+
11531159 @staticmethod
11541160 def _try_split (
11551161 fraction ,
0 commit comments