diff --git a/autodist/kernel/synchronization/all_reduce_synchronizer.py b/autodist/kernel/synchronization/all_reduce_synchronizer.py index 5d63982..18f230f 100644 --- a/autodist/kernel/synchronization/all_reduce_synchronizer.py +++ b/autodist/kernel/synchronization/all_reduce_synchronizer.py @@ -18,6 +18,7 @@ from tensorflow.python import ops from tensorflow.python.framework import device_spec from tensorflow.python.ops import collective_ops +from tensorflow.python.framework.ops import Tensor # pylint: disable=unused-import import autodist from autodist.const import ENV diff --git a/autodist/kernel/synchronization/compressor.py b/autodist/kernel/synchronization/compressor.py index 4878d48..cce74f5 100644 --- a/autodist/kernel/synchronization/compressor.py +++ b/autodist/kernel/synchronization/compressor.py @@ -13,13 +13,12 @@ # limitations under the License. """Gradient Compressors for All-Reduce.""" +import copy from abc import ABC, abstractmethod from tensorflow.python.framework import dtypes from tensorflow.python.framework.ops import Tensor -from tensorflow.python.ops import collective_ops, math_ops - -#from tensorflow.python.ops import array_ops, collective_ops, linalg_ops, math_ops, random_ops -#from autodist.kernel.synchronization.collective_key import get_collective_keys +from tensorflow.python.ops import collective_ops, math_ops, random_ops, array_ops, linalg_ops +from autodist.kernel.synchronization.collective_key import get_collective_keys #from autodist.utils import logging @@ -205,80 +204,108 @@ class HorovodCompressorEF(CompressorEF, HorovodCompressor): # This works becaus """Horovod's Compression but with Error Feedback.""" -# class PowerSGDCompressor(CompressorEF): -# """An implementation of the PowerSGD compression algorithm (arxiv.org/abs/1905.13727).""" - -# def __init__(self, var_op_name, rank=1): -# self.rank = rank -# self.og_shape, self.ndims, self.new_shape, self.compressor = None, None, None, None -# super().__init__(var_op_name) - -# def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): -# """ -# Compress, reduce, and decompress a given tensor. - -# Args: -# tensor (Tensor): the Tensor to reduce. -# conf (CollectiveOpsConfig): the config for Collective Ops. - -# Returns: -# Reduced Tensor -# """ -# if self.og_shape is None: -# self.og_shape = tensor.shape -# self.ndims = len(self.og_shape) - -# # Check if rank 1 tensor (this shouldn't be called with sparse tensors) -# # Just reduce it if it is, no need to compress -# if self._is_1d: -# return self._all_reduce(tensor, conf) - -# logging.info(f"Compressing tensor {tensor.name} (var {self.var_op_name}) with shape {tensor.shape}") -# if self.ndims > 2: -# tensor = array_ops.reshape(tensor, [self.og_shape[0], -1]) - -# if self.compressor is None: -# self.new_shape = array_ops.shape_v2(tensor) -# self.compressor = random_ops.random_normal([self.new_shape[1], self.rank]) - -# if self.error is not None: -# tensor += self.error - -# compressed_tensor = self._compress(tensor) -# self.error = tensor - self._decompress(compressed_tensor) - -# # all reduce mean p -# reduced = self._all_reduce(compressed_tensor, conf) -# reduced = self._orthogonalize(reduced) - -# # update compressor -# self.compressor = math_ops.matmul(tensor, reduced, transpose_a=True) -# conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + "/compressor") -# self.compressor = self._all_reduce(self.compressor, conf) -# return array_ops.reshape(self._decompress(reduced), self.og_shape) \ -# if self.ndims > 2 else self._decompress(reduced) - -# def _compress(self, tensor: Tensor): -# return math_ops.matmul(tensor, self.compressor) - -# def _decompress(self, compressed_tensor: Tensor): -# return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) - -# @property -# def _is_1d(self): -# return self.ndims <= 1 or ( -# self.ndims == 2 and any(d == 1 for d in self.og_shape) -# ) - -# @staticmethod -# def _orthogonalize(matrix): -# _, m = matrix.shape -# for i in range(m): -# v = matrix[:, i] -# v /= linalg_ops.norm_v2(v) -# v = array_ops.expand_dims_v2(v, 1) - -# begin, rest = matrix[:, :i], matrix[:, (i + 1):] -# rest -= math_ops.matmul(v, rest, transpose_a=True) * v -# matrix = array_ops.concat([begin, v, rest], 1) -# return matrix +class PowerSGDCompressor(CompressorEF): + """An implementation of the PowerSGD compression algorithm (arxiv.org/abs/1905.13727).""" + + def __init__(self, var_op_name, rank=1): + self.rank = rank + self.og_shape, self.ndims = None, None + self.compressor, self.compressor_conf = None, None # compressor is the Q in paper + self.var_op_name = var_op_name + super().__init__(var_op_name) + + def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): + """ + Compress, reduce, and decompress a given tensor. + + Args: + tensor (Tensor): the Tensor to reduce. + conf (CollectiveOpsConfig): the config for Collective Ops. + + Returns: + Reduced Tensor + """ + if self.og_shape is None: + self.og_shape = array_ops.shape_v2(tensor) + if self.og_shape.shape[0] is None: + self.ndims = 0 + else: + self.ndims = self.og_shape.shape[0] + + # rank <= 1 + if self.ndims <= 1 or (self.ndims == 2 and any([i == 1 for i in tensor.get_shape().as_list()])): + return self._all_reduce(tensor, conf) + + og_dtype = tensor.dtype + tensor = array_ops.reshape(math_ops.cast(tensor, dtypes.float32), [self.og_shape[0], -1]) + + # compressor init + if self.compressor is None: + self.compressor = random_ops.random_normal([array_ops.shape_v2(tensor)[1], self.rank], + seed=1000, dtype=dtypes.float32) + + self.compressor_conf = copy.copy(conf) + self.compressor_conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor') + + if self.error is not None: + tensor += self.error + + compressed_tensor = self._compress(tensor) + self.error = tensor - self._decompress(compressed_tensor) + + reduced_tensor = self._all_reduce(compressed_tensor, conf) + + orthonormal_reduced_tensor = self._modified_gram_schmidt(reduced_tensor) + + self.compressor = math_ops.matmul(tensor, orthonormal_reduced_tensor, transpose_a=True) # mxn * nxr => mxr + + # all reduce mean compressor + self.compressor = self._all_reduce(self.compressor, self.compressor_conf) + + return math_ops.cast(array_ops.reshape(self._decompress(orthonormal_reduced_tensor), self.og_shape), og_dtype) + + def _compress(self, tensor: Tensor): + """ + Compress a given tensor. + + Args: + tensor (Tensor): the Tensor to compress. + + Returns: + Tensor + """ + return math_ops.matmul(tensor, self.compressor) # nxm * mxr => nxr + + def _decompress(self, compressed_tensor: Tensor): + """ + Decompress a given tensor. + + Args: + compressed_tensor (Tensor): the Tensor to decompress. + + Returns: + Tensor, Context + """ + return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) # nxr * rxm => nxm + + @staticmethod + def _modified_gram_schmidt(matrix): + """ + Apply modified Gram-Schmidt procedure to orthogonalize a matrix in columns. + + Args: + matrix (Tensor): the Tensor to orthogonalize. + + Returns: + matrix (Tensor) + """ + _, m = matrix.shape + + for i in range(m): + v = matrix[:, i:(i + 1)] + v /= linalg_ops.norm_v2(v, axis=0) + + rest = matrix[:, (i + 1):] + rest -= math_ops.reduce_sum_v1(v * rest, axis=0, keepdims=True) * v + matrix = array_ops.concat([matrix[:, :i], v, rest], 1) + return matrix diff --git a/autodist/proto/synchronizers.proto b/autodist/proto/synchronizers.proto index f70996f..1813e3a 100644 --- a/autodist/proto/synchronizers.proto +++ b/autodist/proto/synchronizers.proto @@ -47,7 +47,7 @@ message AllReduceSynchronizer { NoneCompressor = 0; // No compression HorovodCompressor = 1; // Horovod's Compression HorovodCompressorEF = 2; // Horovod's Compression but with Error Feedback. - // PowerSGDCompressor = 3; // PowerSGD compression algorithm (arxiv.org/abs/1905.13727) + PowerSGDCompressor = 3; // PowerSGD compression algorithm (arxiv.org/abs/1905.13727) } Compressor compressor = 2; // One of the compressors to choose diff --git a/tests/integration/test_all.py b/tests/integration/test_all.py index 24bd863..bda23ca 100644 --- a/tests/integration/test_all.py +++ b/tests/integration/test_all.py @@ -38,6 +38,7 @@ AllReduce(chunk_size=1, all_reduce_spec='NCCL', compressor='NoneCompressor'), AllReduce(chunk_size=1, all_reduce_spec='NCCL', compressor='HorovodCompressor'), AllReduce(chunk_size=1, all_reduce_spec='RING', compressor='HorovodCompressorEF'), + AllReduce(chunk_size=1, all_reduce_spec='RING', compressor='PowerSGDCompressor'), PSLoadBalancing(local_proxy_variable=True), Parallax(local_proxy_variable=True), PartitionedAR(),