Skip to content

Commit

Permalink
improve the way to determine whether a variable is None (#4782)
Browse files Browse the repository at this point in the history
refactor: improve the way to decide whether a variable is None
fix: type mismatch for judging if current accelerator is in
SUPPORTED_ACCELERATOR_LIST

---------

Co-authored-by: ryan <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
3 people authored Dec 8, 2023
1 parent 2bdf061 commit 7b818ee
Show file tree
Hide file tree
Showing 15 changed files with 24 additions and 24 deletions.
4 changes: 2 additions & 2 deletions accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def random(self):
return torch.random

def set_rng_state(self, new_state, device_index=None):
if device_index == None:
if device_index is None:
return torch.set_rng_state(new_state)
return torch.set_rng_state(new_state, device_index)

Expand Down Expand Up @@ -253,7 +253,7 @@ def on_accelerator(self, tensor):
# create an instance of op builder and return, name specified by class_name
def create_op_builder(self, op_name):
builder_class = self.get_op_builder(op_name)
if builder_class != None:
if builder_class is not None:
return builder_class()
return None

Expand Down
4 changes: 2 additions & 2 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def is_synchronized_device(self):

# Device APIs
def device_name(self, device_index=None):
if device_index == None:
if device_index is None:
return 'cuda'
return 'cuda:{}'.format(device_index)

Expand Down Expand Up @@ -280,7 +280,7 @@ def op_builder_dir(self):
class_dict = None

def _lazy_init_class_dict(self):
if self.class_dict != None:
if self.class_dict is not None:
return
else:
self.class_dict = {}
Expand Down
4 changes: 2 additions & 2 deletions accelerator/mps_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def is_synchronized_device(self):

# Device APIs
def device_name(self, device_index=None):
if device_index == None:
if device_index is None:
return "mps"
return "mps:{}".format(device_index)

Expand Down Expand Up @@ -221,7 +221,7 @@ def op_builder_dir(self):
# create an instance of op builder, specified by class_name
def create_op_builder(self, op_name):
builder_class = self.get_op_builder(op_name)
if builder_class != None:
if builder_class is not None:
return builder_class()
return None

Expand Down
2 changes: 1 addition & 1 deletion accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def is_synchronized_device(self):

# Device APIs
def device_name(self, device_index=None):
if device_index == None:
if device_index is None:
return 'npu'
return 'npu:{}'.format(device_index)

Expand Down
2 changes: 1 addition & 1 deletion accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _validate_accelerator(accel_obj):


def is_current_accelerator_supported():
return get_accelerator() in SUPPORTED_ACCELERATOR_LIST
return get_accelerator().device_name() in SUPPORTED_ACCELERATOR_LIST


def get_accelerator():
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/inference/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, config: Dict, pre_quant_layer: nn.Embedding) -> None:
device=pre_quant_layer.weight.device,
dtype=pre_quant_layer.weight.dtype)

assert pre_quant_layer.max_norm == None, 'Not supported'
assert pre_quant_layer.max_norm is None, 'Not supported'
assert pre_quant_layer.norm_type == 2, 'Not supported'
assert pre_quant_layer.scale_grad_by_freq == False, 'Not supported'
assert pre_quant_layer.sparse == False, 'Not supported'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def supports_config(config: DSEmbeddingsConfig) -> bool:
if config.use_token_type:
return False

if config.output_normalization != None:
if config.output_normalization is not None:
return False

try:
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def get_model_num_kv_heads(self, config):
for name in kv_head_names:
if hasattr(config, name):
num_kv_heads = getattr(config, name)
if num_kv_heads != None:
if num_kv_heads is not None:
break
return num_kv_heads

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/fusedqkv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def require_tp_fused_qkvw(name, mp_size):


def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):
if src == None:
if src is None:
return
fused_type_dict = {
'CodeGenBlock': 'codegentype',
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def skip_level_0_prefix(model, state_dict):
if key is None:
key = re.match(r"(.*?)Model", model)
# if keys start with 'model.', don't skip level 0 prefix
if state_dict != None:
if state_dict is not None:
for item in state_dict.keys():
if re.match("^model[.]", item):
return False
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def get_num_kv_heads():
def get_shard_size(total_size, mp_size, rank=None):
global num_kv_heads
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
if num_kv_heads != None:
if (rank == None):
if num_kv_heads is not None:
if rank is None:
rank = dist.get_rank()
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1984,7 +1984,7 @@ def step(self, closure=None):
# warn user about caching allocator flushes
memory_stats = get_accelerator().memory_stats()
alloc_retries = memory_stats.get("num_alloc_retries")
if alloc_retries == None:
if alloc_retries is None:
alloc_retries = 0
if alloc_retries > self.n_caching_allocator_flushes:
if dist.get_rank() == 0:
Expand Down Expand Up @@ -2541,7 +2541,7 @@ def load_state_dict(self,
# when use loading checkpoint serial, after finish loading, we need to
# delete the temp state_dict_list variable to save memory, then trigger
# the next rank's loading
if load_serial != None:
if load_serial is not None:
load_serial += 1
rank = dist.get_rank(group=self.dp_process_group)
local_rank = dist.get_local_rank()
Expand Down
2 changes: 1 addition & 1 deletion op_builder/cpu/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def is_compatible(self, verbose=True):

def extra_ldflags(self):
ccl_root_path = os.environ.get("CCL_ROOT")
if ccl_root_path == None:
if ccl_root_path is None:
raise ValueError(
"Didn't find CCL_ROOT, install oneCCL from https://github.com/oneapi-src/oneCCL and source its environment variable"
)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/launcher/test_ds_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_no_ds_arguments():
assert args.deepspeed == False

assert hasattr(args, 'deepspeed_config')
assert args.deepspeed_config == None
assert args.deepspeed_config is None


def test_no_ds_enable_argument():
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_no_ds_config_argument():
assert args.deepspeed == True

assert hasattr(args, 'deepspeed_config')
assert args.deepspeed_config == None
assert args.deepspeed_config is None


def test_no_ds_parser():
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/runtime/zero/test_zero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def test_zero_config_overlapcomm():

def test_zero_config_offload_configs():
config = DeepSpeedZeroConfig()
assert config.offload_param == None
assert config.offload_optimizer == None
assert config.offload_param is None
assert config.offload_optimizer is None

config = DeepSpeedZeroConfig(**{"offload_param": None, "offload_optimizer": None})
assert config.offload_param == None
assert config.offload_optimizer == None
assert config.offload_param is None
assert config.offload_optimizer is None

config = DeepSpeedZeroConfig(**{"offload_param": {}, "offload_optimizer": {}})
assert isinstance(config.offload_param, DeepSpeedZeroOffloadParamConfig)
Expand Down

0 comments on commit 7b818ee

Please sign in to comment.