forked from AkariAsai/OpenScholar
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_sft.py
149 lines (124 loc) · 7.46 KB
/
_sft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Dict, Mapping, Optional
import numpy as np
from datasets import load_dataset
from torch.utils.data import Dataset
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, PromptTemplate
from torchtune.modules.transforms import Transform
class SFTDataset(Dataset):
"""
Primary class for creating any dataset for supervised fine-tuning either from
Hugging Face Hub, local files, or remote files. This class supports instruct,
chat, tool, or multimodal data for fine-tuning. At a high level, this class
will load the data from source and apply the following pre-processing steps
when a sample is retrieved:
1. Dataset-specific transform. This is typically unique to each dataset and extracts
the necessary columns into torchtune's :class:`~torchtune.data.Message` format,
a standardized API for all model tokenizers.
2. If specified, apply a prompt template for the task you are fine-tuning for.
3. Model-specific transform or tokenization
All datasets are formatted into a list of :class:`~torchtune.data.Message`
because for fine-tuning, datasets can be considered as "conversations" with the model,
or AI assistant. Thus, we can standardize all text content as messages in a conversation assigned to
a role:
- ``"system"`` messages contain the system prompt
- ``"user"`` messages contain the input prompt into the model
- ``"assistant"`` messages are the response of the model and what you actually want
to train for and compute loss directly against
- ``"ipython"`` messages are the return from a tool call
Chat datasets are multiple rounds of user-assistant messages. Instruct datasets
are typically a single round involving a specific instruction and the model's response.
Tool datasets are a type of chat dataset that includes ipython messages. Multimodal
datasets are a type of chat dataset that incorporates media into the user messages.
The :class:`~torchtune.data.Message` forms the core data unit that all tokenizer
APIs expect. The key component of this class that ensures any dataset is transformed
into this format is the ``message_transform``. This is a callable class that takes
in a sample dictionary - typically a single row from the source dataset - that
processes the sample in any configurable way to output a list of messages::
[
Message(
role=<system|user|assistant|ipython>,
content=<message>,
),
...
]
For any custom dataset, use the ``message_transform`` to contain all pre-processing to
return the list of messages.
Any model-specific pre-processing that needs to happen can be configured with the ``model_transform``
parameter. This is another callable class that contains any custom logic tied to the
model you are fine-tuning and will carry over to inference. For example, text + image
multimodal datasets requires processing the images in a way specific to the vision
encoder being used by the model and is agnostic to the specific dataset.
Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer`
can be treated as a ``model_transform`` since it uses the model-specific tokenizer to
transform the list of messages outputted from the ``message_transform`` into tokens
used by the model for training. Text-only datasets will simply pass the :class:`~torchtune.modules.tokenizers.ModelTokenizer`
into ``model_transform``.
Args:
source (str): path to dataset repository on Hugging Face. For local datasets,
define source as the data file type (e.g. "json", "csv", "text") and pass
in the filepath in ``data_files``. See `Hugging Face's
<https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path>`_
``load_dataset`` for more details.
message_transform (Transform): callable that keys into the desired fields in the sample
and converts text content to a list of :class:`~torchtune.data.Message`. It is expected that the final list
of messages are stored in the ``"messages"`` key.
model_transform (Transform): callable that applies model-specific pre-processing to the sample after the list of
messages is created from ``message_transform``. This includes tokenization and any modality-specific
transforms. It is expected to return at minimum ``"tokens"`` and ``"mask"`` keys.
prompt_template (Optional[PromptTemplate]): template used to format the messages based on their role. This is used
to add structured text around the actual messages. The structured text is used in three scenarios:
- Task-specific templates to gear models for a particular task that it will expect after training
- Model-specific templates that are required whenever the model is prompted, such as the [INST]
tags in Llama2 and in Mistral
- Community standardized templates, such as :class:`~torchtune.data.ChatMLTemplate`
The extra text will still get tokenized as normal text, not as special tokens.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging
Face's `API ref <https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset>`_
for more details.
"""
def __init__(
self,
*,
source: str,
message_transform: Transform,
model_transform: Transform,
prompt_template: Optional[PromptTemplate] = None,
filter_fn: Optional[Callable] = None,
**load_dataset_kwargs: Dict[str, Any],
) -> None:
self._message_transform = message_transform
self._prompt_template = prompt_template
self._model_transform = model_transform
self._data = load_dataset(source, **load_dataset_kwargs)
if filter_fn is not None:
self._data = self._data.filter(filter_fn)
def __len__(self):
return len(self._data)
def __getitem__(self, index: int) -> Dict[str, Any]:
sample = self._data[index]
return self._prepare_sample(sample)
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
transformed_sample = self._message_transform(sample)
if self._prompt_template is not None:
transformed_sample["messages"] = self._prompt_template(
transformed_sample["messages"]
)
tokenized_dict = self._model_transform(transformed_sample)
# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
tokenized_dict["labels"] = list(
np.where(
tokenized_dict["mask"],
CROSS_ENTROPY_IGNORE_IDX,
tokenized_dict["tokens"],
)
)
assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"])
return tokenized_dict