-
Notifications
You must be signed in to change notification settings - Fork 126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modernize MosaicBERT #440
base: main
Are you sure you want to change the base?
Modernize MosaicBERT #440
Conversation
617db70
to
c9ee668
Compare
c9ee668
to
b809a7b
Compare
if convert_dtype: | ||
# Triton implementation only supports fp16 and bf16 | ||
orig_dtype = qkv.dtype | ||
qkv = qkv.to(torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this to be in torch.float16
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do not, this code was here before though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How should we select between bfloat16 and float16 though?
@@ -266,8 +261,6 @@ def build_text_dataloader( | |||
cfg.dataset.get('validate_hash', None), | |||
keep_zip=stream.get('keep_zip', None) or | |||
cfg.dataset.get('keep_zip', False), | |||
keep_raw=stream.get('keep_raw', None) or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just noting that this is correct and that keep_raw
is no longer a flag in mosaicml-streaming
(see Streaming docs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you check that the defaults here match the defaults currently set in llm foundry?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The defaults in llm foundry are a bit different. Should we update this function whole-hog?
From llmfoundry text_data.py
def __init__(self,
tokenizer: PreTrainedTokenizerBase,
max_seq_len: int,
streams: Optional[Sequence[Stream]] = None,
remote: Optional[str] = None,
local: Optional[str] = None,
split: Optional[str] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[Union[int, str]] = None,
predownload: Optional[int] = None,
cache_limit: Optional[Union[int, str]] = None,
partition_algo: str = 'relaxed',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1e',
shuffle_seed: int = 9176,
shuffle_block_size: Optional[int] = None,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
**kwargs: Any):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's still text data, this should be good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just linting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just linting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just linting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just linting
Should be close to done @dakinggg, the two failed pytests were
|
@@ -425,6 +499,7 @@ def __init__(self, config): | |||
(1, self.num_attention_heads, self._current_alibi_size, | |||
self._current_alibi_size)) | |||
self.rebuild_alibi_tensor(size=config.alibi_starting_size) | |||
self.slopes = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @Skylion007 many thanks for this PR! I am currently testing it (with own dataset) and training is working (8x H100).
I had to remove this line, because:
this.slopes
is set in therebuild_alibi_tensor
function before- it is later needed in line 583
Setting to None
will then cause an error in line 583.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was on me trying to appease the linting gods. Thanks for catching! Should be removed now
UPDATE on 1/8/24: This was not an issue for me on a clean machine, so this is unlikely to be a real issue, and VERY unlikely to be an issue with this PR. ============== Details: Env: (This is in WSL for Windows, but most of the time that's equivalent to a Ubuntu environment, and I don't think it's the source of this error.) I just checked out the branch and created a clean conda env. Then, I did the
I tried adding so I
re-ran
FA2 is assuming that torch is already installed, but it's being installed as a sibling, so it's not a module yet! Then I installed FA2 by running that: This |
One more bug that I'll report here just in case it is not just a "my machine" thing. I didn't see NVidia Apex mentioned on the requirements, but when I get to the point where I am running this:
It looks like I need to have NVidia Apex installed:
|
An update on the above: Once I installed Apex from source, the command worked. You have already recommended the MosaicML Pytorch base image, which presumably comes with Apex pre-installed. I decided to ignore that handy tip and run from my existing WSL environment. Something that would have helped me would be to clarify that if the user does not use the recommended Pytorch base image, they will need to install Apex after pip installing the requirements.txt. If I'm not the target audience, or this is opening you up to way too much config specification, I get it. |
With regards to my comment :
This was not an issue for me on a clean machine, so this is unlikely to be a real issue, and VERY unlikely to be an issue with this PR. |
I believe that one of the test yamls is missing: algorithms:
fused_layernorm: {} I say that because in the README, it explains you can do a test run of training a Mosaic model by running: # Run the pre-training script with the test config and MosaicBERT
composer main.py yamls/test/main.yaml model.name=mosaic_bert However, yamls/test/main.yaml doesn't have these lines: algorithms:
fused_layernorm: {} But That means that the first time it tries to load Apex's fused_layernorm is when you get to this section:
I noticed this because I got an error when it tried to load Apex and my environment didn't have it installed. I was surprised because all of my "tests" from the README worked. |
Hi @Taytay, Thanks for pointing this out. The MosaicML Composer library for a while used Fused Layernorm as a Composer "algorithm" to speed up pretraining. It relies on NVIDIA Apex and enables a faster kernel for LayerNorm. More recently, we've been using Low Precision LayerNorm which does not rely on APEX and works just as well as Fused LayerNorm. From the Composer docs:
In the yaml, you can replace algorithms:
low_precision_layernorm: {} I've updated the mosaicbert pretraining and finetuning yamls to use |
Thanks @jacobfulano. That's good news. It's worth mentioning that I ran into a bug in this branch that is fixed by #443 |
This PR modernizes the MosaicBERT codebase with Flash Attention 2, PyTorch 2 (
torch==2.1.1
), and an updated version of composer (mosaicml>=0.17
).In particular, this updates MosaicBERT to be compatible with Flash Attention 2 (
flash-attn==4.2.4
), which now supports ALiBi slopes (PR#540).Context:
triton
in https://github.com/mosaicml/examples/blob/v0.0.4/examples/bert/src/flash_attn_triton.py. This version oftriton
also required PyTorch 1.13. This is also the kernel used for the MosaicBERT NeurIPS submission.triton
implementationSee w&b runs here
Note that changes to files outside of
examples/benchmarks/bert
are simply formatting changes due to linting.