@@ -160,7 +160,9 @@ def forward(self, x, y):
160160
161161@register_test_case (module_factory = lambda : ElementwiseGeFloatTensorModule ())
162162def ElementwiseGeFloatTensorModule_basic (module , tu : TestUtils ):
163- module .forward (tu .rand (3 , 5 ), tu .rand (5 ))
163+ module .forward (
164+ torch .tensor ([[1.0 , 2.2 , torch .nan ], [6.0 , 2.0 , 3.1 ]]).to (torch .float32 ),
165+ torch .tensor ([6.0 , 2.1 , torch .nan ]).to (torch .float32 ))
164166
165167# ==============================================================================
166168
@@ -200,7 +202,9 @@ def forward(self, x, y):
200202
201203@register_test_case (module_factory = lambda : ElementwiseGtFloatTensorModule ())
202204def ElementwiseGtFloatTensorModule_basic (module , tu : TestUtils ):
203- module .forward (tu .rand (3 , 5 ), tu .rand (5 ))
205+ module .forward (
206+ torch .tensor ([[1.0 , 2.2 , torch .nan ], [6.0 , 2.0 , 3.1 ]]).to (torch .float32 ),
207+ torch .tensor ([6.0 , 2.1 , torch .nan ]).to (torch .float32 ))
204208
205209# ==============================================================================
206210
@@ -378,6 +382,28 @@ def ElementwiseLeFloatTensorModule_basic(module, tu: TestUtils):
378382
379383# ==============================================================================
380384
385+ class ElementwiseLeFloatTensorNanModule (torch .nn .Module ):
386+ def __init__ (self ):
387+ super ().__init__ ()
388+
389+ @export
390+ @annotate_args ([
391+ None ,
392+ ([- 1 , - 1 ], torch .float32 , True ),
393+ ([- 1 ], torch .float32 , True ),
394+ ])
395+ def forward (self , x , y ):
396+ return torch .le (x , y )
397+
398+
399+ @register_test_case (module_factory = lambda : ElementwiseLeFloatTensorNanModule ())
400+ def ElementwiseLeFloatTensorNanModule_basic (module , tu : TestUtils ):
401+ module .forward (
402+ torch .tensor ([[1.0 , 2.2 , torch .nan ], [6.0 , 2.0 , 3.1 ]]).to (torch .float32 ),
403+ torch .tensor ([6.0 , 2.1 , torch .nan ]).to (torch .float32 ))
404+
405+ # ==============================================================================
406+
381407class ElementwiseLeIntTensorModule (torch .nn .Module ):
382408 def __init__ (self ):
383409 super ().__init__ ()
@@ -414,7 +440,9 @@ def forward(self, x, y):
414440
415441@register_test_case (module_factory = lambda : ElementwiseLtFloatTensorModule ())
416442def ElementwiseLtFloatTensorModule_basic (module , tu : TestUtils ):
417- module .forward (tu .rand (3 , 5 ), tu .rand (5 ))
443+ module .forward (
444+ torch .tensor ([[1.0 , 2.2 , torch .nan ], [6.0 , 2.0 , 3.1 ]]).to (torch .float32 ),
445+ torch .tensor ([6.0 , 2.1 , torch .nan ]).to (torch .float32 ))
418446
419447# ==============================================================================
420448
0 commit comments