Skip to content

Commit

Permalink
define symbolic constants
Browse files Browse the repository at this point in the history
  • Loading branch information
duli2012 committed Apr 17, 2024
1 parent e2327a1 commit 49cbda7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
11 changes: 4 additions & 7 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,16 @@

import abc
from abc import ABC
from .constants import *


class DeepSpeedAccelerator(ABC):

def __init__(self):
self._name = None
self._communication_backend_name = None
self._capabilities: dict[str, bool] = {
"zero1": False,
"zero2": False,
"zero3": False,
"sparse_attn": False
}
self._capabilities: dict[str, bool] = {ZERO_1: False, ZERO_2: False, ZERO_3: False, SPARSE_ATTN: False}

@abc.abstractmethod
def is_synchronized_device(self):
...
Expand Down Expand Up @@ -297,4 +294,4 @@ def get_capability(self, key):
return self._capabilities[key]

def set_capability(self, key, value):
self._capabilities[key] = value
self._capabilities[key] = value
10 changes: 10 additions & 0 deletions accelerator/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

#A list of constants used in the DeepSpeed capabilities dictionary
ZERO_1= "zero1"
ZERO_2= "zero2"
ZERO_3= "zero3"
SPARSE_ATTN= "sparse_attn"
10 changes: 6 additions & 4 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import importlib

from .abstract_accelerator import DeepSpeedAccelerator
from .constants import *

# During setup stage torch may not be installed, pass on no torch will
# allow op builder related API to be executed.
try:
Expand All @@ -28,10 +30,10 @@ def __init__(self):
self._communication_backend_name = 'nccl'
if pynvml is None:
self._init_pynvml()
self.set_capability('zero1', True)
self.set_capability('zero2', True)
self.set_capability('zero3', True)
self.set_capability('sparse_attn', True)
self.set_capability(ZERO_1, True)
self.set_capability(ZERO_2, True)
self.set_capability(ZERO_3, True)
self.set_capability(SPARSE_ATTN, True)

def _init_pynvml(self):
global pynvml
Expand Down

0 comments on commit 49cbda7

Please sign in to comment.