From eeeb37c1e9306daa0cd90273f14f2ea5d348a03e Mon Sep 17 00:00:00 2001
From: Adrian Kneip <adrian.kneip@uclouvain.be>
Date: Thu, 2 Nov 2023 13:03:19 +0100
Subject: [PATCH] Update models + Algo-to-C toolchain

---
 chip_files/cim_config.h              | 1593 +++++++++++++++++++++++++-
 chip_files/create_C_header.py        |  316 ++++-
 config/config_cim_cnn_param.py       |   43 +-
 config/config_sweep_param.py         |   24 +-
 layers/analog_BN_charge_interp_PL.py |   58 +-
 models/Analog_DP.py                  |    5 +-
 models/MAC_charge.py                 |    2 +-
 models/MBIT_unit.py                  |   18 +-
 models/makeModel.py                  |  182 ++-
 models/model_IMC.py                  |    7 +-
 my_datasets/my_cifar10.py            |    8 +-
 sw_to_chip.py                        |   87 +-
 train_cim_qnn.py                     |  117 +-
 utils/config_hardware_model.py       |   28 +-
 utils/linInterp.py                   |   18 +
 15 files changed, 2327 insertions(+), 179 deletions(-)

diff --git a/chip_files/cim_config.h b/chip_files/cim_config.h
index fffa515..d8aa162 100644
--- a/chip_files/cim_config.h
+++ b/chip_files/cim_config.h
@@ -5,16 +5,19 @@
 */
 
 #define N_ROWS 1152
-#define N_COLS 512
+#define N_COLS 256
 
 // Input img size
 uint8_t H_IMG = 28;
 uint8_t W_IMG = 28;
 // Networks channels
-uint16_t C_IN[2] = {9,72}
-uint16_t C_OUT[2] = {8,16}
-uint8_t C_IN_LOG[2] = {3,6}
-uint8_t C_OUT_LOG[2] = {3,4}
+uint16_t C_IN[1] = {128};
+uint16_t C_OUT[1] = {64};
+uint8_t C_IN_LOG[1] = {7};
+uint8_t C_OUT_LOG[1] = {6};
+// FP channels 
+uint16_t C_IN_FP[1] = {128};
+uint16_t C_OUT_FP[1] = {64};
 // Computing precision
 uint8_t R_IN  = 1; uint8_t R_IN_LOG  = 0;
 uint8_t R_W   = 1; uint8_t R_W_LOG   = 0;
@@ -25,19 +28,1587 @@ uint8_t R_GAMMA = 5;
 // Timing configuration
 uint8_t T_DP_CONF  = 3;
 uint8_t T_PRE_CONF = 3;
-uint8_t T_MBIT_CONF = 3;
+uint8_t T_MBIT_IN_CONF = 3;
+uint8_t T_MBIT_W_CONF = 3;
 uint8_t T_ADC_CONF  = 3;
+uint8_t T_REF_CONF  = 3;
 
 uint8_t Nimg = 128;
+uint8_t Nlayers_cim = 1;
+uint8_t Nlayers_fp = 1;
+
+// Input data 
+uint32_t DATA_IN[128][4] = {{960103632,
+3573579370,
+2714061332,
+42131640
+},
+{244710137,
+700473429,
+93289491,
+1141399995
+},
+{8686712,
+353011961,
+631092483,
+1114375241
+},
+{511198834,
+1894442824,
+711542922,
+638745498
+},
+{2285375286,
+383706058,
+2478895598,
+397756106
+},
+{8389752,
+352979193,
+630568195,
+1114374217
+},
+{2282245430,
+383707851,
+3552768494,
+468010184
+},
+{671616006,
+1442737738,
+2479477732,
+174383752
+},
+{2317045296,
+583721410,
+130675099,
+370172304
+},
+{671616022,
+1442737994,
+2479477732,
+182764232
+},
+{511267440,
+1894441800,
+711542922,
+637696922
+},
+{1113083450,
+384557464,
+131724642,
+2803124186
+},
+{671878454,
+1442770890,
+2479477732,
+720691912
+},
+{511263346,
+1894441800,
+711543434,
+637696922
+},
+{8430712,
+352979193,
+633189635,
+1114374153
+},
+{2517232680,
+887285845,
+1657634378,
+1611166396
+},
+{674827318,
+1442736970,
+2479477740,
+585425608
+},
+{691668176,
+3576201066,
+2714061332,
+42131672
+},
+{512354473,
+885712989,
+59594322,
+1212705340
+},
+{2284326326,
+383708106,
+2478895598,
+334841546
+},
+{673729815,
+1442770762,
+2478916324,
+182772456
+},
+{1113345592,
+116580784,
+132773218,
+2807318418
+},
+{1113083450,
+385606040,
+133723490,
+2274641882
+},
+{2518666416,
+1692592707,
+1120780250,
+1916826548
+},
+{2285374902,
+383708106,
+2478895598,
+328550090
+},
+{511197810,
+1894446920,
+727795850,
+638745498
+},
+{959972434,
+3576332106,
+2748402308,
+44228824
+},
+{2285374902,
+400485322,
+2478895598,
+334841546
+},
+{511267442,
+1894441800,
+711543434,
+637696922
+},
+{8389752,
+352979193,
+633189635,
+1114378249
+},
+{780791017,
+1709206109,
+562942546,
+1212705340
+},
+{8397944,
+352979193,
+666744067,
+1114378264
+},
+{512354536,
+616589917,
+831378258,
+1214806068
+},
+{192434034,
+315549642,
+443146666,
+387205018
+},
+{748372208,
+3842014050,
+2177192469,
+42131676
+},
+{244833008,
+2847957201,
+1171224851,
+1444438427
+},
+{716918864,
+3506995018,
+2714190353,
+42655960
+},
+{8397944,
+355076345,
+633189635,
+1114378249
+},
+{244447344,
+566255697,
+80566546,
+1174954905
+},
+{8422520,
+355076345,
+1740485891,
+1114378312
+},
+{8422512,
+355076345,
+633189635,
+1110179849
+},
+{683279440,
+3506995050,
+2713830149,
+42655896
+},
+{2282229046,
+400485322,
+2478895590,
+469059272
+},
+{177610424,
+197419472,
+97513779,
+387728283
+},
+{109964344,
+633367025,
+95278931,
+1112302360
+},
+{370011193,
+618850389,
+42858330,
+2017489852
+},
+{8408168,
+352979161,
+631092483,
+1248591881
+},
+{43802296,
+197157329,
+98497843,
+387728283
+},
+{2284326198,
+383708106,
+2478895598,
+468010696
+},
+{2285375414,
+400485322,
+2478895598,
+330647242
+},
+{1113345592,
+117105072,
+132756851,
+2270451602
+},
+{243920105,
+618719325,
+562942546,
+1346923068
+},
+{2671381840,
+1961553483,
+583115146,
+306213296
+},
+{2483932218,
+887810655,
+579436226,
+1648386740
+},
+{1179929146,
+921428376,
+133952866,
+2803124122
+},
+{511263346,
+1894507336,
+711583882,
+104888730
+},
+{2285375414,
+383708106,
+2478895598,
+330647242
+},
+{8651888,
+355076329,
+631092483,
+1114374217
+},
+{672664854,
+1442737994,
+2479477732,
+719643352
+},
+{422708562,
+3576364619,
+3805358080,
+304104600
+},
+{1060767058,
+3576332106,
+3805366272,
+33736856
+},
+{177297632,
+365060824,
+365941027,
+439902936
+},
+{671960118,
+1442720738,
+2479477733,
+585425608
+},
+{244181033,
+901933277,
+30373202,
+1208509176
+},
+{680199376,
+1426358122,
+2714096469,
+48816840
+},
+{2284326198,
+383708106,
+2478895598,
+334841544
+},
+{1113083450,
+384557464,
+66155842,
+2266253274
+},
+{2318929846,
+383706058,
+2478895598,
+397756122
+},
+{243919081,
+901999197,
+26202706,
+1212705404
+},
+{528110192,
+1894442312,
+778651786,
+638745498
+},
+{1027212624,
+3576201034,
+3822143488,
+42123448
+},
+{511267442,
+1894441800,
+711542922,
+637696922
+},
+{244701944,
+2847957073,
+97482771,
+1443389851
+},
+{680346723,
+1439592026,
+2211304004,
+183296200
+},
+{8397944,
+355076345,
+666744067,
+1114378249
+},
+{700056912,
+3574103626,
+2748432644,
+310952152
+},
+{243919977,
+901802077,
+562910802,
+1212705340
+},
+{512971376,
+3918034009,
+1171212291,
+1176134043
+},
+{671992886,
+368708299,
+2479486820,
+182772424
+},
+{691668306,
+3574234954,
+3822143492,
+44220632
+},
+{960103506,
+3580559178,
+2747353732,
+42131672
+},
+{1113345592,
+116056496,
+132740555,
+2803136402
+},
+{244800184,
+734290385,
+97480979,
+370697115
+},
+{960103762,
+3576201034,
+3821095556,
+42131672
+},
+{2390952242,
+902720219,
+332198146,
+443113672
+},
+{2284326326,
+383708106,
+2478895598,
+334841544
+},
+{697959632,
+3574103658,
+3787833349,
+42131672
+},
+{378137848,
+887875669,
+596475474,
+1078487356
+},
+{1113345594,
+385606064,
+133789027,
+2803124122
+},
+{8424568,
+353044729,
+631092483,
+1114374217
+},
+{243919081,
+901802077,
+562910802,
+1212705340
+},
+{1113345594,
+385540528,
+133821795,
+2807318418
+},
+{2282229046,
+366914250,
+2479485286,
+334841544
+},
+{243917929,
+903899741,
+562910802,
+1212704828
+},
+{8422522,
+353044731,
+1203614979,
+1114378440
+},
+{2318929714,
+383706058,
+2479419822,
+397756122
+},
+{3192,
+355109081,
+647607555,
+1114378249
+},
+{680133848,
+1426620258,
+2173555509,
+40690248
+},
+{106450680,
+650895760,
+130643219,
+119038874
+},
+{672730134,
+1442737994,
+2613695468,
+719643352
+},
+{1113083450,
+384557464,
+131691874,
+2803124122
+},
+{511263346,
+1894441800,
+711543434,
+637696926
+},
+{3189476592,
+1693118019,
+1657637786,
+574648500
+},
+{2285375422,
+383706058,
+2478895598,
+397756106
+},
+{714688566,
+1439591234,
+47043560,
+577036952
+},
+{672664630,
+1440640842,
+2479477732,
+719643336
+},
+{245062392,
+734290384,
+97472787,
+370697114
+},
+{2155875448,
+336300249,
+831829251,
+1114374217
+},
+{671894582,
+1442720586,
+3553219556,
+719635144
+},
+{2285375286,
+383706058,
+2613113262,
+397756106
+},
+{244402274,
+902523609,
+1406136578,
+443048152
+},
+{679806064,
+1426590435,
+2244842241,
+40427721
+},
+{244443241,
+901999325,
+26169938,
+1214802492
+},
+{672992278,
+1442737994,
+2479477732,
+720691912
+},
+{681182416,
+1158184930,
+2244842293,
+47243976
+},
+{1210584374,
+385805258,
+2481517030,
+334841546
+},
+{2282245430,
+383691466,
+2478895598,
+469059272
+},
+{169348534,
+383708106,
+2478895470,
+334841546
+},
+{675089718,
+1440639946,
+2479477732,
+48554696
+},
+{244666992,
+4043989459,
+1204769027,
+101343642
+},
+{2651572280,
+619375187,
+1624112026,
+1917350068
+},
+{137889590,
+400483274,
+2481582502,
+334841546
+},
+{1027278034,
+3249078122,
+2713832064,
+42133656
+},
+{1113345594,
+385540496,
+133821795,
+2803124122
+},
+{135793974,
+2498423754,
+3553716294,
+65488072
+},
+{671681846,
+1440639946,
+2479477732,
+318005960
+},
+{511267442,
+1894441800,
+711542922,
+637696922
+},
+{2518387824,
+1693118019,
+1657645978,
+1382052276
+}
+};
+
+// Weight data 
+uint32_t W_CIM[1][256] = {{2937347042,
+582711007,
+2812568234,
+3806148925,
+2668995834,
+2946988591,
+2027081042,
+3592476372,
+2422448829,
+3091261544,
+2094285888,
+4700509,
+1248488721,
+1947715281,
+181634821,
+1025697358,
+2459141933,
+4181128815,
+998744400,
+3921943325,
+4095755968,
+2167003098,
+873739106,
+468532042,
+3071780336,
+3990125135,
+1867828498,
+2461529823,
+1864987605,
+2002454937,
+2417329846,
+2594858514,
+2269002158,
+1878853761,
+1883987700,
+493012722,
+392950781,
+2113044658,
+4278727405,
+2150211759,
+1613435404,
+1576798978,
+3420873869,
+3868720383,
+1867019857,
+668721021,
+1221307795,
+2132076228,
+1198242749,
+4294608547,
+1863932784,
+1736313089,
+1807244154,
+2139032455,
+2166296843,
+4287474181,
+1517969061,
+947036919,
+3178116067,
+3266938800,
+1867022295,
+1205722489,
+334859900,
+972453119,
+1371295022,
+439632850,
+3516808877,
+3634274722,
+4011097411,
+2998925023,
+2308212994,
+2525457885,
+1825300306,
+1464981272,
+4007814869,
+3247830365,
+2180361258,
+3626000583,
+2962738010,
+4111759994,
+2940711129,
+3976312869,
+3264098701,
+3806182567,
+1194015663,
+4294576033,
+1090157742,
+1346299298,
+3807061981,
+4044747902,
+134877452,
+1567328800,
+1208621384,
+1533739552,
+1787581037,
+158232831,
+1860919263,
+2138946969,
+4154748426,
+182266363,
+2543289778,
+2880596559,
+4158940795,
+2229230077,
+3213173498,
+2812099391,
+3458875129,
+82496757,
+488461098,
+796596747,
+1008556840,
+259995951,
+2503454098,
+4153757504,
+3163227984,
+3848075902,
+4096285286,
+321717690,
+924825012,
+2146717253,
+3405049099,
+3265260741,
+1673091851,
+2138966663,
+388741295,
+786664609,
+2699794389,
+4250269272,
+1967845781,
+4253060432,
+177786704,
+744896527,
+3049660589,
+3264085294,
+4154748717,
+3893785079,
+2018155348,
+3177367074,
+2635441144,
+2944477998,
+3172735974,
+2276177456,
+2540973741,
+3136184815,
+882167530,
+2482792826,
+3207740134,
+2477434784,
+1961385846,
+421697498,
+1517004309,
+2102587983,
+882148690,
+2207283544,
+2962313044,
+3277549950,
+174043521,
+4152462853,
+280285034,
+1138522566,
+3710529766,
+3672869792,
+597662074,
+2144611137,
+1089198358,
+1918158528,
+1060485986,
+803867521,
+1852571383,
+1998767569,
+1875174419,
+626523485,
+371864237,
+3445364783,
+2430081606,
+1361318658,
+2228691346,
+3540631296,
+4290014056,
+2264742207,
+928221008,
+2109759320,
+803529915,
+2933386743,
+3956551018,
+3806161373,
+1245367871,
+274874578,
+4043022204,
+4185142098,
+3041337516,
+2190326826,
+871867827,
+2109787989,
+998879577,
+777328983,
+2048614244,
+1034565182,
+2306056451,
+3867528704,
+1815632743,
+1476332160,
+2568454708,
+446834210,
+3991305674,
+3265119696,
+3601666732,
+3654846498,
+1027437172,
+767695423,
+4010029218,
+2259363469,
+3905885021,
+4114592830,
+1804430659,
+650084037,
+2133158517,
+746549759,
+2369468847,
+2080082561,
+1782152,
+1350535722,
+2909778423,
+4156131312,
+3000290321,
+2687956575,
+3941316061,
+4246074973,
+3046409938,
+2239785340,
+149663123,
+4078036944,
+974885680,
+1030109727,
+2448935342,
+4074531283,
+966019882,
+50280765,
+2781225199,
+3265133612,
+948896596,
+4251826750,
+284144701,
+3963619624,
+1057480785,
+2906733373,
+4019480930,
+1660903191,
+9899308,
+2067083870,
+4151607036,
+3187294719,
+2436206262,
+446850770,
+1867085783,
+1205723001
+}
+};
 
 // ABN CIM gain 
-uint8_t GAMMA[2] = {61,53};
+uint8_t GAMMA[1] = {6};
+
+// ABN CIM offset 
+uint32_t B_CIM[1][10] = {{1895480764,
+3071804790,
+2076846997,
+3163626902,
+1616540092,
+978912944,
+1912606736,
+873465076,
+1362645008,
+805897392
+}
+};
+
+// FP weights 
+uint32_t W_FP[1][640] = {{144,
+-8,
+-70,
+-48,
+-21,
+-103,
+-24,
+210,
+-6,
+-119,
+-4,
+24,
+-48,
+82,
+-43,
+-95,
+-20,
+-50,
+64,
+85,
+-52,
+-33,
+-18,
+-122,
+-41,
+252,
+-10,
+165,
+-32,
+-87,
+26,
+-17,
+-36,
+-135,
+24,
+67,
+-59,
+3,
+97,
+31,
+-39,
+82,
+120,
+-93,
+-47,
+-28,
+80,
+-32,
+-47,
+1,
+-32,
+-4,
+1,
+-81,
+-57,
+208,
+-31,
+-71,
+-63,
+165,
+-60,
+-32,
+-117,
+-15,
+-91,
+-57,
+375,
+29,
+-1,
+-25,
+-57,
+-34,
+-37,
+-36,
+-97,
+-49,
+-3,
+163,
+-25,
+200,
+-25,
+-50,
+116,
+135,
+-33,
+-118,
+-47,
+-33,
+131,
+-41,
+-48,
+-33,
+-76,
+-101,
+-50,
+268,
+143,
+-24,
+-28,
+-31,
+27,
+332,
+-66,
+-86,
+-29,
+-41,
+-43,
+-86,
+-83,
+-9,
+-31,
+34,
+-66,
+58,
+-28,
+82,
+-29,
+-54,
+79,
+-43,
+0,
+-52,
+5,
+21,
+-97,
+11,
+26,
+24,
+39,
+44,
+98,
+-5,
+-33,
+-46,
+-83,
+-83,
+-40,
+69,
+-39,
+146,
+40,
+25,
+-96,
+22,
+-46,
+32,
+43,
+-43,
+52,
+-51,
+-9,
+112,
+209,
+-97,
+-14,
+-32,
+-73,
+-40,
+-68,
+4,
+-31,
+-32,
+426,
+-86,
+-20,
+-29,
+-83,
+-47,
+-65,
+34,
+-39,
+-39,
+195,
+-62,
+-30,
+-41,
+-23,
+205,
+-42,
+-85,
+-42,
+-19,
+-9,
+-11,
+182,
+-49,
+-61,
+-119,
+-90,
+242,
+7,
+-59,
+24,
+26,
+-62,
+28,
+30,
+39,
+47,
+-71,
+-43,
+4,
+27,
+35,
+34,
+51,
+-71,
+-80,
+21,
+40,
+22,
+24,
+-38,
+-125,
+-60,
+69,
+-39,
+2,
+75,
+53,
+-11,
+136,
+-70,
+-82,
+-27,
+-53,
+-11,
+223,
+-61,
+-82,
+3,
+50,
+-50,
+-102,
+100,
+-47,
+-43,
+-25,
+179,
+-90,
+125,
+19,
+-13,
+-44,
+-89,
+-36,
+-52,
+-79,
+-39,
+198,
+410,
+21,
+-74,
+-61,
+-20,
+-102,
+-81,
+-62,
+-34,
+-93,
+-36,
+-6,
+-42,
+-11,
+-109,
+-35,
+140,
+-52,
+-49,
+222,
+68,
+-15,
+-84,
+183,
+-15,
+-105,
+106,
+-47,
+-76,
+-30,
+-27,
+26,
+32,
+-77,
+29,
+81,
+-1,
+-48,
+-60,
+45,
+-56,
+-41,
+-93,
+-19,
+-93,
+-57,
+373,
+26,
+-1,
+-29,
+20,
+10,
+-80,
+20,
+-40,
+9,
+44,
+39,
+34,
+-75,
+-49,
+-19,
+-17,
+-22,
+100,
+-66,
+-27,
+100,
+-83,
+104,
+416,
+20,
+-72,
+-49,
+-24,
+-111,
+-88,
+-65,
+-37,
+-93,
+-16,
+-12,
+-23,
+-74,
+121,
+251,
+-68,
+-40,
+-56,
+-68,
+147,
+4,
+-53,
+-1,
+206,
+-70,
+-75,
+-55,
+-51,
+-101,
+-31,
+-29,
+-73,
+16,
+185,
+-84,
+156,
+-1,
+-41,
+-105,
+-9,
+112,
+-85,
+-30,
+-52,
+-41,
+144,
+-27,
+-34,
+-14,
+29,
+-37,
+-57,
+-79,
+9,
+84,
+49,
+39,
+-67,
+27,
+28,
+353,
+-69,
+-89,
+-34,
+-41,
+-47,
+-85,
+-97,
+-9,
+20,
+19,
+-86,
+27,
+-84,
+34,
+-46,
+14,
+41,
+48,
+-13,
+26,
+-38,
+-78,
+-48,
+-42,
+60,
+-27,
+99,
+54,
+-37,
+-42,
+193,
+305,
+9,
+-115,
+-75,
+-45,
+-151,
+3,
+-16,
+53,
+-70,
+180,
+-1,
+-114,
+-5,
+116,
+-114,
+-53,
+40,
+-32,
+48,
+-94,
+20,
+-89,
+19,
+-60,
+116,
+44,
+-30,
+-24,
+-75,
+12,
+185,
+-64,
+137,
+-1,
+-37,
+-116,
+-24,
+-16,
+-20,
+41,
+447,
+-38,
+-73,
+-69,
+-69,
+-182,
+28,
+343,
+-64,
+-86,
+-35,
+-46,
+-42,
+-81,
+-93,
+-15,
+-57,
+-49,
+-67,
+-31,
+-21,
+-73,
+25,
+448,
+-2,
+-168,
+9,
+-32,
+-62,
+70,
+18,
+-100,
+28,
+32,
+28,
+9,
+69,
+-22,
+-58,
+161,
+-6,
+136,
+-46,
+-73,
+-134,
+-39,
+-15,
+-17,
+-60,
+-93,
+111,
+-52,
+-26,
+-25,
+251,
+-77,
+12,
+-47,
+-89,
+46,
+2,
+21,
+21,
+23,
+31,
+-19,
+-19,
+-33,
+-74,
+-171,
+-62,
+-56,
+-9,
+-8,
+516,
+-67,
+-56,
+-50,
+-60,
+255,
+2,
+218,
+-41,
+-61,
+-135,
+-38,
+30,
+-31,
+39,
+-157,
+-52,
+106,
+44,
+-58,
+131,
+-40,
+-39,
+-4,
+26,
+-15,
+-122,
+3,
+-13,
+-123,
+-53,
+387,
+-52,
+-40,
+190,
+-75,
+-71,
+-29,
+160,
+-5,
+-39,
+3,
+-25,
+-48,
+-57,
+139,
+-61,
+-86,
+-8,
+-74,
+106,
+133,
+100,
+-7,
+106,
+-55,
+-43,
+-75,
+90,
+-49,
+-55,
+-19,
+49,
+-7,
+119,
+-105,
+-14,
+158,
+-54,
+-68,
+-54,
+-14,
+-57,
+3,
+33,
+25,
+-53,
+75,
+-48,
+51,
+63,
+-78,
+104,
+-15,
+170,
+-64,
+-14,
+-47,
+-49,
+-43,
+-43,
+-4,
+-32,
+-39,
+94,
+-82,
+-30,
+-46,
+-20,
+91,
+143,
+-56,
+-52,
+-31,
+-20,
+-30,
+115,
+-68,
+-41,
+136,
+-85,
+90
+}
+};
 
 // ABN FP parameters
-uint32_t GAMMA_FP[2] = {61,53};
+uint32_t GAMMA_FP[1][10] = {{0xb0fd,0xcecb,0x8352,0x732e,0x7fca,0x797d,0x824d,0x9373,0x7bc0,0x80e9}
+};
 
-uint32_t BETA_FP[2][47] = {
-{0xfffffffffffff01a,0xc79,0xffffffffffffe8e0,0xfffffffffffffef2,0xfffffffffffff91b,0xfffffffffffff6a9,0x5b8,0xfffffffffffffd0e,0xfffffffffffff4a3,0xfffffffffffff245,0xffffffffffffeea5,0xf04,0xfffffffffffff672,0xfffffffffffffc89,0xfffffffffffff74a,0xfffffffffffff9b4,0xfffffffffffffb62,0xfffffffffffff673,0xfffffffffffff130,0xe1b,0x1124,0xe06,0xd1e,0xffffffffffffec4f,0xffffffffffffff7f,0xfffffffffffffda8,0xfffffffffffff50f,0xffffffffffffee80,0x242,0x132,0xfffffffffffff64c,0xfffffffffffffd11,0xfffffffffffff21f,0xfffffffffffff268,0xfffffffffffff6a8,0xfffffffffffff757,0xfffffffffffff75d,0xe49,0xffffffffffffee5f,0xfffffffffffff54b,0xfffffffffffffbee,0xfffffffffffff851,0xfffffffffffff4c7,0xfffffffffffff9a8,0xfffffffffffffe62,0xfffffffffffffd99,0x46e},
-{0xfffffffffffffc13,0xfffffffffffffa72,0xfffffffffffff122,0x34b,0xfffffffffffff97b,0x12b,0xfffffffffffff1a1,0xfffffffffffffa05,0xfffffffffffff4e5,0xb27,0xfffffffffffffeb0,0xffffffffffffedc9,0xffffffffffffefe7,0xfffffffffffffec8,0x2be,0xfffffffffffff819,0xffffffffffffff2f,0xfffffffffffff27a,0xfffffffffffffcf0,0xfffffffffffff16a,0xfffffffffffff967,0xfffffffffffff092,0x3b1,0x8,0xfffffffffffff1ce,0xfffffffffffff1cc,0x38a,0xffffffffffffeea1,0xfffffffffffff7c0,0xfffffffffffff189,0x79e,0xfffffffffffffbe9,0xfffffffffffffb22,0xfffffffffffff608,0x306,0xffffffffffffe93f,0xfffffffffffff4ed,0xfffffffffffffeb4,0xfffffffffffffe4e,0xfffffffffffffee9,0xfffffffffffffeb2,0xfffffffffffffe8d,0xffffffffffffff12,0xfffffffffffffe67,0xfffffffffffffe8e,0xfffffffffffffed8,0xfffffffffffffec4},
+uint32_t BETA_FP[1][10] = {
+{-0x822,-0x63c,-0x932,-0xb03,-0x9b9,-0xb6b,-0xae0,-0x955,-0xc0b,-0xacf}
 };
 
+// Inference results 
+uint8_t inf_result[128] = {7,2,1,0,4,1,4,9,5,9,0,6,9,0,1,5,9,7,3,4,9,6,6,5,4,0,7,4,0,1,3,1,3,4,7,2,7,1,2,1,1,7,4,2,3,5,1,2,4,4,6,3,5,5,6,0,4,1,9,5,7,8,9,3,7,4,6,4,3,0,7,0,2,9,1,7,3,2,9,7,7,6,2,7,8,4,7,3,6,1,3,6,9,3,1,4,1,7,6,9,6,0,5,4,9,9,2,1,9,4,8,7,3,9,7,4,4,4,9,2,5,4,7,6,7,9,0,5};
+
diff --git a/chip_files/create_C_header.py b/chip_files/create_C_header.py
index 9a8d0d1..7422d4c 100644
--- a/chip_files/create_C_header.py
+++ b/chip_files/create_C_header.py
@@ -33,6 +33,7 @@ def create_C_header(filename,network_info,cim_dim,D_VEC,P_VEC,TIME_CONF,GAMMA_VE
   T_ADC   = TIME_CONF[3];
   
   # // Reshape FP beta-offset
+  GAMMA_FP_VEC = np.reshape(GAMMA_FP_VEC,(Nlayers_fp,-1));
   BETA_FP_VEC = np.reshape(BETA_FP_VEC,(Nlayers_fp,-1));
   Nbeta_fp = np.shape(BETA_FP_VEC)[-1];
   
@@ -58,28 +59,28 @@ def create_C_header(filename,network_info,cim_dim,D_VEC,P_VEC,TIME_CONF,GAMMA_VE
 
   # Layers & channels
   fileID.write('// Networks channels\n');
-  fileID.write(f'uint16_t C_IN[{Nlayers_cim}] = {{'); 
+  fileID.write(f'uint16_t C_IN[{Nlayers_cim-START_LAYER}] = {{'); 
   for i in range(len(C_IN)):
     if(i == 0):
       fileID.write(f'{C_IN[i]}');
     else:
       fileID.write(f',{C_IN[i]}');
   fileID.write('}\n');
-  fileID.write(f'uint16_t C_OUT[{Nlayers_cim}] = {{');
+  fileID.write(f'uint16_t C_OUT[{Nlayers_cim-START_LAYER}] = {{');
   for i in range(len(C_OUT)):
     if(i == 0):
       fileID.write(f'{C_OUT[i]}');
     else:
       fileID.write(f',{C_OUT[i]}');
   fileID.write('}\n');
-  fileID.write(f'uint8_t C_IN_LOG[{Nlayers_cim}] = {{'); 
+  fileID.write(f'uint8_t C_IN_LOG[{Nlayers_cim-START_LAYER}] = {{'); 
   for i in range(len(C_IN)):
     if(i == 0):
       fileID.write(f'{int(math.log2(C_IN[i]))}');
     else:
       fileID.write(f',{int(math.log2(C_IN[i]))}');
   fileID.write('}\n');
-  fileID.write(f'uint8_t C_OUT_LOG[{Nlayers_cim}] = {{');
+  fileID.write(f'uint8_t C_OUT_LOG[{Nlayers_cim-START_LAYER}] = {{');
   for i in range(len(C_OUT)):
     if(i == 0):
       fileID.write(f'{int(math.log2(C_OUT[i]))}');
@@ -104,8 +105,10 @@ def create_C_header(filename,network_info,cim_dim,D_VEC,P_VEC,TIME_CONF,GAMMA_VE
   fileID.write(f'uint8_t T_ADC_CONF  = {T_ADC};\n');
   fileID.write('\n');
   
-  # Number of samples
+  # Number of samples and layers
   fileID.write(f'uint8_t Nimg = {Nimg};\n');
+  fileID.write(f'uint8_t Nlayers_cim = {Nlayers_cim};\n');
+  fileID.write(f'uint8_t Nlayers_fp = {Nlayers_fp};\n');
   fileID.write('\n');
   
   # ABN params
@@ -151,4 +154,307 @@ def create_C_header(filename,network_info,cim_dim,D_VEC,P_VEC,TIME_CONF,GAMMA_VE
   fileID.close();
   
   
+  return;
+  
+  
+def create_C_header_subset(filename,network_info,cim_dim,D_VEC,P_VEC,TIME_CONF,GAMMA_VEC,BETA_FP_VEC,GAMMA_FP_VEC,data_cim,START_LAYER):
+  # // Retrieve variables //
+  # CNN network info
+  Nlayers_cim = network_info[0];
+  Nlayers_fp  = network_info[1];
+  Nimg        = network_info[2];
+  # CIM dims
+  N_ROWS = cim_dim[0];
+  N_COLS = cim_dim[1];
+  # Channels
+  H_IMG = D_VEC[0];
+  W_IMG = D_VEC[1];
+  C_IN  = D_VEC[2];
+  C_OUT = D_VEC[3];
+  # Precisions
+  R_IN    = P_VEC[0];
+  R_W     = P_VEC[1];
+  R_OUT   = P_VEC[2];
+  R_BETA  = P_VEC[3];
+  R_GAMMA = P_VEC[4];
+  # Timings
+  T_DP    = TIME_CONF[0];
+  T_PRE   = TIME_CONF[1];
+  T_MBIT  = TIME_CONF[2];
+  T_ADC   = TIME_CONF[3];
+  # CIM data, starting at the chosen layer index
+  data_in = data_cim[-1][START_LAYER-1];
+  data_w  = data_cim[1];
+  data_b  = data_cim[2];
+  data_w_fp = data_cim[3];
+  data_inf = data_cim[4];
+  
+  # Reshape input data
+  data_in = np.reshape(data_in,(Nimg,-1));
+  
+  # Reshape CIM offset
+  beta_conf_list = [];
+  for i in range(Nlayers_cim):
+    beta_conf_temp = np.expand_dims(data_b[i].astype("uint8"),axis=-1); 
+    beta_unpacked = np.flip(np.unpackbits(beta_conf_temp,axis=-1),axis=-1);
+    # swap axes
+    beta_unpacked = np.swapaxes(beta_unpacked,0,1);
+    # Repeat beta values in r_w cols
+    beta_unpacked = np.repeat(beta_unpacked,R_W,axis=-1);
+    if(R_W*C_OUT[i] < 32):
+      beta_unpacked = np.pad(beta_unpacked,((0,0),(0,32-R_W*C_OUT[i])));
+    beta_conf_temp = np.dot(np.reshape(beta_unpacked[:R_BETA,...],(-1,32)),2**np.arange(32));
+    beta_conf_list.append(beta_conf_temp);
+  #Stack results along a single dimension
+  data_b = beta_conf_list;
+  
+  # // Reshape FP beta-offset
+  GAMMA_FP_VEC = np.reshape(GAMMA_FP_VEC,(Nlayers_fp,-1));
+  BETA_FP_VEC = np.reshape(BETA_FP_VEC,(Nlayers_fp,-1));
+  Nbeta_fp = np.shape(BETA_FP_VEC)[-1];
+  
+  # // Write header file //
+  # Open file
+  fileID = open(filename,'w');
+  # Header
+  fileID.write('/*\n');
+  fileID.write(' *-----------------------------------\n');
+  fileID.write(' * Header file for CIM-QNN parameters\n');
+  fileID.write(' *-----------------------------------\n');
+  fileID.write('*/\n');
+  fileID.write('\n');
+  # Pre-processor statements
+  fileID.write(f'#define N_ROWS {N_ROWS}\n');
+  fileID.write(f'#define N_COLS {N_COLS}\n');
+  fileID.write('\n');
+
+  # Input img size
+  fileID.write('// Input img size\n');
+  fileID.write(f'uint8_t H_IMG = {H_IMG};\n')
+  fileID.write(f'uint8_t W_IMG = {W_IMG};\n')
+
+  # Layers & channels
+  fileID.write('// Networks channels\n');
+  fileID.write(f'uint16_t C_IN[{Nlayers_cim-START_LAYER}] = {{'); 
+  for i in range(START_LAYER,len(C_IN)):
+    if(i == START_LAYER):
+      fileID.write(f'{C_IN[i]}');
+    else:
+      fileID.write(f',{C_IN[i]}');
+  fileID.write('};\n');
+  fileID.write(f'uint16_t C_OUT[{Nlayers_cim-START_LAYER}] = {{');
+  for i in range(START_LAYER,len(C_OUT)):
+    if(i == START_LAYER):
+      fileID.write(f'{C_OUT[i]}');
+    else:
+      fileID.write(f',{C_OUT[i]}');
+  fileID.write('};\n');
+  fileID.write(f'uint8_t C_IN_LOG[{Nlayers_cim-START_LAYER}] = {{'); 
+  for i in range(START_LAYER,len(C_IN)):
+    if(i == START_LAYER):
+      fileID.write(f'{int(math.log2(C_IN[i]))}');
+    else:
+      fileID.write(f',{int(math.log2(C_IN[i]))}');
+  fileID.write('};\n');
+  fileID.write(f'uint8_t C_OUT_LOG[{Nlayers_cim-START_LAYER}] = {{');
+  for i in range(START_LAYER,len(C_OUT)):
+    if(i == START_LAYER):
+      fileID.write(f'{int(math.log2(C_OUT[i]))}');
+    else:
+      fileID.write(f',{int(math.log2(C_OUT[i]))}');
+  fileID.write('};\n');
+  fileID.write('// FP channels \n');
+  fileID.write(f'uint16_t C_IN_FP[{Nlayers_fp}] = {{'); 
+  for i in range(Nlayers_fp):
+    if(i == 0):
+      fileID.write(f'{C_IN[Nlayers_cim-1+i]}');
+    else:
+      fileID.write(f',{C_IN[Nlayers_cim-1+i]}');
+  fileID.write('};\n');
+  fileID.write(f'uint16_t C_OUT_FP[{Nlayers_fp}] = {{');
+  for i in range(Nlayers_fp):
+    if(i == 0):
+      fileID.write(f'{C_OUT[Nlayers_cim-1+i]}');
+    else:
+      fileID.write(f',{C_OUT[Nlayers_cim-1+i]}');
+  fileID.write('};\n');
+
+  # Precision
+  fileID.write('// Computing precision\n');
+  fileID.write(f'uint8_t R_IN  = {R_IN}; uint8_t R_IN_LOG  = {int(math.log2(R_IN))};\n');
+  fileID.write(f'uint8_t R_W   = {R_W}; uint8_t R_W_LOG   = {int(math.log2(R_W))};\n');
+  fileID.write(f'uint8_t R_OUT = {R_OUT}; uint8_t R_OUT_LOG = {int(math.log2(R_OUT))};\n'); 
+  fileID.write(f'uint8_t R_BETA  = {R_BETA};\n');
+  fileID.write(f'uint8_t R_GAMMA = {R_GAMMA};\n');
+  fileID.write('\n');
+  
+  # Timing configs
+  fileID.write('// Timing configuration\n');
+  fileID.write(f'uint8_t T_DP_CONF  = {T_DP};\n');
+  fileID.write(f'uint8_t T_PRE_CONF = {T_PRE};\n');
+  fileID.write(f'uint8_t T_MBIT_IN_CONF = {T_MBIT};\n');
+  fileID.write(f'uint8_t T_MBIT_W_CONF = {T_MBIT};\n');
+  fileID.write(f'uint8_t T_ADC_CONF  = {T_ADC};\n');
+  fileID.write(f'uint8_t T_REF_CONF  = {T_ADC};\n');
+  fileID.write('\n');
+  
+  # Number of samples and layers
+  fileID.write(f'uint8_t Nimg = {Nimg};\n');
+  fileID.write(f'uint8_t Nlayers_cim = {Nlayers_cim-START_LAYER};\n');
+  fileID.write(f'uint8_t Nlayers_fp = {Nlayers_fp};\n');
+  fileID.write('\n');
+  
+  # Inputs
+  fileID.write('// Input data \n');
+  fileID.write(f'uint32_t DATA_IN[{Nimg}][{np.shape(data_in)[1]}] = {{');
+  img_size = np.shape(data_in)[1];
+  for i in range(Nimg):
+    fileID.write('{');
+    for j in range(img_size):
+      if(j==img_size-1):
+        fileID.write(f'{data_in[i,j]}\n');
+      else:
+        fileID.write(f'{data_in[i,j]},\n');
+    if(i == Nimg-1):
+      fileID.write('}\n');
+    else:
+      fileID.write('},\n');
+  fileID.write(f'}};\n');
+  fileID.write('\n');
+ 
+  # Weights
+  fileID.write('// Weight data \n');
+  max_w = np.size(data_w[START_LAYER]); # ! Only valid for FC networks
+  fileID.write(f'uint32_t W_CIM[{Nlayers_cim-START_LAYER}][{max_w}] = {{');
+  for i in range(START_LAYER,Nlayers_cim):
+    fileID.write('{');
+    layer_size = np.size(data_w[i]);
+    for j in range(max_w):
+      if(j<layer_size):
+        if(j==max_w-1):
+          fileID.write(f'{data_w[i][j]}\n');
+        else:
+          fileID.write(f'{data_w[i][j]},\n');
+      else:
+        if(j==max_w-1):
+          fileID.write(f'0x0\n');
+        else:
+          fileID.write(f'0x0,\n');
+    if(i == Nlayers_cim-1):
+      fileID.write('}\n');
+    else:
+      fileID.write('},\n');
+  fileID.write(f'}};\n');
+  fileID.write('\n');
+ 
+  # ABN params
+  fileID.write('// ABN CIM gain \n');
+  # Gain values
+  fileID.write(f'uint8_t GAMMA[{Nlayers_cim-START_LAYER}] = {{');
+  for i in range(START_LAYER,Nlayers_cim):
+    if(i==START_LAYER):
+      fileID.write(f'{GAMMA_VEC[i]}');
+    else:
+      fileID.write(f',{GAMMA_VEC[i]}');
+  fileID.write(f'}};\n');
+  fileID.write('\n');
+  
+  fileID.write('// ABN CIM offset \n');
+  max_b = np.size(data_b[START_LAYER]); # ! Only valid for FC networks
+  fileID.write(f'uint32_t B_CIM[{Nlayers_cim-START_LAYER}][{max_b}] = {{');
+  for i in range(START_LAYER,Nlayers_cim):
+    fileID.write('{');
+    layer_size = np.size(data_b[i]);
+    for j in range(max_b):
+      if(j<layer_size):
+        if(j==max_b-1):
+          fileID.write(f'{data_b[i][j]}\n');
+        else:
+          fileID.write(f'{data_b[i][j]},\n');
+      else:
+        if(j==max_b-1):
+          fileID.write(f'0x0\n');
+        else:
+          fileID.write(f'0x0,\n');
+    if(i == Nlayers_cim-1):
+      fileID.write('}\n');
+    else:
+      fileID.write('},\n');
+  fileID.write(f'}};\n');
+  fileID.write('\n');
+  
+  fileID.write('// FP weights \n');
+  max_w = np.size(data_w_fp[0]); # ! Only valid for FC networks
+  fileID.write(f'uint32_t W_FP[{Nlayers_fp}][{max_w}] = {{');
+  for i in range(Nlayers_fp):
+    fileID.write('{');
+    layer_size = np.size(data_w_fp[i]);
+    for j in range(max_w):
+      if(j<layer_size):
+        if(j==max_w-1):
+          fileID.write(f'{data_w_fp[i][j]}\n');
+        else:
+          fileID.write(f'{data_w_fp[i][j]},\n');
+      else:
+        if(j==max_w-1):
+          fileID.write(f'0x0\n');
+        else:
+          fileID.write(f'0x0,\n');
+    if(i == Nlayers_fp-1):
+      fileID.write('}\n');
+    else:
+      fileID.write('},\n');
+  fileID.write(f'}};\n');
+  fileID.write('\n');
+  
+  # ABN params
+  fileID.write('// ABN FP parameters\n');
+  # Gain values
+  fileID.write(f'uint32_t GAMMA_FP[{Nlayers_fp}][{Nbeta_fp}] = {{');
+  print(GAMMA_FP_VEC); print(BETA_FP_VEC)
+  for i in range(Nlayers_fp):
+    fileID.write(f'{{');
+    for j in range(Nbeta_fp):
+      if(j==0):
+        fileID.write(f'{hex(GAMMA_FP_VEC[i][j])}');
+      else:
+        fileID.write(f',{hex(GAMMA_FP_VEC[i][j])}');
+    if(i==Nlayers_fp-1):
+      fileID.write(f'}}\n');
+    else:
+      fileID.write(f'}},\n');
+  fileID.write(f'}};\n');
+  fileID.write('\n');
+  # Offsets value
+  fileID.write(f'uint32_t BETA_FP[{Nlayers_fp}][{Nbeta_fp}] = {{\n');
+  for i in range(Nlayers_fp):
+    fileID.write(f'{{');
+    for j in range(Nbeta_fp):
+      if(j==0):
+        fileID.write(f'{hex(BETA_FP_VEC[i][j])}');
+      else:
+        fileID.write(f',{hex(BETA_FP_VEC[i][j])}');
+    if(i==Nlayers_fp-1):
+      fileID.write(f'}}\n');
+    else:
+      fileID.write(f'}},\n');
+  fileID.write(f'}};\n');
+  fileID.write('\n');
+  
+  # ABN params
+  fileID.write('// Inference results \n');
+  # Gain values
+  fileID.write(f'uint8_t inf_result[{Nimg}] = {{');
+  for i in range(Nimg):
+    if(i==0):
+      fileID.write(f'{data_inf[i]}');
+    else:
+      fileID.write(f',{data_inf[i]}');
+  fileID.write(f'}};\n');
+  fileID.write('\n');
+  
+  # Close file and return
+  fileID.close();
+  
+  
   return;
\ No newline at end of file
diff --git a/config/config_cim_cnn_param.py b/config/config_cim_cnn_param.py
index 84ee05d..cd70d89 100644
--- a/config/config_cim_cnn_param.py
+++ b/config/config_cim_cnn_param.py
@@ -4,26 +4,26 @@
 ########################################
 # // Dataset //
 config_path = "config_cim_cnn_param"
-dataset_name = "MNIST";
-dim=28
-channels=1
+dataset_name = "CIFAR-10";
+dim=32
+channels=3
 classes=10
 # // Network structure //
-network_type = "full-qnn";
-# network_struct = "1C1D"
-network_struct = "LeNet-5"
+network_type = "float";
+#network_struct = "MLP_512_256_32_32_10"
+network_struct = "VGG-8"
 OP_TYPE = "CONV-2D";
-C_IN_VEC = [1024,128];
-C_OUT_VEC = [128,64];
+C_IN_VEC = [1024,256,512,32,32];
+C_OUT_VEC = [512,256,32,32,10];
 Nl_fp = 1;
 
 # // Conv. kernel size //
 kern_size = 3
 # // Regularization //
-kernel_regularizer=0.
+kernel_regularizer=0.001
 activity_regularizer=0.
 # // Training iterations & savings //
-Niter = 10;
+Niter = 1;
 Nimg_save = 128;
 
 #####################################
@@ -33,11 +33,12 @@ Nimg_save = 128;
 epochs = 30
 batch_size = 128
 # batch_size = 128
-lr = 0.001
-decay = 0.000025
+lr = 0.005
+# decay = 0.000025
+decay = 5e-4
 # Decay & lr factors
-decay_at_epoch = [15, 75, 150 ]
-factor_at_epoch = [.25, .25, .1]
+decay_at_epoch = [1, 10, 30 ]
+factor_at_epoch = [.1, .1, .1]
 kernel_lr_multiplier = 10
 # Debug and logging
 progress_logging = 1 # can be 0 = no std logging, 1 = progress bar logging, 2 = one log line per epoch
@@ -54,19 +55,19 @@ tech = 'GF22nmFDX';
 typeT = 'RVT';
 # SUPPLY and BACK-BIAS
 VDD  = 0.8;
-Vmax_beta = 0.05;
+Vmax_beta = 0.02;
 BBN  = 0;
 BBP  = 0;
 # CIM-SRAM I/O RESOLUTION
-IAres = 1;
+IAres = 4;
 Wres  = 1;
-OAres = 1;
+OAres = 4;
 # ABN resolution (if enabled)
 r_gamma = 5;
 r_beta  = 5;
 # MAXIMUM INPUT VECTOR SIZE for ALL layers
 Nrows = 1152;
-Ncols = 512;
+Ncols = 256;
 # Timing configuration (! these should be updated with the last of the conf setup)
 T_DP    = 0x3;
 T_PRE   = 0x3;
@@ -79,7 +80,7 @@ T_ADC   = 0x3;
 # Simulator (techno-dependent)
 simulator = "spectre"
 # Enable noisy training
-EN_NOISE = 0;
+EN_NOISE = 1;
 # Enable synmaic-rnage scaling (charge-domain)
 EN_SCALE = 1;
 # Enable analog BN
@@ -108,7 +109,7 @@ path_to_out = "./saved_models/";
 acc_file_template = "accuracy/acc_IMC_{}_{}_{}_{}_IA{}bW{}bOA{}b_{}b{}bABN_{}iter_{}SCALE_{}ABN_{}noise.txt";
 w_file_template = "weights/weights_IMC_{}_{}_{}_{}_IA{}bW{}bOA{}b_{}b{}bABN_{}iter_{}SCALE_{}ABN_{}noise.hdf5";
 in_file_template = "inputs/in_IMC_{}_{}_{}_IA{}b.txt";
-out_file_template = "outputs/out_IMC_{}_{}_{}_{}_IA{}bW{}bOA{}b_{}b{}bABN_{}iter_{}SCALE_{}ABN_{}noise";
+out_file_template = "outputs/out_IMC_{}_{}_{}_{}_IA{}bW{}bOA{}b_{}b{}bABN_{}iter_{}SCALE_{}ABN_{}noise.txt";
 inference_file_template = "outputs/inference_IMC_{}_{}_{}_{}_IA{}bW{}bOA{}b_{}b{}bABN_{}iter_{}SCALE_{}ABN_{}noise.txt";
 
 # On-chip inference files
@@ -126,7 +127,7 @@ chip_beta_FP_template   = "fp_bn/beta_fp_{}_{}_{}_{}_IA{}bW{}bOA{}b_noise{}";
 # FPGA files
 path_to_fpga = "./chip_files/fpga/"
 
-fS_beta_fp = 128;
+fS_beta_fp = 1024;
 fS_gamma_fp = 64;
 
 # // CPU-only training //
diff --git a/config/config_sweep_param.py b/config/config_sweep_param.py
index 129e93c..84d9328 100644
--- a/config/config_sweep_param.py
+++ b/config/config_sweep_param.py
@@ -10,10 +10,10 @@ channels=1
 classes=10
 # // Network structure //
 network_type = "full-qnn";
-network_struct = "Jia_2020_reduced"
-# network_struct = "MLP_three_stage_abn"
-OP_TYPE = "CONV-2D";
-# OP_TYPE = "FC";
+# network_struct = "Jia_2020_reduced"
+network_struct = "MLP_three_stage_abn"
+# OP_TYPE = "CONV-2D";
+OP_TYPE = "FC";
 
 C_IN_VEC = [1024,128];
 C_OUT_VEC = [128,64];
@@ -33,9 +33,9 @@ Nimg_save = 128;
 #####################################
 # Main hyper-params
 epochs = 30
-batch_size = 32*1
-# batch_size = 128
-lr = 0.001
+# batch_size = 32*1
+batch_size = 4*128
+lr = 0.01
 decay = 0.000025
 # Decay & lr factors
 decay_at_epoch = [15, 75, 150 ]
@@ -56,7 +56,7 @@ tech = 'GF22nmFDX';
 typeT = 'RVT';
 # SUPPLY and BACK-BIAS
 VDD  = 0.8;
-Vmax_beta = 0.1;
+Vmax_beta = 0.05;
 BBN  = 0;
 BBP  = 0;
 # CIM-SRAM I/O RESOLUTION
@@ -65,7 +65,7 @@ Wres  = 1;
 OAres = IAres;
 # ABN resolution (if enabled)
 r_gamma = 5;
-r_beta  = 5;
+r_beta  = 8;
 # MAXIMUM INPUT VECTOR SIZE for ALL layers
 Nrows = 1152;
 Ncols = 512;
@@ -73,9 +73,9 @@ Ncols = 512;
 #######################################################################
 ######### Sweep vectors (comment out related HW info above !) #########
 #######################################################################
-IAres_vec = [1];
-# r_gamma_vec = [1,2,3,4,5,6,7,8];
-r_gamma_vec = [1,2,3,4];
+IAres_vec = [1,2,4,6,8];
+r_gamma_vec = [1,2,3,4,5,6,7,8];
+# r_gamma_vec = [7];
 
 ########################################
 ########## Simulation flags ############
diff --git a/layers/analog_BN_charge_interp_PL.py b/layers/analog_BN_charge_interp_PL.py
index ab959f9..c0a63a0 100644
--- a/layers/analog_BN_charge_interp_PL.py
+++ b/layers/analog_BN_charge_interp_PL.py
@@ -278,13 +278,16 @@ def ABN(V_DP,ABN_lookup,sig_ABN_lookup,V_DP_half_LUT,devGainLUT,mov_mean_DP=0.0,
     # Get hardware parameters
     VDD = hardware.sramInfo.VDD.data;
     
-    r_gamma = hardware.sramInfo.r_gamma; Ns_gamma = 2**r_gamma;
-    r_beta  = hardware.sramInfo.r_beta;  Ns_beta  = 2**r_beta;
+    r_gamma = hardware.sramInfo.r_gamma;
+    r_beta  = hardware.sramInfo.r_beta;
     OAres   = hardware.sramInfo.OAres;
+    # Get number of states
+    Ns_gamma = 2**r_gamma;
     
     Vmax_beta = hardware.sramInfo.Vmax_beta;
-    Vlsb_beta = Vmax_beta/2**(r_beta-1)
-    
+    Vlsb_beta = Vmax_beta/2**(r_beta-1);
+    Vadc_step = VDD/(2**OAres);
+
     # Set 'None' parameters to their initial values
     if gamma is None:
         gamma = K.constant(1.0);
@@ -295,38 +298,37 @@ def ABN(V_DP,ABN_lookup,sig_ABN_lookup,V_DP_half_LUT,devGainLUT,mov_mean_DP=0.0,
     if mov_variance_DP is None:
         mov_variance_DP  = K.constant(1.0);
 
-    # // Specify non-centernormalized correction factors //
-    mu_goal  = 0;
+    # Specify non-centernormalized correction factors
+#    mu_goal  = VDD/2;
     sigma_goal = VDD/m_sigma; var_goal = sigma_goal*sigma_goal;
-    
-    # // Equivalent gain computation //
-    # Get custom renorm factors (single gain for all columns)
+
+#    # Get custom renorm factors
+#    sigma_DP = K.sqrt(mov_variance_DP);
+#    mov_mean_DP_t = mov_mean_DP - mu_goal/sigma_goal*sigma_DP;
+#    # mov_mean_DP_t = K.zeros_like(mov_mean_DP);
     mov_variance_DP_t = K.mean(mov_variance_DP)/var_goal;
-    sigma_DP_t = K.sqrt(mov_variance_DP_t); 
-    # Get equivalent coefficients
-    gamma_eq = gamma/(sigma_DP_t + epsilon);
-    # Add Bernouilli matrices to regularize gain training (not mandatory)
-#    bern_matrix = tf.random.uniform(shape=tf.shape(gamma_eq),maxval=1);
-#    bern_matrix = tf.math.greater(bern_matrix,0.2); bern_matrix = tf.cast(bern_matrix,dtype='float32');
-#    gamma_eq = bern_matrix*round_through(gamma_eq)+(1-bern_matrix)*gamma_eq;  
-    
-    # // Equivalent offset computation //
-    sigma_DP = K.sqrt(mov_variance_DP);
-    mov_mean_DP_t = mov_mean_DP - mu_goal/sigma_goal*sigma_DP;
-    beta_eq = beta/gamma_eq - mov_mean_DP;
-    # Convert into voltage domain and add to ABN input
-    V_beta  = K.clip(round_through(beta_eq/Vlsb_beta)*Vlsb_beta,-Vmax_beta,Vmax_beta);
-    V_DP = V_DP + V_beta;
+#    mov_variance_DP_t = mov_variance_DP/var_goal;
+#    # Get equivalent coefficients
+#    sigma_DP_t = K.sqrt(mov_variance_DP_t); 
+
+    gamma_eq = gamma/(K.sqrt(mov_variance_DP_t) + epsilon);
+    beta_eq  = beta/gamma_eq - mov_mean_DP;
     
     # Restrict gain factor to power-of-2
-    log_gamma_eq = round_through(tf.math.log2(gamma_eq));
-    gamma_eq = K.pow(2,log_gamma_eq);
+    log_gamma_eq = round_through(tf.math.log(gamma_eq)/tf.math.log(2.));
+    gamma_eq = K.pow(2.,log_gamma_eq);
     
+    # Quantize results
+    gamma_eq = K.clip(round_through(gamma_eq),1,2**r_gamma);
+    V_beta  = K.clip(round_through(beta_eq/Vlsb_beta)*Vlsb_beta,-Vmax_beta,Vmax_beta);
+       
+    # Apply quantized offset
+    V_ABN_temp = V_DP+V_beta;
         
     # // Get ABN distribution from LUTs based on the gain/offset mapping //
-    D_OUT = doInterpABN(ABN_lookup,gamma_eq,V_DP,Ns_gamma,Ns_gamma,VDD,Npoints);
+    D_OUT = doInterpABN(ABN_lookup,gamma_eq,V_ABN_temp,Ns_gamma,Ns_gamma,VDD,Npoints);
     if(EN_NOISE):
-        sig_D_OUT = doInterpABN(sig_ABN_lookup,gamma_eq,V_DP,Ns_gamma,Ns_gamma,VDD,Npoints);
+        sig_D_OUT = doInterpABN(sig_ABN_lookup,gamma_eq,V_ABN_temp,Ns_gamma,Ns_gamma,VDD,Npoints);
         sig_D_OUT = sig_D_OUT*K.random_normal(shape=tf.shape(D_OUT),mean=0.,stddev=1.,dtype='float32');
         D_OUT   = D_OUT + sig_D_OUT;
     
diff --git a/models/Analog_DP.py b/models/Analog_DP.py
index 196cbc9..7eca0c2 100644
--- a/models/Analog_DP.py
+++ b/models/Analog_DP.py
@@ -262,9 +262,12 @@ def int_DP_cap(hardware,IA,W,Nunit,sig_MoM,EN_NOISE,EN_SCALE):
     IA   = tf.linalg.matrix_transpose(IA);
     V_DP = VDDL*(1+(Cc/Cp)*K.dot(IA,W));
     V_DP = tf.linalg.matrix_transpose(V_DP);
-    
     # Accumulate multi-bit inputs
     V_MBIT = MBIT_IN_actual(V_DP,hardware.sramInfo);
+    # Add spatial+temporal noise on the DP result (estimated post-silicon)
+    if(EN_NOISE):
+      sig_DP_mat = K.in_train_phase(K.random_normal(shape=tf.shape(V_MBIT),mean=0.,stddev=hardware.sramInfo.sig_Vdp,dtype='float32'),K.zeros_like(V_MBIT));
+      V_MBIT = V_MBIT + sig_DP_mat;
     
     # Debugs
     # tf.print("IA",IA[0:8,0]);
diff --git a/models/MAC_charge.py b/models/MAC_charge.py
index 32465db..75b9b0b 100644
--- a/models/MAC_charge.py
+++ b/models/MAC_charge.py
@@ -53,7 +53,7 @@ def MAC_op_se_ana(hardware,IA,W,sig_MoM_inf,T_DP_conf,EN_NOISE,EN_SCALE):
     sparse_mode = (hardware.sramInfo.input_mode == 'sparse');
     # Retrieve inputs dimensions & select desired LUT
     dim_IA = K.int_shape(IA);
-    Nunit = 2**np.ceil(np.log(np.ceil(dim_IA[-1]//C_unit))/np.log(2));
+    Nunit = 2**np.ceil(np.log(np.ceil(dim_IA[-1]/C_unit))/np.log(2));
     Nunit = Nunit.astype("int32");
     
     # One-hot input vector for bit-serial input processing
diff --git a/models/MBIT_unit.py b/models/MBIT_unit.py
index 3c5783d..f7f2b7c 100644
--- a/models/MBIT_unit.py
+++ b/models/MBIT_unit.py
@@ -7,6 +7,8 @@ import keras.backend as K
 import tensorflow as tf
 import math
 
+from utils.linInterp import doInterp_2D as doInterpMBIT 
+
 ## --- Sequential input accumulation ---
 # /// Ideal model ///
 def MBIT_IN_ideal(V_DP,sramInfo):
@@ -39,10 +41,20 @@ def MBIT_IN_actual(V_DP,sramInfo):
     return V_DP;
     
 # /// LUT-based numerical model ///
-def MBIT_IN_num(V_DP,sramInfo):
-  # do stuff
+def MBIT_IN_num(V_DP,sramInfo,Npoints=401):
+  # Retrieve hardware information
+  IAres = sramInfo.IAres;
+  VDD   = sramInfo.VDD.data;
+  MBIT_IN_lookup = sramInfo.MBIT_IN_LUT;
+  
+  # Get data by 2D interpolation, reusing each time the previous data
+  Vacc = K.zeros_like(V_DP[...,0]);
+  for i in range(IAres):
+    Vacc = doInterpMBIT(MBIT_IN_lookup,V_DP[...,i],Vacc,VDD,Npoints,VDD,Npoints);
+  # Final accumulation voltage is the result
+  V_MBIT_IN = Vacc;
   
-  return V_DP
+  return V_MBIT_IN
 
 
 ## --- Spatial binary weighting model ---
diff --git a/models/makeModel.py b/models/makeModel.py
index beda857..a099702 100644
--- a/models/makeModel.py
+++ b/models/makeModel.py
@@ -101,7 +101,7 @@ def make_model(model_type,cf,Conv_,Conv,Dens_,Dens,Act,Quant,BatchNormalization,
         model.add(Activation('softmax'))
         
     elif(model_type == 'MLP_128_64_10'):
-        print('MLP_three_stage toplogy selected...\n')
+        print('MLP_small toplogy selected...\n')
         
         model.add(Dens_(128,cf.dim,cf.channels,6.))
         model.add(BatchNormalization(cf.dim*cf.dim*cf.channels,4))
@@ -117,6 +117,55 @@ def make_model(model_type,cf,Conv_,Conv,Dens_,Dens,Act,Quant,BatchNormalization,
         model.add(BatchNormalization_FP())
         model.add(Activation('softmax'))
         
+    elif(model_type == 'MLP_512_128_64_10'):
+        print('MLP_512_128_64_10 toplogy selected...\n')
+        
+        model.add(Dens_(512,cf.dim,cf.channels,6.))
+        model.add(BatchNormalization(cf.dim*cf.dim*cf.channels,4))
+        model.add(Act())
+        
+        model.add(Dropout(0.05))
+        model.add(Dens(128,2.))
+        model.add(BatchNormalization(512,2))
+        model.add(Act())
+  
+        model.add(Dropout(0.05))
+        model.add(Dens(64,2.))
+        model.add(BatchNormalization(512,2))
+        model.add(Act())
+  
+        model.add(Dropout(0.1))
+        model.add(Dens_FP(cf.classes))
+        model.add(BatchNormalization_FP())
+        model.add(Activation('softmax'))
+        
+    elif(model_type == 'MLP_512_256_32_32_10'):
+        print('MLP_512_32_32_10 toplogy selected...\n')
+        
+        model.add(Dens_(512,cf.dim,cf.channels,6.))
+        model.add(BatchNormalization(cf.dim*cf.dim*cf.channels,4))
+        model.add(Act())
+        
+        model.add(Dropout(0.0))
+        model.add(Dens(256,2.))
+        model.add(BatchNormalization(512,2))
+        model.add(Act())
+
+        model.add(Dropout(0.0))
+        model.add(Dens(32,2.))
+        model.add(BatchNormalization(512,2))
+        model.add(Act())
+  
+        model.add(Dropout(0.0))
+        model.add(Dens(32,2.))
+        model.add(BatchNormalization(32,2))
+        model.add(Act())
+  
+        model.add(Dropout(0.0))
+        model.add(Dens_FP(cf.classes))
+        model.add(BatchNormalization_FP())
+        model.add(Activation('softmax'))
+        
     # MLP with hidden layers of size 512
     elif(model_type == 'MLP_512'):
         print('MLP_512 toplogy selected...\n')
@@ -375,71 +424,174 @@ def make_model(model_type,cf,Conv_,Conv,Dens_,Dens,Act,Quant,BatchNormalization,
         model.add(Activation('softmax'))
         
     # VGG-16 network
+    elif(model_type == 'VGG-8'):
+        print('VGG-8 network topology selected...')
+        
+        model.add(Conv_(cf.kern_size, 64,cf.dim,cf.channels,6,0))
+        model.add(BatchNormalization(cf.dim*cf.dim*cf.channels,4))
+        model.add(Act())
+        #model.add(Dropout(0.3))
+        # model.add(MaxPooling2D(pool_size=(2, 2)))
+        
+        model.add(Conv(cf.kern_size, 128,6,0))
+        model.add(BatchNormalization(16*16*128,4))
+        model.add(Act())
+        model.add(MaxPooling2D(pool_size=(2, 2)))
+        model.add(Dropout(0.2))
+        
+        model.add(Conv(cf.kern_size, 256,6,0))
+        model.add(BatchNormalization(8*8*256,4))
+        model.add(Act())
+        model.add(Dropout(0.2))
+        model.add(Conv(cf.kern_size, 256,6,0))
+        model.add(BatchNormalization(8*8*256,4))
+        model.add(Act())
+        model.add(MaxPooling2D(pool_size=(2, 2)))
+
+        model.add(Conv(cf.kern_size, 512,6,0))
+        model.add(BatchNormalization(2*2*512,4))
+        model.add(Act())
+        model.add(Dropout(0.2))
+        model.add(Conv(cf.kern_size, 512,6,0))
+        model.add(BatchNormalization(2*2*512,4))
+        model.add(Act())
+        model.add(MaxPooling2D(pool_size=(2, 2)))
+
+        model.add(Flatten())
+        model.add(Dropout(0.5))
+        model.add(Dens_FP(4096))
+        model.add(BatchNormalization_FP())
+        model.add(Activation('relu'))
+        model.add(Dens_FP(10))
+        model.add(BatchNormalization_FP())
+        model.add(Activation('softmax'))      
+        
+        # VGG-16 network
+    elif(model_type == 'VGG-11'):
+        print('VGG-11 network topology selected...')
+        
+        model.add(Conv_(cf.kern_size, 64,cf.dim,cf.channels,6,0))
+        model.add(BatchNormalization(cf.dim*cf.dim*cf.channels,4))
+        model.add(Act())
+        #model.add(Dropout(0.3))
+        # model.add(MaxPooling2D(pool_size=(2, 2)))
+        
+        model.add(Conv(cf.kern_size, 128,6,0))
+        model.add(BatchNormalization(16*16*128,4))
+        model.add(Act())
+        # model.add(MaxPooling2D(pool_size=(2, 2)))
+        #model.add(Dropout(0.4))
+        
+        model.add(Conv(cf.kern_size, 256,6,0))
+        model.add(BatchNormalization(8*8*256,4))
+        model.add(Act())
+        #model.add(Dropout(0.4))
+        model.add(Conv(cf.kern_size, 256,6,0))
+        model.add(BatchNormalization(8*8*256,4))
+        model.add(Act())
+        model.add(MaxPooling2D(pool_size=(2, 2)))
+
+        model.add(Conv(cf.kern_size, 512,6,0))
+        model.add(BatchNormalization(4*4*512,4))
+        model.add(Act())
+        #model.add(Dropout(0.4))
+        model.add(Conv(cf.kern_size, 512,6,0))
+        model.add(BatchNormalization(4*4*512,4))
+        model.add(Act())
+        # model.add(MaxPooling2D(pool_size=(2, 2)))
+
+        model.add(Conv(cf.kern_size, 512,6,0))
+        model.add(BatchNormalization(2*2*512,4))
+        model.add(Act())
+        #model.add(Dropout(0.4))
+        model.add(Conv(cf.kern_size, 512,6,0))
+        model.add(BatchNormalization(2*2*512,4))
+        model.add(Act())
+        model.add(MaxPooling2D(pool_size=(2, 2)))
+
+        model.add(Flatten())
+        model.add(Dropout(0.5))
+        model.add(Dens_FP(4096))
+        model.add(BatchNormalization_FP())
+        model.add(Activation('relu'))
+        model.add(Dropout(0.5))
+        model.add(Dens_FP(4096))
+        model.add(BatchNormalization_FP())
+        model.add(Activation('relu'))
+        model.add(Dens_FP(10))
+        model.add(BatchNormalization_FP())
+        model.add(Activation('softmax'))      
+        
+        # VGG-16 network
     elif(model_type == 'VGG-16'):
         print('VGG-16 network topology selected...')
         
         model.add(Conv_(cf.kern_size, 64,cf.dim,cf.channels,6,0))
         model.add(BatchNormalization(cf.dim*cf.dim*cf.channels,4))
         model.add(Act())
-        model.add(Dropout(0.3))
+        # model.add(Dropout(0.3))
         
         model.add(Conv(cf.kern_size, 64,6,0))
         model.add(BatchNormalization(32*32*cf.channels,4))
         model.add(Act())
-        # model.add(MaxPooling2D(pool_size=(2, 2)))
+        model.add(MaxPooling2D(pool_size=(2, 2)))
         
         model.add(Conv(cf.kern_size, 128,6,0))
         model.add(BatchNormalization(16*16*64,4))
         model.add(Act())
-        model.add(Dropout(0.4))
+        # model.add(Dropout(0.4))
 
         model.add(Conv(cf.kern_size, 128,6,0))
         model.add(BatchNormalization(16*16*128,4))
         model.add(Act())
-        # model.add(MaxPooling2D(pool_size=(2, 2)))
+        model.add(MaxPooling2D(pool_size=(2, 2)))
 
         model.add(Conv(cf.kern_size, 256,6,0))
         model.add(BatchNormalization(8*8*128,4))
         model.add(Act())
-        model.add(Dropout(0.4))
+        # model.add(Dropout(0.4))
         model.add(Conv(cf.kern_size, 256,6,0))
         model.add(BatchNormalization(8*8*256,4))
         model.add(Act())
-        model.add(Dropout(0.4))
+        # model.add(Dropout(0.4))
         model.add(Conv(cf.kern_size, 256,6,0))
         model.add(BatchNormalization(8*8*256,4))
         model.add(Act())
-        # model.add(MaxPooling2D(pool_size=(2, 2)))
+        model.add(MaxPooling2D(pool_size=(2, 2)))
 
         model.add(Conv(cf.kern_size, 512,6,0))
         model.add(BatchNormalization(4*4*256,4))
         model.add(Act())
-        model.add(Dropout(0.4))
+        # model.add(Dropout(0.4))
         model.add(Conv(cf.kern_size, 512,6,0))
         model.add(BatchNormalization(4*4*512,4))
         model.add(Act())
-        model.add(Dropout(0.4))
+        # model.add(Dropout(0.4))
         model.add(Conv(cf.kern_size, 512,6,0))
         model.add(BatchNormalization(4*4*512,4))
         model.add(Act())
-        # model.add(MaxPooling2D(pool_size=(2, 2)))
+        model.add(MaxPooling2D(pool_size=(2, 2)))
 
         model.add(Conv(cf.kern_size, 512,6,0))
         model.add(BatchNormalization(2*2*256,4))
         model.add(Act())
-        model.add(Dropout(0.4))
+        # model.add(Dropout(0.4))
         model.add(Conv(cf.kern_size, 512,6,0))
         model.add(BatchNormalization(2*2*512,4))
         model.add(Act())
-        model.add(Dropout(0.4))
+        # model.add(Dropout(0.4))
         model.add(Conv(cf.kern_size, 512,6,0))
         model.add(BatchNormalization(2*2*512,4))
         model.add(Act())
-        # model.add(MaxPooling2D(pool_size=(2, 2)))
+        model.add(MaxPooling2D(pool_size=(2, 2)))
 
         model.add(Flatten())
         model.add(Dropout(0.5))
-        model.add(Dens_FP(512))
+        model.add(Dens_FP(4096))
+        model.add(BatchNormalization_FP())
+        model.add(Activation('relu'))
+        model.add(Dropout(0.5))
+        model.add(Dens_FP(4096))
         model.add(BatchNormalization_FP())
         model.add(Activation('relu'))
         model.add(Dens_FP(10))
diff --git a/models/model_IMC.py b/models/model_IMC.py
index 9b5e8b0..7c8c279 100644
--- a/models/model_IMC.py
+++ b/models/model_IMC.py
@@ -63,13 +63,14 @@ def build_model(cf,model_type,sramInfo,EN_NOISE,FLAGS):
                                    kernel_regularizer=l2(cf.kernel_regularizer),use_bias=False)
         Conv_FP_ = lambda s, f, i, c: Conv2D(kernel_size=(s, s), filters=f, strides=(1, 1), padding='same', activation='linear',
                                    kernel_regularizer=l2(cf.kernel_regularizer),input_shape = (i,i,c),use_bias=False)
-        Act = lambda: LeakyReLU()
+        #Act = lambda: LeakyReLU()
+        Act = lambda: Activation('relu')
         
         Quant = lambda n: Activation(lambda x: quant_uni(x,maxVal=n,dynRange=dynRange,OAres=OAres,offset=0.5*dynRange/n))
         
-        Dens_FP = lambda n: Dense(n,use_bias=False)
+        Dens_FP = lambda n: Dense(n,use_bias=True)
         
-        Dens = lambda n: Dense(n,use_bias=False)
+        Dens = lambda n: Dense(n,use_bias=True)
         
         Dens_ = lambda n,i,c:  Dense(n,use_bias=False,activation='linear',input_shape=(i*i*c,))
     elif cf.network_type=='qnn':
diff --git a/my_datasets/my_cifar10.py b/my_datasets/my_cifar10.py
index 7ecd4c3..cc1e402 100644
--- a/my_datasets/my_cifar10.py
+++ b/my_datasets/my_cifar10.py
@@ -35,11 +35,11 @@ def load_data():
   Returns:
       Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
   """
-  dirname = 'cifar-10-batches-py'
-  origin = './'
+  #dirname = 'cifar-10-batches-py'
+  #origin = './'
   #path = get_file(dirname, origin=origin, untar=True)
-  path = '/export/home/adkneip/Documents/PhD/Python3/IMC_Modeling/qnn/my_datasets/cifar-10-batches-py';
-
+  HOME = os.environ["HOME"];
+  path = HOME+'/Documents/PhD/cim_qnn_training/my_datasets/cifar-10-batches-py';
   num_train_samples = 50000
 
   x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')
diff --git a/sw_to_chip.py b/sw_to_chip.py
index 3898362..50aad98 100644
--- a/sw_to_chip.py
+++ b/sw_to_chip.py
@@ -5,6 +5,7 @@ import sys,os
 import h5py
 import numpy as np
 import tensorflow as tf
+import keras.backend as K
 from keras.models import load_model
 
 from ctypes import c_uint32, c_uint64
@@ -14,7 +15,7 @@ from layers.binary_ops import binarize as binarize
 
 from utils.config_hardware_model import SramInfo_charge as SramInfo
 
-from chip_files.create_C_header import create_C_header
+from chip_files.create_C_header import create_C_header, create_C_header_subset
 from chip_files.create_fpga_files import create_fpga_files
 
 #################################################
@@ -31,7 +32,10 @@ R_GAMMA = r_gamma;
 # Network length
 Nlayers = len(C_IN_VEC); 
 # Flags for test files generation
-OUT_EN = 0; # 1: output files per layer exist ; 0: they do not, prevent storage and comparison
+OUT_EN = 1; # 1: output files per layer exist ; 0: they do not, prevent storage and comparison
+# Enable mapping of a network sub-part, and select the starting layer
+SUB_EN = 1;
+START_LAYER = 2;
 
 # Create CIMU structure
 sramInfo = SramInfo(arch,tech,typeT,VDD,BBN,BBP,IAres,Wres,OAres,r_gamma,r_beta,Nrows,[0,0]);     
@@ -99,7 +103,7 @@ with open(in_file,"r") as f:
 # // Transform outputs sub-set into 32b words for SRAM encoding //
 if(OUT_EN):
   outputs_list = []; outputs_list_test = [];
-  for i in range(Nlayers):
+  for i in range(Nlayers-Nl_fp):
     C_OUT = C_OUT_VEC[i];
     # Get outputs (only ADC outputs)
     with open(out_file+"_layer_{}.txt".format(i),"r") as f:
@@ -132,7 +136,7 @@ if(OUT_EN):
 
   for i in range(Nl_fp):
     # Get outputs
-    with open(out_file+"_layer_{}.txt".format(Nlayers+i),"r") as f:
+    with open(out_file+"_layer_{}.txt".format(Nlayers-Nl_fp+i),"r") as f:
       outputs = np.genfromtxt(f,delimiter=" ");
       # Transform into FP
       outputs = np.int32(np.round(outputs*(2**16-1)*(2**15)/fS_beta_fp/fS_gamma_fp));
@@ -152,24 +156,25 @@ c_in_vec = []; c_out_vec = [];
 Nlayers_cim = 0;  Nlayers_fp = 0;
 with h5py.File(w_file,"r") as f:
   # List all groups
-  list_of_keys = list(f.keys())
+  list_of_keys = list(f.keys()); print(list_of_keys);
   # print(list_of_keys)
   for key in list_of_keys:
     # // Different cases depending upon the layer type (key) //
     # CIM-QNN layer
     if(('cim_charge_conv2d' in key) or ('cim_charge_dense' in key)):
-      dataset = f[key][key];
+      dataset = f[key][key]; print(f[key].keys())
       local_keys = list(dataset.keys());
-      w_data = dataset[local_keys[0]][()];
-      # Binarize weights
+      w_data = dataset[local_keys[1]][()];
       w_data = tf.cast((binarize(w_data,H=1.)+np.ones_like(w_data))/2,dtype="int32");
       # Get weights shape (detect FC or CONV)
-      w_shape = tf.shape(w_data);
+      w_shape = K.int_shape(w_data);
       if(len(w_shape)>1):
         w_data    = tf.reshape(w_data,(-1,w_shape[-1]));
-        w_shape   = tf.shape(w_data);
+        w_shape   = K.int_shape(w_data);
+      print(w_data); print(w_shape)
       # Pad with zeros to reach the full array size
-      w_data = np.pad(w_data,((0,Nrows-w_shape[0]),(0,Ncols-w_shape[1])));
+      if(not(SUB_EN)):
+        w_data = np.pad(w_data,((0,Nrows-w_shape[0]),(0,Ncols-w_shape[1])));
       # Store layer dimensions
       c_in_vec.append(w_shape[-2]); c_out_vec.append(w_shape[-1]); 
       Nlayers_cim += 1;
@@ -188,7 +193,7 @@ with h5py.File(w_file,"r") as f:
       w_data = np.int32(w_data);
       # w_data = w_data*(2**15)/fS_beta_fp;
       # Store weights
-      weights_FP_list.append(np.reshape(w_data,(-1,1)));
+      weights_FP_list.append(np.reshape(w_data,(-1,)));
       # Count one more FP layer
       Nlayers_fp += 1;
     # Analog BN
@@ -197,10 +202,10 @@ with h5py.File(w_file,"r") as f:
       local_keys = list(dataset.keys());
       beta  = dataset[local_keys[0]][()];
       gamma = dataset[local_keys[1]][()];
-      #m_sigma  = dataset[local_keys[2]][()]; # to be corrected with updated training, if necesseay
-      m_sigma = 1;
-      mov_mean = dataset[local_keys[2]][()];
-      mov_var  = dataset[local_keys[3]][()];
+      m_sigma  = dataset[local_keys[2]][()]; # to be corrected with updated training, if necesseay
+      #m_sigma = 1;
+      mov_mean = dataset[local_keys[3]][()];
+      mov_var  = dataset[local_keys[4]][()];
       
       # // Retrieve hardware parameters //
       Vmax_beta = sramInfo.Vmax_beta;
@@ -213,14 +218,14 @@ with h5py.File(w_file,"r") as f:
       mov_variance_DP_t = np.mean(mov_var)/var_goal;
       sigma_DP_t = np.sqrt(mov_variance_DP_t); 
       # Get equivalent coefficients
-      gamma_eq = gamma/(sigma_DP_t + epsilon);
+      gamma_eq = gamma/(sigma_DP_t + epsilon); print(gamma_eq)
       # Get gamma encoding
-      gamma_code = np.round(np.log2(gamma_eq));
+      gamma_code = 1+np.clip(np.round(np.log(gamma_eq)/np.log(2)),0,r_gamma);
       
       # // Equivalent offset computation //
       beta_eq = beta/gamma_eq - mov_mean;
       # Get beta encoding
-      beta_code = np.round(beta_eq/Vlsb_beta);
+      beta_code = np.clip(np.round(beta_eq/Vlsb_beta)+2**(r_beta-1),0,2**r_beta-1);
       print(beta_code)
       
       # // Append gamma & beta configs (uint8, encoding done during C mapping) 
@@ -286,12 +291,12 @@ if(OUT_EN):
   print(outputs_list_test[-1]);
   # Detailed comptuation below
   in_FP = np.reshape(outputs_list_test[-1],(Nimg_save,-1));
-  w_FP  = np.reshape(weights_FP_list[0],(C_OUT_VEC[-1],10));
+  w_FP  = np.reshape(weights_FP_list[0],(C_IN_VEC[-Nl_fp],10));
   gamma_FP = gamma_FP_list[0]; beta_FP = beta_FP_list[0];
   mac_val = np.zeros((Nimg_save,10),dtype="int32");
   for m in range(Nimg_save):
     # Perform MAC operations
-    for i in range(C_OUT_VEC[-1]):
+    for i in range(C_IN_VEC[-1]):
       # Fetch input
       inputs = in_FP[m][i]; 
       for j in range(10):
@@ -299,11 +304,13 @@ if(OUT_EN):
         weights = w_FP[i][j];
         # MAC operation
         mac_val[m][j] = mac_val[m][j] + inputs*weights;
+        if(m == 0):
+          print("Img {}, mac {} with input {}, weight {}: {}".format(m,10*i+j,inputs,weights,mac_val[m][j]));
         #if(m==0 and (i==0 or i==1)):
-        if(m==0 and i<8 and j==0):
-          print('Input {} is {}'.format(j,inputs));
-          print('Weight {} is {}'.format(j,weights));
-          print('DP {} at iter {} is {}'.format(j,i,mac_val[m][j]));
+        #if(m==0 and i<8 and j==0):
+        #  print('Input {} is {}'.format(j,inputs));
+        #  print('Weight {} is {}'.format(j,weights));
+        #  print('DP {} at iter {} is {}'.format(j,i,mac_val[m][j]));
     # Print final DP value
     for j in range(10):
       if(m==0):
@@ -344,11 +351,12 @@ np.savetxt(file_out_inputs,data_in,fmt='%x');
 if(OUT_EN):
   cim_outputs = np.concatenate(outputs_list,axis=None).astype("uint64");
   for i in range(len(outputs_list)):
-    if(i<Nlayers):
+    if(i<Nlayers_cim):
       np.savetxt(file_out_outputs+'_layer_{}.txt'.format(i),outputs_list[i].astype("uint64"),fmt='%x');
     else:
       np.savetxt(file_out_outputs+'_layer_{}.txt'.format(i),outputs_list[i].astype("uint64"),fmt='%x');
   
+  
 # Inference results
 np.savetxt(file_out_inference,np.array([inf_results]).astype("uint64"),fmt='%x');
 # CIM weights
@@ -361,11 +369,11 @@ np.savetxt(file_out_gamma+'.txt',gamma_cim,fmt='%x');
 beta_cim = np.concatenate(beta_list,axis=None);
 np.savetxt(file_out_beta+'.txt',beta_cim,fmt='%x');
 # FP FC/CONV weights
-weights_fp = np.concatenate(weights_FP_list,axis=None).astype("uint64");
+weights_fp = np.concatenate(weights_FP_list,axis=None).astype("int32");
 np.savetxt(file_out_weights_FP+'.txt',weights_fp,fmt='%x');
 # FP BN weights
-gamma_fp = np.concatenate(gamma_FP_list,axis=None).astype("uint64");
-beta_fp  = np.concatenate(beta_FP_list,axis=None).astype("uint64");
+gamma_fp = np.concatenate(gamma_FP_list,axis=None).astype("int32");
+beta_fp  = np.concatenate(beta_FP_list,axis=None).astype("int32");
 np.savetxt(file_out_gamma_FP+'.txt',gamma_fp,fmt='%x'); 
 np.savetxt(file_out_beta_FP+'.txt',beta_fp,fmt='%x'); 
   
@@ -385,14 +393,23 @@ D_VEC         = (dim,dim,c_in_vec,c_out_vec);
 P_VEC         = (R_IN,R_W,R_OUT,R_BETA,R_GAMMA);
 T_VEC         = (T_DP,T_PRE,T_MBIT,T_ADC);
 # Data for FPGA
-data_fpga = [data_in,weights_cim,beta_list,weights_fp,inf_results.astype("int32")];
-if(OUT_EN):
-  data_fpga.append(cim_outputs);
+if(SUB_EN):
+  data_cim = [data_in,weights_list,beta_list,weights_FP_list,inf_results.astype("int32"),outputs_list];
+else:
+  data_cim = [data_in,weights_cim,beta_list,weights_fp,inf_results.astype("int32")];
+  if(OUT_EN):
+    data_cim.append(outputs_list_test);
 
 # // Generate C header file with hardware params //
-create_C_header(filename_c,network_info,cim_dim,D_VEC,P_VEC,T_VEC,gamma_cim,beta_fp,gamma_fp);
-# // Generate off-chip FPGA memory files //
-create_fpga_files(filename_fpga,network_info,cim_dim,D_VEC,P_VEC,data_fpga);
+if(SUB_EN):
+  # // Generate C header file only, with inputs from specified layer //
+  create_C_header_subset(filename_c,network_info,cim_dim,D_VEC,P_VEC,T_VEC,gamma_cim,beta_fp,gamma_fp,data_cim,START_LAYER);
+  
+else:
+  # // Generate C header file with hardware params //
+  create_C_header(filename_c,network_info,cim_dim,D_VEC,P_VEC,T_VEC,gamma_cim,beta_fp,gamma_fp);
+  # // Generate off-chip FPGA memory files //
+  create_fpga_files(filename_fpga,network_info,cim_dim,D_VEC,P_VEC,data_cim);
 
  
 print('///////////////////////////////////////////////////////');
diff --git a/train_cim_qnn.py b/train_cim_qnn.py
index 80ccf90..57dd253 100644
--- a/train_cim_qnn.py
+++ b/train_cim_qnn.py
@@ -2,6 +2,7 @@ from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, Learnin
 from tensorflow.keras.optimizers import SGD, Adam
 from keras.losses import squared_hinge, categorical_crossentropy
 from keras.models import Model
+from keras.preprocessing.image import ImageDataGenerator
 import tensorflow as tf
 import keras.backend as K
 import numpy as np
@@ -13,6 +14,8 @@ from utils.load_data import load_dataset
 from utils.config_hardware_model import SramInfo_current, SramInfo_charge
 from config.config_cim_cnn_param import*
 
+import matplotlib.pyplot as plt
+
 # // Override configuration //
 override = {}
 override_dir = {}
@@ -83,12 +86,12 @@ def generate_model(data_files,cf,network_struct,sramInfo,FLAGS):
     lr_decay = LearningRateScheduler(scheduler)
 
 
-    #sgd = SGD(lr=cf.lr, decay=cf.decay, momentum=0.9, nesterov=True)
-    adam= Adam(lr=cf.lr, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=cf.decay)
+    sgd = SGD(learning_rate=cf.lr, decay=cf.decay, momentum=0.9, nesterov=True)
+    adam= Adam(learning_rate=cf.lr, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=cf.decay)
 
     # Perform training and validation on ideal model
     print('Compiling the network\n')
-    model.compile(loss=categorical_crossentropy,optimizer=adam,metrics=['accuracy'])
+    model.compile(loss=categorical_crossentropy,optimizer=sgd,metrics=['accuracy'])
     if cf.finetune:
         print('Load previous weights\n')
         model.load_weights(w_file)
@@ -116,9 +119,13 @@ def process_input(dataset,IS_FL_MLP,precisions):
         x_test = test_data[0].reshape(test_data[0].shape[0],test_data[0].shape[1]*test_data[0].shape[2])
         train_data = (x_train,train_data[1])
         test_data = (x_test,test_data[1])
-    # Quantize inputs    
-    x_train = quant_input(train_data[0],IAres);
-    x_test = quant_input(test_data[0],IAres);
+    # Quantize inputs     
+    if(cf.network_type == "float"):
+        x_train = train_data[0];
+        x_test  = test_data[0];
+    else:
+        x_train = quant_input(train_data[0],IAres);
+        x_test  = quant_input(test_data[0],IAres);
     train_data = (x_train,train_data[1])
     test_data = (x_test,test_data[1])
     return(train_data,test_data);
@@ -166,19 +173,61 @@ def train_eval_model(data_files,model,precisions,input_data,Niter,SAVE_EN):
             return K.get_value(model.optimizer.lr)
         # Create LR scheduler using this custom scheduling
         lr_decay = LearningRateScheduler(scheduler)
-
+        
+        def my_scheduler(epoch):
+          if(epoch < 1):
+            return cf.lr
+          elif(epoch<10):
+            return 0.1*cf.lr
+          else:
+            return 0.1*cf.lr * (0.5 ** ((epoch-10) // 20))
+                
+        #lr_decay = LearningRateScheduler(my_scheduler)
+        lr_decay = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',factor = 0.5,patience=10)
+        
+        # Data augmentation
+        if(cf.dataset_name == 'CIFAR-10'):
+          print('Using real-time data augmentation.')
+          datagen_train = ImageDataGenerator(horizontal_flip=True,
+                  width_shift_range=0.25,height_shift_range=0.25,fill_mode='constant',cval=0.,validation_split=0.15)
+          # datagen_train = ImageDataGenerator(horizontal_flip=False,fill_mode='constant',cval=0.,validation_split=0.15);
+          datagen_val   = ImageDataGenerator(horizontal_flip=False,fill_mode='constant',cval=0.,validation_split=0.15);
+        else:
+          print('No data augmentation');
+          datagen_train = ImageDataGenerator(horizontal_flip=False,fill_mode='constant',cval=0.,validation_split=0.15);
+          datagen_val   = ImageDataGenerator(horizontal_flip=False,fill_mode='constant',cval=0.,validation_split=0.15);
+        
         # // Train the model //
         print('### Training the network ###\n')
-        history = model.fit(train_data[0],train_data[1],
-                    batch_size = batch_size,
-                    epochs = epochs,
-                    verbose = progress_logging,
-                    callbacks = [checkpoint, tensorboard,lr_decay],
-                    validation_split = 0.15,
-                    workers = 4,
-                    use_multiprocessing = True
-                    );              
-         
+        if(cf.dataset_name == 'CIFAR-10'):
+          x_train = train_data[0]; y_train = train_data[1];
+          datagen_train.fit(x_train);
+          datagen_val.fit(x_train);
+
+          history = model.fit(datagen_train.flow(x_train,y_train,
+                      batch_size = batch_size,subset='training'),
+                      # batch_size = batch_size),
+                      epochs = epochs,
+                      verbose = progress_logging,
+                      callbacks = [checkpoint, tensorboard,lr_decay],
+                      # validation_split = 0.2,
+                      validation_data = datagen_val.flow(x_train,y_train,
+                      batch_size = batch_size,subset='validation'),
+                      workers = 4,
+                      use_multiprocessing = True,
+                      steps_per_epoch=4*len(x_train)/batch_size
+                      );              
+        else:
+          history = model.fit(train_data[0],train_data[1],
+              batch_size = batch_size,
+              epochs = epochs,
+              verbose = progress_logging,
+              callbacks = [checkpoint, tensorboard,lr_decay],
+              validation_split = 0.15,
+              workers = 4,
+              use_multiprocessing = True
+              );  
+           
         # Test model
         print('### Training done ! Evaluating on test data... ###')
         history_eval = model.evaluate(test_data[0],test_data[1],
@@ -219,20 +268,26 @@ def train_eval_model(data_files,model,precisions,input_data,Niter,SAVE_EN):
         # Save inputs
         with open(in_file,"w") as f:
           np.savetxt(f,np.reshape(test_data[0][0:Nimg_save],(-1,1)),fmt='%d');
-#        # Save outputs
-#        Nlayers = len(best_model.layers); indL = 0;
-#        for i in range(Nlayers):
-#          # Get desired layer outputs
-#          partial_model = Model(best_model.input,best_model.layers[i].output);
-#          data_out = partial_model(test_data[0][0:Nimg_save],training=False);
-#          # Write outputs to file, if ADC output only
-#          #if(i==6 or i==7 or i==8):
-#          #  print(data_out)
-#          if(i==2 or i==6 or i==9):
-#            out_file_temp = out_file+"_layer_{}.txt".format(indL);
-#            indL = indL+1;
-#            with open(out_file_temp,"w") as f:
-#              np.savetxt(f,np.reshape(data_out,(-1,1)),fmt='%f');
+        # Save outputs
+        Nlayers = len(best_model.layers); indL = 0;
+        for i in range(Nlayers):
+          # Get desired layer outputs
+          partial_model = Model(best_model.input,best_model.layers[i].output);
+          data_out = partial_model(test_data[0][0:Nimg_save],training=False);
+          # Write outputs to file, if ADC output only
+          #if(i==6 or i==7 or i==8):
+          #  print(data_out)
+          if(i==2 or i==6 or i==10 or i==13):
+            out_file_temp = out_file+"_layer_{}.txt".format(indL);
+            indL = indL+1;
+            with open(out_file_temp,"w") as f:
+              np.savetxt(f,np.reshape(data_out,(-1,1)),fmt='%d');
+          # Store output decision
+          if(i==Nlayers-1):
+            out_file_temp = out_file+"_layer_softmax.txt";
+            with open(out_file_temp,"w") as f:
+              np.savetxt(f,np.reshape(data_out,(-1,1)),fmt='%d');
+              
         # Save inference result
         with open(inference_file,"w") as f:
           indResult = np.argmax(test_data[1][0:Nimg_save],axis=-1);
diff --git a/utils/config_hardware_model.py b/utils/config_hardware_model.py
index 0b37968..ab95aff 100644
--- a/utils/config_hardware_model.py
+++ b/utils/config_hardware_model.py
@@ -414,6 +414,8 @@ class SramInfo_charge:
         # Computation and parasitic caps
         self.Cc = 0.697e-15;
         self.C_array = 32.5e-15; # to be checked
+        # Estimated pre-ADC voltage-equivalent deviation
+        self.sig_Vdp = 2.5*self.VDD.data/256;
         # --- DP post-layout LUT ---
         C_in_conf = 5;
         # - Nominal result -
@@ -479,20 +481,27 @@ class SramInfo_charge:
         # Load capacitance
         self.C_adc = 25.4e-15;
 
-#        path_dir = dir_LUT+'/ABN/mean_TF_ABN_PL_int.txt';
-#        #path_dir = dir_LUT+'<path_to_ABN_TF_filename>.txt';
+        # --- Post-layout LUT ---
+        #path_dir = dir_LUT+'<path_to_ABN_TF_filename>.txt';
+        #path_dir = dir_LUT+'/ABN/mean_TF_ABN_PL_int.txt';
+#        path_dir = dir_LUT+'/ABN/TF_ADC_gamma_PL_with_lad.txt';
+#        Nabn = np.floor(1.2*2*2^8+1);
 #        # Get & reshape data
 #        data_temp = np.genfromtxt(path_dir,delimiter=" ",skip_header=1,skip_footer=0);
-#        temp_LUT = np.reshape(data_temp,(2**r_gamma,Nabn,8));
 #        if(ABN_INC_ADC):
-#          temp_LUT = temp_LUT[...,4:8];
+#          temp_LUT = data_temp[...,-4];
 #        else:
 #          raise NameError('ABN with zoom-ADC must include ADC !');
-#        temp_LUT = temp_LUT.astype("float32");  temp_LUT = np.flip(temp_LUT,axis=-1);
+#        temp_LUT = np.reshape(data_temp,(2**r_gamma,Nabn));
+#        temp_LUT = temp_LUT.astype("float32");
 #        # Make 2D lookup of linear interpolations
-#        self.ABN_LUT = makeLookup2D(temp_LUT,np.arange(0,2**r_gamma),np.linspace(0,self.VDD,Nabn));
-#        # ABN mismatch
-#        path_dir = dir_LUT+'/ABN/std_TF_ABN_PL_int.txt';
+#        self.ABN_LUT = np.zeros((r_gamma+1,Nabn));
+#        for i in range(r_gamma+1):
+#          ABN_LUT_part = makeLookup2D(temp_LUT,2**i,np.linspace(self.VDD/2*(1-1.2/(2**i)),self.VDD/2*(1+1.2/(2**i)),Nabn));
+#          #ABN_LUT_part = makeLookup2D(temp_LUT,2**np.arange(0,r_gamma+1),np.linspace(self.VDD/2*(1-1.2/(2**i)),self.VDD/2*(1+1.2/(2**i)),Nabn));
+#          self.ABN_LUT[i,::] = ABN_LUT_part;
+        # ABN mismatch
+#        path_dir = dir_LUT+'/ABN/std_TF_ADC_gamma_PL_with_lad.txt';
 #        #path_dir = dir_LUT+'<path_to_ABN_dev_filename>.txt';
 #        data_temp = np.genfromtxt(path_dir,delimiter=" ",skip_header=1,skip_footer=0);
 #        temp_LUT = np.reshape(data_temp,(2**r_gamma,Nabn,8));
@@ -503,7 +512,8 @@ class SramInfo_charge:
 #        temp_LUT = temp_LUT.astype("float32");  temp_LUT = np.flip(temp_LUT,axis=-1);
 #        # Make 2D lookup of linear interpolations
 #        self.sig_ABN_LUT = makeLookup2D(temp_LUT,np.arange(0,2**r_gamma),np.linspace(0,self.VDD,Nabn));
-        
+#        self.sig_ABN_LUT = np.zeros_like(self.ABN_LUT); # we don't have the gamma-wise dist.
+          
         self.ABN_LUT = None;
         self.sig_ABN_LUT = None;
 
diff --git a/utils/linInterp.py b/utils/linInterp.py
index 3043874..db1e132 100644
--- a/utils/linInterp.py
+++ b/utils/linInterp.py
@@ -118,4 +118,22 @@ def doInterpDP_2D(LUT,x1,x2,x1_vec,x2_max,N2):
     f_int = tf.reshape(f_int,x2_shape);
     # Return interpolated result
     return f_int
+    
+# /// 2D interpolation from a LUT - specific to numerical analog DP ///
+def doInterp_2D(LUT,x1,x2,x1_max,N1,x2_max,N2):
+    # Possibly reshape x2 if CONV layer
+    x2_shape = tf.shape(x2);
+    Vlsb = x2_max/(N2-1);
+    x2 = floor_through(tf.reshape(x2/Vlsb,(-1,x2_shape[-1])))*Vlsb;
+    # Get indices
+    ind_x1 = K.clip(tf.math.floor(x1/x1_max*N1),0,(N1-1)-1); ind_x1 = K.cast(ind_x1,"int32");
+    ind_x2 = K.clip(tf.math.floor(x2/x2_max*(N2-1)),0,(N2-1)-1); ind_x2 = K.cast(ind_x2,"int32"); 
+    # Get corresponding coefficients
+    coef_vec = tf.gather_nd(LUT,tf.stack([ind_x1*K.ones_like(ind_x2),ind_x2],axis=2));
+    # Perform interpolation
+    f_int = coef_vec[::,::,0]+coef_vec[::,::,1]*x1+coef_vec[::,::,2]*x2+coef_vec[::,::,3]*x1*x2;
+    # Reshape result back, if needed
+    f_int = tf.reshape(f_int,x2_shape);
+    # Return interpolated result
+    return f_int
     
\ No newline at end of file
-- 
GitLab