Skip to content

Commit

Permalink
Refactor curriculum learning callback
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed Jun 5, 2024
1 parent fb9a225 commit 3845a79
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 203 deletions.
256 changes: 190 additions & 66 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,27 @@
"""

import logging
from typing import Any, Dict
from typing import Any

from composer.core import State
from composer.loggers import Logger
from composer import DataSpec
from composer.core import State, TimeUnit, ensure_time
from composer.loggers import Logger, MosaicMLLogger
from composer.trainer.trainer import _get_initial_device_train_microbatch_size
from streaming import StreamingDataset
from streaming.base.util import clean_stale_shared_memory
from torch.utils.data import DataLoader

from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.utils.warnings import experimental_class
from llmfoundry.utils.builder_tokenizer import build_tokenizer
from llmfoundry.utils.config_utils import calculate_batch_size_info
from llmfoundry.utils.exceptions import BaseContextualError, TrainDataLoaderLocation

log = logging.getLogger(__name__)

__all__ = ['CurriculumLearning']


@experimental_class('CurriculumLearning callback')
class CurriculumLearning(CallbackWithConfig):
"""Starts an epoch with a different dataset when resuming from a checkpoint.
Expand All @@ -34,20 +39,190 @@ class CurriculumLearning(CallbackWithConfig):
dataset_index (int): The index of the dataset currently being used.
"""

def __init__(self, train_config: Dict, dataset_index: int):
self.dataset_index = dataset_index
self.saved_dataset_index = 0
self.all_dataset_configs = []
self.current_dataset_state = {}
# The current dataset config is resolved and passed in train.py
self.current_dataset_config = train_config['train_loader']
def __init__(
self, train_config: dict[str, Any], duration: str | int | TimeUnit,
schedule: list[dict[str, Any]]
):
non_positive_error = ValueError('The duration must be positive.')
unit_error = ValueError(
'Schedules can only be defined in terms of epochs or tokens.'
)

# Ensure all duration values are positive
# Ensure all duration units are in epochs or tokens
self._duration = ensure_time(duration, TimeUnit.EPOCH)
if self._duration.value <= 0:
raise non_positive_error
if self._duration.unit != TimeUnit.EPOCH and self._duration.unit != TimeUnit.TOKEN:
raise unit_error

self._schedule = schedule
for datamix in self._schedule:
assert 'duration' in datamix, 'Each datamix must have a duration.'
datamix['duration'] = ensure_time(
datamix['duration'], TimeUnit.EPOCH
)
if datamix['duration'].value <= 0:
raise non_positive_error
if datamix['duration'].unit != TimeUnit.EPOCH and datamix[
'duration'].unit != TimeUnit.TOKEN:
raise unit_error
assert 'train_loader' in datamix, 'Each datamix must have a train_loader.'

self._schedule_index = -1

# Copied from llmfoundry/utils/config_utils.py
self.device_train_batch_size, _, _ = calculate_batch_size_info(
train_config['global_train_batch_size'],
train_config['device_train_microbatch_size'],
data_replication_degree=1,
)

# Copied from scripts/train/train.py
tokenizer_name = train_config['tokenizer']['name']
tokenizer_kwargs = train_config['tokenizer'].get('kwargs', {})
self.tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

def before_load(self, state: State, logger: Logger):
del logger

# Save the current dataset state so we can restore it correctly
# if we are resuming with a new dataset.
train_loader = state.train_dataloader
# Ensure all duration units are the same as max_duration
units_match = True
if self._duration.unit != state.max_duration.unit:
units_match = False
for datamix in self._schedule:
if datamix['duration'].unit != state.max_duration.unit:
units_match = False
if not units_match:
raise ValueError((
'All durations in the schedule must have the same units as '
'the max_duration.'
))

# Ensure schedule duration is greater than max_duration
schedule_duration = self._duration
for datamix in self._schedule:
schedule_duration += datamix['duration']
if schedule_duration < state.max_duration:
raise ValueError((
'The sum of all durations in the schedule must be greater than '
'the max_duration.'
))

self._validate_dataloader(state.train_dataloader)

def after_load(self, state: State, logger: Logger):
del logger

self._validate_dataloader(state.train_dataloader)

# Check if adding a new datamix to a run that didn't use this callback
if self._schedule_index == -1 and state.timestamp >= self._duration:
self._schedule_index = 0
state.timestamp = state.timestamp.to_next_iteration()
# If checkpoint was saved before iteration was incremented, we need to increment it now
elif ((
self._schedule[self._schedule_index]['duration'].unit
== TimeUnit.TOKEN and state.timestamp.token_in_iteration
>= self._schedule[self._schedule_index]['duration'].value
) or (
self._schedule[self._schedule_index]['duration'].unit
== TimeUnit.EPOCH and state.timestamp.epoch_in_iteration
>= self._schedule[self._schedule_index]['duration'].value
)):
log.warning((
'The CurriculumLearning callback has detected that the previous run did not correctly '
'increment the iteration.'
))
self._schedule_index += 1
state.timestamp = state.timestamp.to_next_iteration()

def iteration_start(self, state: State, logger: Logger):
# Reset and initialize state train dataloader
log.warning(
'trainer._train_data_spec should be updated whenever the dataloader is updated'
)

# Swap the dataset if starting a new iteration that's not the original datamix
if self._schedule_index >= 0:
data_spec = self._build_train_loader(
train_loader_config=self._schedule[self._schedule_index]
['train_loader'],
logger=logger,
)
state.set_dataloader(
dataloader=data_spec.dataloader,
dataloader_label='train',
)
state.train_dataloader = state.dataloader
state.device_train_microbatch_size = _get_initial_device_train_microbatch_size(
state.device_train_microbatch_size,
state.auto_microbatching,
state.train_dataloader,
)
self._validate_dataloader(state.train_dataloader)

# Make sure the epoch in the loaded state dict is not negative.
# Since `__iter__` has not yet been called on the dataset, the
# epoch index in the dataset will still be -1. We need to ensure
# that we set the epoch correctly to 0 in this case.
dataset = state.train_dataloader.dataset
dataset_state = dataset.state_dict( # type: ignore
num_samples=0, from_beginning=False
)
if dataset_state['epoch'] < 0:
dataset_state['epoch'] = 0
dataset.load_state_dict(dataset_state) # type: ignore

# Set the length of the new iteration
if self._schedule_index == -1:
state._iteration_length = self._duration
else:
state._iteration_length = self._schedule[self._schedule_index
]['duration']

def iteration_end(self, state: State, logger: Logger):
del state, logger # unused

self._schedule_index += 1
clean_stale_shared_memory()

def state_dict(self):
return {
'duration': self._duration,
'schedule': self._schedule,
'schedule_index': self._schedule_index,
}

def load_state_dict(self, state: dict[str, Any]):
# Ensure that the schedule has not changed on already trained datamixes
assert self._duration == state['duration']
for idx in range(state['schedule_index'] + 1):
assert self._schedule[idx] == state['schedule'][idx]

self._schedule_index = state['schedule_index']

def _build_train_loader(
self, train_loader_config: dict[str, Any], logger: Logger
) -> DataSpec:
# Copied from scripts/train/train.py
log.info(
f'Building train loader in CurriculumLearning callback for dataset {self._schedule_index}'
)
try:
return build_dataloader(
train_loader_config,
self.tokenizer,
self.device_train_batch_size,
)
except BaseContextualError as e:
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
e.location = TrainDataLoaderLocation
destination.log_exception(e)
raise e

def _validate_dataloader(self, train_loader: Any):
# Check if we are using a DataLoader and StreamingDataset
if not isinstance(train_loader, DataLoader):
raise ValueError(
Expand All @@ -61,54 +236,3 @@ def before_load(self, state: State, logger: Logger):
f'because it requires loading and saving dataset state. ',
f'Instead, got a dataset of type {type(dataset)}',
)
assert isinstance(dataset, StreamingDataset)
# Save the current dataset state so we can restore it if needed.
self.current_dataset_state = dataset.state_dict( # type: ignore
num_samples=0, from_beginning=False)

def after_load(self, state: State, logger: Logger):
del logger

# As saved_dataset_index is loaded from state_dict, this only runs when
# a user explicitly increments the dataset_index and not on any other
# resumption, including autoresume.
train_loader = state._train_dataloader
assert isinstance(
train_loader,
DataLoader,
), 'CurriculumLearning callback requires a DataLoader.'
dataset = train_loader.dataset
assert isinstance(
dataset,
StreamingDataset,
), 'CurriculumLearning callback requires a StreamingDataset.'
if self.saved_dataset_index < self.dataset_index:
# Ignore the dataset state that was read in from the checkpoint, and
# replace with the new dataset state. This preserves resumption info.
if self.current_dataset_state['epoch'] < 0:
# Make sure the epoch in the loaded state dict is not negative.
# Since `__iter__` has not yet been called on the dataset, the
# epoch index in the dataset will still be -1. We need to ensure
# that we set the epoch correctly to 0 in this case.
self.current_dataset_state['epoch'] = 0
dataset.load_state_dict( # type: ignore
self.current_dataset_state)
# Start a new epoch since we are using a new dataset.
# This will also reset the sample_in_epoch written to checkpoint,
# making sure that subsequent resumptions proceed correctly.
state.timestamp = state.timestamp.to_next_epoch()
# Append the new dataset config to the list of all dataset configs.
self.all_dataset_configs.append(self.current_dataset_config)
elif self.dataset_index == 0 and len(self.all_dataset_configs) == 0:
# Make sure to track our current dataset config if we are just starting training.
self.all_dataset_configs.append(self.current_dataset_config)

def state_dict(self):
return {
'dataset_index': self.dataset_index,
'all_dataset_configs': self.all_dataset_configs,
}

def load_state_dict(self, state: Dict[str, Any]):
self.saved_dataset_index = state.get('dataset_index', 0)
self.all_dataset_configs = state.get('all_dataset_configs', [])
4 changes: 2 additions & 2 deletions llmfoundry/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@
build_metric,
build_optimizer,
build_scheduler,
build_tokenizer,
)
from llmfoundry.utils.builder_tokenizer import build_tokenizer
from llmfoundry.utils.checkpoint_conversion_helpers import (
convert_and_save_ft_weights,
get_hf_tokenizer_from_composer_state_dict,
load_tokenizer,
)
from llmfoundry.utils.config_dist_utils import process_init_device
from llmfoundry.utils.config_utils import (
calculate_batch_size_info,
log_config,
pop_config,
process_init_device,
update_batch_size_info,
)
from llmfoundry.utils.data_prep_utils import (
Expand Down
54 changes: 1 addition & 53 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,13 @@
from omegaconf import OmegaConf as om
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from transformers import PreTrainedTokenizerBase

from llmfoundry import registry
from llmfoundry.callbacks import EvalGauntlet
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.eval.datasets.in_context_learning_evaluation import \
get_icl_task_dataloader
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
from llmfoundry.utils.config_utils import to_dict_container, to_list_container
from llmfoundry.utils.registry_utils import construct_from_registry

Expand All @@ -50,7 +49,6 @@
'build_logger',
'build_optimizer',
'build_scheduler',
'build_tokenizer',
'build_composer_model',
'build_metric',
]
Expand Down Expand Up @@ -446,56 +444,6 @@ def build_scheduler(
)


def build_tokenizer(
tokenizer_name: str,
tokenizer_kwargs: Dict[str, Any],
) -> PreTrainedTokenizerBase:
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup'

if dist.is_available() and dist.is_initialized(
) and dist.get_world_size() > 1:
# Make sure the tokenizer files are downloaded and cached first by local rank 0
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass

if tokenizer_name.startswith('tiktoken'):
tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
**tokenizer_kwargs,
)

# HuggingFace does not respect the model_max_length kwarg, and overrides it with
# min(kwargs['model_max_length'], original_config['model_max_length']), so we
# explicitly set it here
tokenizer.model_max_length = tokenizer_kwargs.get(
'model_max_length',
int(1e30),
)

if not hasattr(tokenizer, 'eos_token') or tokenizer.eos_token is None:
raise ValueError(
f'The tokenizer {tokenizer_name} must have an eos_token.',
)

if dist.is_available() and dist.is_initialized(
) and dist.get_world_size() > 1:
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_tokenizer_setup')

dist.barrier()

if dist.get_local_rank() == 0:
os.remove(signal_file_path)

return tokenizer


def build_icl_evaluators(
icl_tasks: Union[str, List[Dict[str, Any]]],
tokenizer: PreTrainedTokenizerBase,
Expand Down
Loading

0 comments on commit 3845a79

Please sign in to comment.