Skip to content

Commit

Permalink
add more useful info to state (mosaicml#1848)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jan 9, 2023
1 parent c184d75 commit 5a51b79
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 15 deletions.
15 changes: 11 additions & 4 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class State(Serializable):
"""The state of the trainer.
Contains variables that the trainer tracks throughout the training loop. Note that all the necessary parts (i.e.,
:attr:`serialized_attributes`) of state are serialized when the trainer is checkpointed so that it can be used
:attr:`serialized_attributes`) of state are serialized when the trainer is checkpointed so that it can be used to
restore the trainer and continue training from a checkpoint. :mod:`~composer.algorithms` are able to modify an
instance of this class in-place.
Expand Down Expand Up @@ -622,6 +622,13 @@ def _get_state_metadata(self) -> Dict[str, Any]:
"""
metadata_dict = {}
metadata_dict['composer_env_info'] = get_composer_env_dict()
metadata_dict['device'] = self.device.name
metadata_dict['precision'] = self.precision.value
metadata_dict['world_size'] = dist.get_world_size()
metadata_dict['device_train_microbatch_size'] = self.device_train_microbatch_size

if self._train_dataloader is not None and hasattr(self._train_dataloader, 'batch_size'):
metadata_dict['train_dataloader_batch_size'] = self._train_dataloader.batch_size # type: ignore

return metadata_dict

Expand Down Expand Up @@ -663,7 +670,7 @@ def state_dict(self) -> Dict[str, Any]:
serialized_value = self._dataset_state_dict()
elif attribute_name == 'model':
# Save model directly instead of by class name, since model may be wrapped by DistributedDataParallel
# If it is DDP wrapped, do not save the `module.` prefix, as that is an implmentation detail
# If it is DDP wrapped, do not save the `module.` prefix, as that is an implementation detail
with get_fsdp_rank0_cpu_save_context(
attribute_value) if self.fsdp_enabled else contextlib.nullcontext():
model_state = attribute_value.state_dict()
Expand Down Expand Up @@ -756,7 +763,7 @@ def _apply_required_algorithms(
f"loaded checkpoint but is now specified in the following forms: {', '.join(current_algos[type(algo)])}."
'Potential parameter discrepancies for this required_on_load algorithm may lead to '
'unexpected behavior, including failing to load weights for some layers.'))
# Otherwise, queue algorithm to be autoappled
# Otherwise, queue algorithm to be autoapplied
elif type(algo) not in current_algos:
missing_algos.add(algo)
missing_algo_names.append(algo_name)
Expand Down Expand Up @@ -1007,7 +1014,7 @@ def dataloader_len(self):
.. note::
If not explicitely specified, this value is an approximation, as it depends on ``len(self.dataloader)``.
If not explicitly specified, this value is an approximation, as it depends on ``len(self.dataloader)``.
See the :doc:`PyTorch DataLoader Documentation <torch:data>` for more information.
Returns:
Expand Down
1 change: 1 addition & 0 deletions composer/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Device(Serializable, ABC):
"""

dist_backend: str = ''
name: str = ''

@abstractmethod
def module_to_device(self, module: T_nnModule) -> T_nnModule:
Expand Down
1 change: 1 addition & 0 deletions composer/devices/device_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class DeviceCPU(Device):
"""

dist_backend = 'gloo'
name = 'cpu'
_device = torch.device('cpu')

def module_to_device(self, module: T_nnModule) -> T_nnModule:
Expand Down
1 change: 1 addition & 0 deletions composer/devices/device_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DeviceGPU(Device):
For more information, see :ref:`torch:tf32_on_ampere`.
"""
dist_backend = 'nccl'
name = 'gpu'

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions composer/devices/device_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class DeviceMPS(Device):
This class takes no arguments.
"""
dist_backend = ''
name = 'mps'

def __init__(self):
if version.parse(torch.__version__) < version.parse('1.12.0'):
Expand Down
2 changes: 2 additions & 0 deletions composer/devices/device_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class DeviceTPU(Device):
More details.
"""

name = 'tpu'

def __init__(self):
import torch_xla.core.xla_model as xm

Expand Down
30 changes: 19 additions & 11 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,24 @@ def random_tensor(size=(4, 10)):

def get_dummy_state(request: pytest.FixtureRequest):
model = SimpleModel()
dataset = RandomClassificationDataset()
dataloader = DataLoader(dataset, batch_size=4)
optimizers = torch.optim.Adadelta(model.parameters())
device = None
for item in request.session.items:
device = DeviceCPU() if item.get_closest_marker('gpu') is None else DeviceGPU()
break
assert device != None
state = State(
model=model,
device=device,
run_name=f'{random.randint(0, 100)}',
grad_accum=random.randint(0, 100),
rank_zero_seed=random.randint(0, 100),
precision=Precision.AMP_FP16,
max_duration=f'{random.randint(0, 100)}ep',
optimizers=optimizers,
)
state = State(model=model,
device=device,
train_dataloader=dataloader,
run_name=f'{random.randint(0, 100)}',
grad_accum=random.randint(0, 100),
rank_zero_seed=random.randint(0, 100),
precision=Precision.AMP_FP16,
max_duration=f'{random.randint(0, 100)}ep',
optimizers=optimizers,
device_train_microbatch_size=2)
state.schedulers = torch.optim.lr_scheduler.StepLR(optimizers, step_size=3)
state.loss = random_tensor()
state.batch = (random_tensor(), random_tensor())
Expand Down Expand Up @@ -134,7 +136,7 @@ def test_state_batch_set_item(batch, key, val, request: pytest.FixtureRequest):
assert state.batch_get_item(key) == val


def test_composer_env_info_in_state_dict(tmp_path, request: pytest.FixtureRequest):
def test_composer_metadata_in_state_dict(tmp_path, request: pytest.FixtureRequest):
state = get_dummy_state(request)
save_path = pathlib.Path(tmp_path) / 'state_dict.pt'
with open(save_path, 'wb') as _tmp_file:
Expand All @@ -148,3 +150,9 @@ def test_composer_env_info_in_state_dict(tmp_path, request: pytest.FixtureReques
actual_env_info_keys = set(loaded_state_dict['metadata']['composer_env_info'].keys())
assert expected_env_info_keys == actual_env_info_keys
assert loaded_state_dict['metadata']['composer_env_info']['composer_version'] == composer.__version__

assert loaded_state_dict['metadata']['device'] == 'cpu'
assert loaded_state_dict['metadata']['precision'] == 'amp_fp16'
assert loaded_state_dict['metadata']['world_size'] == 1
assert loaded_state_dict['metadata']['device_train_microbatch_size'] == 2
assert loaded_state_dict['metadata']['train_dataloader_batch_size'] == 4

0 comments on commit 5a51b79

Please sign in to comment.