diff --git a/predict.py b/predict.py index 254af9c..74d6bb6 100644 --- a/predict.py +++ b/predict.py @@ -154,9 +154,10 @@ def plot_diagnosis(predict_image_path, model,label): if __name__ == "__main__": parser = ArgumentParser() - #parser.add_argument("--path", type=str, default='/home/Data/generated/seed9984_1.png', help="Path to image to predict") + #parser.add_argument("--image_path", type=str, default='/home/Data/generated/seed9984_1.png', help="Path to image to predict") parser.add_argument('--seeds', type=num_range, help='List of random seeds Ex. 0-3 or 0,1,2') parser.add_argument("--data_path", type=str, default='/workspace/generated-no-valset') + parser.add_argument("--model_path", type=str, default='/workspace/stylegan2-ada-pytorch/CNN_trainings/melanoma_model_0_0.9225_16_12_train_reals+15melanoma.pth') args = parser.parse_args() # Setting up GPU for processing or CPU if GPU isn't available @@ -164,12 +165,12 @@ def plot_diagnosis(predict_image_path, model,label): # Load model model = load_model() - model.load_state_dict(torch.load('/workspace/stylegan2-ada-pytorch/CNN_trainings/melanoma_model_0_0.9225_16_12_train_reals+15melanoma.pth')) + model.load_state_dict(torch.load(args.model_path)) model.eval() if "SAM" in args.data_path: - input_images = [str(f) for f in sorted(Path("/workspace/stylegan2-ada-pytorch/processed_dataset_512_SAM").rglob('*jpg')) if os.path.isfile(f)] + input_images = [str(f) for f in sorted(Path(args.data_path).rglob('*jpg')) if os.path.isfile(f)] y = [1 for i in range(len(input_images))] test_df = pd.DataFrame({'image_name': input_images, 'target': y}) elif "isic" in args.data_path: