#####################################################################################
################ Write C header file dedicated to the CNN params ####################
#####################################################################################
import sys, os
import math

import numpy as np

def create_C_header(filename,network_info,cim_dim,D_VEC,P_VEC,TIME_CONF,GAMMA_VEC,BETA_FP_VEC,GAMMA_FP_VEC):
  # // Retrieve variables //
  # CNN network info
  Nlayers_cim = network_info[0];
  Nlayers_fp  = network_info[1];
  Nimg        = network_info[2];
  # CIM dims
  N_ROWS = cim_dim[0];
  N_COLS = cim_dim[1];
  # Channels
  H_IMG = D_VEC[0];
  W_IMG = D_VEC[1];
  C_IN  = D_VEC[2];
  C_OUT = D_VEC[3];
  # Precisions
  R_IN    = P_VEC[0];
  R_W     = P_VEC[1];
  R_OUT   = P_VEC[2];
  R_BETA  = P_VEC[3];
  R_GAMMA = P_VEC[4];
  # Timings
  T_DP    = TIME_CONF[0];
  T_PRE   = TIME_CONF[1];
  T_MBIT  = TIME_CONF[2];
  T_ADC   = TIME_CONF[3];
  
  # // Reshape FP beta-offset
  GAMMA_FP_VEC = np.reshape(GAMMA_FP_VEC,(Nlayers_fp,-1));
  BETA_FP_VEC = np.reshape(BETA_FP_VEC,(Nlayers_fp,-1));
  Nbeta_fp = np.shape(BETA_FP_VEC)[-1];
  
  # // Write header file //
  # Open file
  fileID = open(filename,'w');
  # Header
  fileID.write('/*\n');
  fileID.write(' *-----------------------------------\n');
  fileID.write(' * Header file for CIM-QNN parameters\n');
  fileID.write(' *-----------------------------------\n');
  fileID.write('*/\n');
  fileID.write('\n');
  # Pre-processor statements
  fileID.write(f'#define N_ROWS {N_ROWS}\n');
  fileID.write(f'#define N_COLS {N_COLS}\n');
  fileID.write('\n');

  # Input img size
  fileID.write('// Input img size\n');
  fileID.write(f'uint8_t H_IMG = {H_IMG};\n')
  fileID.write(f'uint8_t W_IMG = {W_IMG};\n')

  # Layers & channels
  fileID.write('// Networks channels\n');
  fileID.write(f'uint16_t C_IN[{Nlayers_cim-START_LAYER}] = {{'); 
  for i in range(len(C_IN)):
    if(i == 0):
      fileID.write(f'{C_IN[i]}');
    else:
      fileID.write(f',{C_IN[i]}');
  fileID.write('}\n');
  fileID.write(f'uint16_t C_OUT[{Nlayers_cim-START_LAYER}] = {{');
  for i in range(len(C_OUT)):
    if(i == 0):
      fileID.write(f'{C_OUT[i]}');
    else:
      fileID.write(f',{C_OUT[i]}');
  fileID.write('}\n');
  fileID.write(f'uint8_t C_IN_LOG[{Nlayers_cim-START_LAYER}] = {{'); 
  for i in range(len(C_IN)):
    if(i == 0):
      fileID.write(f'{int(math.log2(C_IN[i]))}');
    else:
      fileID.write(f',{int(math.log2(C_IN[i]))}');
  fileID.write('}\n');
  fileID.write(f'uint8_t C_OUT_LOG[{Nlayers_cim-START_LAYER}] = {{');
  for i in range(len(C_OUT)):
    if(i == 0):
      fileID.write(f'{int(math.log2(C_OUT[i]))}');
    else:
      fileID.write(f',{int(math.log2(C_OUT[i]))}');
  fileID.write('}\n');

  # Precision
  fileID.write('// Computing precision\n');
  fileID.write(f'uint8_t R_IN  = {R_IN}; uint8_t R_IN_LOG  = {int(math.log2(R_IN))};\n');
  fileID.write(f'uint8_t R_W   = {R_W}; uint8_t R_W_LOG   = {int(math.log2(R_W))};\n');
  fileID.write(f'uint8_t R_OUT = {R_OUT}; uint8_t R_OUT_LOG = {int(math.log2(R_OUT))};\n'); 
  fileID.write(f'uint8_t R_BETA  = {R_BETA};\n');
  fileID.write(f'uint8_t R_GAMMA = {R_GAMMA};\n');
  fileID.write('\n');
  
  # Timing configs
  fileID.write('// Timing configuration\n');
  fileID.write(f'uint8_t T_DP_CONF  = {T_DP};\n');
  fileID.write(f'uint8_t T_PRE_CONF = {T_PRE};\n');
  fileID.write(f'uint8_t T_MBIT_CONF = {T_MBIT};\n');
  fileID.write(f'uint8_t T_ADC_CONF  = {T_ADC};\n');
  fileID.write('\n');
  
  # Number of samples and layers
  fileID.write(f'uint8_t Nimg = {Nimg};\n');
  fileID.write(f'uint8_t Nlayers_cim = {Nlayers_cim};\n');
  fileID.write(f'uint8_t Nlayers_fp = {Nlayers_fp};\n');
  fileID.write('\n');
  
  # ABN params
  fileID.write('// ABN CIM gain \n');
  # Gain values
  fileID.write(f'uint8_t GAMMA[{Nlayers_cim}] = {{');
  for i in range(Nlayers_cim):
    if(i==0):
      fileID.write(f'{GAMMA_FP_VEC[i]}');
    else:
      fileID.write(f',{GAMMA_FP_VEC[i]}');
  fileID.write(f'}};\n');
  fileID.write('\n');

  # ABN params
  fileID.write('// ABN FP parameters\n');
  # Gain values
  fileID.write(f'uint32_t GAMMA_FP[{Nlayers_fp}] = {{');
  for i in range(Nlayers_fp):
    if(i==0):
      fileID.write(f'{GAMMA_FP_VEC[i]}');
    else:
      fileID.write(f',{GAMMA_FP_VEC[i]}');
  fileID.write(f'}};\n');
  fileID.write('\n');
  # Offsets value
  fileID.write(f'uint32_t BETA_FP[{Nlayers_fp}][{Nbeta_fp}] = {{\n');
  for i in range(Nlayers_fp):
    fileID.write(f'{{');
    for j in range(Nbeta_fp):
      if(j==0):
        fileID.write(f'{hex(BETA_FP_VEC[i][j])}');
      else:
        fileID.write(f',{hex(BETA_FP_VEC[i][j])}');
    if(i==R_BETA-1):
      fileID.write(f'}}\n');
    else:
      fileID.write(f'}},\n');
  fileID.write(f'}};\n');
  fileID.write('\n');
  
  # Close file and return
  fileID.close();
  
  
  return;
  
  
def create_C_header_subset(filename,network_info,cim_dim,D_VEC,P_VEC,TIME_CONF,GAMMA_VEC,BETA_FP_VEC,GAMMA_FP_VEC,data_cim,START_LAYER):
  # // Retrieve variables //
  # CNN network info
  Nlayers_cim = network_info[0];
  Nlayers_fp  = network_info[1];
  Nimg        = network_info[2];
  # CIM dims
  N_ROWS = cim_dim[0];
  N_COLS = cim_dim[1];
  # Channels
  H_IMG = D_VEC[0];
  W_IMG = D_VEC[1];
  C_IN  = D_VEC[2];
  C_OUT = D_VEC[3];
  # Precisions
  R_IN    = P_VEC[0];
  R_W     = P_VEC[1];
  R_OUT   = P_VEC[2];
  R_BETA  = P_VEC[3];
  R_GAMMA = P_VEC[4];
  # Timings
  T_DP    = TIME_CONF[0];
  T_PRE   = TIME_CONF[1];
  T_MBIT  = TIME_CONF[2];
  T_ADC   = TIME_CONF[3];
  # CIM data, starting at the chosen layer index
  data_in = data_cim[-1][START_LAYER-1];
  data_w  = data_cim[1];
  data_b  = data_cim[2];
  data_w_fp = data_cim[3];
  data_inf = data_cim[4];
  
  # Reshape input data
  data_in = np.reshape(data_in,(Nimg,-1));
  
  # Reshape CIM offset
  beta_conf_list = [];
  for i in range(Nlayers_cim):
    beta_conf_temp = np.expand_dims(data_b[i].astype("uint8"),axis=-1); 
    beta_unpacked = np.flip(np.unpackbits(beta_conf_temp,axis=-1),axis=-1);
    # swap axes
    beta_unpacked = np.swapaxes(beta_unpacked,0,1);
    # Repeat beta values in r_w cols
    beta_unpacked = np.repeat(beta_unpacked,R_W,axis=-1);
    if(R_W*C_OUT[i] < 32):
      beta_unpacked = np.pad(beta_unpacked,((0,0),(0,32-R_W*C_OUT[i])));
    beta_conf_temp = np.dot(np.reshape(beta_unpacked[:R_BETA,...],(-1,32)),2**np.arange(32));
    beta_conf_list.append(beta_conf_temp);
  #Stack results along a single dimension
  data_b = beta_conf_list;
  
  # // Reshape FP beta-offset
  GAMMA_FP_VEC = np.reshape(GAMMA_FP_VEC,(Nlayers_fp,-1));
  BETA_FP_VEC = np.reshape(BETA_FP_VEC,(Nlayers_fp,-1));
  Nbeta_fp = np.shape(BETA_FP_VEC)[-1];
  
  # // Write header file //
  # Open file
  fileID = open(filename,'w');
  # Header
  fileID.write('/*\n');
  fileID.write(' *-----------------------------------\n');
  fileID.write(' * Header file for CIM-QNN parameters\n');
  fileID.write(' *-----------------------------------\n');
  fileID.write('*/\n');
  fileID.write('\n');
  # Pre-processor statements
  fileID.write(f'#define N_ROWS {N_ROWS}\n');
  fileID.write(f'#define N_COLS {N_COLS}\n');
  fileID.write('\n');

  # Input img size
  fileID.write('// Input img size\n');
  fileID.write(f'uint8_t H_IMG = {H_IMG};\n')
  fileID.write(f'uint8_t W_IMG = {W_IMG};\n')

  # Layers & channels
  fileID.write('// Networks channels\n');
  fileID.write(f'uint16_t C_IN[{Nlayers_cim-START_LAYER}] = {{'); 
  for i in range(START_LAYER,len(C_IN)):
    if(i == START_LAYER):
      fileID.write(f'{C_IN[i]}');
    else:
      fileID.write(f',{C_IN[i]}');
  fileID.write('};\n');
  fileID.write(f'uint16_t C_OUT[{Nlayers_cim-START_LAYER}] = {{');
  for i in range(START_LAYER,len(C_OUT)):
    if(i == START_LAYER):
      fileID.write(f'{C_OUT[i]}');
    else:
      fileID.write(f',{C_OUT[i]}');
  fileID.write('};\n');
  fileID.write(f'uint8_t C_IN_LOG[{Nlayers_cim-START_LAYER}] = {{'); 
  for i in range(START_LAYER,len(C_IN)):
    if(i == START_LAYER):
      fileID.write(f'{int(math.log2(C_IN[i]))}');
    else:
      fileID.write(f',{int(math.log2(C_IN[i]))}');
  fileID.write('};\n');
  fileID.write(f'uint8_t C_OUT_LOG[{Nlayers_cim-START_LAYER}] = {{');
  for i in range(START_LAYER,len(C_OUT)):
    if(i == START_LAYER):
      fileID.write(f'{int(math.log2(C_OUT[i]))}');
    else:
      fileID.write(f',{int(math.log2(C_OUT[i]))}');
  fileID.write('};\n');
  fileID.write('// FP channels \n');
  fileID.write(f'uint16_t C_IN_FP[{Nlayers_fp}] = {{'); 
  for i in range(Nlayers_fp):
    if(i == 0):
      fileID.write(f'{C_IN[Nlayers_cim-1+i]}');
    else:
      fileID.write(f',{C_IN[Nlayers_cim-1+i]}');
  fileID.write('};\n');
  fileID.write(f'uint16_t C_OUT_FP[{Nlayers_fp}] = {{');
  for i in range(Nlayers_fp):
    if(i == 0):
      fileID.write(f'{C_OUT[Nlayers_cim-1+i]}');
    else:
      fileID.write(f',{C_OUT[Nlayers_cim-1+i]}');
  fileID.write('};\n');

  # Precision
  fileID.write('// Computing precision\n');
  fileID.write(f'uint8_t R_IN  = {R_IN}; uint8_t R_IN_LOG  = {int(math.log2(R_IN))};\n');
  fileID.write(f'uint8_t R_W   = {R_W}; uint8_t R_W_LOG   = {int(math.log2(R_W))};\n');
  fileID.write(f'uint8_t R_OUT = {R_OUT}; uint8_t R_OUT_LOG = {int(math.log2(R_OUT))};\n'); 
  fileID.write(f'uint8_t R_BETA  = {R_BETA};\n');
  fileID.write(f'uint8_t R_GAMMA = {R_GAMMA};\n');
  fileID.write('\n');
  
  # Timing configs
  fileID.write('// Timing configuration\n');
  fileID.write(f'uint8_t T_DP_CONF  = {T_DP};\n');
  fileID.write(f'uint8_t T_PRE_CONF = {T_PRE};\n');
  fileID.write(f'uint8_t T_MBIT_IN_CONF = {T_MBIT};\n');
  fileID.write(f'uint8_t T_MBIT_W_CONF = {T_MBIT};\n');
  fileID.write(f'uint8_t T_ADC_CONF  = {T_ADC};\n');
  fileID.write(f'uint8_t T_REF_CONF  = {T_ADC};\n');
  fileID.write('\n');
  
  # Number of samples and layers
  fileID.write(f'uint8_t Nimg = {Nimg};\n');
  fileID.write(f'uint8_t Nlayers_cim = {Nlayers_cim-START_LAYER};\n');
  fileID.write(f'uint8_t Nlayers_fp = {Nlayers_fp};\n');
  fileID.write('\n');
  
  # Inputs
  fileID.write('// Input data \n');
  fileID.write(f'uint32_t DATA_IN[{Nimg}][{np.shape(data_in)[1]}] = {{');
  img_size = np.shape(data_in)[1];
  for i in range(Nimg):
    fileID.write('{');
    for j in range(img_size):
      if(j==img_size-1):
        fileID.write(f'{data_in[i,j]}\n');
      else:
        fileID.write(f'{data_in[i,j]},\n');
    if(i == Nimg-1):
      fileID.write('}\n');
    else:
      fileID.write('},\n');
  fileID.write(f'}};\n');
  fileID.write('\n');
 
  # Weights
  fileID.write('// Weight data \n');
  max_w = np.size(data_w[START_LAYER]); # ! Only valid for FC networks
  fileID.write(f'uint32_t W_CIM[{Nlayers_cim-START_LAYER}][{max_w}] = {{');
  for i in range(START_LAYER,Nlayers_cim):
    fileID.write('{');
    layer_size = np.size(data_w[i]);
    for j in range(max_w):
      if(j<layer_size):
        if(j==max_w-1):
          fileID.write(f'{data_w[i][j]}\n');
        else:
          fileID.write(f'{data_w[i][j]},\n');
      else:
        if(j==max_w-1):
          fileID.write(f'0x0\n');
        else:
          fileID.write(f'0x0,\n');
    if(i == Nlayers_cim-1):
      fileID.write('}\n');
    else:
      fileID.write('},\n');
  fileID.write(f'}};\n');
  fileID.write('\n');
 
  # ABN params
  fileID.write('// ABN CIM gain \n');
  # Gain values
  fileID.write(f'uint8_t GAMMA[{Nlayers_cim-START_LAYER}] = {{');
  for i in range(START_LAYER,Nlayers_cim):
    if(i==START_LAYER):
      fileID.write(f'{GAMMA_VEC[i]}');
    else:
      fileID.write(f',{GAMMA_VEC[i]}');
  fileID.write(f'}};\n');
  fileID.write('\n');
  
  fileID.write('// ABN CIM offset \n');
  max_b = np.size(data_b[START_LAYER]); # ! Only valid for FC networks
  fileID.write(f'uint32_t B_CIM[{Nlayers_cim-START_LAYER}][{max_b}] = {{');
  for i in range(START_LAYER,Nlayers_cim):
    fileID.write('{');
    layer_size = np.size(data_b[i]);
    for j in range(max_b):
      if(j<layer_size):
        if(j==max_b-1):
          fileID.write(f'{data_b[i][j]}\n');
        else:
          fileID.write(f'{data_b[i][j]},\n');
      else:
        if(j==max_b-1):
          fileID.write(f'0x0\n');
        else:
          fileID.write(f'0x0,\n');
    if(i == Nlayers_cim-1):
      fileID.write('}\n');
    else:
      fileID.write('},\n');
  fileID.write(f'}};\n');
  fileID.write('\n');
  
  fileID.write('// FP weights \n');
  max_w = np.size(data_w_fp[0]); # ! Only valid for FC networks
  fileID.write(f'uint32_t W_FP[{Nlayers_fp}][{max_w}] = {{');
  for i in range(Nlayers_fp):
    fileID.write('{');
    layer_size = np.size(data_w_fp[i]);
    for j in range(max_w):
      if(j<layer_size):
        if(j==max_w-1):
          fileID.write(f'{data_w_fp[i][j]}\n');
        else:
          fileID.write(f'{data_w_fp[i][j]},\n');
      else:
        if(j==max_w-1):
          fileID.write(f'0x0\n');
        else:
          fileID.write(f'0x0,\n');
    if(i == Nlayers_fp-1):
      fileID.write('}\n');
    else:
      fileID.write('},\n');
  fileID.write(f'}};\n');
  fileID.write('\n');
  
  # ABN params
  fileID.write('// ABN FP parameters\n');
  # Gain values
  fileID.write(f'uint32_t GAMMA_FP[{Nlayers_fp}][{Nbeta_fp}] = {{');
  print(GAMMA_FP_VEC); print(BETA_FP_VEC)
  for i in range(Nlayers_fp):
    fileID.write(f'{{');
    for j in range(Nbeta_fp):
      if(j==0):
        fileID.write(f'{hex(GAMMA_FP_VEC[i][j])}');
      else:
        fileID.write(f',{hex(GAMMA_FP_VEC[i][j])}');
    if(i==Nlayers_fp-1):
      fileID.write(f'}}\n');
    else:
      fileID.write(f'}},\n');
  fileID.write(f'}};\n');
  fileID.write('\n');
  # Offsets value
  fileID.write(f'uint32_t BETA_FP[{Nlayers_fp}][{Nbeta_fp}] = {{\n');
  for i in range(Nlayers_fp):
    fileID.write(f'{{');
    for j in range(Nbeta_fp):
      if(j==0):
        fileID.write(f'{hex(BETA_FP_VEC[i][j])}');
      else:
        fileID.write(f',{hex(BETA_FP_VEC[i][j])}');
    if(i==Nlayers_fp-1):
      fileID.write(f'}}\n');
    else:
      fileID.write(f'}},\n');
  fileID.write(f'}};\n');
  fileID.write('\n');
  
  # ABN params
  fileID.write('// Inference results \n');
  # Gain values
  fileID.write(f'uint8_t inf_result[{Nimg}] = {{');
  for i in range(Nimg):
    if(i==0):
      fileID.write(f'{data_inf[i]}');
    else:
      fileID.write(f',{data_inf[i]}');
  fileID.write(f'}};\n');
  fileID.write('\n');
  
  # Close file and return
  fileID.close();
  
  
  return;