Skip to content

Commit

Permalink
Changes for dataset swap callback (#1569)
Browse files Browse the repository at this point in the history
  • Loading branch information
gupta-abhay authored Oct 4, 2024
1 parent 4bbb4a5 commit 788c1f5
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions llmfoundry/callbacks/dataset_swap_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import logging
from typing import Any
from dataclasses import dataclass

from composer.core import State
from composer.loggers import Logger
Expand All @@ -23,6 +23,12 @@
__all__ = ['DatasetSwap']


@dataclass
class DatasetSwapStateDict:
dataset_index: int
all_dataset_configs: list


@experimental_class('DatasetSwap callback')
class DatasetSwap(CallbackWithConfig):
"""Starts an epoch with a different dataset when resuming from a checkpoint.
Expand Down Expand Up @@ -105,10 +111,18 @@ def after_load(self, state: State, logger: Logger):

def state_dict(self):
return {
'dataset_index': self.dataset_index,
'all_dataset_configs': self.all_dataset_configs,
'callback_state':
DatasetSwapStateDict(
dataset_index=self.dataset_index,
all_dataset_configs=self.all_dataset_configs,
),
}

def load_state_dict(self, state: dict[str, Any]):
self.saved_dataset_index = state.get('dataset_index', 0)
self.all_dataset_configs = state.get('all_dataset_configs', [])
def load_state_dict(self, state: dict[str, DatasetSwapStateDict]):
_dummy_obj = DatasetSwapStateDict(
dataset_index=0,
all_dataset_configs=[],
)
_state_obj = state.get('callback_state', _dummy_obj)
self.saved_dataset_index = getattr(_state_obj, 'dataset_index')
self.all_dataset_configs = getattr(_state_obj, 'all_dataset_configs')

0 comments on commit 788c1f5

Please sign in to comment.