Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ddp #32

Merged
merged 10 commits into from
Jan 25, 2024
Merged

Ddp #32

merged 10 commits into from
Jan 25, 2024

Conversation

LauraGPT
Copy link
Collaborator

What does this PR do?

Fixes # (issue)

Feature/Issue validation/testing

Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.

  • Test A
    Logs for Test A

  • Test B
    Logs for Test B

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Thanks for contributing 🎉!


train_config:
model_name: "PATH/to/LLAMA/7B"
enable_fsdp: false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable_ddp: false

++train_config.lr=1e-4 \
++train_config.output_dir=$output_dir \
++train_config.peft_config.peft_method=lora \
++metric=acc \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metric to be classified to a certain class

++train_config.enable_fsdp=true \
++train_config.enable_ddp=false \
++train_config.use_fp16=true \
++metric=acc \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metric to be classified to a certain class

use_fp16: false
# sharding_strategy: "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
sharding_strategy: "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD to use DDP mode in FSDP
checkpoint_type: "StateDictType.SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To unify with only the str name, say, "SHARDED_STATE_DICT"

@@ -229,12 +233,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
)

else:
if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
if not train_config.use_peft and fsdp_config.checkpoint_type == "StateDictType.FULL_STATE_DICT":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to modify


save_model_checkpoint(
model, optimizer, rank, train_config, epoch=epoch
)
elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
elif not train_config.use_peft and fsdp_config.checkpoint_type == "StateDictType.SHARDED_STATE_DICT":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to modify

@ddlBoJack ddlBoJack merged commit 6d30313 into main Jan 25, 2024
2 of 4 checks passed
@LauraGPT LauraGPT deleted the ddp branch February 4, 2024 03:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants