diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a32f142dfcedfe..26ab0877d0c171 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2995,10 +2995,11 @@ def _load_callback_state(self): if not isinstance(data, list): data = [data] if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks): - matches = [ + # We can load/restore from multiple callbacks of the same type. + duplicates = [ callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback ] - for callback, callback_data in zip(matches, data): + for callback, callback_data in zip(duplicates, data): args = callback_data.get("args", {}) attributes = callback_data.get("attributes", {}) new_callback = type(callback)(**args) diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 4bdf52ed48e0aa..a21c46fea9fe2a 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -127,11 +127,13 @@ def __post_init__(self): ) name = callback.__class__.__name__ if name in stateful_callbacks: + # We can have multiple versions of the same callback + # if so, we store them as a list of states to restore if not isinstance(stateful_callbacks[name], list): stateful_callbacks[name] = [stateful_callbacks[name]] stateful_callbacks[name].append(callback.state()) else: - stateful_callbacks[callback.__class__.__name__] = callback.state() + stateful_callbacks[name] = callback.state() self.stateful_callbacks = stateful_callbacks def save_to_json(self, json_path: str):