Skip to content

Commit

Permalink
Fix broken type annotations, and replace pytype disables with typing.…
Browse files Browse the repository at this point in the history
…cast

PiperOrigin-RevId: 560130807
  • Loading branch information
T5X Team authored and t5-copybara committed Aug 25, 2023
1 parent 828e910 commit 05ea768
Showing 1 changed file with 50 additions and 55 deletions.
105 changes: 50 additions & 55 deletions t5x/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import dataclasses
import functools
import inspect
import typing
from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Union

from absl import logging
Expand Down Expand Up @@ -210,17 +211,15 @@ def predict_batch_with_aux(
predictions: the model predictions
aux: auxiliary data
"""
pass

@abc.abstractmethod
def score_batch(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
return_intermediates: bool = False,
) -> jnp.ndarray:
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, Any]]]:
"""Computes scores for batch."""
pass

@abc.abstractmethod
def get_initial_variables(
Expand All @@ -230,7 +229,6 @@ def get_initial_variables(
input_types: Optional[Mapping[str, jnp.dtype]] = None,
) -> flax_scope.FrozenVariableDict:
"""Returns the initial variables of the model."""
pass


class BaseTransformerModel(BaseModel):
Expand Down Expand Up @@ -281,9 +279,8 @@ def _compute_logits(
params: PyTree,
batch: Mapping[str, jnp.ndarray],
dropout_rng: Optional[jax.random.KeyArray] = None,
) -> jnp.ndarray:
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]:
"""Computes logits via a forward pass of the model."""
pass

def loss_fn(
self,
Expand Down Expand Up @@ -384,9 +381,8 @@ def __init__(
default_decoder_params: Optional[DecoderParams] = None,
):
if feature_converter_cls is not None:
self.FEATURE_CONVERTER_CLS = (
feature_converter_cls # pylint: disable=invalid-name
)
# pylint: disable-next=invalid-name
self.FEATURE_CONVERTER_CLS = feature_converter_cls
self._default_decoder_params = default_decoder_params or DecoderParams()
super().__init__(
module=module,
Expand Down Expand Up @@ -455,7 +451,7 @@ def get_initial_variables(
)
return initial_variables

def _compute_logits( # pytype: disable=signature-mismatch # jax-ndarray
def _compute_logits(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
Expand Down Expand Up @@ -665,7 +661,7 @@ def predict_batch_with_aux(
else:
decoder_prompt_inputs = jnp.zeros_like(decoder_input_tokens)

encoded_inputs = self.module.apply(
encoded_inputs: jnp.ndarray = self.module.apply(
{'params': params},
encoder_input_tokens,
enable_dropout=False,
Expand Down Expand Up @@ -759,7 +755,7 @@ def predict_batch_with_aux(
else:
return decodes[:, -1, :], {'scores': scores[:, -1]}

def score_batch( # pytype: disable=signature-mismatch # jax-ndarray
def score_batch(
self,
params: PyTree,
batch: Mapping[str, jnp.ndarray],
Expand All @@ -773,23 +769,9 @@ def score_batch( # pytype: disable=signature-mismatch # jax-ndarray
logits, modified_variables = self._compute_logits(
params=params, batch=batch, mutable=['intermediates']
)

# Inside self.module, we called nn.Module.sow to track various
# intermediate values. We extract them here.
intermediates = flax_core.unfreeze(
modified_variables.get('intermediates', {})
)

# Track per-token labels and loss weights as well. These are not
# intermediate values of logit computation, so we manually add them here.
intermediates.setdefault('decoder', {})
intermediates['decoder']['target_tokens'] = (target_tokens,)
intermediates['decoder']['loss_weights'] = (weights,)
# Note that the values are singleton tuples. This is because values inside
# `intermediates` should be tuples tracking all instantiations of a value.
# These values each have just one instantiation, hence singletons.
else:
logits = self._compute_logits(params, batch) # type: jnp.ndarray # pytype: disable=annotation-type-mismatch # jax-ndarray
logits = typing.cast(jnp.ndarray, self._compute_logits(params, batch))
modified_variables = {}

# Purposefully don't use config.z_loss because that term is for training
# stability and shouldn't affect our reported scores.
Expand All @@ -803,12 +785,26 @@ def score_batch( # pytype: disable=signature-mismatch # jax-ndarray
)[0]
* weights
)
if return_intermediates:
intermediates['decoder']['token_scores'] = (token_scores,)

sequence_scores = token_scores.sum(-1)

if return_intermediates:

# Inside self.module, we called nn.Module.sow to track various
# intermediate values. We extract them here.
intermediates = flax_core.unfreeze(
modified_variables.get('intermediates', {})
)

# Track per-token labels and loss weights as well. These are not
# intermediate values of logit computation, so we manually add them here.
intermediates.setdefault('decoder', {})
intermediates['decoder']['target_tokens'] = (target_tokens,)
intermediates['decoder']['loss_weights'] = (weights,)
# Note that the values are singleton tuples. This is because values inside
# `intermediates` should be tuples tracking all instantiations of a value.
# These values each have just one instantiation, hence singletons.
intermediates['decoder']['token_scores'] = (token_scores,)
return sequence_scores, intermediates

return sequence_scores
Expand Down Expand Up @@ -847,9 +843,8 @@ def __init__(
] = None,
):
if feature_converter_cls is not None:
self.FEATURE_CONVERTER_CLS = (
feature_converter_cls # pylint: disable=invalid-name
)
# pylint: disable-next=invalid-name
self.FEATURE_CONVERTER_CLS = feature_converter_cls
self._inputs_bidirectional_attention = inputs_bidirectional_attention
super().__init__(
module,
Expand Down Expand Up @@ -901,7 +896,7 @@ def _compute_logits(
dropout_rng: Optional[jax.random.KeyArray] = None,
mutable: flax_scope.CollectionFilter = False,
other_variables: Optional[PyTree] = None,
) -> jnp.ndarray:
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]:
"""Computes logits via a forward pass of `self.module`."""
rngs = {'dropout': dropout_rng} if dropout_rng is not None else None
decoder_causal_attention = self._get_decoder_causal_attention(batch)
Expand Down Expand Up @@ -954,7 +949,7 @@ def score_batch(
params: PyTree,
batch: Mapping[str, jnp.ndarray],
return_intermediates: bool = False,
) -> jnp.ndarray:
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, Any]]]:
"""Compute log likelihood score on a batch."""

decoder_target_tokens = batch['decoder_target_tokens']
Expand All @@ -967,7 +962,26 @@ def score_batch(
dropout_rng=None,
mutable=['intermediates'],
)
else:
logits = typing.cast(
jnp.ndarray,
self._compute_logits(params=params, batch=batch, dropout_rng=None),
)
modified_variables = {}

token_scores = (
-losses.cross_entropy_with_logits(
logits,
common_utils.onehot(
decoder_target_tokens, logits.shape[-1], on_value=1, off_value=0
),
z_loss=0.0,
)[0]
* weights
)
sequence_scores = token_scores.sum(-1)

if return_intermediates:
# Inside self.module, we called nn.Module.sow to track various
# intermediate values. We extract them here.
intermediates = flax_core.unfreeze(
Expand All @@ -982,28 +996,9 @@ def score_batch(
# Note that the values are singleton tuples. This is because values inside
# `intermediates` should be tuples tracking all instantiations of a value.
# These values each have just one instantiation, hence singletons.
else:
logits = self._compute_logits(
params=params, batch=batch, dropout_rng=None
)

token_scores = (
-losses.cross_entropy_with_logits(
logits,
common_utils.onehot(
decoder_target_tokens, logits.shape[-1], on_value=1, off_value=0
),
z_loss=0.0,
)[0]
* weights
)
if return_intermediates:
intermediates['decoder']['token_scores'] = (token_scores,)

sequence_scores = token_scores.sum(-1)

if return_intermediates:
return sequence_scores, intermediates # pytype: disable=bad-return-type # jax-ndarray
return sequence_scores, intermediates

return sequence_scores

Expand Down

0 comments on commit 05ea768

Please sign in to comment.