Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
sancarlim committed Feb 17, 2022
1 parent bc6606b commit b0e8b78
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,23 @@ def plot_diagnosis(predict_image_path, model,label):

if __name__ == "__main__":
parser = ArgumentParser()
#parser.add_argument("--path", type=str, default='/home/Data/generated/seed9984_1.png', help="Path to image to predict")
#parser.add_argument("--image_path", type=str, default='/home/Data/generated/seed9984_1.png', help="Path to image to predict")
parser.add_argument('--seeds', type=num_range, help='List of random seeds Ex. 0-3 or 0,1,2')
parser.add_argument("--data_path", type=str, default='/workspace/generated-no-valset')
parser.add_argument("--model_path", type=str, default='/workspace/stylegan2-ada-pytorch/CNN_trainings/melanoma_model_0_0.9225_16_12_train_reals+15melanoma.pth')
args = parser.parse_args()

# Setting up GPU for processing or CPU if GPU isn't available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
model = load_model()
model.load_state_dict(torch.load('/workspace/stylegan2-ada-pytorch/CNN_trainings/melanoma_model_0_0.9225_16_12_train_reals+15melanoma.pth'))
model.load_state_dict(torch.load(args.model_path))
model.eval()


if "SAM" in args.data_path:
input_images = [str(f) for f in sorted(Path("/workspace/stylegan2-ada-pytorch/processed_dataset_512_SAM").rglob('*jpg')) if os.path.isfile(f)]
input_images = [str(f) for f in sorted(Path(args.data_path).rglob('*jpg')) if os.path.isfile(f)]
y = [1 for i in range(len(input_images))]
test_df = pd.DataFrame({'image_name': input_images, 'target': y})
elif "isic" in args.data_path:
Expand Down

0 comments on commit b0e8b78

Please sign in to comment.