From efc687db62d95eb832636592b3ee21457172b93d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 12 Dec 2024 12:53:32 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=9B=A0=EF=B8=8F=20Update=20tests=20and=20?= =?UTF-8?q?fix=20PPO=20(#2463)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [bugfix] critic not update * Update ppo_trainer.py * Update ppo_trainer.py * add failing test * test both policy and critic * formatting * fix tests * formatting * Update tests/test_ppo_trainer.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * fix test --------- Co-authored-by: NINGBENZHE <53843873+NINGBENZHE@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- tests/test_ppo_trainer.py | 293 ++++++++++++++++++++++++++----------- trl/trainer/ppo_trainer.py | 66 ++++----- 2 files changed, 237 insertions(+), 122 deletions(-) diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index ad9b19a7b9..5560aa80a9 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -12,94 +12,213 @@ # See the License for the specific language governing permissions and # limitations under the License. -import platform -import subprocess +import tempfile +import unittest +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer from transformers.testing_utils import require_peft +from trl import PPOConfig, PPOTrainer +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -def test(): - command = """\ -python examples/scripts/ppo/ppo.py \ - --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ - --dataset_train_split descriptiveness \ - --learning_rate 3e-6 \ - --output_dir models/minimal/ppo \ - --per_device_train_batch_size 4 \ - --gradient_accumulation_steps 1 \ - --total_episodes 10 \ - --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 \ - --reward_model_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 \ - --sft_model_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 \ - --missing_eos_penalty 1.0 \ - --save_strategy no \ - --stop_token eos -""" - if platform.system() == "Windows": - # windows CI does not work with subprocesses for some reason - # e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743 - return - subprocess.run( - command, - shell=True, - check=True, - ) - - -def test_num_train_epochs(): - command = """\ -python examples/scripts/ppo/ppo.py \ - --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ - --dataset_train_split descriptiveness \ - --learning_rate 3e-6 \ - --output_dir models/minimal/ppo \ - --per_device_train_batch_size 4 \ - --gradient_accumulation_steps 1 \ - --num_train_epochs 0.003 \ - --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 \ - --reward_model_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 \ - --sft_model_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 \ - --missing_eos_penalty 1.0 \ - --save_strategy no \ - --stop_token eos -""" - if platform.system() == "Windows": - # windows CI does not work with subprocesses for some reason - # e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743 - return - subprocess.run( - command, - shell=True, - check=True, - ) - - -@require_peft -def test_peft_support(): - command = """\ -python examples/scripts/ppo/ppo.py \ - --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ - --dataset_train_split descriptiveness \ - --learning_rate 3e-6 \ - --output_dir models/minimal/ppo \ - --per_device_train_batch_size 4 \ - --gradient_accumulation_steps 1 \ - --total_episodes 10 \ - --model_name_or_path EleutherAI/pythia-14m \ - --missing_eos_penalty 1.0 \ - --save_strategy no \ - --stop_token eos \ - --use_peft \ - --lora_r 32 \ - --lora_alpha 16 \ - --lora_target_modules query_key_value dense -""" - if platform.system() == "Windows": - # windows CI does not work with subprocesses for some reason - # e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743 - return - subprocess.run( - command, - shell=True, - check=True, - ) + +class TestPPOTrainer(unittest.TestCase): + def setUp(self): + # Set up the models and tokenizer using the test model + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left") + self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + if self.tokenizer.chat_template is None: + self.tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + + # Add reward and value models as in ppo.py + self.value_model = AutoModelForSequenceClassification.from_pretrained( + self.model_id, trust_remote_code=True, num_labels=1 + ) + self.reward_model = AutoModelForSequenceClassification.from_pretrained( + self.model_id, trust_remote_code=True, num_labels=1 + ) + + # Load dataset + raw_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + self.raw_dataset = raw_dataset.map(lambda x: self.tokenizer(x["prompt"]), remove_columns=["prompt"]) + + def test_basic_training(self): + """Test basic PPO training configuration and verify model updates.""" + with tempfile.TemporaryDirectory() as tmp_dir: + # Capture initial weights + initial_critic_weights = {} + initial_policy_weights = {} + for name, param in self.value_model.named_parameters(): + initial_critic_weights[name] = param.clone().detach() + for name, param in self.model.named_parameters(): + initial_policy_weights[name] = param.clone().detach() + + # Configure training args similar to example script + training_args = PPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=4, + per_device_eval_batch_size=2, + report_to="none", + missing_eos_penalty=1.0, + vf_coef=1.0, # Increase value function coefficient + num_ppo_epochs=4, # Increase number of PPO epochs + ) + + # Create trainer + trainer = PPOTrainer( + args=training_args, + processing_class=self.tokenizer, + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + value_model=self.value_model, + train_dataset=self.raw_dataset["train"], + eval_dataset=self.raw_dataset["test"], + ) + + # Train + trainer.train() + + # Check if critic weights have been updated + critic_weights_updated = False + for name, param in trainer.model.value_model.named_parameters(): + if not torch.allclose(initial_critic_weights[name], param.to("cpu")): + critic_weights_updated = True + break + + # Check if policy weights have been updated + policy_weights_updated = False + for name, param in trainer.model.policy.named_parameters(): + if not torch.allclose(initial_policy_weights[name], param.to("cpu")): + policy_weights_updated = True + break + + self.assertTrue(critic_weights_updated, "Critic weights were not updated during training") + self.assertTrue(policy_weights_updated, "Policy weights were not updated during training") + + @require_peft + def test_peft_training(self): + """Test PPO training with PEFT configuration and verify model updates.""" + from peft import LoraConfig + + with tempfile.TemporaryDirectory() as tmp_dir: + # Capture initial weights + initial_critic_weights = {} + initial_policy_weights = {} + for name, param in self.value_model.named_parameters(): + initial_critic_weights[name] = param.clone().detach() + for name, param in self.model.named_parameters(): + initial_policy_weights[name] = param.clone().detach() + + # Configure training args + training_args = PPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=4, + per_device_eval_batch_size=2, + report_to="none", + missing_eos_penalty=1.0, + ) + + # Configure PEFT + peft_config = LoraConfig( + r=32, + lora_alpha=16, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + # Create trainer with PEFT + trainer = PPOTrainer( + args=training_args, + processing_class=self.tokenizer, + model=self.model, + ref_model=None, + reward_model=self.reward_model, + value_model=self.value_model, + train_dataset=self.raw_dataset["train"], + eval_dataset=self.raw_dataset["test"], + peft_config=peft_config, + ) + + # Train + trainer.train() + + # Check if critic weights have been updated + critic_weights_updated = False + for name, param in trainer.model.value_model.named_parameters(): + if name in initial_critic_weights and not torch.allclose( + initial_critic_weights[name], param.to("cpu") + ): + critic_weights_updated = True + break + + # Check if policy weights have been updated - for PEFT we check the LoRA weights + policy_weights_updated = False + for name, param in trainer.model.policy.named_parameters(): + if "lora" in name.lower() and param.requires_grad: # Only check LoRA weights + # New weights should be non-zero if they've been updated + if not torch.allclose(param, torch.zeros_like(param)): + policy_weights_updated = True + break + + self.assertTrue(critic_weights_updated, "Critic weights were not updated during training") + self.assertTrue(policy_weights_updated, "Policy LoRA weights were not updated during training") + + def test_with_num_train_epochs(self): + """Test PPO training with num_train_epochs configuration.""" + with tempfile.TemporaryDirectory() as tmp_dir: + # Capture initial weights + initial_critic_weights = {} + initial_policy_weights = {} + for name, param in self.value_model.named_parameters(): + initial_critic_weights[name] = param.clone().detach() + for name, param in self.model.named_parameters(): + initial_policy_weights[name] = param.clone().detach() + + # Configure training args + training_args = PPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=4, + per_device_eval_batch_size=2, + report_to="none", + missing_eos_penalty=1.0, + ) + + # Create trainer + trainer = PPOTrainer( + args=training_args, + processing_class=self.tokenizer, + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + value_model=self.value_model, + train_dataset=self.raw_dataset["train"], + eval_dataset=self.raw_dataset["test"], + ) + + # Train and verify no exceptions are raised + trainer.train() + + # Check if critic weights have been updated + critic_weights_updated = False + for name, param in trainer.model.value_model.named_parameters(): + if not torch.allclose(initial_critic_weights[name], param.to("cpu")): + critic_weights_updated = True + break + + # Check if policy weights have been updated + policy_weights_updated = False + for name, param in trainer.model.policy.named_parameters(): + if not torch.allclose(initial_policy_weights[name], param.to("cpu")): + policy_weights_updated = True + break + + self.assertTrue(critic_weights_updated, "Critic weights were not updated during training") + self.assertTrue(policy_weights_updated, "Policy weights were not updated during training") diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index b51c29c014..3827caae30 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -130,16 +130,16 @@ def __init__( self.args = args self.processing_class = processing_class - self.model = model + self.policy_model = model # Define the collator if not provided if data_collator is None: data_collator = DataCollatorWithPadding(self.processing_class) - self.model.generation_config.eos_token_id = ( + self.policy_model.generation_config.eos_token_id = ( None # disable `pad_token_id` and `eos_token_id` because we just want to ) - self.model.generation_config.pad_token_id = None # generate tokens without truncation / padding + self.policy_model.generation_config.pad_token_id = None # generate tokens without truncation / padding # peft support if not is_peft_available() and peft_config is not None: @@ -148,15 +148,15 @@ def __init__( ) elif is_peft_available() and peft_config is not None: # if model is a peft model and we have a peft_confg, we merge and unload it first - if isinstance(self.model, PeftModel): - self.model = self.model.merge_and_unload() + if isinstance(self.policy_model, PeftModel): + self.policy_model = self.policy_model.merge_and_unload() # get peft model with the given config - self.model = get_peft_model(self.model, peft_config) - if args.bf16 and getattr(self.model, "is_loaded_in_4bit", False): - peft_module_casting_to_bf16(self.model) + self.policy_model = get_peft_model(self.policy_model, peft_config) + if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(self.policy_model) - self.is_peft_model = is_peft_available() and isinstance(self.model, PeftModel) + self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel) self.model_adapter_name = args.model_adapter_name self.ref_adapter_name = args.ref_adapter_name @@ -165,7 +165,7 @@ def __init__( elif self.is_peft_model: self.ref_model = None else: - self.ref_model = create_reference_model(self.model) + self.ref_model = create_reference_model(self.policy_model) self.reward_model = reward_model self.train_dataset = train_dataset @@ -215,13 +215,13 @@ def __init__( ######### # setup model, optimizer, and others ######### - for module in [self.model, self.ref_model, self.value_model, self.reward_model]: + for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]: if module is not None: disable_dropout_in_model(module) if args.stop_token and args.stop_token == "eos": args.stop_token_id = processing_class.eos_token_id - self.policy_and_value = PolicyAndValueWrapper(self.model, self.value_model) - self.policy_and_value.config = self.model.config # needed for pushing to hub + self.model = PolicyAndValueWrapper(self.policy_model, self.value_model) + self.model.config = self.policy_model.config # needed for pushing to hub self.create_optimizer_and_scheduler( num_training_steps=args.num_total_batches ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level @@ -232,7 +232,7 @@ def __init__( default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks self.callback_handler = CallbackHandler( - self.callbacks, self.policy_and_value, self.processing_class, self.optimizer, self.lr_scheduler + self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler ) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) self.control = TrainerControl() @@ -255,8 +255,8 @@ def __init__( os.makedirs(self.args.output_dir, exist_ok=True) # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.policy_and_value, "add_model_tags"): - self.policy_and_value.add_model_tags(self._tag_names) + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) ######### ### setup dataloader @@ -271,9 +271,7 @@ def __init__( # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c torch.manual_seed(args.seed) - self.policy_and_value, self.optimizer, self.dataloader = accelerator.prepare( - self.policy_and_value, self.optimizer, self.dataloader - ) + self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) torch.manual_seed(self.local_seed) # reset the local seed again self.eval_dataloader = DataLoader( @@ -314,25 +312,25 @@ def get_eval_dataloader(self) -> DataLoader: def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" with self.accelerator.unwrap_model( - self.policy_and_value.policy + self.model.policy ).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): if self.ref_adapter_name: - self.policy_and_value.policy.set_adapter(self.ref_adapter_name) + self.model.policy.set_adapter(self.ref_adapter_name) yield if self.ref_adapter_name: - self.policy_and_value.policy.set_adapter(self.model_adapter_name or "default") + self.model.policy.set_adapter(self.model_adapter_name or "default") def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): - backup_model = self.policy_and_value - self.policy_and_value = self.policy_and_value.policy # save only the policy + backup_model = self.model + self.model = self.model.policy # save only the policy if self.is_deepspeed_enabled: backup_deepspeed = self.deepspeed - self.deepspeed = self.policy_and_value + self.deepspeed = self.model super().save_model(output_dir, _internal_call) - self.policy_and_value = backup_model + self.model = backup_model if self.is_deepspeed_enabled: self.deepspeed = backup_deepspeed @@ -341,7 +339,7 @@ def train(self): args = self.args accelerator = self.accelerator optimizer = self.optimizer - model = self.policy_and_value + model = self.model ref_policy = self.ref_model reward_model = self.reward_model processing_class = self.processing_class @@ -398,8 +396,8 @@ def repeat_generator(): # backward compatibility if self.is_deepspeed_enabled: - self.deepspeed = self.policy_and_value - self.model_wrapped = self.policy_and_value + self.deepspeed = self.model + self.model_wrapped = self.model for update in range(1, args.num_total_batches + 1): self.state.episode += 1 * args.batch_size @@ -686,7 +684,7 @@ def generate_completions(self, sampling: bool = False): ) table = defaultdict(list) - with unwrap_model_for_generation(self.policy_and_value, self.accelerator) as unwrapped_model: + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: for batch in self.eval_dataloader: query = batch["input_ids"] with torch.no_grad(): @@ -749,10 +747,8 @@ def create_model_card( if not self.is_world_process_zero(): return - if hasattr(self.policy_and_value.config, "_name_or_path") and not os.path.isdir( - self.policy_and_value.config._name_or_path - ): - base_model = self.policy_and_value.config._name_or_path + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path else: base_model = None @@ -760,7 +756,7 @@ def create_model_card( if isinstance(tags, str): tags = [tags] - if hasattr(self.policy_and_value.config, "unsloth_version"): + if hasattr(self.model.config, "unsloth_version"): tags.append("unsloth") citation = textwrap.dedent("""\