Skip to content

Commit

Permalink
fix for "make fixup"
Browse files Browse the repository at this point in the history
  • Loading branch information
louie-tsai committed Dec 4, 2024
1 parent d7fc2d1 commit 64d76c8
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions examples/flax/language-modeling/run_bert_flax.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,56 @@
#!/usr/bin/env python3
import time
from argparse import ArgumentParser

import jax
from transformers import AutoTokenizer, FlaxBertModel, BertConfig
import numpy as np

from transformers import BertConfig, FlaxBertModel


parser = ArgumentParser()
parser.add_argument('--precision', type=str, choices=["float32", "bfloat16"], default="float32")
parser.add_argument("--precision", type=str, choices=["float32", "bfloat16"], default="float32")
args = parser.parse_args()

dtype = jax.numpy.float32
if args.precision == "bfloat16":
dtype = jax.numpy.bfloat16
dtype = jax.numpy.bfloat16

VOCAB_SIZE = 30522
BS = 32
SEQ_LEN = 128


def get_input_data(batch_size=1, seq_length=384):
shape = (batch_size, seq_length)
input_ids = np.random.randint(1, VOCAB_SIZE, size=shape).astype(np.int32)
token_type_ids = np.ones(shape).astype(np.int32)
attention_mask = np.ones(shape).astype(np.int32)
return { 'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': attention_mask }
shape = (batch_size, seq_length)
input_ids = np.random.randint(1, VOCAB_SIZE, size=shape).astype(np.int32)
token_type_ids = np.ones(shape).astype(np.int32)
attention_mask = np.ones(shape).astype(np.int32)
return {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}


inputs = get_input_data(BS, SEQ_LEN)
config = BertConfig.from_pretrained("bert-base-uncased", hidden_act="gelu_new")
model = FlaxBertModel.from_pretrained("bert-base-uncased", config=config, dtype=dtype)


@jax.jit
def func():
outputs = model(**inputs)
return outputs
outputs = model(**inputs)
return outputs


(nwarmup, nbenchmark) = (5, 100)

# warmpup
for _ in range(nwarmup):
func()
func()

# benchmark
import time

start = time.time()
for _ in range(nbenchmark):
func()
func()
end = time.time()
print(end-start)
print(end - start)
print(f"Throughput: {((nbenchmark * BS)/(end-start)):.3f} examples/sec")

0 comments on commit 64d76c8

Please sign in to comment.