From 25df9623cdc8ad138fb5221f6a4d779a2ab2ffae Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 11 Nov 2024 02:57:03 +0000 Subject: [PATCH] use mark parametrize for test --- .../runtime/zero/test_zero_leaf_module.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/unit/runtime/zero/test_zero_leaf_module.py b/tests/unit/runtime/zero/test_zero_leaf_module.py index a1eacabc44ba..74c709883645 100644 --- a/tests/unit/runtime/zero/test_zero_leaf_module.py +++ b/tests/unit/runtime/zero/test_zero_leaf_module.py @@ -3,6 +3,7 @@ # DeepSpeed Team +import pytest import deepspeed.comm as dist import torch @@ -190,14 +191,14 @@ def test_set_no_match_class(self): pass +@pytest.mark.parametrize("module_granularity_threshold", [0, 100, 12100, 10000000]) class TestZ3LeafOptimization(DistributedTest): world_size = 2 reuse_dist_env = True - def test_finegrained_optimization(self): + def test_finegrained_optimization(self, module_granularity_threshold: int): hidden_dim = 128 num_block = 16 - stage3_module_granularity_threshold_list = [0, 100, 12100, 10000000] config_dict = { "train_micro_batch_size_per_gpu": 1, "steps_per_print": 1, @@ -248,19 +249,14 @@ def bench_loss_and_time(config): model.destroy() return loss_list, duration - result_loss_list = [] - result_duration = [] - baseline_loss_list, baseline_exec_time = bench_loss_and_time(config_dict) - for threshold in stage3_module_granularity_threshold_list: - config_dict["zero_optimization"]["stage3_module_granularity_threshold"] = threshold - loss_list, duration = bench_loss_and_time(config_dict) - result_duration.append(duration) - result_loss_list.append(loss_list) + config_dict["zero_optimization"]["stage3_module_granularity_threshold"] = module_granularity_threshold + loss, duration = bench_loss_and_time(config_dict) + if dist.get_rank() == 0: print(f"baseline exec time:", baseline_exec_time) - for idx, threshold in enumerate(stage3_module_granularity_threshold_list): - if dist.get_rank() == 0: - print(f"finegrained optimziation exec time: {result_duration[idx]},granularity threshold:{threshold} ") - assert baseline_loss_list == result_loss_list[idx], f"incorrect loss value with threshold:{threshold}" + print( + f"finegrained optimziation exec time: {duration},granularity threshold:{module_granularity_threshold} " + ) + assert baseline_loss_list == loss, f"incorrect loss value with threshold:{module_granularity_threshold}"