Skip to content

Commit

Permalink
add batchnorm layer (train only)
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Jul 22, 2024
1 parent 0049e96 commit a4ffb35
Showing 1 changed file with 111 additions and 1 deletion.
112 changes: 111 additions & 1 deletion src/HGQ/layers/batchnorm_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import tensorflow as tf
from keras.src.utils import tf_utils
from tensorflow import keras
Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(
self.gamma_regularizer = tf.keras.regularizers.get(gamma_regularizer)
self.beta_constraint = tf.keras.constraints.get(beta_constraint)
self.gamma_constraint = tf.keras.constraints.get(gamma_constraint)
self.norm_shape = None

def post_build(self, input_shape):
super().post_build(input_shape)
Expand All @@ -52,7 +54,7 @@ def post_build(self, input_shape):
def _post_build(self, input_shape):
self._reduction_axis = tuple([i for i in range(len(input_shape)) if i not in self.axis])
output_shape = self.compute_output_shape(input_shape)
shape = tuple([output_shape[i] for i in self.axis])
shape = self.norm_shape or tuple([output_shape[i] for i in self.axis])

if self.center and not getattr(self, "use_bias", False):
warn(f'`center` in fused BatchNorm can only be used if `use_bias` is True. Setting center to False.', stacklevel=3)
Expand Down Expand Up @@ -122,3 +124,111 @@ def adapt_fused_bn_kernel_bw_bits(self, x: tf.Tensor):
scale = self.bn_gamma * tf.math.rsqrt(var + self.epsilon)
fused_kernel = self.kernel * scale
self.kq.adapt_bw_bits(fused_kernel)


class FakeObj:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)


class HBatchNormalization(HBatchNormBase):

def __init__(
self,
axis=-1,
momentum=0.99,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer="zeros",
gamma_initializer="ones",
moving_mean_initializer="zeros",
moving_variance_initializer="ones",
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs,
):
super().__init__(
axis=axis,
momentum=momentum,
epsilon=epsilon,
center=center,
scale=scale,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
moving_mean_initializer=moving_mean_initializer,
moving_variance_initializer=moving_variance_initializer,
beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer,
beta_constraint=beta_constraint,
gamma_constraint=gamma_constraint,
**kwargs,
)
self._delayed_kernel_bits_adaption = True

def post_build(self, input_shape):
self.step_counter = tf.Variable(1, trainable=False, dtype=tf.int32, name="step_counter")
self.use_bias = True
axis: list = tf_utils.validate_axis(self.axis, input_shape)
reduction_axis = tuple([i for i in range(len(input_shape)) if i not in axis])
ker_shape = tuple(1 if i in reduction_axis else n_inp for i, n_inp in enumerate(input_shape))
self._reduction_axis = reduction_axis
self.norm_shape = ker_shape
self.kernel = FakeObj(shape=ker_shape)
r = super().post_build(input_shape)
delattr(self, "use_bias")
delattr(self, "kernel")
return r

def forward(self, x, training=None, record_minmax=False):

if training:
self.step_counter.assign_add(1)
var = tf.math.reduce_variance(x, axis=self._reduction_axis, keepdims=True)
mean = tf.math.reduce_mean(x, axis=self._reduction_axis, keepdims=True)
self.moving_mean.assign_sub((self.moving_mean - mean) * (1 - self.momentum))
self.moving_variance.assign_sub((self.moving_variance - var) * (1 - self.momentum))
else:
# correction = 1 - tf.pow(self.momentum, tf.cast(self.step_counter, self.dtype))
mean = self.moving_mean
var = self.moving_variance

ker = self.bn_gamma * tf.math.rsqrt(var + self.epsilon)
bias = self.bn_beta - mean / ker

if self._do_adapt_kernel_bits and self._delayed_kernel_bits_adaption:
self.kq.adapt_bw_bits(ker)
self._delayed_kernel_bits_adaption = False

qker = self.kq(ker, training=training) # type: ignore
qbias = self.paq(bias, training=training) # type: ignore
z = qker * x + qbias

input_bw = self.input_bw
if input_bw is not None:
kernel_bw = self._kernel_bw(qker)
bops = tf.reduce_sum(input_bw * kernel_bw)
self.bops.assign(bops)
bops_loss = tf.cast(bops, tf.float32) * self.beta
self.add_loss(bops_loss)

return self.paq(z, training=training, record_minmax=record_minmax) # type: ignore

def compute_output_shape(self, input_shape):
return input_shape

@property
def compute_exact_bops(self):
mean = self.moving_mean
var = self.moving_variance
ker = self.bn_gamma * tf.math.rsqrt(var + self.epsilon)
qker = self.kq(ker, training=False) # type: ignore
kn, int_bits, fb = self.kq.get_bits_exact(qker)
kernel_bw = int_bits + fb # sign not considered for kernel
input_bw = self.input_bw_exact
bops = int(np.sum(input_bw * kernel_bw))
self.bops.assign(tf.constant(bops, dtype=tf.float32))
return bops

0 comments on commit a4ffb35

Please sign in to comment.