From 5548add49c684c77105c120e339df515ff6dc39d Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Thu, 4 Apr 2024 12:17:25 -0400 Subject: [PATCH] Rework to allow for duplicates --- src/transformers/trainer.py | 40 ++++++++++++--------- src/transformers/trainer_callback.py | 8 ++++- tests/trainer/test_trainer_callback.py | 50 ++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 18 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 136f40f6f17a40..a32f142dfcedfe 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2989,24 +2989,28 @@ def _load_callback_state(self): return # Callback states are stored in stateful_callbacks not_found = [] + new_callbacks = [] + original_callbacks = self.callback_handler.callbacks + [self.control] for stored_callback, data in self.state.stateful_callbacks.items(): - if any( - callback.__class__.__name__ == stored_callback - for callback in self.callback_handler.callbacks + [self.control] - ): - for callback in self.callback_handler.callbacks + [self.control]: - if callback.__class__.__name__ == stored_callback: - args, attributes = data.values() - new_callback = type(callback)(**args) - for attribute, value in attributes.items(): - setattr(new_callback, attribute, value) - if isinstance(callback, TrainerControl): - # Specifically for restoring the `control` state - self.control = new_callback - else: - # We remove the existing callback and add a new one - self.callback_handler.remove_callback(callback) - self.callback_handler.add_callback(new_callback) + if not isinstance(data, list): + data = [data] + if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks): + matches = [ + callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback + ] + for callback, callback_data in zip(matches, data): + args = callback_data.get("args", {}) + attributes = callback_data.get("attributes", {}) + new_callback = type(callback)(**args) + for attribute, value in attributes.items(): + setattr(new_callback, attribute, value) + if isinstance(callback, TrainerControl): + # Specifically for restoring the `control` state + self.control = new_callback + else: + new_callbacks.append(new_callback) + # We remove the existing callback and add it to the list of new callbacks + self.callback_handler.remove_callback(type(new_callback)) logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in") else: not_found.append(stored_callback) @@ -3014,6 +3018,8 @@ def _load_callback_state(self): logger.warning( f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})" ) + for callback in new_callbacks: + self.callback_handler.add_callback(callback) def hyperparameter_search( self, diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 7664a2d510af54..4bdf52ed48e0aa 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -125,7 +125,13 @@ def __post_init__(self): raise TypeError( f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}" ) - stateful_callbacks[callback.__class__.__name__] = callback.state() + name = callback.__class__.__name__ + if name in stateful_callbacks: + if not isinstance(stateful_callbacks[name], list): + stateful_callbacks[name] = [stateful_callbacks[name]] + stateful_callbacks[name].append(callback.state()) + else: + stateful_callbacks[callback.__class__.__name__] = callback.state() self.stateful_callbacks = stateful_callbacks def save_to_json(self, json_path: str): diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py index 9db293b9b65da0..e08ec5065fd891 100644 --- a/tests/trainer/test_trainer_callback.py +++ b/tests/trainer/test_trainer_callback.py @@ -32,6 +32,7 @@ is_torch_available, ) from transformers.testing_utils import require_torch +from transformers.trainer_callback import ExportableState if is_torch_available(): @@ -40,6 +41,18 @@ from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel +class MyTestExportableCallback(TrainerCallback, ExportableState): + def __init__(self, my_test_state="test"): + self.my_test_state = my_test_state + + def state(self): + return { + "args": { + "my_test_state": self.my_test_state, + }, + } + + class MyTestTrainerCallback(TrainerCallback): "A callback that registers the events that goes through." @@ -326,6 +339,43 @@ def test_stateful_mixed_callbacks(self): assert early_stopping.early_stopping_threshold == 0.2 assert my_test.my_test_state == "test" + def test_stateful_duplicate_callbacks(self): + # Use something with non-defaults + cbs = [MyTestExportableCallback("first"), MyTestExportableCallback("second")] + trainer = self.get_trainer( + callbacks=cbs, + load_best_model_at_end=True, + save_strategy="steps", + evaluation_strategy="steps", + save_steps=2, + eval_steps=2, + max_steps=2, + ) + trainer.train() + + # Create a new trainer with defaults + trainer = self.get_trainer( + callbacks=[MyTestExportableCallback(), MyTestExportableCallback()], + load_best_model_at_end=True, + save_strategy="steps", + evaluation_strategy="steps", + save_steps=2, + eval_steps=2, + max_steps=2, + restore_callback_states_from_checkpoint=True, + ) + # Load it back in and verify values + checkpoint = os.path.join(self.output_dir, "checkpoint-2") + trainer.train(resume_from_checkpoint=checkpoint) + cbs = [ + callback + for callback in trainer.callback_handler.callbacks + if isinstance(callback, MyTestExportableCallback) + ] + assert len(cbs) == 2 + assert cbs[0].my_test_state == "first" + assert cbs[1].my_test_state == "second" + def test_missing_stateful_callback(self): cb = EarlyStoppingCallback() trainer = self.get_trainer(