from keras.models import Sequential, Model
from keras import regularizers
from keras.layers import Reshape, Activation, Conv2D, Input, MaxPooling2D, BatchNormalization, Flatten, Dense, Lambda, concatenate
from keras.layers.advanced_activations import LeakyReLU
from keras.regularizers import l2
import numpy as np

from layers.custom_regu import Reg_abn_out, Reg_l2_p
#from layers.analog_BN_current_model import Analog_BN
from layers.analog_BN_current_interp_PL import Analog_BN
from layers.binary_layers_IMC import BinaryConv2D,BinaryDense
from layers.quantized_layers_IMC import QuantizedConv2D,QuantizedDense
from layers.quantized_layers_IMC_ABN import QuantizedDenseABN
from layers.quantized_ops import my_quantized_relu as quantize_op
from layers.binary_ops import binary_tanh as binary_tanh_op
from layers.binary_ops import binary_sigmoid as binary_sigmoid_op
from layers.binary_ops import binary_sigmoid_abn, binary_sigmoid_p, binary_tanh, binary_tanh_p
from models.ADC import quant_uni,Quant_train
from models.makeModel import make_model
# Hardware parameters generation
from utils.config_hardware_model import genHardware

from copy import deepcopy


def build_model(cf,model_type,sramInfo,EN_NOISE,EN_QUANT,ABN_INC_ADC):
    # Useful build variables
    IAres = sramInfo.IAres;
    Wres  = sramInfo.Wres;
    OAres = sramInfo.OAres;
    dynRange = sramInfo.VDD.data-0.108-0.04; # To be updated --> incorporate quantization directly inside IMC layer, with an EN flag
    H = 1.
		
    print('###################################################')
    print('########### BUILDING CIM-SRAM NETWORK #############')
    print('###################################################')
    
    def binary_sigmoid(x):
        return binary_sigmoid_op(x)
    def quant_relu(x,IAres):
        return quantize_op(x=x,IAres=IAres)
    
    if cf.network_type =='float':
        Conv_ = lambda s, f, i, c: Conv2D(kernel_size=(s, s), filters=f, strides=(1, 1), padding='same', activation='linear',
                                   kernel_regularizer=l2(cf.kernel_regularizer),input_shape = (i,i,c),use_bias=False,sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        Conv = lambda s, f: Conv2D(kernel_size=(s, s), filters=f, strides=(1, 1), padding='same', activation='linear',
                                   kernel_regularizer=l2(cf.kernel_regularizer),use_bias=False,sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        Act = lambda: LeakyReLU()
        
        Quant = lambda n: Activation(lambda x: quant_uni(x,maxVal=n,dynRange=dynRange,OAres=OAres,offset=0.5*dynRange/n))
        
        Dens_FP = lambda n: Dense(n,use_bias=False)
        
        Dens = lambda n: Dense(n,use_bias=False,sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        
        Dens_ = lambda n,i,c:  Dense(n,use_bias=False,activation='linear',input_shape=(i*i*c,),sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
    elif cf.network_type=='qnn':
        Conv_ = lambda s,f,i,c,m,k: QuantizedConv2D(kernel_size=(s, s), H=1, m_T_DP=m, nRep=k, nb=Wres, filters=f, strides=(1, 1), padding='same',
                                         activation='linear', kernel_regularizer=l2(cf.kernel_regularizer),
                                         kernel_lr_multiplier=cf.kernel_lr_multiplier,input_shape = (i,i,c),use_bias=False,sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        Conv = lambda s,f,m,k: QuantizedConv2D(kernel_size=(s, s), H=1, m_T_DP=m, nRep=k, nb=Wres, filters=f, strides=(1, 1), padding='same',
                                         activation='linear', kernel_regularizer=l2(cf.kernel_regularizer),
                                         kernel_lr_multiplier=cf.kernel_lr_multiplier,use_bias=False,sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        if(EN_QUANT):
            # Act = lambda: LeakyReLU()
            Act = lambda: Activation(lambda x: quant_relu(x,IAres=IAres))
        else:
            # Act = lambda: Activation(lambda x: binary_sigmoid_abn(x,sramInfo.VDD.data))
            Act = lambda: Quant_train(sramInfo)
            # Act = lambda: Activation(lambda x: quant_uni(x,maxVal=0,dynRange=dynRange,VDD=sramInfo.VDD.data,OAres=OAres,offset=0,archType=sramInfo.arch.name));
      
        Quant = lambda p: Activation('linear');
        
        Dens_FP = lambda n: Dense(n,use_bias=False)
        
        Dens = lambda n,m: QuantizedDense(n,nb=Wres,m_T_DP=m,use_bias=False,activation='linear',sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        
        Dens_ = lambda n,i,c,m:  QuantizedDense(n,nb=Wres,m_T_DP=m,use_bias=False,activation='linear',input_shape=(i*i*c,),sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
    elif cf.network_type=='full-qnn':
        Conv_ = lambda s,f,i,c,m,k: QuantizedConv2D(kernel_size=(s, s), H=1, m_T_DP=m, nRep=k, nb=Wres, filters=f, strides=(1, 1), padding='same',
                                         activation='linear', kernel_regularizer=l2(cf.kernel_regularizer),
                                         kernel_lr_multiplier=cf.kernel_lr_multiplier,input_shape = (i,i,c),use_bias=False,sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        Conv = lambda s,f,m,k: QuantizedConv2D(kernel_size=(s, s), H=1, m_T_DP=m, nRep=k, nb=Wres, filters=f, strides=(1, 1), padding='same',
                                         activation='linear', kernel_regularizer=l2(cf.kernel_regularizer),
                                         kernel_lr_multiplier=cf.kernel_lr_multiplier,use_bias=False,sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        Conv_FP_ = lambda s, f, i, c: Conv2D(kernel_size=(s, s), filters=f, strides=(1, 1), padding='same', activation='linear',
                                   kernel_regularizer=l2(cf.kernel_regularizer),input_shape = (i,i,c),use_bias=False)
        
        if(EN_QUANT):
            Act = lambda: Activation(lambda x: quant_relu(x,IAres=IAres))
        else:
            # Act = lambda: Activation(lambda x: binary_sigmoid_abn(x,sramInfo.VDD.data))
            Act = lambda: Quant_train(sramInfo,ABN_INC_ADC)
            # Act = lambda: Activation(lambda x: quant_uni(x,maxVal=0,dynRange=dynRange,VDD=sramInfo.VDD.data,OAres=OAres,offset=0,archType=sramInfo.arch.name));
       
        # Quant = lambda n: Activation(lambda x: quant_uni(x,maxVal=n,dynRange=dynRange,VDD=sramInfo.VDD.data,OAres=OAres,offset=0.5*dynRange/n,archType=sramInfo.arch.name))
        Quant = lambda n: Activation(lambda x: quant_uni(x,maxVal=n,dynRange=dynRange,VDD=sramInfo.VDD.data,OAres=OAres,offset=0.,archType=sramInfo.arch.name))
    
        Dens_FP = lambda n: Dense(n,use_bias=False)
        
        Dens = lambda n,m: QuantizedDense(n,nb=Wres,m_T_DP=m,use_bias=False,activation='linear',sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        
        Dens_ = lambda n,i,c,m:  QuantizedDense(n,nb=Wres,m_T_DP=m,use_bias=False,activation='linear',input_shape=(i*i*c,),sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
    elif cf.network_type=='full-qnn-embedded':
        Conv_ = lambda s,f,i,c,m,k: QuantizedConv2D(kernel_size=(s, s), H=1, m_T_DP=m, nRep=k, nb=Wres, filters=f, strides=(1, 1), padding='same',
                                         activation='linear', kernel_regularizer=l2(cf.kernel_regularizer),
                                         kernel_lr_multiplier=cf.kernel_lr_multiplier,input_shape = (i,i,c),use_bias=False,sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        Conv = lambda s,f,m,k: QuantizedConv2D(kernel_size=(s, s), H=1, m_T_DP=m, nRep=k, nb=Wres, filters=f, strides=(1, 1), padding='same',
                                         activation='linear', kernel_regularizer=l2(cf.kernel_regularizer),
                                         kernel_lr_multiplier=cf.kernel_lr_multiplier,use_bias=False,sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        if(EN_QUANT):
            Act = lambda: Activation(lambda x: quant_relu(x,IAres=IAres))
        else:
            # Act = lambda: Activation(lambda x: binary_sigmoid_abn(x,sramInfo.VDD.data))
            Act = lambda: Quant_train(sramInfo)
            # Act = lambda: Activation(lambda x: quant_uni(x,maxVal=0,dynRange=dynRange,VDD=sramInfo.VDD.data,OAres=OAres,offset=0,archType=sramInfo.arch.name));
       
        # Quant = lambda n: Activation(lambda x: quant_uni(x,maxVal=n,dynRange=dynRange,VDD=sramInfo.VDD.data,OAres=OAres,offset=0.5*dynRange/n,archType=sramInfo.arch.name))
        Quant = lambda n: Activation(lambda x: quant_uni(x,maxVal=n,dynRange=dynRange,VDD=sramInfo.VDD.data,OAres=OAres,offset=0.,archType=sramInfo.arch.name))
    
        Dens_FP = lambda n: Dense(n,use_bias=False)
        
        Dens = lambda n,m: QuantizedDenseABN(n,nb=Wres,m_T_DP=m,use_bias=False,activation='linear',sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT,m_sigma=4)
        
        Dens_ = lambda n,i,c,m:  QuantizedDenseABN(n,nb=Wres,m_T_DP=m,use_bias=False,activation='linear',input_shape=(i*i*c,),sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT,m_sigma=4)
    elif cf.network_type=='bnn':
        Conv_ = lambda s, f,i,c: BinaryConv2D(kernel_size=(s, s), H=1, filters=f, strides=(1, 1), padding='same',
                                         activation='linear', kernel_regularizer=l2(cf.kernel_regularizer),
                                         kernel_lr_multiplier=cf.kernel_lr_multiplier,input_shape = (i,i,c),use_bias=False,sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        Conv = lambda s, f: BinaryConv2D(kernel_size=(s, s), H=1, filters=f, strides=(1, 1), padding='same',
                                         activation='linear', kernel_regularizer=l2(cf.kernel_regularizer),
                                         kernel_lr_multiplier=cf.kernel_lr_multiplier,use_bias=False,sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        Act = lambda: LeakyReLU()
        
        # Quant = lambda n: Activation(lambda x: quant_uni(x,maxVal=n,dynRange=dynRange,OAres=OAres,offset=0.5*dynRange/n))
        Quant = lambda p: Activation('linear');
        
        Dens_FP = lambda n: Dense(n,use_bias=False)
       
        Dens = lambda n: BinaryDense(n,use_bias=False,activation='linear',sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        
        Dens_ = lambda n,i,c:  BinaryDense(n,use_bias=False,activation='linear',input_shape=(i*i*c,),sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
    elif cf.network_type=='full-bnn':
        Conv_ = lambda s, f,i,c: BinaryConv2D(kernel_size=(s, s), H=1, filters=f, strides=(1, 1), padding='same',
                                         activation='linear', kernel_regularizer=l2(cf.kernel_regularizer),
                                         kernel_lr_multiplier=cf.kernel_lr_multiplier,input_shape = (i,i,c),use_bias=False,sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        Conv = lambda s, f: BinaryConv2D(kernel_size=(s, s), H=1, filters=f, strides=(1, 1), padding='same',
                                         activation='linear', kernel_regularizer=l2(cf.kernel_regularizer),
                                         kernel_lr_multiplier=cf.kernel_lr_multiplier,use_bias=False,sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        if(EN_QUANT):
            Act = lambda: Activation(lambda x: binary_sigmoid(x))
        else:
            # Act = lambda: Activation(lambda x: binary_sigmoid_abn(x,sramInfo.VDD.data))
            Act = lambda: Quant_train(sramInfo)
            # Act = lambda: Activation(lambda x: quant_uni(x,maxVal=0,dynRange=dynRange,VDD=sramInfo.VDD.data,OAres=OAres,offset=0,archType=sramInfo.arch.name));
        # Quant = lambda n: Activation(lambda x: quant_uni(x,maxVal=n,dynRange=dynRange,OAres=OAres,offset=0.5*dynRange/n));
        Quant = lambda n: Activation(lambda x: quant_uni(x,maxVal=n,dynRange=dynRange,VDD=sramInfo.VDD.data,OAres=OAres,offset=0.,archType=sramInfo.arch.name))
           
        Dens_FP = lambda n: Dense(n,use_bias=False)
        
        Dens = lambda n: BinaryDense(n,use_bias=False,activation='linear',sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
        
        Dens_ = lambda n,i,c:  BinaryDense(n,use_bias=False,activation='linear',input_shape=(i*i*c,),sramInfo=deepcopy(sramInfo),EN_NOISE=EN_NOISE,EN_QUANT=EN_QUANT)
    else:
        print('wrong network type, the supported network types in this repo are float, qnn, full-qnn, bnn and full-bnn')

    if(EN_QUANT):
        BatchNorm = lambda: BatchNormalization(momentum=0.1,epsilon=1e-5)
    else:
        if(cf.network_type == 'full-qnn-embedded'):
            BatchNorm = lambda n,m: Activation('linear');
        else:
            BatchNorm = lambda n,m: Analog_BN(momentum=0.1,epsilon=1e-5,renorm=True,hardware=genHardware(sramInfo),NB=n,m_sigma=m,EN_NOISE=EN_NOISE
                                            # center=False,scale=False,
                                            # gamma_regularizer=l2(0.001),beta_regularizer=l2(0.001))
                                            # activity_regularizer=Reg_abn_out(1e-5,sramInfo.VDD.data))
                                            # activity_regularizer=Reg_l2_p(0.,0.5)
                                            );
    
    BatchNorm_FP = lambda: BatchNormalization(momentum=0.1,epsilon=1e-5)
    
    model = make_model(model_type,cf,Conv_,Conv,Dens_,Dens,Act,Quant,BatchNorm,Dens_FP,BatchNorm_FP,Conv_FP_);
    return model


def load_weights(model, weight_reader):
    weight_reader.reset()

    for i in range(len(model.layers)):
        if 'conv' in model.layers[i].name:
            if 'batch' in model.layers[i + 1].name:
                norm_layer = model.layers[i + 1]
                size = np.prod(norm_layer.get_weights()[0].shape)

                beta = weight_reader.read_bytes(size)
                gamma = weight_reader.read_bytes(size)
                mean = weight_reader.read_bytes(size)
                var = weight_reader.read_bytes(size)

                weights = norm_layer.set_weights([gamma, beta, mean, var])

            conv_layer = model.layers[i]
            if len(conv_layer.get_weights()) > 1:
                bias = weight_reader.read_bytes(np.prod(conv_layer.get_weights()[1].shape))
                kernel = weight_reader.read_bytes(np.prod(conv_layer.get_weights()[0].shape))
                kernel = kernel.reshape(list(reversed(conv_layer.get_weights()[0].shape)))
                kernel = kernel.transpose([2, 3, 1, 0])
                conv_layer.set_weights([kernel, bias])
            else:
                kernel = weight_reader.read_bytes(np.prod(conv_layer.get_weights()[0].shape))
                kernel = kernel.reshape(list(reversed(conv_layer.get_weights()[0].shape)))
                kernel = kernel.transpose([2, 3, 1, 0])
                conv_layer.set_weights([kernel])
    return model