-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcoster.go
54 lines (45 loc) · 1.44 KB
/
coster.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
package automata
import (
"math"
)
// Coster defines a cost function for use with the Trainer.
type Coster interface {
// Cost of the given output. The length of target and output will always be the same.
Cost(target, output []float64) (cost float64)
}
// MeanSquaredErrorCost implements the MSE cost function.
type MeanSquaredErrorCost struct{}
// Cost of the given output.
func (c *MeanSquaredErrorCost) Cost(target, output []float64) (cost float64) {
for i := range output {
cost += math.Pow(target[i]-output[i], 2)
}
return cost / float64(len(output))
}
// CrossEntropyCost implement the cross entropy function (Eq. 9)
type CrossEntropyCost struct{}
// Cost of the given output
func (c *CrossEntropyCost) Cost(target, output []float64) (cost float64) {
nudge := 1e-15 // nudge all values up a little from 0-1 to make it impossible to do math.Log(0) which = -Inf
for i := range output {
n := (target[i] * math.Log(output[i]+nudge)) +
((1 - target[i]) * math.Log((1+nudge)-output[i]))
cost -= n
}
return
}
// BinaryCost implement the binary (Zero-One Loss) function
type BinaryCost struct{}
// Cost of the given output
func (c *BinaryCost) Cost(target, output []float64) float64 {
var misses float64
for i := range output {
if c.round(target[i]*2) != c.round(output[i]*2) {
misses ++
}
}
return misses
}
func (c *BinaryCost) round(in float64) float64 {
return math.Floor(in + 0.5)
}