Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8-bit inference (#512) #513

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Conversation

glerzing
Copy link
Contributor

No description provided.

@Dahoas
Copy link
Collaborator

Dahoas commented Jul 10, 2023

@glerzing Do you have an example run using 8bit?

@glerzing
Copy link
Contributor Author

There are a few things to improve, I'm working on it. I'll also add an example.

@PhungVanDuy
Copy link
Collaborator

There are a few things to improve, I'm working on it. I'll also add an example.

@glerzing Thank you for the great PR, do you have any update on this or anything that you need to help with?

@glerzing
Copy link
Contributor Author

glerzing commented Jul 17, 2023

I added from_pretrained_kwargs to the model config to add some flexibility to how the model is loaded.

When testing, I ran into 2 problems with ppo_sentiments_8bit.py when it executes the function generate :

In both cases, it doesn't look related to trlx. Quantization can introduce bugs because it additionally relies on accelerate and bitsandbytes which also have dependencies, and there can be problems with the versions of different libraries. With the library versions listed in requirements.txt, I run into the 2nd problem. If I take with the latest versions, I run into the 1st one.

@Dahoas
Copy link
Collaborator

Dahoas commented Jul 21, 2023

@PhungVanDuy If you have time can you help to debug this? I think having lower precision inference and training options will be very useful.

@Dahoas
Copy link
Collaborator

Dahoas commented Jul 21, 2023

@glerzing Are you able to get quantized model inference working with our package requirements? (but without any training)

@glerzing
Copy link
Contributor Author

glerzing commented Jul 21, 2023

No, when I have the version 4.28.1 of the transformers library like in trlx, I have RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half, when it's >= 4.30.0, I get RuntimeError: probability tensor contains either 'inf', 'nan' or element < 0, which I guess happens further in the processing (my guess is that this bug is also present with the version 4.28.1 but the processing doesn't go so far).

@glerzing
Copy link
Contributor Author

Actually, adding the argument torch_dtype=torch.bfloat16 to from_pretrained and using a more recent version of the transformers library solves the issue, and enables to run ppo_sentiments_8bit.py.

@PhungVanDuy
Copy link
Collaborator

@glerzing @Dahoas I tried to run inference with 8-bit but I dont think this way could help inference faster:
https://wandb.ai/pvduy/trlx/reports/8bit-Sentiment-Rollout--Vmlldzo0OTUxOTM5

This is also mentioned by the author here:

The main purpose of the LLM.int8() method is to make large models more accessible without performance degradation. But the method would be less useful if it is very slow. So we benchmarked the generation speed of multiple models. We find that BLOOM-176B with LLM.int8() is about 15% to 23% slower than the fp16 version – which is still quite acceptable. 

Let's come up with another idea like using vLLM, with my experiments vLLM actually boosts the inference time. I will work in that direction.

@Dahoas
Copy link
Collaborator

Dahoas commented Jul 24, 2023

Thanks for checking this. Were you able to run this experiment with the trlX's pinned transformer's version? Or will we need to update it.

On the inference speedup side, vLLM seems like a good idea. In general implementing some kind of asynchronous PPO like v-trace seems promising

@PhungVanDuy
Copy link
Collaborator

Thanks for checking this. Were you able to run this experiment with the trlX's pinned transformer's version? Or will we need to update it.

On the inference speedup side, vLLM seems like a good idea. In general implementing some kind of asynchronous PPO like v-trace seems promising

I have to update that one, I guess we should also update the transformer's version in terms of supporting LLaMA 2.

I am checking vLLM to see how hard to integrate. Thank you for your suggestion on asynchronous PPO.

@glerzing
Copy link
Contributor Author

glerzing commented Aug 7, 2023

I was wondering if there should be an example of how to train 16-bit models.
Because now that there is the config argument from_pretrained_kwargs you can easily set torch_dtype=torch.bfloat16, which doesn't seem obvious to newcomers. On the other side, I'm not sure whether it's worth adding another file ppo_sentiments_16bit.py just to show that we can easily do that.

@Dahoas
Copy link
Collaborator

Dahoas commented Aug 28, 2023

@glerzing Checking in on the state of this pr. Do you have any more features you would like to add? If not, let's get it merged sometime this week

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants