From aa985f09ccb4b26818e14d286bf4cb83d2488886 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Wed, 21 Aug 2024 08:50:05 -0400 Subject: [PATCH] removed forward method override --- llmfoundry/models/hf/hf_base.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/llmfoundry/models/hf/hf_base.py b/llmfoundry/models/hf/hf_base.py index ee1fd77816..49bd22dd68 100644 --- a/llmfoundry/models/hf/hf_base.py +++ b/llmfoundry/models/hf/hf_base.py @@ -122,19 +122,6 @@ def __init__( should_save_peft_only=should_save_peft_only, ) - def forward(self, batch: Mapping): - if isinstance(batch, dict) or isinstance(batch, UserDict): - # Further input validation is left to the huggingface forward call - batch = { - k: v for k, v in batch.items() if k in self.model_forward_args - } - output = self.model(**batch) # type: ignore (thirdparty) - else: - raise ValueError( - 'Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model', - ) - return output - def loss(self, outputs: ModelOutput, batch: Mapping): if self.config.use_return_dict: return outputs['loss']