Skip to content

Commit

Permalink
skip some special case
Browse files Browse the repository at this point in the history
  • Loading branch information
delock committed Mar 11, 2024
1 parent 41ced03 commit ad19171
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 19 deletions.
3 changes: 2 additions & 1 deletion deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ def _set_dtype(self, ds_config, dtype):
else:
self.dtype = torch.float
else:
self.dtype = dtype or torch.half
self.dtype = dtype or torch.float16 if get_accelerator().is_fp16_supported(
) else torch.bfloat16 if get_accelerator().is_bf16_supported else torch.float32

def patch_init_and_builtins(self):

Expand Down
22 changes: 4 additions & 18 deletions tests/unit/checkpoint/test_zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ class TestZeROCheckpoint(DistributedTest):

@pytest.mark.parametrize('zero_stage', [3])
def test_pipeline_checkpoint_loading(self, tmpdir, zero_stage):
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
config_dict = {
"train_batch_size": 2,
"optimizer": {
Expand Down Expand Up @@ -53,8 +51,6 @@ def test_pipeline_checkpoint_loading(self, tmpdir, zero_stage):
def test_load_optimizer_state(self, tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")

config_dict = {
"train_batch_size": 2,
Expand Down Expand Up @@ -95,8 +91,6 @@ def test_load_optimizer_state(self, tmpdir, zero_stage, use_cpu_offload, adam_op
def test_not_load_optimizer_state(self, tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]:
pytest.skip("cpu-adam is not compatible")
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")

config_dict = {
"train_batch_size": 2,
Expand Down Expand Up @@ -133,8 +127,6 @@ def test_not_load_optimizer_state(self, tmpdir, zero_stage, use_cpu_offload, ada

@pytest.mark.parametrize('zero_stage', [1, 2])
def test_hybrid_optimizer_state(self, tmpdir, zero_stage):
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
config_dict = {
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 2,
Expand All @@ -161,8 +153,8 @@ def test_hybrid_optimizer_state(self, tmpdir, zero_stage):

@pytest.mark.parametrize('zero_stage', [0, 1, 2, 3])
def test_load_module_only(self, tmpdir, zero_stage):
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
if zero_stage == 0 and get_accelerator().device_name() == "cpu":
pytest.skip("CPU Accelerator does not support this test")
config_dict = {
"train_batch_size": 2,
"optimizer": {
Expand Down Expand Up @@ -336,8 +328,8 @@ def test_immediate_save_load(self, tmpdir, zero_stage):

@pytest.mark.parametrize('zero_stage', [0, 1, 2, 3])
def test_load_immediate_save(self, tmpdir, zero_stage):
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
if zero_stage == 0 and get_accelerator().device_name() == "cpu":
pytest.skip("CPU Accelerator does not support this test")
config_dict = {
"train_batch_size": 4,
"optimizer": {
Expand Down Expand Up @@ -421,8 +413,6 @@ class TestZeROCheckpointFrozenWeights(DistributedTest):
@pytest.mark.parametrize('zero_stage', [1, 2, 3])
def test_load_optimizer_state(self, tmpdir, zero_stage):

if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
Expand Down Expand Up @@ -454,8 +444,6 @@ def test_load_optimizer_state(self, tmpdir, zero_stage):
@pytest.mark.parametrize('zero_stage', [1, 2, 3])
def test_not_load_optimizer_state(self, tmpdir, zero_stage):

if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
Expand Down Expand Up @@ -485,8 +473,6 @@ def test_not_load_optimizer_state(self, tmpdir, zero_stage):

@pytest.mark.parametrize('zero_stage', [1, 2, 3])
def test_load_module_only(self, tmpdir, zero_stage):
if not get_accelerator().is_fp16_supported():
pytest.skip("fp16 is not supported")
config_dict = {
"train_batch_size": 2,
"optimizer": {
Expand Down

0 comments on commit ad19171

Please sign in to comment.