From 5af7af154995d0d2292ea0b19eb8a824748da4b7 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Thu, 25 Apr 2024 11:00:09 -0400 Subject: [PATCH] Introduce Stateful Callbacks (#29666) * Introduce saveable callbacks * Add note * Test for non-present and flag * Support early stopping and refusing to train further * Update docstring * More saving * Import oopsie * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Make it go through TrainerArguments * Document * Fix test * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Rework to allow for duplicates * CLean * Fix failing tests --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/trainer.py | 58 +++++++- src/transformers/trainer_callback.py | 98 +++++++++++++- src/transformers/training_args.py | 9 ++ tests/trainer/test_trainer_callback.py | 179 ++++++++++++++++++++++++- 4 files changed, 337 insertions(+), 7 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 52beb6c1e56ff5..26ab0877d0c171 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -78,6 +78,7 @@ from .trainer_callback import ( CallbackHandler, DefaultFlowCallback, + ExportableState, PrinterCallback, ProgressCallback, TrainerCallback, @@ -649,12 +650,15 @@ def __init__( else: self.label_smoother = None + self.control = TrainerControl() + self.state = TrainerState( is_local_process_zero=self.is_local_process_zero(), is_world_process_zero=self.is_world_process_zero(), + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ], ) - - self.control = TrainerControl() # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then # returned to 0 every time flos need to be logged self.current_flos = 0 @@ -1499,6 +1503,8 @@ def _tune_save_checkpoint(self, checkpoint_dir: str): output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") self.save_model(output_dir, _internal_call=True) if self.args.should_save: + # Update the `TrainerControl` state to where we are currently + self.state.stateful_callbacks["TrainerControl"] = self.control.state() self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) @@ -1970,7 +1976,11 @@ def _inner_training_loop( if not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) - self.state = TrainerState() + self.state = TrainerState( + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ] + ) self.state.is_hyper_param_search = trial is not None self.state.train_batch_size = self._train_batch_size @@ -2079,6 +2089,7 @@ def _inner_training_loop( ): self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) self.compare_trainer_and_checkpoint_args(self.args, self.state) + self._load_callback_state() epochs_trained = self.state.global_step // num_update_steps_per_epoch if not args.ignore_data_skip: steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) @@ -2786,6 +2797,8 @@ def _save_checkpoint(self, model, trial, metrics=None): # Save the Trainer state if self.args.should_save: + # Update the `TrainerControl` state to where we are currently + self.state.stateful_callbacks["TrainerControl"] = self.control.state() self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) if self.args.push_to_hub: @@ -2970,6 +2983,45 @@ def opt_load_hook(mod, opt): self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) reissue_pt_warnings(caught_warnings) + def _load_callback_state(self): + """If callback states exist and were passed in, restore their states if enabled""" + if not self.args.restore_callback_states_from_checkpoint: + return + # Callback states are stored in stateful_callbacks + not_found = [] + new_callbacks = [] + original_callbacks = self.callback_handler.callbacks + [self.control] + for stored_callback, data in self.state.stateful_callbacks.items(): + if not isinstance(data, list): + data = [data] + if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks): + # We can load/restore from multiple callbacks of the same type. + duplicates = [ + callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback + ] + for callback, callback_data in zip(duplicates, data): + args = callback_data.get("args", {}) + attributes = callback_data.get("attributes", {}) + new_callback = type(callback)(**args) + for attribute, value in attributes.items(): + setattr(new_callback, attribute, value) + if isinstance(callback, TrainerControl): + # Specifically for restoring the `control` state + self.control = new_callback + else: + new_callbacks.append(new_callback) + # We remove the existing callback and add it to the list of new callbacks + self.callback_handler.remove_callback(type(new_callback)) + logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in") + else: + not_found.append(stored_callback) + if len(not_found) > 0: + logger.warning( + f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})" + ) + for callback in new_callbacks: + self.callback_handler.add_callback(callback) + def hyperparameter_search( self, hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 53eb49401d8da4..a21c46fea9fe2a 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -84,6 +84,9 @@ class TrainerState: is_hyper_param_search (`bool`, *optional*, defaults to `False`): Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will impact the way data will be logged in TensorBoard. + stateful_callbacks (`List[StatefulTrainerCallback]`, *optional*): + Callbacks attached to the `Trainer` that should have their states be saved or restored. + Relevent callbacks should implement a `state` and `from_state` function. """ epoch: Optional[float] = None @@ -104,10 +107,34 @@ class TrainerState: is_hyper_param_search: bool = False trial_name: str = None trial_params: Dict[str, Union[str, float, int, bool]] = None + stateful_callbacks: List["TrainerCallback"] = None def __post_init__(self): if self.log_history is None: self.log_history = [] + if self.stateful_callbacks is None: + self.stateful_callbacks = {} + elif isinstance(self.stateful_callbacks, dict): + # We are loading the callbacks in from the state file, no need to process them + pass + else: + # Saveable callbacks get stored as dict of kwargs + stateful_callbacks = {} + for callback in self.stateful_callbacks: + if not isinstance(callback, (ExportableState)): + raise TypeError( + f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}" + ) + name = callback.__class__.__name__ + if name in stateful_callbacks: + # We can have multiple versions of the same callback + # if so, we store them as a list of states to restore + if not isinstance(stateful_callbacks[name], list): + stateful_callbacks[name] = [stateful_callbacks[name]] + stateful_callbacks[name].append(callback.state()) + else: + stateful_callbacks[name] = callback.state() + self.stateful_callbacks = stateful_callbacks def save_to_json(self, json_path: str): """Save the content of this instance in JSON format inside `json_path`.""" @@ -123,8 +150,52 @@ def load_from_json(cls, json_path: str): return cls(**json.loads(text)) +class ExportableState: + """ + A class for objects that include the ability to have its state + be saved during `Trainer._save_checkpoint` and loaded back in during + `Trainer._load_from_checkpoint`. + + These must implement a `state` function that gets called during the respective + Trainer function call. It should only include parameters and attributes needed to + recreate the state at a particular time, to avoid utilizing pickle/maintain standard + file IO writing. + + Example: + + ```python + class EarlyStoppingCallback(TrainerCallback, ExportableState): + def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0): + self.early_stopping_patience = early_stopping_patience + self.early_stopping_threshold = early_stopping_threshold + # early_stopping_patience_counter denotes the number of times validation metrics failed to improve. + self.early_stopping_patience_counter = 0 + + def state(self) -> dict: + return { + "args": { + "early_stopping_patience": self.early_stopping_patience, + "early_stopping_threshold": self.early_stopping_threshold, + }, + "attributes": { + "early_stopping_patience_counter": self.early_stopping_patience_counter, + } + } + ```""" + + def state(self) -> dict: + raise NotImplementedError("You must implement a `state` function to utilize this class.") + + @classmethod + def from_state(cls, state): + instance = cls(**state["args"]) + for k, v in state["attributes"].items(): + setattr(instance, k, v) + return instance + + @dataclass -class TrainerControl: +class TrainerControl(ExportableState): """ A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some switches in the training loop. @@ -172,6 +243,18 @@ def _new_step(self): self.should_evaluate = False self.should_log = False + def state(self) -> dict: + return { + "args": { + "should_training_stop": self.should_training_stop, + "should_epoch_stop": self.should_epoch_stop, + "should_save": self.should_save, + "should_evaluate": self.should_evaluate, + "should_log": self.should_log, + }, + "attributes": {}, + } + class TrainerCallback: # no-format @@ -546,7 +629,7 @@ def on_log(self, args, state, control, logs=None, **kwargs): print(logs) -class EarlyStoppingCallback(TrainerCallback): +class EarlyStoppingCallback(TrainerCallback, ExportableState): """ A [`TrainerCallback`] that handles early stopping. @@ -605,3 +688,14 @@ def on_evaluate(self, args, state, control, metrics, **kwargs): self.check_metric_value(args, state, control, metric_value) if self.early_stopping_patience_counter >= self.early_stopping_patience: control.should_training_stop = True + + def state(self) -> dict: + return { + "args": { + "early_stopping_patience": self.early_stopping_patience, + "early_stopping_threshold": self.early_stopping_threshold, + }, + "attributes": { + "early_stopping_patience_counter": self.early_stopping_patience_counter, + }, + } diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 91472eed9b0314..12ae77908ebfae 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -357,6 +357,9 @@ class TrainingArguments: Note that when this is true, you won't be able to resume training from checkpoint. This enables you to save storage by not storing the optimizer, scheduler & rng state. You can only load the model using `from_pretrained` with this option set to `True`. + restore_callback_states_from_checkpoint (`bool`, *optional*, defaults to `False`): + Whether to restore the callback states from the checkpoint. If `True`, will override + callbacks passed to the `Trainer` if they exist in the checkpoint." use_cpu (`bool`, *optional*, defaults to `False`): Whether or not to use cpu. If set to False, we will use cuda or mps device if available. seed (`int`, *optional*, defaults to 42): @@ -951,6 +954,12 @@ class TrainingArguments: ) }, ) + restore_callback_states_from_checkpoint: bool = field( + default=False, + metadata={ + "help": "Whether to restore the callback states from the checkpoint. If `True`, will override callbacks passed to the `Trainer` if they exist in the checkpoint." + }, + ) no_cuda: bool = field( default=False, metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."}, diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py index b712edca385c25..8c0c9367d8d779 100644 --- a/tests/trainer/test_trainer_callback.py +++ b/tests/trainer/test_trainer_callback.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import os import shutil import tempfile import unittest @@ -19,28 +21,44 @@ from transformers import ( DefaultFlowCallback, + EarlyStoppingCallback, IntervalStrategy, PrinterCallback, ProgressCallback, Trainer, TrainerCallback, + TrainerState, TrainingArguments, is_torch_available, ) from transformers.testing_utils import require_torch +from transformers.trainer_callback import ExportableState if is_torch_available(): - from transformers.trainer import DEFAULT_CALLBACKS + from transformers.trainer import DEFAULT_CALLBACKS, TRAINER_STATE_NAME from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel +class MyTestExportableCallback(TrainerCallback, ExportableState): + def __init__(self, my_test_state="test"): + self.my_test_state = my_test_state + + def state(self): + return { + "args": { + "my_test_state": self.my_test_state, + }, + } + + class MyTestTrainerCallback(TrainerCallback): "A callback that registers the events that goes through." - def __init__(self): + def __init__(self, my_test_state="test"): self.events = [] + self.my_test_state = my_test_state def on_init_end(self, args, state, control, **kwargs): self.events.append("on_init_end") @@ -243,3 +261,160 @@ def test_event_flow(self): callbacks=[MyTestTrainerCallback, MyTestTrainerCallback], ) assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0] + + def test_stateful_callbacks(self): + # Use something with non-defaults + cb = EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.2) + trainer = self.get_trainer( + callbacks=[cb], + load_best_model_at_end=True, + save_strategy="steps", + eval_strategy="steps", + save_steps=2, + eval_steps=2, + max_steps=2, + ) + trainer.train() + + # Create a new trainer with defaults + trainer = self.get_trainer( + callbacks=[EarlyStoppingCallback()], + load_best_model_at_end=True, + save_strategy="steps", + eval_strategy="steps", + save_steps=2, + eval_steps=2, + max_steps=2, + restore_callback_states_from_checkpoint=True, + ) + # Load it back in and verify values + checkpoint = os.path.join(self.output_dir, "checkpoint-2") + trainer.train(resume_from_checkpoint=checkpoint) + cb = [ + callback for callback in trainer.callback_handler.callbacks if isinstance(callback, EarlyStoppingCallback) + ][0] + assert cb.early_stopping_patience == 5 + assert cb.early_stopping_threshold == 0.2 + + def test_stateful_mixed_callbacks(self): + # Use two callbacks, one stateful one not + # Use something with non-defaults + cbs = [ + MyTestTrainerCallback(my_test_state="another value"), + EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.2), + ] + trainer = self.get_trainer( + callbacks=cbs, + load_best_model_at_end=True, + save_strategy="steps", + eval_strategy="steps", + save_steps=2, + eval_steps=2, + max_steps=2, + ) + trainer.train() + + # Create a new trainer with defaults + trainer = self.get_trainer( + callbacks=[EarlyStoppingCallback(), MyTestTrainerCallback()], + load_best_model_at_end=True, + save_strategy="steps", + eval_strategy="steps", + save_steps=2, + eval_steps=2, + max_steps=2, + restore_callback_states_from_checkpoint=True, + ) + # Load it back in and verify values + checkpoint = os.path.join(self.output_dir, "checkpoint-2") + trainer.train(resume_from_checkpoint=checkpoint) + cbs = [ + callback + for callback in trainer.callback_handler.callbacks + if isinstance(callback, (EarlyStoppingCallback, MyTestTrainerCallback)) + ] + assert len(cbs) == 2 + my_test, early_stopping = cbs + assert early_stopping.early_stopping_patience == 5 + assert early_stopping.early_stopping_threshold == 0.2 + assert my_test.my_test_state == "test" + + def test_stateful_duplicate_callbacks(self): + # Use something with non-defaults + cbs = [MyTestExportableCallback("first"), MyTestExportableCallback("second")] + trainer = self.get_trainer( + callbacks=cbs, + load_best_model_at_end=True, + save_strategy="steps", + eval_strategy="steps", + save_steps=2, + eval_steps=2, + max_steps=2, + ) + trainer.train() + + # Create a new trainer with defaults + trainer = self.get_trainer( + callbacks=[MyTestExportableCallback(), MyTestExportableCallback()], + load_best_model_at_end=True, + save_strategy="steps", + eval_strategy="steps", + save_steps=2, + eval_steps=2, + max_steps=2, + restore_callback_states_from_checkpoint=True, + ) + # Load it back in and verify values + checkpoint = os.path.join(self.output_dir, "checkpoint-2") + trainer.train(resume_from_checkpoint=checkpoint) + cbs = [ + callback + for callback in trainer.callback_handler.callbacks + if isinstance(callback, MyTestExportableCallback) + ] + assert len(cbs) == 2 + assert cbs[0].my_test_state == "first" + assert cbs[1].my_test_state == "second" + + def test_missing_stateful_callback(self): + cb = EarlyStoppingCallback() + trainer = self.get_trainer( + callbacks=[cb], + load_best_model_at_end=True, + save_strategy="steps", + eval_strategy="steps", + save_steps=2, + eval_steps=2, + max_steps=2, + ) + trainer.train() + + # Create a new trainer with defaults + trainer = self.get_trainer( + save_strategy="steps", + eval_strategy="steps", + save_steps=2, + eval_steps=2, + max_steps=2, + restore_callback_states_from_checkpoint=True, + ) + # Load it back in and verify values + checkpoint = os.path.join(self.output_dir, "checkpoint-2") + # warning should be emitted for not-present callbacks + with patch("transformers.trainer.logger.warning") as warn_mock: + trainer.train(resume_from_checkpoint=checkpoint) + assert "EarlyStoppingCallback" in warn_mock.call_args[0][0] + + def test_stateful_control(self): + trainer = self.get_trainer( + max_steps=2, + save_strategy="steps", + save_steps=2, + ) + trainer.train() + # Load it back in and verify values + trainer = self.get_trainer(max_steps=2, restore_callback_states_from_checkpoint=True) + checkpoint = os.path.join(self.output_dir, "checkpoint-2") + trainer.state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME)) + trainer._load_callback_state() + assert trainer.control.should_training_stop