Skip to content

Commit

Permalink
media token fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Oscar Lo committed Dec 1, 2023
1 parent 2053d1f commit 6425b29
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion open_flamingo/train/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,17 @@
from data_utils import DataInfo
import random
import numpy as np
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP

def unwrap_model(model):
"""
Unwrap a model from a DataParallel or DistributedDataParallel wrapper.
"""
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
return model.module
else:
return model

def train_one_epoch(
args,
Expand Down Expand Up @@ -77,8 +87,11 @@ def train_one_epoch(
batch_metadata_to_log[
f"{datasets[dataset_ix].name}_num_tokens"
] = attention_mask.sum().item()
model = unwrap_model(model)
model.media_token_id = 400
model = DDP(model)
batch_metadata_to_log[f"{datasets[dataset_ix].name}_num_images"] = (
(input_ids == model.media_token_id).sum().item()
(input_ids == model.module.media_token_id).sum().item()
)

# forward pass
Expand Down

0 comments on commit 6425b29

Please sign in to comment.