-
Notifications
You must be signed in to change notification settings - Fork 65
/
persist_test.go
57 lines (46 loc) · 1.13 KB
/
persist_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
package deep
import (
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_RestoreFromDump(t *testing.T) {
rand.Seed(0)
n := NewNeural(&Config{
Inputs: 1,
Layout: []int{5, 3, 1},
Activation: ActivationSigmoid,
Weight: NewUniform(0.5, 0),
Bias: true,
})
dump := n.Dump()
new := FromDump(dump)
for i, biases := range n.Biases {
for j, bias := range biases {
assert.Equal(t, bias.Weight, new.Biases[i][j].Weight)
}
}
assert.Equal(t, n.String(), new.String())
assert.Equal(t, n.Predict([]float64{0}), new.Predict([]float64{0}))
}
func Test_Marshal(t *testing.T) {
rand.Seed(0)
n := NewNeural(&Config{
Inputs: 1,
Layout: []int{3, 3, 1},
Activation: ActivationSigmoid,
Weight: NewUniform(0.5, 0),
Bias: true,
})
dump, err := n.Marshal()
assert.Nil(t, err)
new, err := Unmarshal(dump)
assert.Nil(t, err)
for i, biases := range n.Biases {
for j, bias := range biases {
assert.Equal(t, bias.Weight, new.Biases[i][j].Weight)
}
}
assert.Equal(t, n.String(), new.String())
assert.Equal(t, n.Predict([]float64{0}), new.Predict([]float64{0}))
}