# ////////////////////////////////////////////////////////////////////////////////////////////////////////////
# /////////////////////////// Custom batchnorm implementing actual hardware ABN //////////////////////////////
# ////////////////////////////////////////////////////////////////////////////////////////////////////////////

# Inspired from https://stackoverflow.com/questions/54101593/conditional-batch-normalization-in-keras

import numpy as np
import math

import tensorflow as tf
import keras.backend as K

from keras import regularizers, initializers, constraints
#from keras.legacy import interfaces
from keras.layers import Layer, Input, InputSpec
from keras.models import Model

class Analog_BN(Layer):
    """ Analog batchnorm layer    
    """
    # /// Init layer ///
#    @interfaces.legacy_batchnorm_support
    def __init__(self, 
             axis=-1,
             momentum=0.99,
             epsilon=1e-5,
             center=True,
             scale=True,
             renorm = True,
             beta_initializer='zeros',
             gamma_initializer='ones',
             moving_mean_initializer='zeros',
             moving_variance_initializer='ones',
             beta_regularizer=None,
             gamma_regularizer=None,
             activity_regularizer=None,
             beta_constraint=None,
             gamma_constraint=None,
             hardware = None,
             NB = None,
             **kwargs):
             
        super(Analog_BN, self).__init__(**kwargs)
        self.axis = axis
        self.momentum = momentum
        self.epsilon = epsilon
        self.center = center
        self.scale = scale
        self.renorm = renorm
        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
        self.moving_mean_initializer = initializers.get(moving_mean_initializer)
        self.moving_variance_initializer = (initializers.get(moving_variance_initializer))
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)
        self.beta_constraint = constraints.get(beta_constraint)
        self.gamma_constraint = constraints.get(gamma_constraint)
        self.DRlim = (hardware.sramInfo.GND.data,hardware.sramInfo.VDD.data);
        self.gamma_range = 4*math.sqrt(NB)
        self.ABNstates = (2**hardware.sramInfo.r_gamma,2**hardware.sramInfo.r_beta)
        self.IS_DIFF = (hardware.sramInfo.arch.name == '6T'); # Update with other arch types
    
    # /// Build layer ///
    def build(self,input_shape):
        dim = input_shape[self.axis];
        
        if dim is None:
            raise ValueError('Axis ' + str(self.axis) + ' of '
                             'input tensor should have a defined dimension '
                             'but the layer received an input with shape ' +
                             str(input_shape) + '.')
        shape = (dim,)

        if self.scale:
            # gamma_constraint = Clip(0.0,4.0)
        
            self.gamma = self.add_weight(shape = shape,
                                         name = 'gamma',
                                         initializer = self.gamma_initializer,
                                         regularizer = self.gamma_regularizer,
                                         constraint = self.gamma_constraint)
        else:
            self.gamma = None
                
        if self.center:
            # beta_constraint = Clip(-100.0,100.0);
        
            self.beta = self.add_weight(shape = shape,
                                        name = 'beta',
                                        initializer = self.beta_initializer,
                                        regularizer = self.beta_regularizer,
                                        constraint = self.beta_constraint)
        else:
            self.beta = None
            
        if self.renorm:
            self.moving_mean_DP = self.add_weight(
                                    shape=shape,
                                    name='moving_mean_DP',
                                    initializer=self.moving_mean_initializer,
                                    trainable=False)
            self.moving_variance_DP = self.add_weight(
                                    shape=shape,
                                    name='moving_variance_DP',
                                    initializer=self.moving_variance_initializer,
                                    trainable=False)
        else:
            self.moving_mean_DP = K.variable(0.0)
            self.moving_variance_DP = K.variable(1.0)

        super(Analog_BN, self).build(input_shape)

    # /// Call layer (train or inference) ///
    def call(self,inputs,training=None):
    
        input_shape = K.int_shape(inputs[0])

        # Prepare broadcasting shape.
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
        
        def normalize_inference():
            # Explicitely broadcast parameters when required.
            if needs_broadcasting:
                # Norm params
                if self.renorm:
                    broadcast_moving_mean_DP = K.reshape(self.moving_mean_DP,
                                                        broadcast_shape);
                    broadcast_moving_variance_DP = K.reshape(self.moving_variance_DP,
                                                        broadcast_shape);
                else:
                    broadcast_moving_mean_DP = None;
                    broadcast_moving_variance_DP = None;
                # Scale param
                if self.scale:
                    broadcast_gamma = K.reshape(self.gamma,broadcast_shape);
                else:
                    broadcast_gamma = None
                # Offset param
                if self.center:
                    broadcast_beta = K.reshape(self.beta,broadcast_shape);
                else:
                    broadcast_beta = None
                # Return batchnorm 
                return ABN(
                    inputs,
                    broadcast_moving_mean_DP,
                    broadcast_moving_variance_DP,
                    broadcast_beta,
                    broadcast_gamma,
                    axis = self.axis,
                    epsilon = self.epsilon,
                    DR_tuple = self.DRlim,
                    gamma_range = self.gamma_range,
                    ABNstates = self.ABNstates,
                    IS_DIFF = self.IS_DIFF,
                    training=training)
            else:
                return ABN(
                    inputs,
                    self.moving_mean_DP,
                    self.moving_variance_DP,
                    self.beta,
                    self.gamma,
                    axis = self.axis,
                    epsilon = self.epsilon,
                    DR_tuple = self.DRlim,
                    gamma_range = self.gamma_range,
                    ABNstates = self.ABNstates,
                    IS_DIFF = self.IS_DIFF,
                    training=training)

        # If the learning phase is *static* and set to inference:
        if training in {0, False}:
           return normalize_inference()


        # If the learning is either dynamic, or set to training:
        (normed_training,mean_DP,variance_DP) = \
                            norm_ABN_in_train(
                            inputs, self.beta, self.gamma, self.renorm, reduction_axes,
                            epsilon=self.epsilon,DR_tuple=self.DRlim,gamma_range=self.gamma_range,ABNstates=self.ABNstates,IS_DIFF=self.IS_DIFF,training=training)
        # ???
        if K.backend() != 'cntk':
            sample_size = K.prod([K.shape(inputs[0])[axis]
                                  for axis in reduction_axes])
            sample_size = K.cast(sample_size, dtype=K.dtype(inputs[0]))
            if K.backend() == 'tensorflow' and sample_size.dtype != 'float32':
                sample_size = K.cast(sample_size, dtype='float32')

            # sample variance - unbiased estimator of population variance
            variance_DP *= sample_size / (sample_size - (1.0 + self.epsilon))

        # Update moving mean and variance during training
        self.add_update([K.moving_average_update(self.moving_mean_DP,
                                                 mean_DP,
                                                 self.momentum),
                         K.moving_average_update(self.moving_variance_DP,
                                                 variance_DP,
                                                 self.momentum)])
             
        # Pick ABN result for either training or inference
        return K.in_train_phase(normed_training,
                                normalize_inference,
                                training=training)


    def get_config(self):
        config = {
            'axis': self.axis,
            'momentum': self.momentum,
            'epsilon': self.epsilon,
            'center': self.center,
            'scale': self.scale,
            'renorm': self.renorm,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            'moving_mean_initializer':
                initializers.serialize(self.moving_mean_initializer),
            'moving_variance_initializer':
                initializers.serialize(self.moving_variance_initializer),
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': constraints.serialize(self.beta_constraint),
            'gamma_constraint': constraints.serialize(self.gamma_constraint),
            'DRlim': self.DRlim,
            'IS_DIFF': self.IS_DIFF
        }
        base_config = super(Analog_BN, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):

        return input_shape[1]
############################################## Internal functions ##################################################

# Perform ABN
def ABN(x_in,mov_mean_DP=0.0,mov_variance_DP=1.0,beta=0.0,gamma=0.0,axis=-1,epsilon=1e-5,DR_tuple=None,gamma_range=None,ABNstates=None,IS_DIFF=True,training=False):
    # Retrieve differential or se output
    if(IS_DIFF):
        V_BL  = x_in[0];
        V_BLB = x_in[1];
    else:
        V_BL = x_in;
        # tf.print("V_RBL",V_BL[0],summarize=10)

    # Get min and max DR limits
    minDR = DR_tuple[0];
    maxDR = DR_tuple[1];

    # Set 'None' parameters to their initial values
    if gamma is None:
        gamma = K.constant(1.0);
    if beta is None:
        beta = K.constant(0.0);
    if mov_mean_DP is None:
        mov_mean_DP  = K.constant(DR_tuple[1]);
    if mov_variance_DP is None:
        mov_variance_DP  = K.constant(1.0);

    # Specify non-centernormalized correction factors
    mu_goal  = maxDR/2;
    sigma_goal = maxDR; var_goal = sigma_goal*sigma_goal;
    
    # Compute differential or single-ended DP with switched-cap unit
    if(IS_DIFF):
        V_DP = maxDR/2 + (V_BL-V_BLB)/2
    else:
        V_DP = V_BL;
    # Get custom renorm factors
    sigma_DP = K.sqrt(mov_variance_DP);
    mov_mean_DP_t = mov_mean_DP - mu_goal/sigma_goal*sigma_DP;
    mov_variance_DP_t = mov_variance_DP/var_goal;
    # Get equivalent coefficients
    sigma_DP_t = K.sqrt(mov_variance_DP_t); 
    gamma_eq = gamma/(sigma_DP_t + epsilon);
    beta_eq  = beta - gamma*mov_mean_DP_t/(sigma_DP_t + epsilon);
    beta_eq_norm = beta_eq/gamma_eq + maxDR/2;
    # Quantize gamma and beta
    Ns_gamma = ABNstates[0];
    Ns_beta = ABNstates[1];
    gamma_eq = K.clip(floor_through(gamma_eq),0,Ns_gamma-1);
    # beta_eq_norm = K.clip(floor_through(beta_eq_norm/(2*maxDR/5)*256)*(maxDR)/256,-maxDR/2,maxDR/2) - maxDR/2;
    beta_eq_norm = beta_eq_norm - maxDR/2
    # Apply (ideal, for now) equivalent coefficient to get ABN result.
    V_ABN = gamma_eq*(V_DP+beta_eq_norm);
    # Return (unclipped) result
    return V_ABN;
        
# Compute mean and variance of the batch then perform ABN with it, when enabled
def norm_ABN_in_train(x_tuple,beta=0.0,gamma=1.0,renorm=True,axis=-1,epsilon=1e-5,DR_tuple=None,gamma_range=None,ABNstates=None,IS_DIFF=True,training=False):
    # Retrieve differential tensors
    V_BL  = x_tuple[0];
    V_BLB = x_tuple[1];
    # Retrieve max DR (VDD by default)
    maxDR = DR_tuple[1];
    # Compute mean and variance of each batch when desired
    if(renorm):
        # Compute differential or single-ended DP with switched-cap unit
        if(IS_DIFF):
            V_DP = maxDR/2 + (V_BL-V_BLB)/2
        else:
            V_DP = V_BL;
        # Get mean and variance
        mean_DP = K.mean(V_DP,axis=0);
        variance_DP = K.var(V_DP,axis=0);
    else:
        mean_DP = K.constant(0.0);
        variance_DP = K.constant(1.0);
    # Compute ABN with specified mean and variance
    V_DP_BN = ABN(x_tuple,mean_DP,variance_DP,beta,gamma,axis,epsilon,DR_tuple,gamma_range,ABNstates,IS_DIFF,training);
    # Return a tuple of BN_result, mean and variance
    return (V_DP_BN,mean_DP,variance_DP);
    
# Gamma & Beta constaints

class Clip(constraints.Constraint):
    def __init__(self, min_value, max_value=None):
        self.min_value = min_value
        self.max_value = max_value
        if not self.max_value:
            self.max_value = -self.min_value
        if self.min_value > self.max_value:
            self.min_value, self.max_value = self.max_value, self.min_value

    def __call__(self, p):
        return K.clip(p, self.min_value, self.max_value)

    def get_config(self):
        return {"name": self.__call__.__name__,
                "min_value": self.min_value,
                "max_value": self.max_value}


# Truncated normal phi function
def phi_exp(x):
    return 1/math.sqrt(2*math.pi)*K.exp(-0.5*(x*x));
def phi_erf(x):
    return 0.5*(1+tf.math.erf(x/math.sqrt(2)));
    
def floor_through(x):
    '''Element-wise rounding to the closest integer with full gradient propagation.
    A trick from [Sergey Ioffe](http://stackoverflow.com/a/36480182)
    '''
    floored = tf.math.floor(x);
    floored_through = x + K.stop_gradient(floored - x);
    return floored_through;