Skip to content
Extraits de code Groupes Projets
create_C_header.py 4,68 ko
Newer Older
  • Learn to ignore specific revisions
  • #####################################################################################
    ################ Write C header file dedicated to the CNN params ####################
    #####################################################################################
    import sys, os
    import math
    
    import numpy as np
    
    def create_C_header(filename,network_info,cim_dim,D_VEC,P_VEC,TIME_CONF,GAMMA_VEC,BETA_FP_VEC,GAMMA_FP_VEC):
      # // Retrieve variables //
      # CNN network info
      Nlayers_cim = network_info[0];
      Nlayers_fp  = network_info[1];
      Nimg        = network_info[2];
      # CIM dims
      N_ROWS = cim_dim[0];
      N_COLS = cim_dim[1];
      # Channels
      H_IMG = D_VEC[0];
      W_IMG = D_VEC[1];
      C_IN  = D_VEC[2];
      C_OUT = D_VEC[3];
      # Precisions
      R_IN    = P_VEC[0];
      R_W     = P_VEC[1];
      R_OUT   = P_VEC[2];
      R_BETA  = P_VEC[3];
      R_GAMMA = P_VEC[4];
      # Timings
      T_DP    = TIME_CONF[0];
      T_PRE   = TIME_CONF[1];
      T_MBIT  = TIME_CONF[2];
      T_ADC   = TIME_CONF[3];
      
      # // Reshape FP beta-offset
      BETA_FP_VEC = np.reshape(BETA_FP_VEC,(Nlayers_fp,-1));
      Nbeta_fp = np.shape(BETA_FP_VEC)[-1];
      
      # // Write header file //
      # Open file
      fileID = open(filename,'w');
      # Header
      fileID.write('/*\n');
      fileID.write(' *-----------------------------------\n');
      fileID.write(' * Header file for CIM-QNN parameters\n');
      fileID.write(' *-----------------------------------\n');
      fileID.write('*/\n');
      fileID.write('\n');
      # Pre-processor statements
      fileID.write(f'#define N_ROWS {N_ROWS}\n');
      fileID.write(f'#define N_COLS {N_COLS}\n');
      fileID.write('\n');
    
      # Input img size
      fileID.write('// Input img size\n');
      fileID.write(f'uint8_t H_IMG = {H_IMG};\n')
      fileID.write(f'uint8_t W_IMG = {W_IMG};\n')
    
      # Layers & channels
      fileID.write('// Networks channels\n');
      fileID.write(f'uint16_t C_IN[{Nlayers_cim}] = {{'); 
      for i in range(len(C_IN)):
        if(i == 0):
          fileID.write(f'{C_IN[i]}');
        else:
          fileID.write(f',{C_IN[i]}');
      fileID.write('}\n');
      fileID.write(f'uint16_t C_OUT[{Nlayers_cim}] = {{');
      for i in range(len(C_OUT)):
        if(i == 0):
          fileID.write(f'{C_OUT[i]}');
        else:
          fileID.write(f',{C_OUT[i]}');
      fileID.write('}\n');
      fileID.write(f'uint8_t C_IN_LOG[{Nlayers_cim}] = {{'); 
      for i in range(len(C_IN)):
        if(i == 0):
          fileID.write(f'{int(math.log2(C_IN[i]))}');
        else:
          fileID.write(f',{int(math.log2(C_IN[i]))}');
      fileID.write('}\n');
      fileID.write(f'uint8_t C_OUT_LOG[{Nlayers_cim}] = {{');
      for i in range(len(C_OUT)):
        if(i == 0):
          fileID.write(f'{int(math.log2(C_OUT[i]))}');
        else:
          fileID.write(f',{int(math.log2(C_OUT[i]))}');
      fileID.write('}\n');
    
      # Precision
      fileID.write('// Computing precision\n');
      fileID.write(f'uint8_t R_IN  = {R_IN}; uint8_t R_IN_LOG  = {int(math.log2(R_IN))};\n');
      fileID.write(f'uint8_t R_W   = {R_W}; uint8_t R_W_LOG   = {int(math.log2(R_W))};\n');
      fileID.write(f'uint8_t R_OUT = {R_OUT}; uint8_t R_OUT_LOG = {int(math.log2(R_OUT))};\n'); 
      fileID.write(f'uint8_t R_BETA  = {R_BETA};\n');
      fileID.write(f'uint8_t R_GAMMA = {R_GAMMA};\n');
      fileID.write('\n');
      
      # Timing configs
      fileID.write('// Timing configuration\n');
      fileID.write(f'uint8_t T_DP_CONF  = {T_DP};\n');
      fileID.write(f'uint8_t T_PRE_CONF = {T_PRE};\n');
      fileID.write(f'uint8_t T_MBIT_CONF = {T_MBIT};\n');
      fileID.write(f'uint8_t T_ADC_CONF  = {T_ADC};\n');
      fileID.write('\n');
      
      # Number of samples
      fileID.write(f'uint8_t Nimg = {Nimg};\n');
      fileID.write('\n');
      
      # ABN params
      fileID.write('// ABN CIM gain \n');
      # Gain values
      fileID.write(f'uint8_t GAMMA[{Nlayers_cim}] = {{');
      for i in range(Nlayers_cim):
        if(i==0):
          fileID.write(f'{GAMMA_FP_VEC[i]}');
        else:
          fileID.write(f',{GAMMA_FP_VEC[i]}');
      fileID.write(f'}};\n');
      fileID.write('\n');
    
      # ABN params
      fileID.write('// ABN FP parameters\n');
      # Gain values
      fileID.write(f'uint32_t GAMMA_FP[{Nlayers_fp}] = {{');
      for i in range(Nlayers_fp):
        if(i==0):
          fileID.write(f'{GAMMA_FP_VEC[i]}');
        else:
          fileID.write(f',{GAMMA_FP_VEC[i]}');
      fileID.write(f'}};\n');
      fileID.write('\n');
      # Offsets value
      fileID.write(f'uint32_t BETA_FP[{Nlayers_fp}][{Nbeta_fp}] = {{\n');
      for i in range(Nlayers_fp):
        fileID.write(f'{{');
        for j in range(Nbeta_fp):
          if(j==0):
            fileID.write(f'{hex(BETA_FP_VEC[i][j])}');
          else:
            fileID.write(f',{hex(BETA_FP_VEC[i][j])}');
        if(i==R_BETA-1):
          fileID.write(f'}}\n');
        else:
          fileID.write(f'}},\n');
      fileID.write(f'}};\n');
      fileID.write('\n');
      
      # Close file and return
      fileID.close();
      
      
      return;