Skip to content

Commit

Permalink
removed forward method override
Browse files Browse the repository at this point in the history
  • Loading branch information
jdchang1 committed Aug 21, 2024
1 parent f7cabf5 commit aa985f0
Showing 1 changed file with 0 additions and 13 deletions.
13 changes: 0 additions & 13 deletions llmfoundry/models/hf/hf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,6 @@ def __init__(
should_save_peft_only=should_save_peft_only,
)

def forward(self, batch: Mapping):
if isinstance(batch, dict) or isinstance(batch, UserDict):
# Further input validation is left to the huggingface forward call
batch = {
k: v for k, v in batch.items() if k in self.model_forward_args
}
output = self.model(**batch) # type: ignore (thirdparty)
else:
raise ValueError(
'Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model',
)
return output

def loss(self, outputs: ModelOutput, batch: Mapping):
if self.config.use_return_dict:
return outputs['loss']
Expand Down

0 comments on commit aa985f0

Please sign in to comment.