-
Notifications
You must be signed in to change notification settings - Fork 0
/
HeatmapGenerator1.py
86 lines (61 loc) · 2.92 KB
/
HeatmapGenerator1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import numpy as np
import time
import sys
import re
from PIL import Image
import cv2
## Model file
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from DensenetModels import DenseNet121
from DensenetModels import DenseNet169
from DensenetModels import DenseNet201
#--------------------------------------------------------------------------------
#---- Class to generate heatmaps (CAM)
class HeatmapGenerator ():
#---- Initialize heatmap generator
#---- pathModel - path to the trained densenet model
#---- nnArchitecture - architecture name DENSE-NET121, DENSE-NET169, DENSE-NET201
#---- nnClassCount - class count, 14 for chxray-14
def __init__ (self, pathModel, nnArchitecture, nnClassCount, transCrop):
#---- Initialize the network
if nnArchitecture == 'DENSE-NET-121': model = DenseNet121(nnClassCount, True).cuda()
elif nnArchitecture == 'DENSE-NET-169': model = DenseNet169(nnClassCount, True).cuda()
elif nnArchitecture == 'DENSE-NET-201': model = DenseNet201(nnClassCount, True).cuda()
model = torch.nn.DataParallel(model).cuda()
modelCheckpoint = torch.load(pathModel)
pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = modelCheckpoint['state_dict']
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
self.model = model
self.model.eval()
#---- Initialize the weights
#self.weights = list(self.model.module.densenet121.features.parameters())
#---- Initialize the image transform - resize + normalize
normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transformList = []
transformList.append(transforms.Resize(transCrop))
transformList.append(transforms.ToTensor())
transformList.append(normalize)
self.transformSequence = transforms.Compose(transformList)
#--------------------------------------------------------------------------------
def generate (self, pathImageFile, pathOutputFile, transCrop):
#---- Load image, transform, convert
imageData = Image.open(pathImageFile).convert('RGB')
imageData = self.transformSequence(imageData)
imageData = imageData.unsqueeze_(0)
input = torch.autograd.Variable(imageData)
self.model.cuda()
pred = self.model.forward(input.cuda())
return pred
#-------------------------------------