Skip to content

Commit

Permalink
updated directions for retrieving weights file
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeid committed Jul 4, 2017
1 parent 2979ff0 commit def6101
Showing 1 changed file with 5 additions and 14 deletions.
19 changes: 5 additions & 14 deletions src/custom_vgg16.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,10 @@ def __init__(self, vgg16_npy_path=None):
vgg16_npy_path = path
else:
print("VGG16 weights were not found in the project directory")

answer = 0
while answer is not 'y' and answer is not 'N':
answer = input("Would you like to download the 528 MB file? [y/N] ").replace(" ", "")

# Download weights if yes, else exit the program
if answer == 'y':
print("Downloading. Please be patient...")
urllib.request.urlretrieve(weights_url, weights_name)
vgg16_npy_path = path
elif answer == 'N':
print("Exiting the program..")
exit(0)
print("Please download the numpy weights file and place it in the 'lib/descriptor' directory")
print("Download link: https://mega.nz/#!YU1FWJrA!O1ywiCS2IiOlUCtCpI6HTJOMrneN-Qdv3ywQP5poecM")
print("Exiting the program..")
exit(1)

if data is None:
data = np.load(vgg16_npy_path, encoding='latin1')
Expand Down Expand Up @@ -122,4 +113,4 @@ def __get_bias(self, name):
return tf.constant(self.data_dict[name][1], name="biases")

def __get_fc_weight(self, name):
return tf.constant(self.data_dict[name][0], name="weights")
return tf.constant(self.data_dict[name][0], name="weights")

0 comments on commit def6101

Please sign in to comment.