Skip to content

Commit

Permalink
support extra model and tokenizer configs during loading by from_pret…
Browse files Browse the repository at this point in the history
…rained in accelerate trainer (#568)

* support extra model and tokenizer configs during loading by from_pretrained

* support extra model and tokenizer configs during loading by from_pretrained

* style: satisfy black

---------

Co-authored-by: 聂靖入 <[email protected]>
Co-authored-by: maxreciprocate <[email protected]>
  • Loading branch information
3 people authored Oct 23, 2023
1 parent 69b02b0 commit bcbcdac
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 3 deletions.
2 changes: 2 additions & 0 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class ModelConfig:
model_arch_type: str = "causal"
num_layers_unfrozen: int = -1
peft_config: Any = None
model_extra_configs: Dict[str, Any] = field(default_factory=dict)

@classmethod
def from_dict(cls, config: Dict[str, Any]):
Expand All @@ -89,6 +90,7 @@ class TokenizerConfig:
tokenizer_path: str
padding_side: str = "left"
truncation_side: str = "right"
tokenizer_extra_configs: Dict[str, Any] = field(default_factory=dict)

@classmethod
def from_dict(cls, config: Dict[str, Any]):
Expand Down
4 changes: 3 additions & 1 deletion trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def __init__(self, config, **kwargs): # noqa: C901
self.opt = self.setup_optimizer()
self.scheduler = self.setup_scheduler()

self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path)
self.tokenizer = AutoTokenizer.from_pretrained(
config.tokenizer.tokenizer_path, **config.tokenizer.tokenizer_extra_configs
)
self.tokenizer.padding_side = config.tokenizer.padding_side
self.tokenizer.truncation_side = config.tokenizer.truncation_side
self.tokenizer.sep_token = "<sep>"
Expand Down
1 change: 1 addition & 0 deletions trlx/trainer/accelerate_ilql_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def get_arch(self, config):
two_qs=config.method.two_qs,
alpha=config.method.alpha,
peft_config=self.config.model.peft_config,
**self.config.model.model_extra_configs,
)

def post_backward_callback(self):
Expand Down
1 change: 1 addition & 0 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def get_arch(self, config: TRLConfig):
num_layers_unfrozen=config.model.num_layers_unfrozen,
num_value_layers_unfrozen=config.method.num_value_layers_unfrozen,
peft_config=self.config.model.peft_config,
**self.config.model.model_extra_configs,
)

def loss(self, batch: PPORLBatch) -> Tuple[float, Dict[str, Any]]:
Expand Down
2 changes: 1 addition & 1 deletion trlx/trainer/accelerate_rft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_arch(self, config):
if issubclass(type(config.model.model_path), PretrainedConfig):
from_fn = AutoModelForCausalLM.from_config

model = from_fn(config.model.model_path)
model = from_fn(config.model.model_path, **config.model.model_extra_configs)

if config.model.peft_config is not None:
# Initialize the peft adapter
Expand Down
2 changes: 1 addition & 1 deletion trlx/trainer/accelerate_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_arch(self, config):
if issubclass(type(config.model.model_path), PretrainedConfig):
from_fn = AutoModelForCausalLM.from_config

model = from_fn(config.model.model_path)
model = from_fn(config.model.model_path, **config.model.model_extra_configs)

if config.model.peft_config is not None:
# Initialize the peft adapter
Expand Down

0 comments on commit bcbcdac

Please sign in to comment.