Skip to content

Commit

Permalink
use mark parametrize for test
Browse files Browse the repository at this point in the history
  • Loading branch information
inkcherry committed Nov 11, 2024
1 parent 00ac4eb commit 25df962
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions tests/unit/runtime/zero/test_zero_leaf_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

import pytest
import deepspeed.comm as dist
import torch

Expand Down Expand Up @@ -190,14 +191,14 @@ def test_set_no_match_class(self):
pass


@pytest.mark.parametrize("module_granularity_threshold", [0, 100, 12100, 10000000])
class TestZ3LeafOptimization(DistributedTest):
world_size = 2
reuse_dist_env = True

def test_finegrained_optimization(self):
def test_finegrained_optimization(self, module_granularity_threshold: int):
hidden_dim = 128
num_block = 16
stage3_module_granularity_threshold_list = [0, 100, 12100, 10000000]
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
Expand Down Expand Up @@ -248,19 +249,14 @@ def bench_loss_and_time(config):
model.destroy()
return loss_list, duration

result_loss_list = []
result_duration = []

baseline_loss_list, baseline_exec_time = bench_loss_and_time(config_dict)

for threshold in stage3_module_granularity_threshold_list:
config_dict["zero_optimization"]["stage3_module_granularity_threshold"] = threshold
loss_list, duration = bench_loss_and_time(config_dict)
result_duration.append(duration)
result_loss_list.append(loss_list)
config_dict["zero_optimization"]["stage3_module_granularity_threshold"] = module_granularity_threshold
loss, duration = bench_loss_and_time(config_dict)

if dist.get_rank() == 0:
print(f"baseline exec time:", baseline_exec_time)
for idx, threshold in enumerate(stage3_module_granularity_threshold_list):
if dist.get_rank() == 0:
print(f"finegrained optimziation exec time: {result_duration[idx]},granularity threshold:{threshold} ")
assert baseline_loss_list == result_loss_list[idx], f"incorrect loss value with threshold:{threshold}"
print(
f"finegrained optimziation exec time: {duration},granularity threshold:{module_granularity_threshold} "
)
assert baseline_loss_list == loss, f"incorrect loss value with threshold:{module_granularity_threshold}"

0 comments on commit 25df962

Please sign in to comment.