Skip to content

Commit

Permalink
Small refactor for update batch size (#1293)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jun 21, 2024
1 parent 8241f9c commit 78e4cc6
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 7 deletions.
37 changes: 30 additions & 7 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,13 +452,20 @@ def calculate_batch_size_info(
return device_batch_size, device_microbatch_size, device_grad_accum


def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]:
data_replication_degree = 1
device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info(
cfg['global_train_batch_size'],
cfg['device_train_microbatch_size'],
data_replication_degree=data_replication_degree,
)
def update_config_with_batch_size_info(
cfg: Dict[str, Any],
device_train_batch_size: Union[int, float],
device_train_microbatch_size: Union[int, float, Literal['auto']],
device_train_grad_accum: Union[int, Literal['auto']],
) -> Dict[str, Any]:
"""Update the config with batch size information.
Args:
cfg (Dict[str, Any]): The config to update.
Returns:
Dict[str, Any]: The updated config.
"""
cfg['n_gpus'] = dist.get_world_size()
cfg['device_train_batch_size'] = device_train_batch_size
cfg['device_train_microbatch_size'] = device_train_microbatch_size
Expand All @@ -473,6 +480,22 @@ def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]:
return cfg


def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]:
data_replication_degree = 1
device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info(
cfg['global_train_batch_size'],
cfg['device_train_microbatch_size'],
data_replication_degree=data_replication_degree,
)
cfg = update_config_with_batch_size_info(
cfg,
device_train_batch_size,
device_train_microbatch_size,
device_train_grad_accum,
)
return cfg


def process_init_device(model_cfg: Dict[str, Any], fsdp_config: Optional[Dict]):
# Restrict model init_device to 'meta' and 'cpu',
# using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors
Expand Down
15 changes: 15 additions & 0 deletions tests/utils/test_config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.utils.config_utils import update_config_with_batch_size_info


def test_update_config_with_batch_size_info():
config = {}
config = update_config_with_batch_size_info(config, 1, 2, 3)

assert config['n_gpus'] == 1
assert config['device_train_batch_size'] == 1
assert config['device_train_microbatch_size'] == 2
assert config['device_train_grad_accum'] == 3
assert config['device_eval_batch_size'] == 2

0 comments on commit 78e4cc6

Please sign in to comment.