From 27907c81bb6ae1936537c24b2e31a9e2b0bfccf9 Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Thu, 25 Nov 2021 21:07:57 +0900 Subject: [PATCH] [gradclip] hot fix + unittests gradient clipping works by adding extending the execution order of the gradients to the last node, where the global norm is calculated and the gradients are clipped and applied. However, weight sharing use of gradients also relies on the last access of the gradient and gradient clipping disturbs the balance of gradient last access. As a quick fix, if gradient clip is enabled, the last access is replaced with second to last access. A better way would be for clipping to be layer, and then last access by clipping layer would be a valid access and balance to the system can be maintained. Unittests for gradient clipping is added with and without weight sharing. Signed-off-by: Parichay Kapoor --- nntrainer/graph/network_graph.cpp | 15 +++++- nntrainer/tensor/manager.cpp | 20 ++++++++ nntrainer/tensor/manager.h | 21 ++++++++ nntrainer/tensor/weight.h | 4 +- packaging/unittest_models_v2.tar.gz | Bin 3108 -> 3187 bytes test/input_gen/genModelsRecurrent_v2.py | 9 ++++ test/input_gen/recorder_v2.py | 4 +- test/unittest/models/unittest_models_recurrent.cpp | 54 +++++++++++++++++++++ test/unittest/unittest_nntrainer_models.cpp | 26 ++++++++++ 9 files changed, 149 insertions(+), 4 deletions(-) diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index bb30cac..c6c9745 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -840,8 +840,21 @@ int NetworkGraph::initialize( first_grad_access)) { rc.getWeightObject(i).setAsGradientFirstAccess(); } + /** + * if the gradient is to be clipped by global norm, then the last access + * is by clipping itself. However, as clipping is not a layer and does + * not contain any weights, such weights never get assigned + * gradient_last_access. This is a quick hotfix. + * TODO: make an independent clipping layer which will execute at the + * end, and will share ownership of weights which it will clip. This + * will remove this hot fix, and also remove the checks of if weights + * require clipping. + */ if (tensor_manager->isLastAccess(rc.getWeightGrad(i).getName(), - last_grad_access)) { + last_grad_access) || + (rc.isGradientClipByGlobalNorm(i) && + tensor_manager->isSecondLastAccess(rc.getWeightGrad(i).getName(), + last_grad_access))) { rc.getWeightObject(i).setAsGradientLastAccess(); } } diff --git a/nntrainer/tensor/manager.cpp b/nntrainer/tensor/manager.cpp index 58d2198..6a4047e 100644 --- a/nntrainer/tensor/manager.cpp +++ b/nntrainer/tensor/manager.cpp @@ -544,6 +544,20 @@ Manager::getMinMaxTensorExecutionOrder(const std::string &name, return {*min_, *max_}; } +unsigned int Manager::getSecondMaxTensorExecutionOrder(const std::string &name, + bool is_weight) { + + auto orders = is_weight ? weight_pool.getExecutionOrder(name) + : tensor_pool.getExecutionOrder(name); + if (orders.size() < 2) + throw std::runtime_error( + "Requesting second last access with less than 2 exec orders"); + /** tensor pool exec order can have same exec order multiple times */ + std::sort(orders.begin(), orders.end()); + orders.erase(std::unique(orders.begin(), orders.end()), orders.end()); + return orders[orders.size() - 2]; +} + bool Manager::isFirstAccess(const std::string &name, unsigned current_execution, bool is_weight) { /// @todo add cache machanism, eg) sort at finalizing requesting @@ -558,6 +572,12 @@ bool Manager::isLastAccess(const std::string &name, unsigned current_execution, current_execution; } +bool Manager::isSecondLastAccess(const std::string &name, + unsigned current_execution, bool is_weight) { + /// @todo add cache machanism, eg) sort at finalizing requesting + return getSecondMaxTensorExecutionOrder(name, is_weight) == current_execution; +} + /** * @brief Create tensors with the given spec * diff --git a/nntrainer/tensor/manager.h b/nntrainer/tensor/manager.h index ed1c00c..87eb2c4 100644 --- a/nntrainer/tensor/manager.h +++ b/nntrainer/tensor/manager.h @@ -263,6 +263,16 @@ public: getMinMaxTensorExecutionOrder(const std::string &name, bool is_weight); /** + * @brief Get the second max of a tensor execution order + * + * @param name name of the tensor + * @param is_weight check if this should be queried in weight pool + * @return 2nd max execution order value + */ + unsigned int getSecondMaxTensorExecutionOrder(const std::string &name, + bool is_weight); + + /** * @brief check if given execution order is the first access * * @param name tensor name @@ -285,6 +295,17 @@ public: bool is_weight = false); /** + * @brief check if given execution order is the second last access + * + * @param name tensor name + * @param current_execution current execution + * @param is_weight check if this should be queried in weight pool + * @return bool ture if given execution order is the second last access + */ + bool isSecondLastAccess(const std::string &name, unsigned current_execution, + bool is_weight = false); + + /** * @brief Check if the manager has allocated tensors * * @return true if tensors allocated, else false diff --git a/nntrainer/tensor/weight.h b/nntrainer/tensor/weight.h index 957203b..c5c6841 100644 --- a/nntrainer/tensor/weight.h +++ b/nntrainer/tensor/weight.h @@ -256,12 +256,12 @@ public: * @param global_norm the global norm for all the weights */ void clipGradientByGlobalNorm(const float global_norm) { - if (global_norm > clip_by_global_norm) + if ((global_norm + epsilon) > clip_by_global_norm) grad->multiply_i(clip_by_global_norm / (global_norm + epsilon)); } private: - static constexpr float epsilon = 1e-8; /**< epsilon for zero comparison */ + static constexpr float epsilon = 1e-6; /**< epsilon for zero comparison */ WeightRegularizer regularizer; /**< regularizer for this variable */ float regularizer_constant; /**< constant factor for regularization */ diff --git a/packaging/unittest_models_v2.tar.gz b/packaging/unittest_models_v2.tar.gz index dea2ea14bd8709b9d69c2966d8f0faf6c1905531..c2a657f95e805880082642695df38d05c8a89e74 100644 GIT binary patch literal 3187 zcmV-(42<(1iwFP!000001MQdzR1`-V$HyZo1O-LFYXC_C2nch)3`XiJA&w%Es4>x? zh7ke5$Q2F`2nq;zfryo6n^O|nx>I@#SM zYu-N5>`VXOn^&*9>+7znZ@#a7Uri6zCq%_YM?~o3W<-TWgt|sWMMj5&MubI2goH-P zo{AI(X*BMvRcbvH|Jn*ts@!EtccohGp;Bu-G%~|ejaDhMRXm-~|92(C#RtXO+RA3e z1jWt>4iB398~vX2{Qp!P4ZjYCU#WaJur|ZThO&0kp!a#!hTrx@)@Igk654wCHKE^Y z-A4)i4mxED{hWhV3eV?N1oF>mDw6&OP5(LmN;_zt(9CdBpGgj2%x*~C!cIwmEJW93Q`5(0a>gvDwAoHmU`-ty};RyRxzKgrUI3-QmJg?uT3x&hAS`+QgN@JGG~vu;4MW z_i=Aqf4E%_3@5(#y;4*UOo0bYb+F>j3b1ZfaIahTz;ivk z@%z+hbnvbd*1?TmhqX7WP}ZduUbs?-U%qD!rCZ9dUL6d9LyBSEh!&84v;ZdcS%$5f zh6wBMRe26KY0velAzwLCn}wy&WB3tu{`XL*rTqUaB<6qP z7pH>l(Fi(X>J7MeJe;OKiYGPuyo9{B?!Gtq;iwCJYe*m3d0{vDVZ4&`t}vzkC$>TK zVHv&T_$4uUzdxhK-TJk6cYP4tIP z>p$mm?fdfIPFV#9{pCEf`OY>2Uej9A#D~Uox5)umF}ORwYitc?Jg*w^Z}vnV&qjy{ zF{9LE4|cLEj74Jp;b!0eEzRlaTjxH?QNtk*DF7T*=SbT@BV% zi_k6OHlEJ-#LKm%1Fsr2kk*YU!8K9&uw>Xqj@h){sKum8YqCDI8}I9I1Xft~=Dl@A zU@;*D2HvD#(d0n;zUjbkxKx9^x*GFr40D<*VYh`l(OGqm~0f*Zr@B> z^7=4KNI#KELNY?&-HC2QIn4&9L}!xiXJca{^IA}6Q}+)wOvbA(*J-OQTaOqoGX z=>2H-hsm^Cd;%SWe!_n8p2(B$h(;(GQ$wW1JS`#=v5&uO3KXB{}gpfL|lAiyTZTh$@yRD-j@HBN{w2r(i-x=Mx&AV z|Cy*g_}|Q6BXYl32Z;PHvc1UvVm%<%0V3nGcZl_XSO+l1XS~lC|M7ts5bFZ59$@^> zY5~Uooxh64ZnCz0i0wISFJk)++lSbm!}cO^-#N|g;%;A15eMI!TvdJ$v6@uKIXddd z<>PhSD=tx_5g%3cH}WFRA8+8al}Y4j*9mN{aV1W?#imlc*kr?R(An`*o997Q^C^(E z9)p~1h3MyNijHI$_R-})U79CXUlWC=Z7a#~@0y|JLfd+c&0B|yTsf@S&^dvi`#bP6ntA018171!Op~DSW}*j58oLI1$RrKzm*C1x01)W>Bf)i72g@FySnoo z`AFhX=?_VPOYqh3H8B42*W9AZ1BCS&wwN+6Kt;2|7uJ}bw2=fPzHS9X-N+ld|TjG!<`?qzS8BbaF~0m0mkUf zXjx)46o1(h&5O(Nb#(#vqW>9az1rs6^DzOqbMh6y$b3`}@ViR&9KwH&9^Ufvci#Nl@O_X6MBu@&NWsT(Hh(eFFDEi8n@h`aWe`b{yjC(*=+=E zgAU6y3M$`HfTd|FKJ~jc->y3EfekLsRC6tp(|@`fG$}6pT*Y9jTG$_xqRsiq_fK*! zVJeulmGLo(o-1sHI*hS)dy;n2l-B-uJ0f0%su zsqCwT8Lz!7>!>vBTxX=A?9D*Aygzs#akd9}4KftBk`q50~ z7to*`O=1?U<=*otfPS0rk{f5U;L(EjNyELNR-gahe9F6 z|Fckg@E?mAvCbFkd3H~%=f%36eVf(wtafM5#rj^X;aesC# ztX^4w&7ZZvp+lL%TmtjgVUY!gp)p}_D{GrD*Zr0|Q0Y8cc;07|15HoL;?8W}1BFek zRqUQc<{8**6$dLy9$;)`lrUF5d?nFlrwel(>r@UyHiyC4;y`k3(n@$xa15rVOoK6A z_N2Mv8n7rRgG)`*A)~>y(!M$OPjP-JNs=TE001UVjhO%d literal 3108 zcmV+<4BPV`iwFP!000001MQdzP*g`6$H!X;3W|c)0HOf|gc)X-!AO5C#8E_v_kkJ^ z1tAHC3PgpVpnw;sD2NAw2a1X)Xn6ted!R;)#v>|GQ9)4wb>pTIvMO$poqp^~w#v%7 zwI#Y+dB3WuuJ>N|d)@ub_kI7bXV3DF4xbkh8tNbUNqBH*kW+YgSVUk@XmCVmU{JX1 zjiAt@(WqIg)VeDEwG~jhsAWpEQl)ZrQE6Q@GKEs(s&JKADc;QI|2@%>Q33O;tYjgP zZi7ODg8s1I>z@Cg(pvv(rT-PkhXQMpJS{0}cMO=zvo_?B7qT|FYKPEP$?ppNKI}9? z=y%*cS?K2wuugcssMwExM@y0TFPQzi|3?MP`Yb501p>fp{a^En|7+A5$^YL5qW>@6 zKV0x<&$Ly7e}_EzN$_XKHOqxI#Ac@8=a-`|2+yCIB@6veq@)Vhi--FN*M}TSg#MYs zOoVn$l$S8B|B6JR-^Su!;2kVe;tvR(7dkQ&i`tCABC${!Tew8D&_yT z0`vd&vgZ8HazD%OEVr|Gz+!=stl7tzf9H5bQ%f>~UR50ft%r<;J#(VDks7)rjf2(u zjH&TC88yiWaA#wdJ+kC=L#_Fq%j|iV2L)JiC>B&<7eMEg2K!dKagR(oW2vPF9_|u_ zD|a6yY)s{jEYu9(a6+33=yxRt+i&O!UDJ!u`}s^bGd~ZUM>N2mbxR@iKq|JKG#J>t zTU}qp{hDsXSwFO+bzA{Vski_+*)Ng3kDF)){!4~nw~PZIe^P*BS`UDzwI&!c{va&a zv<-WhZGx({XShRcE1}oyb;#cL?7EJgX$+u~t`x(Tm2w&zu$jbvTPo~>t~Q(=FSDWj z{LX{h;}LX<+YWdf-i2B(_M?{LY9Phehh}L_VV<{|&h=MQHfForiobEOFMle@p6~Xw z09AeC;dxyptbM!|EE*NuNVCp(rLzbAk`RHm9+kpAIP+OpalaI0?JJ<&?HqjgsVNlf zD#R+)EbtqY2aAR^fPBkRnA&4CHtrZC?8A3OY24I9_euwSXGblU6+q{q--3Z-t*{R_ z56pykkCf8CnH&dsuNl~LGr*T~TA^&sad5e@2CE+r1mo#|$>&Z$tig7CCx1kVcfxP+ zzs1zw>*K##^-BC#DYa7ke;bJL-|)Q|V0Ai_4x3R6PtS(X#HJ`xe#Bjfds)vs$gigz z>4ZT&Xxn8S>G~)o=~iq^ea`KHh?6pU)9xWLn%kR}-7}$GcHJ4w#@w-Mg_chvNb!Z| zFu(u%)N}PZ9C7k7XSzhqr^oof*HzcJbeo?1kJC56aUVI)#{6uR1RrY6Xv_;kdeG#nfJS(r&|q#2AWW6bO_s6=WxQjhlha0jduJ^gRLdA7G46cL|;5T z$zF(iy8HL=`0$JHK(6E$94rHi(&gx!^aw8|edX@d(292%(U(?^&c{vRnXq!mc8-l{ ztgXP<5(~05p(F2QdkWT?b>lsBxzK%5JoLR!p?jSz?K#et-*&ScyLK?-*&G(sm%zd9 zYT|#-9DQ7KU{w8B*k!a`*t;X2S&@-L%pmbx0trkCgpVgX6Xi@xm>!W#_Fn!BI!{?c zEOYIl#P2Av(Jc^S`CbzXdO!XXdd}aQc6t#EgfT~*Tj5KeIV3`Fi$!`{w02&PRt}H3 zp~-4K=#&pQ)Li4@^tM>AI2~@8Hrv9YZxXobZIP9-x zEi}$6@J{g~jQ#sL_x&eY@&2Rw(w-mXV{lRi?5W=)*y8+y3VeRVf_NW{Kv$DqKbm@fh~>J+hRycITkh<2(~CpD}g#M2NKz(2dWe~@ZlH!@Z#!j!4~TB z@kG1d5x(8Eg%}==fkK1vq-?bUQi3*Odd_{mJ>K}cQ+|CP5e{TDWQ>3VJ#~D zWv`F_N_BJmSL)-w%0;V>|E_ASr2pRvEusHS^kYQbFV+E~{ui~qsQ<-!K&%5qjnCd8 z)&pW4z%)M7`%L4%JT?c!xWAb2AQ7XQeG8nR- z6hLotBkrgCm$;+$pZgWn7Rx%Q`PO_GaV_zISihC{LC7W;f9rd0`K>;}evMrj#InES3n9H~E&%5^BF#OZtmcm7%g8Q6OxS1`?*e6ZpgH@fI!@=pHP68{m+@xKMd zzwGt#U!#2`{%h1OQv81#w1oa=u}IYLqNW$MJk#k+vx{0@)b^sLXBwXAdG-#b<>ybn z4FwA_(b?GzMIA5ddZzP@Coe^;Y~a#{yP&A&MSaiW8PoGjyE9GC^gGk@OuI8p&-6R{ z4%6*S!!zA3ciPKsb~mK64&^~mw49fFZyU@sJ=6D0)7xLyq1ERK8t<8np7}1kd&5d( z@5p!6pqtWxrUxzK@SB5hk2vy2M-8M$1Nz{dU{l^|*LiMnh68m7^TkSeFqhFj6tr>1 zd{cWljC$EinCt0IyTQ+QEPi2gh%C%Bq~C0@;Eh%{fX7x}-0ov4=xB17;*qDBq`F4} zPU=uWS}z`ocPC~LJh23yIG-Xv+;GC{L3t$j{O9=eW+i#k^#oh%?7te)5uJ_!9gqZH zxS7#odfR4umvdJ~Z!K}05CRJyRKsY06IvKk26+#=plMzaj#OoH@A+JU#yicny*kGi z_fNYG7?z2u0dDA^RPymzWpMb)J^X&$MJOI^j|(O)#j9r;xQ}%%806WPT2Ic$d9`1; zo7Xklc7DuR9IuLp>67JT=vfNL7{SWk|=m_2K(x`~~g}@kjHmT?E_id!WNdsT;_d$xATs!k47~MsI!Y zT~GS_uoQitloOj8ZydZLgTz@R;;++AzLC5-FFgERpZ_adHNVAwS63I6aQ?4RNb&z| z&=UH;t?YOGFY0`DPpkz*&CkBgY69{6Uz0;%pbv8XPUaAKBbEClP4$6E`dECn;s{`w zA?({~f$4h#@z&usG;)3#ylEZI=GeM+0H3DYi80=`=zGjrnD>g|mi&roefcW;weVuM zr!em;zn+CzHQj~pY@3vbm%YBgRULcLWaYO|tsOX{9_c0M7sms6l=>0DC% zbTGO6qSgP@`TB4E|L)Fz6*i3=f5nF0xclcCQKMx|FS zCV*|Pkz8o{0w`L!kUNoR_r}jx*&L6iR=~z}*;xNg1DrUKEUd*ZGYfOOa~L!y7#^hT z5!QOtY#%BeMhVY*?69SYu_@f8y@w#DuCbKe>z;fGcA7`R+WhA@uOwVpD<86sXj2n~ zwT`hbf ...] # Each iteration contains @@ -84,6 +84,8 @@ def record_v2(model, iteration, input_dims, label_dims, name): optimizer.zero_grad() loss.backward() + if clip: + norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 0.0001) optimizer.step() with open(file_name, "wb") as f: diff --git a/test/unittest/models/unittest_models_recurrent.cpp b/test/unittest/models/unittest_models_recurrent.cpp index 75d044f..25deff6 100644 --- a/test/unittest/models/unittest_models_recurrent.cpp +++ b/test/unittest/models/unittest_models_recurrent.cpp @@ -58,6 +58,24 @@ IniWrapper fc_unroll_single__1( constant_loss, }); +IniWrapper fc_unroll_single__2( + "fc_unroll_single__2", + { + nn_base, + sgd_base + "learning_rate=0.1", + IniSection("fc_1") + fc_base + + "unit=1 | input_shape=1:1:1 | clip_grad_by_norm = 10000.0", + IniSection("fc_2") + fc_base + + "unit=1 | shared_from = fc_1 | clip_grad_by_norm = 10000.0", + IniSection("fc_3") + fc_base + + "unit=1 | shared_from = fc_1 | clip_grad_by_norm = 10000.0", + IniSection("fc_4") + fc_base + + "unit=1 | shared_from = fc_1 | clip_grad_by_norm = 10000.0", + IniSection("fc_5") + fc_base + + "unit=1 | shared_from = fc_1 | clip_grad_by_norm = 10000.0", + constant_loss, + }); + std::unique_ptr makeFC() { std::unique_ptr nn(new NeuralNetwork()); nn->setProperty({"batch_size=1"}); @@ -89,6 +107,38 @@ std::unique_ptr makeFC() { return nn; } +std::unique_ptr makeFCClipped() { + std::unique_ptr nn(new NeuralNetwork()); + nn->setProperty({"batch_size=1"}); + + auto outer_graph = makeGraph({ + {"input", {"name=input", "input_shape=1:1:1"}}, + /// here lstm_cells is being inserted + {"constant_derivative", {"name=loss", "input_layers=recurrent/a2"}}, + }); + for (auto &node : outer_graph) { + nn->addLayer(node); + } + + auto fcfc = makeGraph({ + {"Fully_connected", {"name=a1", "unit=1", "clip_grad_by_norm=0.0001"}}, + {"Fully_connected", + {"name=a2", "unit=1", "input_layers=a1", "clip_grad_by_norm=0.0001"}}, + }); + + nn->addWithReferenceLayers(fcfc, "recurrent", {"input"}, {"a1"}, {"a2"}, + ml::train::ReferenceLayersType::RECURRENT, + { + "unroll_for=2", + "return_sequences=false", + "recurrent_input=a1", + "recurrent_output=a2", + }); + + nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"})); + return nn; +} + static std::unique_ptr makeSingleLSTM() { std::unique_ptr nn(new NeuralNetwork()); nn->setProperty({"batch_size=3"}); @@ -340,7 +390,11 @@ INSTANTIATE_TEST_CASE_P( ModelTestOption::COMPARE_V2), mkModelIniTc(fc_unroll_single__1, DIM_UNUSED, NOT_USED_, ModelTestOption::COMPARE_V2), + mkModelIniTc(fc_unroll_single__2, DIM_UNUSED, NOT_USED_, + ModelTestOption::COMPARE_V2), mkModelTc_V2(makeFC, "fc_unroll_stacked", ModelTestOption::COMPARE_V2), + mkModelTc_V2(makeFCClipped, "fc_unroll_stacked_clipped", + ModelTestOption::COMPARE_V2), mkModelTc_V2(makeSingleLSTM, "lstm_single", ModelTestOption::COMPARE_V2), mkModelTc_V2(makeSingleLSTMCell, "lstm_single__1", ModelTestOption::COMPARE_V2), diff --git a/test/unittest/unittest_nntrainer_models.cpp b/test/unittest/unittest_nntrainer_models.cpp index b5cfb63..c9f9ab2 100644 --- a/test/unittest/unittest_nntrainer_models.cpp +++ b/test/unittest/unittest_nntrainer_models.cpp @@ -103,6 +103,30 @@ INI fc_sigmoid_mse = INI fc_sigmoid_mse__1 = INI("fc_sigmoid_mse__1") + fc_sigmoid_baseline + softmax_base + I("loss", mse_base); +INI fc_sigmoid_baseline_clipped_at_0( + "fc_sigmoid", + {nn_base + "batch_size = 3", + sgd_base + "learning_rate = 1", + I("input") + input_base + "input_shape = 1:1:3", + I("dense") + fc_base + "unit = 5" + "clip_grad_by_norm = 0.0", + I("act") + sigmoid_base, + I("dense_1") + fc_base + "unit = 10" + "clip_grad_by_norm = 0.0"}); + +INI fc_sigmoid_mse__2 = + INI("fc_sigmoid_mse__2") + fc_sigmoid_baseline_clipped_at_0 + softmax_base + I("loss", mse_base); + +INI fc_sigmoid_baseline_clipped_too_high( + "fc_sigmoid", + {nn_base + "batch_size = 3", + sgd_base + "learning_rate = 1", + I("input") + input_base + "input_shape = 1:1:3", + I("dense") + fc_base + "unit = 5" + "clip_grad_by_norm = 10000.0", + I("act") + sigmoid_base, + I("dense_1") + fc_base + "unit = 10" + "clip_grad_by_norm = 10000.0"}); + +INI fc_sigmoid_mse__3 = + INI("fc_sigmoid_mse__3") + fc_sigmoid_baseline_clipped_too_high + softmax_base + I("loss", mse_base); + INI fc_sigmoid_cross = INI("fc_sigmoid_cross") + fc_sigmoid_baseline + softmax_base + "model/loss=cross"; @@ -844,6 +868,8 @@ INSTANTIATE_TEST_CASE_P( { mkModelIniTc(fc_sigmoid_mse, "3:1:1:10", 10, ModelTestOption::ALL), mkModelIniTc(fc_sigmoid_mse__1, "3:1:1:10", 1, ModelTestOption::ALL), + mkModelIniTc(fc_sigmoid_mse__2, "3:1:1:10", 10, ModelTestOption::ALL), + mkModelIniTc(fc_sigmoid_mse__3, "3:1:1:10", 10, ModelTestOption::ALL), mkModelIniTc(fc_sigmoid_cross, "3:1:1:10", 10, ModelTestOption::ALL), mkModelIniTc(fc_sigmoid_cross__1, "3:1:1:10", 1, ModelTestOption::ALL), mkModelIniTc(fc_relu_mse, "3:1:1:2", 10, ModelTestOption::ALL), -- 2.7.4