Skip to content
Extraits de code Groupes Projets
analog_BN_charge_interp_PL.py 14,9 ko
Newer Older
  • Learn to ignore specific revisions
  • # ////////////////////////////////////////////////////////////////////////////////////////////////////////////
    # /////////////////////////// Custom batchnorm implementing actual hardware ABN //////////////////////////////
    # ////////////////////////////////////////////////////////////////////////////////////////////////////////////
    
    # Inspired from https://stackoverflow.com/questions/54101593/conditional-batch-normalization-in-keras
    
    import numpy as np
    import math
    
    import tensorflow as tf
    import keras.backend as K
    
    from keras import regularizers, initializers, constraints
    #from keras.legacy import interfaces
    from keras.layers import Layer, Input, InputSpec
    from keras.models import Model
    
    # Current ABN model
    from models.ABN_charge import makeLookupABN, doInterpABN
    from models.ABN_charge import round_through, floor_through
    
    
    class Analog_BN(Layer):
        """ Analog batchnorm layer    
        """
        # /// Init layer ///
    #    @interfaces.legacy_batchnorm_support
        def __init__(self, 
                 axis=-1,
                 momentum=0.99,
                 epsilon=1e-5,
                 center=True,
                 scale=True,
                 renorm = True,
                 beta_initializer='zeros',
                 gamma_initializer=tf.keras.initializers.Constant(value=3),
                 moving_mean_initializer='zeros',
                 moving_variance_initializer='ones',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 activity_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 hardware = None,
                 NB = None,
                 m_sigma = 1,
                 Npoints = 401,
                 EN_NOISE = 0,
                 **kwargs):
                 
            super(Analog_BN, self).__init__(**kwargs)
            self.axis = axis
            self.momentum = momentum
            self.epsilon = epsilon
            self.center = center
            self.scale = scale
            self.renorm = renorm
            self.beta_initializer = initializers.get(beta_initializer)
            self.gamma_initializer = initializers.get(gamma_initializer)
            self.moving_mean_initializer = initializers.get(moving_mean_initializer)
            self.moving_variance_initializer = (initializers.get(moving_variance_initializer))
            self.beta_regularizer = regularizers.get(beta_regularizer)
            self.gamma_regularizer = regularizers.get(gamma_regularizer)
            self.activity_regularizer = regularizers.get(activity_regularizer)
            self.beta_constraint = constraints.get(beta_constraint)
            self.gamma_constraint = constraints.get(gamma_constraint)
            self.hardware = hardware;
            self.EN_NOISE = EN_NOISE;
            self.m_sigma = m_sigma;
            # -- Interpolation info --
            self.Npoints = Npoints;
            self.ABN_lookup = None;
            self.sig_ABN_lookup = None;
            
        # /// Build layer ///
        def build(self,input_shape):
            dim = input_shape[self.axis];
            
            if dim is None:
                raise ValueError('Axis ' + str(self.axis) + ' of '
                                 'input tensor should have a defined dimension '
                                 'but the layer received an input with shape ' +
                                 str(input_shape) + '.')
            shape = (dim,)
    
            if self.scale:
                # gamma_constraint = Clip(0.0,4.0)
            
                self.gamma = self.add_weight(shape = (1,),
                                             name = 'gamma',
                                             initializer = self.gamma_initializer,
                                             regularizer = self.gamma_regularizer,
                                             constraint = self.gamma_constraint)
            else:
                self.gamma = None
                    
            if self.center:
                # beta_constraint = Clip(-100.0,100.0);
            
                self.beta = self.add_weight(shape = shape,
                                            name = 'beta',
                                            initializer = self.beta_initializer,
                                            regularizer = self.beta_regularizer,
                                            constraint = self.beta_constraint)
                                          
            else:
                self.beta = None
                
            if self.renorm:
                self.moving_mean_DP = self.add_weight(
                                        shape=shape,
                                        name='moving_mean_DP',
                                        initializer=self.moving_mean_initializer,
                                        trainable=False)
                self.moving_variance_DP = self.add_weight(
                                        shape=shape,
                                        name='moving_variance_DP',
                                        initializer=self.moving_variance_initializer,
                                        trainable=False)
            else:
                self.moving_mean_DP = K.variable(0.0)
                self.moving_variance_DP = K.variable(1.0)
                
            self.m_sigma = self.add_weight(shape = (1,),
                             name = 'm_sigma',
                             initializer = initializers.get(tf.keras.initializers.Constant(value=self.m_sigma)),
                             trainable=False)
    
            # Spice-extracted lookup table between D_OUT, V_DP and T_ABN (hardcoded hardware has to match CIM params) 
            print('Retrieving actual charge-domain ABN response...')
            self.ABN_lookup = self.hardware.sramInfo.ABN_LUT;
            self.sig_ABN_lookup = self.hardware.sramInfo.sig_ABN_LUT;
            print('Done !')
            
            super(Analog_BN, self).build(input_shape)
    
        # /// Call layer (train or inference) ///
        def call(self,inputs,training=None):
        
            input_shape = K.int_shape(inputs);
            # print(input_shape)
    
            # Prepare broadcasting shape.
            ndim = len(input_shape)
            reduction_axes = list(range(len(input_shape)))
            del reduction_axes[self.axis]
            broadcast_shape = [1] * len(input_shape)
            broadcast_shape[self.axis] = input_shape[self.axis]
    
            # Determines whether broadcasting is needed.
            needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
            
            def normalize_inference():
                # Explicitely broadcast parameters when required.
                if needs_broadcasting:
                    # Norm params
                    if self.renorm:
                        broadcast_moving_mean_DP = K.reshape(self.moving_mean_DP,
                                                            broadcast_shape);
                        broadcast_moving_variance_DP = K.reshape(self.moving_variance_DP,
                                                            broadcast_shape);
                    else:
                        broadcast_moving_mean_DP = None;
                        broadcast_moving_variance_DP = None;
                    # Scale param
                    if self.scale:
                        broadcast_gamma = K.reshape(self.gamma,broadcast_shape);
                    else:
                        broadcast_gamma = None
                    # Offset param
                    if self.center:
                        broadcast_beta = K.reshape(self.beta,broadcast_shape);
                    else:
                        broadcast_beta = None
                  
                    # Return batchnorm 
                    return ABN(
                        inputs,
                        self.ABN_lookup,
                        self.sig_ABN_lookup,
                        self.V_DP_half_LUT,
                        self.devGainLUT,
                        broadcast_moving_mean_DP,
                        broadcast_moving_variance_DP,
                        broadcast_beta,
                        broadcast_gamma,
                        axis = self.axis,
                        epsilon = self.epsilon,
                        m_sigma = self.m_sigma,
                        hardware = self.hardware,
                        Npoints = self.Npoints,
                        EN_NOISE=self.EN_NOISE,
                        training=training)
                else:
                    return ABN(
                        inputs,
                        self.ABN_lookup,
                        self.sig_ABN_lookup,
                        self.V_DP_half_LUT,
                        self.devGainLUT,
                        self.moving_mean_DP,
                        self.moving_variance_DP,
                        self.beta,
                        self.gamma,
                        axis = self.axis,
                        epsilon = self.epsilon,
                        m_sigma = self.m_sigma,
                        hardware = self.hardware,
                        Npoints = self.Npoints,
                        EN_NOISE=self.EN_NOISE,
                        training=training)
    
            # If the learning phase is *static* and set to inference:
            if training in {0, False}:
               return normalize_inference()
    
    
            # If the learning is either dynamic, or set to training:
            (normed_training,mean_DP,variance_DP) = \
                                norm_ABN_in_train(
                                inputs,self.ABN_lookup,self.sig_ABN_lookup,self.V_DP_half_LUT, self.devGainLUT, self.beta, self.gamma, self.renorm, reduction_axes,
                                epsilon=self.epsilon,m_sigma=self.m_sigma,hardware=self.hardware,Npoints=self.Npoints,EN_NOISE=self.EN_NOISE,training=training)
            # ???
            if K.backend() != 'cntk':
                sample_size = K.prod([K.shape(inputs)[axis]
                                      for axis in reduction_axes])
                sample_size = K.cast(sample_size, dtype=K.dtype(inputs))
                if K.backend() == 'tensorflow' and sample_size.dtype != 'float32':
                    sample_size = K.cast(sample_size, dtype='float32')
    
                # sample variance - unbiased estimator of population variance
                variance_DP *= sample_size / (sample_size - (1.0 + self.epsilon))
    
            # Update moving mean and variance during training
            self.add_update([K.moving_average_update(self.moving_mean_DP,
                                                     mean_DP,
                                                     self.momentum),
                             K.moving_average_update(self.moving_variance_DP,
                                                     variance_DP,
                                                     self.momentum)])
                 
            # Pick ABN result for either training or inference
            return K.in_train_phase(normed_training,
                                    normalize_inference,
                                    training=training)
    
    
        def get_config(self):
            config = {
                'axis': self.axis,
                'momentum': self.momentum,
                'epsilon': self.epsilon,
                'center': self.center,
                'scale': self.scale,
                'renorm': self.renorm,
                'beta_initializer': initializers.serialize(self.beta_initializer),
                'gamma_initializer': initializers.serialize(self.gamma_initializer),
                'moving_mean_initializer':
                    initializers.serialize(self.moving_mean_initializer),
                'moving_variance_initializer':
                    initializers.serialize(self.moving_variance_initializer),
                'beta_regularizer': regularizers.serialize(self.beta_regularizer),
                'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
                'beta_constraint': constraints.serialize(self.beta_constraint),
                'gamma_constraint': constraints.serialize(self.gamma_constraint)
            }
            base_config = super(Analog_BN, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))
    
        def compute_output_shape(self, input_shape):
    
            return input_shape[1]
    ############################################## Internal functions ##################################################
    
    # Perform ABN
    def ABN(V_DP,ABN_lookup,sig_ABN_lookup,V_DP_half_LUT,devGainLUT,mov_mean_DP=0.0,mov_variance_DP=1.0, beta=0.0,gamma=0.0,axis=-1,epsilon=1e-5,m_sigma=1,hardware=None,Npoints=401,EN_NOISE=0,training=False):
        
        # Get hardware parameters
        VDD = hardware.sramInfo.VDD.data;
        
    
        r_gamma = hardware.sramInfo.r_gamma;
        r_beta  = hardware.sramInfo.r_beta;
    
        OAres   = hardware.sramInfo.OAres;
    
        # Get number of states
        Ns_gamma = 2**r_gamma;
    
        
        Vmax_beta = hardware.sramInfo.Vmax_beta;
    
        Vlsb_beta = Vmax_beta/2**(r_beta-1);
        Vadc_step = VDD/(2**OAres);
    
    
        # Set 'None' parameters to their initial values
        if gamma is None:
            gamma = K.constant(1.0);
        if beta is None:
            beta = K.constant(0.0);
        if mov_mean_DP is None:
            mov_mean_DP  = K.constant(DR_tuple[1]);
        if mov_variance_DP is None:
            mov_variance_DP  = K.constant(1.0);
    
    
        # Specify non-centernormalized correction factors
    #    mu_goal  = VDD/2;
    
        sigma_goal = VDD/m_sigma; var_goal = sigma_goal*sigma_goal;
    
    
    #    # Get custom renorm factors
    #    sigma_DP = K.sqrt(mov_variance_DP);
    #    mov_mean_DP_t = mov_mean_DP - mu_goal/sigma_goal*sigma_DP;
    #    # mov_mean_DP_t = K.zeros_like(mov_mean_DP);
    
        mov_variance_DP_t = K.mean(mov_variance_DP)/var_goal;
    
    #    mov_variance_DP_t = mov_variance_DP/var_goal;
    #    # Get equivalent coefficients
    #    sigma_DP_t = K.sqrt(mov_variance_DP_t); 
    
        gamma_eq = gamma/(K.sqrt(mov_variance_DP_t) + epsilon);
        beta_eq  = beta/gamma_eq - mov_mean_DP;
    
        
        # Restrict gain factor to power-of-2
    
        log_gamma_eq = round_through(tf.math.log(gamma_eq)/tf.math.log(2.));
        gamma_eq = K.pow(2.,log_gamma_eq);
    
        # Quantize results
        gamma_eq = K.clip(round_through(gamma_eq),1,2**r_gamma);
        V_beta  = K.clip(round_through(beta_eq/Vlsb_beta)*Vlsb_beta,-Vmax_beta,Vmax_beta);
           
        # Apply quantized offset
        V_ABN_temp = V_DP+V_beta;
    
            
        # // Get ABN distribution from LUTs based on the gain/offset mapping //
    
        D_OUT = doInterpABN(ABN_lookup,gamma_eq,V_ABN_temp,Ns_gamma,Ns_gamma,VDD,Npoints);
    
            sig_D_OUT = doInterpABN(sig_ABN_lookup,gamma_eq,V_ABN_temp,Ns_gamma,Ns_gamma,VDD,Npoints);
    
            sig_D_OUT = sig_D_OUT*K.random_normal(shape=tf.shape(D_OUT),mean=0.,stddev=1.,dtype='float32');
            D_OUT   = D_OUT + sig_D_OUT;
        
        # Reshape into the right order
        return D_OUT;
         
    # Compute mean and variance of the batch then perform ABN with it, when enabled
    def norm_ABN_in_train(V_DP,ABN_lookup,sig_ABN_lookup,V_DP_half_LUT,devGainLUT,beta=0.0,gamma=1.0,renorm=True,axis=-1,epsilon=1e-5,m_sigma=1,hardware=None,Npoints=401,EN_NOISE=0,training=False):
        # Compute mean and variance of each batch when desired
        if(renorm):
            # Eventually reshape V_DP in case of CONV2D operation
            Ncols = K.int_shape(V_DP)[-1];
            V_DP_flat = tf.reshape(V_DP,(-1,Ncols));
            # Get mean and variance
            mean_DP = K.mean(V_DP_flat,axis=0);
            variance_DP = K.var(V_DP_flat,axis=0);
        else:
            mean_DP = K.constant(0.0);
            variance_DP = K.constant(1.0);
        # Compute ABN with specified mean and variance
        V_DP_BN = ABN(V_DP,ABN_lookup,sig_ABN_lookup,V_DP_half_LUT,devGainLUT,mean_DP,variance_DP,beta,gamma,axis,epsilon,m_sigma,hardware,Npoints,EN_NOISE,training);
        # Return a tuple of BN_result, mean and variance
        return (V_DP_BN,mean_DP,variance_DP);