diff --git a/open_flamingo/train/losses.py b/open_flamingo/train/losses.py index 0ed76648..0f86e3a4 100644 --- a/open_flamingo/train/losses.py +++ b/open_flamingo/train/losses.py @@ -88,7 +88,7 @@ def __call__( shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLossWithZLoss(eps=z_loss_eps) - shift_logits = shift_logits.view(-1, model.lang_model.config.vocab_size) + shift_logits = shift_logits.view(-1, unwrap_model(model).lang_model.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device)