Skip to content

Commit

Permalink
Optimize analyzers and mappers for faster inference in TF2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604944502
  • Loading branch information
tf-transform-team authored and tfx-copybara committed Feb 7, 2024
1 parent c4e6066 commit fb7688c
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 22 deletions.
2 changes: 1 addition & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

## Breaking Changes

* Existing `tft.vocabulary` cache is automatically invalidated.
* Existing analyzer cache is automatically invalidated.

## Deprecations

Expand Down
6 changes: 1 addition & 5 deletions tensorflow_transform/analyzer_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,7 @@ def _bind_future_as_tensor_v2(
replaced_result)
return replaced_result
else:
# Without the identity wrapper some V2 tests fail with AttributeError:
# Tensor.name is meaningless when eager execution is enabled.
# TODO(b/149997088): Remove the identity wrapper once we no longer rely on
# tensor names.
return tf.identity(replaced_result)
return replaced_result
else:
graph.add_to_collection(TENSOR_REPLACEMENTS, tensor_sink)
eager_asset_path = temporary_analyzer_info.eager_asset_path
Expand Down
5 changes: 0 additions & 5 deletions tensorflow_transform/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,15 +1973,10 @@ def _get_vocabulary_analyzer_inputs(
elif vocab_ordering_type == _VocabOrderingType.WEIGHTED_FREQUENCY:
reduced_batch = tf_utils.reduce_batch_weighted_counts(
x, weights, filter_regex=filter_regex)
assert reduced_batch.summed_positive_per_x_and_y is None
assert reduced_batch.counts_per_x is None
return [reduced_batch.unique_x, reduced_batch.summed_weights_per_x]
else:
reduced_batch = tf_utils.reduce_batch_weighted_counts(
x, filter_regex=filter_regex)
assert reduced_batch.summed_weights_per_x is None
assert reduced_batch.summed_positive_per_x_and_y is None
assert reduced_batch.counts_per_x is None
return [reduced_batch.unique_x]


Expand Down
5 changes: 3 additions & 2 deletions tensorflow_transform/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,11 +1449,12 @@ def _deduplicate_row(dedup_row_loop_vars):

# Keep track of the maximum number of unique elements in a row, as this
# will determine the resulting dense shape.
num_unique_values = tf.shape(row_values)[0]
max_unique = tf.cast(
tf.maximum(tf.cast(tf.shape(row_values)[0], tf.int64), max_unique),
tf.maximum(tf.cast(num_unique_values, tf.int64), max_unique),
tf.int64)
column_indices = tf.cast(
tf.expand_dims(tf.range(tf.shape(row_values)[0]), axis=1), tf.int64)
tf.expand_dims(tf.range(num_unique_values), axis=1), tf.int64)
row_indices = tf.fill(tf.shape(column_indices), tf.cast(index, tf.int64))
values = values.write(index, row_values)
indices = indices.write(index, tf.concat([row_indices, column_indices], 1))
Expand Down
11 changes: 2 additions & 9 deletions tensorflow_transform/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,6 @@ def reduce_batch_weighted_counts(
else:
# TODO(b/112916494): Always do batch wise reduction once possible.
return ReducedBatchWeightedCounts(flat_x, None, None, None)
# TODO(b/134075780): Revisit expected weights shape when input is composite.
x, weights = assert_same_shape(x, weights)
weights = filter_fn(tf.reshape(weights, [-1]))
unique_x_values, unique_idx, _ = tf.unique_with_counts(
flat_x, out_idx=tf.int64)
Expand Down Expand Up @@ -410,7 +408,6 @@ def _preprocess_tensors_for_cooccurences(
x, weights_input = assert_same_shape(x, weights_input)
weights = weights_input
y = _broadcast_to_x_shape(x, y)
x, y = assert_same_shape(x, y)
x = tf.reshape(x, [-1])
filter_fn = _make_regex_filter_fn(x, filter_regex)
x = filter_fn(x)
Expand Down Expand Up @@ -593,8 +590,7 @@ def _broadcast_to_x_shape(x, y):
y_shape = tf.shape(input=y)
assert_eq = tf.compat.v1.assert_equal(x_shape[0], y_shape[0])
with tf.control_dependencies([assert_eq]):
y = tf.identity(y)
rank_delta = tf.rank(x) - tf.rank(y)
rank_delta = tf.rank(x) - tf.rank(y)
target_shape = tf.concat(
[tf.shape(y), tf.ones(rank_delta, dtype=tf.int32)], axis=0)
matched_rank = tf.reshape(y, target_shape)
Expand Down Expand Up @@ -1756,7 +1752,7 @@ def reduce_batch_minus_min_and_max(

x_batch_max = tf.reduce_max(input_tensor=x)
x_batch_minus_min = tf.reduce_max(input_tensor=tf.zeros_like(x) - x)
return assert_same_shape(x_batch_minus_min, x_batch_max)
return x_batch_minus_min, x_batch_max

elif isinstance(x, tf.SparseTensor):
return _sparse_minus_reduce_min_and_reduce_max(x)
Expand Down Expand Up @@ -1820,9 +1816,6 @@ def get_batch_max_per_key(tensor, key_uniques): # pylint: disable=missing-docst
x_batch_maxes = get_batch_max_per_key(x, unique)
x_batch_minus_mins = get_batch_max_per_key(-x, unique)

x_batch_minus_mins, x_batch_maxes = assert_same_shape(x_batch_minus_mins,
x_batch_maxes)

return (unique.y, x_batch_minus_mins, x_batch_maxes)


Expand Down

0 comments on commit fb7688c

Please sign in to comment.