From 805daed95aae13ebf9833331c9241742e45dc1de Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Thu, 4 Apr 2024 12:19:55 -0400 Subject: [PATCH] CLean --- src/transformers/trainer.py | 5 +++-- src/transformers/trainer_callback.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a32f142dfcedfe..26ab0877d0c171 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2995,10 +2995,11 @@ def _load_callback_state(self): if not isinstance(data, list): data = [data] if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks): - matches = [ + # We can load/restore from multiple callbacks of the same type. + duplicates = [ callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback ] - for callback, callback_data in zip(matches, data): + for callback, callback_data in zip(duplicates, data): args = callback_data.get("args", {}) attributes = callback_data.get("attributes", {}) new_callback = type(callback)(**args) diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 4bdf52ed48e0aa..a21c46fea9fe2a 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -127,11 +127,13 @@ def __post_init__(self): ) name = callback.__class__.__name__ if name in stateful_callbacks: + # We can have multiple versions of the same callback + # if so, we store them as a list of states to restore if not isinstance(stateful_callbacks[name], list): stateful_callbacks[name] = [stateful_callbacks[name]] stateful_callbacks[name].append(callback.state()) else: - stateful_callbacks[callback.__class__.__name__] = callback.state() + stateful_callbacks[name] = callback.state() self.stateful_callbacks = stateful_callbacks def save_to_json(self, json_path: str):