diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 5e24eec14..15c804ae6 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -4,7 +4,7 @@ import sys from abc import abstractmethod from time import time -from typing import Dict, Sequence, Tuple, Union +from typing import Dict, Optional, Sequence, Tuple, Union import torch import torch.nn.functional as F @@ -196,15 +196,16 @@ def generate_eval(self, input_ids, attention_mask=None, **kwargs): input_ids=input_ids, attention_mask=attention_mask, **kwargs ) - def save(self, directory=None): - """Creates checkpoint of optimizer, scheduler and a model""" + def save(self, directory: Optional[str] = None): + """Creates a checkpoint of the optimizer, scheduler and model""" self.accelerator.save_state(directory or self.config.train.checkpoint_dir) - if directory: - self.model.base_model.save_pretrained(f"hf_model_{directory}") - else: - self.model.base_model.save_pretrained( - f"hf_model_{self.config.train.checkpoint_dir}" - ) + + @abstractmethod + def save_pretrained(self, directory: Optional[str] = None): + """Save the model and its configuration file to a directory, so that it can be re-loaded with the + `transformers.PreTrainedModel.from_pretrained` method. + """ + pass def load(self, directory=None): """Load checkpoint of optimizer, scheduler and a model""" diff --git a/trlx/trainer/accelerate_ilql_trainer.py b/trlx/trainer/accelerate_ilql_trainer.py index 66b88e330..a5191697b 100644 --- a/trlx/trainer/accelerate_ilql_trainer.py +++ b/trlx/trainer/accelerate_ilql_trainer.py @@ -1,4 +1,4 @@ -from typing import Sequence, Union, cast +from typing import Optional, Sequence, Union, cast import torch @@ -92,3 +92,14 @@ def prepare_learning(self): self.n_updates_per_batch = 1 self.total_steps = self.config.train.epochs * len(train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) + + def save_pretrained(self, directory: Optional[str] = None): + # TODO: Support saving with `transformers.PreTrainedModel.save_pretrained`. + # This is currently not supported becasue `nn.ilql_models.CausalLMWithValueHeads` + # requires a custom `generate` method using its (value/q) heads to steer + # sampling - something that is not possible with the default + # `transformers.PreTrainedModel.generate`. + raise NotImplementedError( + "`AccelerateILQLTrainer` does not currently support automatic saving " + "with `transformers.PreTrainedModel.save_pretrained`." + ) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 85da0bd16..06593b2a2 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -1,7 +1,7 @@ import json import os import uuid -from typing import Tuple +from typing import Optional, Tuple import torch from torchtyping import TensorType @@ -218,3 +218,8 @@ def prepare_learning(self): * len(self.train_dataloader) ) self.total_steps = min(self.total_steps, self.config.train.total_steps) + + def save_pretrained(self, directory: Optional[str] = None): + directory = f"{directory or self.config.train.checkpoint_dir}/hf_model" + self.accelerator.unwrap_model(self.model).base_model.save_pretrained(directory) + self.tokenizer.save_pretrained(directory)