diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py index a352401b4bc9..39675fa3305d 100644 --- a/accelerator/abstract_accelerator.py +++ b/accelerator/abstract_accelerator.py @@ -5,6 +5,7 @@ import abc from abc import ABC +from .constants import * class DeepSpeedAccelerator(ABC): @@ -12,12 +13,8 @@ class DeepSpeedAccelerator(ABC): def __init__(self): self._name = None self._communication_backend_name = None - self._capabilities: dict[str, bool] = { - "zero1": False, - "zero2": False, - "zero3": False, - "sparse_attn": False - } + self._capabilities: dict[str, bool] = {ZERO_1: False, ZERO_2: False, ZERO_3: False, SPARSE_ATTN: False} + @abc.abstractmethod def is_synchronized_device(self): ... @@ -297,4 +294,4 @@ def get_capability(self, key): return self._capabilities[key] def set_capability(self, key, value): - self._capabilities[key] = value \ No newline at end of file + self._capabilities[key] = value diff --git a/accelerator/constants.py b/accelerator/constants.py new file mode 100644 index 000000000000..5470248af7d5 --- /dev/null +++ b/accelerator/constants.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +#A list of constants used in the DeepSpeed capabilities dictionary +ZERO_1= "zero1" +ZERO_2= "zero2" +ZERO_3= "zero3" +SPARSE_ATTN= "sparse_attn" diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index b205728da2ed..2a40580e3bad 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -9,6 +9,8 @@ import importlib from .abstract_accelerator import DeepSpeedAccelerator +from .constants import * + # During setup stage torch may not be installed, pass on no torch will # allow op builder related API to be executed. try: @@ -28,10 +30,10 @@ def __init__(self): self._communication_backend_name = 'nccl' if pynvml is None: self._init_pynvml() - self.set_capability('zero1', True) - self.set_capability('zero2', True) - self.set_capability('zero3', True) - self.set_capability('sparse_attn', True) + self.set_capability(ZERO_1, True) + self.set_capability(ZERO_2, True) + self.set_capability(ZERO_3, True) + self.set_capability(SPARSE_ATTN, True) def _init_pynvml(self): global pynvml