Skip to content

Commit 0e33502

Browse files
author
Pilger, David
committed
Merge branch 'develop'
2 parents 5f9aeac + 3b32388 commit 0e33502

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

include/NumCpp/Special/softmax.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ namespace nc
7575
}
7676
case Axis::ROW:
7777
{
78-
auto returnArray = exp(inArray).transpose();
79-
auto expSums = returnArray.sum(inAxis);
78+
auto returnArray = exp(inArray.transpose());
79+
auto expSums = returnArray.sum(Axis::COL);
8080

8181
for (uint32 row = 0; row < returnArray.shape().rows; ++row)
8282
{

unitTests/testScripts/TestSpecial.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,30 @@ def doTest():
583583
else:
584584
print(colored('\tFAIL', 'red'))
585585

586+
print(colored('Testing softmax Axis::ROW', 'cyan'))
587+
shapeInput = np.random.randint(20, 100, [2, ])
588+
shape = NumCpp.Shape(shapeInput[0].item(), shapeInput[1].item())
589+
cArray = NumCpp.NdArray(shape)
590+
data = np.random.rand(shape.rows, shape.cols)
591+
cArray.setArray(data)
592+
if np.array_equal(roundArray(NumCpp.softmax(cArray, NumCpp.Axis.ROW), NUM_DECIMALS_ROUND),
593+
roundArray(sp.softmax(data, axis=0), NUM_DECIMALS_ROUND)):
594+
print(colored('\tPASS', 'green'))
595+
else:
596+
print(colored('\tFAIL', 'red'))
597+
598+
print(colored('Testing softmax Axis::COL', 'cyan'))
599+
shapeInput = np.random.randint(20, 100, [2, ])
600+
shape = NumCpp.Shape(shapeInput[0].item(), shapeInput[1].item())
601+
cArray = NumCpp.NdArray(shape)
602+
data = np.random.rand(shape.rows, shape.cols)
603+
cArray.setArray(data)
604+
if np.array_equal(roundArray(NumCpp.softmax(cArray, NumCpp.Axis.COL), NUM_DECIMALS_ROUND),
605+
roundArray(sp.softmax(data, axis=1), NUM_DECIMALS_ROUND)):
606+
print(colored('\tPASS', 'green'))
607+
else:
608+
print(colored('\tFAIL', 'red'))
609+
586610
print(colored('Testing spherical_bessel_jn scaler', 'cyan'))
587611
order = np.random.randint(0, 10)
588612
value = np.random.rand(1).item()

0 commit comments

Comments
 (0)