From b4d48e3649e4d2537a9ad1b79b10448a369e69d5 Mon Sep 17 00:00:00 2001 From: Ezra-H <44772185+Ezra-H@users.noreply.github.com> Date: Mon, 26 Oct 2020 23:28:06 +0800 Subject: [PATCH 01/12] Add PowerSGDCompressor --- autodist/kernel/synchronization/compressor.py | 98 +++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/autodist/kernel/synchronization/compressor.py b/autodist/kernel/synchronization/compressor.py index 4878d48..1478746 100644 --- a/autodist/kernel/synchronization/compressor.py +++ b/autodist/kernel/synchronization/compressor.py @@ -205,6 +205,104 @@ class HorovodCompressorEF(CompressorEF, HorovodCompressor): # This works becaus """Horovod's Compression but with Error Feedback.""" +class PowerSGDCompressor(CompressorEF): + """An implementation of the PowerSGD compression algorithm (arxiv.org/abs/1905.13727).""" + def __init__(self, var_op_name, rank=1): + self.rank = rank + self.og_shape, self.ndims, self.compressor = None, None, None # compressor is the Q in paper + self.var_op_name = var_op_name + super.__init__(var_op_name) + + def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): + """ + Compress, reduce, and decompress a given tensor. + + Args: + tensor (Tensor): the Tensor to reduce. + conf (CollectiveOpsConfig): the config for Collective Ops. + + Returns: + Reduced Tensor + """ + if self.og_shape is None: + self.og_shape = tensor.shape + self.ndims = len(self.og_shape) + + # rank <= 1 + if self.ndims <= 1 or (self.ndims==2 and any([d == 1 for d in self.og_shape])): + return self._all_reduce(tensor, conf) + + # compressor init + if self.compressor is None: + self.compressor = random_ops.random_normal([array_ops.shape_v2(tensor)[1], self.rank]) + + if self.error is not None: + tensor += self.error + + compressed_tensor = self._compress(tensor) + self.error = tensor - self._decompress(compressed_tensor) + + reduced_tensor = self._all_reduce(compressed_tensor, conf) + + orthonormal_reduced_tensor = self._modified_gram_schmidt(reduced_tensor) + + self.compressor = math_ops.matmul(tensor, orthonormal_reduced_tensor, transpose_a=True) # mxn * nxr => mxr + + # all reduce mean compressor + instance_key = conf.instance_key + conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor') + self.compressor = self._all_reduce(self.compressor, conf) + conf.instance_key = instance_key + + return self._decompress(orthonormal_reduced_tensor) + + def _compress(self, tensor: Tensor): + """ + Compress a given tensor. + + Args: + tensor (Tensor): the Tensor to compress. + + Returns: + Tensor + """ + return math_ops.matmul(tensor, self.compressor) # nxm * mxr => nxr + + def _decompress(self, compressed_tensor: Tensor): + """ + Decompress a given tensor. + + Args: + compressed_tensor (Tensor): the Tensor to decompress. + + Returns: + Tensor, Context + """ + return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) # nxr * rxm = nxm + + @staticmethod + def _modified_gram_schmidt(matrix): + ''' + apply modified Gram-Schmidt procedure to orthogonalize a matrix in columns + + Args: + matrix (Tensor): the Tensor to orthogonalize. + + Returns: + matrix (Tensor) + ''' + n, m = matrix.shape + + for i in range(m): + v = matrix[:, i:i+1] + v /= linalg_ops.norm_v2(v, axis=0) + + rest = matrix[:,i+1:] + rest -= math_ops.reduce_sum_v1(v * rest, axis=0, keepdims=True) * v + matrix = array_ops.concat([matrix[:,:i], v, rest],axis=1) + return matrix + + # class PowerSGDCompressor(CompressorEF): # """An implementation of the PowerSGD compression algorithm (arxiv.org/abs/1905.13727).""" From c774e460825ef932ebbbd0de9cee2fa7f63b46f6 Mon Sep 17 00:00:00 2001 From: Ezra-H <44772185+Ezra-H@users.noreply.github.com> Date: Wed, 28 Oct 2020 14:24:30 +0800 Subject: [PATCH 02/12] lintering * linter * update import * delete the commented code * synchronize the compressor at first --- autodist/kernel/synchronization/compressor.py | 124 ++++-------------- 1 file changed, 24 insertions(+), 100 deletions(-) diff --git a/autodist/kernel/synchronization/compressor.py b/autodist/kernel/synchronization/compressor.py index 1478746..1548878 100644 --- a/autodist/kernel/synchronization/compressor.py +++ b/autodist/kernel/synchronization/compressor.py @@ -13,13 +13,12 @@ # limitations under the License. """Gradient Compressors for All-Reduce.""" +import copy from abc import ABC, abstractmethod from tensorflow.python.framework import dtypes from tensorflow.python.framework.ops import Tensor -from tensorflow.python.ops import collective_ops, math_ops - -#from tensorflow.python.ops import array_ops, collective_ops, linalg_ops, math_ops, random_ops -#from autodist.kernel.synchronization.collective_key import get_collective_keys +from tensorflow.python.ops import collective_ops, math_ops, random_ops, array_ops, linalg_ops +from autodist.kernel.synchronization.collective_key import get_collective_keys #from autodist.utils import logging @@ -207,11 +206,13 @@ class HorovodCompressorEF(CompressorEF, HorovodCompressor): # This works becaus class PowerSGDCompressor(CompressorEF): """An implementation of the PowerSGD compression algorithm (arxiv.org/abs/1905.13727).""" + def __init__(self, var_op_name, rank=1): self.rank = rank - self.og_shape, self.ndims, self.compressor = None, None, None # compressor is the Q in paper + self.og_shape, self.ndims = None, None + self.compressor, self.compressor_conf = None, None # compressor is the Q in paper self.var_op_name = var_op_name - super.__init__(var_op_name) + super().__init__(var_op_name) def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): """ @@ -229,13 +230,18 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): self.ndims = len(self.og_shape) # rank <= 1 - if self.ndims <= 1 or (self.ndims==2 and any([d == 1 for d in self.og_shape])): + if self.ndims <= 1 or (self.ndims == 2 and any([d == 1 for d in self.og_shape])): return self._all_reduce(tensor, conf) # compressor init if self.compressor is None: self.compressor = random_ops.random_normal([array_ops.shape_v2(tensor)[1], self.rank]) + # synchronize compressor init statue + self.compressor_conf = copy.copy(conf) + self.conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor') + self.compressor = self._all_reduce(self.compressor, self.compressor_conf) + if self.error is not None: tensor += self.error @@ -246,13 +252,10 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): orthonormal_reduced_tensor = self._modified_gram_schmidt(reduced_tensor) - self.compressor = math_ops.matmul(tensor, orthonormal_reduced_tensor, transpose_a=True) # mxn * nxr => mxr + self.compressor = math_ops.matmul(tensor, orthonormal_reduced_tensor, transpose_a=True) # mxn * nxr => mxr # all reduce mean compressor - instance_key = conf.instance_key - conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor') - self.compressor = self._all_reduce(self.compressor, conf) - conf.instance_key = instance_key + self.compressor = self._all_reduce(self.compressor, self.compressor_conf) return self._decompress(orthonormal_reduced_tensor) @@ -278,105 +281,26 @@ def _decompress(self, compressed_tensor: Tensor): Returns: Tensor, Context """ - return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) # nxr * rxm = nxm + return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) # nxr * rxm => nxm @staticmethod def _modified_gram_schmidt(matrix): - ''' - apply modified Gram-Schmidt procedure to orthogonalize a matrix in columns + """ + Apply modified Gram-Schmidt procedure to orthogonalize a matrix in columns. Args: matrix (Tensor): the Tensor to orthogonalize. Returns: - matrix (Tensor) - ''' - n, m = matrix.shape + matrix (Tensor) + """ + _, m = matrix.shape for i in range(m): - v = matrix[:, i:i+1] + v = matrix[:, i:(i + 1)] v /= linalg_ops.norm_v2(v, axis=0) - rest = matrix[:,i+1:] + rest = matrix[:, (i + 1):] rest -= math_ops.reduce_sum_v1(v * rest, axis=0, keepdims=True) * v - matrix = array_ops.concat([matrix[:,:i], v, rest],axis=1) + matrix = array_ops.concat([matrix[:, :i], v, rest], axis=1) return matrix - - -# class PowerSGDCompressor(CompressorEF): -# """An implementation of the PowerSGD compression algorithm (arxiv.org/abs/1905.13727).""" - -# def __init__(self, var_op_name, rank=1): -# self.rank = rank -# self.og_shape, self.ndims, self.new_shape, self.compressor = None, None, None, None -# super().__init__(var_op_name) - -# def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): -# """ -# Compress, reduce, and decompress a given tensor. - -# Args: -# tensor (Tensor): the Tensor to reduce. -# conf (CollectiveOpsConfig): the config for Collective Ops. - -# Returns: -# Reduced Tensor -# """ -# if self.og_shape is None: -# self.og_shape = tensor.shape -# self.ndims = len(self.og_shape) - -# # Check if rank 1 tensor (this shouldn't be called with sparse tensors) -# # Just reduce it if it is, no need to compress -# if self._is_1d: -# return self._all_reduce(tensor, conf) - -# logging.info(f"Compressing tensor {tensor.name} (var {self.var_op_name}) with shape {tensor.shape}") -# if self.ndims > 2: -# tensor = array_ops.reshape(tensor, [self.og_shape[0], -1]) - -# if self.compressor is None: -# self.new_shape = array_ops.shape_v2(tensor) -# self.compressor = random_ops.random_normal([self.new_shape[1], self.rank]) - -# if self.error is not None: -# tensor += self.error - -# compressed_tensor = self._compress(tensor) -# self.error = tensor - self._decompress(compressed_tensor) - -# # all reduce mean p -# reduced = self._all_reduce(compressed_tensor, conf) -# reduced = self._orthogonalize(reduced) - -# # update compressor -# self.compressor = math_ops.matmul(tensor, reduced, transpose_a=True) -# conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + "/compressor") -# self.compressor = self._all_reduce(self.compressor, conf) -# return array_ops.reshape(self._decompress(reduced), self.og_shape) \ -# if self.ndims > 2 else self._decompress(reduced) - -# def _compress(self, tensor: Tensor): -# return math_ops.matmul(tensor, self.compressor) - -# def _decompress(self, compressed_tensor: Tensor): -# return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) - -# @property -# def _is_1d(self): -# return self.ndims <= 1 or ( -# self.ndims == 2 and any(d == 1 for d in self.og_shape) -# ) - -# @staticmethod -# def _orthogonalize(matrix): -# _, m = matrix.shape -# for i in range(m): -# v = matrix[:, i] -# v /= linalg_ops.norm_v2(v) -# v = array_ops.expand_dims_v2(v, 1) - -# begin, rest = matrix[:, :i], matrix[:, (i + 1):] -# rest -= math_ops.matmul(v, rest, transpose_a=True) * v -# matrix = array_ops.concat([begin, v, rest], 1) -# return matrix From 835ff73fcae7432dc05c75458af7ef237258e4b1 Mon Sep 17 00:00:00 2001 From: Ezra-H <44772185+Ezra-H@users.noreply.github.com> Date: Wed, 28 Oct 2020 14:53:33 +0800 Subject: [PATCH 03/12] Revert "lintering" This reverts commit c774e460825ef932ebbbd0de9cee2fa7f63b46f6. --- autodist/kernel/synchronization/compressor.py | 124 ++++++++++++++---- 1 file changed, 100 insertions(+), 24 deletions(-) diff --git a/autodist/kernel/synchronization/compressor.py b/autodist/kernel/synchronization/compressor.py index 1548878..1478746 100644 --- a/autodist/kernel/synchronization/compressor.py +++ b/autodist/kernel/synchronization/compressor.py @@ -13,12 +13,13 @@ # limitations under the License. """Gradient Compressors for All-Reduce.""" -import copy from abc import ABC, abstractmethod from tensorflow.python.framework import dtypes from tensorflow.python.framework.ops import Tensor -from tensorflow.python.ops import collective_ops, math_ops, random_ops, array_ops, linalg_ops -from autodist.kernel.synchronization.collective_key import get_collective_keys +from tensorflow.python.ops import collective_ops, math_ops + +#from tensorflow.python.ops import array_ops, collective_ops, linalg_ops, math_ops, random_ops +#from autodist.kernel.synchronization.collective_key import get_collective_keys #from autodist.utils import logging @@ -206,13 +207,11 @@ class HorovodCompressorEF(CompressorEF, HorovodCompressor): # This works becaus class PowerSGDCompressor(CompressorEF): """An implementation of the PowerSGD compression algorithm (arxiv.org/abs/1905.13727).""" - def __init__(self, var_op_name, rank=1): self.rank = rank - self.og_shape, self.ndims = None, None - self.compressor, self.compressor_conf = None, None # compressor is the Q in paper + self.og_shape, self.ndims, self.compressor = None, None, None # compressor is the Q in paper self.var_op_name = var_op_name - super().__init__(var_op_name) + super.__init__(var_op_name) def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): """ @@ -230,18 +229,13 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): self.ndims = len(self.og_shape) # rank <= 1 - if self.ndims <= 1 or (self.ndims == 2 and any([d == 1 for d in self.og_shape])): + if self.ndims <= 1 or (self.ndims==2 and any([d == 1 for d in self.og_shape])): return self._all_reduce(tensor, conf) # compressor init if self.compressor is None: self.compressor = random_ops.random_normal([array_ops.shape_v2(tensor)[1], self.rank]) - # synchronize compressor init statue - self.compressor_conf = copy.copy(conf) - self.conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor') - self.compressor = self._all_reduce(self.compressor, self.compressor_conf) - if self.error is not None: tensor += self.error @@ -252,10 +246,13 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): orthonormal_reduced_tensor = self._modified_gram_schmidt(reduced_tensor) - self.compressor = math_ops.matmul(tensor, orthonormal_reduced_tensor, transpose_a=True) # mxn * nxr => mxr + self.compressor = math_ops.matmul(tensor, orthonormal_reduced_tensor, transpose_a=True) # mxn * nxr => mxr # all reduce mean compressor - self.compressor = self._all_reduce(self.compressor, self.compressor_conf) + instance_key = conf.instance_key + conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor') + self.compressor = self._all_reduce(self.compressor, conf) + conf.instance_key = instance_key return self._decompress(orthonormal_reduced_tensor) @@ -281,26 +278,105 @@ def _decompress(self, compressed_tensor: Tensor): Returns: Tensor, Context """ - return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) # nxr * rxm => nxm + return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) # nxr * rxm = nxm @staticmethod def _modified_gram_schmidt(matrix): - """ - Apply modified Gram-Schmidt procedure to orthogonalize a matrix in columns. + ''' + apply modified Gram-Schmidt procedure to orthogonalize a matrix in columns Args: matrix (Tensor): the Tensor to orthogonalize. Returns: - matrix (Tensor) - """ - _, m = matrix.shape + matrix (Tensor) + ''' + n, m = matrix.shape for i in range(m): - v = matrix[:, i:(i + 1)] + v = matrix[:, i:i+1] v /= linalg_ops.norm_v2(v, axis=0) - rest = matrix[:, (i + 1):] + rest = matrix[:,i+1:] rest -= math_ops.reduce_sum_v1(v * rest, axis=0, keepdims=True) * v - matrix = array_ops.concat([matrix[:, :i], v, rest], axis=1) + matrix = array_ops.concat([matrix[:,:i], v, rest],axis=1) return matrix + + +# class PowerSGDCompressor(CompressorEF): +# """An implementation of the PowerSGD compression algorithm (arxiv.org/abs/1905.13727).""" + +# def __init__(self, var_op_name, rank=1): +# self.rank = rank +# self.og_shape, self.ndims, self.new_shape, self.compressor = None, None, None, None +# super().__init__(var_op_name) + +# def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): +# """ +# Compress, reduce, and decompress a given tensor. + +# Args: +# tensor (Tensor): the Tensor to reduce. +# conf (CollectiveOpsConfig): the config for Collective Ops. + +# Returns: +# Reduced Tensor +# """ +# if self.og_shape is None: +# self.og_shape = tensor.shape +# self.ndims = len(self.og_shape) + +# # Check if rank 1 tensor (this shouldn't be called with sparse tensors) +# # Just reduce it if it is, no need to compress +# if self._is_1d: +# return self._all_reduce(tensor, conf) + +# logging.info(f"Compressing tensor {tensor.name} (var {self.var_op_name}) with shape {tensor.shape}") +# if self.ndims > 2: +# tensor = array_ops.reshape(tensor, [self.og_shape[0], -1]) + +# if self.compressor is None: +# self.new_shape = array_ops.shape_v2(tensor) +# self.compressor = random_ops.random_normal([self.new_shape[1], self.rank]) + +# if self.error is not None: +# tensor += self.error + +# compressed_tensor = self._compress(tensor) +# self.error = tensor - self._decompress(compressed_tensor) + +# # all reduce mean p +# reduced = self._all_reduce(compressed_tensor, conf) +# reduced = self._orthogonalize(reduced) + +# # update compressor +# self.compressor = math_ops.matmul(tensor, reduced, transpose_a=True) +# conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + "/compressor") +# self.compressor = self._all_reduce(self.compressor, conf) +# return array_ops.reshape(self._decompress(reduced), self.og_shape) \ +# if self.ndims > 2 else self._decompress(reduced) + +# def _compress(self, tensor: Tensor): +# return math_ops.matmul(tensor, self.compressor) + +# def _decompress(self, compressed_tensor: Tensor): +# return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) + +# @property +# def _is_1d(self): +# return self.ndims <= 1 or ( +# self.ndims == 2 and any(d == 1 for d in self.og_shape) +# ) + +# @staticmethod +# def _orthogonalize(matrix): +# _, m = matrix.shape +# for i in range(m): +# v = matrix[:, i] +# v /= linalg_ops.norm_v2(v) +# v = array_ops.expand_dims_v2(v, 1) + +# begin, rest = matrix[:, :i], matrix[:, (i + 1):] +# rest -= math_ops.matmul(v, rest, transpose_a=True) * v +# matrix = array_ops.concat([begin, v, rest], 1) +# return matrix From c228567f4dadd0bc829cad95c9ca149157759fa0 Mon Sep 17 00:00:00 2001 From: Ezra-H <44772185+Ezra-H@users.noreply.github.com> Date: Wed, 28 Oct 2020 14:58:15 +0800 Subject: [PATCH 04/12] Update PowerSGDCompressor --- autodist/kernel/synchronization/compressor.py | 124 ++++-------------- 1 file changed, 24 insertions(+), 100 deletions(-) diff --git a/autodist/kernel/synchronization/compressor.py b/autodist/kernel/synchronization/compressor.py index 1478746..1548878 100644 --- a/autodist/kernel/synchronization/compressor.py +++ b/autodist/kernel/synchronization/compressor.py @@ -13,13 +13,12 @@ # limitations under the License. """Gradient Compressors for All-Reduce.""" +import copy from abc import ABC, abstractmethod from tensorflow.python.framework import dtypes from tensorflow.python.framework.ops import Tensor -from tensorflow.python.ops import collective_ops, math_ops - -#from tensorflow.python.ops import array_ops, collective_ops, linalg_ops, math_ops, random_ops -#from autodist.kernel.synchronization.collective_key import get_collective_keys +from tensorflow.python.ops import collective_ops, math_ops, random_ops, array_ops, linalg_ops +from autodist.kernel.synchronization.collective_key import get_collective_keys #from autodist.utils import logging @@ -207,11 +206,13 @@ class HorovodCompressorEF(CompressorEF, HorovodCompressor): # This works becaus class PowerSGDCompressor(CompressorEF): """An implementation of the PowerSGD compression algorithm (arxiv.org/abs/1905.13727).""" + def __init__(self, var_op_name, rank=1): self.rank = rank - self.og_shape, self.ndims, self.compressor = None, None, None # compressor is the Q in paper + self.og_shape, self.ndims = None, None + self.compressor, self.compressor_conf = None, None # compressor is the Q in paper self.var_op_name = var_op_name - super.__init__(var_op_name) + super().__init__(var_op_name) def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): """ @@ -229,13 +230,18 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): self.ndims = len(self.og_shape) # rank <= 1 - if self.ndims <= 1 or (self.ndims==2 and any([d == 1 for d in self.og_shape])): + if self.ndims <= 1 or (self.ndims == 2 and any([d == 1 for d in self.og_shape])): return self._all_reduce(tensor, conf) # compressor init if self.compressor is None: self.compressor = random_ops.random_normal([array_ops.shape_v2(tensor)[1], self.rank]) + # synchronize compressor init statue + self.compressor_conf = copy.copy(conf) + self.conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor') + self.compressor = self._all_reduce(self.compressor, self.compressor_conf) + if self.error is not None: tensor += self.error @@ -246,13 +252,10 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): orthonormal_reduced_tensor = self._modified_gram_schmidt(reduced_tensor) - self.compressor = math_ops.matmul(tensor, orthonormal_reduced_tensor, transpose_a=True) # mxn * nxr => mxr + self.compressor = math_ops.matmul(tensor, orthonormal_reduced_tensor, transpose_a=True) # mxn * nxr => mxr # all reduce mean compressor - instance_key = conf.instance_key - conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor') - self.compressor = self._all_reduce(self.compressor, conf) - conf.instance_key = instance_key + self.compressor = self._all_reduce(self.compressor, self.compressor_conf) return self._decompress(orthonormal_reduced_tensor) @@ -278,105 +281,26 @@ def _decompress(self, compressed_tensor: Tensor): Returns: Tensor, Context """ - return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) # nxr * rxm = nxm + return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) # nxr * rxm => nxm @staticmethod def _modified_gram_schmidt(matrix): - ''' - apply modified Gram-Schmidt procedure to orthogonalize a matrix in columns + """ + Apply modified Gram-Schmidt procedure to orthogonalize a matrix in columns. Args: matrix (Tensor): the Tensor to orthogonalize. Returns: - matrix (Tensor) - ''' - n, m = matrix.shape + matrix (Tensor) + """ + _, m = matrix.shape for i in range(m): - v = matrix[:, i:i+1] + v = matrix[:, i:(i + 1)] v /= linalg_ops.norm_v2(v, axis=0) - rest = matrix[:,i+1:] + rest = matrix[:, (i + 1):] rest -= math_ops.reduce_sum_v1(v * rest, axis=0, keepdims=True) * v - matrix = array_ops.concat([matrix[:,:i], v, rest],axis=1) + matrix = array_ops.concat([matrix[:, :i], v, rest], axis=1) return matrix - - -# class PowerSGDCompressor(CompressorEF): -# """An implementation of the PowerSGD compression algorithm (arxiv.org/abs/1905.13727).""" - -# def __init__(self, var_op_name, rank=1): -# self.rank = rank -# self.og_shape, self.ndims, self.new_shape, self.compressor = None, None, None, None -# super().__init__(var_op_name) - -# def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): -# """ -# Compress, reduce, and decompress a given tensor. - -# Args: -# tensor (Tensor): the Tensor to reduce. -# conf (CollectiveOpsConfig): the config for Collective Ops. - -# Returns: -# Reduced Tensor -# """ -# if self.og_shape is None: -# self.og_shape = tensor.shape -# self.ndims = len(self.og_shape) - -# # Check if rank 1 tensor (this shouldn't be called with sparse tensors) -# # Just reduce it if it is, no need to compress -# if self._is_1d: -# return self._all_reduce(tensor, conf) - -# logging.info(f"Compressing tensor {tensor.name} (var {self.var_op_name}) with shape {tensor.shape}") -# if self.ndims > 2: -# tensor = array_ops.reshape(tensor, [self.og_shape[0], -1]) - -# if self.compressor is None: -# self.new_shape = array_ops.shape_v2(tensor) -# self.compressor = random_ops.random_normal([self.new_shape[1], self.rank]) - -# if self.error is not None: -# tensor += self.error - -# compressed_tensor = self._compress(tensor) -# self.error = tensor - self._decompress(compressed_tensor) - -# # all reduce mean p -# reduced = self._all_reduce(compressed_tensor, conf) -# reduced = self._orthogonalize(reduced) - -# # update compressor -# self.compressor = math_ops.matmul(tensor, reduced, transpose_a=True) -# conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + "/compressor") -# self.compressor = self._all_reduce(self.compressor, conf) -# return array_ops.reshape(self._decompress(reduced), self.og_shape) \ -# if self.ndims > 2 else self._decompress(reduced) - -# def _compress(self, tensor: Tensor): -# return math_ops.matmul(tensor, self.compressor) - -# def _decompress(self, compressed_tensor: Tensor): -# return math_ops.matmul(compressed_tensor, self.compressor, transpose_b=True) - -# @property -# def _is_1d(self): -# return self.ndims <= 1 or ( -# self.ndims == 2 and any(d == 1 for d in self.og_shape) -# ) - -# @staticmethod -# def _orthogonalize(matrix): -# _, m = matrix.shape -# for i in range(m): -# v = matrix[:, i] -# v /= linalg_ops.norm_v2(v) -# v = array_ops.expand_dims_v2(v, 1) - -# begin, rest = matrix[:, :i], matrix[:, (i + 1):] -# rest -= math_ops.matmul(v, rest, transpose_a=True) * v -# matrix = array_ops.concat([begin, v, rest], 1) -# return matrix From 508dc322b26b65412333834bd01045aab8a2f077 Mon Sep 17 00:00:00 2001 From: Ezra-H <44772185+Ezra-H@users.noreply.github.com> Date: Wed, 28 Oct 2020 16:06:31 +0800 Subject: [PATCH 05/12] Add random seed for initialization of compressor --- autodist/kernel/synchronization/compressor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/autodist/kernel/synchronization/compressor.py b/autodist/kernel/synchronization/compressor.py index 1548878..868b542 100644 --- a/autodist/kernel/synchronization/compressor.py +++ b/autodist/kernel/synchronization/compressor.py @@ -235,12 +235,10 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): # compressor init if self.compressor is None: - self.compressor = random_ops.random_normal([array_ops.shape_v2(tensor)[1], self.rank]) + self.compressor = random_ops.random_normal([array_ops.shape_v2(tensor)[1], self.rank], seed=1000) - # synchronize compressor init statue self.compressor_conf = copy.copy(conf) self.conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor') - self.compressor = self._all_reduce(self.compressor, self.compressor_conf) if self.error is not None: tensor += self.error From 414b91bbb3fe8c2a6966ea4a159baf248deecd55 Mon Sep 17 00:00:00 2001 From: Ezra-H <44772185+Ezra-H@users.noreply.github.com> Date: Sun, 1 Nov 2020 15:21:10 +0800 Subject: [PATCH 06/12] Update compressor.py --- autodist/kernel/synchronization/compressor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autodist/kernel/synchronization/compressor.py b/autodist/kernel/synchronization/compressor.py index 868b542..917f553 100644 --- a/autodist/kernel/synchronization/compressor.py +++ b/autodist/kernel/synchronization/compressor.py @@ -300,5 +300,5 @@ def _modified_gram_schmidt(matrix): rest = matrix[:, (i + 1):] rest -= math_ops.reduce_sum_v1(v * rest, axis=0, keepdims=True) * v - matrix = array_ops.concat([matrix[:, :i], v, rest], axis=1) + matrix = array_ops.concat([matrix[:, :i], v, rest], 1) return matrix From 64575ac9d85ab8d7930fc405c8aae16633412c5d Mon Sep 17 00:00:00 2001 From: Ezra-H <44772185+Ezra-H@users.noreply.github.com> Date: Tue, 3 Nov 2020 10:04:27 +0800 Subject: [PATCH 07/12] Import Tensor --- autodist/kernel/synchronization/all_reduce_synchronizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/autodist/kernel/synchronization/all_reduce_synchronizer.py b/autodist/kernel/synchronization/all_reduce_synchronizer.py index b186f51..4e11afd 100644 --- a/autodist/kernel/synchronization/all_reduce_synchronizer.py +++ b/autodist/kernel/synchronization/all_reduce_synchronizer.py @@ -18,6 +18,7 @@ from tensorflow.python import ops from tensorflow.python.framework import device_spec from tensorflow.python.ops import collective_ops +from tensorflow.python.framework.ops import Tensor import autodist from autodist.const import ENV From bfc289e4ab0d6ae9657d44a38eb344ad3511b51c Mon Sep 17 00:00:00 2001 From: Ezra-H <44772185+Ezra-H@users.noreply.github.com> Date: Wed, 4 Nov 2020 09:41:58 +0800 Subject: [PATCH 08/12] disable unused-import error --- autodist/kernel/synchronization/all_reduce_synchronizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autodist/kernel/synchronization/all_reduce_synchronizer.py b/autodist/kernel/synchronization/all_reduce_synchronizer.py index 4e11afd..e8e3891 100644 --- a/autodist/kernel/synchronization/all_reduce_synchronizer.py +++ b/autodist/kernel/synchronization/all_reduce_synchronizer.py @@ -18,7 +18,7 @@ from tensorflow.python import ops from tensorflow.python.framework import device_spec from tensorflow.python.ops import collective_ops -from tensorflow.python.framework.ops import Tensor +from tensorflow.python.framework.ops import Tensor # pylint: disable=unused-import import autodist from autodist.const import ENV From cbc19a56eb252b697f148701e4bf5a81b4d27e79 Mon Sep 17 00:00:00 2001 From: Ezra-H <44772185+Ezra-H@users.noreply.github.com> Date: Sat, 7 Nov 2020 16:39:29 +0800 Subject: [PATCH 09/12] add testing for PowerSGDCompressor --- autodist/kernel/synchronization/compressor.py | 15 ++++++++++----- autodist/proto/synchronizers.proto | 2 +- tests/integration/test_all.py | 1 + 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/autodist/kernel/synchronization/compressor.py b/autodist/kernel/synchronization/compressor.py index 917f553..c1a429b 100644 --- a/autodist/kernel/synchronization/compressor.py +++ b/autodist/kernel/synchronization/compressor.py @@ -226,19 +226,24 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): Reduced Tensor """ if self.og_shape is None: - self.og_shape = tensor.shape - self.ndims = len(self.og_shape) + self.og_shape = array_ops.shape_v2(tensor) + if self.og_shape.shape[0] is None: + self.ndims = 0 + else: + self.ndims = self.og_shape.shape[0] # rank <= 1 - if self.ndims <= 1 or (self.ndims == 2 and any([d == 1 for d in self.og_shape])): + if self.ndims <= 1: return self._all_reduce(tensor, conf) + tensor = array_ops.reshape(tensor, [array_ops.shape_v2(tensor)[0], -1]) + # compressor init if self.compressor is None: self.compressor = random_ops.random_normal([array_ops.shape_v2(tensor)[1], self.rank], seed=1000) self.compressor_conf = copy.copy(conf) - self.conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor') + self.compressor_conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor') if self.error is not None: tensor += self.error @@ -255,7 +260,7 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): # all reduce mean compressor self.compressor = self._all_reduce(self.compressor, self.compressor_conf) - return self._decompress(orthonormal_reduced_tensor) + return array_ops.reshape(self._decompress(orthonormal_reduced_tensor), self.og_shape) def _compress(self, tensor: Tensor): """ diff --git a/autodist/proto/synchronizers.proto b/autodist/proto/synchronizers.proto index f70996f..1813e3a 100644 --- a/autodist/proto/synchronizers.proto +++ b/autodist/proto/synchronizers.proto @@ -47,7 +47,7 @@ message AllReduceSynchronizer { NoneCompressor = 0; // No compression HorovodCompressor = 1; // Horovod's Compression HorovodCompressorEF = 2; // Horovod's Compression but with Error Feedback. - // PowerSGDCompressor = 3; // PowerSGD compression algorithm (arxiv.org/abs/1905.13727) + PowerSGDCompressor = 3; // PowerSGD compression algorithm (arxiv.org/abs/1905.13727) } Compressor compressor = 2; // One of the compressors to choose diff --git a/tests/integration/test_all.py b/tests/integration/test_all.py index 24bd863..bda23ca 100644 --- a/tests/integration/test_all.py +++ b/tests/integration/test_all.py @@ -38,6 +38,7 @@ AllReduce(chunk_size=1, all_reduce_spec='NCCL', compressor='NoneCompressor'), AllReduce(chunk_size=1, all_reduce_spec='NCCL', compressor='HorovodCompressor'), AllReduce(chunk_size=1, all_reduce_spec='RING', compressor='HorovodCompressorEF'), + AllReduce(chunk_size=1, all_reduce_spec='RING', compressor='PowerSGDCompressor'), PSLoadBalancing(local_proxy_variable=True), Parallax(local_proxy_variable=True), PartitionedAR(), From caba899183a3e3eee51af1e6f47a9813f6881128 Mon Sep 17 00:00:00 2001 From: Ezra-H <44772185+Ezra-H@users.noreply.github.com> Date: Sun, 15 Nov 2020 19:39:20 +0800 Subject: [PATCH 10/12] Update compressor.py --- autodist/kernel/synchronization/compressor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/autodist/kernel/synchronization/compressor.py b/autodist/kernel/synchronization/compressor.py index c1a429b..69335ad 100644 --- a/autodist/kernel/synchronization/compressor.py +++ b/autodist/kernel/synchronization/compressor.py @@ -233,10 +233,10 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): self.ndims = self.og_shape.shape[0] # rank <= 1 - if self.ndims <= 1: + if self.ndims <= 1 or (self.ndims == 2 and any([i == 1 for i in tensor.get_shape().as_list()])): return self._all_reduce(tensor, conf) - tensor = array_ops.reshape(tensor, [array_ops.shape_v2(tensor)[0], -1]) + tensor = array_ops.reshape(tensor, [self.og_shape[0], -1]) # compressor init if self.compressor is None: From c44183a53ce536370a231483988c7b74663ae9db Mon Sep 17 00:00:00 2001 From: Ezra-H <44772185+Ezra-H@users.noreply.github.com> Date: Mon, 16 Nov 2020 11:11:38 +0800 Subject: [PATCH 11/12] use float32 --- autodist/kernel/synchronization/compressor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/autodist/kernel/synchronization/compressor.py b/autodist/kernel/synchronization/compressor.py index 69335ad..1d8d160 100644 --- a/autodist/kernel/synchronization/compressor.py +++ b/autodist/kernel/synchronization/compressor.py @@ -236,11 +236,13 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): if self.ndims <= 1 or (self.ndims == 2 and any([i == 1 for i in tensor.get_shape().as_list()])): return self._all_reduce(tensor, conf) - tensor = array_ops.reshape(tensor, [self.og_shape[0], -1]) + og_dtype = tensor.dtype + tensor = array_ops.reshape(math_ops.cast(tensor, dtype=dtypes.float32), [self.og_shape[0], -1]) # compressor init if self.compressor is None: - self.compressor = random_ops.random_normal([array_ops.shape_v2(tensor)[1], self.rank], seed=1000) + self.compressor = random_ops.random_normal([array_ops.shape_v2(tensor)[1], self.rank], + seed=1000, dtype=dtypes.float32) self.compressor_conf = copy.copy(conf) self.compressor_conf.instance_key = get_collective_keys().get_instance_key(self.var_op_name + '/compressor') @@ -260,7 +262,7 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): # all reduce mean compressor self.compressor = self._all_reduce(self.compressor, self.compressor_conf) - return array_ops.reshape(self._decompress(orthonormal_reduced_tensor), self.og_shape) + return math_ops.cast(array_ops.reshape(self._decompress(orthonormal_reduced_tensor), self.og_shape), og_dtype) def _compress(self, tensor: Tensor): """ From 0c8b72b8ffed78a43d4a8df241a5a9dfe6b212ae Mon Sep 17 00:00:00 2001 From: Ezra-H <44772185+Ezra-H@users.noreply.github.com> Date: Tue, 17 Nov 2020 10:00:57 +0800 Subject: [PATCH 12/12] delete keyword dtype --- autodist/kernel/synchronization/compressor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autodist/kernel/synchronization/compressor.py b/autodist/kernel/synchronization/compressor.py index 1d8d160..cce74f5 100644 --- a/autodist/kernel/synchronization/compressor.py +++ b/autodist/kernel/synchronization/compressor.py @@ -237,7 +237,7 @@ def reduce(self, tensor: Tensor, conf: CollectiveOpsConfig): return self._all_reduce(tensor, conf) og_dtype = tensor.dtype - tensor = array_ops.reshape(math_ops.cast(tensor, dtype=dtypes.float32), [self.og_shape[0], -1]) + tensor = array_ops.reshape(math_ops.cast(tensor, dtypes.float32), [self.og_shape[0], -1]) # compressor init if self.compressor is None: