Skip to content

Commit

Permalink
fix error in nv-torch-latest
Browse files Browse the repository at this point in the history
  • Loading branch information
delock committed Mar 11, 2024
1 parent ad19171 commit f4fe02b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
5 changes: 4 additions & 1 deletion tests/unit/checkpoint/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ def checkpoint_correctness_verification(config_dict,
empty_tag=False,
seq_dataloader=False,
load_module_only=False,
dtype=preferred_dtype()):
dtype=None):
if dtype == None:
dtype = preferred_dtype()

ds_model = create_deepspeed_model(config_dict=config_dict, model=models[0], base_optimizer=base_optimizers[0])

if seq_dataloader:
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/checkpoint/test_latest_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def test_existing_latest(self, tmpdir):
load_optimizer_states=True,
load_lr_scheduler_states=False,
fp16=False,
empty_tag=True)
empty_tag=True,
dtype=torch.float)

def test_missing_latest(self, tmpdir):
config_dict = {
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/checkpoint/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def test_checkpoint_pipe_engine(self, zero_stage, tmpdir):
fp16=config_dict['fp16']['enabled'],
load_optimizer_states=True,
load_lr_scheduler_states=True,
train_batch=True)
train_batch=True,
dtype=torch.float16 if zero_stage > 0 else torch.float32)

@pytest.mark.parametrize(
"base_topo,test_topo",
Expand Down

0 comments on commit f4fe02b

Please sign in to comment.