From f5546858a688742ec6763181ba7b8ffc6e933758 Mon Sep 17 00:00:00 2001 From: hyeonseok lee Date: Wed, 26 Oct 2022 15:32:59 +0900 Subject: [PATCH] [split layer] make a unittest to test split input dimension by split number - Make a unittest to test split input dimension by split number - Conv2d layer is added to make model has a weight Signed-off-by: hyeonseok lee --- nntrainer/layers/common_properties.h | 4 +- nntrainer/layers/split_layer.cpp | 12 ++-- packaging/unittest_models_multiout.tar.gz | Bin 2062 -> 6862 bytes test/input_gen/genModelsMultiout_v2.py | 42 ++++++++++++++ test/unittest/models/unittest_models_multiout.cpp | 65 ++++++++++++++++++++++ 5 files changed, 116 insertions(+), 7 deletions(-) diff --git a/nntrainer/layers/common_properties.h b/nntrainer/layers/common_properties.h index a5a9c46..f99fd96 100644 --- a/nntrainer/layers/common_properties.h +++ b/nntrainer/layers/common_properties.h @@ -250,8 +250,8 @@ public: }; /** - * @brief split number property, split number is used to split input dimension - * in split layer + * @brief split number property, split number indicates how many numbers of outs + * are generated by spliting the input dimension * */ class SplitNumber : public PositiveIntegerProperty { diff --git a/nntrainer/layers/split_layer.cpp b/nntrainer/layers/split_layer.cpp index 4e55d81..2b15f07 100644 --- a/nntrainer/layers/split_layer.cpp +++ b/nntrainer/layers/split_layer.cpp @@ -34,20 +34,22 @@ void SplitLayer::finalize(InitLayerContext &context) { unsigned int split_dimension = std::get(split_props); + const TensorDim &in_dim = context.getInputDimensions()[0]; + if (std::get(split_props).empty()) { std::get(split_props) - .set(context.getNumRequestedOutputs()); + .set(in_dim.getTensorDim(split_dimension)); } unsigned int split_number = std::get(split_props); /** * The split is only done along the split_dimension dimension. + * (Assumes input data is continous) * For example, consider input dimension [b,c,h,w], split_number = n - * 1. axis = 1, output_dim = [b,n,h,w], num_outputs = c//n - * 2. axis = 2, output_dim = [b,c,n,w], num_outputs = h//n - * 3. axis = 3, output_dim = [b,c,h,n], num_outputs = w//n + * 1. axis = 1, output_dim = [b,c//n,h,w], num_outputs = n + * 2. axis = 2, output_dim = [b,c,h//n,w], num_outputs = n + * 3. axis = 3, output_dim = [b,c,h,w//n], num_outputs = n */ - const TensorDim &in_dim = context.getInputDimensions()[0]; NNTR_THROW_IF(split_number != context.getNumRequestedOutputs(), std::invalid_argument) << "Given split number does not match with number of outputs"; diff --git a/packaging/unittest_models_multiout.tar.gz b/packaging/unittest_models_multiout.tar.gz index 689f8dba893efacc5f8099322785a32b2c16b25f..ce44df379cd8b08caaf67e961d93aa0a8b2c9a52 100644 GIT binary patch literal 6862 zcmV;<8ZqS`iwFP!000001MFCPIF(u7W<()M4k>NqlnUdtCHC{Xg(4*pN(Pa0PN66r z94a%FqSJH|9i;R5FtPWuBaLIJC^dsAr*u#>ok;bv=X<}t?|r|y=8tct`M&GwzOHqx zXYXgN_59ZTtovTS^+ZPn`Nc>3MfyiAF^YLOkeT1U;xE21x%x?2WI{_1yO63GA|YLC|vB!VnKgg z+CpU26{k>T90QxD@z|>{0L@1AwB;h7$GI<#b;V!f=5je|*EWE>-)sY|b2!UxsLp2j? z9+^K*)PF|mc8W>(fyrqV$05@p&<`mP{7tXwFF{;PL|D9^e^j7fcyw5lU!Z?fNW`CQ z1%IG_6Yhk6)qmezlP~)J9O&N@`*k+0wUgKYXUf~;L0wf`()K$~`JrKmf89W)n=R;k zh7Beb%wRY(4l~E3)tQf$&uG<_k?7U4qW=D)e;I&XRa~s9jDW$Sf!Ldqh+@^jIPvP1 zb>)y3l$kGy{CYKNP8x#&nv@{`2r1L@=AH8Kh9?zDzu?fNYkEbK>z#UWZ^HQ$RZ6q)KjC(P(E$0 zSBG?jJe>_JrEMBoSiV%5b{H1XZ6{6i_aF3dF_R;}wA#Igu4R6WNW)AKPh6fM*57zfj^8?=O?e)jmI5fHor(wE zBp=fVqEPdtG85W;1`6gc>_f}oj%#gx*Q!Pl+Wqp^~d?}-4OcPwu01+ zM&rzp(R9?Tkk%KDLUp+k`S)CK;MitBYtzt~#y+45wX}uF!=k8X}&$(_^tO z*c*1!Zu2O%i;-}1hk2D8)t4#Y-JX{uA(c%v^&w1jZ8XXvLJ-kniK7e0fUPbQ@pRdC z3APWgz@3<0nysBrzV;o=+q;QqNLf!eUVswiz9IKgTiCZO#GW8A%v>=RZyHb2Px%s} z^;DnsiKiY(Syn3e6HkT@&QhwuC!Ut8|HhNR^wWC!5c=DC^4+*Z_(&`mqr0R;JY8t- zV7F_i39BM6iFi8J&Y_2HMf8lrLAjO!Jr1v;DRVTjZI&7dmR%v+A`Vo7RVe3DDJgwW zhd@`Ic$!7@T%Lo^sq^2%fA{~tCI|c@|2O%!^IzYU#TWno9QbGcm*Y7cKFf0`hqC}b zEp}msj}4*hbC<|Sv4+mwT8MBrCCC%V@~kTIWulSlEltlCCt*t#N2FIq%?*R%3{lZ? zomM9M(n0Ad*!h(gf_4ss{KcEJzGo}X%1V_UwL4>M#~tc8ew>oEx6`w06GUgn^SYI! zsM^D3R2~cLFnHnvQZD!L^HJLAJe`nMnyo`27qU#K{KhCpPl7xFIk7KBgF@uEHYzbQpa86R5 zsg1Cy;VfQS8sr!!kj|wz(L7C2gbp0sb3^Q70BVEExWs^EJ=(ozu}`7wx`*e8|wvZ}x+>u)S1EZmo#w<=1hy1H3Ecs5Ae99FdUKr1q z%r+#}H-k~>3c<5st+ed^JtkqsIK1B?hUB!HROVxaL`!2Da<_!$o%@zFtFx(b1DDzr z-jFVLG>o6cF*kQlrO=lfsL+gy@&oy#bNL?SEii{>s}`;AzE4-xMq~Gl2=Myev1aK7 zDm>$g;QQJ*AvXcG(hQDiP z*O~}O_s#ijn>6I>Gbr6VP*9fD%1;?^h+Jc@u~*mGlU$n}TQTscX#bns>c-|ZKH^>2 z??C-r)mH8$+h^n0!CM^Rel-m2?gNxL^#%?2phS-^2pRp^5`>bYB$rkou#7L}?^2sd z0~B-VZMrl&ZT&))m!?3$9!kQW=X%pfyL7r4U;{y?4EBcT!tmp86zrXb(2h{}z88c2 zkqoSG`%6bb999Y1rza0cs)owbtZwa40 zC&hL;+~+qwmSLV;Z4n5pr_=d0r3N!K+Ah`ub;gthqT8amjqfY|4C?A6>&+GtLKhl!K9b+6%E- zW3bEZ5WUNmg+pE(IU4%mbhkP6UJ;O1ye+g#c9U}<2aOr?=z@Xw+%@1^AWjuI}lC=KJI`YzT zfipJ&w;cq`G;d`{Z>gqHR>A1ncbut6nn*j>xG>yFJ86T^mUh1%%=6j$73JtYWES>X zp-_DhIbRD#kLe+@FY-gBc`RKkDWVNEGEkHqfL|JFXw8UHs>2C7+jWgxWFAw<@DkdT zsEdq59fOJUYN z2XvoGrio%5Wc6SUUT!i(>>NHVm?lRu%yn887KI>p1Ca+f<=r9eEefc)JC1qqUhdt#k@Y=|JR61SFc`R9`;7b8?~e zRI7O0=^6>W)H~K|tv=Ftn-ShBlUR%8uyRK^=c{X-(!N-C3Q3> z&4cbe38g2C(`kRD2QBilBP|adtjZ~2UN$oDo+pLm?c($}IfCR~|46pFE0`g6F#>JR zY}%MtL=NJ!nWZ<e_tfaXC%^lk-4u8{vf z@jc=TJ}-VR|2O}9`M;Sl_kYO$zkL7sEcj>sf5kNf8?-l2epnh6X_&#XZ|}3KNkYRg z9TdDiL<%ly7`ns_ZpE6&OX*@Z`fj3xUCSV*+&}}x72(+}h0Gu)T5@tMJ+cvl#ERvZ zcBz~;ms`L_)136nGnp45qbVU;mmXMVQC5&Nx?fC&sjvS(+OhMm8t^_2;At=F-aFc8 zB@#+nXuIFfyV54^Oi@NjyB6veC52Y&q)eDhWJ%UarWr!%F3BLm*g_^*LYikZQ4->z zbDr~@=RCj956^R7f5GRR@8@$q@9zgX7N!{I9)z7EN1=PKHR9H)q3A#)=4j7jb(!7} z9MPi2gd`?7u#`2m`?CsNMN}-fMrtpz&>Q%Gb?rG!7nf$?LC6)FH8ujg<2)3rq(F2e z0yDC_U@=pkO%|r&ma-BiU+||mO(Nqskg3 zq=hxHr`6N3+BX8LCh;(TdnUeDv8VROUi=g9U(oCgLX0e$k6H6y(LIeAgf598^|}(e zqh=3Ih$8|%74!AxZKZ~tT2S>CLT*bHjeFV0j6T0{@+3v@KiNR|yqJ==2IKwX7_>LI zV}f)Yl~NMgeUiY=5W2_5LG6!a6dXTDQ~kd|d`JpL+n6HRG7b-|tTEVO#a=H8Ku4U2 zY|kd5W==o7oF_+V&r;ElzKJikwM5o64h5(i(~_BHLpyj*W*Hfki^w%ISDb%;9vL-d ziF^GVNhx`!_(=Um5`}iNFFQDBVT}0WV4-+(pB{awxpaauxZMS!m%o&YxzD`oS2wQ_ z)g5$~gxlpxcoyp=4;3a$+`2|c;(yb^iI{L4znX>ZP7a97RK;H=r`h^cj$}!RhuHH& zE^Ao#ibg-!Bq>w!mzb(|i_K#Al8eO=l8z4LRn1DJG!UrNIq-Fz}83ud99I#gCwk_h-HgW>aB{iEt-gRJ4lM*RU|a6pj+Ek;-IcCIh>sY?)Z^V@9<|y{p#5L zmp^3`)Umc>A=p0sA^$@E3h<5{re4ks8r`f;M-Anvrtct#`3T>b=)=bgm`2 zhvl%iK?QdC^Hj9lc*a_Tu2GG@1zTjBfT)B-n5KOL`8F5I2+~2?>==mmuYlZ~Ja%Ab z8SP$VgTzi#?0h{2!zTJ8cJp*_KmV;JoSsZ4&(6U3x2LGp`2_1Z=Y)*>aQG=sWZ!XZ zNnZC3eTY0pyF%?*q3v@v;Q5elZQjC?`h3w)yAa;)$~5P`2Wfn2r^3x?XwdgW++WI6 zz90cByXN9cf0<}Ck+g?HUHJj*L8pi{@FH=}%7Q$bYe^|<0(RdXjU4_Z@@kgQY>OPW zuJk_bg9DE4cE?gxFB-_b{aXV>M!%jVw?_xW^0yPI`BC}MOg#T1 zP;~yyT_y@UJ~YQVH?E;&`!<|CQGjEyIp{L6M(3IVEOY0&3`=amAL$$yMVn5Di**q_ zJyc}>t1`XtyG)yIJrePJb#O>A3YDC%Tr>}_g}~PpjeFm~YJU?9uSmfSZ^BFeyD+cP zhFp6ha)V1Kc!UCuNeX3m-Ty)V&VRQ5k^iRsefiJE{=e_Pwo`3n`S0)Wul!eK=uYe0 zjL~ZNkY;GiMAlI+xNX@)wfO~<mGyQQ8y9t&|=Y6jDg%jD2C6Z0L)s7uJh*sKs3 z&uORpKmm5ykECTjc32u}2le4$WO6PZo?(ya<;7U6eV4)-YWTFP-<5qjX90glJt%Ha z#kLe{rd7v*;L;IFJS&HL=1c?4^;R-ef3gq{9`#ER~*9Mbs)By*v_&{CL;{XKmoKK!}vBsDml@`KK+IOZ5S84cm?7~wSy8J}ihag!#v zt}8IRT8G>pyyC~G>A|by1@nwehH>K|dN;$9e4@K3XYUeJsfVL=UJgy0@(bm?bt3)W zqoLYv2ffd4?4#@i^Z00Zt?l1Y` zYDW<64Q!?F9IWX+UyaK4Wy0sGJ!KV@QT{3(8Yhk7cNRt?%VH%C6>1LcG{zq;ney?u zXwLfraoA6W;-D?^l3Gfjj0{DIs9Hx+ zG|E;L@0CPkb4MGyvei-FI53M{Z5o%y= z^Cnu8mJaRSNuP6VE-ksy#x5TB#7bXpq`!5;T%E3Wnpn`g7X+J`#^7q~q!fT^I?^P-XlZ$kVy4g-1C1YOmF2e>oHZTl-cRTT|0BQfT`d)HospXHn&Q8$CJ}cW-Z~GXAAhi- z<3IE9aily+-9%JYGaB*cH`$_C0WNTs!KZ8`l8VF8_>&1j?1X698hg~}twX(D`r$H2HtL?&lCi1*zsTHV)2dS_N4%34B)J!kRn z4~vJ=_o0}mcY66Lqxj$;@;%49^)r08Ip&jX_qq74Dd&(*By6Lp`qVwg7j0 zov@Kvk`#1_{hIdITE8P0&)I_WE{UU1qSa+Sbg7h zx>7ossfO6lR52G3>edvY!NbKO1i_@% literal 2062 zcmV+p2=VtHiwFP!000001MQc2P!nkyhXG{}2trV!tV*&vJ0J=skPwda`$|AAHx{Ui z7ZRcz0wEy50}ui6B!EE&(G_*YbyPr684uKS6LHkRbG%l(#ZgDrQB+iL2Hd2r)!m)4 z{KunatDmZ_u1@#c{r1!Oef@QcDpsDJB2QMRGTcY>SYG?QXN$6{sStV&k?cHWq8y!4b2V5hE`-^ z>M8{`)bKD$sKSdHZ~VgWdkt$Zbn+I0znFu{7(29@R}W>k*M+!h`14*Ur0rRQPpvuDunKo;G{lL9kzcR2^zx*^y(aTpFwXor8c*6NIX zJUH5+TczGil;=GsjOUng>pA*wZN|G9OYlZ8T~{rUI& zP5wRj9!&l}1;+f_8kYPy8Q7=I1v`J74@qaQ5yxr+F*7U}Guov%mXnBiYn-tlmyOn| zM?+DAJN9l+W5xMtsG2m7&SLS(aO4$P;}NTSQ2fg@Y<9|^v)FdmpQtOC3Fp3Z0@ij) z*K=$N**DdMGG6z{<7Vh5-2=w@`#z~yow5!}xei$7^)nQQod#oHr|$PfN7VtSQf`7( zedOqRIufG(B*A3&18`m>!$?IiS`>OD9H4ZlK`>v-%O<%w3ox9Rhqm2^q8FYBPx z&KGB8KLzFZ9boJ$tFVMvTeS$j(gDO%B|7%T5R%{HA!+O@xAr-l?bb-w&v%Z;x;i0_ zsj$a+Hj}Z?JPkk7qxvYQQP`Aq&l z1;%%xF=vUa<)C@V#XZN`pqD%t22u^MxiXTN1^1vbCJU=CbP^reZA6nTflDTG@NC9o z7#(|s2;7%Q=hQJK4);Z0fEBKmcqa9@@1CDm!>k=Q;Kc3$XknX#L3|!c4+_ytWLl@<&ARvwpC6a*QrK?IbyAN<7(q*AkDF+&TE2CIGrRwPeq$ zMUdBh3)tNCC%YSJ8ZD||NR4#zX|Ib@qA0gc<4D<@FLws?{) zfjN}1PuccDnAlDM-)0f)JEaBx$0Ybp)j(`%F~rQHU~1kHxUi@Yg3an7`tm~Xb0~(W zwG?n}EvE0&N-_XO<+E|Ow*_9*s&V7vXFzT?!KPG8^w}GZoF@r{w|pq-6#>W!5@GhM zQJDCBPqg!!jmG*v@8Ced^(T^e%YGQN<~|WPU5(s^Z{VeWAcV%XK;HT!+-9Y~%Mp1X za@z$3?LN@|Y8>{sAw;7mw}{p^o~8>=lC4wX-gxpW7zwS_Z#{Vgy!K>z?yaX1pAg!U zo-088y53~`ibUGeg765c!^wkmosdm?>Wo?jf=Bg`o3;doZqUGlD|* z4-an^K(vbnTHE#D&d!HdOP4_XhC|SnoIn z5cU^mTqj&fOdG8r8Z$P*4@3H*!|F-c)Uh0vdI{+ID-uhn*f0-6Pl+L?v;qpV*x2CW zAo)ofhdIMC@f&SVyuWrEA>E$z zL?EzI-+F3kl+vDVjH-U)$$3~P?a5RmK-1awqw$TLbN)?h9{}_o!+7uXZc0a5s z7>HSlzY@6@55mZ174ED53I|6W2d|b8oZwXfEh;uWPb4JPRCLWi(yi;BZc>REHRhKh zvJR5L*k5*bcS^T|Lw~+%Y$xaH0&$#l0XcRERU=LbS4u2Hs&M21Yl-!(-EN z!OL_)zgq(3FZ_sChlk-f?|yWB_r*PMMY@3Kz26pPjawk)km23oe1~Ao=#TZc?2x_b zn(tOSN4%YzfQ|A{wCTvk9_cf22pLLyaw@Z?a)Y?!55iWRZnhbfcet2LzL7+Gy3}G$ zaqqnLG+enJ=KlIEDW1!ry_jzs4E8m4*yp4pyEU> zZqc4v&WwR08?V9k#8dDfe;4o~i($whCp;g19kh2Nz)B>+tKZLs!{ZOZ*rl(a)sKP) sRyojA split_axis3_split_number5() { + std::unique_ptr nn(new NeuralNetwork()); + nn->setProperty({"batch_size=2"}); + + auto graph = makeGraph({ + {"conv2d", + {"name=conv", "input_shape=3:4:5", "filters=3", "kernel_size=1,1"}}, + {"split", {"name=split", "input_layers=conv", "axis=3", "split_number=5"}}, + {"addition", + {"name=add", "input_layers=split(0),split(1),split(2),split(3),split(4)"}}, + {"mse", {"name=loss", "input_layers=add"}}, + }); + for (auto &node : graph) { + nn->addLayer(node); + } + + nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"})); + return nn; +} + +static std::unique_ptr split_axis2_split_number4() { + std::unique_ptr nn(new NeuralNetwork()); + nn->setProperty({"batch_size=2"}); + + auto graph = makeGraph({ + {"conv2d", + {"name=conv", "input_shape=3:4:5", "filters=3", "kernel_size=1,1"}}, + {"split", {"name=split", "input_layers=conv", "axis=2", "split_number=4"}}, + {"addition", + {"name=add", "input_layers=split(0),split(1),split(2),split(3)"}}, + {"mse", {"name=loss", "input_layers=add"}}, + }); + for (auto &node : graph) { + nn->addLayer(node); + } + + nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"})); + return nn; +} + +static std::unique_ptr split_axis2_split_number2() { + std::unique_ptr nn(new NeuralNetwork()); + nn->setProperty({"batch_size=2"}); + + auto graph = makeGraph({ + {"conv2d", + {"name=conv", "input_shape=3:4:5", "filters=3", "kernel_size=1,1"}}, + {"split", {"name=split", "input_layers=conv", "axis=2", "split_number=2"}}, + {"addition", {"name=add", "input_layers=split(0),split(1)"}}, + {"mse", {"name=loss", "input_layers=add"}}, + }); + for (auto &node : graph) { + nn->addLayer(node); + } + + nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"})); + return nn; +} + /// A has two output tensor a1, a2 and B, C takes it /// A /// (a0, a1) @@ -176,6 +235,12 @@ static std::unique_ptr split_and_join_dangle() { GTEST_PARAMETER_TEST( multiInoutModels, nntrainerModelTest, ::testing::ValuesIn({ + mkModelTc_V2(split_axis3_split_number5, "split_axis3_split_number5", + ModelTestOption::ALL_V2), + mkModelTc_V2(split_axis2_split_number4, "split_axis2_split_number4", + ModelTestOption::ALL_V2), + mkModelTc_V2(split_axis2_split_number2, "split_axis2_split_number2", + ModelTestOption::ALL_V2), mkModelTc_V2(split_and_join, "split_and_join", ModelTestOption::ALL_V2), mkModelTc_V2(one_to_one, "one_to_one", ModelTestOption::ALL_V2), mkModelTc_V2(one_to_one_reversed, "one_to_one__reversed", -- 2.7.4