From d03fea7185b3d12ea8175ddfab86f47e5a312663 Mon Sep 17 00:00:00 2001 From: Max <56548574+maxreciprocate@users.noreply.github.com> Date: Tue, 17 Oct 2023 20:32:34 +0300 Subject: [PATCH] fix(modeling_base): re-order `model.forward_kwargs` initialization (#566) * fix(modeling_base): re-order `model.forward_kwargs` initialization * fix(modeling_base): revert abstract `post_init` deletion --- trlx/models/modeling_base.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/trlx/models/modeling_base.py b/trlx/models/modeling_base.py index 1035a3109..2799b811c 100644 --- a/trlx/models/modeling_base.py +++ b/trlx/models/modeling_base.py @@ -69,13 +69,6 @@ class PreTrainedModelWrapper(nn.Module, transformers.utils.PushToHubMixin): def __init__(self, base_model: Optional[transformers.PreTrainedModel] = None, peft_config=None, **kwargs): super().__init__() self.base_model = base_model - # cache `forward` args for general use (avoids incompatible args across architectures) - if peft_config: - # keep all kwargs for peft - self.forward_kwargs = None - else: - self.forward_kwargs = inspect.getfullargspec(base_model.forward).args - self.is_loaded_in_8bit = getattr(base_model, "is_loaded_in_8bit", False) if self.is_loaded_in_8bit: # TODO(glerzing): Fully test and support loading in 8-bit @@ -318,6 +311,16 @@ def from_pretrained( # noqa: max-complexity state_dict = pretrained_model_name_or_path.state_dict() model.post_init(state_dict=state_dict) + + # cache `forward` args for general use (avoids incompatible args across architectures) + if peft_config: + # Don't use the interface of the peft model, + # use the interface of the underlying transformer model instead. + # (peft adds 2 "base_model" layers) + model.forward_kwargs = inspect.getfullargspec(model.base_model.base_model.base_model.forward).args + else: + model.forward_kwargs = inspect.getfullargspec(model.base_model.forward).args + return model def save_pretrained(self, *args, **kwargs): @@ -349,20 +352,16 @@ def save_pretrained(self, *args, **kwargs): return self.base_model.save_pretrained(*args, **kwargs) - def state_dict(self, *args, **kwargs): - """Return the state_dict of the pretrained model.""" - raise NotImplementedError - def post_init(self, *args, **kwargs): """Post initialization method. This method is called after the model is instantiated and loaded from a checkpoint. It can be used to perform additional operations such as loading the state_dict. """ - if self.peft_type: - # Don't use the interface of the peft model, - # use the interface of the underlying transformer model instead. - # (peft adds 2 "base_model" layers) - self.forward_kwargs = inspect.getfullargspec(self.base_model.base_model.base_model.forward).args + pass + + def state_dict(self, *args, **kwargs): + """Return the state_dict of the pretrained model.""" + raise NotImplementedError def get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: """Filter out arguments not supported by the specific instance of @@ -370,6 +369,4 @@ def get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: """ # FIXME: This is a hack to get around the fact that the `transformers` # architectures we use don't have a consistent API for `forward` parameters. - if self.forward_kwargs is None: - return kwargs return {k: v for k, v in kwargs.items() if k in self.forward_kwargs}