Skip to content

Commit

Permalink
Fix inhomogeneous shape error in example (#30434)
Browse files Browse the repository at this point in the history
Fix inhomogeneous shape error in example.
  • Loading branch information
Lu Teng authored May 21, 2024
1 parent d24097e commit 5bf9caa
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions examples/flax/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,8 @@ def eval_data_collator(dataset: Dataset, batch_size: int):

for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
# Ignore `offset_mapping` to avoid numpy/JAX array conversion issue.
batch = {k: np.array(v) for k, v in batch.items() if k != "offset_mapping"}

yield batch

Expand Down Expand Up @@ -1000,7 +1001,6 @@ def eval_step(state, batch):
position=2,
):
_ = batch.pop("example_id")
_ = batch.pop("offset_mapping")
predictions = pad_shard_unpad(p_eval_step)(
state, batch, min_device_batch=per_device_eval_batch_size
)
Expand Down Expand Up @@ -1055,7 +1055,6 @@ def eval_step(state, batch):
eval_loader, total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2
):
_ = batch.pop("example_id")
_ = batch.pop("offset_mapping")
predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size)
start_logits = np.array(predictions[0])
end_logits = np.array(predictions[1])
Expand Down

0 comments on commit 5bf9caa

Please sign in to comment.