forked from AkariAsai/OpenScholar
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_grammar.py
83 lines (70 loc) · 3.76 KB
/
_grammar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional, Union
from torchtune.data import InputOutputToMessages
from torchtune.data._prompt_templates import (
GrammarErrorCorrectionTemplate,
PromptTemplate,
)
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.transforms import Transform
def grammar_dataset(
model_transform: Transform,
*,
source: str = "liweili/c4_200m",
column_map: Optional[Dict[str, str]] = None,
prompt_template: Optional[PromptTemplate] = GrammarErrorCorrectionTemplate(),
train_on_input: bool = False,
packed: bool = False,
split: str = "train",
) -> Union[SFTDataset, PackedDataset]:
"""
Support for grammar correction datasets and their variants from Hugging Face Datasets.
Here is an `example <https://huggingface.co/datasets/liweili/c4_200m>`_ of a grammar correction dataset.
The prompt template mirrors what is used in the `llama_recipes codebase
<https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py#L50>`_
where ``input`` and ``output`` are fields from the dataset.
Masking of the prompt during training is controlled by the ``train_on_input`` flag, which is
set to ``False`` by default
- If ``train_on_input`` is True, the prompt is used during training and
contributes to the loss.
- If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100)
Args:
model_transform (Transform): model specific transform to convert a list of messages
output by the dataset to tokens. This will always be a :class:`~torchtune.modules.tokenizers.ModelTokenizer`.
source (str): path to dataset repository on Hugging Face. For local datasets,
define source as the data file type (e.g. "json", "csv", "text") and pass
in the filepath in ``data_files``. See `Hugging Face's
<https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path>`_
``load_dataset`` for more details. Default is ``liweili/c4_200m``.
column_map (Optional[Dict[str, str]]): a mapping from the expected columns in the prompt template
to the new column names in the dataset. If None, assume these are identical.
prompt_template (Optional[PromptTemplate]): optional template used to format the prompt. Default
is :class:`~torchtune.data.GrammarErrorCorrectionTemplate`.
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
Returns:
Union[SFTDataset, PackedDataset]: dataset configured with source data and template
Example:
>>> grammar_ds = grammar_dataset(tokenizer=tokenizer)
>>> for batch in Dataloader(grammar_ds, batch_size=8):
>>> print(f"Batch size: {len(batch)}")
>>> Batch size: 8
"""
message_transform = InputOutputToMessages(
train_on_input=train_on_input, column_map=column_map
)
ds = SFTDataset(
source=source,
message_transform=message_transform,
model_transform=model_transform,
prompt_template=prompt_template,
split=split,
)
return PackedDataset(ds) if packed else ds