-
Notifications
You must be signed in to change notification settings - Fork 0
/
output.py
32 lines (22 loc) · 840 Bytes
/
output.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from torchvision import transforms
from models import Generator
device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print('Device:', device)
checkpoint = torch.load('data/gan.pth', map_location='cpu', weights_only=False)
model = Generator().to(device)
model.load_state_dict(checkpoint['generator_state'])
model.eval()
transform = transforms.Compose([
transforms.ConvertImageDtype(torch.float32),
transforms.Normalize([0.5], [0.5])
])
inverse_transform = transforms.Compose([
transforms.Normalize([-1], [2]),
transforms.ConvertImageDtype(torch.uint8)
])
def colorize(sar_imgs):
sar_imgs = transform(sar_imgs).to(device)
with torch.no_grad():
color_imgs = model(sar_imgs)
return inverse_transform(color_imgs.cpu())