Skip to content

Commit

Permalink
Rework to allow for duplicates
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Apr 25, 2024
1 parent de4c603 commit 5548add
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 18 deletions.
40 changes: 23 additions & 17 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2989,31 +2989,37 @@ def _load_callback_state(self):
return
# Callback states are stored in stateful_callbacks
not_found = []
new_callbacks = []
original_callbacks = self.callback_handler.callbacks + [self.control]
for stored_callback, data in self.state.stateful_callbacks.items():
if any(
callback.__class__.__name__ == stored_callback
for callback in self.callback_handler.callbacks + [self.control]
):
for callback in self.callback_handler.callbacks + [self.control]:
if callback.__class__.__name__ == stored_callback:
args, attributes = data.values()
new_callback = type(callback)(**args)
for attribute, value in attributes.items():
setattr(new_callback, attribute, value)
if isinstance(callback, TrainerControl):
# Specifically for restoring the `control` state
self.control = new_callback
else:
# We remove the existing callback and add a new one
self.callback_handler.remove_callback(callback)
self.callback_handler.add_callback(new_callback)
if not isinstance(data, list):
data = [data]
if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks):
matches = [
callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback
]
for callback, callback_data in zip(matches, data):
args = callback_data.get("args", {})
attributes = callback_data.get("attributes", {})
new_callback = type(callback)(**args)
for attribute, value in attributes.items():
setattr(new_callback, attribute, value)
if isinstance(callback, TrainerControl):
# Specifically for restoring the `control` state
self.control = new_callback
else:
new_callbacks.append(new_callback)
# We remove the existing callback and add it to the list of new callbacks
self.callback_handler.remove_callback(type(new_callback))
logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in")
else:
not_found.append(stored_callback)
if len(not_found) > 0:
logger.warning(
f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})"
)
for callback in new_callbacks:
self.callback_handler.add_callback(callback)

def hyperparameter_search(
self,
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,13 @@ def __post_init__(self):
raise TypeError(
f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}"
)
stateful_callbacks[callback.__class__.__name__] = callback.state()
name = callback.__class__.__name__
if name in stateful_callbacks:
if not isinstance(stateful_callbacks[name], list):
stateful_callbacks[name] = [stateful_callbacks[name]]
stateful_callbacks[name].append(callback.state())
else:
stateful_callbacks[callback.__class__.__name__] = callback.state()
self.stateful_callbacks = stateful_callbacks

def save_to_json(self, json_path: str):
Expand Down
50 changes: 50 additions & 0 deletions tests/trainer/test_trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
is_torch_available,
)
from transformers.testing_utils import require_torch
from transformers.trainer_callback import ExportableState


if is_torch_available():
Expand All @@ -40,6 +41,18 @@
from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel


class MyTestExportableCallback(TrainerCallback, ExportableState):
def __init__(self, my_test_state="test"):
self.my_test_state = my_test_state

def state(self):
return {
"args": {
"my_test_state": self.my_test_state,
},
}


class MyTestTrainerCallback(TrainerCallback):
"A callback that registers the events that goes through."

Expand Down Expand Up @@ -326,6 +339,43 @@ def test_stateful_mixed_callbacks(self):
assert early_stopping.early_stopping_threshold == 0.2
assert my_test.my_test_state == "test"

def test_stateful_duplicate_callbacks(self):
# Use something with non-defaults
cbs = [MyTestExportableCallback("first"), MyTestExportableCallback("second")]
trainer = self.get_trainer(
callbacks=cbs,
load_best_model_at_end=True,
save_strategy="steps",
evaluation_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
)
trainer.train()

# Create a new trainer with defaults
trainer = self.get_trainer(
callbacks=[MyTestExportableCallback(), MyTestExportableCallback()],
load_best_model_at_end=True,
save_strategy="steps",
evaluation_strategy="steps",
save_steps=2,
eval_steps=2,
max_steps=2,
restore_callback_states_from_checkpoint=True,
)
# Load it back in and verify values
checkpoint = os.path.join(self.output_dir, "checkpoint-2")
trainer.train(resume_from_checkpoint=checkpoint)
cbs = [
callback
for callback in trainer.callback_handler.callbacks
if isinstance(callback, MyTestExportableCallback)
]
assert len(cbs) == 2
assert cbs[0].my_test_state == "first"
assert cbs[1].my_test_state == "second"

def test_missing_stateful_callback(self):
cb = EarlyStoppingCallback()
trainer = self.get_trainer(
Expand Down

0 comments on commit 5548add

Please sign in to comment.