-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation style appears tied to the dataset rather than the model #2096
Comments
@RdoubleA do you mind taking a look? |
@RonanKMcGovern thanks for creating the issue. Rafi will probably have the best answer here, but I can weigh in as well. First, one thing worth clarifying: the use of
Your understanding is correct. In general if we are formatting the prompt in a certain way I'm not sure it would be easy to infer what formatting comes from special tokens vs what doesn't. Also:
This we don't currently support (mainly because it is hard to map 1:1 from Hugging Face to torchtune tokenization logic, though this is likely something we can work to improve).
There are different entry points depending on the degree of customization you're looking for, but most of them should be accessible directly through the tokenizer (i.e. not through the dataset, so you do have access to any special tokens).
P.S. sorry I am realizing that some of the model tokenizer classes are not linked in our API reference and they contain a lot of important info (e.g. the details about Llama3 here). So will make sure we update that! |
Thanks very much for your detailed response. I'll need to dig in a bit more
but this is giving me structure for finding my feet.
For background, I'm coming to this from a transformers/unsloth perspective
where I'm used to:
a) models having a tokenizer.chat_template
b) I either i) start with data that is formatted as an array of
system/user/assistant messages OR ii) data that is in columns, say of
"question" and "answer", or just "text".
c) either the trainer automatically applies the chat template as in b.i. OR
I set up a formatting function for b.ii. (which may just make use of the
chat_template) to convert data in columns into a single text string).
If using a chat_template, that typically will automatically have an EOS
type token at the end of assistant responses, which will ensure the model
knows when to stop (and this token is typically not masked. I note in torch
tune docs, that bos and eos typically are masked, and then it seems there
is an option to add an additional ending token, which I presume is not
masked).
Lastly, there is the matter of whether one trains on completions or not.
This can be a bit messy with transformers/unsloth and often requires
identifying the portion of the chat template that delineates the assistant
response's start, and using that to find where to position the loss mask.
I don't think the above approach is necessarily better than torch tune. I'm
just sharing it for context (apologies if already obvious).
So, moving to torch tune, it seems like the philosophy is that models are
more hard coded (i.e. the prompt template is already there for common
models in the fine-tuning library). For most use cases this is fine (or
probably easier, because one can easily turn on training on
completions-only or not, because the library has a firm understanding of
model/tokenizer syntax). For edge cases where the prompt is being changed
and special tokens introduced (say, custom tool calling that doesn't follow
the base model approach), this would require care [and that is fine].
Something I think is worth some thought is ensure that the final
tokenizer.json does have a tokenizer.chat_template that is consistent with
how the training was done. This (unless I'm mistaken) is what vllm uses to
apply a chat template for inference.
Anyway. I'm keen to try out torch tune mostly for two reasons: i) to see
training speed, ii) since it allows for multi-gpu quite easily (DDP and
FSDP is more messy on transformers and not possible on unsloth).
One last side-note. There was a major issue in transformers/unsloth whereby
gradient accumulation was naively adding gradients without properly
normalising for the number of unmasked tokens and length. This led to major
error on the loss when accumulating gradients. I read through the torch
tune code today and I think it is correct, but probably would be wise for
someone smart on torch tune to read this if they haven't already -
https://unsloth.ai/blog/gradient.
…On Sat, Nov 30, 2024 at 8:35 PM ebsmothers ***@***.***> wrote:
@RonanKMcGovern <https://github.com/RonanKMcGovern> thanks for creating
the issue. Rafi will probably have the best answer here, but I can weigh in
as well. First, one thing worth clarifying: the use of conversation_style
is primarily to get data from its raw format into an intermediate format
that can be understood by torchtune. E.g. you can see the input and output
formats of ShareGPTToMessages here
<https://pytorch.org/torchtune/main/generated/torchtune.data.ShareGPTToMessages.html#torchtune.data.ShareGPTToMessages>
-- note that nothing is actually being tokenized, nor do we do anything
involving special tokens.
What happens if the tokenizer+model do not have the tokens required for a
given conversation style? Are those special tokens created? I assume not.
Your understanding is correct. In general if we are formatting the prompt
in a certain way I'm not sure it would be easy to infer what formatting
comes from special tokens vs what doesn't. Also:
Is there an option whereby one can default to using
tokenizer.chat_template for the conversation style? (most models on
huggingface have this defined)
This we don't currently support (mainly because it is hard to map 1:1 from
Hugging Face to torchtune tokenization logic, though this is likely
something we can work to improve).
The practical task I'm interested in is fine-tuning llama 3 and qwen 2.5
using conversation styles that match their chat templates (so as to
minimise the re-training/over-writing that I'm doing).
There are different entry points depending on the degree of customization
you're looking for, but most of them should be accessible directly through
the tokenizer (i.e. not through the dataset, so you do have access to any
special tokens).
1.
As a first step, you may just be able to use the tokenizer as-is --
all our model tokenizers have a method tokenize_messages which will
take in a list of messages (i.e. those returned by the
conversation_style mentioned above) and tokenize them with the default
formatting and special tokens for that model. E.g. for Llama3 you can just
use llama3_tokenizer
<https://pytorch.org/torchtune/main/generated/torchtune.models.llama3.llama3_tokenizer.html#llama3-tokenizer>,
and you can see here
<https://github.com/pytorch/torchtune/blob/32e265d5749fd592711a03247486eafa6c898d94/torchtune/models/llama3/_tokenizer.py#L261>
that its tokenize_messages method will iterate over the messages and
apply any formatting with special tokens that is unique to Llama 3 (see
e.g. the standard Llama 3 chat header added here
<https://github.com/pytorch/torchtune/blob/32e265d5749fd592711a03247486eafa6c898d94/torchtune/models/llama3/_tokenizer.py#L199-L208>
).
2.
If you want to do some formatting of the prompt you can use a
predefined prompt template or write your own (check this page
<https://pytorch.org/torchtune/main/basics/prompt_templates.html> in
our live docs for how to do this). This then plugs into the tokenizer to
format each message before it gets tokenized. You can also combine this
with the special_tokens_path tokenizer argument (e.g. see in the Qwen
2.5 tokenizer API reference here
<https://pytorch.org/torchtune/main/generated/torchtune.models.qwen2_5.qwen2_5_tokenizer.html#torchtune.models.qwen2_5.qwen2_5_tokenizer>).
Combining both of these should get you custom formatting with any
additional special tokens, but just in case..
3.
Finally, if you find that the existing tokenizer class is just not
cutting it or there's some additional customization you need, you can
always modify the tokenizer yourself! Probably the quickest way to do this
is to just inherit from whatever model tokenizer class you'd like to modify
(e.g. Llama3Tokenizer here
<https://github.com/pytorch/torchtune/blob/32e265d5749fd592711a03247486eafa6c898d94/torchtune/models/llama3/_tokenizer.py#L43>)
and override tokenize_messages or any other methods you need to
customize. But this should be more of a last resort -- the hope is that
through appropriate customization of prompt_template and/or
special_tokens in (2) you never need to do this for an existing model.
P.S. sorry I am realizing that some of the model tokenizer classes are not
linked in our API reference and they contain a lot of important info (e.g.
the details about Llama3 here
<https://github.com/pytorch/torchtune/blob/32e265d5749fd592711a03247486eafa6c898d94/torchtune/models/llama3/_tokenizer.py#L43-L63>).
So will make sure we update that!
—
Reply to this email directly, view it on GitHub
<#2096 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ASVG6CTFRPQUGN5OERB3BTD2DKG2NAVCNFSM6AAAAABSY2L2LWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKMBZGU3DONZQGQ>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
@RonanKMcGovern to address some of these comments individually:
b) is good as a starting point for torchtune too. In both i) and ii) the flow is generally the same: given a raw data format, first apply a transform specific to your particular dataset to get the data into a standard format recognized by torchtune (i.e. a list of Message objects), then apply any custom formatting or model-specific logic (including tokenization). So for case i) you would probably want to use either ShareGPTToMessages or OpenAIToMessages for the first step, depending on your exact input format (and of course if it's in a non-standard format you can always write your own version of these). I would split case ii) further: the case of multiple columns (e.g. "question" and "answer") should fall under instruct_dataset. This comes readymade with a transform into the message format: InputOutputToMessages. In this case you would just need to pass the Then (c) basically corresponds to the second step -- apply custom formatting, tokenization, etc.. anything that's unique to the model. I think I already covered this in my last comment, but lmk if there's more you're unclear on here.
Yes, this is all determined by the tokenizer. We do always mask BOS and EOS, and in e.g. Llama3 there are tokens like EOT and EOM that will not be masked.
Yes, at least for our chat- and instruct-style datasets we have the flag
Yeah this is a good point. Right now we do not really maintain a standard mapping between our tokenizer and Hugging Face chat templates, but we are planning on providing cleaner integration with vLLM soon so this is likely something we will need to support.
For training speed, I would recommend running with torch compile (set
Yes, thanks for mentioning this. We fixed this in #1917 (this also fixes the same problem for distributed training, where the accumulation of the number of tokens seen needs to be taken over all ranks to properly normalize the loss). |
If I'm not mistaken, the conversation style that applies during a fine-tune is defined by the dataset defaults, rather than by the tokenizer being used (docs here.
What happens if the tokenizer+model do not have the tokens required for a given conversation style? Are those special tokens created? I assume not.
Is there an option whereby one can:
I'm guessing one issue here is that - since tokenizer.chat_template is not known in advance, this poses issues for controlling the loss mask on the prompt vs completions?
So maybe that's the dilemna? Either one can:
a) load a default conversation style from the model/tokenizer, but then it's hard to implement loss masks, or
b) load the default conversation style based on the dataset choice, but then there risks being token incompatibilities with the model/tokenizer being trained.
The practical task I'm interested in is fine-tuning llama 3 and qwen 2.5 using conversation styles that match their chat templates (so as to minimise the re-training/over-writing that I'm doing).
The text was updated successfully, but these errors were encountered: