Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get optimizer method #5149

Closed
wants to merge 4 commits into from
Closed

Conversation

ShellyNR
Copy link
Contributor

Add the get_optimizer method to accelerators for use during optimizer configuration, instead of checking for the specific accelerator in the engine.py code.

In all accelerators other than hpu the current implementation returns None so that the previous flow is not affected.

In hpu_accelerator.py, the method returns an optimizer if the configured optimizer has hpu-specific implementation or requirements. Otherwise, it returns None.

@ShellyNR
Copy link
Contributor Author

@ShellyNR please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.

@microsoft-github-policy-service agree [company="{your company}"]

Options:

  • (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
  • (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"

Contributor License Agreement

@microsoft-github-policy-service agree

@@ -294,3 +294,22 @@ def build_extension(self):

def export_envs(self):
return []

def get_optimizer(self, optimizer_name, cpu_optimization, model_parameters, **optimizer_parameters):
from habana_frameworks.torch.hpex.optimizers import FusedAdamW
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not convinced that this much code is needed.

It seems the goal is to replace AdamW on hpu accelerator as below:

from deepspeed.ops.adam import FusedAdam with
from habana_frameworks.torch.hpex.optimizers import FusedAdamW

Is this correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjruwase Yes, but in addition to prepare the ground for other fused optimizers that are addressed in this function. for example: FusedLamb, OneBit, etc...
Do you have in mind something else?

Copy link
Contributor

@tjruwase tjruwase Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nelyahu, thanks for the clarification. I do agree that deferring to accelerator to instantiate fused optimizer when available is the way to go. My concern is the duplication of optimizer selection logic here with engine.py will be difficult to maintain long term.

I think the complication is that engine.py mixes optimizer name refinement (especially for Adam variants) and optimizer instantiation. By name refinement, I mean something like ADAM_OPTIMIZER which is a user-facing config value, internally maps to one of many optimizers (e.g., torch.optim.Adam, DeepSpeedCPUAdam, etc.). So, my thought is to

  1. Create an internal naming convention for the various optimizer instantiations. For example,
  • _TORCH_ADAM_OPTIMIZER for torch.optim.Adam,
  • _DS_FUSED_ADAM for deepspeed.ops.adam.FusedAdam,
  • _DS_ADAM_OPTIMIZER for deepspeed.ops.adam.DeepSpeedCPUAdam.
  1. Create a function that maps an external name to an internal one based on other configuration values.  For example, ADAM_OPTIMIZER would map to one of _TORCH_ADAM, _DS_FUSED_ADAM, _DS_CPUADAM.

Based on the above the following changes could now be made:

  1. Simplify get_optimizer() by accepting internal optimizer name and creating an optimizer if name is one of supported internal optimizer names. Also, cpu_optimization is no longer needed.
  2. Simplify _configure_basic_optimizer() to (i) convert to internal name, (ii) call accelerator().get_optimizer() with internal name to create optimizer, and (iii) otherwise fall through to existing optimizer creation logic.

Looking forward to your feedback. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjruwase I applied the changes you suggested, however some optimizers still require specific code. Which one of the commits you think is best?

Copy link
Contributor

@tjruwase tjruwase Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ShellyNR, thanks for making the changes. This looks better to me.

When you say 'some optimizers still require specific code, is that referring to _ADAM in hpu_accelerator which maps to torch.optim.Adam()? If so, I think that is fine. My goal is to avoid duplication between accelerator codes and engine.py.

@ShellyNR ShellyNR force-pushed the add_get_optimizer branch from 03ee38a to 1be755b Compare March 11, 2024 15:36
"_TORCH_ADAMW": lambda arg1, **arg2: torch.optim.AdamW(arg1, **arg2),
"_CPU_ADAM": lambda arg1, **arg2: DeepSpeedCPUAdam(arg1, **arg2, adamw_mode=False),
"_CPU_ADAMW": lambda arg1, **arg2: DeepSpeedCPUAdam(arg1, **arg2, adamw_mode=True),
"_ADAM": lambda arg1, **arg2: FusedAdam(arg1, **arg2, adam_w_mode=False),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating this dict. Based on this, I think mappings of cuda-specific optimizers. such as Fused[Adam|Lamb|Lion] should move to cuda_accelerator.py.

What do you think?

@tjruwase
Copy link
Contributor

@ShellyNR, thanks for your work on this PR. I wanted to check if you needed more clarification to address my comments?

@nelyahu
Copy link
Contributor

nelyahu commented Apr 15, 2024

@tjruwase , sorry for not updating this PR, @ShellyNR was OOO for few days. we are considering another direction for support FusedAdam via existing HPU implementation.

@tjruwase
Copy link
Contributor

@tjruwase , sorry for not updating this PR, @ShellyNR was OOO for few days. we are considering another direction for support FusedAdam via existing HPU implementation.

@nelyahu, thanks for the update. I am excited to see the new approach :).

@ShellyNR
Copy link
Contributor Author

I'm closing this PR, a new PR that addresses this issue without using a get_optimizer method will be pushed soon.
@tjruwase thanks for your review and comments (:

@ShellyNR ShellyNR closed this Apr 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants