-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorch_high_level_img_to_img.py
113 lines (103 loc) · 3.65 KB
/
pytorch_high_level_img_to_img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import cv2
import numpy as np
from tqdm import tqdm
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
GPU = False
if torch.cuda.is_available():
# torch.cuda.device_count -> for counting number of gpus. most laptops will have just 1
device = torch.device("cuda:0")# currently only supporting 1 gpu
GPU = True
print('running on GPU')
else:
device = torch.device("cpu")
GPU = False
print("running on CPU")
def fwd_pass(net,X,Y,optimizer,loss_function,train=False):
if train:
net.zero_grad()
output = net(X)
if(output.shape != Y.shape):
print("output shape does not match target shape!")
print("input shape:",X.shape)
print("output shape:",output.shape)
print("target shape:",Y.shape)
exit()
loss = loss_function(output,Y)
output = None
del output
if train:
loss.backward()
optimizer.step()
# print(loss)
return loss
def fit(net,X,Y,train_log,optimizer,loss_function,validation_set,BATCH_SIZE,EPOCHS):
val_size = int(validation_set*len(X))
data_size = len(X)
train_size = data_size - val_size
num_dim = len(X.shape)
base = 0
# if(num_dim==4):
# CHANNELS = X.shape[1]
# base = 1
# if(num_dim==3):
CHANNELS = 1
HEIGHT = 240
WIDTH = 320
for epochs in range(EPOCHS):
#insample data
train_average_loss = 0
val_average_loss = 0
train_counter = 0
val_counter = 0
optimizer = optim.Adam(net.parameters(),lr = 0.0001)
loss_function = nn.MSELoss()
for i in tqdm(range(0,train_size, BATCH_SIZE ) ):
batch_X = (torch.Tensor((X[i:i+BATCH_SIZE])).view(-1,CHANNELS,HEIGHT,WIDTH)).to(device)
batch_Y = (torch.Tensor((Y[i:i+BATCH_SIZE])).view(-1,CHANNELS,HEIGHT,WIDTH)).to(device)
train_loss = fwd_pass(net,batch_X,batch_Y,optimizer,loss_function,train=True)
batch_X = None
del batch_X
batch_Y = None
del batch_Y
if i%100==0:
train_average_loss += float(train_loss.cpu())
train_counter += 1
train_loss = None
del train_loss
#outsample data
del optimizer,loss_function
torch.cuda.empty_cache()
optimizer = optim.Adam(net.parameters(),lr = 0.001)
loss_function = nn.MSELoss()
for i in tqdm(range(train_size,data_size,BATCH_SIZE)):
batch_X = (torch.Tensor((X[i:i+BATCH_SIZE])).view(-1,CHANNELS,HEIGHT,WIDTH)).to(device)
batch_Y = (torch.Tensor((Y[i:i+BATCH_SIZE])).view(-1,CHANNELS,HEIGHT,WIDTH)).to(device)
val_loss = fwd_pass(net,batch_X,batch_Y,optimizer,loss_function,train=False)
batch_X = None
del batch_X
batch_Y = None
del batch_Y
if i%10==0:
val_average_loss += float(val_loss.cpu())
val_counter += 1
val_loss = None
del val_loss
# print('val loss: ',float(val_loss))
# del train_loss
# del val_loss
torch.cuda.empty_cache()
if(train_counter==0):
train_counter = 1
if(val_counter ==0):
val_counter = 1
train_log.append([train_average_loss/train_counter,val_average_loss/val_counter]) # just store the last values for now
optimizer = None
loss_function = None
del optimizer, loss_function
torch.cuda.empty_cache()
return train_log