Skip to content

Commit

Permalink
Add DPO training example and fix minor bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
sandeepchittilla committed Sep 12, 2023
1 parent cd923c1 commit deb71c1
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 19 deletions.
99 changes: 99 additions & 0 deletions examples/hh/dpo_hh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import json
import sys
from collections import defaultdict

import tqdm
from datasets import Dataset, load_dataset

import trlx
from trlx.data.default_configs import (
DPOConfig,
ModelConfig,
OptimizerConfig,
SchedulerConfig,
TokenizerConfig,
TrainConfig,
TRLConfig,
)

default_config = TRLConfig(
train=TrainConfig(
seq_length=1024,
epochs=100,
total_steps=1000,
batch_size=1,
checkpoint_interval=10000,
eval_interval=100,
pipeline="PromptPipeline",
trainer="AccelerateDPOTrainer",
checkpoint_dir="checkpoints/dpo_hh",
),
model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1),
tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"),
optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)),
scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps
method=DPOConfig(
name="DPOConfig", gen_kwargs=dict(max_new_tokens=40, top_k=20, top_p=1.0, do_sample=True), beta=0.1
),
)


def get_hh(split: str, sanity_check=False, silent=False):
dataset = load_dataset("Anthropic/hh-rlhf", split=split)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))

def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = "\n\nAssistant:"
search_term_idx = prompt_and_response.rfind(search_term)
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
return prompt_and_response[: search_term_idx + len(search_term)]

def split_prompt_and_responses(ex):
prompt = extract_anthropic_prompt(ex["chosen"])
chosen_response = ex["chosen"][len(prompt) :]
rejected_response = ex["rejected"][len(prompt) :]
return prompt, chosen_response, rejected_response

data = defaultdict(lambda: defaultdict(list))
for row in tqdm.tqdm(dataset, desc="Processing HH", disable=silent):
prompt, chosen, rejected = split_prompt_and_responses(row)
responses = [chosen, rejected]
n_responses = len(data[prompt]["responses"])
data[prompt]["pairs"].append((n_responses, n_responses + 1))
data[prompt]["responses"].extend(responses)
data[prompt]["sft_target"] = chosen

def gen():
for prompt, values in data.items():
yield {
"prompt": prompt,
"responses": values["responses"],
"pairs": values["pairs"],
}

return Dataset.from_generator(gen)


def preprocess(sample):
pass


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)

dataset = load_dataset("Dahoas/full-hh-rlhf").map(preprocess)

trlx.train(
config=config,
samples=dataset["train"],
eval_prompts=dataset["test"]["prompt"][:280],
# metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)},
stop_sequences=["Human:", "human:", "Assistant:", "assistant:"],
)


if __name__ == "__main__":
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)
4 changes: 2 additions & 2 deletions trlx/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def __next__(self): # noqa: C901
minibatch = BatchEncoding(sliced_data)
elif is_dataclass(batch):
minibatch = batch.__class__(**sliced_data)
# else:
# minibatch = sliced_data
else:
minibatch = sliced_data

minibatches.append(minibatch)

Expand Down
24 changes: 16 additions & 8 deletions trlx/pipeline/offline_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,19 +288,27 @@ class DPOPreferences:

class DPOStore(BaseRolloutStore):
# Adapted from TRL
def __init__(self, preferences: List[DPOPreferences], tokenizer: PreTrainedTokenizer):
def __init__(
self,
preferences: List[DPOPreferences],
tokenizer: PreTrainedTokenizer,
label_pad_token_id: int,
padding_value: int,
):
super().__init__()
self.tokenizer = tokenizer
self.label_pad_token_id = label_pad_token_id
self.padding_value = padding_value

self.history = [
self._build_batch_from_preference_tokens(preference_element) for preference_element in preferences
]

@staticmethod
def tokenize_preferences(samples, tokenizer, max_length=2048):
chosen_tokens = tokenizer(samples[0], add_special_tokens=False)
rejected_tokens = tokenizer(samples[1], add_special_tokens=False)
prompt_tokens = tokenizer(samples[2], add_special_tokens=False)
chosen_tokens = tokenizer(samples["chosen"], add_special_tokens=False)
rejected_tokens = tokenizer(samples["rejected"], add_special_tokens=False)
prompt_tokens = tokenizer(samples["prompt"], add_special_tokens=False)

chosen_tokens["input_ids"].append(tokenizer.eos_token_id)
chosen_tokens["attention_mask"].append(1)
Expand All @@ -313,14 +321,14 @@ def tokenize_preferences(samples, tokenizer, max_length=2048):
# if combined sequence is too long, truncate the prompt only
if len(prompt_tokens["input_ids"]) + longer_response_length > max_length:
if tokenizer.truncation_side == "right":
prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()}
prompt_tokens = {k: v[:max_length] for k, v in prompt_tokens.items()}
elif tokenizer.truncation_side == "left":
prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()}
prompt_tokens = {k: v[-max_length:] for k, v in prompt_tokens.items()}

# if that's still too long, truncate the response
if len(prompt_tokens["input_ids"]) + longer_response_length > max_length:
chosen_tokens = {k: v[: max_length - max_prompt_length] for k, v in chosen_tokens.items()}
rejected_tokens = {k: v[: max_length - max_prompt_length] for k, v in rejected_tokens.items()}
chosen_tokens = {k: v[: max_length - max_length] for k, v in chosen_tokens.items()}
rejected_tokens = {k: v[: max_length - max_length] for k, v in rejected_tokens.items()}

return DPOPreferences(prompt_tokens=prompt_tokens, chosen_tokens=chosen_tokens, rejected_tokens=rejected_tokens)

Expand Down
17 changes: 10 additions & 7 deletions trlx/trainer/accelerate_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,20 @@ class DPOConfig(MethodConfig):
"""

gen_kwargs: dict
beta: float = 0.1
beta: float = 0.1 # Beta value for DPO loss calculation
label_pad_token_id: int = -100 # -100 is ignore token for CELoss
padding_value: int = 0


@register_trainer
class AccelerateDPOTrainer(AccelerateRLTrainer):
def __init__(self, config: TRLConfig, **kwargs):
super().__init__(config, **kwargs)

# Set up a reference model when hydra heads are not used
if not hasattr(self.model, "frozen_head") and not self.model.peft_type:
self.ref_model = self.get_arch(self.config)
self.ref_model.to(self.accelerator.device)
self.ref_model.eval()
# TODO: Avoid setting up a reference model when hydra heads are used
self.ref_model = self.get_arch(self.config)
self.ref_model.to(self.accelerator.device)
self.ref_model.eval()

self.generate_kwargs = dict(
config.method.gen_kwargs,
Expand All @@ -47,6 +48,8 @@ def __init__(self, config: TRLConfig, **kwargs):

# `beta` corresponding to the DPO hyperparameter
self.beta = config.method.beta
self.label_pad_token_id = config.method.label_pad_token_id
self.padding_value = config.method.padding_value

def get_arch(self, config):
from_fn = AutoModelForCausalLM.from_pretrained
Expand Down Expand Up @@ -250,4 +253,4 @@ def prepare_learning(self):

def make_experience(self, samples, seq_length):
preferences = [DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length) for sample in samples]
self.store = DPOStore(preferences, self.tokenizer)
self.store = DPOStore(preferences, self.tokenizer, self.label_pad_token_id, self.padding_value)
4 changes: 2 additions & 2 deletions trlx/trlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def train( # noqa: C901
config = default_ppo_config()
elif rewards:
config = default_ilql_config()
else:
else: # Alternatively, could be DPO. But, ignoring since passing `config` implicitly is deprecated
config = default_sft_config()

set_seed(config.train.seed)
Expand Down Expand Up @@ -102,7 +102,7 @@ def train( # noqa: C901
if eval_prompts is None:
eval_prompts = prompts[:batch_size]

# Offline training from the collected samples (e.g. SFT, ILQL)
# Offline training from the collected samples (e.g. SFT, ILQL, DPO)
elif samples:
if rewards is not None:
if len(samples) != len(rewards):
Expand Down
1 change: 1 addition & 0 deletions trlx/utils/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# Register load trainers via module import
from trlx.trainer import _TRAINERS, register_trainer
from trlx.trainer.accelerate_dpo_trainer import AccelerateDPOTrainer
from trlx.trainer.accelerate_ilql_trainer import AccelerateILQLTrainer
from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer
from trlx.trainer.accelerate_sft_trainer import AccelerateSFTTrainer
Expand Down

0 comments on commit deb71c1

Please sign in to comment.