From cf6dc53dc4ee0f2d847f3c09f1cd6338f13261e8 Mon Sep 17 00:00:00 2001 From: alexdrydew Date: Wed, 20 Jul 2022 19:14:48 +0300 Subject: [PATCH] Support for latent_space=constant --- export_model.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/export_model.py b/export_model.py index 553fd2d..a5a163f 100644 --- a/export_model.py +++ b/export_model.py @@ -1,6 +1,7 @@ import os import argparse import tensorflow as tf +import numpy as np import tf2onnx from pathlib import Path @@ -20,7 +21,7 @@ def main(): parser.add_argument('--latent_space', choices=['normal', 'uniform', 'constant', 'none'], default='normal') parser.add_argument('--latent_dim', type=int, default=32, required=False) - parser.add_argument('--constant_latent', type=float, default=None) + parser.add_argument('--constant_latent', type=float, default=0.5) parser.add_argument('--export_format', choices=['pbtxt', 'onnx'], default='pbtxt') @@ -87,6 +88,14 @@ def to_save(x): output_path=Path(args.output_path) / f'{args.checkpoint_name}.onnx', ) + if args.test_input: + test_output = to_save(tf.convert_to_tensor([args.test_input])) + print('Model test output:') + print('Input:') + print(args.test_input) + print('Output') + print(*(f'{num:.4f}' if num > 1e-16 else 0 for num in test_output.numpy().flatten())) + if args.upload_to_mlflow: import mlflow @@ -114,6 +123,11 @@ def latent_input_gen(batch_size): def latent_input_gen(batch_size): return tf.random.uniform(shape=(batch_size, args.latent_dim), dtype='float32') + elif args.latent_space == 'constant': + + def latent_input_gen(batch_size): + return tf.fill(dims=(batch_size, args.latent_dim), value=np.float32(args.constant_latent)) + if latent_input_gen is None: input_signature = [tf.TensorSpec(shape=[predefined_batch_size, 36], dtype=tf.float32)]