Skip to content

Commit

Permalink
Merge pull request #12 from void-intelligence/develop
Browse files Browse the repository at this point in the history
Closing #11
  • Loading branch information
nirex0 authored May 12, 2020
2 parents 1a8be6a + 85856b9 commit bd43b68
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
17 changes: 6 additions & 11 deletions src/Vortex-Tests/ActivationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -334,23 +334,18 @@ public void SoftmaxPrimeTest()
a.InRandomize();
var b = a.Duplicate();

var res = b.Duplicate();
var sumExp = 0.0;

for (var i = 0; i < res.Rows; i++)
for (var j = 0; j < res.Columns; j++) sumExp += Math.Exp(b[i, j]);

for (var i = 0; i < res.Rows; i++)
for (var j = 0; j < res.Columns; j++) res[i, j] = Math.Exp(b[i, j]) / sumExp;

b = res;
for (var i = 0; i < b.Rows; i++)
for (var j = 0; j < b.Columns; j++) sumExp += Math.Exp(a[i, j]);

for (var i = 0; i < b.Rows; i++)
for (var j = 0; j < b.Columns; j++)
b[i, j] = Math.Exp(a[i, j]) / sumExp * (1.0 - Math.Exp(a[i, j]) / sumExp);

var s = new Softmax();
a = s.Forward(a);
a = s.Backward(a);

b.InMap((x) => Math.Exp(x) / sumExp * (1 - Math.Exp(x) / sumExp));

Assert.IsTrue(a == b, "Softmax Derivative successful");
}

Expand Down
15 changes: 12 additions & 3 deletions src/Vortex/Activation/Kernels/Softmax.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,22 @@ public override Matrix Forward(Matrix input)
for (var i = 0; i < res.Rows; i++)
for (var j = 0; j < res.Columns; j++) res[i, j] = Exp(input[i, j]) / SumExp;


return res;
}

public override Matrix Backward(Matrix input)
{
return input.Map(Derivative);
var res = input.Duplicate();
SumExp = 0.0;

for (var i = 0; i < res.Rows; i++)
for (var j = 0; j < res.Columns; j++) SumExp += Exp(input[i, j]);

for (var i = 0; i < res.Rows; i++)
for (var j = 0; j < res.Columns; j++)
res[i, j] = Exp(input[i, j]) / SumExp * (1.0 - Exp(input[i, j]) / SumExp);

return res;
}

protected override double Activate(double input)
Expand All @@ -42,7 +51,7 @@ protected override double Activate(double input)

protected override double Derivative(double input)
{
return Exp(input) / SumExp * (1 - Exp(input) / SumExp);
return 0;
}

public override EActivationType Type()
Expand Down

0 comments on commit bd43b68

Please sign in to comment.