@@ -42,6 +42,10 @@ class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
4242 pass
4343
4444
45+ class _NoValue :
46+ pass
47+
48+
4549class PytatoFakeNumpyNamespace (BaseFakeNumpyNamespace ):
4650 """
4751 A :mod:`numpy` mimic for :class:`PytatoPyOpenCLArrayContext`.
@@ -91,22 +95,50 @@ def minimum(self, x, y):
9195 def where (self , criterion , then , else_ ):
9296 return rec_multimap_array_container (pt .where , criterion , then , else_ )
9397
94- def sum (self , a , axis = None , dtype = None ):
95- def _pt_sum (ary ):
98+ @staticmethod
99+ def _reduce (container_binop , array_reduce ,
100+ ary , * ,
101+ axis , dtype , initial ):
102+ def container_reduce (ctr ):
103+ if initial is _NoValue :
104+ try :
105+ return reduce (container_binop , ctr )
106+ except TypeError as exc :
107+ assert "empty sequence" in str (exc )
108+ raise ValueError ("zero-size reduction operation "
109+ "without supplied 'initial' value" )
110+ else :
111+ return reduce (container_binop , ctr , initial )
112+
113+ def actual_array_reduce (ary ):
96114 if dtype not in [ary .dtype , None ]:
97115 raise NotImplementedError
98116
99- return pt .sum (ary , axis = axis )
100-
101- return rec_map_reduce_array_container (sum , _pt_sum , a )
102-
103- def min (self , a , axis = None ):
104- return rec_map_reduce_array_container (
105- partial (reduce , pt .minimum ), partial (pt .amin , axis = axis ), a )
117+ if initial is _NoValue :
118+ return array_reduce (ary , axis = axis )
119+ else :
120+ return array_reduce (ary , axis = axis , initial = initial )
106121
107- def max (self , a , axis = None ):
108122 return rec_map_reduce_array_container (
109- partial (reduce , pt .maximum ), partial (pt .amax , axis = axis ), a )
123+ container_reduce ,
124+ actual_array_reduce ,
125+ ary )
126+
127+ # * appears where positional signature starts diverging from numpy
128+ def sum (self , a , axis = None , dtype = None , * , initial = 0 ):
129+ import operator
130+ return self ._reduce (operator .add , pt .sum , a ,
131+ axis = axis , dtype = dtype , initial = initial )
132+
133+ # * appears where positional signature starts diverging from numpy
134+ def min (self , a , axis = None , * , initial = _NoValue ):
135+ return self ._reduce (pt .minimum , pt .amin , a ,
136+ axis = axis , dtype = None , initial = initial )
137+
138+ # * appears where positional signature starts diverging from numpy
139+ def max (self , a , axis = None , * , initial = _NoValue ):
140+ return self ._reduce (pt .maximum , pt .amax , a ,
141+ axis = axis , dtype = None , initial = initial )
110142
111143 def stack (self , arrays , axis = 0 ):
112144 return rec_multimap_array_container (
0 commit comments