Skip to content

Commit

Permalink
Remove incorrect eos token label mask
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-awadalla authored Sep 29, 2023
1 parent c693798 commit b51cbdc
Showing 1 changed file with 0 additions and 2 deletions.
2 changes: 0 additions & 2 deletions open_flamingo/train/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def train_one_epoch(
# set up labels; language model is expected to handle shifting
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
labels[labels == tokenizer.eos_token] = -100
labels[labels == media_token_id] = -100
labels = labels.to(device_id)

Expand All @@ -127,7 +126,6 @@ def train_one_epoch(
# set up labels; language model is expected to handle shifting
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
labels[labels == tokenizer.eos_token] = -100
for i in range(labels.shape[0]):
# remove loss for any token before the first <image> token
label_idx = 0
Expand Down

0 comments on commit b51cbdc

Please sign in to comment.