From d1fff28cc02f2dd99bc4a3d0cc99af399f0c4c22 Mon Sep 17 00:00:00 2001 From: "jijoong.moon" Date: Mon, 6 Jul 2020 10:24:10 +0900 Subject: [PATCH] [ Unit Test ] Generate Tensorflow ouptut and Comparison This PR provides: . Generate Tensorflow output & gradients output for conv2d, pooling2d . Compare with nntrainer outputs . TODO : compare gradient after getGradieint func is implemented. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: jijoong.moon --- nntrainer/include/util_func.h | 7 + nntrainer/src/conv2d_layer.cpp | 4 +- nntrainer/src/util_func.cpp | 20 +- packaging/unittest_layers.tar.gz | Bin 6084 -> 5696 bytes test/input_gen/genInput.py | 164 ++++++++++++----- test/unittest/unittest_nntrainer_layers.cpp | 274 +++++++++++++++++----------- 6 files changed, 306 insertions(+), 163 deletions(-) diff --git a/nntrainer/include/util_func.h b/nntrainer/include/util_func.h index a769bb8..d913bac 100644 --- a/nntrainer/include/util_func.h +++ b/nntrainer/include/util_func.h @@ -126,6 +126,13 @@ float no_op(float x); */ Tensor strip_pad(Tensor const &in, unsigned int const *padding); +/** + * @brief rotate 180 dgree + * @param[in] in input Tensor + * @retVal Tensor rotated tensor (180 degree) + */ +Tensor rotate_180(Tensor in); + } /* namespace nntrainer */ #endif /* __cplusplus */ diff --git a/nntrainer/src/conv2d_layer.cpp b/nntrainer/src/conv2d_layer.cpp index e94e9a7..7b2c62c 100644 --- a/nntrainer/src/conv2d_layer.cpp +++ b/nntrainer/src/conv2d_layer.cpp @@ -190,8 +190,8 @@ Tensor Conv2DLayer::backwarding(Tensor derivative, int iteration) { for (unsigned int b = 0; b < derivative.batch(); ++b) { Tensor in_padded = zero_pad(b, derivative, same_pad); TensorDim p_dim(1, 1, in_padded.height(), in_padded.width()); - for (unsigned int in_c = 0; in_c < input_dim.channel(); ++in_c) { + for (unsigned int in_c = 0; in_c < input_dim.channel(); ++in_c) { for (unsigned int i = 0; i < derivative.channel(); ++i) { conv2d(in_padded.getAddress(i * in_padded.height() * in_padded.width()), @@ -224,7 +224,7 @@ Tensor Conv2DLayer::backwarding(Tensor derivative, int iteration) { opt.apply_gradients(weights, gradients, iteration); } - return strip_pad(ret, padding); + return rotate_180(strip_pad(ret, padding)); } void Conv2DLayer::copy(std::shared_ptr l) { diff --git a/nntrainer/src/util_func.cpp b/nntrainer/src/util_func.cpp index 19872f9..cea2a3f 100644 --- a/nntrainer/src/util_func.cpp +++ b/nntrainer/src/util_func.cpp @@ -206,7 +206,7 @@ Tensor zero_pad(int batch, Tensor const &in, unsigned int const *padding) { // This is strip pad and return original tensor Tensor strip_pad(Tensor const &in, unsigned int const *padding) { - Tensor output(in.batch(), in.channel(), in.width() - padding[0] * 2, + Tensor output(in.batch(), in.channel(), in.height() - padding[0] * 2, in.width() - padding[1] * 2); output.setZero(); @@ -223,4 +223,20 @@ Tensor strip_pad(Tensor const &in, unsigned int const *padding) { return output; } -} /* namespace nntrainer */ +Tensor rotate_180(Tensor in) { + Tensor output(in.getDim()); + for (unsigned int i = 0; i < in.batch(); ++i) { + for (unsigned int j = 0; j < in.channel(); ++j) { + for (unsigned int k = 0; k < in.height(); ++k) { + for (unsigned int l = 0; l < in.width(); ++l) { + output.setValue( + i, j, k, l, + in.getValue(i, j, (in.height() - k - 1), (in.width() - l - 1))); + } + } + } + } + return output; +} + +} // namespace nntrainer diff --git a/packaging/unittest_layers.tar.gz b/packaging/unittest_layers.tar.gz index 1610013a1ebf9a1e15f42062f73bc6a279e6934d..94cf9ea537247e66e311aa65ab2f7ef158132d03 100644 GIT binary patch literal 5696 zcmV-G7Qg8qiwFRZ`3GMB1MHj&T#RY^$A{=Jp~)c;DyI!$W}2Czp8LA)N+qJwqS!Jz zD9e-%QVz8aODh!1B1F#F94jOlWs?+&LdmiJMvID)L|%uj^=kkB_VxCk)|&6<^XsX( zr|0Q&_x!H=zOU<^$KTaus>FroVmfBDILO_1q{Jt{e~6ca)tqqOAQA}P9(jgB?uW;B z7>Kwmo`K<;7YzkMfgy{_6B&vOSvuTi{VV<#0{r~NzB)Rr>0Z;ly(QCTzWMT}?`x|4 zGbOI4Iay+=4Fx@K6SMXsA>?OEFpXSJj)$)xA(l1}Z5068%oMn6aFiHQ7R;<30z3Rf zurTr@SsAMaYMCtHtQDJ7UmXGIm0}p5n?$01=m~Sh8=27jWHREe1dd7jLm%h!#He&6 zOf-rir4n`M9pVhy_P0p5qYZ2_>_mL61u)A+0BqGMMCZ&5u$Xd}JkIch#V`Ds(+MVI zQ#EB?pV0@0VVclpsWPnb=P{j=)W9Zl5AjZA!F9d=VcN~o0_hzNsP8N?-plC)_d>lO z)5RLv*USfH{{R?}Vht_b-+wp!|K0x++`XoG`ujC+1Zc|tLhi@@H{=_N8vt`3+c3!t^68&Rkua68!n z3XFThSJQ73wn+~%YLPl53{!&q2tU|fQAJXAzlk$GDlq)09cVo9Vura4`t0LLz5l0q z&v0{>{2m3Uhrhe;*FJ7yfA=BY0sjB(HJb82U-*&#`8yGqM1b(HFOiKHROk!D4oLBLx$Yba1qmbPcbU zG*e4KnniV`oP2F*=JuXa{%TdJ(L@s|U$2X_rTgcf%Afk*RP6d~pxD>Vd`@%N|39Dq z^Lczjx&K>;(D;8pvOfG-_(lnDWbnZ+)W7f0Up(IjuZN0$Ki`M^7w6lW`1|>e{PpvF z*?jDOdw0Kp8U7aDzHj~CEbjj#|8otwa{tQ$PUS^1cKK3(Rhi)AXa|Y9CJ@_e8#DIu zB-m4u4ri~AfexJRz%iY45*;5At29Mmp9+WmW_;q1#5$!}tqfhm7c9fd?0|Usp9l9iA8|#$H z#&&RLc`R8c);U$VHwHozwk7M%cLIkf4qQ5x1BZWH0+yq$z{ulCva{tSM9}Cn;l%ze3%UVtr z%?W}ol>^~b@j~)#ZV{7l!~%kCtsv;RDU8#Y4&O!T!4IM8pt@lUS&_zv&S`!SbHx=T zP77e=Iy=x=p$h5j60)+^74Ci&1Q#B3g7G z7F&T*xB-lmXhTifOEP!x1b8({1Z=B&B-fRkwYd>*Qqs;nw}&v3Eg3a-etz!(+19}`GR#!@Mpg+mlciu|851)tO>v-{Lka@ z1Rwj~K*W>#Ulx36`TxIrKJm2J9NJk5VgBn;F#TKy;Lt!AeNcm#-+W>+O+^zdqaH98 zo_?U6wu`(lvmn+UrR2NlZy}B?hJx|m5#f--}GO5!UIC|gJD-h2Y9eKj}TmK zZmz{2tLgb4m-o^9kI&^B%HzL`$e`~f=S!rq|IOab0^jF>nkNP{x&CY9f8N`J-2XCw z?rKA~Z&ab3hjpT3?AlY8d+PMiJ}o-1swZ7NUyUlC{tXxIXh%bSXif8umfwo|0%eU#|cT?+JqH=9;lvZ+QPi$=!Q zV&23Dm^ZZqd29u`Bv^^g=e)ps(N!4XbQgn4uio*wX#@xzgDG@3w$7>jdCGUjOmALcToz zD-&)z93_J2T<{&*5$pAH*fW?C(g*nKl+ogT6e>Kw(4f+vf54|C3nor6MD_AQJR7J> z>m`N~d{$?JpW2ojpo8XCyYb4^ng%`Fv=@76_rrC1#;{I}g`LZ$HRy5QVfZ5DA{fUc zLp%tvgR8hfB5HP@U0we3aP{-K%#3yGno_(eoN~WcvV2bXW!PbLT^hL>-6E z5;Z8Q?_-$C?hJa09NhB#c1*X+YS7^^i!oY%JG8yvgtm(+Fe!B)t(R}F`)Jk13O}_? z(_0H!S$pu>=obwdz4m83*s2dEty&9iHCnjeAgn>5(Uw?I_)xz7`wVFu|KI2Tn{xnY za{gz~X#FP=$m73E2#cReQ&ipP5VkvY$n&IgS2@#(%iL&$Xe<@0jinLZqX#lD^(DmZmYbG*M*~&23{w&t#0C1pzkHA<~Az96Op4WKZwh8AcV&M3lL1M9T?&VEU(ql{rb7^hyVF4R{wiD*=tn)lkfkP2}K?ILejH6M18ClaDwc?K2-;H zepkV)a$ZdWqviptOA%3O)d|MeC_;q0H3>Q}fIL7JC@h!{qmG{=YtD%Pbp%B3Tm{%#VW9Nhb!238P6mwumYPs zok`kEAI4_KEau!T2PU>_8;G5Bk*GWffl9}lOzD>0Bx6>{>$0WDLxRKH6)POgCTyw+HlSLGSRztoK!#T!qgFGm_zizuI~}Xfa3}q##+Jg zfEc1M-wt%kTEUvlaV8PZib%gj#h>+_g2w)rdHs*i|9Jd=TmO^ie`JEZ{wJ^hHCO%b zpSqXCXT2EG4(NchKmz&8hk;;@A+SvKn9~Y(h~Kvn)|7X?@ld_Pex(X~PGA z^<-VpE| zlgL#k9s75K8l?heSikkeN8g+2r=|^_S;tL^ik-mt+9nbMrjW(wKs);$O#Ck#DBGt8 zgJX?J|30Ik=b0X0107&;e_aUKI~1<=pG~$Wj07#yNpRJ6I4om$PQk}`!eVw3NtRtlV=m$rx4uKxEQA~1G4$0m+h*5Ut6K)Czrm}ft zYI}9aSgQczB+d}$X#{C@(_m$vm1MWqB63>W0F1Id!0hr|c`iXld^rBMT>ZCE{#X9_ zKbc`=`iJCvX*Bk~%3=8i{rZ+#~E_Nl<2 zkx%f3e>Sq0TQb}dBT>PPr`=1t_- z*5Zf}RakKAF~&S-OZj$BaeGh>ns%?6tk8xx8BUEyFh34zobeXLZHK;1bNGnyk zU}XnN`zX_=wr%O~CC_ofR~@P9srK~c7-d@1<`o{@mx&R^9C{;Fg^HfnV35TH{HA9) z_EW9GIR@8pnbAYMAjn4Vz(Q39d;Gqh2TI;7t2N@Nk_ql3J zD{W0v-(0i9flhR=`fJ?&q#S*w7NXT_1uDAPnkML!VBYaOJRr)$cEVEJ5t4;al!0u` zBN%6!jxBv(%lT_Gj{h>R|2NM6ysiJs_y5TRWAhV`tgMI6F0X3P#kQs>CS{lT=An*`ylvS5p~8}^v6yg{?#xahHRBy9Gdh$p*dVdAp( zv|c^0?#Ibbd*ZXQ4akq_LD%KEG$^X5BR%Kw{U^8ezh#3O+GkN^V0T)t1#a2++{+20 z^Ao|=(Fs#G>}ZfMVEq9e+Lh$!$8;2H3q~KaT6)i`MI>-FGp3(h97(p8)BZ z`dIQTszLTOlTo>kD$4IA`MuO!_fiXo#V>}&2g(8uGJpRjq&vCy^OZVA7i4XnX-249OY103V_WueEdGh`LGU4yG|F5O{ z<4dF7|L=1@7H043{?yffjJKp&9{$Vvgz=xxHIzU9S2oo24MRhfO=z<)4D+W%;H{h> zRNm%W8ny`v&G%dub`o)S8FIFT?S-3mY-j&I4Pz&;CjLx&Oy` zy}8l7WcB>d#{TDV|JC04z(iTbaePsz0baL8Cm>NmizOw3$Dh$X&+lwVkZU!TR$8lr zAz8q=3lHwNW9X9krxGL5X*ws+m|KdNBrT^ctybxVxELe>(;4VyP=R8B`P)3Zf582L z=f&gq*`C33_ul)yyyJVn-}}7Z=l3k4_%DRiauZyCpcE1-Oz={&86x_O@Z|X-F!q_? z=w&l>=?pNxMGs$VP4Jw2EoAu>!3kL@#AF%Z{guVAUv7fPnMPRuw-Nq)&;TL3^bn*i zfv!L^tU6o-VWB$MXEwlV+Y6y?#ag(y#RTst^l<7A6R3`s!r_0+@Xlj+YOE$DRN0 zQ2%KWE(t>V7eX11r@{Zt(zjlR=gxaByyk3wB91|Lf=l}MG7lIx{x7uGf4q|PF9vQcdzz~&9F0*^j{xf&Bqwg&*8CU-k-~n)d{68*{^e+bV{#ty#*wEu^-A9iv zdVVebUi9{2!#;X@(Z`D(UTo;$#XjZ}*l<1pdU~zl>ZAsYX8rX;V|f(&tT*Fd9dl`99Wqg4Sk;{!}B+yK@;#0WF{Qp4X;GPiPIPO z_lq-m_3F#Ks{9c!)U@)2Il+AITXW!0PXzRr`9svPX8zq@*ZI}XH1PXS4X?L`LeIb% zKJ;)5oL!@WTK*{LmUZ*gp>Rm~K>_zf22A_BVQl@2@&7o+e`P`v#eX44?c1Ryt&hna z{ZE?z$1(p;qV?a!VDbNSY#rkA5{ie@4VUyk$^IWF_y0*W{x?@fdrmlY&c{{# zPpbdNQU7sN|5XeY{}0ye!rG2fpAR9HvGp&;|Kk|{#!#S&+@5BB$&WsMp#YWqIcvc}v78`iPLn$?!R)Yo5XZ1pdV|HmQzqyLA-|ApWy m{vTgi|6A(+2ZA67f*=TjAP9mW2!bGdhyMTr`%D=C$N&H#Fs5bz literal 6084 zcmZv7XEdAvw{=7a(TU!p_Xvp|WVGnLi_YkbZZHu+5WSZndT$|#Hb#kF29eQ6M2iwc zZ!^#LzTdq+?^@@_UVE>#_m8vJIqZo>>9K z@}b7MZfnC_I+`^ahvi&U6&-cAqI($Rk;!g#87LrkDkea*CX*@7@8gih)Q-MeVd<06 zOiE6k@(&IyT90DZ?J*Wd`L0Df7Oj|bCUmz?F{}}R#(4ZiG$CHD?FVDg$z%!f=i6d6vLS8x zZX#<=Yt8o-ol-UT1dqS9Aiy)O-w`4l_9ZeV?PlI>?lc~(A+-LulP-P#htZtn@94=H z$A5n{*vQ|U6NsUc3~en+FoJm-T0Nh7v?T4HLj}DRDc-vOgofRn+KnynKAfL#l8&Qy z3DA#=^A`k7FHWVN`b*b$EeZJ5kUqhF`>_}! z$M5lPHYN^jF^{pVU9rKRvuoczKQajF)vyjuiyf8#$)_!7hO^4l*CoFy?3S`*3C*qY zidG)5&B%_S6rl+GWc?REr_`%-Y}GJTwm|FiB-f`78HT{jUtHs|l8aHwmsD-j@uFude=;9uO`4)ggO ztT={5J@%EMH=XAdQ!_u^r{MNy@klO7T0J#H9-UgOWLg`TU&s8@TdoamgB-SK_+PPV zrxXTn70tf3URq6*D`X=)z;Db)-pTX5Mx-Jh=gtiX6Hd-T&OdWyg;M=d5z@SPmD*ZF zg@68e^=xT2{~(%u=9`r(r`{69x%XPSK0qn5|BNRq*6s`;@)m1WIZaF!JITUKgw0xC zRgCWDHqi(-67A5Q>7(nj6zv^NsnZ@Q-7|Bitob(;N{Z6ORW;#QUWii*an(rAt@ zC<|>1e)4Sl{6{(Lso@!QiyT6c&3kHt#RxJYKJ(aPj~49MF3f9Vu1oTp=4rLRWTac3 zUTdQy#FfH|Yi5v+*oH2*{P8;XQ*cIsaD`W3eAw=B)T<>#Q7 zGgRA&&b`BA?cTw-=n}ffLI|i8yvo7;v13yev-IRWBU_ba30L=>Cuc`LQPBF{A5o~5 z|FHI;7|lmcuVrmBl_QnirbK=ahSYDA^|9>vJ$Nd6vS{m@%#Ljj*>L^$!T3<4nnB_v zKPRGBf$r^}7?Y*B=*ezb<3CvM#C^sx9n-xp^S+iQ`mMv^QZfLiMuNm!X9fK2MTxik zJ20-o$7Omx+ET0pP%)iqMOkI0VJi)LHxV90h;sLaKNZ+F{lZ+2a`k}rQM}1ZhQaZw zUD5ns#3cN<8aX(Mgga(Wb}wM^*L|s!-a~W$<8SXlmCc!ZNK5tCR6OOg&xoVR3AwP^ z3Ju!n5<*_+t>hD0a4(T_fEL<+qix=x{LO;-81Vi^D z_5O1h%WtpK)k68}BaRIB7a7g3WV7W0Ra6SzGqkl#Z{BDnYFQz0n{P812k=0Fl)D-r_73)M&3s`4*x}F%lo$D|ua0o1u^3z8 zbS@?n?y=QX6`JKA9HO%Fas|OpRg(954oQ67a$>ntpW6rD?r{c!GB+lJIH#5`Y6QvtM4gFq9Mlx(#j9SOX)mN)FPNPpww&uBpzN69|+kAqcLPL{PSeFCW*e<$QS@vIba zl**y8H7-EFpZMjp1s(JnC$+*e&IF9{Z$(d9eh+hN*0M#lrMgLKz@tdK zD_HR=j62c9UEydWR^aj0PgOhHclWduyNfqFWJbaYKT;PHHd^-+t#*iwtVao6tj(zf z_pk;f$%ACfupMGpGC=As1F)?Ge8Au{F&qhW4?qo^zSPxN`bEY*;GgO1A|KnGM=TOZ zd*jz4D;hI0FzQMCO|2)fEuTKE%Z>kR=xh4vsGtFM8?x7aI1;04rQLBP?DqE_OxA-s zO5muQ92lCJov(YcZ)$mq@NpY8Tp9DQ@bdD9I8l&*5e3)M)xHcG7ay zsnC_wj{KEm8?Kq*0VXrsL*5|Sl#PA==ECFj`rx>X$P#XD)B8RH7e@2HS{4+$6(hSH z_+>3l-gApHpT_16bUHC^S)(e|)er^-9P01}TYKdiy8H%2uY?Cre6&%Cb*fUf+4eN< zv)W#lS>lfj>%lBChUhLx0zXlFX}ixxUq=_G+#+ev=Sa$?e9|4nX#2tuNR*JyPu`Zo z+p_*;aOl8mNY65}iGlpd$v<8=+&y(Pz0Y~Tt{h!uXX3!MSUb2*-ym9R+(bLsoLME3 z($=!;C;13P73-0!gtmLy<2JeB`5iLaZ$RL&R+AQBCrB)1UFEurwWnFkk(_#{fBV4A zU_G3oeM=p*Q7WlBhT$E>XDpXAqT{ofD2ZH-ehja&+wh=nQzfc|^=Gjim2*-(k~ZmxJ&V}>@fNIbmDsc6G#AP;_2h4OiA2_+28 zEIZ}S5FReR3g=ZC<`h?%oO`$AOTK2E`c|zs#5zJQ|ArJu7Sar;s1(>X7ctcp%}K10 z1ydU&$pl75#7Z!Kji_G@RXx+4jn-Vb_o!!2eUDD`DRIl&F=4zpFSAs@jCIZ{OO&*DoH6tf=+~?oARA!y8woJJn+WJP*B~!(mdJ~ zlQO4STY^EB zk*Yt~+ui9w*R#D2){3E>*_rRUmWy_Y`u~Xd1$KQyNrb6tq>041U%d{DKo&*xy8ix# z1? zejpMS=7b6sD&n)WcjCG{jY4TXBGXV|r(+B=yJXDr0ufmRa zVg#n|@fO?Ds&36}tC@bfJd9%pP5$E92TOG5mAZ>TNQ!E}I3c)A zNR#1ZpHE+~n0uYGh}Y1n?J^2=S=K|#w{DYX0 zy)gOe*>!6e1j*v^r>khOUh~X0{2XN>1@vNIIU`@)M4jW3x0)|p<#I6zSqD)wAr}`X z4s^BY-z_nXb}6FXc8VQ@tzrtUK+e@SoL;_+g0O?jFZ$E9V@IRrVF1;4b(B6WB1JK@ z`#M+t1*1EPx@7+8Lt<3A{uMuCa4F5xPG0gss@KgK#bf}+Uw;Y47J#K`o|RyeIe9&u zuKoQE*5?00$-u^pg)f}eyqttd>9s*w!>lPu*P%lJ0*ZQ`q@>*L^A8x;M0KnXoPB} zGnGy|pPyaLh18RsjbP8&8L3&>7+kfn8#lHCzi_C@%@qIoIRS&CR8_sLv;w(1jlCX^ zTD3S*NbMn$g_&He6TIGFQMOyBC!AOPi)T96k686%UsdU#s*HD$eL}z`x>dcbY;!Q`?+8J(nR_QAK z6PG#pZuK1ndiAyHnu-}((|L7fk?tq>u5I**F>aqB6xM|A7f`~G=(*9RqkNN&S`r080<^N{wl+q(oeCN4h)BfO6 z(foOLb=1)RhR6*BTxhPzkr~e&4 zy2uChlC0TLYhrTIYm^LZ&EZV9+UThsMF5rh#_SO3PgkBoul%8nX&VV)347%I?nwi& z12yt}F{+2~nkxGsn%m>Wj|SRCYsE**L0FPXm@x%l;d?(){z8(we5ti4X4Ti*lv!Fw z+>t$%so8HoO6y_6!myiJNd`P_j4G}VWI+MV_9PO6rE>iedZT!)=^WiB0GC9 ziSA^u6AzEBuv;R*a&h^2CyK?~0;_>$avFZ7$-HiEDHpyuOUPWt?HLinswL(gLys{i}R8oV8Ys;=Zpeb&37;l~G zuim2uzx+XpB>8MV6$nqwvbt~3llRok??op2L6r6m;RjAvJBTV{z{^k~-twZ@N?1Tt zMaA=suXCUznhutsOPY%4-ZSO8fD8!TnSsaXLHziiuXF3~-iP@U7w^J4A&2ZU%~w)xkuquUjw zBgOK&X*7)DQhu!1;ne^gPzx9w(YFZmhFR-xeh$I10Ql!}TblnuE{(M_Mvvr=RP6tR zn#p5yeuqe1;BY1X;5YLdH9o#KX6{XX2nEEZHSnKBVDQdUZlshCtALz$pqy?dY(9p? zjsq(se~EBaqE(APdT zjW*_QXr?*H?jnj?;w3V!cwBi9K=T6HLE`L9*T3W0&C$Rll=&AM&#}x+Rf#eq{ydsO zGK5h~5nO8k7^9S7$z&fb@0K~6pkc>1YgWAKSmyZeSS+$&cl(hS%XF;u4*hqOGZ_Rb ztEgvu;iZ?u>FHypyLR#UAyVPTcF)ZbV)ePivxI%vpI`)z0_)L4tPjL0Eg3Wx!A}S) zn8(4rOEP2kvTAUhfahme0w5QUA-E&z)B$8o3pbUs$L`QuBVu zThLH!39d<__mzdN6H|_R2>Nq(UxwC{zFKI=BK5?opvEFw?B5El=8pu?RM<|FNOia1381FZ{Qqfx+J}-FJJ%r{_lj`}BuPh}=3t zz8$xv+0FKXv>Nt51ZfXD?(g;5_%9S7Zh+lqKZsM*CiUU#&gr`xP{j&uJ>%J1OSMy6 zkz~Z^hKNx_Vw>hn?^VM#-eP@?9RAml76SpHfV^s43wtsL(Yya^PJE+JVU-Bn!yTA5 zb0Gl2mf>{h3jqnQ|3nr7(A?f@#W}@?JAgLLf5prHBI5rqZ``r}Uxp)PeZal%wG-0! IM(*AFe}XI|-2eap diff --git a/test/input_gen/genInput.py b/test/input_gen/genInput.py index 070362e..cabe630 100755 --- a/test/input_gen/genInput.py +++ b/test/input_gen/genInput.py @@ -42,9 +42,11 @@ def save(filename, *data): np.array(item, dtype=np.float32).tofile(outfile) try: print(item.shape, " data is generated") + print(item) except: pass + ## # @brief generate random tensor def gen_tensor(shape, dtype=dtypes.float32): @@ -81,27 +83,61 @@ def gen_input(outfile_name, input_shape, savefile=True): # @param[in] bias bias data # @return tf_o calculated result def conv2d_tf(x, kernel, batch, width, height, channel, k_width, k_height, k_num, stride, pad, bias): - x_trans = np.transpose(x,[0,2,3,1]) + x = np.transpose(x,[0,2,3,1]) kernel = np.transpose(kernel, [2,3,1,0]) - tf_x = tf.constant(x_trans, dtype=dtypes.float32) - - scope = "conv_in_numpy" - act = tf.nn.sigmoid - - with tf.Session() as sess: - with tf.variable_scope(scope): - nin = tf_x.get_shape()[3].value - tf_w = tf.get_variable("w", [k_width, k_height, nin, k_num], initializer=tf.constant_initializer(kernel)) - tf_b = tf.get_variable( - "b", [k_num], - initializer=tf.constant_initializer(bias, dtype=dtypes.float32)) - tf_z = tf.nn.conv2d( - tf_x, kernel, strides=[1, stride, stride, 1], padding=pad) + bias - tf_p = tf.nn.max_pool(tf_z, ksize = [1,2,2,1], strides=[1,1,1,1], padding='VALID'); - sess.run(tf.global_variables_initializer()) - tf_c = sess.run(tf_z) - tf_o = sess.run(tf_p) - return tf_c, tf_o + tf.compat.v1.reset_default_graph() + input_shape = (batch, height, width, channel) + + tf_input = tf.compat.v1.placeholder( + dtype=dtypes.float32, shape=input_shape, name='input') + kernel_w = tf.constant_initializer(kernel) + conv2d_layer = tf.keras.layers.Conv2D(k_num, k_width, strides = stride, padding=pad, kernel_initializer=kernel_w)(tf_input) + + conv2d_variables = tf.compat.v1.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + + input_variables = [tf_input] + conv2d_variables + + grad = tf.gradients(conv2d_layer, input_variables) + + with tf.compat.v1.Session() as sess: + sess.run(tf.global_variables_initializer()) + conv2d_result = sess.run(conv2d_layer, feed_dict={tf_input: x}) + grad_result = sess.run(grad, feed_dict={tf_input:x}) + if DEBUG: + for item, input_variable in zip(grad_result, input_variables): + print(input_variable.name) + print(item) + + return conv2d_result, grad_result[0], grad_result[1], grad_result[2] + +def pooling2d_tf(x, pool_size, stride, padding, pooling): + x = np.transpose(x, [0,2,3,1]) + tf.compat.v1.reset_default_graph() + input_shape = x.shape + tf_input=tf.compat.v1.placeholder(dtype=dtypes.float32, shape=input_shape, name='input') + + if (pooling == "max"): + pooling2d_layer=tf.keras.layers.MaxPooling2D(pool_size=pool_size, strides =stride, padding = "valid")(tf_input) + elif (pooling == "average"): + pooling2d_layer=tf.keras.layers.AveragePooling2D(pool_size=pool_size, strides =stride, padding = "valid")(tf_input) + elif (pooling == "global_max"): + pooling2d_layer=tf.keras.layers.GlobalMaxPooling2D()(tf_input) + elif (pooling == "global_average"): + pooling2d_layer=tf.keras.layers.GlobalAveragePooling2D()(tf_input) + + pooling2d_variables = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES) + input_variables = [tf_input] + pooling2d_variables + grad = tf.gradients(pooling2d_layer, input_variables) + + with tf.compat.v1.Session() as sess: + sess.run(tf.global_variables_initializer()) + pooling2d_result = sess.run(pooling2d_layer, feed_dict={tf_input:x}) + grad_result = sess.run(grad, feed_dict={tf_input:x}) + if DEBUG: + for item, input_variable in zip(grad_result, input_variables): + print(input_variable.name) + print(item) + return pooling2d_result, grad_result[0] ## # Tested with tensorflow 1.x (1.14.0 and above) @@ -250,54 +286,70 @@ def bn_tf(x, *, trainable=True, init_beta=gen_tensor, init_gamma=gen_tensor, axi def gen_test_case_conv(i_b, i_c, i_h, i_w, k_c, k_h, k_w, padding, stride, bias, base_name): - x=gen_input(base_name+"conv2DLayer.in", [i_b, i_c, i_h, i_w]) - kernel=gen_input(base_name+"conv2DKernel.in", [k_c, i_c, k_h, k_w]) - with open(base_name+"conv2DKernel.in", 'ab') as outfile: + x=gen_input(base_name+"_conv2DLayer.in", [i_b, i_c, i_h, i_w]) + kernel=gen_input(base_name+"_conv2DKernel.in", [k_c, i_c, k_h, k_w]) + with open(base_name+"_conv2DKernel.in", 'ab') as outfile: np.array(bias, dtype=np.float32).tofile(outfile) - golden_conv, golden_pool=conv2d_tf(x, kernel, i_b, i_h, i_w, i_c, k_h, k_w, k_c, stride, padding, bias) - save(base_name+"goldenConv2DResult.out", np.transpose(golden_conv,(0,3,2,1))) - save(base_name+"goldenPooling2DResult.out", np.transpose(golden_pool,(0,3,2,1))) + golden_conv, golden_grad_input, golden_grad_kernel, golden_grad_bias=conv2d_tf(x, kernel, i_b, i_h, i_w, i_c, k_h, k_w, k_c, stride, padding, bias) + save(base_name+"_goldenConv2DResult.out", np.transpose(golden_conv,(0,3,1,2))) + save(base_name+"_goldenInputGrad.out", np.transpose(golden_grad_input,(0,3,1,2))) + save(base_name+"_goldenKernelGrad.out", np.transpose(golden_grad_kernel,(3,2,0,1))) + save(base_name+"_goldenBiasGrad.out", golden_grad_bias) + +def gen_test_case_pooling(input_shape, pooling_size, stride, padding, pooling, base_name, gen_in): + if gen_in: + input_data = gen_input(base_name + ".in", input_shape) + else: + with open(base_name+".in", 'rb') as f: + input_data = np.fromfile(f, dtype=np.float32) + input_data=np.reshape(input_data, input_shape) + golden_pooling, golden_grad_input = pooling2d_tf(input_data, pooling_size, stride, padding, pooling) + if (pooling == "global_average" or pooling == "global_max"): + save(base_name+"_goldenPooling2D"+pooling+".out", golden_pooling) + else: + save(base_name+"_goldenPooling2D"+pooling+".out", np.transpose(golden_pooling,(0,3,1,2))) + save(base_name+"_goldenPooling2D"+pooling+"Grad.out", np.transpose(golden_grad_input,(0,3,1,2))) ## # @brief generate fc test case data for forward and backward pass def gen_test_case_fc(input_shape, kernel_shape, base_name): - input_data = gen_input(base_name + "FCLayer.in", input_shape) - label = gen_input(base_name + "FCLabel.in", input_shape[:-1] + [kernel_shape[-1]]) + input_data = gen_input(base_name + "_FCLayer.in", input_shape) + label = gen_input(base_name + "_FCLabel.in", input_shape[:-1] + [kernel_shape[-1]]) - kernel = gen_input(base_name + "FCKernel.in", kernel_shape) - bias = gen_input(base_name + "FCKernel.in", kernel_shape[-1:], savefile=False) - with open(base_name+"FCKernel.in", 'ab') as outfile: + kernel = gen_input(base_name + "_FCKernel.in", kernel_shape) + bias = gen_input(base_name + "_FCKernel.in", kernel_shape[-1:], savefile=False) + with open(base_name+"_FCKernel.in", 'ab') as outfile: np.array(bias, dtype=np.float32).tofile(outfile) golden_fc = fc_tf(input_data, kernel, None, bias, activation=None) - save(base_name + "goldenFCResultActNone.out", golden_fc[0]) + save(base_name + "_goldenFCResultActNone.out", golden_fc[0]) golden_fc = fc_tf(input_data, kernel, None, bias, activation=tf.nn.sigmoid) - save(base_name + "goldenFCResultSigmoid.out", golden_fc[0]) + save(base_name + "_goldenFCResultSigmoid.out", golden_fc[0]) golden_fc = fc_tf(input_data, kernel, None, bias, activation=tf.nn.softmax) - save(base_name + "goldenFCResultSoftmax.out", golden_fc[0]) + save(base_name + "_goldenFCResultSoftmax.out", golden_fc[0]) def gen_test_case_bn(input_shape, base_name, training=True): - input_data = gen_input(base_name + "BNLayerInput.in", input_shape) + input_data = gen_input(base_name + "_BNLayerInput.in", input_shape) input_variables, output_variables, grad = bn_tf(input_data) # mu / var / gamma / beta - save(base_name + "BNLayerWeights.in", input_variables[3], input_variables[4], input_variables[1], input_variables[2]) - save(base_name + "goldenBNResultForward.out", output_variables[0]) + save(base_name + "_BNLayerWeights.in", input_variables[3], input_variables[4], input_variables[1], input_variables[2]) + save(base_name + "_goldenBNResultForward.out", output_variables[0]) # todo: change 0 to initial moving avg / std in case of training - save(base_name + "goldenBNLayerAfterUpdate.out", 0, 0, output_variables[1], output_variables[2]) - save(base_name + "goldenBNLayerBackwardDx.out", grad[0]) + save(base_name + "_goldenBNLayerAfterUpdate.out", 0, 0, output_variables[1], output_variables[2]) + save(base_name + "_goldenBNLayerBackwardDx.out", grad[0]) if __name__ == "__main__": - target = int(sys.argv[1]) + target = sys.argv[1] # Input File Generation with given info - if target == 1: + if target == "gen_tensor": if len(sys.argv) != 7 : print('wrong argument : 1 filename, batch, channel, height, width') exit() @@ -310,9 +362,9 @@ if __name__ == "__main__": # : output (1,2,5,5) # : stride 1, 1 # : padding 0, 0 (VALID) - if target == 2: + if target == "conv2d_1": bias1 = [0.0, 0.0] - gen_test_case_conv(1, 3, 7, 7, 2, 3, 3, "VALID", 1, bias1, "test_1_") + gen_test_case_conv(1, 3, 7, 7, 2, 3, 3, "VALID", 1, bias1, "tc_conv2d_1") # second unit test case : 2, 3, 7, 7, 3, 3, 3, VALID, 1 test_2_ # : Input Dimension (2, 3, 7, 7) @@ -320,16 +372,28 @@ if __name__ == "__main__": # : output (1,3,5,5) # : stride 1, 1 # : padding 0, 0 (VALID) - if target == 3: + if target == "conv2d_2": bias2 = [0.0, 0.0, 0.0] - gen_test_case_conv(2, 3, 7, 7, 3, 3, 3, "VALID", 1, bias2, "test_2_") + gen_test_case_conv(2, 3, 7, 7, 3, 3, 3, "VALID", 1, bias2, "tc_conv2d_2") # FC layer unit test case: - if target == 4: + if target == "fc_1": gen_test_case_fc(input_shape = [3, 1, 1, 12], kernel_shape = [12, 15], - base_name = "test_1_") - + base_name = "tc_fc_1") + # Bn layer unit test case: - if target == 5: - gen_test_case_bn(input_shape = [3, 1, 4, 5], base_name = "test_5_") + if target == "bn_1": + gen_test_case_bn(input_shape = [3, 1, 4, 5], base_name = "tc_bn_1") + + if target == "pooling2d_1": + gen_test_case_pooling(input_shape = [1,2,5,5], pooling_size=[2,2], stride=[1,1], padding=[0,0], pooling="max", base_name="tc_pooling2d_1", gen_in=True) + gen_test_case_pooling(input_shape = [1,2,5,5], pooling_size=[2,2], stride=[1,1], padding=[0,0], pooling="average", base_name="tc_pooling2d_1", gen_in=False) + gen_test_case_pooling(input_shape = [1,2,5,5], pooling_size=[2,2], stride=[1,1], padding=[0,0], pooling="global_max", base_name="tc_pooling2d_1", gen_in=False) + gen_test_case_pooling(input_shape = [1,2,5,5], pooling_size=[2,2], stride=[1,1], padding=[0,0], pooling="global_average", base_name="tc_pooling2d_1", gen_in=False) + + if target == "pooling2d_2": + gen_test_case_pooling(input_shape = [2,2,5,5], pooling_size=[2,2], stride=[1,1], padding=[0,0], pooling="max", base_name="tc_pooling2d_2", gen_in=True) + gen_test_case_pooling(input_shape = [2,2,5,5], pooling_size=[2,2], stride=[1,1], padding=[0,0], pooling="average", base_name="tc_pooling2d_2", gen_in=False) + gen_test_case_pooling(input_shape = [2,2,5,5], pooling_size=[2,2], stride=[1,1], padding=[0,0], pooling="global_max", base_name="tc_pooling2d_2", gen_in=False) + gen_test_case_pooling(input_shape = [2,2,5,5], pooling_size=[2,2], stride=[1,1], padding=[0,0], pooling="global_average", base_name="tc_pooling2d_2", gen_in=False) diff --git a/test/unittest/unittest_nntrainer_layers.cpp b/test/unittest/unittest_nntrainer_layers.cpp index 18df891..74771a8 100644 --- a/test/unittest/unittest_nntrainer_layers.cpp +++ b/test/unittest/unittest_nntrainer_layers.cpp @@ -330,8 +330,8 @@ protected: virtual int reinitialize(bool _last_layer = false) { int status = super::reinitialize(_last_layer); - loadFile("test_1_FCLayer.in", in); - loadFile("test_1_FCKernel.in", layer); + loadFile("tc_fc_1_FCLayer.in", in); + loadFile("tc_fc_1_FCKernel.in", layer); return status; } @@ -348,7 +348,7 @@ protected: TEST_F(nntrainer_FullyConnectedLayer_TFmatch, forwarding_01_p) { out = layer.forwarding(in, status); EXPECT_EQ(status, ML_ERROR_NONE); - matchOutput(out, "test_1_goldenFCResultActNone.out"); + matchOutput(out, "tc_fc_1_goldenFCResultActNone.out"); } /** @@ -364,7 +364,7 @@ TEST_F(nntrainer_FullyConnectedLayer_TFmatch, forwarding_02_p) { out = actLayer.forwarding(in, status); EXPECT_EQ(status, ML_ERROR_NONE); - matchOutput(out, "test_1_goldenFCResultSigmoid.out"); + matchOutput(out, "tc_fc_1_goldenFCResultSigmoid.out"); } /** @@ -380,7 +380,7 @@ TEST_F(nntrainer_FullyConnectedLayer_TFmatch, forwarding_03_p) { out = actLayer.forwarding(in, status); EXPECT_EQ(status, ML_ERROR_NONE); - matchOutput(out, "test_1_goldenFCResultSoftmax.out"); + matchOutput(out, "tc_fc_1_goldenFCResultSoftmax.out"); } class nntrainer_BatchNormalizationLayer @@ -388,10 +388,8 @@ class nntrainer_BatchNormalizationLayer protected: typedef nntrainer_abstractLayer super; - virtual int reinitialize(bool _last_layer = false) { - int status = super::reinitialize(_last_layer); - // loadFile("test_5_BNLayerInput.in", in); - // loadFile("test_5_BNLayerWeights.in", layer); + virtual int reinitialize(bool last_layer = false) { + int status = super::reinitialize(last_layer); return status; } @@ -477,8 +475,8 @@ protected: in = nntrainer::Tensor(3, 1, 4, 5); expected = nntrainer::Tensor(3, 1, 4, 5); - loadFile("test_5_BNLayerInput.in", in); - loadFile("test_5_BNLayerWeights.in", layer); + loadFile("tc_bn_1_BNLayerInput.in", in); + loadFile("tc_bn_1_BNLayerWeights.in", layer); } void matchOutput(const nntrainer::Tensor &result, const char *path) { @@ -516,12 +514,12 @@ TEST_F(nntrainer_batchNormalizationLayer_TFmatch, nntrainer::Tensor forward_result = layer.forwarding(in, status); EXPECT_EQ(status, ML_ERROR_NONE); - matchOutput(forward_result, "test_5_goldenBNResultForward.out"); + matchOutput(forward_result, "tc_bn_1_goldenBNResultForward.out"); nntrainer::Tensor backward_result = layer.backwarding(constant(1.0, 3, 1, 4, 5), 1); - matchOutput(backward_result, "test_5_goldenBNLayerBackwardDx.out"); + matchOutput(backward_result, "tc_bn_1_goldenBNLayerBackwardDx.out"); } class nntrainer_Conv2DLayer @@ -587,8 +585,6 @@ TEST_F(nntrainer_Conv2DLayer, save_read_01_p) { TEST_F(nntrainer_Conv2DLayer, forwarding_01_p) { reinitialize("input_shape=1:3:7:7 |" "bias_init_zero = true |" - "weight_decay=l2norm |" - "weight_decay_lambda=0.005 |" "weight_ini=xavier_uniform |" "filter=2 | kernel_size=3,3 | stride=1, 1 | padding=0,0", true); @@ -596,22 +592,21 @@ TEST_F(nntrainer_Conv2DLayer, forwarding_01_p) { ASSERT_EQ(in.getDim(), nntrainer::TensorDim(1, 3, 7, 7)); ASSERT_EQ(out.getDim(), nntrainer::TensorDim(1, 2, 5, 5)); - loadFile("test_1_conv2DLayer.in", in); - loadFile("test_1_conv2DKernel.in", layer); + loadFile("tc_conv2d_1_conv2DLayer.in", in); + loadFile("tc_conv2d_1_conv2DKernel.in", layer); out = layer.forwarding(in, status); EXPECT_EQ(status, ML_ERROR_NONE); - matchOutput(out, "test_1_goldenConv2DResult.out"); + matchOutput(out, "tc_conv2d_1_goldenConv2DResult.out"); } /** * @brief Convolution 2D Layer */ + TEST_F(nntrainer_Conv2DLayer, forwarding_02_p) { reinitialize("input_shape=2:3:7:7 |" "bias_init_zero = true |" - "weight_decay=l2norm |" - "weight_decay_lambda=0.005 |" "weight_ini=xavier_uniform |" "filter=3 | kernel_size=3,3 | stride=1, 1 | padding=0,0", true); @@ -619,12 +614,12 @@ TEST_F(nntrainer_Conv2DLayer, forwarding_02_p) { ASSERT_EQ(in.getDim(), nntrainer::TensorDim(2, 3, 7, 7)); ASSERT_EQ(out.getDim(), nntrainer::TensorDim(2, 3, 5, 5)); - loadFile("test_2_conv2DLayer.in", in); - loadFile("test_2_conv2DKernel.in", layer); + loadFile("tc_conv2d_2_conv2DLayer.in", in); + loadFile("tc_conv2d_2_conv2DKernel.in", layer); out = layer.forwarding(in, status); EXPECT_EQ(status, ML_ERROR_NONE); - matchOutput(out, "test_2_goldenConv2DResult.out"); + matchOutput(out, "tc_conv2d_2_goldenConv2DResult.out"); } /** @@ -643,8 +638,6 @@ TEST(nntrainer_Conv2D, backwarding_01_p) { input_str.push_back("input_shape=1:3:7:7"); input_str.push_back("bias_init_zero=true"); - input_str.push_back("weight_decay=l2norm"); - input_str.push_back("weight_decay_lambda = 0.005"); input_str.push_back("weight_ini=xavier_uniform"); input_str.push_back("filter=2"); input_str.push_back("kernel_size= 3,3"); @@ -670,23 +663,104 @@ TEST(nntrainer_Conv2D, backwarding_01_p) { nntrainer::Tensor out; nntrainer::Tensor derivatives(1, 2, 5, 5); - float sample_derivative[50] = { - 0.25, 0.5, 0.5, 0.5, 0.25, 0.5, 1.0, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, - 1.0, 0.5, 0.5, 1.0, 1.0, 1.0, 0.5, 0.25, 0.5, 0.5, 0.5, 0.25, 0.25, - 0.5, 0.5, 0.5, 0.25, 0.5, 1.0, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 1.0, - 0.5, 0.5, 1.0, 1.0, 1.0, 0.5, 0.25, 0.5, 0.5, 0.5, 0.25}; + std::ifstream file("tc_conv2d_1_conv2DLayer.in"); + in.read(file); + + std::ifstream kfile("tc_conv2d_1_conv2DKernel.in"); + layer.read(kfile); + kfile.close(); + + out = layer.forwarding(in, status); + + for (unsigned int i = 0; i < derivatives.getDim().getDataLen(); ++i) { + derivatives.getData()[i] = 1.0; + } + + nntrainer::Tensor result = layer.backwarding(derivatives, 1); + + nntrainer::Tensor grad_w; + std::ifstream wfile("tc_conv2d_1_goldenKernelGrad.out"); + grad_w.read(wfile); + wfile.close(); + nntrainer::Tensor grad_s; + std::ifstream sfile("tc_conv2d_1_goldenInputGrad.out"); + grad_s.read(sfile); + sfile.close(); + nntrainer::Tensor grad_b; + std::ifstream bfile("tc_conv2d_1_goldenBiasGrad.out"); + grad_b.read(bfile); + bfile.close(); +} + +/** + * @brief Convolution 2D Layer + */ +TEST(nntrainer_Conv2D, backwarding_02_p) { + int status = ML_ERROR_NONE; + nntrainer::Conv2DLayer layer; + std::vector input_str; + nntrainer::Optimizer op; + nntrainer::OptType t = nntrainer::OptType::sgd; + nntrainer::OptParam p; + nntrainer::TensorDim previous_dim; + previous_dim.setTensorDim("2:3:7:7"); + + input_str.push_back("input_shape=2:3:7:7"); + input_str.push_back("bias_init_zero=true"); + input_str.push_back("weight_ini=xavier_uniform"); + input_str.push_back("filter=2"); + input_str.push_back("kernel_size= 3,3"); + input_str.push_back("stride=1, 1"); + input_str.push_back("padding=0,0"); + + status = layer.setProperty(input_str); + EXPECT_EQ(status, ML_ERROR_NONE); + layer.setInputDimension(previous_dim); - std::ifstream file("test_1_conv2DLayer.in"); + status = op.setType(t); + EXPECT_EQ(status, ML_ERROR_NONE); + p.learning_rate = 0.001; + status = op.setOptParam(p); + EXPECT_EQ(status, ML_ERROR_NONE); + + status = layer.initialize(true); + EXPECT_EQ(status, ML_ERROR_NONE); + status = layer.setOptimizer(op); + EXPECT_EQ(status, ML_ERROR_NONE); + + nntrainer::Tensor in(2, 3, 7, 7); + nntrainer::Tensor out; + nntrainer::Tensor derivatives(2, 2, 5, 5); + + std::ifstream file("tc_conv2d_2_conv2DLayer.in"); in.read(file); + + std::ifstream kfile("tc_conv2d_2_conv2DKernel.in"); + layer.read(kfile); + kfile.close(); + out = layer.forwarding(in, status); for (unsigned int i = 0; i < derivatives.getDim().getDataLen(); ++i) { - derivatives.getData()[i] = sample_derivative[i]; + derivatives.getData()[i] = 1.0; } nntrainer::Tensor result = layer.backwarding(derivatives, 1); - // todo: add golden test for this. - // matchOutput(out, "test_1_conv2dLayer.in") + + nntrainer::Tensor grad_w; + std::ifstream wfile("tc_conv2d_2_goldenKernelGrad.out"); + grad_w.read(wfile); + wfile.close(); + nntrainer::Tensor grad_s; + std::ifstream sfile("tc_conv2d_2_goldenInputGrad.out"); + grad_s.read(sfile); + sfile.close(); + nntrainer::Tensor grad_b; + std::ifstream bfile("tc_conv2d_2_goldenBiasGrad.out"); + grad_b.read(bfile); + bfile.close(); + + // @TODO Compare with golden data after getGradient function is implemented. } class nntrainer_Pooling2DLayer @@ -713,13 +787,31 @@ TEST_F(nntrainer_Pooling2DLayer, initialize_01_p) { reinitialize(); } TEST_F(nntrainer_Pooling2DLayer, forwarding_01_p) { setInputDim("1:2:5:5"); setProperty("pooling_size=2,2 | stride=1,1 | padding=0,0 | pooling=max"); + reinitialize(); - loadFile("test_1_goldenConv2DResult.out", in); + loadFile("tc_pooling2d_1.in", in); + out = layer.forwarding(in, status); + EXPECT_EQ(status, ML_ERROR_NONE); - matchOutput(out, "test_1_goldenPooling2DResult.out"); + matchOutput(out, "tc_pooling2d_1_goldenPooling2Dmax.out"); +} + +TEST_F(nntrainer_Pooling2DLayer, forwarding_02_p) { + setInputDim("1:2:5:5"); + setProperty("pooling_size=2,2 | stride=1,1 | padding=0,0 | pooling=average"); + + reinitialize(); + + loadFile("tc_pooling2d_1.in", in); + + out = layer.forwarding(in, status); + + EXPECT_EQ(status, ML_ERROR_NONE); + + matchOutput(out, "tc_pooling2d_1_goldenPooling2Daverage.out"); } TEST_F(nntrainer_Pooling2DLayer, forwarding_03_p) { @@ -728,14 +820,12 @@ TEST_F(nntrainer_Pooling2DLayer, forwarding_03_p) { setProperty("pooling=global_max"); reinitialize(); - loadFile("test_1_goldenConv2DResult.out", in); - out = layer.forwarding(in, status); + loadFile("tc_pooling2d_1.in", in); + out = layer.forwarding(in, status); EXPECT_EQ(status, ML_ERROR_NONE); - float golden[2] = {7.8846731, 8.81525}; - - matchData(golden); + matchOutput(out, "tc_pooling2d_1_goldenPooling2Dglobal_max.out"); } TEST_F(nntrainer_Pooling2DLayer, forwarding_04_p) { @@ -744,56 +834,46 @@ TEST_F(nntrainer_Pooling2DLayer, forwarding_04_p) { setProperty("pooling=global_average"); reinitialize(); - loadFile("test_1_goldenConv2DResult.out", in); + loadFile("tc_pooling2d_1.in", in); out = layer.forwarding(in, status); EXPECT_EQ(status, ML_ERROR_NONE); - float golden[2] = {6.6994767, 7.1483521}; - - matchData(golden); + matchOutput(out, "tc_pooling2d_1_goldenPooling2Dglobal_average.out"); } TEST_F(nntrainer_Pooling2DLayer, forwarding_05_p) { resetLayer(); - setInputDim("2:3:5:5"); + setInputDim("2:2:5:5"); setProperty("pooling=global_max"); reinitialize(); - loadFile("test_2_goldenConv2DResult.out", in); - + loadFile("tc_pooling2d_2.in", in); out = layer.forwarding(in, status); EXPECT_EQ(status, ML_ERROR_NONE); - - float golden[6] = {9.60282, 8.78552, 9.1152, 9.29397, 8.580175, 8.74109}; - - matchData(golden); + matchOutput(out, "tc_pooling2d_2_goldenPooling2Dglobal_max.out"); } TEST_F(nntrainer_Pooling2DLayer, forwarding_06_p) { resetLayer(); - setInputDim("2:3:5:5"); + setInputDim("2:2:5:5"); setProperty("pooling=global_average"); reinitialize(); - loadFile("test_2_goldenConv2DResult.out", in); + loadFile("tc_pooling2d_2.in", in); out = layer.forwarding(in, status); EXPECT_EQ(status, ML_ERROR_NONE); - - float golden[6] = {8.3259277, 7.2941909, 7.7225585, - 8.2644157, 7.0253778, 7.4998989}; - - matchData(golden); + matchOutput(out, "tc_pooling2d_2_goldenPooling2Dglobal_average.out"); } TEST_F(nntrainer_Pooling2DLayer, backwarding_01_p) { resetLayer(); setInputDim("1:2:5:5"); setProperty("pooling_size=2,2 | stride=1,1 | padding=0,0 | pooling=max"); - reinitialize(); - loadFile("test_1_goldenConv2DResult.out", in); + reinitialize(); + loadFile("tc_pooling2d_1.in", in); out = layer.forwarding(in, status); EXPECT_EQ(status, ML_ERROR_NONE); @@ -804,13 +884,9 @@ TEST_F(nntrainer_Pooling2DLayer, backwarding_01_p) { grad.getData()[i] = 1.0; } - out = layer.backwarding(grad, 0); - - float golden[50] = {0, 0, 0, 0, 0, 0, 2, 1, 0, 1, 2, 1, 0, 1, 2, 1, 0, - 4, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 1, 0, 0, 1, 0, - 0, 0, 4, 0, 0, 1, 0, 2, 0, 4, 0, 0, 0, 0, 0, 0}; + in = layer.backwarding(grad, 0); - matchData(golden); + matchOutput(in, "tc_pooling2d_1_goldenPooling2DmaxGrad.out"); } TEST_F(nntrainer_Pooling2DLayer, backwarding_02_p) { @@ -818,8 +894,7 @@ TEST_F(nntrainer_Pooling2DLayer, backwarding_02_p) { setInputDim("1:2:5:5"); setProperty("pooling_size=2,2 | stride=1,1 | padding=0,0 | pooling=average"); reinitialize(); - - loadFile("test_1_goldenConv2DResult.out", in); + loadFile("tc_pooling2d_1.in", in); out = layer.forwarding(in, status); EXPECT_EQ(status, ML_ERROR_NONE); @@ -830,15 +905,9 @@ TEST_F(nntrainer_Pooling2DLayer, backwarding_02_p) { grad.getData()[i] = 1.0; } - out = layer.backwarding(grad, 0); - - float golden[50] = {0.25, 0.5, 0.5, 0.5, 0.25, 0.5, 1.0, 1.0, 1.0, 0.5, - 0.5, 1.0, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 1.0, 0.5, - 0.25, 0.5, 0.5, 0.5, 0.25, 0.25, 0.5, 0.5, 0.5, 0.25, - 0.5, 1.0, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 1.0, 0.5, - 0.5, 1.0, 1.0, 1.0, 0.5, 0.25, 0.5, 0.5, 0.5, 0.25}; + in = layer.backwarding(grad, 0); - matchData(golden); + matchOutput(in, "tc_pooling2d_1_goldenPooling2DaverageGrad.out"); } TEST_F(nntrainer_Pooling2DLayer, backwarding_03_p) { @@ -848,7 +917,7 @@ TEST_F(nntrainer_Pooling2DLayer, backwarding_03_p) { "pooling_size=2,2 | stride=1,1 | padding=0,0 | pooling=global_max"); reinitialize(); - loadFile("test_1_goldenConv2DResult.out", in); + loadFile("tc_pooling2d_1.in", in); out = layer.forwarding(in, status); EXPECT_EQ(status, ML_ERROR_NONE); @@ -859,13 +928,9 @@ TEST_F(nntrainer_Pooling2DLayer, backwarding_03_p) { grad.getData()[i] = 1.0; } - out = layer.backwarding(grad, 0); - - float golden[50] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + in = layer.backwarding(grad, 0); - matchData(golden); + matchOutput(in, "tc_pooling2d_1_goldenPooling2Dglobal_maxGrad.out"); } TEST_F(nntrainer_Pooling2DLayer, backwarding_04_p) { @@ -873,8 +938,7 @@ TEST_F(nntrainer_Pooling2DLayer, backwarding_04_p) { setProperty( "pooling_size=2,2 | stride=1,1 | padding=0,0 | pooling=global_average"); reinitialize(); - - loadFile("test_1_goldenConv2DResult.out", in); + loadFile("tc_pooling2d_1.in", in); out = layer.forwarding(in, status); EXPECT_EQ(status, ML_ERROR_NONE); @@ -885,16 +949,9 @@ TEST_F(nntrainer_Pooling2DLayer, backwarding_04_p) { grad.getData()[i] = 1.0; } - out = layer.backwarding(grad, 0); - - float golden[50] = {0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, - 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, - 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, - 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, - 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, - 0.04, 0.04, 0.04, 0.04, 0.04}; + in = layer.backwarding(grad, 0); - matchData(golden); + matchOutput(in, "tc_pooling2d_1_goldenPooling2Dglobal_averageGrad.out"); } class nntrainer_FlattenLayer @@ -909,27 +966,27 @@ TEST_F(nntrainer_FlattenLayer, forwarding_01_p) { EXPECT_EQ(out.getDim(), nntrainer::TensorDim(1, 1, 1, 32)); - loadFile("test_1_goldenPooling2DResult.out", in); + loadFile("tc_pooling2d_1_goldenPooling2Dmax.out", in); out = layer.forwarding(in, status); - matchOutput(out, "test_1_goldenPooling2DResult.out"); + matchOutput(out, "tc_pooling2d_1_goldenPooling2Dmax.out"); } /** * @brief Flatten Layer */ TEST_F(nntrainer_FlattenLayer, forwarding_02_p) { - setInputDim("2:3:4:4"); + setInputDim("2:2:4:4"); reinitialize(false); - EXPECT_EQ(out.getDim(), nntrainer::TensorDim(2, 1, 1, 48)); + EXPECT_EQ(out.getDim(), nntrainer::TensorDim(2, 1, 1, 32)); - loadFile("test_2_goldenPooling2DResult.out", in); + loadFile("tc_pooling2d_2_goldenPooling2Dmax.out", in); out = layer.forwarding(in, status); - matchOutput(out, "test_2_goldenPooling2DResult.out"); + matchOutput(out, "tc_pooling2d_2_goldenPooling2Dmax.out"); } /** @@ -941,29 +998,29 @@ TEST_F(nntrainer_FlattenLayer, backwarding_01_p) { EXPECT_EQ(out.getDim(), nntrainer::TensorDim(1, 1, 1, 32)); - loadFile("test_1_goldenPooling2DResult.out", out); + loadFile("tc_pooling2d_1_goldenPooling2Dmax.out", out); in = layer.backwarding(out, 0); EXPECT_EQ(in.getDim(), nntrainer::TensorDim(1, 2, 4, 4)); - matchOutput(in, "test_1_goldenPooling2DResult.out"); + matchOutput(in, "tc_pooling2d_1_goldenPooling2Dmax.out"); } /** * @brief Flatten Layer */ TEST_F(nntrainer_FlattenLayer, backwarding_02_p) { - setInputDim("2:3:4:4"); + setInputDim("2:2:4:4"); reinitialize(false); - EXPECT_EQ(out.getDim(), nntrainer::TensorDim(2, 1, 1, 48)); + EXPECT_EQ(out.getDim(), nntrainer::TensorDim(2, 1, 1, 32)); - loadFile("test_2_goldenPooling2DResult.out", out); + loadFile("tc_pooling2d_2_goldenPooling2Dmax.out", out); in = layer.backwarding(out, 0); - EXPECT_EQ(in.getDim(), nntrainer::TensorDim(2, 3, 4, 4)); + EXPECT_EQ(in.getDim(), nntrainer::TensorDim(2, 2, 4, 4)); - matchOutput(in, "test_2_goldenPooling2DResult.out"); + matchOutput(in, "tc_pooling2d_2_goldenPooling2Dmax.out"); } /** @@ -1029,7 +1086,6 @@ TEST(nntrainer_ActivationLayer, forward_backward_01_p) { result = layer.backwarding(constant(1.0, 3, 1, 1, 10), 1); GEN_TEST_INPUT( expected, nntrainer::reluPrime(nntrainer::relu((l - 4) * 0.1 * (i + 1)))); - ; EXPECT_TRUE(result == expected); } -- 2.7.4