diff --git a/tensorflow/image_recognition/classify_image.py b/tensorflow/image_recognition/classify_image.py index c2850f5..e60c355 100644 --- a/tensorflow/image_recognition/classify_image.py +++ b/tensorflow/image_recognition/classify_image.py @@ -1,4 +1,5 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2015 The TensorFlow Authors, 2019 Analytics Zoo Authors. +# All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,6 +45,8 @@ import numpy as np from six.moves import urllib import tensorflow as tf +from zoo import init_nncontext +from zoo.pipeline.api.net import TFDataset, TFPredictor FLAGS = None @@ -117,14 +120,27 @@ def id_to_string(self, node_id): return self.node_lookup[node_id] -def create_graph(): +def create_graph(image_array): """Creates a graph from saved GraphDef file and returns a saver.""" # Creates graph from saved graph_def.pb. + h, w, c = image_array.shape + # get the TFDataset + sc = init_nncontext() + image_rdd = sc.parallelize(image_array[None, ...]).map(lambda x: [x]) + image_dataset = TFDataset.from_rdd(image_rdd, + names=['features'], + shapes=[[w, h, c]], + types=[tf.uint8], + batch_per_thread=1, + hard_code_batch_size=True) + image_tensor = image_dataset.tensors[0] with tf.gfile.FastGFile(os.path.join( FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) - _ = tf.import_graph_def(graph_def, name='') + _ = tf.import_graph_def(graph_def, + input_map={'DecodeJpeg:0': image_tensor[0]}, + name='') def run_inference_on_image(image): @@ -138,10 +154,10 @@ def run_inference_on_image(image): """ if not tf.gfile.Exists(image): tf.logging.fatal('File does not exist %s', image) - image_data = tf.gfile.FastGFile(image, 'rb').read() - + from PIL import Image + image_array = np.array(Image.open(image))[:, :, 0:3] # Creates graph from saved GraphDef. - create_graph() + create_graph(image_array) with tf.Session() as sess: # Some useful tensors: @@ -153,8 +169,9 @@ def run_inference_on_image(image): # encoding of the image. # Runs the softmax tensor by feeding the image_data as input to the graph. softmax_tensor = sess.graph.get_tensor_by_name('softmax:0') - predictions = sess.run(softmax_tensor, - {'DecodeJpeg/contents:0': image_data}) + predictor = TFPredictor(sess, [softmax_tensor]) + predictions = predictor.predict().collect() + predictions = np.squeeze(predictions) # Creates node ID --> English string lookup.