Skip to content

Commit

Permalink
Simplify create_decoder_preprocessor and add argument for `input_fe…
Browse files Browse the repository at this point in the history
…ature` that supports targets-only scoring.

PiperOrigin-RevId: 555491380
  • Loading branch information
adarob authored and t5-copybara committed Aug 14, 2023
1 parent 97d3bbf commit 7fe71f4
Showing 1 changed file with 24 additions and 48 deletions.
72 changes: 24 additions & 48 deletions t5x/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ def create_decoder_preprocessor(
output_features: Mapping[str, seqio.Feature],
task_feature_lengths: Mapping[str, int],
tokenized_inputs: bool = False,
input_feature: str = 'inputs',
) -> PreprocessorFn:
"""Returns a function to tokenize and featurize inputs for decoder only models.
Expand All @@ -657,59 +658,41 @@ def create_decoder_preprocessor(
tokenized_inputs: specifies whether the input is expected to be
pre-tokenized. If so, the preprocessor expects an int32 tensor padded with
0s to shape of [B, N] rather than a string tensor of shape [B].
input_feature: Name of the feature provided by `input_texts`, e.g., 'inputs'
or 'targets'.
"""

def preprocess(input_texts: tf.Tensor) -> Mapping[str, tf.Tensor]:
"""TF-based preprocessor that takes a batch of text and converts it to model features."""

if tokenized_inputs:
inputs = input_texts # actually an int32 tensor of shape [B, N].
targets = tf.broadcast_to(
tf.constant(0, dtype=tf.int32), tf.shape(input_texts))
else:
inputs = input_texts
targets = tf.broadcast_to(tf.constant(''), tf.shape(input_texts))

def tokenize(text, k):
vocab = output_features[k].vocabulary # type: seqio.Vocabulary
def tokenize(text):
feature = output_features[input_feature]
vocab = feature.vocabulary # type: seqio.Vocabulary
if not tokenized_inputs: # if inputs are tokenized, we don't re-tokenize.
t = vocab.encode_tf(text)
else:
t = text
if output_features[k].add_eos:
if feature.add_eos:
t = tf.concat([t, [vocab.eos_id]], axis=-1)
return t

decoder_input_tokens = tf.map_fn(
functools.partial(tokenize, k='inputs'),
inputs,
decoder_tokens = tf.map_fn(
tokenize,
input_texts,
fn_output_signature=(tf.int32),
)

decoder_target_tokens = tf.map_fn(
functools.partial(tokenize, k='targets'),
targets,
fn_output_signature=(tf.int32),
)

decoder_target_tokens = tf.concat(
[decoder_input_tokens, decoder_target_tokens], axis=-1
)

# Create 'inputs_width' tensor in the same shape as decoder_target_tokens.
# It is the length of 'inputs' (excluding padding 0 values) tiled across
# length dimension and 'inputs_width_add_pos' is the same except that it
# has one additional position tensor.
ragged_input_tokens = tf.RaggedTensor.from_tensor(
decoder_input_tokens, padding=0
)
inputs_length = tf.cast(ragged_input_tokens.row_lengths(), dtype=tf.int32)
inputs_length = tf.expand_dims(inputs_length, -1)
if output_features['inputs'].add_eos:
inputs_length -= 1
ones_like_target = tf.ones(tf.shape(decoder_target_tokens), dtype=tf.int32)
inputs_width = tf.multiply(ones_like_target, inputs_length)
inputs_width_add_pos = tf.multiply(ones_like_target, inputs_length + 1)
if input_feature == 'inputs':
# 'inputs_width' is the length of 'inputs' (excluding padding 0).
ragged_input_tokens = tf.RaggedTensor.from_tensor(
decoder_tokens, padding=0
)
inputs_length = tf.cast(ragged_input_tokens.row_lengths(), dtype=tf.int32)
inputs_width = tf.expand_dims(inputs_length, -1)
inputs_width_add_pos = inputs_width + 1
else:
inputs_width = tf.zeros(tf.shape(decoder_tokens)[0], dtype=tf.int32)
inputs_width_add_pos = inputs_width

def featurize(text, length):
text = text[:length]
Expand All @@ -721,18 +704,10 @@ def featurize(text, length):
return text, ar_inputs, loss_weights

targets_length = sum(task_feature_lengths.values())
inputs_width, _, _ = tf.map_fn(
functools.partial(featurize, length=targets_length),
inputs_width,
fn_output_signature=(tf.int32, tf.int32, tf.int32))
inputs_width_add_pos, _, _ = tf.map_fn(
functools.partial(featurize, length=targets_length),
inputs_width_add_pos,
fn_output_signature=(tf.int32, tf.int32, tf.int32))
decoder_target_tokens, decoder_input_tokens, decoder_loss_weights = (
tf.map_fn(
functools.partial(featurize, length=targets_length),
decoder_target_tokens,
decoder_tokens,
fn_output_signature=(tf.int32, tf.int32, tf.int32),
)
)
Expand All @@ -742,7 +717,8 @@ def featurize(text, length):
axis=0)

decoder_causal_attention = tf.cast(
positions < inputs_width_add_pos, dtype=decoder_target_tokens.dtype)
positions < inputs_width_add_pos, dtype=decoder_target_tokens.dtype
)

inputs = positions < inputs_width
padding_mask = tf.cast(decoder_loss_weights, dtype=tf.bool)
Expand Down

0 comments on commit 7fe71f4

Please sign in to comment.