@@ -256,46 +256,70 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
256256
257257def rec_map_array_container (
258258 f : Callable [[Any ], Any ],
259- ary : ArrayOrContainerT ) -> ArrayOrContainerT :
259+ ary : ArrayOrContainerT ,
260+ leaf_class : Optional [type ] = None ) -> ArrayOrContainerT :
260261 r"""Applies *f* recursively to an :class:`ArrayContainer`.
261262
262263 For a non-recursive version see :func:`map_array_container`.
263264
264265 :param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s,
265266 or an instance of a base array type.
266267 """
267- return _map_array_container_impl (f , ary , recursive = True )
268+ return _map_array_container_impl (f , ary , leaf_cls = leaf_class , recursive = True )
268269
269270
270271def mapped_over_array_containers (
271- f : Callable [[Any ], Any ]) -> Callable [[ArrayOrContainerT ], ArrayOrContainerT ]:
272+ f : Optional [Callable [[Any ], Any ]] = None ,
273+ leaf_class : Optional [type ] = None ) -> Union [
274+ Callable [[ArrayOrContainerT ], ArrayOrContainerT ],
275+ Callable [
276+ [Callable [[Any ], Any ]],
277+ Callable [[ArrayOrContainerT ], ArrayOrContainerT ]]]:
272278 """Decorator around :func:`rec_map_array_container`."""
273- wrapper = partial (rec_map_array_container , f )
274- update_wrapper (wrapper , f )
275- return wrapper
279+ def decorator (g : Callable [[Any ], Any ]) -> Callable [
280+ [ArrayOrContainerT ], ArrayOrContainerT ]:
281+ wrapper = partial (rec_map_array_container , g , leaf_class = leaf_class )
282+ update_wrapper (wrapper , g )
283+ return wrapper
284+ if f is not None :
285+ return decorator (f )
286+ else :
287+ return decorator
276288
277289
278- def rec_multimap_array_container (f : Callable [..., Any ], * args : Any ) -> Any :
290+ def rec_multimap_array_container (
291+ f : Callable [..., Any ],
292+ * args : Any ,
293+ leaf_class : Optional [type ] = None ) -> Any :
279294 r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s.
280295
281296 For a non-recursive version see :func:`multimap_array_container`.
282297
283298 :param args: all :class:`ArrayContainer` arguments must be of the same
284299 type and with the same structure (same number of components, etc.).
285300 """
286- return _multimap_array_container_impl (f , * args , recursive = True )
301+ return _multimap_array_container_impl (
302+ f , * args , leaf_cls = leaf_class , recursive = True )
287303
288304
289305def multimapped_over_array_containers (
290- f : Callable [..., Any ]) -> Callable [..., Any ]:
306+ f : Optional [Callable [..., Any ]] = None ,
307+ leaf_class : Optional [type ] = None ) -> Union [
308+ Callable [..., Any ],
309+ Callable [[Callable [..., Any ]], Callable [..., Any ]]]:
291310 """Decorator around :func:`rec_multimap_array_container`."""
292- # can't use functools.partial, because its result is insufficiently
293- # function-y to be used as a method definition.
294- def wrapper (* args : Any ) -> Any :
295- return rec_multimap_array_container (f , * args )
311+ def decorator (g : Callable [..., Any ]) -> Callable [..., Any ]:
312+ # can't use functools.partial, because its result is insufficiently
313+ # function-y to be used as a method definition.
314+ def wrapper (* args : Any ) -> Any :
315+ return rec_multimap_array_container (g , * args , leaf_class = leaf_class )
316+ update_wrapper (wrapper , g )
317+ return wrapper
318+ if f is not None :
319+ return decorator (f )
320+ else :
321+ return decorator
296322
297- update_wrapper (wrapper , f )
298- return wrapper
299323
300324# }}}
301325
@@ -401,7 +425,8 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
401425def rec_map_reduce_array_container (
402426 reduce_func : Callable [[Iterable [Any ]], Any ],
403427 map_func : Callable [[Any ], Any ],
404- ary : ArrayOrContainerT ) -> "DeviceArray" :
428+ ary : ArrayOrContainerT ,
429+ leaf_class : Optional [type ] = None ) -> "DeviceArray" :
405430 """Perform a map-reduce over array containers recursively.
406431
407432 :param reduce_func: callable used to reduce over the components of *ary*
@@ -440,22 +465,26 @@ def rec_map_reduce_array_container(
440465 or any other such traversal.
441466 """
442467 def rec (_ary : ArrayOrContainerT ) -> ArrayOrContainerT :
443- try :
444- iterable = serialize_container (_ary )
445- except NotAnArrayContainerError :
468+ if type (_ary ) is leaf_class :
446469 return map_func (_ary )
447470 else :
448- return reduce_func ([
449- rec (subary ) for _ , subary in iterable
450- ])
471+ try :
472+ iterable = serialize_container (_ary )
473+ except NotAnArrayContainerError :
474+ return map_func (_ary )
475+ else :
476+ return reduce_func ([
477+ rec (subary ) for _ , subary in iterable
478+ ])
451479
452480 return rec (ary )
453481
454482
455483def rec_multimap_reduce_array_container (
456484 reduce_func : Callable [[Iterable [Any ]], Any ],
457485 map_func : Callable [..., Any ],
458- * args : Any ) -> "DeviceArray" :
486+ * args : Any ,
487+ leaf_class : Optional [type ] = None ) -> "DeviceArray" :
459488 r"""Perform a map-reduce over multiple array containers recursively.
460489
461490 :param reduce_func: callable used to reduce over the components of any
@@ -478,7 +507,7 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
478507
479508 return _multimap_array_container_impl (
480509 map_func , * args ,
481- reduce_func = _reduce_wrapper , leaf_cls = None , recursive = True )
510+ reduce_func = _reduce_wrapper , leaf_cls = leaf_class , recursive = True )
482511
483512# }}}
484513
0 commit comments