From 098619f3361a33578d0b1feac977ffcf108e02ec Mon Sep 17 00:00:00 2001 From: bozhou Date: Mon, 18 Mar 2019 16:38:13 +0800 Subject: [PATCH 1/3] modify example with analytics-zoo --- .../image_recognition/classify_image.py | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/tensorflow/image_recognition/classify_image.py b/tensorflow/image_recognition/classify_image.py index c2850f5..4e66db8 100644 --- a/tensorflow/image_recognition/classify_image.py +++ b/tensorflow/image_recognition/classify_image.py @@ -44,6 +44,8 @@ import numpy as np from six.moves import urllib import tensorflow as tf +from zoo import init_nncontext +from zoo.pipeline.api.net import TFDataset, TFPredictor FLAGS = None @@ -117,14 +119,27 @@ def id_to_string(self, node_id): return self.node_lookup[node_id] -def create_graph(): +def create_graph(image_array): """Creates a graph from saved GraphDef file and returns a saver.""" # Creates graph from saved graph_def.pb. + w, h, c = image_array.shape + # get the TFDataset + sc = init_nncontext() + image_rdd = sc.parallelize(image_array[None, ...]).map(lambda x: [x]) + image_dataset = TFDataset.from_rdd(image_rdd, + names=['features'], + shapes=[[w, h, c]], + types=[tf.uint8], + batch_per_thread=1, + hard_code_batch_size=True) + image_tensor = image_dataset.tensors[0] with tf.gfile.FastGFile(os.path.join( FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) - _ = tf.import_graph_def(graph_def, name='') + _ = tf.import_graph_def(graph_def, + input_map={'DecodeJpeg:0': image_tensor[0]}, + name='zoo') def run_inference_on_image(image): @@ -138,10 +153,10 @@ def run_inference_on_image(image): """ if not tf.gfile.Exists(image): tf.logging.fatal('File does not exist %s', image) - image_data = tf.gfile.FastGFile(image, 'rb').read() - + from PIL import Image + image_array = np.array(Image.open(image))[:, :, 0:3] # Creates graph from saved GraphDef. - create_graph() + create_graph(image_array) with tf.Session() as sess: # Some useful tensors: @@ -152,9 +167,10 @@ def run_inference_on_image(image): # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG # encoding of the image. # Runs the softmax tensor by feeding the image_data as input to the graph. - softmax_tensor = sess.graph.get_tensor_by_name('softmax:0') - predictions = sess.run(softmax_tensor, - {'DecodeJpeg/contents:0': image_data}) + softmax_tensor = sess.graph.get_tensor_by_name('zoo/softmax:0') + predictor = TFPredictor(sess, [softmax_tensor]) + predictions = predictor.predict().collect() + predictions = np.squeeze(predictions) # Creates node ID --> English string lookup. From ae57e1471b59d49323d10635c11d8cf56ad09b20 Mon Sep 17 00:00:00 2001 From: bozhou Date: Wed, 20 Mar 2019 10:28:34 +0800 Subject: [PATCH 2/3] Fix some bugs --- tensorflow/image_recognition/classify_image.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/image_recognition/classify_image.py b/tensorflow/image_recognition/classify_image.py index 4e66db8..2b238a7 100644 --- a/tensorflow/image_recognition/classify_image.py +++ b/tensorflow/image_recognition/classify_image.py @@ -122,7 +122,7 @@ def id_to_string(self, node_id): def create_graph(image_array): """Creates a graph from saved GraphDef file and returns a saver.""" # Creates graph from saved graph_def.pb. - w, h, c = image_array.shape + h, w, c = image_array.shape # get the TFDataset sc = init_nncontext() image_rdd = sc.parallelize(image_array[None, ...]).map(lambda x: [x]) @@ -139,7 +139,7 @@ def create_graph(image_array): graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(graph_def, input_map={'DecodeJpeg:0': image_tensor[0]}, - name='zoo') + name='') def run_inference_on_image(image): @@ -167,7 +167,7 @@ def run_inference_on_image(image): # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG # encoding of the image. # Runs the softmax tensor by feeding the image_data as input to the graph. - softmax_tensor = sess.graph.get_tensor_by_name('zoo/softmax:0') + softmax_tensor = sess.graph.get_tensor_by_name('softmax:0') predictor = TFPredictor(sess, [softmax_tensor]) predictions = predictor.predict().collect() From 0f066af4f39b61e333839d27cfe35e543b93b237 Mon Sep 17 00:00:00 2001 From: bozhou Date: Thu, 21 Mar 2019 11:10:50 +0800 Subject: [PATCH 3/3] Update license --- tensorflow/image_recognition/classify_image.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/image_recognition/classify_image.py b/tensorflow/image_recognition/classify_image.py index 2b238a7..e60c355 100644 --- a/tensorflow/image_recognition/classify_image.py +++ b/tensorflow/image_recognition/classify_image.py @@ -1,4 +1,5 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2015 The TensorFlow Authors, 2019 Analytics Zoo Authors. +# All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.