Skip to content

Commit

Permalink
Merge branch 'patch-op-names' into r0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
seanpmorgan committed Mar 17, 2019
2 parents a116402 + 12b5053 commit a9076f1
Show file tree
Hide file tree
Showing 35 changed files with 1,735 additions and 80 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Try those commands below:
Addons provides `make code-format` command to format your changes
automatically, don't forget to use it before pushing your codes.

Please see our [Style Guide](SYLE_GUIDE.md) for more details.
Please see our [Style Guide](STYLE_GUIDE.md) for more details.

## Code Testing
#### CI Testing
Expand Down
35 changes: 30 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# TensorFlow Addons

[![PyPI Status Badge](https://badge.fury.io/py/tensorflow-addons.svg)](https://pypi.org/project/tensorflow-addons/)
[![Gitter chat](https://img.shields.io/badge/chat-on%20gitter-46bc99.svg)](https://gitter.im/tensorflow/sig-addons)

### Official Builds

| Build Type | Status |
| --- | --- |
| **Linux Py2 CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-py2.html) |
| **Linux Py3 CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-py3.html) |
| **Linux Py2 GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py2.html) |
| **Linux Py3 GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py3.html) |

TensorFlow Addons is a repository of contributions that conform to
well-established API patterns, but implement new functionality
not available in core TensorFlow. TensorFlow natively supports
Expand All @@ -13,9 +25,14 @@ developments that cannot be integrated into core TensorFlow
| Sub-Package | Addon | Reference |
|:----------------------- |:----------- |:---------------------------- |
| tfa.activations | Sparsemax | https://arxiv.org/abs/1602.02068 |
| tfa.image | adjust_hsv_in_yiq | |
| tfa.image | random_hsv_in_yiq | |
| tfa.image | transform | |
| tfa.layers | GroupNormalization | https://arxiv.org/abs/1803.08494 |
| tfa.layers | InstanceNormalization | https://arxiv.org/abs/1607.08022 |
| tfa.layers | LayerNormalization | https://arxiv.org/abs/1607.06450 |
| tfa.layers | Maxout | https://arxiv.org/abs/1302.4389 |
| tfa.layers | PoinareNormalize | https://arxiv.org/abs/1705.08039 |
| tfa.layers | PoincareNormalize | https://arxiv.org/abs/1705.08039 |
| tfa.layers | WeightNormalization | https://arxiv.org/abs/1602.07868 |
| tfa.losses | LiftedStructLoss | https://arxiv.org/abs/1511.06452 |
| tfa.losses | SparsemaxLoss | https://arxiv.org/abs/1602.02068 |
Expand All @@ -33,9 +50,9 @@ the list we adhere to:


1) [Layers](tensorflow_addons/layers/README.md)
1) [Optimizers](tensorflow_addons/optimizers/README.md)
1) [Losses](tensorflow_addons/losses/README.md)
1) [Custom Ops](tensorflow_addons/custom_ops/README.md)
2) [Optimizers](tensorflow_addons/optimizers/README.md)
3) [Losses](tensorflow_addons/losses/README.md)
4) [Custom Ops](tensorflow_addons/custom_ops/README.md)

#### Periodic Evaluation
Based on the nature of this repository, there will be contributions that
Expand All @@ -44,7 +61,6 @@ maintainable, SIG-Addons will perform periodic reviews and deprecate
contributions which will be slated for removal. More information will
be available after we submit a formal request for comment.


## Examples
See [`tensorflow_addons/examples/`](tensorflow_addons/examples/)
for end-to-end examples of various addons.
Expand All @@ -56,6 +72,15 @@ To install the latest version, run the following:
pip install tensorflow-addons
```

**Note:** You will also need [TensorFlow 2.0 or higher](https://www.tensorflow.org/alpha).

To use addons:

```python
import tensorflow as tf
import tensorflow_addons as tfa
```

#### Installing from Source
You can also install from source. This requires the [Bazel](
https://bazel.build/) build system.
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_addons/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ py_library(

py_test(
name = "sparsemax_py_test",
size = "small",
size = "medium",
srcs = [
"python/sparsemax_test.py",
],
Expand Down
4 changes: 3 additions & 1 deletion tensorflow_addons/activations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ must:
#### Testing Requirements
* Simple unittests that demonstrate the layer is behaving as expected.
* When applicable, run all unittests with TensorFlow's
`@run_all_in_graph_and_eager_modes` decorator.
`@run_in_graph_and_eager_modes` (for test method)
or `run_all_in_graph_and_eager_modes` (for TestCase subclass)
decorator.
* Add a `py_test` to this sub-package's BUILD file.

#### Documentation Requirements
Expand Down
4 changes: 3 additions & 1 deletion tensorflow_addons/custom_ops/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ must:
* Simple unittests that demonstrate the custom op is behaving as
expected.
* When applicable, run all unittests with TensorFlow's
`@run_all_in_graph_and_eager_modes` decorator.
`@run_in_graph_and_eager_modes` (for test method)
or `run_all_in_graph_and_eager_modes` (for TestCase subclass)
decorator.
* Add a `py_test` to the custom-op's BUILD file.

#### Documentation Requirements
Expand Down
34 changes: 33 additions & 1 deletion tensorflow_addons/custom_ops/image/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@ licenses(["notice"]) # Apache 2.0

package(default_visibility = ["//visibility:public"])

cc_binary(
name = "python/_distort_image_ops.so",
srcs = [
"cc/kernels/adjust_hsv_in_yiq_op.cc",
"cc/kernels/adjust_hsv_in_yiq_op.h",
"cc/ops/distort_image_ops.cc",
],
linkshared = 1,
deps = [
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
],
copts = ["-pthread", "-std=c++11", "-D_GLIBCXX_USE_CXX11_ABI=0"]
)


cc_binary(
name = "python/_image_ops.so",
srcs = [
Expand All @@ -26,18 +42,34 @@ py_library(
srcs = ([
"__init__.py",
"python/__init__.py",
"python/distort_image_ops.py",
"python/transform.py",
]),
data = [
":python/_distort_image_ops.so",
":python/_image_ops.so",
"//tensorflow_addons/utils:utils_py",
],
srcs_version = "PY2AND3",
)

py_test(
name = "distort_image_ops_test",
size = "small",
srcs = [
"python/distort_image_ops_test.py",
],
main = "python/distort_image_ops_test.py",
deps = [
":images_ops_py",
],
srcs_version = "PY2AND3"
)

# TODO: use cuda_py_test later.
py_test(
name = "transform_ops_test",
size = "small",
size = "medium",
srcs = [
"python/transform_test.py",
],
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_addons/custom_ops/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@
from __future__ import division
from __future__ import print_function

from tensorflow_addons.custom_ops.image.python.distort_image_ops import adjust_hsv_in_yiq
from tensorflow_addons.custom_ops.image.python.distort_image_ops import random_hsv_in_yiq
# Transforms
from tensorflow_addons.custom_ops.image.python.transform import transform
169 changes: 169 additions & 0 deletions tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA

#include <memory>

#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/work_sharder.h"
#include "tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.h"

namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;

class AdjustHsvInYiqOpBase : public OpKernel {
protected:
explicit AdjustHsvInYiqOpBase(OpKernelConstruction* context)
: OpKernel(context) {}

struct ComputeOptions {
const Tensor* input = nullptr;
Tensor* output = nullptr;
const Tensor* delta_h = nullptr;
const Tensor* scale_s = nullptr;
const Tensor* scale_v = nullptr;
int64 channel_count = 0;
};

virtual void DoCompute(OpKernelContext* context,
const ComputeOptions& options) = 0;

void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const Tensor& delta_h = context->input(1);
const Tensor& scale_s = context->input(2);
const Tensor& scale_v = context->input(3);
OP_REQUIRES(context, input.dims() >= 3,
errors::InvalidArgument("input must be at least 3-D, got shape",
input.shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_h.shape()),
errors::InvalidArgument("delta_h must be scalar: ",
delta_h.shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_s.shape()),
errors::InvalidArgument("scale_s must be scalar: ",
scale_s.shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_v.shape()),
errors::InvalidArgument("scale_v must be scalar: ",
scale_v.shape().DebugString()));
auto channels = input.dim_size(input.dims() - 1);
OP_REQUIRES(
context, channels == kChannelSize,
errors::InvalidArgument("input must have 3 channels but instead has ",
channels, " channels."));

Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));

if (input.NumElements() > 0) {
const int64 channel_count = input.NumElements() / channels;
ComputeOptions options;
options.input = &input;
options.delta_h = &delta_h;
options.scale_s = &scale_s;
options.scale_v = &scale_v;
options.output = output;
options.channel_count = channel_count;
DoCompute(context, options);
}
}
};

template <class Device>
class AdjustHsvInYiqOp;

template <>
class AdjustHsvInYiqOp<CPUDevice> : public AdjustHsvInYiqOpBase {
public:
explicit AdjustHsvInYiqOp(OpKernelConstruction* context)
: AdjustHsvInYiqOpBase(context) {}

void DoCompute(OpKernelContext* context,
const ComputeOptions& options) override {
const Tensor* input = options.input;
Tensor* output = options.output;
const int64 channel_count = options.channel_count;
auto input_data = input->shaped<float, 2>({channel_count, kChannelSize});
const float delta_h = options.delta_h->scalar<float>()();
const float scale_s = options.scale_s->scalar<float>()();
const float scale_v = options.scale_v->scalar<float>()();
auto output_data = output->shaped<float, 2>({channel_count, kChannelSize});
float tranformation_matrix[kChannelSize * kChannelSize] = {0};
internal::compute_tranformation_matrix<kChannelSize * kChannelSize>(
delta_h, scale_s, scale_v, tranformation_matrix);
const int kCostPerChannel = 10;
const DeviceBase::CpuWorkerThreads& worker_threads =
*context->device()->tensorflow_cpu_worker_threads();
Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
kCostPerChannel, [&input_data, &output_data, &tranformation_matrix](
int64 start_channel, int64 end_channel) {
// Applying projection matrix to input RGB vectors.
const float* p = input_data.data() + start_channel * kChannelSize;
float* q = output_data.data() + start_channel * kChannelSize;
for (int i = start_channel; i < end_channel; i++) {
for (int q_index = 0; q_index < kChannelSize; q_index++) {
q[q_index] = 0;
for (int p_index = 0; p_index < kChannelSize; p_index++) {
q[q_index] +=
p[p_index] *
tranformation_matrix[q_index + kChannelSize * p_index];
}
}
p += kChannelSize;
q += kChannelSize;
}
});
}
};

REGISTER_KERNEL_BUILDER(
Name("AdjustHsvInYiq").Device(DEVICE_CPU).TypeConstraint<float>("T"),
AdjustHsvInYiqOp<CPUDevice>);

#if GOOGLE_CUDA
template <>
class AdjustHsvInYiqOp<GPUDevice> : public AdjustHsvInYiqOpBase {
public:
explicit AdjustHsvInYiqOp(OpKernelConstruction* context)
: AdjustHsvInYiqOpBase(context) {}

void DoCompute(OpKernelContext* ctx, const ComputeOptions& options) override {
const int64 number_of_elements = options.input->NumElements();
if (number_of_elements <= 0) {
return;
}
const float* delta_h = options.delta_h->flat<float>().data();
const float* scale_s = options.scale_s->flat<float>().data();
const float* scale_v = options.scale_v->flat<float>().data();
functor::AdjustHsvInYiqGPU()(ctx, options.channel_count, options.input,
delta_h, scale_s, scale_v, options.output);
}
};

REGISTER_KERNEL_BUILDER(
Name("AdjustHsvInYiq").Device(DEVICE_GPU).TypeConstraint<float>("T"),
AdjustHsvInYiqOp<GPUDevice>);
#endif

} // namespace tensorflow
Loading

0 comments on commit a9076f1

Please sign in to comment.