Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚜 Use field in dataclasses #2494

Merged
merged 39 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7f7c51e
in hh-rlhf-helpful-base
qgallouedec Dec 17, 2024
9f7be40
delete tokenize ds
qgallouedec Dec 17, 2024
0c2d219
dataset scripts
qgallouedec Dec 17, 2024
7902808
alignprop
qgallouedec Dec 17, 2024
835a10b
judge tldr
qgallouedec Dec 17, 2024
362e722
ddpo
qgallouedec Dec 17, 2024
89eeb60
zen
qgallouedec Dec 17, 2024
9eabe63
sft video
qgallouedec Dec 17, 2024
e20af16
literal to choices
qgallouedec Dec 17, 2024
0514967
chat
qgallouedec Dec 17, 2024
64546bb
script args
qgallouedec Dec 17, 2024
8784ef9
alignprop
qgallouedec Dec 17, 2024
4b28cf0
bco
qgallouedec Dec 17, 2024
2595aff
better help format
qgallouedec Dec 17, 2024
c57eac0
cpo
qgallouedec Dec 21, 2024
a952f28
ddpo
qgallouedec Dec 21, 2024
835ca04
whether or not -> whether
qgallouedec Dec 21, 2024
ed6954f
dpo
qgallouedec Dec 21, 2024
b21570f
dont set the possible values
qgallouedec Dec 21, 2024
7c5df16
`Optional[...]` to ... or `None`
qgallouedec Dec 21, 2024
609d081
xpo
qgallouedec Dec 21, 2024
64543ca
gkd
qgallouedec Dec 21, 2024
9b7764e
kto
qgallouedec Dec 21, 2024
3a89ee1
nash
qgallouedec Dec 21, 2024
61ecd49
online dpo
qgallouedec Dec 21, 2024
20c6992
Merge branch 'main' into field
qgallouedec Dec 21, 2024
55ac06a
Merge branch 'field' of https://github.com/huggingface/trl into field
qgallouedec Dec 21, 2024
5edf56b
Fix typo in learning rate help message
qgallouedec Dec 21, 2024
79d49ae
orpo
qgallouedec Dec 21, 2024
0d811f1
more ... or `None`
qgallouedec Dec 21, 2024
54df2af
model config
qgallouedec Dec 21, 2024
140539e
ppo
qgallouedec Dec 21, 2024
cd9fa63
prm
qgallouedec Dec 21, 2024
c4e5441
reward
qgallouedec Dec 21, 2024
1d74dc7
rloo
qgallouedec Dec 21, 2024
3f9b81f
sft
qgallouedec Dec 21, 2024
2ea9cb9
online policy config
qgallouedec Dec 21, 2024
b0b6f51
Merge branch 'main' into field
qgallouedec Jan 6, 2025
1c82235
make style
qgallouedec Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions examples/datasets/hh-rlhf-helpful-base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import re
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset
Expand All @@ -30,13 +30,20 @@ class ScriptArguments:
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/hh-rlhf-helpful-base"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/hh-rlhf-helpful-base"
dataset_num_proc: Optional[int] = None
push_to_hub: bool = field(
default=False,
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
)
repo_id: str = field(
default="trl-lib/hh-rlhf-helpful-base", metadata={"help": "Hugging Face repository ID to push the dataset to."}
)
dataset_num_proc: Optional[int] = field(
default=None, metadata={"help": "Number of workers to use for dataset processing."}
)


def common_start(str1: str, str2: str) -> str:
Expand Down
19 changes: 14 additions & 5 deletions examples/datasets/lm-human-preferences-descriptiveness.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset
Expand All @@ -29,13 +29,22 @@ class ScriptArguments:
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/lm-human-preferences-descriptiveness"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/lm-human-preferences-descriptiveness"
dataset_num_proc: Optional[int] = None
push_to_hub: bool = field(
default=False,
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
)
repo_id: str = field(
default="trl-lib/lm-human-preferences-descriptiveness",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)


# Edge cases handling: remove the cases where all samples are the same
Expand Down
19 changes: 14 additions & 5 deletions examples/datasets/lm-human-preferences-sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset
Expand All @@ -29,13 +29,22 @@ class ScriptArguments:
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/lm-human-preferences-sentiment"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/lm-human-preferences-sentiment"
dataset_num_proc: Optional[int] = None
push_to_hub: bool = field(
default=False,
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
)
repo_id: str = field(
default="trl-lib/lm-human-preferences-sentiment",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)


def to_prompt_completion(example, tokenizer):
Expand Down
19 changes: 14 additions & 5 deletions examples/datasets/math_shepherd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import re
from dataclasses import dataclass
from dataclasses import dataclass, field
from itertools import chain
from typing import Optional

Expand All @@ -31,13 +31,22 @@ class ScriptArguments:
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/math_shepherd"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/math_shepherd"
dataset_num_proc: Optional[int] = None
push_to_hub: bool = field(
default=False,
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
)
repo_id: str = field(
default="trl-lib/math_shepherd",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)


def process_example(example):
Expand Down
19 changes: 14 additions & 5 deletions examples/datasets/prm800k.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset
Expand All @@ -29,13 +29,22 @@ class ScriptArguments:
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/prm800k"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/prm800k"
dataset_num_proc: Optional[int] = None
push_to_hub: bool = field(
default=False,
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
)
repo_id: str = field(
default="trl-lib/prm800k",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)


def process_example(example):
Expand Down
19 changes: 14 additions & 5 deletions examples/datasets/rlaif-v.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional

from datasets import features, load_dataset
Expand All @@ -29,13 +29,22 @@ class ScriptArguments:
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/rlaif-v"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/rlaif-v"
dataset_num_proc: Optional[int] = None
push_to_hub: bool = field(
default=False,
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
)
repo_id: str = field(
default="trl-lib/rlaif-v",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)


def to_conversational(example):
Expand Down
19 changes: 14 additions & 5 deletions examples/datasets/tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset
Expand All @@ -29,13 +29,22 @@ class ScriptArguments:
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/tldr"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/tldr"
dataset_num_proc: Optional[int] = None
push_to_hub: bool = field(
default=False,
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
)
repo_id: str = field(
default="trl-lib/tldr",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)


def to_prompt_completion(example):
Expand Down
19 changes: 14 additions & 5 deletions examples/datasets/tldr_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset
Expand All @@ -29,13 +29,22 @@ class ScriptArguments:
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/tldr-preference"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/tldr-preference"
dataset_num_proc: Optional[int] = None
push_to_hub: bool = field(
default=False,
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
)
repo_id: str = field(
default="trl-lib/tldr-preference",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)


def to_preference(example):
Expand Down
54 changes: 0 additions & 54 deletions examples/datasets/tokenize_ds.py

This file was deleted.

19 changes: 14 additions & 5 deletions examples/datasets/ultrafeedback-prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset
Expand All @@ -29,13 +29,22 @@ class ScriptArguments:
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/ultrafeedback-prompt"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/ultrafeedback-prompt"
dataset_num_proc: Optional[int] = None
push_to_hub: bool = field(
default=False,
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
)
repo_id: str = field(
default="trl-lib/ultrafeedback-prompt",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)


def to_unpaired_preference(example):
Expand Down
Loading
Loading