forked from aws/amazon-sagemaker-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
iris_dnn_classifier.py
32 lines (26 loc) · 1.26 KB
/
iris_dnn_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os
import numpy as np
import tensorflow as tf
INPUT_TENSOR_NAME = 'inputs'
def estimator_fn(run_config, params):
feature_columns = [tf.feature_column.numeric_column(INPUT_TENSOR_NAME, shape=[4])]
return tf.estimator.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3,
config=run_config)
def serving_input_fn():
feature_spec = {INPUT_TENSOR_NAME: tf.FixedLenFeature(dtype=tf.float32, shape=[4])}
return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)()
def train_input_fn(training_dir, params):
"""Returns input function that would feed the model during training"""
return _generate_input_fn(training_dir, 'iris_training.csv')
def _generate_input_fn(training_dir, training_filename):
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=os.path.join(training_dir, training_filename),
target_dtype=np.int,
features_dtype=np.float32)
return tf.estimator.inputs.numpy_input_fn(
x={INPUT_TENSOR_NAME: np.array(training_set.data)},
y=np.array(training_set.target),
num_epochs=None,
shuffle=True)