-
Notifications
You must be signed in to change notification settings - Fork 7
/
test_models.py
40 lines (33 loc) · 1.16 KB
/
test_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Tests for AugmentedNet.models."""
import os
import unittest
import numpy as np
from AugmentedNet import models
from AugmentedNet.train import InputOutput
# Force no-gpu mode in this test
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
class TestModels(unittest.TestCase):
def test_augmentednet(self):
inputs = [
InputOutput(f"input{i}", np.zeros((1, 640, 19))) for i in range(2)
]
outputs = [
InputOutput(f"output{i}", np.zeros((1, 640, 30))) for i in range(6)
]
for o in outputs:
o.shortname = o.name
o.outputFeatures = 30
model = models.AugmentedNet(inputs, outputs, blocks=6)
self.assertEqual(model.count_params(), 77002)
def test_micchietal(self):
inputs = [
InputOutput(f"input{i}", np.zeros((1, 640, 19))) for i in range(2)
]
outputs = [
InputOutput(f"output{i}", np.zeros((1, 640, 30))) for i in range(6)
]
for o in outputs:
o.shortname = o.name
o.outputFeatures = 30
model = models.Micchi2020(inputs, outputs)
self.assertEqual(model.count_params(), 89412)