@@ -2538,20 +2538,25 @@ def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims, rng):
2538
2538
self ._CheckBatching (op , 5 , bdims , (inshape ,), dtype , rng )
2539
2539
2540
2540
@parameterized .named_parameters (jtu .cases_from_list (
2541
- {"testcase_name" : "_inshape={}_outshape={}_bdims={}" .format (
2541
+ {"testcase_name" : "_inshape={}_outshape={}_dims={} _bdims={}" .format (
2542
2542
jtu .format_shape_dtype_string (arg_shape , dtype ),
2543
2543
jtu .format_shape_dtype_string (out_shape , dtype ),
2544
- bdims ),
2544
+ dimensions , bdims ),
2545
2545
"arg_shape" : arg_shape , "out_shape" : out_shape , "dtype" : dtype ,
2546
- "bdims" : bdims , "rng" : rng }
2546
+ "dimensions" : dimensions , " bdims" : bdims , "rng" : rng }
2547
2547
for dtype in default_dtypes
2548
- for arg_shape , out_shape in [
2549
- [(3 , 4 ), (12 ,)], [(2 , 1 , 4 ), (8 ,)], [(2 , 2 , 4 ), (2 , 8 )]
2548
+ for arg_shape , dimensions , out_shape in [
2549
+ [(3 , 4 ), None , (12 ,)],
2550
+ [(2 , 1 , 4 ), None , (8 ,)],
2551
+ [(2 , 2 , 4 ), None , (2 , 8 )],
2552
+ [(2 , 2 , 4 ), (0 , 1 , 2 ), (2 , 8 )],
2553
+ [(2 , 2 , 4 ), (1 , 0 , 2 ), (8 , 2 )],
2554
+ [(2 , 2 , 4 ), (2 , 1 , 0 ), (4 , 2 , 2 )]
2550
2555
]
2551
2556
for bdims in all_bdims (arg_shape )
2552
2557
for rng in [jtu .rand_default ()]))
2553
- def testReshape (self , arg_shape , out_shape , dtype , bdims , rng ):
2554
- op = lambda x : lax .reshape (x , out_shape )
2558
+ def testReshape (self , arg_shape , out_shape , dtype , dimensions , bdims , rng ):
2559
+ op = lambda x : lax .reshape (x , out_shape , dimensions = dimensions )
2555
2560
self ._CheckBatching (op , 10 , bdims , (arg_shape ,), dtype , rng )
2556
2561
2557
2562
@parameterized .named_parameters (jtu .cases_from_list (
0 commit comments