# -*- coding: utf-8 -*-
import numpy as np
import math

from keras import backend as K
import tensorflow as tf

from keras.layers import InputSpec, Layer, Dense, Conv2D
from keras import constraints
from keras import initializers
# Binarization functions
from layers.binary_ops import binarize, binarize_exp, binarize_ssb
from layers.binary_ops import binary_sigmoid_p
# Analog MAC operator
from models.MAC_current import MAC_op_se_ana as MAC_op_se
from models.MAC_current import MAC_op_diff_ana as MAC_op_diff
from models.CONV_current import CONV_op_se_ana as CONV_op_se
from models.CONV_current import CONV_op_diff_ana as CONV_op_diff
# ADC model
from models.ADC import quant_uni
# Hardware parameters generation
from utils.config_hardware_model import genHardware
# Temporary dir
import tempfile
import sys
import subprocess
import time
# Modeling files
import os
scriptpath = "../lib_modelcim/"
sys.path.append(os.path.abspath(scriptpath));
from preProc_wrapper import preProcSat as getHardwareData
from fit_spice import DP_fit

class Clip(constraints.Constraint):
    def __init__(self, min_value, max_value=None):
        self.min_value = min_value
        self.max_value = max_value
        if not self.max_value:
            self.max_value = -self.min_value
        if self.min_value > self.max_value:
            self.min_value, self.max_value = self.max_value, self.min_value

    def __call__(self, p):
        return K.clip(p, self.min_value, self.max_value)

    def get_config(self):
        return {"name": self.__call__.__name__,
                "min_value": self.min_value,
                "max_value": self.max_value}


class BinaryDense(Dense):
    ''' Binarized Dense layer
    References:
    "BinaryNet: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1" [http://arxiv.org/abs/1602.02830]
    '''
    def __init__(self, units, H=1.,sramInfo=None, EN_NOISE=0, EN_QUANT=1, kernel_lr_multiplier='Glorot', bias_lr_multiplier=None, **kwargs):
        super(BinaryDense, self).__init__(units, **kwargs)
        self.H = H
        self.kernel_lr_multiplier = kernel_lr_multiplier
        self.bias_lr_multiplier = bias_lr_multiplier
        
        self.EN_NOISE = EN_NOISE
        self.EN_QUANT = EN_QUANT
        
        self.sramInfo = sramInfo
        self.hardware = None
        self.Vt_noise = None
        self.input_dim = None
        
        super(BinaryDense, self).__init__(units, **kwargs)
    
    def build(self, input_shape):
        assert len(input_shape) >= 2
        input_dim = input_shape[1]
        self.input_dim = input_dim;

        if self.H == 'Glorot':
            self.H = np.float32(np.sqrt(1.5 / (input_dim + self.units)))
            #print('Glorot H: {}'.format(self.H))
        if self.kernel_lr_multiplier == 'Glorot':
            self.kernel_lr_multiplier = np.float32(1. / np.sqrt(1.5 / (input_dim + self.units)))
            #print('Glorot learning rate multiplier: {}'.format(self.kernel_lr_multiplier))
        
        # Retrieve architecture type (diff or se) and derive flag
        archType = self.sramInfo.arch.name;
       # if(archType == '6T'):
        self.kernel_constraint = Clip(-self.H, self.H)
        self.kernel_initializer = initializers.RandomUniform(-self.H, self.H)
       # elif(archType == '8T'):
            # self.kernel_constraint = Clip(0, self.H)
            # self.kernel_initializer = initializers.RandomUniform(0, self.H)    
        # else:
            # error('Unsupported cell type during binary weights initialization !');
        
        self.kernel = self.add_weight(shape=(input_dim, self.units),
                                     initializer=self.kernel_initializer,
                                     name='kernel',
                                     regularizer=self.kernel_regularizer,
                                     constraint=self.kernel_constraint)

        if self.use_bias:
            self.lr_multipliers = [self.kernel_lr_multiplier, self.bias_lr_multiplier]
            self.bias = self.add_weight(shape=(self.output_dim,),
                                     initializer=self.bias_initializer,
                                     name='bias',
                                     regularizer=self.bias_regularizer,
                                     constraint=self.bias_constraint)
        else:
            self.lr_multipliers = [self.kernel_lr_multiplier]
            self.bias = None
            
        # Get DP electrical quantities for this layer
        Nrows = self.sramInfo.Nrows.data
        N_cim = int(math.ceil((input_dim-1)/Nrows));
        self.sramInfo.NB.data = int(input_dim/N_cim);
        print(f'######## FC layer with {self.sramInfo.NB.data} cells/op supplied at {self.sramInfo.VDD.data:.2f}V ######## ')
        path_dir = '/export/home/adkneip/Documents/PhD/ELDO/IMC_PYTHON/CURRENT_MAC/'+self.sramInfo.arch.name+'_CELL/'
        ################################################# USE TEMPORARY SIM DIRECTORY #####################################################
        with tempfile.TemporaryDirectory(dir=path_dir,prefix='SimFolder_') as path_to_file:
            #print(path_to_file)
            # Copy .cir files into temporary simu folder -- '*' sumbol bugs for some reason
            if(self.sramInfo.simulator == "eldo"):
                file_table = np.array(['MAC_DC.cir','MAC_NL.cir','MAC_satCal.cir','MAC_time.cir','MAC_train_MC.cir']);
            elif(self.sramInfo.simulator == "spectre"):
                file_table = np.array(['MAC_DC.scs','MAC_satCal.scs','MAC_time.scs',
                                        'MAC_DC.mdl','MAC_satCal.mdl','MAC_time.mdl']);
            else:
                sys.exit('Error: selected simulator not supported !\n');
            for file_temp in file_table:
                commandLine = ['cp',path_dir+'RefFolder/'+file_temp,path_to_file+'/'];
                proc = subprocess.run(commandLine);
                if(proc.returncode != 0):
                    sys.exit('Error: could not copy reference files into temporary sim folder !\n');
            # Create temporary data file
            commandLine = ['mkdir',path_to_file+'/data'];
            proc = subprocess.run(commandLine);
            if(proc.returncode != 0):
                sys.exit('Error: could not copy reference files into temporary sim folder !\n');
            # Perform Spice simulations
            self.sramInfo = getHardwareData(path_to_file,self.sramInfo)
       #     time.sleep(300); # For debug
       ###################################################################################################################################
       # Generate hardware parameters
        hardware = genHardware(self.sramInfo)
        # Compute the appropriate curve-fitting factors
        # hardware.a1 = 1; hardware.a2 = 1; hardware.b1 = 1;
        # self.hardware = hardware
        print(f'######## Performing three-parametric best curve-fitting ######## ') 
        self.hardware = DP_fit(path_dir,'early',hardware)        
        # Create V_th distribution
        mu_Vth = self.hardware.mu_Vth
        sig_Vth = self.hardware.sig_Vth
        # self.Vt_noise = K.random_normal(shape=(self.units,),mean=0,stddev=sig_Vth)
        self.Vt_noise = K.random_normal(shape=(self.units,),mean=0,stddev=0)
              
        # Perform build
        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
        self.built = True
        
    def call(self, inputs):
        # Binarize weights
        W_bin = binarize(self.kernel, H=self.H);
        # Check if a single CIM-SRAM is sufficient, or ideal charge-share of their analog outputs
        Nrows = self.hardware.sramInfo.Nrows.data
        N_cim = int(math.ceil((self.input_dim-1)/Nrows));
        # Retrieve architecture type (diff or se) and derive flag
        archType = self.hardware.sramInfo.arch.name;
        IS_SE_OUT = (archType == '8T') or  self.EN_QUANT;
        # Wrap correct MAC_op function
        if(archType == '6T'):
            MAC_op = MAC_op_diff;
        elif(archType == '8T'):
            MAC_op = MAC_op_se;
        else:
            raise NameError('Error: selected architecture (cell type) not supported during FC layer compute !\n');
        # Emulate 6T-based CIM-SRAM analog MAC operation, possibly with parallel macros
        if(N_cim > 1):
            # Separate inputs and weights in sub-matrices
            inputs = tf.unstack(K.reshape(inputs,(-1,int(self.input_dim/N_cim),N_cim)),axis=-1)
            W_bin = K.permute_dimensions(K.reshape(K.permute_dimensions(W_bin,(1,0)),(-1,int(self.input_dim/N_cim),N_cim)),(1,2,0))
            W_bin = tf.unstack(W_bin,axis=1)
            # Perform CIM-SRAM operations over all sub-matrices (i.e. different CIM-SRAMs)
            V_DP = [];
            for i in range(N_cim):
                V_DP.append(MAC_op(self.hardware,inputs[i],W_bin[i],self.Vt_noise,self.EN_NOISE,self.EN_QUANT))
            # Combine the result as if ideal charge-sharing (--> could implement actual charge-sharing !)
            if(IS_SE_OUT):
                V_DP = K.sum(tf.stack(V_DP,axis=2),axis=2)/N_cim;
            else:
                V_BL  = K.sum(tf.stack(V_DP[0],axis=2),axis=2)/N_cim;
                V_BLB = K.sum(tf.stack(V_DP[1],axis=2),axis=2)/N_cim;
        else:
            if(IS_SE_OUT):
                V_DP = MAC_op(self.hardware,inputs,W_bin,self.Vt_noise,self.EN_NOISE,self.EN_QUANT);
            else:
                (V_BL,V_BLB) = MAC_op(self.hardware,inputs,W_bin,self.Vt_noise,self.EN_NOISE,self.EN_QUANT);
        # Add bias to PA
        if self.use_bias:
            if(IS_SE_OUT):
                V_DP = K.bias_add(V_DP, self.bias)
            else:
                V_BL  = K.bias_add(V_BL,self.bias)
                V_BLB = K.bias_add(V_BLB,self.bias)
                
        # Quantify the PA to get the digitized OA
        IAres = self.hardware.sramInfo.IAres;
        OAres = self.hardware.sramInfo.OAres;
        NB = self.hardware.sramInfo.NB.data;
        PAmax = (2**IAres-1)*NB;
        DRval = self.hardware.sramInfo.DR.data;
        VDD = self.hardware.sramInfo.VDD.data;
        if(self.EN_QUANT):
            DO = quant_uni(V_DP,PAmax,DRval,VDD,OAres,0.5*DRval/PAmax,archType);
            # Return quantized output
            return DO
        elif(archType == '8T'):
            return V_DP
        else:
            # Return unquantized differential output
            return K.concatenate([V_BL[np.newaxis,...],V_BLB[np.newaxis,...]],axis=0)
        
    def get_config(self):
        config = {'H': self.H,
                  'kernel_lr_multiplier': self.kernel_lr_multiplier,
                  'bias_lr_multiplier': self.bias_lr_multiplier}
        base_config = super(BinaryDense, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


class BinaryConv2D(Conv2D):
    '''Binarized Convolution2D layer
    References: 
    "BinaryNet: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1" [http://arxiv.org/abs/1602.02830]
    '''
    def __init__(self, filters, kernel_regularizer=None,activity_regularizer=None, kernel_lr_multiplier='Glorot',
                 bias_lr_multiplier=None, H=1.,sramInfo=None, EN_NOISE=0, EN_QUANT=1, **kwargs):
        super(BinaryConv2D, self).__init__(filters, **kwargs)
        self.H = H
        self.kernel_lr_multiplier = kernel_lr_multiplier
        self.bias_lr_multiplier = bias_lr_multiplier
        self.activity_regularizer = activity_regularizer
        self.kernel_regularizer = kernel_regularizer
        
        self.sramInfo = sramInfo
        self.hardware = None
        self.Vt_noise = None
        
        self.EN_NOISE = EN_NOISE
        self.EN_QUANT = EN_QUANT

    def build(self, input_shape):
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1 
        if input_shape[channel_axis] is None:
                raise ValueError('The channel dimension of the inputs '
                                 'should be defined. Found `None`.')

        input_dim = input_shape[channel_axis]
        kernel_shape = self.kernel_size + (input_dim, self.filters)
        #kernel_shape = self.kernel_size + (self.filters,)
         
        base = self.kernel_size[0] * self.kernel_size[1]
        if self.H == 'Glorot':
            nb_input = int(input_dim * base)
            nb_output = int(self.filters * base)
            self.H = np.float32(np.sqrt(1.5 / (nb_input + nb_output)))
            #print('Glorot H: {}'.format(self.H))
            
        if self.kernel_lr_multiplier == 'Glorot':
            nb_input = int(input_dim * base)
            nb_output = int(self.filters * base)
            self.kernel_lr_multiplier = np.float32(1. / np.sqrt(1.5/ (nb_input + nb_output)))
            #print('Glorot learning rate multiplier: {}'.format(self.lr_multiplier))

        self.kernel_constraint = Clip(-self.H, self.H)
        self.kernel_initializer = initializers.RandomUniform(-self.H, self.H)
        #self.bias_initializer = initializers.RandomUniform(-self.H, self.H)
        self.kernel = self.add_weight(shape=kernel_shape,
                                 initializer=self.kernel_initializer,
                                 name='kernel',
                                 regularizer=self.kernel_regularizer,
                                 constraint=self.kernel_constraint)
#        print(K.int_shape(self.kernel))

        if self.use_bias:
            self.lr_multipliers = [self.kernel_lr_multiplier, self.bias_lr_multiplier]
            self.bias = self.add_weight((self.filters,),
                                     initializer=self.bias_initializer,
                                     name='bias',
                                     regularizer=self.bias_regularizer,
                                     constraint=self.bias_constraint)

        else:
            self.lr_multipliers = [self.kernel_lr_multiplier]
            self.bias = None
            
        # Get DP electrical quantities for this layer
        self.sramInfo.NB.data = base*input_dim;
        print(f'######## 2D-CONV layer with {self.sramInfo.NB.data} cells/op supplied at {self.sramInfo.VDD.data:.2f}V ######## ')
        path_dir = '/export/home/adkneip/Documents/PhD/ELDO/IMC_PYTHON/CURRENT_MAC/'+self.sramInfo.arch.name+'_CELL/'
        ################################################# USE TEMPORARY SIM DIRECTORY #####################################################
        with tempfile.TemporaryDirectory(dir=path_dir,prefix='SimFolder_') as path_to_file:
            #print(path_to_file)
            # Copy .cir files into temporary simu folder -- '*' sumbol bugs for some reason
            if(self.sramInfo.simulator == "eldo"):
                file_table = np.array(['MAC_DC.cir','MAC_NL.cir','MAC_satCal.cir','MAC_time.cir','MAC_train_MC.cir']);
            elif(self.sramInfo.simulator == "spectre"):
                file_table = np.array(['MAC_DC.scs','MAC_satCal.scs','MAC_time.scs',
                                        'MAC_DC.mdl','MAC_satCal.mdl','MAC_time.mdl']);
            else:
                sys.exit('Error: selected simulator not supported !\n');
            for file_temp in file_table:
                commandLine = ['cp',path_dir+'RefFolder/'+file_temp,path_to_file+'/'];
                proc = subprocess.run(commandLine);
                if(proc.returncode != 0):
                    sys.exit('Error: could not copy reference files into temporary sim folder !\n');
            # Create temporary data file
            commandLine = ['mkdir',path_to_file+'/data'];
            proc = subprocess.run(commandLine);
            if(proc.returncode != 0):
                sys.exit('Error: could not copy reference files into temporary sim folder !\n');
                # Perform Spice simulations
            self.sramInfo = getHardwareData(path_to_file,self.sramInfo)
        ###################################################################################################################################
        # Generate hardware parameters
        hardware = genHardware(self.sramInfo)
        # Compute the appropriate curve-fitting factors
        # hardware.a1 = 1; hardware.a2 = 1; hardware.b1 = 1;
        # self.hardware = hardware
        print(f'######## Performing three-parametric best curve-fitting ######## ') 
        self.hardware = DP_fit(path_dir,'early',hardware)        
        # Create V_th distribution
        sig_Vth = self.hardware.sig_Vth
        #self.Vt_noise = K.random_normal(shape=(input_dim,),mean=0,stddev=sig_Vth)
        self.Vt_noise = K.random_normal(shape=(input_dim,),mean=0,stddev=0)
        

        # Set input spec.
        self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
        self.built = True
        
    def call(self, inputs):
        binary_kernel = binarize(self.kernel, H=self.H)
        # Retrieve architecture type (diff or se) and derive flag
        archType = self.hardware.sramInfo.arch.name;
        IS_SE_OUT = (archType == '8T') or  self.EN_QUANT;
        # Wrap correct CONV_op function
        if(archType == '6T'):
            CONV_op = CONV_op_diff;
        elif(archType == '8T'):
            CONV_op = CONV_op_se;
        else:
            raise NameError('Error: selected architecture (cell type) not supported during 2DCONV layer compute !\n');

        inverse_kernel_lr_multiplier = 1./self.kernel_lr_multiplier
        inputs_bnn_gradient = (inputs - (1. - 1./inverse_kernel_lr_multiplier) * K.stop_gradient(inputs))\
                  * inverse_kernel_lr_multiplier

        outputs_bnn_gradient = CONV_op(
            self.hardware,
            inputs_bnn_gradient,
            binary_kernel,
            self.Vt_noise,
            self.data_format,
            self.EN_NOISE,
            self.EN_QUANT)
        
        if(IS_SE_OUT):
            V_DP = (outputs_bnn_gradient - (1. - 1./self.kernel_lr_multiplier) * K.stop_gradient(outputs_bnn_gradient))\
                    * self.kernel_lr_multiplier
        else:
            V_BL  = (outputs_bnn_gradient[0] - (1. - 1./self.kernel_lr_multiplier) * K.stop_gradient(outputs_bnn_gradient[0]))\
                    * self.kernel_lr_multiplier
            V_BLB = (outputs_bnn_gradient[1] - (1. - 1./self.kernel_lr_multiplier) * K.stop_gradient(outputs_bnn_gradient[1]))\
                    * self.kernel_lr_multiplier
        
        if self.use_bias:
            if(IS_SE_OUT):
                V_DP = K.bias_add(V_DP,self.bias,data_format=self.data_format);
            else:
                V_BL  = K.bias_add(V_BL,self.bias,data_format=self.data_format);
                V_BLB = K.bias_add(V_BLB,self.bias,data_format=self.data_format);
                
        # Quantify the PA to get the digitized OA
        IAres = self.hardware.sramInfo.IAres
        OAres = self.hardware.sramInfo.OAres
        NB = self.hardware.sramInfo.NB.data
        PAmax = (2**IAres-1)*NB
        DRval = self.hardware.sramInfo.DR.data;
        VDD = self.hardware.sramInfo.VDD.data;
        if(self.EN_QUANT):
            DO = quant_uni(V_DP,PAmax,DRval,VDD,OAres,0.5*DRval/PAmax,archType);
            # Return digitized output
            return DO
        elif(archType == '8T'):
            return V_DP
        else:
            # Return unquantized differential output
            return (V_BL,V_BLB)
        
    def get_config(self):
        config = {'H': self.H,
                  'kernel_lr_multiplier': self.kernel_lr_multiplier,
                  'bias_lr_multiplier': self.bias_lr_multiplier}
        base_config = super(BinaryConv2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


# Aliases

BinaryConvolution2D = BinaryConv2D