forked from sancarlim/stylegan2-ada-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Sylwia Majchrowska
committed
Mar 7, 2022
1 parent
e3d3daf
commit 3d9270e
Showing
1 changed file
with
108 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,164 +1,166 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
# File : server_mp.py | ||
# Modified : 22.01.2022 | ||
# By : Sandra Carrasco <[email protected]> | ||
# Last Modified : 22.01.2022 | ||
# By : Sandra Carrasco <[email protected]> | ||
|
||
import numpy as np | ||
import numpy as np | ||
import re | ||
import os | ||
from typing import List | ||
import matplotlib.pyplot as plt | ||
from pathlib import Path | ||
from PIL import Image | ||
import matplotlib.pyplot as plt | ||
from PIL import Image | ||
import torch | ||
#import torchtoolbox.transform as transforms | ||
from torchvision import transforms | ||
from efficientnet_pytorch import EfficientNet | ||
import seaborn as sb | ||
from argparse import ArgumentParser | ||
from argparse import ArgumentParser | ||
from melanoma_classifier import test | ||
from utils import load_model, load_isic_data, load_synthetic_data, CustomDataset , confussion_matrix | ||
import pandas as pd | ||
from sklearn.model_selection import train_test_split | ||
from datetime import date, datetime | ||
|
||
|
||
testing_transforms = transforms.Compose([transforms.Resize(256), | ||
transforms.CenterCrop(256), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], | ||
[0.229, 0.224, 0.225])]) | ||
|
||
def num_range(s: str) -> List[int]: | ||
'''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' | ||
from utils import (load_model, load_isic_data, | ||
load_synthetic_data, CustomDataset, | ||
confussion_matrix, testing_transforms) | ||
|
||
|
||
def num_range(s: str) -> List[int]: | ||
''' | ||
Accept either a comma separated list of numbers | ||
'a,b,c' or a range 'a-c' and return as a list of ints. | ||
''' | ||
range_re = re.compile(r'^(\d+)-(\d+)$') | ||
m = range_re.match(s) | ||
if m: | ||
return list(range(int(m.group(1)), int(m.group(2))+1)) | ||
vals = s.split(',') | ||
return [int(x) for x in vals] | ||
|
||
|
||
def process_image(image_path): | ||
''' Scales, crops, and normalizes a PIL image for a PyTorch model, | ||
returns an Numpy array | ||
''' | ||
''' | ||
Scales, crops, and normalizes a PIL image for a PyTorch model, | ||
returns an Numpy array | ||
''' | ||
# Process a PIL image for use in a PyTorch model | ||
|
||
pil_image = Image.open(image_path) | ||
|
||
# Resize | ||
if pil_image.size[0] > pil_image.size[1]: | ||
pil_image.thumbnail((5000, 256)) | ||
else: | ||
pil_image.thumbnail((256, 5000)) | ||
# Crop | ||
|
||
# Crop | ||
left_margin = (pil_image.width-256)/2 | ||
bottom_margin = (pil_image.height-256)/2 | ||
right_margin = left_margin + 256 | ||
top_margin = bottom_margin + 256 | ||
|
||
pil_image = pil_image.crop((left_margin, bottom_margin, right_margin, top_margin)) | ||
|
||
|
||
pil_image = pil_image.crop((left_margin, bottom_margin, | ||
right_margin, top_margin)) | ||
|
||
# Normalize | ||
np_image = np.array(pil_image)/255 | ||
mean = np.array([0.485, 0.456, 0.406]) | ||
std = np.array([0.229, 0.224, 0.225]) | ||
np_image = (np_image - mean) / std | ||
|
||
# PyTorch expects the color channel to be the first dimension but it's the third dimension in the PIL image and Numpy array | ||
# Color channel needs to be first; retain the order of the other two dimensions. | ||
|
||
# PyTorch expects the color channel to be the first dimension | ||
# but it's the third dimension in the PIL image and Numpy array | ||
# Color channel needs to be first; retain the order of the other | ||
# two dimensions. | ||
np_image = np_image.transpose((2, 0, 1)) | ||
|
||
return np_image | ||
|
||
|
||
def imshow(image, ax=None, title=None): | ||
if ax is None: | ||
fig, ax = plt.subplots() | ||
|
||
# PyTorch tensors assume the color channel is the first dimension | ||
# but matplotlib assumes is the third dimension | ||
image = image.transpose((1, 2, 0)) | ||
|
||
# Undo preprocessing | ||
mean = np.array([0.485, 0.456, 0.406]) | ||
std = np.array([0.229, 0.224, 0.225]) | ||
image = std * image + mean | ||
|
||
if title is not None: | ||
ax.set_title(title) | ||
|
||
# Image needs to be clipped between 0 and 1 or it looks like noise when displayed | ||
|
||
# Image needs to be clipped between 0 and 1 | ||
# or it looks like noise when displayed | ||
image = np.clip(image, 0, 1) | ||
|
||
ax.imshow(image) | ||
|
||
return ax | ||
|
||
def predict(image_path, model, topk=1): #just 2 classes from 1 single output | ||
''' Predict the class (or classes) of an image using a trained deep learning model. | ||
''' | ||
#image = process_image(image_path) | ||
|
||
# Convert image to PyTorch tensor first | ||
#image = torch.from_numpy(image).type(torch.cuda.FloatTensor) | ||
#print(image.shape) | ||
#print(type(image)) | ||
|
||
# Returns a new tensor with a dimension of size one inserted at the specified position. | ||
#image = image.unsqueeze(0) | ||
|
||
output = model(testing_transforms(Image.open(image_path)).type(torch.cuda.FloatTensor).unsqueeze(0)) # same output | ||
|
||
|
||
def predict(image_path, model, topk=1, prob=0.5): | ||
# just 2 classes from 1 single output | ||
''' | ||
Predict the class (or classes) of an image | ||
using a trained deep learning model. | ||
''' | ||
output = model(testing_transforms( | ||
Image.open(image_path)).type( | ||
torch.cuda.FloatTensor).unsqueeze(0)) # same output | ||
|
||
probabilities = torch.sigmoid(output) | ||
|
||
# Probabilities and the indices of those probabilities corresponding to the classes | ||
|
||
# Probabilities and the indices of those probabilities | ||
# corresponding to the classes | ||
top_probabilities, top_indices = probabilities.topk(topk) | ||
|
||
# Convert to lists | ||
top_probabilities = top_probabilities.detach().type(torch.FloatTensor).numpy().tolist()[0] | ||
top_indices = top_indices.detach().type(torch.FloatTensor).numpy().tolist()[0] | ||
|
||
top_probabilities = top_probabilities.detach().type( | ||
torch.FloatTensor).numpy().tolist()[0] | ||
top_indices = top_indices.detach().type( | ||
torch.FloatTensor).numpy().tolist()[0] | ||
|
||
top_classes = [] | ||
if probabilities > 0.5 : | ||
|
||
if probabilities > prob: | ||
top_classes.append("Melanoma") | ||
else: | ||
top_classes.append("Benign") | ||
|
||
|
||
return top_probabilities, top_classes | ||
|
||
def plot_diagnosis(predict_image_path, model,label): | ||
|
||
def plot_diagnosis(predict_image_path, model, label): | ||
img_nb = predict_image_path.split('/')[-1].split('.')[0] | ||
probs, classes = predict(predict_image_path, model) | ||
print(probs) | ||
print(classes) | ||
probs, classes = predict(predict_image_path, model) | ||
|
||
# Display an image along with the diagnosis of melanoma or benign | ||
# Plot Skin image input image | ||
plt.figure(figsize = (6,10)) | ||
plot_1 = plt.subplot(2,1,1) | ||
plt.figure(figsize=(6, 10)) | ||
plot_1 = plt.subplot(2, 1, 1) | ||
|
||
image = process_image(predict_image_path) | ||
|
||
imshow(image, plot_1) | ||
font = {"color": 'g'} if 'Benign' in classes and label == 0 or 'Melanoma' in classes and label == 1 else {"color": 'r'} | ||
plot_1.set_title(f"Diagnosis: {classes}, Output (prob) {probs[0]:.4f}, Label: {label}", fontdict=font); | ||
if (('Benign' in classes and label == 0) | ||
or ('Melanoma' in classes and label == 1)): | ||
font = {"color": 'g'} | ||
else: | ||
font = {"color": 'r'} | ||
plot_1.set_title( | ||
f"Diagnosis: {classes}, Output (prob) {probs[0]:.4f}, Label: {label}", | ||
fontdict=font) | ||
plt.savefig(f'{args.out_path}/prediction_{img_nb}.png') | ||
|
||
|
||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser.add_argument('--seeds', type=num_range, help='List of random seeds Ex. 0-3 or 0,1,2') | ||
parser.add_argument("--data_path", type=str, default='/workspace/generated-no-valset') | ||
parser.add_argument("--model_path", type=str, default='/workspace/stylegan2-ada-pytorch/CNN_trainings/melanoma_model_0_0.9225_16_12_train_reals+15melanoma.pth') | ||
parser.add_argument("--out_path", type=str, default='', help='output path for confussion matrix') | ||
|
||
parser = ArgumentParser() | ||
parser.add_argument('--seeds', type=num_range, | ||
help='List of random seeds Ex. 0-3 or 0,1,2') | ||
parser.add_argument("--data_path", type=str) | ||
parser.add_argument("--model_path", type=str) | ||
parser.add_argument("--out_path", type=str, default='', | ||
help='output path for confussion matrix') | ||
parser.add_argument( | ||
"--plot", | ||
action="store_true", | ||
default=False, | ||
help="Plot and save image with diagnosis", | ||
) | ||
args = parser.parse_args() | ||
|
||
# Setting up GPU for processing or CPU if GPU isn't available | ||
|
@@ -169,26 +171,26 @@ def plot_diagnosis(predict_image_path, model,label): | |
model.load_state_dict(torch.load(args.model_path)) | ||
model.eval() | ||
|
||
|
||
if "SAM" in args.data_path: | ||
input_images = [str(f) for f in sorted(Path(args.data_path).rglob('*jpg')) if os.path.isfile(f)] | ||
y = [1 for i in range(len(input_images))] | ||
test_df = pd.DataFrame({'image_name': input_images, 'target': y}) | ||
elif "isic" in args.data_path: | ||
if "isic" in args.data_path: | ||
# For testing with ISIC dataset | ||
_, test_df = load_isic_data(args.data_path) | ||
else: | ||
else: | ||
test_df = load_synthetic_data(args.data_path, "3,3") | ||
|
||
|
||
testing_dataset = CustomDataset(df = test_df, train = True, transforms = testing_transforms ) | ||
test_loader = torch.utils.data.DataLoader(testing_dataset, batch_size=16, shuffle = False) | ||
test_pred, test_gt, test_accuracy = test(model, test_loader) | ||
testing_dataset = CustomDataset(df=test_df, train=True, | ||
transforms=testing_transforms) | ||
test_loader = torch.utils.data.DataLoader(testing_dataset, | ||
batch_size=16, | ||
shuffle=False) | ||
test_pred, test_gt, test_accuracy = test(model, test_loader) | ||
confussion_matrix(test_gt, test_pred, test_accuracy, args.out_path) | ||
|
||
# Plot diagnosis | ||
""" for seed_idx, seed in enumerate(args.seeds): | ||
print('Predicting image for seed %d (%d/%d) ...' % (seed, seed_idx, len(args.seeds))) | ||
path = '/home/Data/generated/seed' + str(seed).zfill(4) | ||
path += '_0.png' if seed <= 5000 else '_1.png' | ||
plot_diagnosis(path, model) """ | ||
# Plot diagnosis | ||
if args.plot: | ||
for seed_idx, seed in enumerate(args.seeds): | ||
print( | ||
f'Predicting image for seed ' | ||
f'{seed} ({seed_idx}/{len(args.seeds)}) ...') | ||
path = os.path.join(args.out_path, 'seed' + str(seed).zfill(4)) | ||
path += '_0.png' if seed <= 5000 else '_1.png' | ||
plot_diagnosis(path, model) |