Skip to content

Commit

Permalink
CLean
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Apr 25, 2024
1 parent 5548add commit 805daed
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2995,10 +2995,11 @@ def _load_callback_state(self):
if not isinstance(data, list):
data = [data]
if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks):
matches = [
# We can load/restore from multiple callbacks of the same type.
duplicates = [
callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback
]
for callback, callback_data in zip(matches, data):
for callback, callback_data in zip(duplicates, data):
args = callback_data.get("args", {})
attributes = callback_data.get("attributes", {})
new_callback = type(callback)(**args)
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,13 @@ def __post_init__(self):
)
name = callback.__class__.__name__
if name in stateful_callbacks:
# We can have multiple versions of the same callback
# if so, we store them as a list of states to restore
if not isinstance(stateful_callbacks[name], list):
stateful_callbacks[name] = [stateful_callbacks[name]]
stateful_callbacks[name].append(callback.state())
else:
stateful_callbacks[callback.__class__.__name__] = callback.state()
stateful_callbacks[name] = callback.state()
self.stateful_callbacks = stateful_callbacks

def save_to_json(self, json_path: str):
Expand Down

0 comments on commit 805daed

Please sign in to comment.