diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 893b2b2ee6..718ec4ec85 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -68,7 +68,7 @@ Try those commands below: Addons provides `make code-format` command to format your changes automatically, don't forget to use it before pushing your codes. -Please see our [Style Guide](SYLE_GUIDE.md) for more details. +Please see our [Style Guide](STYLE_GUIDE.md) for more details. ## Code Testing #### CI Testing diff --git a/README.md b/README.md index efdb28a8db..db6bcc18b9 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,17 @@ # TensorFlow Addons +[![PyPI Status Badge](https://badge.fury.io/py/tensorflow-addons.svg)](https://pypi.org/project/tensorflow-addons/) +[![Gitter chat](https://img.shields.io/badge/chat-on%20gitter-46bc99.svg)](https://gitter.im/tensorflow/sig-addons) + +### Official Builds + +| Build Type | Status | +| --- | --- | +| **Linux Py2 CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-py2.html) | +| **Linux Py3 CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-py3.html) | +| **Linux Py2 GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py2.html) | +| **Linux Py3 GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py3.html) | + TensorFlow Addons is a repository of contributions that conform to well-established API patterns, but implement new functionality not available in core TensorFlow. TensorFlow natively supports @@ -13,9 +25,14 @@ developments that cannot be integrated into core TensorFlow | Sub-Package | Addon | Reference | |:----------------------- |:----------- |:---------------------------- | | tfa.activations | Sparsemax | https://arxiv.org/abs/1602.02068 | +| tfa.image | adjust_hsv_in_yiq | | +| tfa.image | random_hsv_in_yiq | | | tfa.image | transform | | +| tfa.layers | GroupNormalization | https://arxiv.org/abs/1803.08494 | +| tfa.layers | InstanceNormalization | https://arxiv.org/abs/1607.08022 | +| tfa.layers | LayerNormalization | https://arxiv.org/abs/1607.06450 | | tfa.layers | Maxout | https://arxiv.org/abs/1302.4389 | -| tfa.layers | PoinareNormalize | https://arxiv.org/abs/1705.08039 | +| tfa.layers | PoincareNormalize | https://arxiv.org/abs/1705.08039 | | tfa.layers | WeightNormalization | https://arxiv.org/abs/1602.07868 | | tfa.losses | LiftedStructLoss | https://arxiv.org/abs/1511.06452 | | tfa.losses | SparsemaxLoss | https://arxiv.org/abs/1602.02068 | @@ -33,9 +50,9 @@ the list we adhere to: 1) [Layers](tensorflow_addons/layers/README.md) -1) [Optimizers](tensorflow_addons/optimizers/README.md) -1) [Losses](tensorflow_addons/losses/README.md) -1) [Custom Ops](tensorflow_addons/custom_ops/README.md) +2) [Optimizers](tensorflow_addons/optimizers/README.md) +3) [Losses](tensorflow_addons/losses/README.md) +4) [Custom Ops](tensorflow_addons/custom_ops/README.md) #### Periodic Evaluation Based on the nature of this repository, there will be contributions that @@ -44,7 +61,6 @@ maintainable, SIG-Addons will perform periodic reviews and deprecate contributions which will be slated for removal. More information will be available after we submit a formal request for comment. - ## Examples See [`tensorflow_addons/examples/`](tensorflow_addons/examples/) for end-to-end examples of various addons. @@ -56,6 +72,15 @@ To install the latest version, run the following: pip install tensorflow-addons ``` +**Note:** You will also need [TensorFlow 2.0 or higher](https://www.tensorflow.org/alpha). + +To use addons: + +```python +import tensorflow as tf +import tensorflow_addons as tfa +``` + #### Installing from Source You can also install from source. This requires the [Bazel]( https://bazel.build/) build system. diff --git a/tensorflow_addons/activations/BUILD b/tensorflow_addons/activations/BUILD index db0ac9bc21..dabcbf35a0 100644 --- a/tensorflow_addons/activations/BUILD +++ b/tensorflow_addons/activations/BUILD @@ -17,7 +17,7 @@ py_library( py_test( name = "sparsemax_py_test", - size = "small", + size = "medium", srcs = [ "python/sparsemax_test.py", ], diff --git a/tensorflow_addons/activations/README.md b/tensorflow_addons/activations/README.md index 24399bff3e..dd533d8644 100644 --- a/tensorflow_addons/activations/README.md +++ b/tensorflow_addons/activations/README.md @@ -19,7 +19,9 @@ must: #### Testing Requirements * Simple unittests that demonstrate the layer is behaving as expected. * When applicable, run all unittests with TensorFlow's - `@run_all_in_graph_and_eager_modes` decorator. + `@run_in_graph_and_eager_modes` (for test method) + or `run_all_in_graph_and_eager_modes` (for TestCase subclass) + decorator. * Add a `py_test` to this sub-package's BUILD file. #### Documentation Requirements diff --git a/tensorflow_addons/custom_ops/README.md b/tensorflow_addons/custom_ops/README.md index 1f445f6451..5cc79861f8 100644 --- a/tensorflow_addons/custom_ops/README.md +++ b/tensorflow_addons/custom_ops/README.md @@ -20,7 +20,9 @@ must: * Simple unittests that demonstrate the custom op is behaving as expected. * When applicable, run all unittests with TensorFlow's - `@run_all_in_graph_and_eager_modes` decorator. + `@run_in_graph_and_eager_modes` (for test method) + or `run_all_in_graph_and_eager_modes` (for TestCase subclass) + decorator. * Add a `py_test` to the custom-op's BUILD file. #### Documentation Requirements diff --git a/tensorflow_addons/custom_ops/image/BUILD b/tensorflow_addons/custom_ops/image/BUILD index df9bb998e3..575f76203e 100644 --- a/tensorflow_addons/custom_ops/image/BUILD +++ b/tensorflow_addons/custom_ops/image/BUILD @@ -2,6 +2,22 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) +cc_binary( + name = "python/_distort_image_ops.so", + srcs = [ + "cc/kernels/adjust_hsv_in_yiq_op.cc", + "cc/kernels/adjust_hsv_in_yiq_op.h", + "cc/ops/distort_image_ops.cc", + ], + linkshared = 1, + deps = [ + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ], + copts = ["-pthread", "-std=c++11", "-D_GLIBCXX_USE_CXX11_ABI=0"] +) + + cc_binary( name = "python/_image_ops.so", srcs = [ @@ -26,18 +42,34 @@ py_library( srcs = ([ "__init__.py", "python/__init__.py", + "python/distort_image_ops.py", "python/transform.py", ]), data = [ + ":python/_distort_image_ops.so", ":python/_image_ops.so", + "//tensorflow_addons/utils:utils_py", ], srcs_version = "PY2AND3", ) +py_test( + name = "distort_image_ops_test", + size = "small", + srcs = [ + "python/distort_image_ops_test.py", + ], + main = "python/distort_image_ops_test.py", + deps = [ + ":images_ops_py", + ], + srcs_version = "PY2AND3" +) + # TODO: use cuda_py_test later. py_test( name = "transform_ops_test", - size = "small", + size = "medium", srcs = [ "python/transform_test.py", ], diff --git a/tensorflow_addons/custom_ops/image/__init__.py b/tensorflow_addons/custom_ops/image/__init__.py index 2d8efa2e13..c39f840d3e 100644 --- a/tensorflow_addons/custom_ops/image/__init__.py +++ b/tensorflow_addons/custom_ops/image/__init__.py @@ -17,5 +17,7 @@ from __future__ import division from __future__ import print_function +from tensorflow_addons.custom_ops.image.python.distort_image_ops import adjust_hsv_in_yiq +from tensorflow_addons.custom_ops.image.python.distort_image_ops import random_hsv_in_yiq # Transforms from tensorflow_addons.custom_ops.image.python.transform import transform diff --git a/tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.cc b/tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.cc new file mode 100644 index 0000000000..3cfc87fb13 --- /dev/null +++ b/tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.cc @@ -0,0 +1,169 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/work_sharder.h" +#include "tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +class AdjustHsvInYiqOpBase : public OpKernel { + protected: + explicit AdjustHsvInYiqOpBase(OpKernelConstruction* context) + : OpKernel(context) {} + + struct ComputeOptions { + const Tensor* input = nullptr; + Tensor* output = nullptr; + const Tensor* delta_h = nullptr; + const Tensor* scale_s = nullptr; + const Tensor* scale_v = nullptr; + int64 channel_count = 0; + }; + + virtual void DoCompute(OpKernelContext* context, + const ComputeOptions& options) = 0; + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& delta_h = context->input(1); + const Tensor& scale_s = context->input(2); + const Tensor& scale_v = context->input(3); + OP_REQUIRES(context, input.dims() >= 3, + errors::InvalidArgument("input must be at least 3-D, got shape", + input.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_h.shape()), + errors::InvalidArgument("delta_h must be scalar: ", + delta_h.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_s.shape()), + errors::InvalidArgument("scale_s must be scalar: ", + scale_s.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_v.shape()), + errors::InvalidArgument("scale_v must be scalar: ", + scale_v.shape().DebugString())); + auto channels = input.dim_size(input.dims() - 1); + OP_REQUIRES( + context, channels == kChannelSize, + errors::InvalidArgument("input must have 3 channels but instead has ", + channels, " channels.")); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + + if (input.NumElements() > 0) { + const int64 channel_count = input.NumElements() / channels; + ComputeOptions options; + options.input = &input; + options.delta_h = &delta_h; + options.scale_s = &scale_s; + options.scale_v = &scale_v; + options.output = output; + options.channel_count = channel_count; + DoCompute(context, options); + } + } +}; + +template +class AdjustHsvInYiqOp; + +template <> +class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { + public: + explicit AdjustHsvInYiqOp(OpKernelConstruction* context) + : AdjustHsvInYiqOpBase(context) {} + + void DoCompute(OpKernelContext* context, + const ComputeOptions& options) override { + const Tensor* input = options.input; + Tensor* output = options.output; + const int64 channel_count = options.channel_count; + auto input_data = input->shaped({channel_count, kChannelSize}); + const float delta_h = options.delta_h->scalar()(); + const float scale_s = options.scale_s->scalar()(); + const float scale_v = options.scale_v->scalar()(); + auto output_data = output->shaped({channel_count, kChannelSize}); + float tranformation_matrix[kChannelSize * kChannelSize] = {0}; + internal::compute_tranformation_matrix( + delta_h, scale_s, scale_v, tranformation_matrix); + const int kCostPerChannel = 10; + const DeviceBase::CpuWorkerThreads& worker_threads = + *context->device()->tensorflow_cpu_worker_threads(); + Shard(worker_threads.num_threads, worker_threads.workers, channel_count, + kCostPerChannel, [&input_data, &output_data, &tranformation_matrix]( + int64 start_channel, int64 end_channel) { + // Applying projection matrix to input RGB vectors. + const float* p = input_data.data() + start_channel * kChannelSize; + float* q = output_data.data() + start_channel * kChannelSize; + for (int i = start_channel; i < end_channel; i++) { + for (int q_index = 0; q_index < kChannelSize; q_index++) { + q[q_index] = 0; + for (int p_index = 0; p_index < kChannelSize; p_index++) { + q[q_index] += + p[p_index] * + tranformation_matrix[q_index + kChannelSize * p_index]; + } + } + p += kChannelSize; + q += kChannelSize; + } + }); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("AdjustHsvInYiq").Device(DEVICE_CPU).TypeConstraint("T"), + AdjustHsvInYiqOp); + +#if GOOGLE_CUDA +template <> +class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { + public: + explicit AdjustHsvInYiqOp(OpKernelConstruction* context) + : AdjustHsvInYiqOpBase(context) {} + + void DoCompute(OpKernelContext* ctx, const ComputeOptions& options) override { + const int64 number_of_elements = options.input->NumElements(); + if (number_of_elements <= 0) { + return; + } + const float* delta_h = options.delta_h->flat().data(); + const float* scale_s = options.scale_s->flat().data(); + const float* scale_v = options.scale_v->flat().data(); + functor::AdjustHsvInYiqGPU()(ctx, options.channel_count, options.input, + delta_h, scale_s, scale_v, options.output); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("AdjustHsvInYiq").Device(DEVICE_GPU).TypeConstraint("T"), + AdjustHsvInYiqOp); +#endif + +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.h b/tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.h new file mode 100644 index 0000000000..6f587c9790 --- /dev/null +++ b/tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.h @@ -0,0 +1,87 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ +#define TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +static constexpr int kChannelSize = 3; + +namespace internal { + +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void compute_tranformation_matrix( + const float delta_h, const float scale_s, const float scale_v, + float* matrix) { + static_assert(MATRIX_SIZE == kChannelSize * kChannelSize, + "Size of matrix should be 9."); + // Projection matrix from RGB to YIQ. Numbers from wikipedia + // https://en.wikipedia.org/wiki/YIQ + Eigen::Matrix3f yiq; + /* clang-format off */ + yiq << 0.299, 0.587, 0.114, + 0.596, -0.274, -0.322, + 0.211, -0.523, 0.312; + Eigen::Matrix3f yiq_inverse; + yiq_inverse << 1, 0.95617069, 0.62143257, + 1, -0.2726886, -0.64681324, + 1, -1.103744, 1.70062309; + /* clang-format on */ + // Construct hsv linear transformation matrix in YIQ space. + // https://beesbuzz.biz/code/hsv_color_transforms.php + float vsu = scale_v * scale_s * std::cos(delta_h); + float vsw = scale_v * scale_s * std::sin(delta_h); + Eigen::Matrix3f hsv_transform; + /* clang-format off */ + hsv_transform << scale_v, 0, 0, + 0, vsu, -vsw, + 0, vsw, vsu; + /* clang-format on */ + // Compute final transformation matrix = inverse_yiq * hsv_transform * yiq + Eigen::Map> eigen_matrix(matrix); + eigen_matrix = yiq_inverse * hsv_transform * yiq; +} +} // namespace internal + +#if GOOGLE_CUDA +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +struct AdjustHsvInYiqGPU { + void operator()(OpKernelContext* ctx, int channel_count, + const Tensor* const input, const float* const delta_h, + const float* const scale_s, const float* const scale_v, + Tensor* const output); +}; + +} // namespace functor + +#endif // GOOGLE_CUDA + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ diff --git a/tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc b/tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc new file mode 100644 index 0000000000..f2a6bcc713 --- /dev/null +++ b/tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc @@ -0,0 +1,85 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.h" +#include "tensorflow/core/kernels/gpu_utils.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { + +namespace internal { + +__global__ void compute_tranformation_matrix_cuda(const float* const delta_h, + const float* const scale_s, + const float* const scale_v, + float* const matrix, + const int matrix_size) { + if (matrix_size == kChannelSize * kChannelSize) { + compute_tranformation_matrix( + *delta_h, *scale_s, *scale_v, matrix); + } +} +} // namespace internal + +namespace functor { + +void AdjustHsvInYiqGPU::operator()(OpKernelContext* ctx, int channel_count, + const Tensor* const input, + const float* const delta_h, + const float* const scale_s, + const float* const scale_v, + Tensor* const output) { + const uint64 m = channel_count; + const uint64 k = kChannelSize; + const uint64 n = kChannelSize; + auto* cu_stream = ctx->eigen_device().stream(); + OP_REQUIRES(ctx, cu_stream, errors::Internal("No GPU stream available.")); + Tensor tranformation_matrix; + OP_REQUIRES_OK(ctx, ctx->allocate_temp( + DT_FLOAT, TensorShape({kChannelSize * kChannelSize}), + &tranformation_matrix)); + // TODO(huangyp): It takes about 3.5 us to compute tranformation_matrix + // with one thread. Improve its performance if necessary. + TF_CHECK_OK(CudaLaunchKernel(internal::compute_tranformation_matrix_cuda, 1, + 1, 0, cu_stream, delta_h, scale_s, scale_v, + tranformation_matrix.flat().data(), + tranformation_matrix.flat().size())); + // Call cuBlas C = A * B directly. + auto no_transpose = se::blas::Transpose::kNoTranspose; + auto a_ptr = + AsDeviceMemory(input->flat().data(), input->flat().size()); + auto b_ptr = AsDeviceMemory(tranformation_matrix.flat().data(), + tranformation_matrix.flat().size()); + auto c_ptr = AsDeviceMemory(output->flat().data(), + output->flat().size()); + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); + // TODO(huangyp): share/use autotune cublas algorithms in Matmul.op. + bool blas_launch_status = + stream + ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n, + a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, ", n=", + n, ", k=", k)); + } +} +} // namespace functor +} // namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/tensorflow_addons/custom_ops/image/cc/kernels/image_projective_transform_op.cc b/tensorflow_addons/custom_ops/image/cc/kernels/image_projective_transform_op.cc index ee59b9403e..2936f218c8 100644 --- a/tensorflow_addons/custom_ops/image/cc/kernels/image_projective_transform_op.cc +++ b/tensorflow_addons/custom_ops/image/cc/kernels/image_projective_transform_op.cc @@ -50,12 +50,12 @@ using generator::INTERPOLATION_NEAREST; using generator::ProjectiveGenerator; template -class ImageProjectiveTransform : public OpKernel { +class ImageProjectiveTransformV2 : public OpKernel { private: Interpolation interpolation_; public: - explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) { + explicit ImageProjectiveTransformV2(OpKernelConstruction* ctx) : OpKernel(ctx) { string interpolation_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str)); if (interpolation_str == "NEAREST") { @@ -118,10 +118,10 @@ class ImageProjectiveTransform : public OpKernel { }; #define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \ + REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \ .Device(DEVICE_CPU) \ .TypeConstraint("dtype"), \ - ImageProjectiveTransform) + ImageProjectiveTransformV2) TF_CALL_uint8(REGISTER); TF_CALL_int32(REGISTER); @@ -157,11 +157,11 @@ TF_CALL_double(DECLARE_FUNCTOR); } // end namespace functor #define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \ + REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \ .Device(DEVICE_GPU) \ .TypeConstraint("dtype") \ .HostMemory("output_shape"), \ - ImageProjectiveTransform) + ImageProjectiveTransformV2) TF_CALL_uint8(REGISTER); TF_CALL_int32(REGISTER); diff --git a/tensorflow_addons/custom_ops/image/cc/ops/distort_image_ops.cc b/tensorflow_addons/custom_ops/image/cc/ops/distort_image_ops.cc new file mode 100644 index 0000000000..82357ea606 --- /dev/null +++ b/tensorflow_addons/custom_ops/image/cc/ops/distort_image_ops.cc @@ -0,0 +1,60 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; + +// -------------------------------------------------------------------------- +REGISTER_OP("AdjustHsvInYiq") + .Input("images: T") + .Input("delta_h: float") + .Input("scale_s: float") + .Input("scale_v: float") + .Output("output: T") + .Attr("T: {uint8, int8, int16, int32, int64, half, float, double}") + .SetShapeFn([](InferenceContext* c) { + return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); + }) + .Doc(R"Doc( +Adjust the YIQ hue of one or more images. + +`images` is a tensor of at least 3 dimensions. The last dimension is +interpreted as channels, and must be three. + +We used linear transformation described in: + beesbuzz.biz/code/hsv_color_transforms.php +The input image is considered in the RGB colorspace. Conceptually, the RGB +colors are first mapped into YIQ space, rotated around the Y channel by +delta_h in radians, multiplying the chrominance channels (I, Q) by scale_s, +multiplying all channels (Y, I, Q) by scale_v, and then remapped back to RGB +colorspace. Each operation described above is a linear transformation. + +images: Images to adjust. At least 3-D. +delta_h: A float scale that represents the hue rotation amount, in radians. + Although delta_h can be any float value. +scale_s: A float scale that represents the factor to multiply the saturation by. + scale_s needs to be non-negative. +scale_v: A float scale that represents the factor to multiply the value by. + scale_v needs to be non-negative. +output: The hsv-adjusted image or images. No clipping will be done in this op. + The client can clip them using additional ops in their graph. +)Doc"); + +} // namespace tensorflow diff --git a/tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc b/tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc index c564c9241f..6de221db78 100644 --- a/tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc +++ b/tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc @@ -92,7 +92,7 @@ the `transforms` to the `images`. Satisfies the description above. } // namespace // V2 op supports output_shape. -REGISTER_OP("ImageProjectiveTransform") +REGISTER_OP("ImageProjectiveTransformV2") .Input("images: dtype") .Input("transforms: float32") .Input("output_shape: int32") diff --git a/tensorflow_addons/custom_ops/image/python/distort_image_ops.py b/tensorflow_addons/custom_ops/image/python/distort_image_ops.py new file mode 100644 index 0000000000..80f082e22e --- /dev/null +++ b/tensorflow_addons/custom_ops/image/python/distort_image_ops.py @@ -0,0 +1,143 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python layer for distort_image_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python.platform import resource_loader + +_distort_image_ops = tf.load_op_library( + resource_loader.get_path_to_datafile("_distort_image_ops.so")) + + +# pylint: disable=invalid-name +@tf.function +def random_hsv_in_yiq(image, + max_delta_hue=0, + lower_saturation=1, + upper_saturation=1, + lower_value=1, + upper_value=1, + seed=None): + """Adjust hue, saturation, value of an RGB image randomly in YIQ color + space. + + Equivalent to `adjust_yiq_hsv()` but uses a `delta_h` randomly + picked in the interval `[-max_delta_hue, max_delta_hue]`, a + `scale_saturation` randomly picked in the interval + `[lower_saturation, upper_saturation]`, and a `scale_value` + randomly picked in the interval `[lower_saturation, upper_saturation]`. + + Args: + image: RGB image or images. Size of the last dimension must be 3. + max_delta_hue: float. Maximum value for the random delta_hue. Passing 0 + disables adjusting hue. + lower_saturation: float. Lower bound for the random scale_saturation. + upper_saturation: float. Upper bound for the random scale_saturation. + lower_value: float. Lower bound for the random scale_value. + upper_value: float. Upper bound for the random scale_value. + seed: An operation-specific seed. It will be used in conjunction + with the graph-level seed to determine the real seeds that will be + used in this operation. Please see the documentation of + set_random_seed for its interaction with the graph-level random seed. + + Returns: + 3-D float tensor of shape `[height, width, channels]`. + + Raises: + ValueError: if `max_delta`, `lower_saturation`, `upper_saturation`, + `lower_value`, or `upper_Value` is invalid. + """ + if max_delta_hue < 0: + raise ValueError("max_delta must be non-negative.") + + if lower_saturation < 0: + raise ValueError("lower_saturation must be non-negative.") + + if lower_value < 0: + raise ValueError("lower_value must be non-negative.") + + if lower_saturation > upper_saturation: + raise ValueError("lower_saturation must be < upper_saturation.") + + if lower_value > upper_value: + raise ValueError("lower_value must be < upper_value.") + + if max_delta_hue == 0: + delta_hue = 0 + else: + delta_hue = tf.random.uniform([], + -max_delta_hue, + max_delta_hue, + seed=seed) + if lower_saturation == upper_saturation: + scale_saturation = lower_saturation + else: + scale_saturation = tf.random.uniform([], + lower_saturation, + upper_saturation, + seed=seed) + if lower_value == upper_value: + scale_value = lower_value + else: + scale_value = tf.random.uniform([], + lower_value, + upper_value, + seed=seed) + return adjust_hsv_in_yiq(image, delta_hue, scale_saturation, scale_value) + + +@tf.function +def adjust_hsv_in_yiq(image, + delta_hue=0, + scale_saturation=1, + scale_value=1, + name="adjust_hsv_in_yiq"): + """Adjust hue, saturation, value of an RGB image in YIQ color space. + + This is a convenience method that converts an RGB image to float + representation, converts it to YIQ, rotates the color around the + Y channel by delta_hue in radians, scales the chrominance channels + (I, Q) by scale_saturation, scales all channels (Y, I, Q) by scale_value, + converts back to RGB, and then back to the original data type. + + `image` is an RGB image. The image hue is adjusted by converting the + image to YIQ, rotating around the luminance channel (Y) by + `delta_hue` in radians, multiplying the chrominance channels (I, Q) by + `scale_saturation`, and multiplying all channels (Y, I, Q) by + `scale_value`. The image is then converted back to RGB. + + Args: + image: RGB image or images. Size of the last dimension must be 3. + delta_hue: float, the hue rotation amount, in radians. + scale_saturation: float, factor to multiply the saturation by. + scale_value: float, factor to multiply the value by. + name: A name for this operation (optional). + + Returns: + Adjusted image(s), same shape and DType as `image`. + """ + with tf.name_scope(name): + image = tf.convert_to_tensor(image, name="image") + # Remember original dtype to so we can convert back if needed + orig_dtype = image.dtype + flt_image = tf.image.convert_image_dtype(image, tf.dtypes.float32) + + rgb_altered = _distort_image_ops.adjust_hsv_in_yiq( + flt_image, delta_hue, scale_saturation, scale_value) + + return tf.image.convert_image_dtype(rgb_altered, orig_dtype) diff --git a/tensorflow_addons/custom_ops/image/python/distort_image_ops_test.py b/tensorflow_addons/custom_ops/image/python/distort_image_ops_test.py new file mode 100644 index 0000000000..5e92e71c9c --- /dev/null +++ b/tensorflow_addons/custom_ops/image/python/distort_image_ops_test.py @@ -0,0 +1,341 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may noa use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for python distort_image_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +import tensorflow as tf +from tensorflow_addons.custom_ops.image.python import distort_image_ops +from tensorflow_addons.utils.python import test_utils + + +class AdjustHueInYiqTest(tf.test.TestCase): + def _adjust_hue_in_yiq_np(self, x_np, delta_h): + """Rotate hue in YIQ space. + + Mathematically we first convert rgb color to yiq space, rotate the hue + degrees, and then convert back to rgb. + + Args: + x_np: input x with last dimension = 3. + delta_h: degree of hue rotation, in radians. + + Returns: + Adjusted y with the same shape as x_np. + """ + self.assertEqual(x_np.shape[-1], 3) + x_v = x_np.reshape([-1, 3]) + y_v = np.ndarray(x_v.shape, dtype=x_v.dtype) + u = np.cos(delta_h) + w = np.sin(delta_h) + # Projection matrix from RGB to YIQ. Numbers from wikipedia + # https://en.wikipedia.org/wiki/YIQ + tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.322], + [0.211, -0.523, 0.312]]) + y_v = np.dot(x_v, tyiq.T) + # Hue rotation matrix in YIQ space. + hue_rotation = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]]) + y_v = np.dot(y_v, hue_rotation.T) + # Projecting back to RGB space. + y_v = np.dot(y_v, np.linalg.inv(tyiq).T) + return y_v.reshape(x_np.shape) + + def _adjust_hue_in_yiq_tf(self, x_np, delta_h): + x = tf.constant(x_np) + y = distort_image_ops.adjust_hsv_in_yiq(x, delta_h, 1, 1) + return y + + @test_utils.run_in_graph_and_eager_modes + def test_adjust_random_hue_in_yiq(self): + x_shapes = [ + [2, 2, 3], + [4, 2, 3], + [2, 4, 3], + [2, 5, 3], + [1000, 1, 3], + ] + test_styles = [ + "all_random", + "rg_same", + "rb_same", + "gb_same", + "rgb_same", + ] + for x_shape in x_shapes: + for test_style in test_styles: + x_np = np.random.rand(*x_shape) * 255. + delta_h = (np.random.rand() * 2.0 - 1.0) * np.pi + if test_style == "all_random": + pass + elif test_style == "rg_same": + x_np[..., 1] = x_np[..., 0] + elif test_style == "rb_same": + x_np[..., 2] = x_np[..., 0] + elif test_style == "gb_same": + x_np[..., 2] = x_np[..., 1] + elif test_style == "rgb_same": + x_np[..., 1] = x_np[..., 0] + x_np[..., 2] = x_np[..., 0] + else: + raise AssertionError( + "Invalid test style: %s" % (test_style)) + y_np = self._adjust_hue_in_yiq_np(x_np, delta_h) + y_tf = self._adjust_hue_in_yiq_tf(x_np, delta_h) + self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4) + + # TODO: run in both graph and eager modes + def test_invalid_shapes(self): + x_np = np.random.rand(2, 3) * 255. + delta_h = np.random.rand() * 2.0 - 1.0 + with self.assertRaisesRegexp(ValueError, + "Shape must be at least rank 3"): + self._adjust_hue_in_yiq_tf(x_np, delta_h) + x_np = np.random.rand(4, 2, 4) * 255. + delta_h = np.random.rand() * 2.0 - 1.0 + with self.assertRaisesOpError("input must have 3 channels " + "but instead has 4 channels"): + self._adjust_hue_in_yiq_tf(x_np, delta_h) + + +class AdjustValueInYiqTest(tf.test.TestCase): + def _adjust_value_in_yiq_np(self, x_np, scale): + return x_np * scale + + def _adjust_value_in_yiq_tf(self, x_np, scale): + x = tf.constant(x_np) + y = distort_image_ops.adjust_hsv_in_yiq(x, 0, 1, scale) + return y + + @test_utils.run_in_graph_and_eager_modes + def test_adjust_random_value_in_yiq(self): + x_shapes = [ + [2, 2, 3], + [4, 2, 3], + [2, 4, 3], + [2, 5, 3], + [1000, 1, 3], + ] + test_styles = [ + "all_random", + "rg_same", + "rb_same", + "gb_same", + "rgb_same", + ] + for x_shape in x_shapes: + for test_style in test_styles: + x_np = np.random.rand(*x_shape) * 255. + scale = np.random.rand() * 2.0 - 1.0 + if test_style == "all_random": + pass + elif test_style == "rg_same": + x_np[..., 1] = x_np[..., 0] + elif test_style == "rb_same": + x_np[..., 2] = x_np[..., 0] + elif test_style == "gb_same": + x_np[..., 2] = x_np[..., 1] + elif test_style == "rgb_same": + x_np[..., 1] = x_np[..., 0] + x_np[..., 2] = x_np[..., 0] + else: + raise AssertionError( + "Invalid test style: %s" % (test_style)) + y_np = self._adjust_value_in_yiq_np(x_np, scale) + y_tf = self._adjust_value_in_yiq_tf(x_np, scale) + self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4) + + # TODO: run in both graph and eager modes + def test_invalid_shapes(self): + x_np = np.random.rand(2, 3) * 255. + scale = np.random.rand() * 2.0 - 1.0 + with self.assertRaisesRegexp(ValueError, + "Shape must be at least rank 3"): + self._adjust_value_in_yiq_tf(x_np, scale) + x_np = np.random.rand(4, 2, 4) * 255. + scale = np.random.rand() * 2.0 - 1.0 + with self.assertRaisesOpError("input must have 3 channels " + "but instead has 4 channels"): + self._adjust_value_in_yiq_tf(x_np, scale) + + +class AdjustSaturationInYiqTest(tf.test.TestCase): + def _adjust_saturation_in_yiq_tf(self, x_np, scale): + x = tf.constant(x_np) + y = distort_image_ops.adjust_hsv_in_yiq(x, 0, scale, 1) + return y + + def _adjust_saturation_in_yiq_np(self, x_np, scale): + """Adjust saturation using linear interpolation.""" + rgb_weights = np.array([0.299, 0.587, 0.114]) + gray = np.sum(x_np * rgb_weights, axis=-1, keepdims=True) + y_v = x_np * scale + gray * (1 - scale) + return y_v + + @test_utils.run_in_graph_and_eager_modes + def test_adjust_random_saturation_in_yiq(self): + x_shapes = [ + [2, 2, 3], + [4, 2, 3], + [2, 4, 3], + [2, 5, 3], + [1000, 1, 3], + ] + test_styles = [ + "all_random", + "rg_same", + "rb_same", + "gb_same", + "rgb_same", + ] + for x_shape in x_shapes: + for test_style in test_styles: + x_np = np.random.rand(*x_shape) * 255. + scale = np.random.rand() * 2.0 - 1.0 + if test_style == "all_random": + pass + elif test_style == "rg_same": + x_np[..., 1] = x_np[..., 0] + elif test_style == "rb_same": + x_np[..., 2] = x_np[..., 0] + elif test_style == "gb_same": + x_np[..., 2] = x_np[..., 1] + elif test_style == "rgb_same": + x_np[..., 1] = x_np[..., 0] + x_np[..., 2] = x_np[..., 0] + else: + raise AssertionError( + "Invalid test style: %s" % (test_style)) + y_baseline = self._adjust_saturation_in_yiq_np(x_np, scale) + y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale) + self.assertAllClose(y_tf, y_baseline, rtol=2e-4, atol=1e-4) + + # TODO: run in both graph and eager modes + def test_invalid_shapes(self): + x_np = np.random.rand(2, 3) * 255. + scale = np.random.rand() * 2.0 - 1.0 + with self.assertRaisesRegexp(ValueError, + "Shape must be at least rank 3"): + self._adjust_saturation_in_yiq_tf(x_np, scale) + x_np = np.random.rand(4, 2, 4) * 255. + scale = np.random.rand() * 2.0 - 1.0 + with self.assertRaisesOpError("input must have 3 channels " + "but instead has 4 channels"): + self._adjust_saturation_in_yiq_tf(x_np, scale) + + +# TODO: get rid of sessions +class AdjustHueInYiqBenchmark(tf.test.Benchmark): + def _benchmark_adjust_hue_in_yiq(self, device, cpu_count): + image_shape = [299, 299, 3] + warmup_rounds = 100 + benchmark_rounds = 1000 + config = tf.compat.v1.ConfigProto() + if cpu_count is not None: + config.inter_op_parallelism_threads = 1 + config.intra_op_parallelism_threads = cpu_count + with self.cached_session("", graph=tf.Graph(), config=config) as sess: + with tf.device(device): + inputs = tf.Variable( + tf.random.uniform(image_shape, dtype=tf.dtypes.float32) * + 255, + trainable=False, + dtype=tf.dtypes.float32) + delta = tf.constant(0.1, dtype=tf.dtypes.float32) + outputs = distort_image_ops.adjust_hsv_in_yiq( + inputs, delta, 1, 1) + run_op = tf.group(outputs) + sess.run(tf.compat.v1.global_variables_initializer()) + for i in xrange(warmup_rounds + benchmark_rounds): + if i == warmup_rounds: + start = time.time() + sess.run(run_op) + end = time.time() + step_time = (end - start) / benchmark_rounds + tag = device + "_%s" % (cpu_count if cpu_count is not None else "all") + print("benchmarkadjust_hue_in_yiq_299_299_3_%s step_time: %.2f us" % + (tag, step_time * 1e6)) + self.report_benchmark( + name="benchmarkadjust_hue_in_yiq_299_299_3_%s" % (tag), + iters=benchmark_rounds, + wall_time=step_time) + + def benchmark_adjust_hue_in_yiqCpu1(self): + self._benchmark_adjust_hue_in_yiq("/cpu:0", 1) + + def benchmark_adjust_hue_in_yiqCpuAll(self): + self._benchmark_adjust_hue_in_yiq("/cpu:0", None) + + def benchmark_adjust_hue_in_yiq_gpu_all(self): + self._benchmark_adjust_hue_in_yiq(tf.test.gpu_device_name(), None) + + +# TODO: get rid of sessions +class AdjustSaturationInYiqBenchmark(tf.test.Benchmark): + def _benchmark_adjust_saturation_in_yiq(self, device, cpu_count): + image_shape = [299, 299, 3] + warmup_rounds = 100 + benchmark_rounds = 1000 + config = tf.compat.v1.ConfigProto() + if cpu_count is not None: + config.inter_op_parallelism_threads = 1 + config.intra_op_parallelism_threads = cpu_count + with self.cached_session("", graph=tf.Graph(), config=config) as sess: + with tf.device(device): + inputs = tf.Variable( + tf.random.uniform(image_shape, dtype=tf.dtypes.float32) * + 255, + trainable=False, + dtype=tf.dtypes.float32) + scale = tf.constant(0.1, dtype=tf.dtypes.float32) + outputs = distort_image_ops.adjust_hsv_in_yiq( + inputs, 0, scale, 1) + run_op = tf.group(outputs) + sess.run(tf.compat.v1.global_variables_initializer()) + for _ in xrange(warmup_rounds): + sess.run(run_op) + start = time.time() + for _ in xrange(benchmark_rounds): + sess.run(run_op) + end = time.time() + step_time = (end - start) / benchmark_rounds + tag = "%s" % (cpu_count) if cpu_count is not None else "_all" + print( + "benchmarkAdjustSaturationInYiq_299_299_3_cpu%s step_time: %.2f us" + % (tag, step_time * 1e6)) + self.report_benchmark( + name="benchmarkAdjustSaturationInYiq_299_299_3_cpu%s" % (tag), + iters=benchmark_rounds, + wall_time=step_time) + + def benchmark_adjust_saturation_in_yiq_cpu1(self): + self._benchmark_adjust_saturation_in_yiq("/cpu:0", 1) + + def benchmark_adjust_saturation_in_yiq_cpu_all(self): + self._benchmark_adjust_saturation_in_yiq("/cpu:0", None) + + def benchmark_adjust_saturation_in_yiq_gpu_all(self): + self._benchmark_adjust_saturation_in_yiq(tf.test.gpu_device_name(), + None) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/custom_ops/image/python/transform.py b/tensorflow_addons/custom_ops/image/python/transform.py index e40c5f3b20..ca470b0778 100644 --- a/tensorflow_addons/custom_ops/image/python/transform.py +++ b/tensorflow_addons/custom_ops/image/python/transform.py @@ -30,7 +30,7 @@ tf.dtypes.float32, tf.dtypes.float64 ]) -ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ImageProjectiveTransformV2")(common_shapes.call_cpp_shape_fn) @tf.function @@ -108,7 +108,7 @@ def transform(images, else: raise TypeError("Transforms should have rank 1 or 2.") - output = _image_ops_so.image_projective_transform( + output = _image_ops_so.image_projective_transform_v2( images, output_shape=output_shape, transforms=transforms, @@ -121,6 +121,7 @@ def transform(images, return output +@tf.function def compose_transforms(*transforms): """Composes the transforms tensors. @@ -144,6 +145,7 @@ def compose_transforms(*transforms): return matrices_to_flat_transforms(composed) +@tf.function def flat_transforms_to_matrices(transforms): """Converts projective transforms to affine matrices. @@ -176,6 +178,7 @@ def flat_transforms_to_matrices(transforms): tf.constant([-1, 3, 3])) +@tf.function def matrices_to_flat_transforms(transform_matrices): """Converts affine matrices to projective transforms. @@ -208,6 +211,7 @@ def matrices_to_flat_transforms(transform_matrices): return transforms[:, :8] +@tf.function def angles_to_projective_transforms(angles, image_height, image_width, @@ -256,7 +260,7 @@ def angles_to_projective_transforms(angles, axis=1) -@ops.RegisterGradient("ImageProjectiveTransform") +@ops.RegisterGradient("ImageProjectiveTransformV2") def _image_projective_transform_grad(op, grad): """Computes the gradient for ImageProjectiveTransform.""" images = op.inputs[0] @@ -280,7 +284,7 @@ def _image_projective_transform_grad(op, grad): transforms = flat_transforms_to_matrices(transforms=transforms) inverse = tf.linalg.inv(transforms) transforms = matrices_to_flat_transforms(inverse) - output = _image_ops_so.image_projective_transform( + output = _image_ops_so.image_projective_transform_v2( images=grad, transforms=transforms, output_shape=tf.shape(image_or_images)[1:3], diff --git a/tensorflow_addons/custom_ops/image/python/transform_test.py b/tensorflow_addons/custom_ops/image/python/transform_test.py index f5bcda944f..d381b8eced 100644 --- a/tensorflow_addons/custom_ops/image/python/transform_test.py +++ b/tensorflow_addons/custom_ops/image/python/transform_test.py @@ -21,9 +21,9 @@ import numpy as np import tensorflow as tf -from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.ops import gradient_checker from tensorflow_addons.custom_ops.image.python import transform as transform_ops +from tensorflow_addons.utils.python import test_utils _DTYPES = set([ tf.dtypes.uint8, tf.dtypes.int32, tf.dtypes.int64, tf.dtypes.float16, @@ -32,7 +32,7 @@ class ImageOpsTest(tf.test.TestCase): - @tf_test_util.run_all_in_graph_and_eager_modes + @test_utils.run_in_graph_and_eager_modes def test_compose(self): for dtype in _DTYPES: image = tf.constant( @@ -51,7 +51,7 @@ def test_compose(self): [[0, 0, 0, 0], [0, 1, 0, 1], [0, 1, 0, 1], [0, 1, 1, 1]], image_transformed) - @tf_test_util.run_all_in_graph_and_eager_modes + @test_utils.run_in_graph_and_eager_modes def test_extreme_projective_transform(self): for dtype in _DTYPES: image = tf.constant( @@ -64,7 +64,6 @@ def test_extreme_projective_transform(self): [[1, 0, 0, 0], [0, 0, 0, 0], [1, 0, 0, 0], [0, 0, 0, 0]], image_transformed) - @tf_test_util.run_all_in_graph_and_eager_modes def test_transform_static_output_shape(self): image = tf.constant([[1., 2.], [3., 4.]]) result = transform_ops.transform( @@ -118,7 +117,7 @@ def _test_grad_different_shape(self, input_shape, output_shape): self.assertLess(left_err, 1e-10) # TODO: switch to TF2 later. - @tf_test_util.run_deprecated_v1 + @test_utils.run_deprecated_v1 def test_grad(self): self._test_grad([16, 16]) self._test_grad([4, 12, 12]) @@ -127,16 +126,15 @@ def test_grad(self): self._test_grad_different_shape([4, 12, 3], [8, 24, 3]) self._test_grad_different_shape([3, 4, 12, 3], [3, 8, 24, 3]) - @tf_test_util.run_all_in_graph_and_eager_modes + @test_utils.run_in_graph_and_eager_modes def test_transform_data_types(self): for dtype in _DTYPES: image = tf.constant([[1, 2], [3, 4]], dtype=dtype) - with self.test_session(use_gpu=True): - self.assertAllEqual( - np.array([[4, 4], [4, 4]]).astype(dtype.as_numpy_dtype()), - transform_ops.transform(image, [1] * 8)) + self.assertAllEqual( + np.array([[4, 4], [4, 4]]).astype(dtype.as_numpy_dtype()), + transform_ops.transform(image, [1] * 8)) - @tf_test_util.run_all_in_graph_and_eager_modes + @test_utils.run_in_graph_and_eager_modes def test_transform_eager(self): image = tf.constant([[1., 2.], [3., 4.]]) self.assertAllEqual( diff --git a/tensorflow_addons/custom_ops/text/BUILD b/tensorflow_addons/custom_ops/text/BUILD index bf4d3399e2..a356984c15 100644 --- a/tensorflow_addons/custom_ops/text/BUILD +++ b/tensorflow_addons/custom_ops/text/BUILD @@ -27,6 +27,7 @@ py_library( ]), data = [ ":python/_skip_gram_ops.so", + "//tensorflow_addons/utils:utils_py", ], srcs_version = "PY2AND3", ) diff --git a/tensorflow_addons/custom_ops/text/python/skip_gram_ops_test.py b/tensorflow_addons/custom_ops/text/python/skip_gram_ops_test.py index 7008405a44..9532689632 100644 --- a/tensorflow_addons/custom_ops/text/python/skip_gram_ops_test.py +++ b/tensorflow_addons/custom_ops/text/python/skip_gram_ops_test.py @@ -22,12 +22,12 @@ import os import tensorflow as tf -from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.framework import random_seed from tensorflow.python.ops import lookup_ops from tensorflow.python.platform import test from tensorflow_addons.custom_ops import text from tensorflow_addons.custom_ops.text.python import skip_gram_ops +from tensorflow_addons.utils.python import test_utils class SkipGramOpsTest(tf.test.TestCase): @@ -232,7 +232,7 @@ def test_skip_gram_sample_non_string_input(self): self.assertAllEqual(expected_tokens, tokens) self.assertAllEqual(expected_labels, labels) - @tf_test_util.run_deprecated_v1 + @test_utils.run_deprecated_v1 def test_skip_gram_sample_errors_v1(self): """Tests various errors raised by skip_gram_sample().""" # input_tensor must be of rank 1. diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index b8de59fbe0..3ae079bbf1 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -8,6 +8,7 @@ py_library( "__init__.py", "python/__init__.py", "python/maxout.py", + "python/normalizations.py", "python/poincare.py", "python/sparsemax.py", "python/wrappers.py", @@ -29,7 +30,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":layers_py", - ], + ] ) py_test( @@ -55,18 +56,18 @@ py_test( srcs_version = "PY2AND3", deps = [ ":layers_py", - ], + ] ) py_test( - name = "poincare_py_test", - size = "small", + name = "layers_normalizations_py_test", + size= "small", srcs = [ - "python/poincare_test.py", + "python/normalizations_test.py", ], - main = "python/poincare_test.py", + main = "python/normalizations_test.py", srcs_version = "PY2AND3", deps = [ - ":layers_py", - ], + ":layers_py", + ] ) diff --git a/tensorflow_addons/layers/README.md b/tensorflow_addons/layers/README.md index ae4b3121b0..f186a0422b 100644 --- a/tensorflow_addons/layers/README.md +++ b/tensorflow_addons/layers/README.md @@ -3,8 +3,11 @@ ## Contents | Layer | Reference | |:----------------------- |:-----------------------------| +| GroupNormalization | https://arxiv.org/abs/1803.08494 | +| InstanceNormalization | https://arxiv.org/abs/1607.08022 | +| LayerNormalization | https://arxiv.org/abs/1607.06450 | | Maxout | https://arxiv.org/abs/1302.4389 | -| PoinareNormalize | https://arxiv.org/abs/1705.08039 | +| PoincareNormalize | https://arxiv.org/abs/1705.08039 | | WeightNormalization | https://arxiv.org/abs/1602.07868 | @@ -20,8 +23,10 @@ must: #### Testing Requirements * Simple unittests that demonstrate the layer is behaving as expected. * When applicable, run all unittests with TensorFlow's - `@run_all_in_graph_and_eager_modes` decorator. - * Run `keras.testing_utils.layer_test` on the layer. + `@run_in_graph_and_eager_modes` (for test method) + or `run_all_in_graph_and_eager_modes` (for TestCase subclass) + decorator. + * Run `layer_test` on the layer. * Add a `py_test` to this sub-package's BUILD file. #### Documentation Requirements diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 0e06709ac7..c5e0497726 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -19,6 +19,9 @@ from __future__ import print_function from tensorflow_addons.layers.python.maxout import Maxout +from tensorflow_addons.layers.python.normalizations import GroupNormalization +from tensorflow_addons.layers.python.normalizations import InstanceNormalization +from tensorflow_addons.layers.python.normalizations import LayerNormalization from tensorflow_addons.layers.python.poincare import PoincareNormalize from tensorflow_addons.layers.python.sparsemax import Sparsemax from tensorflow_addons.layers.python.wrappers import WeightNormalization diff --git a/tensorflow_addons/layers/python/maxout_test.py b/tensorflow_addons/layers/python/maxout_test.py index 19a5de772e..f86d21f1e5 100644 --- a/tensorflow_addons/layers/python/maxout_test.py +++ b/tensorflow_addons/layers/python/maxout_test.py @@ -21,17 +21,18 @@ import numpy as np import tensorflow as tf -from tensorflow.python.keras import testing_utils as keras_test_util from tensorflow_addons.layers.python.maxout import Maxout +from tensorflow_addons.utils.python import test_utils +@test_utils.run_all_in_graph_and_eager_modes class MaxOutTest(tf.test.TestCase): def test_simple(self): - keras_test_util.layer_test( + test_utils.layer_test( Maxout, kwargs={'num_units': 3}, input_shape=(5, 4, 2, 18)) def test_nchw(self): - keras_test_util.layer_test( + test_utils.layer_test( Maxout, kwargs={ 'num_units': 4, @@ -39,7 +40,7 @@ def test_nchw(self): }, input_shape=(2, 20, 3, 6)) - keras_test_util.layer_test( + test_utils.layer_test( Maxout, kwargs={ 'num_units': 4, @@ -49,13 +50,13 @@ def test_nchw(self): def test_unknown(self): inputs = np.random.random((5, 4, 2, 18)).astype('float32') - keras_test_util.layer_test( + test_utils.layer_test( Maxout, kwargs={'num_units': 3}, input_shape=(5, 4, 2, None), input_data=inputs) - keras_test_util.layer_test( + test_utils.layer_test( Maxout, kwargs={'num_units': 3}, input_shape=(None, None, None, None), @@ -63,7 +64,7 @@ def test_unknown(self): def test_invalid_shape(self): with self.assertRaisesRegexp(ValueError, r'number of features'): - keras_test_util.layer_test( + test_utils.layer_test( Maxout, kwargs={'num_units': 3}, input_shape=(5, 4, 2, 7)) diff --git a/tensorflow_addons/layers/python/normalizations.py b/tensorflow_addons/layers/python/normalizations.py new file mode 100644 index 0000000000..2a07a3d802 --- /dev/null +++ b/tensorflow_addons/layers/python/normalizations.py @@ -0,0 +1,361 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Orginal implementation from keras_contrib/layer/normalization +# ============================================================================= +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import tensorflow as tf +from tensorflow_addons.utils.python import keras_utils + + +@keras_utils.register_keras_custom_object +class GroupNormalization(tf.keras.layers.Layer): + """Group normalization layer. + + Group Normalization divides the channels into groups and computes + within each group the mean and variance for normalization. + Empirically, its accuracy is more stable than batch norm in a wide + range of small batch sizes, if learning rate is adjusted linearly + with batch sizes. + + Relation to Layer Normalization: + If the number of groups is set to 1, then this operation becomes identical + to Layer Normalization. + + Relation to Instance Normalization: + If the number of groups is set to the + input dimension (number of groups is equal + to number of channels), then this operation becomes + identical to Instance Normalization. + + Arguments + groups: Integer, the number of groups for Group Normalization. + Can be in the range [1, N] where N is the input dimension. + The input dimension must be divisible by the number of groups. + axis: Integer, the axis that should be normalized. + epsilon: Small float added to variance to avoid dividing by zero. + center: If True, add offset of `beta` to normalized tensor. + If False, `beta` is ignored. + scale: If True, multiply by `gamma`. + If False, `gamma` is not used. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + beta_constraint: Optional constraint for the beta weight. + gamma_constraint: Optional constraint for the gamma weight. + + Input shape + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + + Output shape + Same shape as input. + References + - [Group Normalization](https://arxiv.org/abs/1803.08494) + """ + + def __init__(self, + groups=2, + axis=-1, + epsilon=1e-5, + center=True, + scale=True, + beta_initializer='zeros', + gamma_initializer='ones', + beta_regularizer=None, + gamma_regularizer=None, + beta_constraint=None, + gamma_constraint=None, + **kwargs): + super(GroupNormalization, self).__init__(**kwargs) + self.supports_masking = True + self.groups = groups + self.axis = axis + self.epsilon = epsilon + self.center = center + self.scale = scale + self.beta_initializer = tf.keras.initializers.get(beta_initializer) + self.gamma_initializer = tf.keras.initializers.get(gamma_initializer) + self.beta_regularizer = tf.keras.regularizers.get(beta_regularizer) + self.gamma_regularizer = tf.keras.regularizers.get(gamma_regularizer) + self.beta_constraint = tf.keras.constraints.get(beta_constraint) + self.gamma_constraint = tf.keras.constraints.get(gamma_constraint) + self._check_axis() + + def build(self, input_shape): + + self._check_if_input_shape_is_none(input_shape) + self._set_number_of_groups_for_instance_norm(input_shape) + self._check_size_of_dimensions(input_shape) + self._create_input_spec(input_shape) + + self._add_gamma_weight(input_shape) + self._add_beta_weight(input_shape) + self.built = True + super(GroupNormalization, self).build(input_shape) + + def call(self, inputs): + + input_shape = tf.keras.backend.int_shape(inputs) + tensor_input_shape = tf.shape(inputs) + + reshaped_inputs, group_shape = self._reshape_into_groups( + inputs, input_shape, tensor_input_shape) + + normalized_inputs = self._apply_normalization(reshaped_inputs, + input_shape) + + outputs = tf.reshape(normalized_inputs, tensor_input_shape) + + return outputs + + def get_config(self): + config = { + 'groups': + self.groups, + 'axis': + self.axis, + 'epsilon': + self.epsilon, + 'center': + self.center, + 'scale': + self.scale, + 'beta_initializer': + tf.keras.initializers.serialize(self.beta_initializer), + 'gamma_initializer': + tf.keras.initializers.serialize(self.gamma_initializer), + 'beta_regularizer': + tf.keras.regularizers.serialize(self.beta_regularizer), + 'gamma_regularizer': + tf.keras.regularizers.serialize(self.gamma_regularizer), + 'beta_constraint': + tf.keras.constraints.serialize(self.beta_constraint), + 'gamma_constraint': + tf.keras.constraints.serialize(self.gamma_constraint) + } + base_config = super(GroupNormalization, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def compute_output_shape(self, input_shape): + return input_shape + + def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape): + + group_shape = [tensor_input_shape[i] for i in range(len(input_shape))] + group_shape[self.axis] = input_shape[self.axis] // self.groups + group_shape.insert(1, self.groups) + group_shape = tf.stack(group_shape) + reshaped_inputs = tf.reshape(inputs, group_shape) + return reshaped_inputs, group_shape + + def _apply_normalization(self, reshaped_inputs, input_shape): + + group_shape = tf.keras.backend.int_shape(reshaped_inputs) + group_reduction_axes = list(range(len(group_shape))) + # Remember the ordering of the tensor is [batch, group , steps]. Jump + # the first 2 to calculate the variance and the mean + mean, variance = tf.nn.moments( + reshaped_inputs, group_reduction_axes[2:], keepdims=True) + + gamma, beta = self._get_reshaped_weights(input_shape) + normalized_inputs = tf.nn.batch_normalization( + reshaped_inputs, + mean=mean, + variance=variance, + scale=gamma, + offset=beta, + variance_epsilon=self.epsilon) + return normalized_inputs + + def _get_reshaped_weights(self, input_shape): + broadcast_shape = self._create_broadcast_shape(input_shape) + gamma = None + beta = None + if self.scale: + gamma = tf.reshape(self.gamma, broadcast_shape) + + if self.center: + beta = tf.reshape(self.beta, broadcast_shape) + return gamma, beta + + def _check_if_input_shape_is_none(self, input_shape): + dim = input_shape[self.axis] + if dim is None: + raise ValueError('Axis ' + str(self.axis) + ' of ' + 'input tensor should have a defined dimension ' + 'but the layer received an input with shape ' + + str(input_shape) + '.') + + def _set_number_of_groups_for_instance_norm(self, input_shape): + dim = input_shape[self.axis] + + if self.groups == -1: + self.groups = dim + + def _check_size_of_dimensions(self, input_shape): + + dim = input_shape[self.axis] + if dim < self.groups: + raise ValueError( + 'Number of groups (' + str(self.groups) + ') cannot be ' + 'more than the number of channels (' + str(dim) + ').') + + if dim % self.groups != 0: + raise ValueError( + 'Number of groups (' + str(self.groups) + ') must be a ' + 'multiple of the number of channels (' + str(dim) + ').') + + def _check_axis(self): + + if self.axis == 0: + raise ValueError( + "You are trying to normalize your batch axis. Do you want to " + "use tf.layer.batch_normalization instead") + + def _create_input_spec(self, input_shape): + + dim = input_shape[self.axis] + self.input_spec = tf.keras.layers.InputSpec( + ndim=len(input_shape), axes={self.axis: dim}) + + def _add_gamma_weight(self, input_shape): + + dim = input_shape[self.axis] + shape = (dim,) + + if self.scale: + self.gamma = self.add_weight( + shape=shape, + name='gamma', + initializer=self.gamma_initializer, + regularizer=self.gamma_regularizer, + constraint=self.gamma_constraint) + else: + self.gamma = None + + def _add_beta_weight(self, input_shape): + + dim = input_shape[self.axis] + shape = (dim,) + + if self.center: + self.beta = self.add_weight( + shape=shape, + name='beta', + initializer=self.beta_initializer, + regularizer=self.beta_regularizer, + constraint=self.beta_constraint) + else: + self.beta = None + + def _create_broadcast_shape(self, input_shape): + broadcast_shape = [1] * len(input_shape) + broadcast_shape[self.axis] = input_shape[self.axis] // self.groups + broadcast_shape.insert(1, self.groups) + return broadcast_shape + + +@keras_utils.register_keras_custom_object +class LayerNormalization(GroupNormalization): + """Layer normalization layer. + + Layer Normalization is an specific case of ```GroupNormalization```since it + normalizes all features of a layer. The Groupsize is 1. + Empirically, its accuracy is more stable than batch norm in a wide + range of small batch sizes, if learning rate is adjusted linearly + with batch sizes. + + Arguments + axis: Integer, the axis that should be normalized. + epsilon: Small float added to variance to avoid dividing by zero. + center: If True, add offset of `beta` to normalized tensor. + If False, `beta` is ignored. + scale: If True, multiply by `gamma`. + If False, `gamma` is not used. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + beta_constraint: Optional constraint for the beta weight. + gamma_constraint: Optional constraint for the gamma weight. + + Input shape + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + + Output shape + Same shape as input. + + References + - [Layer Normalization](https://arxiv.org/abs/1607.06450) + """ + + def __init__(self, **kwargs): + if "groups" in kwargs: + logging.warning("The given value for groups will be overwritten.") + kwargs["groups"] = 1 + super(LayerNormalization, self).__init__(**kwargs) + + +@keras_utils.register_keras_custom_object +class InstanceNormalization(GroupNormalization): + """Instance normalization layer. + + Instance Normalization is an specific case of ```GroupNormalization```since + it normalizes all features of one channel. The Groupsize is equal to the + channel size. Empirically, its accuracy is more stable than batch norm in a + wide range of small batch sizes, if learning rate is adjusted linearly + with batch sizes. + + Arguments + axis: Integer, the axis that should be normalized. + epsilon: Small float added to variance to avoid dividing by zero. + center: If True, add offset of `beta` to normalized tensor. + If False, `beta` is ignored. + scale: If True, multiply by `gamma`. + If False, `gamma` is not used. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + beta_constraint: Optional constraint for the beta weight. + gamma_constraint: Optional constraint for the gamma weight. + + Input shape + Arbitrary. Use the keyword argument `input_shape` + (tuple of integers, does not include the samples axis) + when using this layer as the first layer in a model. + + Output shape + Same shape as input. + + References + - [Instance Normalization: The Missing Ingredient for Fast Stylization] + (https://arxiv.org/abs/1607.08022) + """ + + def __init__(self, **kwargs): + if "groups" in kwargs: + logging.warning("The given value for groups will be overwritten.") + + kwargs["groups"] = -1 + super(InstanceNormalization, self).__init__(**kwargs) diff --git a/tensorflow_addons/layers/python/normalizations_test.py b/tensorflow_addons/layers/python/normalizations_test.py new file mode 100644 index 0000000000..f3bf95afae --- /dev/null +++ b/tensorflow_addons/layers/python/normalizations_test.py @@ -0,0 +1,282 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow_addons.layers.python.normalizations import GroupNormalization +from tensorflow_addons.layers.python.normalizations import InstanceNormalization +from tensorflow_addons.layers.python.normalizations import LayerNormalization +from tensorflow_addons.utils.python import test_utils + + +class NormalizationTest(tf.test.TestCase): + + # ------------Tests to ensure proper inheritance. If these suceed you can + # test for Instance norm and Layernorm by setting Groupnorm groups = -1 or 1 + def test_inheritance(self): + self.assertTrue(issubclass(LayerNormalization, GroupNormalization)) + self.assertTrue(issubclass(InstanceNormalization, GroupNormalization)) + self.assertTrue(LayerNormalization.build == GroupNormalization.build) + self.assertTrue( + InstanceNormalization.build == GroupNormalization.build) + self.assertTrue(LayerNormalization.call == GroupNormalization.call) + self.assertTrue(InstanceNormalization.call == GroupNormalization.call) + + def test_groups_after_init(self): + layers = InstanceNormalization() + self.assertTrue(layers.groups == -1) + layers = LayerNormalization() + self.assertTrue(layers.groups == 1) + + # ------------------------------------------------------------------------------ + + def test_reshape(self): + def run_reshape_test(axis, group, input_shape, expected_shape): + group_layer = GroupNormalization(groups=group, axis=axis) + group_layer._set_number_of_groups_for_instance_norm(input_shape) + + inputs = np.ones(input_shape) + tensor_input_shape = tf.convert_to_tensor(input_shape) + reshaped_inputs, group_shape = group_layer._reshape_into_groups( + inputs, (10, 10, 10), tensor_input_shape) + for i in range(len(expected_shape)): + self.assertEqual(int(group_shape[i]), expected_shape[i]) + + input_shape = (10, 10, 10) + expected_shape = [10, 5, 10, 2] + run_reshape_test(2, 5, input_shape, expected_shape) + + input_shape = (10, 10, 10) + expected_shape = [10, 2, 5, 10] + run_reshape_test(1, 2, input_shape, expected_shape) + + input_shape = (10, 10, 10) + expected_shape = [10, 10, 1, 10] + run_reshape_test(1, -1, input_shape, expected_shape) + + input_shape = (10, 10, 10) + expected_shape = [10, 1, 10, 10] + run_reshape_test(1, 1, input_shape, expected_shape) + + def test_feature_input(self): + shape = (10, 100) + for center in [True, False]: + for scale in [True, False]: + for groups in [-1, 1, 2, 5]: + self._test_random_shape_on_all_axis_except_batch( + shape, groups, center, scale) + + def test_picture_input(self): + shape = (10, 30, 30, 3) + for center in [True, False]: + for scale in [True, False]: + for groups in [-1, 1, 3]: + self._test_random_shape_on_all_axis_except_batch( + shape, groups, center, scale) + + def _test_random_shape_on_all_axis_except_batch(self, shape, groups, + center, scale): + inputs = tf.random.normal((shape)) + for axis in range(1, len(shape)): + self._test_specific_layer(inputs, axis, groups, center, scale) + + def _test_specific_layer(self, inputs, axis, groups, center, scale): + + input_shape = inputs.shape + + # Get Output from Keras model + layer = GroupNormalization( + axis=axis, groups=groups, center=center, scale=scale) + model = tf.keras.models.Sequential() + model.add(layer) + outputs = model.predict(inputs) + self.assertFalse(np.isnan(outputs).any()) + + # Create shapes + if groups is -1: + groups = input_shape[axis] + np_inputs = inputs.numpy() + reshaped_dims = list(np_inputs.shape) + reshaped_dims[axis] = reshaped_dims[axis] // groups + reshaped_dims.insert(1, groups) + reshaped_inputs = np.reshape(np_inputs, tuple(reshaped_dims)) + + # Calculate mean and variance + mean = np.mean( + reshaped_inputs, + axis=tuple(range(2, len(reshaped_dims))), + keepdims=True) + variance = np.var( + reshaped_inputs, + axis=tuple(range(2, len(reshaped_dims))), + keepdims=True) + + # Get gamma and beta initalized by layer + gamma, beta = layer._get_reshaped_weights(input_shape) + if gamma is None: + gamma = 1.0 + if beta is None: + beta = 0.0 + + # Get ouput from Numpy + zeroed = reshaped_inputs - mean + rsqrt = 1 / np.sqrt(variance + 1e-5) + output_test = gamma * zeroed * rsqrt + beta + + # compare outputs + output_test = np.reshape(output_test, input_shape.as_list()) + self.assertAlmostEqual(np.mean(output_test - outputs), 0, places=7) + + def _create_and_fit_Sequential_model(self, layer, shape): + # Helperfunction for quick evaluation + model = tf.keras.models.Sequential() + model.add(layer) + model.add(tf.keras.layers.Dense(32)) + model.add(tf.keras.layers.Dense(1)) + + model.compile( + optimizer=tf.keras.optimizers.RMSprop(0.01), + loss="categorical_crossentropy") + layer_shape = (10,) + shape + input_batch = np.random.rand(*layer_shape) + output_batch = np.random.rand(*(10, 1)) + model.fit(x=input_batch, y=output_batch, epochs=1, batch_size=1) + return model + + @test_utils.run_in_graph_and_eager_modes + def test_weights(self): + # Check if weights get initialized correctly + layer = GroupNormalization(groups=1, scale=False, center=False) + layer.build((None, 3, 4)) + self.assertEqual(len(layer.trainable_weights), 0) + self.assertEqual(len(layer.weights), 0) + + layer = LayerNormalization() + layer.build((None, 3, 4)) + self.assertEqual(len(layer.trainable_weights), 2) + self.assertEqual(len(layer.weights), 2) + + layer = InstanceNormalization() + layer.build((None, 3, 4)) + self.assertEqual(len(layer.trainable_weights), 2) + self.assertEqual(len(layer.weights), 2) + + def test_apply_normalization(self): + + input_shape = (1, 4) + expected_shape = (1, 2, 2) + reshaped_inputs = tf.constant([[[2.0, 2.0], [3.0, 3.0]]]) + layer = GroupNormalization(groups=2, axis=1, scale=False, center=False) + normalized_input = layer._apply_normalization(reshaped_inputs, + input_shape) + self.assertTrue( + tf.reduce_all( + tf.equal(normalized_input, + tf.constant([[[0.0, 0.0], [0.0, 0.0]]])))) + + def test_axis_error(self): + + with self.assertRaises(ValueError): + GroupNormalization(axis=0) + + @test_utils.run_in_graph_and_eager_modes + def test_groupnorm_flat(self): + # Check basic usage of groupnorm_flat + # Testing for 1 == LayerNorm, 16 == GroupNorm, -1 == InstanceNorm + + groups = [-1, 16, 1] + shape = (64,) + for i in groups: + model = self._create_and_fit_Sequential_model( + GroupNormalization(groups=i), shape) + self.assertTrue(hasattr(model.layers[0], 'gamma')) + self.assertTrue(hasattr(model.layers[0], 'beta')) + + @test_utils.run_in_graph_and_eager_modes + def test_layernorm_flat(self): + # Check basic usage of layernorm + + model = self._create_and_fit_Sequential_model(LayerNormalization(), + (64,)) + self.assertTrue(hasattr(model.layers[0], 'gamma')) + self.assertTrue(hasattr(model.layers[0], 'beta')) + + @test_utils.run_in_graph_and_eager_modes + def test_instancenorm_flat(self): + # Check basic usage of instancenorm + + model = self._create_and_fit_Sequential_model(InstanceNormalization(), + (64,)) + self.assertTrue(hasattr(model.layers[0], 'gamma')) + self.assertTrue(hasattr(model.layers[0], 'beta')) + + @test_utils.run_in_graph_and_eager_modes + def test_initializer(self): + # Check if the initializer for gamma and beta is working correctly + + layer = GroupNormalization( + groups=32, + beta_initializer='random_normal', + beta_constraint='NonNeg', + gamma_initializer='random_normal', + gamma_constraint='NonNeg') + + model = self._create_and_fit_Sequential_model(layer, (64,)) + + weights = np.array(model.layers[0].get_weights()) + negativ = weights[weights < 0.0] + self.assertTrue(len(negativ) == 0) + + @test_utils.run_in_graph_and_eager_modes + def test_regularizations(self): + + layer = GroupNormalization( + gamma_regularizer='l1', beta_regularizer='l1', groups=4, axis=2) + layer.build((None, 4, 4)) + self.assertEqual(len(layer.losses), 2) + max_norm = tf.keras.constraints.max_norm + layer = GroupNormalization( + gamma_constraint=max_norm, beta_constraint=max_norm) + layer.build((None, 3, 4)) + self.assertEqual(layer.gamma.constraint, max_norm) + self.assertEqual(layer.beta.constraint, max_norm) + + @test_utils.run_in_graph_and_eager_modes + def test_groupnorm_conv(self): + # Check if Axis is working for CONV nets + # Testing for 1 == LayerNorm, 5 == GroupNorm, -1 == InstanceNorm + + groups = [-1, 5, 1] + for i in groups: + model = tf.keras.models.Sequential() + model.add( + GroupNormalization(axis=1, groups=i, input_shape=(20, 20, 3))) + model.add(tf.keras.layers.Conv2D(5, (1, 1), padding='same')) + model.add(tf.keras.layers.Flatten()) + model.add(tf.keras.layers.Dense(1, activation='softmax')) + model.compile( + optimizer=tf.keras.optimizers.RMSprop(0.01), loss='mse') + x = np.random.randint(1000, size=(10, 20, 20, 3)) + y = np.random.randint(1000, size=(10, 1)) + a = model.fit(x=x, y=y, epochs=1) + self.assertTrue(hasattr(model.layers[0], 'gamma')) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/layers/python/poincare_test.py b/tensorflow_addons/layers/python/poincare_test.py index 1f7d3ea386..b1e668bd17 100644 --- a/tensorflow_addons/layers/python/poincare_test.py +++ b/tensorflow_addons/layers/python/poincare_test.py @@ -21,10 +21,11 @@ import numpy as np import tensorflow as tf -from tensorflow.python.keras import testing_utils as keras_test_util from tensorflow_addons.layers.python.poincare import PoincareNormalize +from tensorflow_addons.utils.python import test_utils +@test_utils.run_all_in_graph_and_eager_modes class PoincareNormalizeTest(tf.test.TestCase): def _PoincareNormalize(self, x, dim, epsilon=1e-5): if isinstance(dim, list): @@ -48,7 +49,7 @@ def testPoincareNormalize(self): for dim in range(len(x_shape)): outputs_expected = self._PoincareNormalize(inputs, dim, epsilon) - outputs = keras_test_util.layer_test( + outputs = test_utils.layer_test( PoincareNormalize, kwargs={ 'axis': dim, @@ -70,7 +71,7 @@ def testPoincareNormalizeDimArray(self): outputs_expected = self._PoincareNormalize(inputs, dim, epsilon) - outputs = keras_test_util.layer_test( + outputs = test_utils.layer_test( PoincareNormalize, kwargs={ 'axis': dim, diff --git a/tensorflow_addons/layers/python/sparsemax_test.py b/tensorflow_addons/layers/python/sparsemax_test.py index 3cb375418e..6796982c2f 100644 --- a/tensorflow_addons/layers/python/sparsemax_test.py +++ b/tensorflow_addons/layers/python/sparsemax_test.py @@ -60,7 +60,7 @@ def test_sparsemax_layer_against_numpy(self, dtype=None): z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype) test_utils.layer_test( - layer_cls=Sparsemax, + Sparsemax, input_data=z, expected_output=_np_sparsemax(z).astype(dtype)) diff --git a/tensorflow_addons/layers/python/wrappers_test.py b/tensorflow_addons/layers/python/wrappers_test.py index 371b9991f8..4e278b4ddc 100644 --- a/tensorflow_addons/layers/python/wrappers_test.py +++ b/tensorflow_addons/layers/python/wrappers_test.py @@ -20,13 +20,12 @@ import numpy as np import tensorflow as tf -from tensorflow.python.framework import test_util as tf_test_util -from tensorflow.python.keras import testing_utils as keras_test_util from tensorflow_addons.layers.python import wrappers +from tensorflow_addons.utils.python import test_utils +@test_utils.run_all_in_graph_and_eager_modes class WeightNormalizationTest(tf.test.TestCase): - @tf_test_util.run_all_in_graph_and_eager_modes def test_weightnorm_dense_train(self): model = tf.keras.models.Sequential() model.add( @@ -43,7 +42,6 @@ def test_weightnorm_dense_train(self): batch_size=10) self.assertTrue(hasattr(model.layers[0].layer, 'g')) - @tf_test_util.run_all_in_graph_and_eager_modes def test_weightnorm_dense_train_notinit(self): model = tf.keras.models.Sequential() model.add( @@ -60,7 +58,6 @@ def test_weightnorm_dense_train_notinit(self): batch_size=10) self.assertTrue(hasattr(model.layers[0].layer, 'g')) - @tf_test_util.run_all_in_graph_and_eager_modes def test_weightnorm_conv2d(self): model = tf.keras.models.Sequential() model.add( @@ -79,7 +76,6 @@ def test_weightnorm_conv2d(self): self.assertTrue(hasattr(model.layers[0].layer, 'g')) - @tf_test_util.run_all_in_graph_and_eager_modes def test_weightnorm_tflayers(self): images = tf.random.uniform((2, 4, 4, 3)) wn_wrapper = wrappers.WeightNormalization( @@ -87,13 +83,11 @@ def test_weightnorm_tflayers(self): wn_wrapper.apply(images) self.assertTrue(hasattr(wn_wrapper.layer, 'g')) - @tf_test_util.run_all_in_graph_and_eager_modes def test_weightnorm_nonlayer(self): images = tf.random.uniform((2, 4, 43)) with self.assertRaises(ValueError): wrappers.WeightNormalization(images) - @tf_test_util.run_all_in_graph_and_eager_modes def test_weightnorm_nokernel(self): with self.assertRaises(ValueError): wrappers.WeightNormalization(tf.keras.layers.MaxPooling2D( @@ -101,7 +95,7 @@ def test_weightnorm_nokernel(self): def test_weightnorm_keras(self): input_data = np.random.random((10, 3, 4)).astype(np.float32) - outputs = keras_test_util.layer_test( + outputs = test_utils.layer_test( wrappers.WeightNormalization, kwargs={ 'layer': tf.keras.layers.Dense(2), diff --git a/tensorflow_addons/losses/README.md b/tensorflow_addons/losses/README.md index 1298efc712..390d828680 100644 --- a/tensorflow_addons/losses/README.md +++ b/tensorflow_addons/losses/README.md @@ -21,7 +21,9 @@ must: * Simple unittests that demonstrate the loss is behaving as expected on some set of known inputs and outputs. * When applicable, run all tests with TensorFlow's - `@run_all_in_graph_and_eager_modes` decorator. + `@run_in_graph_and_eager_modes` (for test method) + or `run_all_in_graph_and_eager_modes` (for TestCase subclass) + decorator. * Add a `py_test` to this sub-package's BUILD file. #### Documentation Requirements diff --git a/tensorflow_addons/losses/python/lifted_test.py b/tensorflow_addons/losses/python/lifted_test.py index 8bbd9a0f98..51bb95cc6e 100644 --- a/tensorflow_addons/losses/python/lifted_test.py +++ b/tensorflow_addons/losses/python/lifted_test.py @@ -21,8 +21,8 @@ import numpy as np import tensorflow as tf -from tensorflow.python.framework import test_util as tf_test_util from tensorflow_addons.losses.python import lifted +from tensorflow_addons.utils.python import test_utils def pairwise_distance_np(feature, squared=False): @@ -52,8 +52,8 @@ def pairwise_distance_np(feature, squared=False): return pairwise_distances +@test_utils.run_all_in_graph_and_eager_modes class LiftedStructLossTest(tf.test.TestCase): - @tf_test_util.run_all_in_graph_and_eager_modes def testLiftedStruct(self): num_data = 10 feat_dim = 6 diff --git a/tensorflow_addons/losses/python/triplet_test.py b/tensorflow_addons/losses/python/triplet_test.py index 9f9f947ccc..bf100d89d7 100644 --- a/tensorflow_addons/losses/python/triplet_test.py +++ b/tensorflow_addons/losses/python/triplet_test.py @@ -20,8 +20,8 @@ import numpy as np import tensorflow as tf -from tensorflow.python.framework import test_util as tf_test_util from tensorflow_addons.losses.python import triplet +from tensorflow_addons.utils.python import test_utils def pairwise_distance_np(feature, squared=False): @@ -51,7 +51,7 @@ def pairwise_distance_np(feature, squared=False): return pairwise_distances -@tf_test_util.run_all_in_graph_and_eager_modes +@test_utils.run_all_in_graph_and_eager_modes class TripletSemiHardLossTest(tf.test.TestCase): def test_unweighted(self): num_data = 10 diff --git a/tensorflow_addons/optimizers/README.md b/tensorflow_addons/optimizers/README.md index 85dea6e067..ec4ad82574 100644 --- a/tensorflow_addons/optimizers/README.md +++ b/tensorflow_addons/optimizers/README.md @@ -17,7 +17,9 @@ must: #### Testing Requirements * When applicable, run all tests with TensorFlow's - `@run_all_in_graph_and_eager_modes` decorator. + `@run_in_graph_and_eager_modes` (for test method) + or `run_all_in_graph_and_eager_modes` (for TestCase subclass) + decorator. * Add a `py_test` to this sub-package's BUILD file. #### Documentation Requirements diff --git a/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py b/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py index e09b78a079..5528130d6e 100644 --- a/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py +++ b/tensorflow_addons/optimizers/python/lazy_adam_optimizer_test.py @@ -22,9 +22,9 @@ import tensorflow as tf from tensorflow.python.eager import context -from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.ops import variables from tensorflow_addons.optimizers.python import lazy_adam_optimizer +from tensorflow_addons.utils.python import test_utils def adam_update_numpy(param, @@ -57,7 +57,7 @@ def get_beta_accumulators(opt, dtype): class LazyAdamOptimizerTest(tf.test.TestCase): # TODO: remove v1 tests (keep pace with adam_test.py in keras). - @tf_test_util.run_deprecated_v1 + @test_utils.run_deprecated_v1 def testSparse(self): for dtype in [tf.dtypes.half, tf.dtypes.float32, tf.dtypes.float64]: with self.cached_session(): @@ -109,7 +109,7 @@ def testSparse(self): self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) - @tf_test_util.run_deprecated_v1 + @test_utils.run_deprecated_v1 def testSparseDevicePlacement(self): for index_dtype in [tf.dtypes.int32, tf.dtypes.int64]: with self.cached_session(force_gpu=tf.test.is_gpu_available()): @@ -123,7 +123,7 @@ def testSparseDevicePlacement(self): self.evaluate(variables.global_variables_initializer()) self.evaluate(minimize_op) - @tf_test_util.run_deprecated_v1 + @test_utils.run_deprecated_v1 def testSparseRepeatedIndices(self): for dtype in [tf.dtypes.half, tf.dtypes.float32, tf.dtypes.float64]: with self.cached_session(): @@ -218,7 +218,7 @@ def doTestBasic(self, use_callable_params=False): self.assertEqual("var0_%d/m:0" % (i,), opt.get_slot(var0, "m").name) - @tf_test_util.run_in_graph_and_eager_modes(reset_test=True) + @test_utils.run_in_graph_and_eager_modes(reset_test=True) def testResourceBasic(self): self.doTestBasic() @@ -226,7 +226,7 @@ def testBasicCallableParams(self): with context.eager_mode(): self.doTestBasic(use_callable_params=True) - @tf_test_util.run_deprecated_v1 + @test_utils.run_deprecated_v1 def testTensorLearningRate(self): for dtype in [tf.dtypes.half, tf.dtypes.float32, tf.dtypes.float64]: with self.cached_session(): @@ -270,7 +270,7 @@ def testTensorLearningRate(self): self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) - @tf_test_util.run_deprecated_v1 + @test_utils.run_deprecated_v1 def testSharing(self): for dtype in [tf.dtypes.half, tf.dtypes.float32, tf.dtypes.float64]: with self.cached_session(): diff --git a/tensorflow_addons/utils/python/test_utils.py b/tensorflow_addons/utils/python/test_utils.py index 0248536844..af2c9ccd15 100644 --- a/tensorflow_addons/utils/python/test_utils.py +++ b/tensorflow_addons/utils/python/test_utils.py @@ -20,9 +20,15 @@ import inspect import unittest +# yapf: disable +# pylint: disable=unused-import # TODO: find public API alternative to these -from tensorflow.python.keras.testing_utils import layer_test # pylint: disable=unused-import -from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes # pylint: disable=unused-import +from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes +from tensorflow.python.framework.test_util import run_deprecated_v1 +from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes +from tensorflow.python.keras.testing_utils import layer_test +# pylint: enable=unused-import +# yapf: enable def run_all_with_types(dtypes): diff --git a/tools/ci_testing/addons_gpu.sh b/tools/ci_testing/addons_gpu.sh new file mode 100644 index 0000000000..115bd0b023 --- /dev/null +++ b/tools/ci_testing/addons_gpu.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ============================================================================== +# Make sure we're in the project root path. +SCRIPT_DIR=$( cd ${0%/*} && pwd -P ) +ROOT_DIR=$( cd "$SCRIPT_DIR/.." && pwd -P ) +if [[ ! -d "tensorflow_addons" ]]; then + echo "ERROR: PWD: $PWD is not project root" + exit 1 +fi + +set -x + +N_JOBS=$(grep -c ^processor /proc/cpuinfo) + +echo "" +echo "Bazel will use ${N_JOBS} concurrent job(s)." +echo "" + +export CC_OPT_FLAGS='-mavx' +export TF_NEED_CUDA=1 + +export PYTHON_BIN_PATH=`which python` +# Use default configuration here. +yes 'y' | ./configure.sh + +## Run bazel test command. Double test timeouts to avoid flakes. +bazel test -c opt -k \ + --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ + --test_output=errors --local_test_jobs=8 \ + //tensorflow_addons/... + +exit $?