diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..c394c0e
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,106 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+*config.py*
+*secrets.py*
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..e966724
--- /dev/null
+++ b/README.md
@@ -0,0 +1,67 @@
+tf.Transform example for building digital twin
+====================
+
+This repository is designed to quickly get you started with Machine Learning projects on Google Cloud Platform using tf.Transform.
+This code repository is linked to [this blogpost]([https://www.google.com).
+
+### Functionalities
+- preprocessing pipeline using tf.Transform (with Apache Beam) that runs on Cloud Dataflow or locally
+- model training (with Tensorflow) that runs locally or on ML Engine
+- ready to deploy saved models to deploy on ML Engine
+- starter code to use the saved model on ML Engine
+
+### Install dependencies
+**Note** You will need a Linux or Mac environment with Python 2.7.x to install the dependencies [1](#myfootnote1).
+
+Install the following dependencies:
+ * Install [Cloud SDK](https://cloud.google.com/sdk/)
+ * Install [gcloud](https://cloud.google.com/sdk/gcloud/)
+ * ```pip install -r requirements.txt```
+
+# Getting started
+
+You need to complete the following parts to run the code:
+- add trainer/secrets.py with your `PROJECT_ID` and `BUCKET` variable
+- upload data to your buckets, you can upload data/test.csv to test this code
+
+## Preprocess
+
+You can run preprocess.py in the cloud using:
+```
+python preprocess.py --cloud
+
+```
+
+To iterate/test your code, you can also run it locally on a sample of the dataset:
+```
+python preprocess.py
+```
+
+## Training Tensorflow model
+You can submit a ML Engine training job with:
+```
+gcloud ml-engine jobs submit training my_job \
+ --module-name trainer.task \
+ --staging-bucket gs:// \
+ --package-path trainer
+```
+Testing it locally:
+```
+gcloud ml-engine local train --package-path trainer \
+ --module-name trainer.task
+```
+
+## Deploy your trained model
+To deploy your model to ML Engine
+```
+gcloud ml-engine models create digitaltwin
+gcloud ml-engine versions create v1 --model=digitaltwin --origin=ORIGIN
+```
+To test the deployed model:
+```
+python predict.py
+```
+
+
+1: This code requires both Tensorflow and Apache Beam. Currently Tensorflow on Windows only supports Python 3.5.x and
+and Apache Beam doesn't support Python 3.x yet.
\ No newline at end of file
diff --git a/data/input_data.csv b/data/input_data.csv
new file mode 100644
index 0000000..0d3cccd
--- /dev/null
+++ b/data/input_data.csv
@@ -0,0 +1,16 @@
+BatchId;ButterMass;ButterTemperature;SugarMass;SugarHumidity;FlourMass;FlourHumidity;HeatingTime;MixingSpeed;MixingTime
+1;121;20;200;0.22;50;0.23;50;Max Speed;200
+2;244;23;410;0.19;99;0.21;80;Medium Speed;450
+3;110;26;190;0.20;46;0.19;33;Low Speed;210
+4;121;20;200;0.22;50;0.23;50;Max Speed;200
+5;244;23;410;0.19;99;0.21;80;Medium Speed;450
+6;110;26;190;0.20;46;0.19;33;Low Speed;210
+7;121;20;200;0.22;50;0.23;50;Max Speed;200
+8;244;23;410;0.19;99;0.21;80;Medium Speed;450
+9;110;26;190;0.20;46;0.19;33;Low Speed;210
+10;121;20;200;0.22;50;0.23;50;Max Speed;200
+11;244;23;410;0.19;99;0.21;80;Medium Speed;450
+12;110;26;190;0.20;46;0.19;33;Low Speed;210
+13;121;20;200;0.22;50;0.23;50;Max Speed;200
+14;244;23;410;0.19;99;0.21;80;Medium Speed;450
+15;110;26;190;0.20;46;0.19;33;Low Speed;210
\ No newline at end of file
diff --git a/data/output_data.csv b/data/output_data.csv
new file mode 100644
index 0000000..7ec5bb1
--- /dev/null
+++ b/data/output_data.csv
@@ -0,0 +1,16 @@
+BatchId;TotalVolume;Density;Temperature;Humidity;Energy;Problems
+1;305;1.2;45;0.26;0.302;No
+2;603;1.4;47;0.24;0.599;Yes, some chunks remain
+3;301;1.1;42;0.24;0.312;No
+4;305;1.2;45;0.26;0.302;No
+5;603;1.4;47;0.24;0.599;Yes, some chunks remain
+6;301;1.1;42;0.24;0.312;No
+7;305;1.2;45;0.26;0.302;No
+8;603;1.4;47;0.24;0.599;Yes, some chunks remain
+9;301;1.1;42;0.24;0.312;No
+10;305;1.2;45;0.26;0.302;No
+11;603;1.4;47;0.24;0.599;Yes, some chunks remain
+12;301;1.1;42;0.24;0.312;No
+13;305;1.2;45;0.26;0.302;No
+14;603;1.4;47;0.24;0.599;Yes, some chunks remain
+15;301;1.1;42;0.24;0.312;No
\ No newline at end of file
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..74ee2c0
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,57 @@
+from googleapiclient import discovery
+
+from trainer.config import PROJECT_ID
+
+
+def get_predictions(project, model, instances, version=None):
+ """Send json data to a deployed model for prediction.
+
+ Args:
+ project (str): GCP project where the ML Engine Model is deployed.
+ model (str): model name
+ instances ([Mapping[str: Any]]): Keys should be the names of Tensors
+ your deployed model expects as inputs. Values should be datatypes
+ convertible to Tensors, or (potentially nested) lists of datatypes
+ convertible to tensors.
+ version (str) version of the model to target
+
+ Returns:
+ Mapping[str: any]: dictionary of prediction results defined by the
+ model.
+
+ """
+ service = discovery.build('ml', 'v1')
+ name = 'projects/{}/models/{}'.format(project, model)
+
+ if version is not None:
+ name += '/versions/{}'.format(version)
+
+ response = service.projects().predict(
+ name=name,
+ body={'instances': instances}
+ ).execute()
+
+ if 'error' in response:
+ raise RuntimeError(response['error'])
+
+ return response['predictions']
+
+
+if __name__ == "__main__":
+ predictions = get_predictions(
+ project=PROJECT_ID,
+ model="digitaltwin",
+ instances=[
+ {
+ 'ButterMass':120,
+ 'ButterTemperature': 20,
+ 'SugarMass': 200,
+ 'SugarHumidity': 0.22,
+ 'FlourMass': 50,
+ 'FlourHumidity': 0.23,
+ 'HeatingTime': 50,
+ 'MixingSpeed': 'Max Speed',
+ 'MixingTime': 200,
+ }]
+ )
+ print(predictions)
diff --git a/preprocess.py b/preprocess.py
new file mode 100644
index 0000000..4a20833
--- /dev/null
+++ b/preprocess.py
@@ -0,0 +1,199 @@
+#!/usr/bin/python
+import argparse
+import logging
+import os
+import sys
+import tempfile
+from datetime import datetime
+
+import apache_beam as beam
+import tensorflow as tf
+import tensorflow_transform as tft
+from apache_beam.io import tfrecordio
+from tensorflow_transform import coders
+from tensorflow_transform.beam import impl as beam_impl
+from tensorflow_transform.beam.tft_beam_io import transform_fn_io
+from tensorflow_transform.tf_metadata import dataset_metadata
+
+from trainer.config import BUCKET, TRAIN_INPUT_DATA, TRAIN_OUTPUT_DATA, TFRECORD_DIR, MODEL_DIR, \
+ input_schema, output_schema, example_schema
+
+delimiter = ';'
+converter_input = coders.CsvCoder(
+ ['BatchId', 'ButterMass', 'ButterTemperature', 'SugarMass', 'SugarHumidity', 'FlourMass', 'FlourHumidity',
+ 'HeatingTime', 'MixingSpeed', 'MixingTime'],
+ input_schema,
+ delimiter=delimiter)
+converter_output = coders.CsvCoder(
+ ['BatchId', 'TotalVolume', 'Density', 'Temperature', 'Humidity', 'Energy', 'Problems'],
+ output_schema,
+ delimiter=delimiter)
+input_metadata = dataset_metadata.DatasetMetadata(schema=example_schema)
+
+
+def extract_batchkey(record):
+ """Extracts the BatchId out of the record
+ Args:
+ record (dict): record of decoded CSV line
+ Returns:
+ tuple: tuple of BatchId and record
+ """
+ return (record['BatchId'], record)
+
+
+def remove_keys(item):
+ """Clean CoGroupByKey result by removing the keys
+ Args:
+ item: result of CoGroupByKey
+ Returns:
+ dict: dict with item removed of the key
+ """
+ key, vals = item
+ if len(vals[0]) == 1 and len(vals[1]) == 1:
+ example = vals[0][0]
+ example.update(vals[1][0])
+ yield example
+
+
+def preprocessing_fn(inputs):
+ """
+ Preprocess input columns into transformed columns.
+ Args:
+ inputs (dict): dict of input columns
+ Returns:
+ output dict of transformed columns
+ """
+ outputs = {}
+ # Encode categorical column:
+ outputs['MixingSpeed'] = tft.string_to_int(inputs['MixingSpeed'])
+ outputs['ButterMass'] = inputs['ButterMass']
+ # Calculate Derived Features:
+ outputs['TotalMass'] = inputs['ButterMass'] + inputs['SugarMass'] + inputs['FlourMass']
+ for ingredient in ['Butter', 'Sugar', 'Flour']:
+ ingredient_percentage = inputs['{}Mass'.format(ingredient)] / outputs['TotalMass']
+ outputs['Norm{}perc'.format(ingredient)] = tft.scale_to_z_score(ingredient_percentage)
+ # Keep absolute numeric columns
+ for key in ['TotalVolume', 'Energy']:
+ outputs[key] = inputs[key]
+ # Normalize other numeric columns
+ for key in [
+ 'ButterTemperature',
+ 'SugarHumidity',
+ 'FlourHumidity',
+ 'HeatingTime',
+ 'MixingTime',
+ 'Density',
+ 'Temperature',
+ 'Humidity',
+ ]:
+ outputs[key] = tft.scale_to_z_score(inputs[key])
+ # Extract Specific Problems
+ chunks_detected_str = tf.regex_replace(
+ input=inputs['Problems'],
+ pattern='.*chunk.*',
+ rewrite='chunk',
+ name='DetectChunk')
+ outputs['Chunks'] = tf.cast(tf.equal(chunks_detected_str, 'chunk'), tf.float32)
+ return outputs
+
+
+def parse_arguments(argv):
+ """Parse command line arguments
+ Args:
+ argv (list): list of command line arguments including program name
+ Returns:
+ The parsed arguments as returned by argparse.ArgumentParser
+ """
+ parser = argparse.ArgumentParser(description='Runs Preprocessing.')
+ parser.add_argument('--cloud',
+ action='store_true',
+ help='Run preprocessing on the cloud.')
+ args, _ = parser.parse_known_args(args=argv[1:])
+ return args
+
+
+def get_cloud_pipeline_options(project, output_dir):
+ """Get apache beam pipeline options to run with Dataflow on the cloud
+ Args:
+ project (str): GCP project to which job will be submitted
+ output_dir (str): GCS directory to which output will be written
+ Returns:
+ beam.pipeline.PipelineOptions
+ """
+ logging.warning('Start running in the cloud')
+
+ options = {
+ 'runner': 'DataflowRunner',
+ 'job_name': ('preprocessdigitaltwin-{}'.format(
+ datetime.now().strftime('%Y%m%d%H%M%S'))),
+ 'staging_location': os.path.join(BUCKET, 'staging'),
+ 'temp_location': os.path.join(BUCKET, 'tmp'),
+ 'project': project,
+ 'region': 'europe-west1',
+ 'zone': 'europe-west1-d',
+ 'autoscaling_algorithm': 'THROUGHPUT_BASED',
+ 'save_main_session': True,
+ 'setup_file': './setup.py',
+ }
+
+ return beam.pipeline.PipelineOptions(flags=[], **options)
+
+
+def main(argv=None):
+ """Run preprocessing as a Dataflow pipeline.
+ Args:
+ argv (list): list of arguments
+ """
+ args = parse_arguments(sys.argv if argv is None else argv)
+
+ if args.cloud:
+ pipeline_options = get_cloud_pipeline_options(args.project_id,
+ args.output_dir)
+ else:
+ pipeline_options = None
+
+ p = beam.Pipeline(options=pipeline_options)
+ with beam_impl.Context(temp_dir=tempfile.mkdtemp()):
+ # read data and join by key
+ raw_data_input = (
+ p
+ | 'ReadInputData' >> beam.io.ReadFromText(TRAIN_INPUT_DATA, skip_header_lines=1)
+ | 'ParseInputCSV' >> beam.Map(converter_input.decode)
+ | 'ExtractBatchKeyIn' >> beam.Map(extract_batchkey)
+ )
+
+ raw_data_output = (
+ p
+ | 'ReadOutputData' >> beam.io.ReadFromText(TRAIN_OUTPUT_DATA, skip_header_lines=1)
+ | 'ParseOutputCSV' >> beam.Map(converter_output.decode)
+ | 'ExtractBatchKeyOut' >> beam.Map(extract_batchkey)
+ )
+
+ raw_data = (
+ (raw_data_input, raw_data_output)
+ | 'JoinData' >> beam.CoGroupByKey()
+ | 'RemoveKeys' >> beam.FlatMap(remove_keys)
+ )
+
+ # analyse and transform dataset
+ raw_dataset = (raw_data, input_metadata)
+ transform_fn = raw_dataset | beam_impl.AnalyzeDataset(preprocessing_fn)
+ transformed_dataset = (raw_dataset, transform_fn) | beam_impl.TransformDataset()
+ transformed_data, transformed_metadata = transformed_dataset
+
+ # save data and serialize TransformFn
+ transformed_data_coder = tft.coders.ExampleProtoCoder(
+ transformed_metadata.schema)
+ _ = (transformed_data
+ | 'EncodeData' >> beam.Map(transformed_data_coder.encode)
+ | 'WriteData' >> tfrecordio.WriteToTFRecord(
+ os.path.join(TFRECORD_DIR, 'records')))
+ _ = (transform_fn
+ | "WriteTransformFn" >>
+ transform_fn_io.WriteTransformFn(MODEL_DIR))
+
+ p.run().wait_until_finish()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..5ffbb8a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+apache-beam[gcp]==2.4.0
+tensorflow==1.8.0
+tensorflow-transform==0.6.0
+google-api-python-client==1.6.4
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..27304fc
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,15 @@
+#!/usr/bin/python
+from setuptools import find_packages
+from setuptools import setup
+
+setup(
+ name='tft-demo',
+ version='0.1',
+ author='Matthias Feys',
+ author_email='matthiasfeys@gmail.com',
+ install_requires=['tensorflow==1.8.0',
+ 'tensorflow-transform==0.6.0'],
+ packages=find_packages(exclude=['data']),
+ description='tf.Transform demo for digital twin',
+ url='https://github.com/Fematich/mlengine-boilerplate'
+)
diff --git a/trainer/__init__.py b/trainer/__init__.py
new file mode 100644
index 0000000..693737f
--- /dev/null
+++ b/trainer/__init__.py
@@ -0,0 +1 @@
+#!/usr/bin/python
\ No newline at end of file
diff --git a/trainer/config.py b/trainer/config.py
new file mode 100644
index 0000000..54a7127
--- /dev/null
+++ b/trainer/config.py
@@ -0,0 +1,52 @@
+#!/usr/bin/python
+import tensorflow as tf
+from tensorflow_transform.tf_metadata import dataset_schema
+from secrets import PROJECT_ID, BUCKET
+
+DATA_DIR = BUCKET + '/data'
+TRAIN_INPUT_DATA = DATA_DIR + '/input_data.csv'
+TRAIN_OUTPUT_DATA = DATA_DIR + '/output_data.csv'
+TFRECORD_DIR = BUCKET + '/tfrecords2/*'
+MODEL_DIR = BUCKET + '/model2'
+BATCH_SIZE = 64
+
+input_schema = dataset_schema.from_feature_spec({
+ 'BatchId': tf.FixedLenFeature(shape=[], dtype=tf.string),
+ 'ButterMass': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'ButterTemperature': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'SugarMass': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'SugarHumidity': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'FlourMass': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'FlourHumidity': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'HeatingTime': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'MixingSpeed': tf.FixedLenFeature(shape=[], dtype=tf.string),
+ 'MixingTime': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+})
+
+output_schema = dataset_schema.from_feature_spec({
+ 'BatchId': tf.FixedLenFeature(shape=[], dtype=tf.string),
+ 'TotalVolume': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'Density': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'Temperature': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'Humidity': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'Energy': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'Problems': tf.FixedLenFeature(shape=[], dtype=tf.string),
+})
+
+example_schema = dataset_schema.from_feature_spec({
+ 'ButterMass': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'ButterTemperature': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'SugarMass': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'SugarHumidity': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'FlourMass': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'FlourHumidity': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'HeatingTime': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'MixingSpeed': tf.FixedLenFeature(shape=[], dtype=tf.string),
+ 'MixingTime': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'TotalVolume': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'Density': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'Temperature': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'Humidity': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'Energy': tf.FixedLenFeature(shape=[], dtype=tf.float32),
+ 'Problems': tf.FixedLenFeature(shape=[], dtype=tf.string),
+})
diff --git a/trainer/model.py b/trainer/model.py
new file mode 100644
index 0000000..be94d14
--- /dev/null
+++ b/trainer/model.py
@@ -0,0 +1,100 @@
+#!/usr/bin/python
+import tensorflow as tf
+
+
+def inference(features):
+ """Creates the predictions of the model
+
+ Args:
+ features (dict): A dictionary of tensors keyed by the feature name.
+
+ Returns:
+ A dict of tensors that represents the predictions
+
+ """
+ with tf.variable_scope('model'):
+ input_features = tf.concat([tf.expand_dims(features[key],1) for key in
+ [u'MixingTime', u'SugarHumidity', u'TotalMass', u'ButterTemperature',
+ u'NormButterperc', u'ButterMass', u'HeatingTime',
+ u'NormSugarperc', u'FlourHumidity']], 1)
+ hidden = tf.layers.dense(inputs=input_features,
+ units=5,
+ name='dense_weights_1',
+ use_bias=True)
+ predictions = tf.layers.dense(inputs=hidden,
+ units=1,
+ name='dense_weights_2',
+ use_bias=True)
+ predictions_squeezed = tf.squeeze(predictions)
+ return {'TotalVolume': predictions_squeezed}
+
+
+def loss(predictions, labels):
+ """Function that calculates the loss based on the predictions and labels
+
+ Args:
+ predictions (dict): A dictionary of tensors representing the predictions
+ labels (dict): A dictionary of tensors representing the labels.
+
+ Returns:
+ A tensor representing the loss
+
+ """
+ with tf.variable_scope('loss'):
+ return tf.losses.mean_squared_error(predictions['TotalVolume'], labels['TotalVolume'])
+
+
+def build_model_fn():
+ """Build model function as input for estimator.
+
+ Returns:
+ function: model function
+
+ """
+
+ def _model_fn(features, labels, mode, params):
+ """Creates the prediction and its loss.
+
+ Args:
+ features (dict): A dictionary of tensors keyed by the feature name.
+ labels (dict): A dictionary of tensors representing the labels.
+ mode: The execution mode, defined in tf.estimator.ModeKeys.
+
+ Returns:
+ tf.estimator.EstimatorSpec: EstimatorSpec object containing mode,
+ predictions, loss, train_op and export_outputs.
+
+ """
+ predictions = inference(features)
+ loss_op = None
+ train_op = None
+
+ if mode != tf.estimator.ModeKeys.PREDICT:
+ loss_op = loss(predictions, labels)
+
+ if mode == tf.estimator.ModeKeys.TRAIN:
+ train_op = tf.contrib.layers.optimize_loss(
+ loss=loss_op,
+ global_step=tf.train.get_global_step(),
+ learning_rate=params['learning_rate'],
+ optimizer='Adagrad',
+ summaries=[
+ 'learning_rate',
+ 'loss',
+ 'gradients',
+ 'gradient_norm',
+ ],
+ name='train')
+
+ export_outputs = {
+ tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ tf.estimator.export.PredictOutput(predictions)}
+
+ return tf.estimator.EstimatorSpec(
+ mode=mode,
+ predictions=predictions,
+ loss=loss_op,
+ train_op=train_op,
+ export_outputs=export_outputs)
+
+ return _model_fn
diff --git a/trainer/task.py b/trainer/task.py
new file mode 100644
index 0000000..080ec19
--- /dev/null
+++ b/trainer/task.py
@@ -0,0 +1,26 @@
+#!/usr/bin/python
+import logging
+import tensorflow as tf
+
+from trainer.util import build_training_input_fn, build_serving_input_fn
+from trainer.model import build_model_fn
+from trainer.config import MODEL_DIR
+
+if __name__ == '__main__':
+ logger = logging.getLogger('task')
+ logging.basicConfig(format='%(asctime)s %(message)s')
+ logger.setLevel('INFO')
+ estimator = tf.estimator.Estimator(
+ model_fn=build_model_fn(),
+ model_dir=MODEL_DIR,
+ params={'learning_rate': 0.001})
+
+ train_spec = tf.estimator.TrainSpec(input_fn=build_training_input_fn(), max_steps=200)
+ eval_spec = tf.estimator.EvalSpec(input_fn=build_training_input_fn(), steps=64)
+ tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
+ logger.info('model is trained')
+
+ serving_input_fn = build_serving_input_fn()
+ estimator.export_savedmodel(
+ MODEL_DIR, serving_input_fn)
+ logger.info('model is saved')
diff --git a/trainer/util.py b/trainer/util.py
new file mode 100644
index 0000000..eba66ae
--- /dev/null
+++ b/trainer/util.py
@@ -0,0 +1,70 @@
+#!/usr/bin/python
+import os
+import tensorflow as tf
+
+from tensorflow_transform.beam.tft_beam_io import transform_fn_io
+from tensorflow_transform.saved import saved_transform_io
+from tensorflow_transform.tf_metadata import metadata_io
+
+from trainer.config import TFRECORD_DIR, BATCH_SIZE, MODEL_DIR
+from trainer.config import input_schema
+
+def build_training_input_fn():
+ """Creates an input function reading from transformed data.
+ Args:
+ transformed_examples: Base filename of examples.
+ Returns:
+ The input function for training or eval.
+ """
+ transformed_metadata = metadata_io.read_metadata(
+ os.path.join(
+ MODEL_DIR, transform_fn_io.TRANSFORMED_METADATA_DIR))
+ transformed_feature_spec = transformed_metadata.schema.as_feature_spec()
+
+ def input_fn():
+ """Input function for training and eval."""
+ dataset = tf.contrib.data.make_batched_features_dataset(
+ file_pattern=TFRECORD_DIR,
+ batch_size=BATCH_SIZE,
+ features=transformed_feature_spec,
+ reader=tf.data.TFRecordDataset,
+ shuffle=True)
+ transformed_features = dataset.make_one_shot_iterator().get_next()
+ # Extract features and labels from the transformed tensors.
+ label_cols = set(['TotalVolume', 'Density', 'Temperature', 'Humidity', 'Energy', 'Problems'])
+ transformed_labels = {key: value for (key, value) in transformed_features.items() if key in label_cols}
+ transformed_features = {key: value for (key, value) in transformed_features.items() if key not in label_cols}
+ return transformed_features, transformed_labels
+
+ return input_fn
+
+def build_serving_input_fn():
+ """Creates an input function reading from raw data.
+ Args:
+ Returns:
+ The serving input function.
+ """
+ raw_feature_spec = input_schema.as_feature_spec()
+ raw_feature_spec.pop('BatchId')
+
+ def serving_input_fn():
+ """Input function for serving."""
+ # Get raw features by generating the basic serving input_fn and calling it.
+ # Here we generate an input_fn that expects a parsed Example proto to be fed
+ # to the model at serving time. See also
+ # tf.estimator.export.build_raw_serving_input_receiver_fn.
+ raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
+ raw_feature_spec)
+ raw_features, _, default_inputs = raw_input_fn()
+ # Apply the transform function that was used to generate the materialized
+ # data.
+ _, transformed_features = (
+ saved_transform_io.partially_apply_saved_transform(
+ os.path.join(MODEL_DIR, transform_fn_io.TRANSFORM_FN_DIR),
+ raw_features))
+
+ return tf.estimator.export.ServingInputReceiver(transformed_features, raw_features)
+
+ return serving_input_fn
+
+