Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sancarlim committed Feb 11, 2022
1 parent 6f48db7 commit bc6606b
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 45 deletions.
28 changes: 12 additions & 16 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,38 +94,34 @@ def generate_images(
os.makedirs(outdir, exist_ok=True)


# Labels.
label = torch.zeros([1, G.c_dim], device=device)
if G.c_dim != 0:
if class_idx is None:
ctx.fail('Must specify class label with --class when using a conditional network')
label[:, class_idx] = 1
else:
if class_idx is not None:
print ('warn: --class=lbl ignored when running on an unconditional network')


# Synthesize the result of a W projection.
if projected_w is not None:
if seeds is not None:
print ('warn: --seeds is ignored when using --projected-w')
print(f'Generating images from projected W "{projected_w}"')
ws_np = np.load(projected_w)['w']
ws = torch.tensor(ws_np, device=device) # pylint: disable=not-callable
ws = np.load(projected_w)['w']
ws = torch.tensor(ws, device=device) # pylint: disable=not-callable
assert ws.shape[1:] == (G.num_ws, G.w_dim)
for idx, w in enumerate(ws):
img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode)
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.jpg')


return

if seeds is None:
seeds=list(np.random.randint(0,1000000,num_imgs))
#ctx.fail('--seeds option is required when not using --projected-w')


# Labels.
label = torch.zeros([1, G.c_dim], device=device)
if G.c_dim != 0:
if class_idx is None:
ctx.fail('Must specify class label with --class when using a conditional network')
label[:, class_idx] = 1
else:
if class_idx is not None:
print ('warn: --class=lbl ignored when running on an unconditional network')

# Generate images.
for seed_idx, seed in enumerate(seeds):
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
Expand Down
14 changes: 11 additions & 3 deletions melanoma_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,12 @@ def train(model, train_loader, validate_loader, k_fold = 0, epochs = 10, es_pati
writer.add_scalar('Validation AUC Score', val_auc_score, e+1 )
"""

scheduler.step(val_auc_score)
scheduler.step(val_accuracy)

if val_auc_score > best_val:
best_val = val_auc_score
if val_accuracy > best_val:
best_val = val_accuracy
wandb.run.summary["best_auc_score"] = val_auc_score
wandb.run.summary["best_acc_score"] = val_accuracy
patience = es_patience # Resetting patience since we have new best validation accuracy
model_path = os.path.join(writer_path, f'./classifier_{args.model}_{best_val:.4f}_{datetime.datetime.now()}.pth')
torch.save(model.state_dict(), model_path) # Saving current best model
Expand Down Expand Up @@ -199,6 +200,8 @@ def val(model, validate_loader, criterion):
def test(model, test_loader):
test_preds=[]
all_labels=[]
misclassified = []
low_confidence = []
with torch.no_grad():

for _, (test_images, test_labels) in enumerate(test_loader):
Expand All @@ -215,6 +218,11 @@ def test(model, test_loader):
test_pred2 = torch.tensor(test_pred)
test_gt = np.concatenate(all_labels)
test_gt2 = torch.tensor(test_gt)

indeces_misclassified = np.where(test_gt != np.round(test_pred))[0]
well_classified = list(set(list(range(0, len(test_gt2)))) - set(indeces_misclassified.tolist()))
edge_cases = np.where( (test_gt[well_classified] - test_pred[well_classified]) > 0.25 )[0]

try:
test_accuracy = accuracy_score(test_gt2.cpu(), torch.round(test_pred2))
test_auc_score = roc_auc_score(test_gt, test_pred)
Expand Down
44 changes: 19 additions & 25 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import seaborn as sb
from argparse import ArgumentParser
from melanoma_classifier import test
from utils import Net, Synth_Dataset, CustomDataset , confussion_matrix
from utils import load_model, load_isic_data, load_synthetic_data, CustomDataset , confussion_matrix
import pandas as pd
from sklearn.model_selection import train_test_split
from datetime import date, datetime
Expand Down Expand Up @@ -156,43 +156,37 @@ def plot_diagnosis(predict_image_path, model,label):
parser = ArgumentParser()
#parser.add_argument("--path", type=str, default='/home/Data/generated/seed9984_1.png', help="Path to image to predict")
parser.add_argument('--seeds', type=num_range, help='List of random seeds Ex. 0-3 or 0,1,2')
parser.add_argument("--data_path", type=str, default='/workspace/melanoma_isic_dataset')
parser.add_argument("--data_path", type=str, default='/workspace/generated-no-valset')
args = parser.parse_args()

# Setting up GPU for processing or CPU if GPU isn't available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
arch = EfficientNet.from_pretrained('efficientnet-b2')
model = Net(arch=arch)
model = load_model()
model.load_state_dict(torch.load('/workspace/stylegan2-ada-pytorch/CNN_trainings/melanoma_model_0_0.9225_16_12_train_reals+15melanoma.pth'))
model.eval()
model.to(device)

# TEST
#testing_dataset = Synth_Dataset(source_dir = args.data_path, transform = testing_transforms,
# id_list = None, test=True, unbalanced=False)

input_images = [str(f) for f in sorted(Path("/workspace/stylegan2-ada-pytorch/processed_dataset_512_SAM").rglob('*jpg')) if os.path.isfile(f)]
y = [1 for i in range(len(input_images))]
test_df = pd.DataFrame({'image_name': input_images, 'target': y})
"""
# For testing with ISIC dataset
df = pd.read_csv(os.path.join(args.data_path , 'train_concat.csv'))
train_img_dir = os.path.join(args.data_path ,'train/train/')
train_split, valid_split = train_test_split (df, stratify=df.target, test_size = 0.20, random_state=42)
validation_df=pd.DataFrame(valid_split)
validation_df['image_name'] = [os.path.join(train_img_dir, validation_df.iloc[index]['image_name'] + '.jpg') for index in range(len(validation_df))]
"""


if "SAM" in args.data_path:
input_images = [str(f) for f in sorted(Path("/workspace/stylegan2-ada-pytorch/processed_dataset_512_SAM").rglob('*jpg')) if os.path.isfile(f)]
y = [1 for i in range(len(input_images))]
test_df = pd.DataFrame({'image_name': input_images, 'target': y})
elif "isic" in args.data_path:
# For testing with ISIC dataset
_, test_df = load_isic_data(args.data_path)
else:
test_df = load_synthetic_data(args.data_path, "3,3")


testing_dataset = CustomDataset(df = test_df, train = True, transforms = testing_transforms )
test_loader = torch.utils.data.DataLoader(testing_dataset, batch_size=16, shuffle = False)
test_pred, test_gt, test_accuracy = test(model, test_loader)
confussion_matrix(test_gt, test_pred, test_accuracy)
# confussion_matrix(test_gt, test_pred, test_accuracy)

# Plot diagnosis
for seed_idx, seed in enumerate(args.seeds):
""" for seed_idx, seed in enumerate(args.seeds):
print('Predicting image for seed %d (%d/%d) ...' % (seed, seed_idx, len(args.seeds)))
path = '/home/Data/generated/seed' + str(seed).zfill(4)
path += '_0.png' if seed <= 5000 else '_1.png'
plot_diagnosis(path, model)
plot_diagnosis(path, model) """
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def load_isic_data(path):
validation_df=pd.DataFrame(valid_split)
return train_df, validation_df

def load_synthetic_data(syn_data_path, synt_n_imgs, only_syn):
def load_synthetic_data(syn_data_path, synt_n_imgs, only_syn=False):
#Load all images and labels from path
input_images = [str(f) for f in sorted(Path(syn_data_path).rglob('*')) if os.path.isfile(f)]
y = [0 if f.split('.jpg')[0][-1] == '0' else 1 for f in input_images]
Expand Down

0 comments on commit bc6606b

Please sign in to comment.