-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
31 lines (24 loc) · 798 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch.optim as optim
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from model import Model
def main():
print('Train....')
model = Model(K=288, Tx=8, Rx=2)
for name, param in model.named_parameters():
print('Name:', name, 'Size:', param.size())
batch_size = 8
in_tensor = torch.randn(batch_size, 4, 288, 8, requires_grad=True)
gt_labels = torch.empty(batch_size).random_(2)
out_tensor = model(in_tensor)
print('out size:', out_tensor.size())
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss = criterion(out_tensor, gt_labels)
loss.backward()
optimizer.step()
print(loss.item())
if __name__ == '__main__':
main()