Skip to content

Commit

Permalink
Merge branch 'main' into warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Nov 26, 2024
2 parents 02f5b17 + c10cc89 commit f08c609
Show file tree
Hide file tree
Showing 44 changed files with 475 additions and 465 deletions.
4 changes: 2 additions & 2 deletions examples/datasets/hh-rlhf-helpful-base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import re
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Optional

from datasets import load_dataset
from transformers import HfArgumentParser
Expand Down Expand Up @@ -51,7 +51,7 @@ def common_start(str1: str, str2: str) -> str:
return "".join(common_chars)


def extract_dialogue(example: str) -> List[Dict[str, str]]:
def extract_dialogue(example: str) -> list[dict[str, str]]:
# Extract the prompt, which corresponds to the common start of the chosen and rejected dialogues
prompt_text = common_start(example["chosen"], example["rejected"])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union

import evaluate
import numpy as np
Expand Down Expand Up @@ -236,7 +236,7 @@ class RewardDataCollatorWithPadding:
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
features_j = []
features_k = []
for feature in features:
Expand Down
10 changes: 5 additions & 5 deletions examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# 0. imports
import os
from dataclasses import dataclass, field
from typing import Dict, Optional
from typing import Optional

import torch
from accelerate import Accelerator
Expand Down Expand Up @@ -109,9 +109,9 @@ def get_stack_exchange_paired(
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
'prompt': list[str],
'chosen': list[str],
'rejected': list[str],
}
Prompts are structured as follows:
Expand All @@ -126,7 +126,7 @@ def get_stack_exchange_paired(
)
original_columns = dataset.column_names

def return_prompt_and_responses(samples) -> Dict[str, str]:
def return_prompt_and_responses(samples) -> dict[str, str]:
return {
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
"chosen": samples["response_j"],
Expand Down
6 changes: 3 additions & 3 deletions examples/scripts/sft_video_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import os
import random
from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Any

import requests
import torch
Expand Down Expand Up @@ -90,7 +90,7 @@ def download_video(url: str, cache_dir: str) -> str:
raise Exception(f"Failed to download video: {e}") from e


def prepare_dataset(example: Dict[str, Any], cache_dir: str) -> Dict[str, List[Dict[str, Any]]]:
def prepare_dataset(example: dict[str, Any], cache_dir: str) -> dict[str, list[dict[str, Any]]]:
"""Prepare dataset example for training."""
video_url = example["video_url"]
timecoded_cc = example["timecoded_cc"]
Expand Down Expand Up @@ -120,7 +120,7 @@ def prepare_dataset(example: Dict[str, Any], cache_dir: str) -> Dict[str, List[D
return {"messages": messages}


def collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
def collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
"""Collate batch of examples for training."""
texts = []
video_inputs = []
Expand Down
3 changes: 3 additions & 0 deletions tests/slow/test_dpo_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits):
"""
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
Expand Down Expand Up @@ -116,6 +117,7 @@ def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_
"""
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
Expand Down Expand Up @@ -180,6 +182,7 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
Expand Down
11 changes: 10 additions & 1 deletion tests/slow/test_sft_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def test_sft_trainer_transformers(self, model_name, packing):

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

trainer = SFTTrainer(
model,
Expand Down Expand Up @@ -138,6 +139,7 @@ def test_sft_trainer_peft(self, model_name, packing):

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

trainer = SFTTrainer(
model,
Expand Down Expand Up @@ -174,6 +176,7 @@ def test_sft_trainer_transformers_mp(self, model_name, packing):

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

trainer = SFTTrainer(
model,
Expand Down Expand Up @@ -209,6 +212,7 @@ def test_sft_trainer_transformers_mp_gc(self, model_name, packing, gradient_chec

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

trainer = SFTTrainer(
model,
Expand Down Expand Up @@ -245,6 +249,7 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

trainer = SFTTrainer(
model,
Expand Down Expand Up @@ -288,6 +293,7 @@ def test_sft_trainer_transformers_mp_gc_device_map(

model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

trainer = SFTTrainer(
model,
Expand Down Expand Up @@ -327,6 +333,7 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr

model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

trainer = SFTTrainer(
model,
Expand Down Expand Up @@ -370,7 +377,9 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model, tokenizer = setup_chat_format(model, tokenizer)
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

trainer = SFTTrainer(
model,
Expand Down
2 changes: 1 addition & 1 deletion trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(self, parsers, ignore_extra_args=False):
with the processed parsers.
Args:
parsers (`List[argparse.ArgumentParser`]):
parsers (`list[argparse.ArgumentParser`]):
List of parsers.
ignore_extra_args (`bool`):
Whether to ignore extra arguments passed by the config
Expand Down
18 changes: 9 additions & 9 deletions trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import random
import warnings
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union
from typing import Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -70,10 +70,10 @@ def top_k_top_p_filtering(
return logits


def flatten_dict(nested: Dict, sep: str = "/") -> Dict:
def flatten_dict(nested: dict, sep: str = "/") -> dict:
"""Flatten dictionary and concatenate nested keys with separator."""

def recurse(nest: Dict, prefix: str, into: Dict) -> None:
def recurse(nest: dict, prefix: str, into: dict) -> None:
for k, v in nest.items():
if sep in k:
raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
Expand All @@ -87,7 +87,7 @@ def recurse(nest: Dict, prefix: str, into: Dict) -> None:
return flat


def convert_to_scalar(stats: Dict) -> Dict:
def convert_to_scalar(stats: dict) -> dict:
"""
Converts the stats from a flattened dict to single scalar dicts
"""
Expand All @@ -103,7 +103,7 @@ def convert_to_scalar(stats: Dict) -> Dict:
return tensorboard_stats


def stack_dicts(stats_dicts: List[Dict]) -> Dict:
def stack_dicts(stats_dicts: list[dict]) -> dict:
"""Stack the values of a dict."""
results = dict()
for k in stats_dicts[0]:
Expand Down Expand Up @@ -185,7 +185,7 @@ def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
return entropy


def stats_to_np(stats_dict: Dict) -> Dict:
def stats_to_np(stats_dict: dict) -> dict:
"""Cast all torch.tensors in dict to numpy arrays."""
new_dict = dict()
for k, v in stats_dict.items():
Expand All @@ -202,7 +202,7 @@ def stats_to_np(stats_dict: Dict) -> Dict:


def respond_to_batch(
model: nn.Module, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0
model: nn.Module, queries: list[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0
) -> torch.LongTensor:
"""Sample text from language model."""
input_ids = queries
Expand Down Expand Up @@ -271,8 +271,8 @@ def empty_device_cache(cls):


def randn_tensor(
shape: Union[Tuple, List],
generator: Optional[Union[List[torch.Generator], torch.Generator]] = None,
shape: Union[tuple, list],
generator: Optional[Union[list[torch.Generator], torch.Generator]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
Expand Down
26 changes: 13 additions & 13 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Sequence, TypeVar
from typing import Any, Optional, Sequence, TypeVar

from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizer
Expand All @@ -20,12 +20,12 @@
DatasetType = TypeVar("DatasetType", Dataset, DatasetDict)


def is_conversational(example: Dict[str, Any]) -> bool:
def is_conversational(example: dict[str, Any]) -> bool:
r"""
Check if the example is in a conversational format.
Args:
example (`Dict[str, Any]`):
example (`dict[str, Any]`):
A single data entry of a dataset. The example can have different keys depending on the
dataset type.
Expand Down Expand Up @@ -60,7 +60,7 @@ def is_conversational(example: Dict[str, Any]) -> bool:
return False


def apply_chat_template(example: Dict[str, List[Dict[str, str]]], tokenizer: PreTrainedTokenizer) -> Dict[str, str]:
def apply_chat_template(example: dict[str, list[dict[str, str]]], tokenizer: PreTrainedTokenizer) -> dict[str, str]:
r"""
Apply a chat template to a conversational example.
Expand Down Expand Up @@ -139,13 +139,13 @@ def apply_chat_template(example: Dict[str, List[Dict[str, str]]], tokenizer: Pre


def maybe_apply_chat_template(
example: Dict[str, List[Dict[str, str]]], tokenizer: PreTrainedTokenizer
) -> Dict[str, str]:
example: dict[str, list[dict[str, str]]], tokenizer: PreTrainedTokenizer
) -> dict[str, str]:
r"""
If the example is in a conversational format, apply a chat template to it.
Args:
example (`Dict[str, List[Dict[str, str]]`):
example (`dict[str, list[dict[str, str]]`):
Dictionary representing a single data entry of a conversational dataset. Each data entry can have different
keys depending on the dataset type. The supported dataset types are:
Expand All @@ -163,7 +163,7 @@ def maybe_apply_chat_template(
The tokenizer to apply the chat template with.
Returns:
`Dict[str, str]`: The formatted example with the chat template applied.
`dict[str, str]`: The formatted example with the chat template applied.
Note:
This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by
Expand All @@ -188,7 +188,7 @@ def maybe_apply_chat_template(
return example


def _unpair_row(examples: List[Dict[str, List[Dict[str, str]]]]) -> List[Dict[str, List[Dict[str, str]]]]:
def _unpair_row(examples: list[dict[str, list[dict[str, str]]]]) -> list[dict[str, list[dict[str, str]]]]:
batch_size = len(examples["chosen"])
new_rows = {
"completion": examples["chosen"] + examples["rejected"],
Expand Down Expand Up @@ -288,7 +288,7 @@ def maybe_unpair_preference_dataset(
return dataset


def extract_prompt(example: Dict[str, Sequence]) -> Dict[str, Sequence]:
def extract_prompt(example: dict[str, Sequence]) -> dict[str, Sequence]:
r"""
Extracts the shared prompt from a preference data example, where the prompt is implicit within both
the chosen and rejected completions.
Expand All @@ -307,7 +307,7 @@ def extract_prompt(example: Dict[str, Sequence]) -> Dict[str, Sequence]:
}


def maybe_extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
def maybe_extract_prompt(example: dict[str, list]) -> dict[str, list]:
r"""
Extracts the shared prompt from a preference data example, where the prompt is implicit within both
the chosen and rejected completions.
Expand All @@ -318,12 +318,12 @@ def maybe_extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
"rejected" completions.
Args:
example (`Dict[str, List]`):
example (`dict[str, list]`):
A dictionary representing a single data entry in the preference dataset. It must contain the keys
`"chosen"` and `"rejected"`, where each value is either conversational or standard (`str`).
Returns:
`Dict[str, List]`: A dictionary containing:
`dict[str, list]`: A dictionary containing:
- `"prompt"`: The longest common prefix between the "chosen" and "rejected" completions.
- `"chosen"`: The remainder of the "chosen" completion, with the prompt removed.
- `"rejected"`: The remainder of the "rejected" completion, with the prompt removed.
Expand Down
Loading

0 comments on commit f08c609

Please sign in to comment.