@@ -5508,6 +5508,18 @@ def h2(x, y):
55085508 ('4' , (1 , 4 , 1 , 6 , 1 ), (1 , 4 , 6 ),
55095509 P (None , 'x' , None , None , None ), P (None , 'x' , None ), False ),
55105510 ('5' , (4 , 6 ), (4 , 6 ), P (None , 'x' ), P (None , 'x' ), False ),
5511+ ('6' , (1024 , 4096 ), (1024 , 2048 , 2 , 1 , 1 , 1 , 1 ),
5512+ P ('x' , None ), P ('x' , None , None , None , None , None , None ), False ),
5513+ ('7' , (1024 , 4096 , 32 ), (1024 , 2048 , 2 , 1 , 1 , 32 ),
5514+ P ('x' , None , None ), P ('x' , None , None , None , None , None ), False ),
5515+ ('8' , (1024 , 4096 ), (1024 , 1 , 1 , 4096 ),
5516+ P ('x' , None ), P ('x' , None , None , None ), False ),
5517+ ('9' , (1024 , 4096 ), (1024 , 1 , 1 , 4096 ),
5518+ P (None , 'x' ), P (None , None , None , 'x' ), False ),
5519+ ('10' , (1024 , 2048 , 2 , 1 , 1 , 1 ), (1024 , 4096 ),
5520+ P ('x' , None , None , None , None , None ), P ('x' , None ), False ),
5521+ ('11' , (1024 , 2048 , 2 , 1 , 1 , 1 ), (1024 , 4096 ),
5522+ P (None , 'x' , None , None , None , None ), P (None , 'x' ), False ),
55115523 )
55125524 @jtu .with_user_mesh ((2 ,), ('x' ,))
55135525 def test_reshape (self , src_shape , dst_shape , src_spec , dst_spec ,
@@ -5519,6 +5531,8 @@ def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec,
55195531 @partial (jax .jit , static_argnums = 1 )
55205532 def f (x , new_sharding ):
55215533 y = lax .reshape (x , dst_shape , out_sharding = new_sharding )
5534+ self .assertEqual (y .aval .sharding .spec , dst_spec )
5535+ self .assertEqual (y .shape , dst_shape )
55225536 y = y * 2
55235537 self .assertEqual (y .aval .sharding .spec , dst_spec )
55245538 return y
0 commit comments