Skip to content

Commit

Permalink
Fix seed for FSDP wrap (mosaicml#2833)
Browse files Browse the repository at this point in the history
* first try

* add context

* lint

* more lint

* remove comment

---------

Co-authored-by: Daniel King <[email protected]>
Co-authored-by: Your Name <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2024
1 parent eb4fbd0 commit c48e6fe
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 10 deletions.
8 changes: 5 additions & 3 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from composer.core.serializable import Serializable
from composer.core.time import Time, Timestamp, TimeUnit
from composer.devices import Device
from composer.utils import batch_get, batch_set, dist, ensure_tuple, get_composer_env_dict, is_model_deepspeed
from composer.utils import (batch_get, batch_set, dist, ensure_tuple, get_composer_env_dict, is_model_deepspeed,
reproducibility)
from composer.utils.misc import using_torch_2

if TYPE_CHECKING:
Expand Down Expand Up @@ -1264,8 +1265,9 @@ def load_model_and_optimizer_state(
assert self.fsdp_config is not None
log.info('Wrapping model with FSDP after loading model_state.')
from composer.trainer.dist_strategy import prepare_fsdp_module
prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device,
self.auto_microbatching)
with reproducibility.seed_context(self.rank_zero_seed):
prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device,
self.auto_microbatching)
log.debug('Finished wrapping model with FSDP.')

# Legacy optimizer state load must happen after FSDP monolith
Expand Down
5 changes: 0 additions & 5 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,6 @@ def prepare_fsdp_module(
'gpu and some ranks are on meta. Either keep all ranks on the same '
"device or set fsdp_config['sync_module_states'] = True. Otherwise, "
'some weights may be randomly initialized when loading a checkpoint.')
# Comment out while we debug deadlock
# if fsdp_config['sharding_strategy'] in ('HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'):
# raise ValueError('HSDP (HYBRID_SHARD or _HYBRID_SHARD_ZERO2) requires '
# 'fsdp_config["sync_module_states"] = True or different replicas will '
# 'have different weights.')

# Check if other ranks OOMed after forward/backward pass when using auto microbatching. This
# may happen when close to memory limit or with uneven memory usage across ranks. Since we
Expand Down
6 changes: 4 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,8 @@ def __init__(

# FSDP wrap if not using monolith checkpoint on rank 0 only
if self.state.fsdp_config is not None and fsdp_auto_wrap and not self.state.load_fsdp_monolith_rank0_only:
prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)
with reproducibility.seed_context(self.state.rank_zero_seed):
prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)

# Configure Deepspeed
if self.state.deepspeed_config is not None:
Expand Down Expand Up @@ -1440,7 +1441,8 @@ def __init__(
# FSDP wrap if model is not yet wrapped and FSDP is enabled. This can happen if
# load_fsdp_monolith_rank0_only=True but no checkpoint was loaded.
if not self.state.fsdp_enabled and self.state.fsdp_config is not None and self.state.fsdp_auto_wrap and self.state.load_fsdp_monolith_rank0_only:
prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)
with reproducibility.seed_context(self.state.rank_zero_seed):
prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)

self.engine.run_event(Event.AFTER_LOAD)

Expand Down
11 changes: 11 additions & 0 deletions composer/utils/reproducibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import textwrap
import time
import warnings
from contextlib import contextmanager
from typing import Any, Dict, List

import numpy as np
Expand All @@ -62,6 +63,7 @@
from composer.utils import dist

__all__ = [
'seed_context',
'configure_deterministic_mode',
'get_random_seed',
'seed_all',
Expand All @@ -76,6 +78,15 @@
MAX_SEED = 2**32 - 1


@contextmanager
def seed_context(seed: int):
"""Context manager to store rng_state and reseed for duration of context."""
rng_state = get_rng_state()
seed_all(seed)
yield
load_rng_state(rng_state)


def configure_deterministic_mode():
"""Configure PyTorch deterministic mode.
Expand Down

0 comments on commit c48e6fe

Please sign in to comment.