From fc63914399b6f60512c720959f9182b02ae4a45c Mon Sep 17 00:00:00 2001 From: Roy Hvaara Date: Tue, 10 Oct 2023 12:35:16 -0700 Subject: [PATCH] [JAX] Replace uses of `jnp.array` in types with `jnp.ndarray`. (#26703) `jnp.array` is a function, not a type: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Co-authored-by: Peter Hawkins --- .../flax/image-captioning/run_image_captioning_flax.py | 2 +- examples/flax/language-modeling/run_clm_flax.py | 2 +- examples/flax/question-answering/run_qa.py | 2 +- .../run_flax_speech_recognition_seq2seq.py | 2 +- examples/flax/summarization/run_summarization_flax.py | 2 +- examples/flax/text-classification/run_flax_glue.py | 2 +- examples/flax/token-classification/run_flax_ner.py | 2 +- examples/flax/vision/run_image_classification.py | 2 +- .../jax-projects/hybrid_clip/run_hybrid_clip.py | 2 +- .../jax-projects/model_parallel/run_clm_mp.py | 2 +- src/transformers/models/bart/modeling_flax_bart.py | 2 +- src/transformers/models/bert/modeling_flax_bert.py | 2 +- .../models/big_bird/modeling_flax_big_bird.py | 2 +- .../models/blenderbot/modeling_flax_blenderbot.py | 2 +- .../blenderbot_small/modeling_flax_blenderbot_small.py | 2 +- src/transformers/models/electra/modeling_flax_electra.py | 8 ++++---- src/transformers/models/longt5/modeling_flax_longt5.py | 2 +- src/transformers/models/marian/modeling_flax_marian.py | 2 +- src/transformers/models/mt5/modeling_flax_mt5.py | 2 +- src/transformers/models/pegasus/modeling_flax_pegasus.py | 2 +- src/transformers/models/roberta/modeling_flax_roberta.py | 2 +- .../modeling_flax_roberta_prelayernorm.py | 2 +- src/transformers/models/t5/modeling_flax_t5.py | 2 +- .../models/xlm_roberta/modeling_flax_xlm_roberta.py | 2 +- .../modeling_flax_{{cookiecutter.lowercase_modelname}}.py | 2 +- 25 files changed, 28 insertions(+), 28 deletions(-) diff --git a/examples/flax/image-captioning/run_image_captioning_flax.py b/examples/flax/image-captioning/run_image_captioning_flax.py index bbc79977a46793..d8c89c1a242f14 100644 --- a/examples/flax/image-captioning/run_image_captioning_flax.py +++ b/examples/flax/image-captioning/run_image_captioning_flax.py @@ -381,7 +381,7 @@ def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="t def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index 95e175d494bfe2..c61b24f4d7e615 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -326,7 +326,7 @@ def write_eval_metric(summary_writer, eval_metrics, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py index 9cd90f285a02c4..0d35f302f8f37a 100644 --- a/examples/flax/question-answering/run_qa.py +++ b/examples/flax/question-answering/run_qa.py @@ -389,7 +389,7 @@ def cross_entropy_loss(logits, labels): # region Create learning rate function def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py index 4a2915a31ac7dc..8af835b6a4b4d3 100644 --- a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py +++ b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py @@ -360,7 +360,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def create_learning_rate_fn( num_train_steps: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) decay_fn = optax.linear_schedule( diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index d57aa1769135db..782e9ee88c498f 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -409,7 +409,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index b42a2565310799..1535ff8492781b 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -288,7 +288,7 @@ def cross_entropy_loss(logits, labels): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py index f4b40220ff12c7..e06a85cb67c0a0 100644 --- a/examples/flax/token-classification/run_flax_ner.py +++ b/examples/flax/token-classification/run_flax_ner.py @@ -340,7 +340,7 @@ def cross_entropy_loss(logits, labels): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/vision/run_image_classification.py b/examples/flax/vision/run_image_classification.py index 66505014ec57dd..4bed9b663f6f00 100644 --- a/examples/flax/vision/run_image_classification.py +++ b/examples/flax/vision/run_image_classification.py @@ -249,7 +249,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py index f54641408f80a2..c5a4a202534b87 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py +++ b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py @@ -283,7 +283,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py index a6da8729f0ce3b..bb297e3e0db6e4 100644 --- a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py +++ b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py @@ -214,7 +214,7 @@ def write_eval_metric(summary_writer, eval_metrics, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index 9858eb2d1bf416..6abfcdc398422f 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -217,7 +217,7 @@ """ -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 99dfa2a0e2f9ab..bb2af0e0602aba 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -295,7 +295,7 @@ def __call__( hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index afdac2645f2652..c6d8b7c1612ec9 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -316,7 +316,7 @@ def __call__( hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py index 1035272fd05350..61239335be3b63 100644 --- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py @@ -204,7 +204,7 @@ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py index 2bf8b59e2757bc..b5272fb3bca9e2 100644 --- a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py @@ -216,7 +216,7 @@ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 32e76b8b586f4f..8fced6ff1ea25e 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -263,7 +263,7 @@ def __call__( hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, @@ -1228,13 +1228,13 @@ def __call__(self, hidden_states, cls_index=None, deterministic: bool = True): Compute a single vector summary of a sequence hidden states. Args: - hidden_states (`jnp.array` of shape `[batch_size, seq_len, hidden_size]`): + hidden_states (`jnp.ndarray` of shape `[batch_size, seq_len, hidden_size]`): The hidden states of the last layer. - cls_index (`jnp.array` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + cls_index (`jnp.ndarray` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. Returns: - `jnp.array`: The summary of the sequence hidden states. + `jnp.ndarray`: The summary of the sequence hidden states. """ # NOTE: this doest "first" type summary always output = hidden_states[:, 0] diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py index 91ca9c72c22c4e..36e273d5725a4f 100644 --- a/src/transformers/models/longt5/modeling_flax_longt5.py +++ b/src/transformers/models/longt5/modeling_flax_longt5.py @@ -56,7 +56,7 @@ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py index a713fdb05dcfd9..5197c906895917 100644 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ b/src/transformers/models/marian/modeling_flax_marian.py @@ -227,7 +227,7 @@ def create_sinusoidal_positions(n_pos, dim): # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/mt5/modeling_flax_mt5.py b/src/transformers/models/mt5/modeling_flax_mt5.py index 86ddf477ffab56..0046e02ca7308e 100644 --- a/src/transformers/models/mt5/modeling_flax_mt5.py +++ b/src/transformers/models/mt5/modeling_flax_mt5.py @@ -27,7 +27,7 @@ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py index c5189746b1065f..17772251bf0629 100644 --- a/src/transformers/models/pegasus/modeling_flax_pegasus.py +++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py @@ -210,7 +210,7 @@ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 845fcea4429787..6bc72f12b40790 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -256,7 +256,7 @@ def __call__( hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py index b7c347693d951b..e98897993742db 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py @@ -258,7 +258,7 @@ def __call__( hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index db4ca90c270ccb..09575fdcc3b82e 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -56,7 +56,7 @@ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py index f6f39ee93ba687..fb03c390f6f419 100644 --- a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py @@ -266,7 +266,7 @@ def __call__( hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py index 83263a6a47ef11..63b5d83d308ad1 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py @@ -251,7 +251,7 @@ def __call__( hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False,