Skip to content
This repository has been archived by the owner on Jun 4, 2024. It is now read-only.

modify example with analytics-zoo #6

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions tensorflow/image_recognition/classify_image.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015 The TensorFlow Authors, 2019 Analytics Zoo Authors.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,6 +45,8 @@
import numpy as np
from six.moves import urllib
import tensorflow as tf
from zoo import init_nncontext
from zoo.pipeline.api.net import TFDataset, TFPredictor

FLAGS = None

Expand Down Expand Up @@ -117,14 +120,27 @@ def id_to_string(self, node_id):
return self.node_lookup[node_id]


def create_graph():
def create_graph(image_array):
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
h, w, c = image_array.shape
# get the TFDataset
sc = init_nncontext()
image_rdd = sc.parallelize(image_array[None, ...]).map(lambda x: [x])
image_dataset = TFDataset.from_rdd(image_rdd,
names=['features'],
shapes=[[w, h, c]],
types=[tf.uint8],
batch_per_thread=1,
hard_code_batch_size=True)
image_tensor = image_dataset.tensors[0]
with tf.gfile.FastGFile(os.path.join(
FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
_ = tf.import_graph_def(graph_def,
input_map={'DecodeJpeg:0': image_tensor[0]},
name='')


def run_inference_on_image(image):
Expand All @@ -138,10 +154,10 @@ def run_inference_on_image(image):
"""
if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(image, 'rb').read()

from PIL import Image
image_array = np.array(Image.open(image))[:, :, 0:3]
# Creates graph from saved GraphDef.
create_graph()
create_graph(image_array)

with tf.Session() as sess:
# Some useful tensors:
Expand All @@ -153,8 +169,9 @@ def run_inference_on_image(image):
# encoding of the image.
# Runs the softmax tensor by feeding the image_data as input to the graph.
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictor = TFPredictor(sess, [softmax_tensor])
predictions = predictor.predict().collect()

predictions = np.squeeze(predictions)

# Creates node ID --> English string lookup.
Expand Down