diff --git a/trlx/data/configs.py b/trlx/data/configs.py index 16ff55c29..8926cefb6 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -65,6 +65,7 @@ class ModelConfig: model_arch_type: str = "causal" num_layers_unfrozen: int = -1 peft_config: Any = None + model_extra_configs: Dict[str, Any] = field(default_factory=dict) @classmethod def from_dict(cls, config: Dict[str, Any]): @@ -89,6 +90,7 @@ class TokenizerConfig: tokenizer_path: str padding_side: str = "left" truncation_side: str = "right" + tokenizer_extra_configs: Dict[str, Any] = field(default_factory=dict) @classmethod def from_dict(cls, config: Dict[str, Any]): diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index e58254ef9..8bf19b251 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -67,7 +67,9 @@ def __init__(self, config, **kwargs): # noqa: C901 self.opt = self.setup_optimizer() self.scheduler = self.setup_scheduler() - self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path) + self.tokenizer = AutoTokenizer.from_pretrained( + config.tokenizer.tokenizer_path, **config.tokenizer.tokenizer_extra_configs + ) self.tokenizer.padding_side = config.tokenizer.padding_side self.tokenizer.truncation_side = config.tokenizer.truncation_side self.tokenizer.sep_token = "" diff --git a/trlx/trainer/accelerate_ilql_trainer.py b/trlx/trainer/accelerate_ilql_trainer.py index 60001ee55..b2aae2929 100644 --- a/trlx/trainer/accelerate_ilql_trainer.py +++ b/trlx/trainer/accelerate_ilql_trainer.py @@ -132,6 +132,7 @@ def get_arch(self, config): two_qs=config.method.two_qs, alpha=config.method.alpha, peft_config=self.config.model.peft_config, + **self.config.model.model_extra_configs, ) def post_backward_callback(self): diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index cd0b62ab6..730c23248 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -121,6 +121,7 @@ def get_arch(self, config: TRLConfig): num_layers_unfrozen=config.model.num_layers_unfrozen, num_value_layers_unfrozen=config.method.num_value_layers_unfrozen, peft_config=self.config.model.peft_config, + **self.config.model.model_extra_configs, ) def loss(self, batch: PPORLBatch) -> Tuple[float, Dict[str, Any]]: diff --git a/trlx/trainer/accelerate_rft_trainer.py b/trlx/trainer/accelerate_rft_trainer.py index 6dde427fc..2f072343f 100644 --- a/trlx/trainer/accelerate_rft_trainer.py +++ b/trlx/trainer/accelerate_rft_trainer.py @@ -62,7 +62,7 @@ def get_arch(self, config): if issubclass(type(config.model.model_path), PretrainedConfig): from_fn = AutoModelForCausalLM.from_config - model = from_fn(config.model.model_path) + model = from_fn(config.model.model_path, **config.model.model_extra_configs) if config.model.peft_config is not None: # Initialize the peft adapter diff --git a/trlx/trainer/accelerate_sft_trainer.py b/trlx/trainer/accelerate_sft_trainer.py index 11c88a1c9..1e79ce9b9 100644 --- a/trlx/trainer/accelerate_sft_trainer.py +++ b/trlx/trainer/accelerate_sft_trainer.py @@ -42,7 +42,7 @@ def get_arch(self, config): if issubclass(type(config.model.model_path), PretrainedConfig): from_fn = AutoModelForCausalLM.from_config - model = from_fn(config.model.model_path) + model = from_fn(config.model.model_path, **config.model.model_extra_configs) if config.model.peft_config is not None: # Initialize the peft adapter