diff --git a/examples/hh/dpo_hh.py b/examples/hh/dpo_hh.py new file mode 100644 index 000000000..acddecc82 --- /dev/null +++ b/examples/hh/dpo_hh.py @@ -0,0 +1,99 @@ +import json +import sys +from collections import defaultdict + +import tqdm +from datasets import Dataset, load_dataset + +import trlx +from trlx.data.default_configs import ( + DPOConfig, + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) + +default_config = TRLConfig( + train=TrainConfig( + seq_length=1024, + epochs=100, + total_steps=1000, + batch_size=1, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AccelerateDPOTrainer", + checkpoint_dir="checkpoints/dpo_hh", + ), + model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), + optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps + method=DPOConfig( + name="DPOConfig", gen_kwargs=dict(max_new_tokens=40, top_k=20, top_p=1.0, do_sample=True), beta=0.1 + ), +) + + +def get_hh(split: str, sanity_check=False, silent=False): + dataset = load_dataset("Anthropic/hh-rlhf", split=split) + if sanity_check: + dataset = dataset.select(range(min(len(dataset), 1000))) + + def extract_anthropic_prompt(prompt_and_response): + """Extract the anthropic prompt from a prompt and response pair.""" + search_term = "\n\nAssistant:" + search_term_idx = prompt_and_response.rfind(search_term) + assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" + return prompt_and_response[: search_term_idx + len(search_term)] + + def split_prompt_and_responses(ex): + prompt = extract_anthropic_prompt(ex["chosen"]) + chosen_response = ex["chosen"][len(prompt) :] + rejected_response = ex["rejected"][len(prompt) :] + return prompt, chosen_response, rejected_response + + data = defaultdict(lambda: defaultdict(list)) + for row in tqdm.tqdm(dataset, desc="Processing HH", disable=silent): + prompt, chosen, rejected = split_prompt_and_responses(row) + responses = [chosen, rejected] + n_responses = len(data[prompt]["responses"]) + data[prompt]["pairs"].append((n_responses, n_responses + 1)) + data[prompt]["responses"].extend(responses) + data[prompt]["sft_target"] = chosen + + def gen(): + for prompt, values in data.items(): + yield { + "prompt": prompt, + "responses": values["responses"], + "pairs": values["pairs"], + } + + return Dataset.from_generator(gen) + + +def preprocess(sample): + pass + + +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) + + dataset = load_dataset("Dahoas/full-hh-rlhf").map(preprocess) + + trlx.train( + config=config, + samples=dataset["train"], + eval_prompts=dataset["test"]["prompt"][:280], + # metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)}, + stop_sequences=["Human:", "human:", "Assistant:", "assistant:"], + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/trlx/pipeline/__init__.py b/trlx/pipeline/__init__.py index c7dba9e97..7e927cd49 100644 --- a/trlx/pipeline/__init__.py +++ b/trlx/pipeline/__init__.py @@ -166,8 +166,8 @@ def __next__(self): # noqa: C901 minibatch = BatchEncoding(sliced_data) elif is_dataclass(batch): minibatch = batch.__class__(**sliced_data) - # else: - # minibatch = sliced_data + else: + minibatch = sliced_data minibatches.append(minibatch) diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index 1974dacde..0d0256ffd 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -288,9 +288,17 @@ class DPOPreferences: class DPOStore(BaseRolloutStore): # Adapted from TRL - def __init__(self, preferences: List[DPOPreferences], tokenizer: PreTrainedTokenizer): + def __init__( + self, + preferences: List[DPOPreferences], + tokenizer: PreTrainedTokenizer, + label_pad_token_id: int, + padding_value: int, + ): super().__init__() self.tokenizer = tokenizer + self.label_pad_token_id = label_pad_token_id + self.padding_value = padding_value self.history = [ self._build_batch_from_preference_tokens(preference_element) for preference_element in preferences @@ -298,9 +306,9 @@ def __init__(self, preferences: List[DPOPreferences], tokenizer: PreTrainedToken @staticmethod def tokenize_preferences(samples, tokenizer, max_length=2048): - chosen_tokens = tokenizer(samples[0], add_special_tokens=False) - rejected_tokens = tokenizer(samples[1], add_special_tokens=False) - prompt_tokens = tokenizer(samples[2], add_special_tokens=False) + chosen_tokens = tokenizer(samples["chosen"], add_special_tokens=False) + rejected_tokens = tokenizer(samples["rejected"], add_special_tokens=False) + prompt_tokens = tokenizer(samples["prompt"], add_special_tokens=False) chosen_tokens["input_ids"].append(tokenizer.eos_token_id) chosen_tokens["attention_mask"].append(1) @@ -313,14 +321,14 @@ def tokenize_preferences(samples, tokenizer, max_length=2048): # if combined sequence is too long, truncate the prompt only if len(prompt_tokens["input_ids"]) + longer_response_length > max_length: if tokenizer.truncation_side == "right": - prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()} + prompt_tokens = {k: v[:max_length] for k, v in prompt_tokens.items()} elif tokenizer.truncation_side == "left": - prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()} + prompt_tokens = {k: v[-max_length:] for k, v in prompt_tokens.items()} # if that's still too long, truncate the response if len(prompt_tokens["input_ids"]) + longer_response_length > max_length: - chosen_tokens = {k: v[: max_length - max_prompt_length] for k, v in chosen_tokens.items()} - rejected_tokens = {k: v[: max_length - max_prompt_length] for k, v in rejected_tokens.items()} + chosen_tokens = {k: v[: max_length - max_length] for k, v in chosen_tokens.items()} + rejected_tokens = {k: v[: max_length - max_length] for k, v in rejected_tokens.items()} return DPOPreferences(prompt_tokens=prompt_tokens, chosen_tokens=chosen_tokens, rejected_tokens=rejected_tokens) diff --git a/trlx/trainer/accelerate_dpo_trainer.py b/trlx/trainer/accelerate_dpo_trainer.py index 6e895ab18..2ae2b1750 100644 --- a/trlx/trainer/accelerate_dpo_trainer.py +++ b/trlx/trainer/accelerate_dpo_trainer.py @@ -25,7 +25,9 @@ class DPOConfig(MethodConfig): """ gen_kwargs: dict - beta: float = 0.1 + beta: float = 0.1 # Beta value for DPO loss calculation + label_pad_token_id: int = -100 # -100 is ignore token for CELoss + padding_value: int = 0 @register_trainer @@ -33,11 +35,10 @@ class AccelerateDPOTrainer(AccelerateRLTrainer): def __init__(self, config: TRLConfig, **kwargs): super().__init__(config, **kwargs) - # Set up a reference model when hydra heads are not used - if not hasattr(self.model, "frozen_head") and not self.model.peft_type: - self.ref_model = self.get_arch(self.config) - self.ref_model.to(self.accelerator.device) - self.ref_model.eval() + # TODO: Avoid setting up a reference model when hydra heads are used + self.ref_model = self.get_arch(self.config) + self.ref_model.to(self.accelerator.device) + self.ref_model.eval() self.generate_kwargs = dict( config.method.gen_kwargs, @@ -47,6 +48,8 @@ def __init__(self, config: TRLConfig, **kwargs): # `beta` corresponding to the DPO hyperparameter self.beta = config.method.beta + self.label_pad_token_id = config.method.label_pad_token_id + self.padding_value = config.method.padding_value def get_arch(self, config): from_fn = AutoModelForCausalLM.from_pretrained @@ -250,4 +253,4 @@ def prepare_learning(self): def make_experience(self, samples, seq_length): preferences = [DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length) for sample in samples] - self.store = DPOStore(preferences, self.tokenizer) + self.store = DPOStore(preferences, self.tokenizer, self.label_pad_token_id, self.padding_value) diff --git a/trlx/trlx.py b/trlx/trlx.py index 7fbce94f4..6d98019fd 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -64,7 +64,7 @@ def train( # noqa: C901 config = default_ppo_config() elif rewards: config = default_ilql_config() - else: + else: # Alternatively, could be DPO. But, ignoring since passing `config` implicitly is deprecated config = default_sft_config() set_seed(config.train.seed) @@ -102,7 +102,7 @@ def train( # noqa: C901 if eval_prompts is None: eval_prompts = prompts[:batch_size] - # Offline training from the collected samples (e.g. SFT, ILQL) + # Offline training from the collected samples (e.g. SFT, ILQL, DPO) elif samples: if rewards is not None: if len(samples) != len(rewards): diff --git a/trlx/utils/loading.py b/trlx/utils/loading.py index 9c7dccf76..3dc5a4c52 100644 --- a/trlx/utils/loading.py +++ b/trlx/utils/loading.py @@ -6,6 +6,7 @@ # Register load trainers via module import from trlx.trainer import _TRAINERS, register_trainer +from trlx.trainer.accelerate_dpo_trainer import AccelerateDPOTrainer from trlx.trainer.accelerate_ilql_trainer import AccelerateILQLTrainer from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer from trlx.trainer.accelerate_sft_trainer import AccelerateSFTTrainer