Skip to content

Commit

Permalink
Fix HuggingFace model.save_pretrained for DDP (#181)
Browse files Browse the repository at this point in the history
* Fix HuggingFace `model.save_pretrained` for DDP

* Add abstract method for `save_pretrained`

* Clarify `TODO` comment
  • Loading branch information
jon-tow authored Jan 12, 2023
1 parent 7ed923c commit 400dcfd
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 11 deletions.
19 changes: 10 additions & 9 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from abc import abstractmethod
from time import time
from typing import Dict, Sequence, Tuple, Union
from typing import Dict, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -196,15 +196,16 @@ def generate_eval(self, input_ids, attention_mask=None, **kwargs):
input_ids=input_ids, attention_mask=attention_mask, **kwargs
)

def save(self, directory=None):
"""Creates checkpoint of optimizer, scheduler and a model"""
def save(self, directory: Optional[str] = None):
"""Creates a checkpoint of the optimizer, scheduler and model"""
self.accelerator.save_state(directory or self.config.train.checkpoint_dir)
if directory:
self.model.base_model.save_pretrained(f"hf_model_{directory}")
else:
self.model.base_model.save_pretrained(
f"hf_model_{self.config.train.checkpoint_dir}"
)

@abstractmethod
def save_pretrained(self, directory: Optional[str] = None):
"""Save the model and its configuration file to a directory, so that it can be re-loaded with the
`transformers.PreTrainedModel.from_pretrained` method.
"""
pass

def load(self, directory=None):
"""Load checkpoint of optimizer, scheduler and a model"""
Expand Down
13 changes: 12 additions & 1 deletion trlx/trainer/accelerate_ilql_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence, Union, cast
from typing import Optional, Sequence, Union, cast

import torch

Expand Down Expand Up @@ -92,3 +92,14 @@ def prepare_learning(self):
self.n_updates_per_batch = 1
self.total_steps = self.config.train.epochs * len(train_dataloader)
self.total_steps = min(self.total_steps, self.config.train.total_steps)

def save_pretrained(self, directory: Optional[str] = None):
# TODO: Support saving with `transformers.PreTrainedModel.save_pretrained`.
# This is currently not supported becasue `nn.ilql_models.CausalLMWithValueHeads`
# requires a custom `generate` method using its (value/q) heads to steer
# sampling - something that is not possible with the default
# `transformers.PreTrainedModel.generate`.
raise NotImplementedError(
"`AccelerateILQLTrainer` does not currently support automatic saving "
"with `transformers.PreTrainedModel.save_pretrained`."
)
7 changes: 6 additions & 1 deletion trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import uuid
from typing import Tuple
from typing import Optional, Tuple

import torch
from torchtyping import TensorType
Expand Down Expand Up @@ -218,3 +218,8 @@ def prepare_learning(self):
* len(self.train_dataloader)
)
self.total_steps = min(self.total_steps, self.config.train.total_steps)

def save_pretrained(self, directory: Optional[str] = None):
directory = f"{directory or self.config.train.checkpoint_dir}/hf_model"
self.accelerator.unwrap_model(self.model).base_model.save_pretrained(directory)
self.tokenizer.save_pretrained(directory)

0 comments on commit 400dcfd

Please sign in to comment.