From d7c67b84e6391313ca3b5593b9b180884dec55d5 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 4 Nov 2024 18:47:50 +0700 Subject: [PATCH 1/6] refactor(optimizer): use `optimizer_cls_and_kwargs` for custom optimizers --- docs/config.qmd | 24 +- src/axolotl/core/trainer_builder.py | 308 ++++++++---------- .../config/models/input/v0_4_1/__init__.py | 10 + .../utils/optimizers/embedding_scaled.py | 68 ++++ 4 files changed, 235 insertions(+), 175 deletions(-) create mode 100644 src/axolotl/utils/optimizers/embedding_scaled.py diff --git a/docs/config.qmd b/docs/config.qmd index f01a2ce267..d0bbf7e26f 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -397,7 +397,7 @@ lr_div_factor: # Learning rate div factor # Specify optimizer # Valid values are driven by the Transformers OptimizerNames class, see: -# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134 +# https://github.com/huggingface/transformers/blob/9f06fb05059a973048f5865e7e385c9db5d6daa4/src/transformers/training_args.py#L145-L187 # # Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of # torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used @@ -408,26 +408,48 @@ lr_div_factor: # Learning rate div factor # - adamw_torch # - adamw_torch_fused # - adamw_torch_xla +# - adamw_torch_npu_fused # - adamw_apex_fused # - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1) # - adafactor # - adamw_anyprecision +# - adamw_torch_4bit +# - ademamix # - sgd # - adagrad # - adamw_bnb_8bit +# - adamw_8bit # alias for adamw_bnb_8bit +# - ademamix_8bit # - lion_8bit # - lion_32bit # - paged_adamw_32bit # - paged_adamw_8bit +# - paged_ademamix_32bit +# - paged_ademamix_8bit # - paged_lion_32bit # - paged_lion_8bit +# - rmsprop +# - rmsprop_bnb +# - rmsprop_bnb_8bit +# - rmsprop_bnb_32bit # - galore_adamw # - galore_adamw_8bit # - galore_adafactor # - galore_adamw_layerwise # - galore_adamw_8bit_layerwise # - galore_adafactor_layerwise +# - lomo +# - adalomo +# - grokadamw +# - schedule_free_adamw +# - schedule_free_sgd +# +# Additional custom optimizers include: +# - optimi_adamw +# - ao_adamw_8bit +# - ao_adamw_fp8 optimizer: + # Dictionary of arguments to pass to the optimizer optim_args: # For Galore Optimizers the following optim_args are available diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 691437bc65..e511b6df86 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -34,6 +34,7 @@ TrainingArguments, ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker +from transformers.training_args import OptimizerNames from transformers.utils import is_sagemaker_mp_enabled from trl import ( CPOConfig, @@ -74,6 +75,7 @@ ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.models import ensure_dtype +from axolotl.utils.optimizers.embedding_scaled import create_embedding_scaled_optimizer from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( get_cosine_schedule_with_min_lr, @@ -268,12 +270,6 @@ class AxolotlTrainingMixins: default=None, metadata={"help": "whether to use sequential sampling for curriculum learning"}, ) - alternate_optimizer: Optional[str] = field( - default=None, - metadata={ - "help": "workaround to pass an alternate optimizer to the HF trainer" - }, - ) alternate_lr_scheduler_type: Optional[str] = field( default=None, metadata={ @@ -460,132 +456,46 @@ def _wrap_model(self, model, training=True, dataloader=None): return super()._wrap_model(model, training=training, dataloader=dataloader) def create_optimizer(self): - if ( - self.args.loraplus_lr_ratio is None + # For all other cases, use parent implementation + if (self.args.loraplus_lr_ratio is None and self.args.embedding_lr_scale is None and self.args.embedding_lr is None - and self.args.alternate_optimizer - not in [ - "optimi_adamw", - "ao_adamw_8bit", - "ao_adamw_4bit", - "ao_adamw_fp8", - "adopt_adamw", - ] ): return super().create_optimizer() opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if self.optimizer is None: # pylint: disable=access-member-before-definition - decay_parameters = self.get_decay_parameter_names(opt_model) - params = { - "to_weight_decay": {}, # LayerNorm and bias - "embeddings": {}, # lm_head, embed_tokens, - "no_weight_decay": {}, - } - - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( - self.args, - opt_model, - ) + if self.optimizer is not None: + raise ValueError("Optimizer already created.") - for name, param in opt_model.named_parameters(): - if not param.requires_grad: - continue - if name.endswith("modules_to_save.default.weight") or any( - embed_name in name for embed_name in ["embed_tokens", "lm_head"] - ): - params["embeddings"][name] = param - elif name in decay_parameters: - params["to_weight_decay"][name] = param - else: - params["no_weight_decay"][name] = param - optimizer_grouped_parameters = [] - if params["to_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["to_weight_decay"].values()), - "weight_decay": self.args.weight_decay, - "lr": optimizer_kwargs["lr"], - } - ) - if params["embeddings"]: - lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name - if self.args.embedding_lr_scale: - lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name - elif self.args.embedding_lr: - lr = self.args.embedding_lr # pylint: disable=invalid-name - optimizer_grouped_parameters.append( - { - "params": list(params["embeddings"].values()), - "weight_decay": 0.0, - "lr": lr, - } - ) - if params["no_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["no_weight_decay"].values()), - "weight_decay": 0.0, - "lr": optimizer_kwargs["lr"], - } - ) + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( + self.args, + opt_model, + ) - if self.args.loraplus_lr_ratio is not None: - loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) - loraplus_lr_embedding = getattr( - self.args, "loraplus_lr_embedding", 1e-6 - ) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + if self.args.loraplus_lr_ratio is not None: + self.optimizer = ( + create_loraplus_optimizer( opt_model, optimizer_cls, - loraplus_lr_ratio=loraplus_lr_ratio, - loraplus_lr_embedding=loraplus_lr_embedding, + loraplus_lr_ratio=self.args.loraplus_lr_ratio, + loraplus_lr_embedding=self.args.loraplus_lr_embedding, **optimizer_kwargs, ) - elif ( - self.args.embedding_lr_scale is not None - or self.args.embedding_lr is not None - ): - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "optimi_adamw": - from optimi import AdamW - - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamW( - optimizer_grouped_parameters, foreach=False, **optimizer_kwargs - ) - ) - elif self.args.alternate_optimizer == "ao_adamw_4bit": - from torchao.prototype.low_bit_optim import AdamW4bit - - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "ao_adamw_8bit": - from torchao.prototype.low_bit_optim import AdamW8bit - - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "ao_adamw_fp8": - from torchao.prototype.low_bit_optim import AdamWFp8 - - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "adopt_adamw": - from axolotl.utils.optimizers.adopt import ADOPT + ) + elif ( + self.args.embedding_lr_scale is not None + or self.args.embedding_lr is not None + ): + decay_parameters = self.get_decay_parameter_names(opt_model) - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - ADOPT( - optimizer_grouped_parameters, - decouple=True, - **optimizer_kwargs, - ) - ) + self.optimizer = create_embedding_scaled_optimizer( + opt_model, + embedding_lr_scale=self.args.embedding_lr_scale, + embedding_lr=self.args.embedding_lr, + decay_parameters=decay_parameters, + optimizer_cls=optimizer_cls, + optimizer_kwargs=optimizer_kwargs, + ) if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init @@ -594,6 +504,7 @@ def create_optimizer(self): return self.optimizer + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.args.sample_packing and not self.args.pretraining: if self.args.multipack_real_batches: @@ -1078,27 +989,25 @@ def __init__(self, *args, dataset_tags=None, **kwargs): self.optimizer = None def create_optimizer(self): + # For all other cases, use parent implementation if self.args.loraplus_lr_ratio is None: return super().create_optimizer() + # Handle LoRA Plus opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if self.optimizer is None: # pylint: disable=access-member-before-definition - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( - self.args, - opt_model, - ) - - loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) - if loraplus_lr_ratio: - print("Using lora+") - loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( + self.args, + opt_model, + ) + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + create_loraplus_optimizer( opt_model, optimizer_cls, - loraplus_lr_ratio=loraplus_lr_ratio, - loraplus_lr_embedding=loraplus_lr_embedding, + loraplus_lr_ratio=self.args.loraplus_lr_ratio, + loraplus_lr_embedding=self.args.loraplus_lr_embedding, **optimizer_kwargs, ) + ) if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init @@ -1734,27 +1643,6 @@ def build(self, total_num_steps): training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name else: training_arguments_kwargs["run_name"] = None - training_arguments_kwargs["optim"] = ( - self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" - ) - if self.cfg.optim_args: - if isinstance(self.cfg.optim_args, dict): - optim_args = ",".join( - [f"{key}={value}" for key, value in self.cfg.optim_args.items()] - ) - else: - optim_args = self.cfg.optim_args - training_arguments_kwargs["optim_args"] = optim_args - if self.cfg.optim_target_modules: - training_arguments_kwargs[ - "optim_target_modules" - ] = self.cfg.optim_target_modules - training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio - training_arguments_kwargs[ - "loraplus_lr_embedding" - ] = self.cfg.loraplus_lr_embedding - training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr - training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: training_arguments_kwargs["lr_scheduler_type"] = "cosine" @@ -1843,46 +1731,118 @@ def build(self, total_num_steps): if self.cfg.reward_model: trainer_kwargs["max_length"] = self.cfg.sequence_len - # pylint: disable=duplicate-code + # Handle custom optimizer if self.cfg.optimizer in [ "optimi_adamw", "ao_adamw_4bit", "ao_adamw_8bit", "ao_adamw_fp8", "adopt_adamw", + "lion_pytorch", ]: - # Set default so transformers doesn't throw - training_arguments_kwargs["optim"] = "adamw_hf" - training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer + # Common optimizer kwargs + optimizer_kwargs = { + "lr": training_arguments_kwargs.get("learning_rate"), + "weight_decay": training_arguments_kwargs.get("weight_decay"), + } - if self.cfg.optimizer == "lion_pytorch": - from lion_pytorch import Lion + # Adam-specific kwargs + adam_kwargs = { + "betas": ( + training_arguments_kwargs.get("adam_beta1"), + training_arguments_kwargs.get("adam_beta2"), + ), + "eps": training_arguments_kwargs.get("adam_epsilon"), + } - lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]} - if "weight_decay" in training_arguments_kwargs: - lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"] + if self.cfg.optimizer == "optimi_adamw": + from optimi import AdamW - if ( - "adam_beta1" in training_arguments_kwargs - and "adam_beta2" in training_arguments_kwargs - ): - lion_kwargs["betas"] = ( - training_arguments_kwargs["adam_beta1"], - training_arguments_kwargs["adam_beta2"], + optimizer_kwargs["foreach"] = False + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "ao_adamw_4bit": + from torchao.prototype.low_bit_optim import AdamW4bit + + optimizer_cls = AdamW4bit + optimizer_kwargs.update(adam_kwargs) + + LOG.warning( + f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead." + ) + elif self.cfg.optimizer == "ao_adamw_8bit": + from torchao.prototype.low_bit_optim import AdamW8bit + + optimizer_cls = AdamW8bit + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "ao_adamw_fp8": + from torchao.prototype.low_bit_optim import AdamWFp8 + + optimizer_cls = AdamWFp8 + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer = "adopt_adamw": + from axolotl.utils.optimizers.adopt import ADOPT + + optimizer_cls = ADOPT + adam_kwargs["decouple"] = True + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "lion_pytorch": + from lion_pytorch import Lion + + optimizer_cls = Lion + optimizer_kwargs.update({"betas": adam_kwargs["betas"]}) + + LOG.warning( + f"`lion_pytorch` will be deprecated soon. Please use `{OptimizerNames.LION}` instead." ) - trainer_kwargs["optimizers"] = ( - Lion(params=self.model.parameters(), **lion_kwargs), - None, + # Parse any additional optimizer args from config + if self.cfg.optim_args: + if isinstance(self.cfg.optim_args, dict): + optimizer_kwargs.update(self.cfg.optim_args) + else: + # Parse string format "key1=value1,key2=value2" + for mapping in self.cfg.optim_args.replace(" ", "").split(","): + key, value = mapping.split("=") + optimizer_kwargs[key] = value + + training_arguments_kwargs["optimizer_cls_and_kwargs"] = ( + optimizer_cls, + optimizer_kwargs, ) - # Set default so transformers doesn't throw - training_arguments_kwargs["optim"] = "adamw_hf" + + else: + # Use transformers' optimizer + training_arguments_kwargs["optim"] = self.cfg.optimizer + + # Parse any additional optimizer args from config + if self.cfg.optim_args: + if isinstance(self.cfg.optim_args, dict): + optim_args = ",".join( + [f"{key}={value}" for key, value in self.cfg.optim_args.items()] + ) + else: + optim_args = self.cfg.optim_args + training_arguments_kwargs["optim_args"] = optim_args if self.cfg.optimizer == "adamw_anyprecision": if Path(self.cfg.torchdistx_path).exists(): sys.path.append(self.cfg.torchdistx_path) importlib.import_module("torchdistx") + if self.cfg.optim_target_modules: + training_arguments_kwargs[ + "optim_target_modules" + ] = self.cfg.optim_target_modules + + training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr + training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale + + training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio + training_arguments_kwargs[ + "loraplus_lr_embedding" + ] = self.cfg.loraplus_lr_embedding + if self.cfg.accelerator_config: training_arguments_kwargs[ "accelerator_config" diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 24ea62c77f..15531a1e66 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -71,6 +71,8 @@ class DeprecatedParameters(BaseModel): dpo_beta: Optional[float] = None evaluation_strategy: Optional[str] = None + alternate_optimizer: Optional[str] = None + @field_validator("max_packed_sequence_len") @classmethod def validate_max_packed_sequence_len(cls, max_packed_sequence_len): @@ -108,6 +110,14 @@ def validate_evaluation_strategy(cls, evaluation_strategy): LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead") return evaluation_strategy + @field_validator("alternate_optimizer") + @classmethod + def validate_alternate_optimizer(cls, alternate_optimizer): + if alternate_optimizer: + raise DeprecationWarning( + "alternate_optimizer is deprecated, use optimizer instead" + ) + class RemappedParameters(BaseModel): """parameters that have been remapped to other names""" diff --git a/src/axolotl/utils/optimizers/embedding_scaled.py b/src/axolotl/utils/optimizers/embedding_scaled.py new file mode 100644 index 0000000000..34ac9ae304 --- /dev/null +++ b/src/axolotl/utils/optimizers/embedding_scaled.py @@ -0,0 +1,68 @@ +""" +Scales the learning rate of the embedding layer by a factor of `embedding_lr_scale` or to `embedding_lr` if set. + +Applies weight decay to parameters in `decay_parameters` and no weight decay to the rest. +""" + + +def create_embedding_scaled_optimizer( + opt_model, + embedding_lr_scale, + embedding_lr, + weight_decay, + decay_parameters, + optimizer_cls, + optimizer_kwargs, +): + params = { + "embeddings": {}, # lm_head, embed_tokens, + "to_weight_decay": {}, # LayerNorm and bias + "no_weight_decay": {}, + } + + for name, param in opt_model.named_parameters(): + if not param.requires_grad: + continue + if name.endswith("modules_to_save.default.weight") or any( + embed_name in name for embed_name in ["embed_tokens", "lm_head"] + ): + params["embeddings"][name] = param + elif name in decay_parameters: + params["to_weight_decay"][name] = param + else: + params["no_weight_decay"][name] = param + + optimizer_grouped_parameters = [] + if params["to_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["to_weight_decay"].values()), + "weight_decay": weight_decay, + "lr": optimizer_kwargs["lr"], + } + ) + + if params["embeddings"]: + lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name + if embedding_lr_scale: + lr *= embedding_lr_scale # pylint: disable=invalid-name + elif embedding_lr: + lr = embedding_lr # pylint: disable=invalid-name + optimizer_grouped_parameters.append( + { + "params": list(params["embeddings"].values()), + "weight_decay": 0.0, + "lr": lr, + } + ) + + if params["no_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["no_weight_decay"].values()), + "weight_decay": 0.0, + "lr": optimizer_kwargs["lr"], + } + ) + + return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) From d54ed305e83167439ce80ffdf11807568a1a6be6 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 6 Nov 2024 15:02:39 +0700 Subject: [PATCH 2/6] fix: remove unneeded param --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 15531a1e66..db72019347 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -71,8 +71,6 @@ class DeprecatedParameters(BaseModel): dpo_beta: Optional[float] = None evaluation_strategy: Optional[str] = None - alternate_optimizer: Optional[str] = None - @field_validator("max_packed_sequence_len") @classmethod def validate_max_packed_sequence_len(cls, max_packed_sequence_len): From b3738184229478d9fad8a3abc31c027de0d27a81 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 9 Dec 2024 18:09:39 +0700 Subject: [PATCH 3/6] fix: missing equal, initialize empty optimizer, and lint --- src/axolotl/core/trainer_builder.py | 46 +++++++++++++---------------- 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e511b6df86..4e55691857 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -437,6 +437,7 @@ def __init__( self.bench_data_collator = bench_data_collator self.eval_data_collator = eval_data_collator self.dataset_tags = dataset_tags + self.optimizer = None super().__init__(*_args, **kwargs) self.train_data_collator = self.data_collator self._stored_metrics = defaultdict(lambda: defaultdict(list)) @@ -457,7 +458,8 @@ def _wrap_model(self, model, training=True, dataloader=None): def create_optimizer(self): # For all other cases, use parent implementation - if (self.args.loraplus_lr_ratio is None + if ( + self.args.loraplus_lr_ratio is None and self.args.embedding_lr_scale is None and self.args.embedding_lr is None ): @@ -473,14 +475,12 @@ def create_optimizer(self): ) if self.args.loraplus_lr_ratio is not None: - self.optimizer = ( - create_loraplus_optimizer( - opt_model, - optimizer_cls, - loraplus_lr_ratio=self.args.loraplus_lr_ratio, - loraplus_lr_embedding=self.args.loraplus_lr_embedding, - **optimizer_kwargs, - ) + self.optimizer = create_loraplus_optimizer( + opt_model, + optimizer_cls, + loraplus_lr_ratio=self.args.loraplus_lr_ratio, + loraplus_lr_embedding=self.args.loraplus_lr_embedding, + **optimizer_kwargs, ) elif ( self.args.embedding_lr_scale is not None @@ -493,18 +493,16 @@ def create_optimizer(self): embedding_lr_scale=self.args.embedding_lr_scale, embedding_lr=self.args.embedding_lr, decay_parameters=decay_parameters, + weight_decay=self.args.weight_decay, optimizer_cls=optimizer_cls, optimizer_kwargs=optimizer_kwargs, ) if is_sagemaker_mp_enabled(): - self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init - self.optimizer - ) + self.optimizer = smp.DistributedOptimizer(self.optimizer) return self.optimizer - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.args.sample_packing and not self.args.pretraining: if self.args.multipack_real_batches: @@ -984,9 +982,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): tag_names = ["axolotl", "dpo"] def __init__(self, *args, dataset_tags=None, **kwargs): - super().__init__(*args, **kwargs) self.dataset_tags = dataset_tags self.optimizer = None + super().__init__(*args, **kwargs) def create_optimizer(self): # For all other cases, use parent implementation @@ -999,20 +997,16 @@ def create_optimizer(self): self.args, opt_model, ) - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - create_loraplus_optimizer( - opt_model, - optimizer_cls, - loraplus_lr_ratio=self.args.loraplus_lr_ratio, - loraplus_lr_embedding=self.args.loraplus_lr_embedding, - **optimizer_kwargs, - ) + self.optimizer = create_loraplus_optimizer( + opt_model, + optimizer_cls, + loraplus_lr_ratio=self.args.loraplus_lr_ratio, + loraplus_lr_embedding=self.args.loraplus_lr_embedding, + **optimizer_kwargs, ) if is_sagemaker_mp_enabled(): - self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init - self.optimizer - ) + self.optimizer = smp.DistributedOptimizer(self.optimizer) return self.optimizer @@ -1780,7 +1774,7 @@ def build(self, total_num_steps): optimizer_cls = AdamWFp8 optimizer_kwargs.update(adam_kwargs) - elif self.cfg.optimizer = "adopt_adamw": + elif self.cfg.optimizer == "adopt_adamw": from axolotl.utils.optimizers.adopt import ADOPT optimizer_cls = ADOPT From 119514b571f0e4b266bfc13597867655706e436c Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 9 Dec 2024 18:42:18 +0700 Subject: [PATCH 4/6] fix: refactor custom optimizer into enum --- src/axolotl/core/trainer_builder.py | 11 +++------ .../config/models/input/v0_4_1/__init__.py | 23 ++++++++++--------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4e55691857..428ad308fb 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -74,6 +74,7 @@ V2BatchSamplerDataCollatorForSeq2Seq, ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator +from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers from axolotl.utils.models import ensure_dtype from axolotl.utils.optimizers.embedding_scaled import create_embedding_scaled_optimizer from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -1726,14 +1727,8 @@ def build(self, total_num_steps): trainer_kwargs["max_length"] = self.cfg.sequence_len # Handle custom optimizer - if self.cfg.optimizer in [ - "optimi_adamw", - "ao_adamw_4bit", - "ao_adamw_8bit", - "ao_adamw_fp8", - "adopt_adamw", - "lion_pytorch", - ]: + custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers] + if self.cfg.optimizer in custom_supported_optimizers: # Common optimizer kwargs optimizer_kwargs = { "lr": training_arguments_kwargs.get("learning_rate"), diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index db72019347..65de966a34 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -62,6 +62,17 @@ class ChatTemplate(str, Enum): metharme = "metharme" # pylint: disable=invalid-name +class CustomSupportedOptimizers(str, Enum): + """Custom supported optimizers""" + + optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name + ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name + ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name + ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name + adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name + lion_pytorch = "lion_pytorch" # pylint: disable=invalid-name + + class DeprecatedParameters(BaseModel): """configurations that are deprecated""" @@ -445,17 +456,7 @@ class HyperparametersConfig(BaseModel): embedding_lr_scale: Optional[float] = None weight_decay: Optional[float] = 0.0 optimizer: Optional[ - Union[ - OptimizerNames, - Literal[ - "lion_pytorch", - "optimi_adamw", - "ao_adamw_4bit", - "ao_adamw_8bit", - "ao_adamw_fp8", - "adopt_adamw", - ], - ] + Union[OptimizerNames, CustomSupportedOptimizers] ] = OptimizerNames.ADAMW_HF.value optim_args: Optional[Union[str, Dict[str, Any]]] = Field( default=None, From c24ca9bc388aa39fa46dd6588b68878430b2b2e6 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 9 Dec 2024 18:54:32 +0700 Subject: [PATCH 5/6] fix: remove alternate_optimizer --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 65de966a34..bb94913eb2 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -119,14 +119,6 @@ def validate_evaluation_strategy(cls, evaluation_strategy): LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead") return evaluation_strategy - @field_validator("alternate_optimizer") - @classmethod - def validate_alternate_optimizer(cls, alternate_optimizer): - if alternate_optimizer: - raise DeprecationWarning( - "alternate_optimizer is deprecated, use optimizer instead" - ) - class RemappedParameters(BaseModel): """parameters that have been remapped to other names""" From b56bec7810b1ec320455d0fb8648695c4fe36f54 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 13 Dec 2024 12:40:06 +0700 Subject: [PATCH 6/6] fix: no module found --- src/axolotl/utils/optimizers/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/axolotl/utils/optimizers/__init__.py diff --git a/src/axolotl/utils/optimizers/__init__.py b/src/axolotl/utils/optimizers/__init__.py new file mode 100644 index 0000000000..e69de29bb2