From 8cd23d57f485ba17407a883c470ee37482c2856e Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Wed, 15 May 2024 16:35:18 -0700 Subject: [PATCH] using self.shift_labels instead of self.model.transformer.shift_labels in loss (#1211) Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/models/mpt/modeling_mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 8726879208..9c22e90678 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -1157,7 +1157,7 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping) -> Union[dict, torch.Tensor]: - if self.model.transformer.shift_labels: + if self.shift_labels: targets = self.get_targets(batch) else: targets = batch['labels']