Skip to content

Commit

Permalink
Image/fix sparse image warp unknown batch size - r0.12 (#2311)
Browse files Browse the repository at this point in the history
* Fix sparse_image_warp unknown batch size

* More tests

Co-authored-by: Tzu-Wei Sung <[email protected]>
  • Loading branch information
seanpmorgan and WindQAQ authored Dec 22, 2020
1 parent 6f44559 commit d26e2ed
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
2 changes: 1 addition & 1 deletion tensorflow_addons/image/sparse_image_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _add_zero_flow_controls_at_boundary(
merged_control_point_flows: augmented set of control point flows.
"""

batch_size = tf.compat.dimension_value(control_point_locations.shape[0])
batch_size = tf.shape(control_point_locations)[0]

boundary_point_locations = _get_boundary_locations(
image_height, image_width, boundary_points_per_edge
Expand Down
44 changes: 38 additions & 6 deletions tensorflow_addons/image/tests/sparse_image_warp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""Tests for sparse_image_warp."""

from collections import namedtuple

import numpy as np
import pytest
import tensorflow as tf
Expand Down Expand Up @@ -247,23 +249,53 @@ def test_that_backprop_runs():
assert np.sum(np.abs(gradients)) != 0


ShapeConfig = namedtuple(
"ShapeConfig",
[
"image",
"source_control_point_locations",
"dest_control_point_locations",
"input",
],
)


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("shape", [(9, 12), (9, 12, 3), (1, 9, 12, 3)])
@pytest.mark.parametrize(
"shape",
[
ShapeConfig(None, None, None, (1, 9, 12, 3)),
ShapeConfig(None, [1, 1, 2], [1, 1, 2], (9, 12)),
ShapeConfig(None, [1, 1, 2], [1, 1, 2], (9, 12, 3)),
ShapeConfig(None, [1, 1, 2], [1, 1, 2], (1, 9, 12, 3)),
ShapeConfig([None, 9, 12, 3], [None, 1, 2], [None, 1, 2], (1, 9, 12, 3)),
ShapeConfig([None, None, None, 3], [None, 1, 2], [None, 1, 2], (1, 9, 12, 3)),
ShapeConfig(
[None, None, None, None], [None, 1, 2], [None, 1, 2], (1, 9, 12, 3)
),
],
)
@pytest.mark.parametrize("interpolation_order", [1, 2, 3])
@pytest.mark.parametrize("num_boundary_points", [1, 2, 3])
def test_unknown_shape(shape, interpolation_order, num_boundary_points):
def test_partially_or_fully_unknown_shape(
shape, interpolation_order, num_boundary_points
):
control_point_locations = np.asarray([3.0, 3.0]).reshape(1, 1, 2).astype(np.float32)
control_point_displacements = (
np.asarray([0.25, -0.5]).reshape(1, 1, 2).astype(np.float32)
)
fn = tf.function(sparse_image_warp).get_concrete_function(
image=tf.TensorSpec(shape=None, dtype=tf.float32),
source_control_point_locations=tf.TensorSpec(shape=[1, 1, 2], dtype=tf.float32),
dest_control_point_locations=tf.TensorSpec(shape=[1, 1, 2], dtype=tf.float32),
image=tf.TensorSpec(shape=shape.image, dtype=tf.float32),
source_control_point_locations=tf.TensorSpec(
shape=shape.source_control_point_locations, dtype=tf.float32
),
dest_control_point_locations=tf.TensorSpec(
shape=shape.dest_control_point_locations, dtype=tf.float32
),
interpolation_order=interpolation_order,
num_boundary_points=num_boundary_points,
)
image = tf.ones(shape=shape, dtype=tf.float32)
image = tf.ones(shape=shape.input, dtype=tf.float32)
expected_output = sparse_image_warp(
image,
control_point_locations,
Expand Down

0 comments on commit d26e2ed

Please sign in to comment.