diff --git a/t5x/models.py b/t5x/models.py index 09e86a858..ac52e0944 100644 --- a/t5x/models.py +++ b/t5x/models.py @@ -23,6 +23,7 @@ import dataclasses import functools import inspect +import typing from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Union from absl import logging @@ -210,7 +211,6 @@ def predict_batch_with_aux( predictions: the model predictions aux: auxiliary data """ - pass @abc.abstractmethod def score_batch( @@ -218,9 +218,8 @@ def score_batch( params: PyTree, batch: Mapping[str, jnp.ndarray], return_intermediates: bool = False, - ) -> jnp.ndarray: + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, Any]]]: """Computes scores for batch.""" - pass @abc.abstractmethod def get_initial_variables( @@ -230,7 +229,6 @@ def get_initial_variables( input_types: Optional[Mapping[str, jnp.dtype]] = None, ) -> flax_scope.FrozenVariableDict: """Returns the initial variables of the model.""" - pass class BaseTransformerModel(BaseModel): @@ -281,9 +279,8 @@ def _compute_logits( params: PyTree, batch: Mapping[str, jnp.ndarray], dropout_rng: Optional[jax.random.KeyArray] = None, - ) -> jnp.ndarray: + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: """Computes logits via a forward pass of the model.""" - pass def loss_fn( self, @@ -384,9 +381,8 @@ def __init__( default_decoder_params: Optional[DecoderParams] = None, ): if feature_converter_cls is not None: - self.FEATURE_CONVERTER_CLS = ( - feature_converter_cls # pylint: disable=invalid-name - ) + # pylint: disable-next=invalid-name + self.FEATURE_CONVERTER_CLS = feature_converter_cls self._default_decoder_params = default_decoder_params or DecoderParams() super().__init__( module=module, @@ -455,7 +451,7 @@ def get_initial_variables( ) return initial_variables - def _compute_logits( # pytype: disable=signature-mismatch # jax-ndarray + def _compute_logits( self, params: PyTree, batch: Mapping[str, jnp.ndarray], @@ -665,7 +661,7 @@ def predict_batch_with_aux( else: decoder_prompt_inputs = jnp.zeros_like(decoder_input_tokens) - encoded_inputs = self.module.apply( + encoded_inputs: jnp.ndarray = self.module.apply( {'params': params}, encoder_input_tokens, enable_dropout=False, @@ -759,7 +755,7 @@ def predict_batch_with_aux( else: return decodes[:, -1, :], {'scores': scores[:, -1]} - def score_batch( # pytype: disable=signature-mismatch # jax-ndarray + def score_batch( self, params: PyTree, batch: Mapping[str, jnp.ndarray], @@ -773,23 +769,9 @@ def score_batch( # pytype: disable=signature-mismatch # jax-ndarray logits, modified_variables = self._compute_logits( params=params, batch=batch, mutable=['intermediates'] ) - - # Inside self.module, we called nn.Module.sow to track various - # intermediate values. We extract them here. - intermediates = flax_core.unfreeze( - modified_variables.get('intermediates', {}) - ) - - # Track per-token labels and loss weights as well. These are not - # intermediate values of logit computation, so we manually add them here. - intermediates.setdefault('decoder', {}) - intermediates['decoder']['target_tokens'] = (target_tokens,) - intermediates['decoder']['loss_weights'] = (weights,) - # Note that the values are singleton tuples. This is because values inside - # `intermediates` should be tuples tracking all instantiations of a value. - # These values each have just one instantiation, hence singletons. else: - logits = self._compute_logits(params, batch) # type: jnp.ndarray # pytype: disable=annotation-type-mismatch # jax-ndarray + logits = typing.cast(jnp.ndarray, self._compute_logits(params, batch)) + modified_variables = {} # Purposefully don't use config.z_loss because that term is for training # stability and shouldn't affect our reported scores. @@ -803,12 +785,26 @@ def score_batch( # pytype: disable=signature-mismatch # jax-ndarray )[0] * weights ) - if return_intermediates: - intermediates['decoder']['token_scores'] = (token_scores,) sequence_scores = token_scores.sum(-1) if return_intermediates: + + # Inside self.module, we called nn.Module.sow to track various + # intermediate values. We extract them here. + intermediates = flax_core.unfreeze( + modified_variables.get('intermediates', {}) + ) + + # Track per-token labels and loss weights as well. These are not + # intermediate values of logit computation, so we manually add them here. + intermediates.setdefault('decoder', {}) + intermediates['decoder']['target_tokens'] = (target_tokens,) + intermediates['decoder']['loss_weights'] = (weights,) + # Note that the values are singleton tuples. This is because values inside + # `intermediates` should be tuples tracking all instantiations of a value. + # These values each have just one instantiation, hence singletons. + intermediates['decoder']['token_scores'] = (token_scores,) return sequence_scores, intermediates return sequence_scores @@ -847,9 +843,8 @@ def __init__( ] = None, ): if feature_converter_cls is not None: - self.FEATURE_CONVERTER_CLS = ( - feature_converter_cls # pylint: disable=invalid-name - ) + # pylint: disable-next=invalid-name + self.FEATURE_CONVERTER_CLS = feature_converter_cls self._inputs_bidirectional_attention = inputs_bidirectional_attention super().__init__( module, @@ -901,7 +896,7 @@ def _compute_logits( dropout_rng: Optional[jax.random.KeyArray] = None, mutable: flax_scope.CollectionFilter = False, other_variables: Optional[PyTree] = None, - ) -> jnp.ndarray: + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: """Computes logits via a forward pass of `self.module`.""" rngs = {'dropout': dropout_rng} if dropout_rng is not None else None decoder_causal_attention = self._get_decoder_causal_attention(batch) @@ -954,7 +949,7 @@ def score_batch( params: PyTree, batch: Mapping[str, jnp.ndarray], return_intermediates: bool = False, - ) -> jnp.ndarray: + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, Any]]]: """Compute log likelihood score on a batch.""" decoder_target_tokens = batch['decoder_target_tokens'] @@ -967,7 +962,26 @@ def score_batch( dropout_rng=None, mutable=['intermediates'], ) + else: + logits = typing.cast( + jnp.ndarray, + self._compute_logits(params=params, batch=batch, dropout_rng=None), + ) + modified_variables = {} + token_scores = ( + -losses.cross_entropy_with_logits( + logits, + common_utils.onehot( + decoder_target_tokens, logits.shape[-1], on_value=1, off_value=0 + ), + z_loss=0.0, + )[0] + * weights + ) + sequence_scores = token_scores.sum(-1) + + if return_intermediates: # Inside self.module, we called nn.Module.sow to track various # intermediate values. We extract them here. intermediates = flax_core.unfreeze( @@ -982,28 +996,9 @@ def score_batch( # Note that the values are singleton tuples. This is because values inside # `intermediates` should be tuples tracking all instantiations of a value. # These values each have just one instantiation, hence singletons. - else: - logits = self._compute_logits( - params=params, batch=batch, dropout_rng=None - ) - token_scores = ( - -losses.cross_entropy_with_logits( - logits, - common_utils.onehot( - decoder_target_tokens, logits.shape[-1], on_value=1, off_value=0 - ), - z_loss=0.0, - )[0] - * weights - ) - if return_intermediates: intermediates['decoder']['token_scores'] = (token_scores,) - - sequence_scores = token_scores.sum(-1) - - if return_intermediates: - return sequence_scores, intermediates # pytype: disable=bad-return-type # jax-ndarray + return sequence_scores, intermediates return sequence_scores