# //////////////////////////////////////////////////////////////////////////////////////////////////////////// # /////////////////////////// Custom batchnorm implementing actual hardware ABN ////////////////////////////// # //////////////////////////////////////////////////////////////////////////////////////////////////////////// # Inspired from https://stackoverflow.com/questions/54101593/conditional-batch-normalization-in-keras import numpy as np import math import tensorflow as tf import keras.backend as K from keras import regularizers, initializers, constraints #from keras.legacy import interfaces from keras.layers import Layer, Input, InputSpec from keras.models import Model class Analog_BN(Layer): """ Analog batchnorm layer """ # /// Init layer /// # @interfaces.legacy_batchnorm_support def __init__(self, axis=-1, momentum=0.99, epsilon=1e-5, center=True, scale=True, renorm = True, beta_initializer='zeros', gamma_initializer='ones', moving_mean_initializer='zeros', moving_variance_initializer='ones', beta_regularizer=None, gamma_regularizer=None, activity_regularizer=None, beta_constraint=None, gamma_constraint=None, hardware = None, NB = None, **kwargs): super(Analog_BN, self).__init__(**kwargs) self.axis = axis self.momentum = momentum self.epsilon = epsilon self.center = center self.scale = scale self.renorm = renorm self.beta_initializer = initializers.get(beta_initializer) self.gamma_initializer = initializers.get(gamma_initializer) self.moving_mean_initializer = initializers.get(moving_mean_initializer) self.moving_variance_initializer = (initializers.get(moving_variance_initializer)) self.beta_regularizer = regularizers.get(beta_regularizer) self.gamma_regularizer = regularizers.get(gamma_regularizer) self.activity_regularizer = regularizers.get(activity_regularizer) self.beta_constraint = constraints.get(beta_constraint) self.gamma_constraint = constraints.get(gamma_constraint) self.DRlim = (hardware.sramInfo.GND.data,hardware.sramInfo.VDD.data); self.gamma_range = 4*math.sqrt(NB) self.ABNstates = (2**hardware.sramInfo.r_gamma,2**hardware.sramInfo.r_beta) self.IS_DIFF = (hardware.sramInfo.arch.name == '6T'); # Update with other arch types # /// Build layer /// def build(self,input_shape): dim = input_shape[self.axis]; if dim is None: raise ValueError('Axis ' + str(self.axis) + ' of ' 'input tensor should have a defined dimension ' 'but the layer received an input with shape ' + str(input_shape) + '.') shape = (dim,) if self.scale: # gamma_constraint = Clip(0.0,4.0) self.gamma = self.add_weight(shape = shape, name = 'gamma', initializer = self.gamma_initializer, regularizer = self.gamma_regularizer, constraint = self.gamma_constraint) else: self.gamma = None if self.center: # beta_constraint = Clip(-100.0,100.0); self.beta = self.add_weight(shape = shape, name = 'beta', initializer = self.beta_initializer, regularizer = self.beta_regularizer, constraint = self.beta_constraint) else: self.beta = None if self.renorm: self.moving_mean_DP = self.add_weight( shape=shape, name='moving_mean_DP', initializer=self.moving_mean_initializer, trainable=False) self.moving_variance_DP = self.add_weight( shape=shape, name='moving_variance_DP', initializer=self.moving_variance_initializer, trainable=False) else: self.moving_mean_DP = K.variable(0.0) self.moving_variance_DP = K.variable(1.0) super(Analog_BN, self).build(input_shape) # /// Call layer (train or inference) /// def call(self,inputs,training=None): input_shape = K.int_shape(inputs[0]) # Prepare broadcasting shape. ndim = len(input_shape) reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] # Determines whether broadcasting is needed. needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) def normalize_inference(): # Explicitely broadcast parameters when required. if needs_broadcasting: # Norm params if self.renorm: broadcast_moving_mean_DP = K.reshape(self.moving_mean_DP, broadcast_shape); broadcast_moving_variance_DP = K.reshape(self.moving_variance_DP, broadcast_shape); else: broadcast_moving_mean_DP = None; broadcast_moving_variance_DP = None; # Scale param if self.scale: broadcast_gamma = K.reshape(self.gamma,broadcast_shape); else: broadcast_gamma = None # Offset param if self.center: broadcast_beta = K.reshape(self.beta,broadcast_shape); else: broadcast_beta = None # Return batchnorm return ABN( inputs, broadcast_moving_mean_DP, broadcast_moving_variance_DP, broadcast_beta, broadcast_gamma, axis = self.axis, epsilon = self.epsilon, DR_tuple = self.DRlim, gamma_range = self.gamma_range, ABNstates = self.ABNstates, IS_DIFF = self.IS_DIFF, training=training) else: return ABN( inputs, self.moving_mean_DP, self.moving_variance_DP, self.beta, self.gamma, axis = self.axis, epsilon = self.epsilon, DR_tuple = self.DRlim, gamma_range = self.gamma_range, ABNstates = self.ABNstates, IS_DIFF = self.IS_DIFF, training=training) # If the learning phase is *static* and set to inference: if training in {0, False}: return normalize_inference() # If the learning is either dynamic, or set to training: (normed_training,mean_DP,variance_DP) = \ norm_ABN_in_train( inputs, self.beta, self.gamma, self.renorm, reduction_axes, epsilon=self.epsilon,DR_tuple=self.DRlim,gamma_range=self.gamma_range,ABNstates=self.ABNstates,IS_DIFF=self.IS_DIFF,training=training) # ??? if K.backend() != 'cntk': sample_size = K.prod([K.shape(inputs[0])[axis] for axis in reduction_axes]) sample_size = K.cast(sample_size, dtype=K.dtype(inputs[0])) if K.backend() == 'tensorflow' and sample_size.dtype != 'float32': sample_size = K.cast(sample_size, dtype='float32') # sample variance - unbiased estimator of population variance variance_DP *= sample_size / (sample_size - (1.0 + self.epsilon)) # Update moving mean and variance during training self.add_update([K.moving_average_update(self.moving_mean_DP, mean_DP, self.momentum), K.moving_average_update(self.moving_variance_DP, variance_DP, self.momentum)]) # Pick ABN result for either training or inference return K.in_train_phase(normed_training, normalize_inference, training=training) def get_config(self): config = { 'axis': self.axis, 'momentum': self.momentum, 'epsilon': self.epsilon, 'center': self.center, 'scale': self.scale, 'renorm': self.renorm, 'beta_initializer': initializers.serialize(self.beta_initializer), 'gamma_initializer': initializers.serialize(self.gamma_initializer), 'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer), 'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer), 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), 'beta_constraint': constraints.serialize(self.beta_constraint), 'gamma_constraint': constraints.serialize(self.gamma_constraint), 'DRlim': self.DRlim, 'IS_DIFF': self.IS_DIFF } base_config = super(Analog_BN, self).get_config() return dict(list(base_config.items()) + list(config.items())) def compute_output_shape(self, input_shape): return input_shape[1] ############################################## Internal functions ################################################## # Perform ABN def ABN(x_in,mov_mean_DP=0.0,mov_variance_DP=1.0,beta=0.0,gamma=0.0,axis=-1,epsilon=1e-5,DR_tuple=None,gamma_range=None,ABNstates=None,IS_DIFF=True,training=False): # Retrieve differential or se output if(IS_DIFF): V_BL = x_in[0]; V_BLB = x_in[1]; else: V_BL = x_in; # tf.print("V_RBL",V_BL[0],summarize=10) # Get min and max DR limits minDR = DR_tuple[0]; maxDR = DR_tuple[1]; # Set 'None' parameters to their initial values if gamma is None: gamma = K.constant(1.0); if beta is None: beta = K.constant(0.0); if mov_mean_DP is None: mov_mean_DP = K.constant(DR_tuple[1]); if mov_variance_DP is None: mov_variance_DP = K.constant(1.0); # Specify non-centernormalized correction factors mu_goal = maxDR/2; sigma_goal = maxDR; var_goal = sigma_goal*sigma_goal; # Compute differential or single-ended DP with switched-cap unit if(IS_DIFF): V_DP = maxDR/2 + (V_BL-V_BLB)/2 else: V_DP = V_BL; # Get custom renorm factors sigma_DP = K.sqrt(mov_variance_DP); mov_mean_DP_t = mov_mean_DP - mu_goal/sigma_goal*sigma_DP; mov_variance_DP_t = mov_variance_DP/var_goal; # Get equivalent coefficients sigma_DP_t = K.sqrt(mov_variance_DP_t); gamma_eq = gamma/(sigma_DP_t + epsilon); beta_eq = beta - gamma*mov_mean_DP_t/(sigma_DP_t + epsilon); beta_eq_norm = beta_eq/gamma_eq + maxDR/2; # Quantize gamma and beta Ns_gamma = ABNstates[0]; Ns_beta = ABNstates[1]; gamma_eq = K.clip(floor_through(gamma_eq),0,Ns_gamma-1); # beta_eq_norm = K.clip(floor_through(beta_eq_norm/(2*maxDR/5)*256)*(maxDR)/256,-maxDR/2,maxDR/2) - maxDR/2; beta_eq_norm = beta_eq_norm - maxDR/2 # Apply (ideal, for now) equivalent coefficient to get ABN result. V_ABN = gamma_eq*(V_DP+beta_eq_norm); # Return (unclipped) result return V_ABN; # Compute mean and variance of the batch then perform ABN with it, when enabled def norm_ABN_in_train(x_tuple,beta=0.0,gamma=1.0,renorm=True,axis=-1,epsilon=1e-5,DR_tuple=None,gamma_range=None,ABNstates=None,IS_DIFF=True,training=False): # Retrieve differential tensors V_BL = x_tuple[0]; V_BLB = x_tuple[1]; # Retrieve max DR (VDD by default) maxDR = DR_tuple[1]; # Compute mean and variance of each batch when desired if(renorm): # Compute differential or single-ended DP with switched-cap unit if(IS_DIFF): V_DP = maxDR/2 + (V_BL-V_BLB)/2 else: V_DP = V_BL; # Get mean and variance mean_DP = K.mean(V_DP,axis=0); variance_DP = K.var(V_DP,axis=0); else: mean_DP = K.constant(0.0); variance_DP = K.constant(1.0); # Compute ABN with specified mean and variance V_DP_BN = ABN(x_tuple,mean_DP,variance_DP,beta,gamma,axis,epsilon,DR_tuple,gamma_range,ABNstates,IS_DIFF,training); # Return a tuple of BN_result, mean and variance return (V_DP_BN,mean_DP,variance_DP); # Gamma & Beta constaints class Clip(constraints.Constraint): def __init__(self, min_value, max_value=None): self.min_value = min_value self.max_value = max_value if not self.max_value: self.max_value = -self.min_value if self.min_value > self.max_value: self.min_value, self.max_value = self.max_value, self.min_value def __call__(self, p): return K.clip(p, self.min_value, self.max_value) def get_config(self): return {"name": self.__call__.__name__, "min_value": self.min_value, "max_value": self.max_value} # Truncated normal phi function def phi_exp(x): return 1/math.sqrt(2*math.pi)*K.exp(-0.5*(x*x)); def phi_erf(x): return 0.5*(1+tf.math.erf(x/math.sqrt(2))); def floor_through(x): '''Element-wise rounding to the closest integer with full gradient propagation. A trick from [Sergey Ioffe](http://stackoverflow.com/a/36480182) ''' floored = tf.math.floor(x); floored_through = x + K.stop_gradient(floored - x); return floored_through;