Skip to content

Commit

Permalink
change preferred_dtype into a function
Browse files Browse the repository at this point in the history
  • Loading branch information
delock committed Mar 8, 2024
1 parent ae544e1 commit 4623622
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 19 deletions.
13 changes: 7 additions & 6 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,10 @@ def get_test_path(filename):


# fp16 > bf16 > fp32
if get_accelerator().is_fp16_supported():
prefered_dtype = torch.float16
elif get_accelerator().is_bf16_supported():
preferred_dtype = torch.bfloat16
else:
preferred_dtype = torch.float32
def preferred_dtype():
if get_accelerator().is_fp16_supported():
return torch.float16
elif get_accelerator().is_bf16_supported():
return torch.bfloat16
else:
return torch.float32
2 changes: 1 addition & 1 deletion tests/unit/multi_output_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def multi_output_dataloader(model, total_samples, hidden_dim, device, inputs, ta
torch.full(size=(total_samples, hidden_dim),
fill_value=x,
device=device,
dtype=preferred_dtype,
dtype=preferred_dtype(),
requires_grad=True) for x in inputs
]

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/runtime/test_multi_output_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test(self, tmpdir):
inputs, targets = batch[:midpoint], batch[midpoint:]
loss_tuple = model(inputs, targets)

expected_loss = torch.tensor(2.302734375, dtype=preferred_dtype, device=model.device)
expected_loss = torch.tensor(2.302734375, dtype=preferred_dtype(), device=model.device)
for loss in loss_tuple:
assert loss.shape == torch.Size([])
assert loss.item() == approx(expected_loss.item())
Expand Down Expand Up @@ -114,7 +114,7 @@ def test(self, tmpdir):
loss_tuple = model(inputs, targets)
assert len(loss_tuple) == 3

expected_loss = torch.tensor(2.302734375, dtype=preferred_dtype, device=model.device)
expected_loss = torch.tensor(2.302734375, dtype=preferred_dtype(), device=model.device)

for loss in loss_tuple:
assert loss.shape == torch.Size([])
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/runtime/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ def forward(self, x: Tensor) -> Tensor:

ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_cfg)

dtype = preferred_dtype
dtype = preferred_dtype()
for _ in range(3): # test multiple iterations to cover prefetching
activations: List[Tensor] = ds_engine(torch.ones((param_sz, ), dtype=dtype, device=ds_engine.device))
assert len(activations) == n_layers
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/runtime/zero/test_zero_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def forward(self, input):
with deepspeed.zero.GatheredParameters(net.linear1.weight):
assert net.linear1.weight.numel() == net.dim**2

input = torch.rand(net.dim).to(engine.device).to(preferred_dtype)
input = torch.rand(net.dim).to(engine.device).to(preferred_dtype())
loss = engine(input)
engine.backward(loss)
engine.step()
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/runtime/zero/test_zero_context_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_ext_param_return(self):
engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config)

for _ in range(5):
input = torch.rand(net.dim).to(engine.device).to(preferred_dtype)
input = torch.rand(net.dim).to(engine.device).to(preferred_dtype())
loss = engine(input)
engine.backward(loss)
engine.step()
Expand All @@ -160,7 +160,7 @@ def test_ext_param_returnobj(self):
engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config)

for _ in range(5):
input = torch.rand(net.dim).to(engine.device).to(preferred_dtype)
input = torch.rand(net.dim).to(engine.device).to(preferred_dtype())
loss = engine(input)
assert len(net._external_params) == 1
assert len(net.dangler._external_params) == 0
Expand All @@ -178,7 +178,7 @@ def test_stage_3_output_type(self, output_type):
engine, _, _, _ = deepspeed.initialize(args=args, model=net, model_parameters=net.parameters(), config=config)

for _ in range(1):
input = torch.rand(net.dim).to(engine.device).to(preferred_dtype)
input = torch.rand(net.dim).to(engine.device).to(preferred_dtype())
loss = engine(input)
if loss is not None:
if isinstance(loss, dict):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/runtime/zero/test_zero_leaf_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _test_set_z3_leaf_modules(self, cls, requires_grad):
set_z3_leaf_modules(model, [cls])
assert z3_leaf_module(model)

run_model(model, config_dict, hidden_dim, preferred_dtype, requires_grad)
run_model(model, config_dict, hidden_dim, preferred_dtype(), requires_grad)

def test_choose_module_by_counter(self):
self._test_set_z3_leaf_modules(ChooseModuleByCounter, True)
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/runtime/zero/test_zero_tensor_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, froz
validate_after_bwd = lambda model: validate_tensor(model, api_type, opt_states=False)
validate_after_step = lambda model: validate_tensor(model, api_type, opt_states=True)

run_fragmented_model(model, config_dict, hidden_dim, preferred_dtype, validate_after_bwd, validate_after_step)
run_fragmented_model(model, config_dict, hidden_dim, preferred_dtype(), validate_after_bwd,
validate_after_step)

def test_bf16_fragments(self, frozen_weights):
if get_accelerator().device_name() == "cpu":
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/simple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,21 +263,21 @@ def forward(self, x, y, **kwargs):
return hidden_dim


def random_dataset(total_samples, hidden_dim, device, dtype=preferred_dtype):
def random_dataset(total_samples, hidden_dim, device, dtype=preferred_dtype()):
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
return train_dataset


def random_dataloader(model, total_samples, hidden_dim, device, dtype=preferred_dtype):
def random_dataloader(model, total_samples, hidden_dim, device, dtype=preferred_dtype()):
batch_size = model.train_micro_batch_size_per_gpu()
train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
return train_loader


def sequence_dataloader(model, total_samples, hidden_dim, device, seq_len: int = 32, dtype=preferred_dtype):
def sequence_dataloader(model, total_samples, hidden_dim, device, seq_len: int = 32, dtype=preferred_dtype()):
batch_size = model.train_micro_batch_size_per_gpu()
train_data = torch.randn(total_samples, seq_len, hidden_dim, device=device, dtype=dtype)
train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
Expand Down

0 comments on commit 4623622

Please sign in to comment.