-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
113 lines (101 loc) · 3.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Standard library imports
import os
import glob
import argparse
# Third party library imports
import torch
import torch.nn as nn
from torchvision import models
# Local imports
import config
import inference
import training
def main(args):
# updating all the global variables based on the input arguments
if(args.freeze_epochs):
config.FREEZE_EPOCHS = args.freeze_epochs
if(args.unfreeze_epochs):
config.UNFREEZE_EPOCHS = args.unfreeze_epochs
# updating batch size
if(args.batch_size):
config.PARAMS["batch_size"] = args.batch_size
# updating command line arguments to the ARGS variable
config.ARGS = args
# calling required functions based on the input arguments
if args.mode == "inference":
inference.inference()
else:
training.training(args)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--mode",
type=str,
default="inference",
help="inference mode or training mode")
# arguments for training
parser.add_argument(
"--batch_size",
type=int,
default=128,
help="the batch_size for training as well as for inference")
parser.add_argument(
"--freeze_epochs",
type=int,
default=1,
help="the total number of epochs for which the initial few layers will be frozen")
parser.add_argument(
"--unfreeze_epochs",
type=int,
default=200,
help="the total number of epochs for which the full network will be unfrozen")
parser.add_argument(
"--resume",
type=bool,
default=True,
help="Flag to resume the training from where it was stopped")
parser.add_argument(
"--checkpoint_name",
type=str,
default="checkpoint.pth",
help="the name of the checkpoint file where the weights will be saved")
parser.add_argument(
"--positive_image_dict",
type=str,
default="./zalando/cropped_images/top/positive_image_dict.pkl",
help="the dictionary containing information about positive images in this format: {'anchor_image':['postive_image_1', 'positive_image_2']")
parser.add_argument(
"--image_path",
type=str,
default="./zalando/cropped_images/top/images/",
help="the directory which has all the images on which the training will be done")
# arguments for inference
parser.add_argument(
"--proximities_from",
type=list,
help="list of images from which the similar images have to be shown")
parser.add_argument(
"--proximities_from_path",
type=str,
default="./zalando/cropped_images/top/images/",
help="the directory containing the image collection from which the similar images have to be shown")
parser.add_argument(
"--proximities_for",
type=list,
help="list of images for which the similar images have to be shown")
parser.add_argument(
"--proximities_for_path",
type=str,
default="./zalando/cropped_images/top/images/",
help="the directory of images for which the similar images have to be shown")
parser.add_argument(
"--top_count",
type=int,
default=50,
help="the number of similar images to be shown")
parser.add_argument(
"--inference_output_path",
type=str,
default="./output/results/",
help="the output directory where the inference output as images will be stored")
main(parser.parse_args())