@@ -583,6 +583,30 @@ def doTest():
583583 else :
584584 print (colored ('\t FAIL' , 'red' ))
585585
586+ print (colored ('Testing softmax Axis::ROW' , 'cyan' ))
587+ shapeInput = np .random .randint (20 , 100 , [2 , ])
588+ shape = NumCpp .Shape (shapeInput [0 ].item (), shapeInput [1 ].item ())
589+ cArray = NumCpp .NdArray (shape )
590+ data = np .random .rand (shape .rows , shape .cols )
591+ cArray .setArray (data )
592+ if np .array_equal (roundArray (NumCpp .softmax (cArray , NumCpp .Axis .ROW ), NUM_DECIMALS_ROUND ),
593+ roundArray (sp .softmax (data , axis = 0 ), NUM_DECIMALS_ROUND )):
594+ print (colored ('\t PASS' , 'green' ))
595+ else :
596+ print (colored ('\t FAIL' , 'red' ))
597+
598+ print (colored ('Testing softmax Axis::COL' , 'cyan' ))
599+ shapeInput = np .random .randint (20 , 100 , [2 , ])
600+ shape = NumCpp .Shape (shapeInput [0 ].item (), shapeInput [1 ].item ())
601+ cArray = NumCpp .NdArray (shape )
602+ data = np .random .rand (shape .rows , shape .cols )
603+ cArray .setArray (data )
604+ if np .array_equal (roundArray (NumCpp .softmax (cArray , NumCpp .Axis .COL ), NUM_DECIMALS_ROUND ),
605+ roundArray (sp .softmax (data , axis = 1 ), NUM_DECIMALS_ROUND )):
606+ print (colored ('\t PASS' , 'green' ))
607+ else :
608+ print (colored ('\t FAIL' , 'red' ))
609+
586610 print (colored ('Testing spherical_bessel_jn scaler' , 'cyan' ))
587611 order = np .random .randint (0 , 10 )
588612 value = np .random .rand (1 ).item ()
0 commit comments