From 2926ad3f5511fe46d1bc2e41ef5c3910bb76f016 Mon Sep 17 00:00:00 2001 From: Parichay Kapoor Date: Thu, 2 Dec 2021 13:06:45 +0900 Subject: [PATCH] [layer] Support filter masking in mol attention Add support for filter based masking in mol attention. Add corresponding unittest. Signed-off-by: Parichay Kapoor --- nntrainer/layers/mol_attention_layer.cpp | 18 ++++++++++++++++-- nntrainer/tensor/tensor.cpp | 20 ++++++++++++++++++++ nntrainer/tensor/tensor.h | 7 +++++++ packaging/unittest_models_v2.tar.gz | Bin 5706 -> 8118 bytes test/input_gen/genModelTests_v2.py | 26 ++++++++++++++++++++++++-- test/input_gen/recorder_v2.py | 25 ++++++++++++++----------- test/unittest/models/models_golden_test.cpp | 5 +++++ test/unittest/models/models_golden_test.h | 21 +++++++++++++++------ test/unittest/models/unittest_models.cpp | 24 ++++++++++++++++++++++++ 9 files changed, 125 insertions(+), 21 deletions(-) diff --git a/nntrainer/layers/mol_attention_layer.cpp b/nntrainer/layers/mol_attention_layer.cpp index c5e35a1..f0f92fd 100644 --- a/nntrainer/layers/mol_attention_layer.cpp +++ b/nntrainer/layers/mol_attention_layer.cpp @@ -31,6 +31,7 @@ enum AttentionParams { query = 0, value = 1, state = 2, + mask_len = 3, fc_w, fc_bias, fc_proj_w, @@ -46,8 +47,8 @@ enum AttentionParams { }; void MoLAttentionLayer::finalize(InitLayerContext &context) { - if (context.getNumInputs() != 3) - throw std::runtime_error("MoL Attention layer needs 3 inputs."); + if (context.getNumInputs() < 3 || context.getNumInputs() > 4) + throw std::runtime_error("MoL Attention layer needs 3-4 inputs."); auto const &all_dims = context.getInputDimensions(); auto const &query_dim = all_dims[AttentionParams::query]; @@ -57,6 +58,7 @@ void MoLAttentionLayer::finalize(InitLayerContext &context) { wt_idx[AttentionParams::query] = AttentionParams::query; wt_idx[AttentionParams::value] = AttentionParams::value; wt_idx[AttentionParams::state] = AttentionParams::state; + wt_idx[AttentionParams::mask_len] = AttentionParams::mask_len; softmax.setActiFunc(ActivationType::ACT_SOFTMAX); tanh.setActiFunc(ActivationType::ACT_TANH); @@ -225,6 +227,13 @@ void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) { Tensor prob_scaled = prob.multiply(alpha); prob_scaled.sum(3, scores); + if (context.getNumInputs() == 4) { + Tensor mask = Tensor(scores.getDim()); + mask.filter_mask(context.getInput(wt_idx[AttentionParams::mask_len]), + false); + scores.multiply_i(mask); + } + scores.dotBatched(value, output); } @@ -261,6 +270,11 @@ void MoLAttentionLayer::calcDerivativeHelper(RunLayerContext &context, Tensor dscores = Tensor(TensorDim({value.batch(), 1, 1, value.height()})); dscores.dot_batched_deriv_wrt_1(value, derivative); dscores.reshape(TensorDim({scores.batch(), 1, scores.width(), 1})); + if (context.getNumInputs() == 4) { + Tensor mask = Tensor(dscores.getDim()); + mask.filter_mask(context.getInput(wt_idx[AttentionParams::mask_len])); + dscores.multiply_i(mask); + } Tensor dprob_scaled = Tensor(TensorDim({batch, 1, value.height(), mol_k})); dprob_scaled.setZero(); diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index eee5081..d0cf702 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -1228,6 +1228,26 @@ void Tensor::dropout_mask(float dropout) { } } +void Tensor::filter_mask(const Tensor &mask_len, bool reverse) { + float fill_mask_val = 0.0; + float en_mask_val = 1.0 - fill_mask_val; + + if (reverse) { + fill_mask_val = 1.0; + en_mask_val = 1.0 - fill_mask_val; + } + + setValue(fill_mask_val); + if (mask_len.batch() != batch()) + throw std::invalid_argument("Number of filter masks mismatched"); + + for (unsigned int b = 0; b < batch(); b++) { + float *addr = getAddress(b, 0, 0, 0); + const uint *mask_len_val = mask_len.getAddress(b, 0, 0, 0); + std::fill(addr, addr + (*mask_len_val), en_mask_val); + } +} + int Tensor::apply_i(std::function f) { Tensor result = *this; apply(f, result); diff --git a/nntrainer/tensor/tensor.h b/nntrainer/tensor/tensor.h index 53e5064..499ef6b 100644 --- a/nntrainer/tensor/tensor.h +++ b/nntrainer/tensor/tensor.h @@ -698,6 +698,13 @@ public: void dropout_mask(float dropout); /** + * @brief Calculate filter mask + * @param mask_len length of each mask along the last axis + * @param invert invert the mask + */ + void filter_mask(const Tensor &mask_len, bool reverse = false); + + /** * @brief sum all the Tensor elements according to the batch * @retval Calculated Tensor(batch, 1, 1, 1) */ diff --git a/packaging/unittest_models_v2.tar.gz b/packaging/unittest_models_v2.tar.gz index 48734780a63b9682c9450352a15415b7050d397c..c00c0d5ff96ecb53e80de661a6138ee619695f14 100644 GIT binary patch literal 8118 zcmV;nA4%XJiwFP!000001MHaxR8;4-$3YQ@rnqXeL!rnU_}HGP*E&VQ^>U?@y;1C@80+B^?CD_n)lXY zuQlsiv(C5AocaE<|NFPs1o`{KMl6mD3-gH%jR+14G>eD`j|>P53yus62#k>UAYma4 z8yjnG<=OHr{&5ur&(d0gXU(&+;#*qT@@*u9S8Z&05;_(i=HvhSh>ec%Tdboau`tSS zaj1WYU&1@~z3=+}(HJ0nNeW+On&W|6lO44wx8{#t#ByuMBNyb>l(u})+NRkh+V|BE zSJA#|qZH9TW4~>p>nm!!*+1#4^vC}og8p9rF@F9F0|R{g!$PB?0s}sv5xlqlZAA5N zEv$OJjnE4$EO~tE{`&tUA+GR)eN^z zP)HTk`|(YiMb}q4dx-LFf6pt7`_Fl%Yy0vtZmabB!~dfAr~3a00Pt_t|KIz+rGe?!DCdw1H<} zqlo?t${tMnl#aGH%YotCprhGyAiayUZ0NiJ`Z+}UFj^le0Ut8MD*KAU)oZ~_ zR2_;}hdSVIt0K|Bp;>ee=4>{exmSl0(q~{$doJEzpbTXN71(Cw58h*nV8w)opt)xa z%vRfiz4>EA=kWA!Iy1ZSPTiQNhE#i984MkN17uC^iq4_+oB5FBu&M5^iq)X`g)ge_ z1~?TZi4tkmV0kkQJ06>Y!d$?VbH^c3b|3yJKg9m{Q{q4T|D)00`~BbA>Yx0dZ(-H% z|DPhn{xA2LFX)^MqZ54Z!n3m>bnWXH(r{ED^1bXA4&?1g6FO~-8kJf%m~M~Zkd35P3p^wi@0i@{2~$`a6nLso=Ifmvh@t61AU2lA4PzVacd3spFPy=z8KYqrBRh z%~xCz07>Rt(>E0czc5Z@hO&%Lx zUJqJz>(M;<5jG`%D=>Q~$y&PV(&kAexGN$HQpfFMxSZa*XE3o=i{z{t%(@sf!dBJK zSqG;BpfV!~bnj76d0;@*ry8(%tqnL#Mvmp4A-<;;%2lk1&z<4u!OsQvo+(fuzfW{_ zM;GdliQ`mZ?YUJXAUOc$%`_*x`PwiyGKCahUI;^Hr4j7|Mo{ZrMf9EGMP6R4s73E3 zh0=3AZgj}2<#cdNEFFz*qJFYKej99jtxc=W_JXpH3uv;9P;_W6iGv2XQc$6c4I ze_evX$(eAl=b%W7^WA6gj z;`@wNrm#-{}*#gSaeMIM}>ch_xnH3y3hZ4JR2)3OIxA;+t}Fj z^ZzG_kHPUwfKs*PC8K1j{cpebX0UYCVyw5TIn`2QRo(shD z0LTB_EWq)<)YC{DEYa78xSqrHBChXneTeHhTrU#$ohIfR=243!F>txZ)MfdR;j?oY zLqjKW>ufXgrAY+o#MgBrWCg_dR32knn@B$NoWOn8=n^AVCBF=BKG0_Kob*`Vo)u8n za}gwZPeb~_Ty%3$KtnPP)tu6yd9@wW))0YBI<@5NuRYMv(sy2yGqZ8M8G{XZ!{E%0 z!x(wOkOsMJVh*-ngr5#X!KDZDp>$~x>^i&|U(K?Iy-&(u#Bh1$=aM&=fA{zEijl&4 z8EbX`8&3GO9+2priep1|f#|&M|v%Q$cm;lB*(qzSYpWv9;PC9j}pEO+VPBOM&LRIuav_b zV!~Fro6;&j4ZIz!%<2@JXI5q!Q{!+iY}O2BGNr@7cDVxkT3Qp_-+Uo@){{dD!P{#J zF4M0h30ZRV${sCNe#=8}$nnB`9?Bw)CPye9eU?Q!)K=jPnKNX-%JJAWJ(J+^)%e7` zk^Fem3|j(=NbvbMeAe1bKJ+}neb=Dpa@2LmF`%Q9VVRvOJtow4uUi9iZBkCH$+Qqi z=2Y!#L4uFY}qlCFs4~r|q>UFD#wY4j7(=R-^6En8#z2vg_f< z)jRn8)F!BzWQ6h4*Wk6&51DyRmKf-$OZ8@!;NrXA3Wh)E({{A+1aoI0R0B+IKLlC_4~9PG;laNSJXl{Zz>FPE@nDwWN8By!bl@cw!u8ZJ~}iPVis5t zC%t`O+u?*2HWpN~U@w-fwq#fR+NbS~Ykcf5F{U8+d z&Uta38JT5F6*hR`fg!<+d|Vi8T&%!GlxRTs%RX&GKRW=kQzl`vdpU_NkfT%VG}uv@ z-@%=6J~%u_Nu=$eZYO-Sc?UT?b2SEBJVi$BbQAX8?L_0pHR$!Ef#`R-q5XzTvRrE| zzMXU8gY?z#$S@zjn3%wbn9#@wpK!nEkI)C+#XU32z40v8_=%0X|&B=j~XoK0y%*OOkATxkB*U{ zH?~I8{XhH+SR{PzC|Nr8*c7VzV<;KpV@wDNzf*!-F+=aAE&X2auSOt8uN=SV<7DluO(ZQY8 z)J=LUH4F=)&P$x>Kw-Z4CHcF4=LbYmb^jY+*c=JoTZYm0*kSbT;%4X?qeBY|{a`%R zroQRDuq+~hdfs)S1yec%<(~7QGUBg9I?)RJ@7{u~{)UvRxx`%~am}>+Ix`&=Y;UQ+ z=3G-~sVRiiEDyY}X*Zsd_rT8VaQMNyKrm=jCeyfK9W&#&I%M1VFl9z>82u?X$wjAh z%rP1Q7b7lX{Qjdj-^m;J^|s83ij&L_gBl_CwH#y>q=8C`9GRYw2a{vUKz%yDwk~`H z9Ng;yi<0GV%h_1Qe*9D7(3J$Pu?HE2ZIQxRXgKO6=rZ}8i=emyVfNV3sNn4htUfs8vGLcrI0_q1{Z^` zz`<#hkQqiZnr=Gwc*PuU2XDud2X-?KIjziy>Kx{k%0Vbz=7-IuCg3$tMR4Y77cqOd z3e~3;;V_w0G-^4;2p-A_nV_f6{mpKyx)K5VJhn4=4O`%|1uifyXaRPc9)T|>T0?w6 zENaya#@@Tm=yNAa$SJ8%!mncPkWw5ravhu;q7S>AP2l@8gQ0tK0whNm!S?!bXm@o2 zt_v|{c0~*al^GJ?!Pf?zxC+oIm?-3s?+CG!KyR-r%&7T0gl7{Al`piwP``^LntYA& zJQe1-T0LW)9Eypxdj;`5_TXP>0Cv0*sB$lZZgWpqrf*5gPt9ib=!GI{s0Ih(nwY%D zo_Hc2h^Ms%q>j^NbnmgGT;(r_t7-Rd+|H>0pJ$7ikxfA z#liea_~~3Oj=6mZ?xaAHej;6o`J_27mqWE~!);1M+k^cqN^KW4lbD z(b)@R6N;fpdI`8kNx^{jE`fjVXHax+1dMEaN^}+<5Z$Y`Lyn;;Se5&+aX%iY26K| zD)1GgHJ{~31NNs2_SPtp1yc)`(n}@2Fz_Gxo4Ci_?w(EUZ)QV2b)*+86{-AG4f2;$iu8GjAvHZ@5BekB zspd%y>OqInx%#?vWwAMJ(T{`uM_uU%$2qiR@>sf_@I)S2lV^yrL)O-v_(B#p<&|Q9 zjRokuEP<&4f>$&Ou~lmd=2S#Nz?uOBTvC{aa_gDqTphUS`Xw{#ngXu-_6kufO~Zl1 zmEebIO?cDk2+lJMg3JgtMyaKfDJX1!zqN1X;aG0#9~VkQ2pJDbqaib;cipS(+KC30 zG$=xncA;JwN}_}enJHx`nq(HDcilpXT}mk$wlqk!+q*$_64mj(=jZ&+@A_S5pMTEY z`&|2U|GA#`xz@d|=e@tr`mX15Ck&H>{i(^I_gsr)%VmP)fkJRkSLZx^;05yj--5^` zF;qRamEC&#mY7=Xg-aqwm~$%^9BykOa;C9}CLf4QtAVLoCgA-Hdr+O9iF+1F!^7o7 zppn%<39M1*Dm#0yn2GMcjN>kz2a&UzU~QZcioT8&Xk_zIYtCz*lW<(z5}zN?1f?HJ zuw42FHe>l=vU)NXmzKGMzMmM%fA2?BZX874u@PVq7Ypk06EPuwA*jVzLU#N(fNQ{J zby`r(tx23p)Ggf@jgSnRl`tlKKZB=u&qaFOQ1C zt4fWm)lL$ny$t6X5+wTbOcmGlpoELdi*eu(a{T*oESlU~iA^ zRxxZzRU`~IFJ>N!)!6#p2gY6OAt!a@V7}uG(jB-4UZxvC^?izgB4&8FNS!$qD8j03 z60p5h8H%s~uIR2}g|Tw*uCp=G)V5x%Sp zalV@DxoQRZ^5QDzwU{c}=H&q#O~F&QiaB;$_Q0hbCxP6nAbKrHtZh*Qmi9#Q+-IL- zaif9;$0M-1?GUPp-63XEGH}r&5q5OsEmoo!O*(HEpl9wQq7glT2s`9rnqe6{cg@7` zyG0=1dsd)p|18FG8^Lnd7D)B=fh$?NAtdb-e2EYSn|p?EBUBeMdUlXcxB_IAOkj)T zk^hSCC4aW~KmPAGIsp7v{%_Fj{muUwnwlCLn)1*8jEug1|M^{F=-t@g_7+C0aiAHx z+H_V?F6bW@rJ3^ILzTm4&gC<5R5^DQtxw|8(za<-t^Y9L{c;1EvUJEuMLEJBxgI@0`&(`nrjcX~kU0G#yL3pE+&<R4FQEy;l(T=DiXAd+Eg||Vxt+I zrD+0F8skW}`#A9Pnh1}k$3wiEIPL46ON|_z=~!JYI$_RgnkJ`AUuPOpk5F-%oa0KD zD%g{%puzpjDS4>d;!PIjY$c5yb|hC(l;%2I2k{rXY0s($NwVuma7rq?$Drjyjoz&lS=6QNr8L@N9qr{bjgHL zVl3lI4=uK$mD`Vj$gU-H-0NYK|GW~@^Eq+v+hO9~POymafq_Un+I|Y?F0HXtZ`A;l zhxGwThydZoDiog2qN0(5UZHF_^$lDL#=Jo?HS-e0K4m5;JvrBy7r z`rl*0{Wc^qJ_hE?H{fbN#HhGP)Qr8v$}W1r+Q-SHr!S2q3Gc*#7AZ*Gm;&kzQbM6EUzkgH6*8 zLu30B9)B;>mTiku(rVeHecW?cdYRmX8CT=^~b>3_?z zCa;0_z577#`Z_rC3ngdco^tvXjUb@m4Daw&9w)*qnp|A<B>NMKr#y4hJ`_r-$xQ2SHa>(Yk>UJCp^w4X$X@2%p@14quI?cnBVCJt>a9Hv{)Vv z%bW-EdSl3DOC(-L(%23AF%Wir4%5D-#GKdVLYCJZ620#h>pNcn=kk@IEAs%Ev(btb zT)098e4p};D4OFKp<^&bJ`02sl6l`>*g{$d_ky~?cZvL)4A|*5xUaa82TqUnu>9+4 zxS}(LwX1fJV&7EsB%^q|t=X{Mnvz;!Gwk0H2e!{9z&dU+ipx-_pSqi@;{@Ql9Vuw^ z{Sv&;aGHh4ABE{h$3x*ZHS{amOr}4~V)xYUu)Apy;N)A&9#w1r^is$7;x?pfR4L){ zw6VnCXJV2z3dTR&0u3Bpuo@+Wk{M$$-(WK;zX|4%4X1f`e8zx5(qYn2NuWMNp7Z~u9Q1<`pz9Bst!Y&xH*+}oI#g=Iml76Pk@>!-QAA5WXgI zwM8lCixKsFb`ZD;YiR#tciK@ihfcmQ=oP`6c368;avtdGi5^t^$Da&)O|zUbDfC$+mCQ}#0;ATr6@;N zmX6-Nga$Z!(|6%>p*zrw>gO3zZyE|N+cwcGtC6&}bu`^*c^pKY^TAUq7j7l)pvM($ zVS24Lm6<3Zd&s%m>qlh^PC}3LO;EWi0eOEBhO%E; z;EIg8KodWDi$R)S0QLAV9c)fm(jVL3!;P#tG(_PrbZ7Z)UFfw&LNkYpQs3PYAhA}9sHZHa z;=I>H{Kw(6F4&2xMXQsrS%%d2?mNg^rU!gYtnv83GrCj^Ojc)Rn>oW*8-p>@wgoI? zSHY%p8*o8L4GWEH8~jg5Ls7jCTAt1!I$h^5t9cPTl4i&o%y#+FRFHRSXX3IOiE;iB zIPvZRhKupAqb~@XpA<5o9U-7T`~c}!iDEi-tFeFIUm)z=A#jmaV^hgtXqq<_QUkca zjjAS@A$1_E&2YVa8hem?8E3~B;j0(3u)?+ug9>Wc{K9u+zW;H2+LI2?*Uu1W;^#gN zb`J|-8zwZJ7acW{qy!e@N900eI~Nn?Ckr%DWz-4ilKd8ZqRrWrmun%>FdSSKt%k&R z8f5FsI*x0(4wPj);YpkFIF2r{Bx29mnQ@b%$+UsqnPSSKI43p)w|6RI&#a$$rr*gC z&8F#9?%U46nz9yF5M)G*c_FYz#D;VC=6J|f3c|aI$myw&#-Er9Nw~U~C3ZxD@|b;) z`KXUu z`3h0YDN7N`C3mw)X4_bDdp1P8EFdq9er8i_3qftK60BXapCq}OvDLB-k4d@K#URS0SIW9c$CCny>IBfN&Gu)p6LeDZT3 z5nI=CsAl+0$vWGuP*j=>In{`yc;j zU}9p-@BcM4H#0W*>i>R^_*eWN{~Yxn`M=RZzs~mk8rLT_crTy*UEHDmItXhlWh0b~R(`oJ$ZJns875ddfJssbMk|;@vPuiO zhs8teh9<(Ft8cF(T4~7QvMTbY>W|sx>G?nx+WcJ2C8Yys?E$(h}_5Uj}$V0&>&U zG3oSryk9Cpxqk8R@8)p%ISN0XS{*+(8+TYw##I&ig1q<7)2QC>Fs}NZb{zDbuoUEd z^GgG`b|?tu$+8c{W~)Gq8mU6V3|rxW`2yk{k;XPzW`IieOY*om8b0pbN*=tLNt*kH z{YUA?-|PR6>;Ez^H~xS1Z)#xhwf=vP_$~NPe$5PV{zE)Je{6{7AL8}}XU|E+0(PH& zeTeTL;`sSvL#+Q0)6c(Wyty=a5#l4r?K18Hr@5=)2gO8=(((W}X&4P5>sNqUwrq ztkqV#7-nRT@^Ezme8{MWb-^yMm@|ny6FvkA8AWinYXj_iFzx)LXGvfE+Sm9R QU*lKfugb7)H2_cm0QgW77XSbN literal 5706 zcmV-Q7PaXgiwFP!000001MHUzSdD4B$8$`HN=l(4Djl@eYORXab3cldw48F7j4A1e zT15vPP$7qqP;%%%BqSvxQSbehkP$iN5J^O0m~Upx9*1GR^{&3PzrFYO*}mBqX76iv zU)Q>xYhCYizw3RU|GNM8?+*0yiIOe~3kmUwm?sSi2^b-jhKBhEgan0!_y||cR6YxdC&iel%p{W0aL$30AcHXp3Uf;o{bLDk?IqIVP`{lR*`SVD?9Sio3uOq;eS#5L;ZgS0QeW{|L^@@z~k8p zI_v*S#0TsD*)Q_?w%WJ`|Ct}W&{;+QS^bwlQ@~$sBsk_Bhx(DbdhhN6@_yc69kB22u~IQ}y#oR3pPz%;l_ZH)LJL zny`ASEm%Qo5muDOgKg+VaB7Yd`3hegxfL&4sgz3&3Xl6WG6b1%w>lfZb;}0{7fqZmwf~Pgi419+}g}Oc8ul zdl7PT-XeEDX1WgeEFX>iGY*6Pvm%__bvQ((X<)EtF~n@&g}OT1p|Sf}rnGwl3=G_i z-2HyN=|a!G^rbVdo`jt#`ZV5m2T8n9DL)672U1#IWlD#5*MY;+@$@T)Ja{VAq$Ua8 z)NpDutoM@8EPE|j5EtUI^p}IFf?;&ke`DMn}xLxD^aOuE%dmXi;w1OK~a7QHro1u_s9ZRHtq@N?_2@1 zbW`zV-bnd5JU_CDnN`|SIr6zVHC$Z;eaGGa6{`pGbGUtYE+jgwulzTSa?l?*2laLU zoC)uON*l{TaBCypduk2pvjJC~KLPP7yYWx?BX-80694Z1pN;-L?*BsDfAD`1&$iS5 zzeFhfU$yrfFghJV$IW>FFU|$im2Vr>C zJx&4_8m}|yrh4qpv$sOIM4#ny-Wess7<(NW_ga+}s~v|l>ptw^$<>VNvMR`YsEN)F z%@E?RL8)3Pb~DLk@<9$NMJY)o+# z7*wu7o8)$^PyS9k;z<`)Fy5FpOe)0f(o9Ghy_?~3UOuSB_zDA(oz#nUGpm6#o&Kzo z%U;l)kqE{QDQG`7qk2=#*j=})v7fRk%RNI(TLlzr3yDw50F;Pw!J};o@G zJ@4aA`@CLCdqqak;pi^!C-c=d!@4(ywCvnV(DHEueYOS)4$dag&;|$ncjWcb zm*loOp=-Q!QYod~4Gk{Es4Ns0>*2WGkZtDa?<*6R* zg2tSRb#a$qRGbgKxoRfY;$c$@HczRC=D~cnSf+)Y&01_a*^cqQofjWE(S`N#Fs6DF z3o$4;1NOJ=mupeiT8poa84#!Cy;y6R7N=GEvv1t?g3shc;5Aav;_xF(!t(E^T=9G|#$k5LU|4JYCf4;E8|M`47TU&v>%>V7|>^k}X zi^S*Pe+^lVg8LP7fP()OY_H&d#XO*x0~CzU-9s@CDCPi;@j2e-82|0@a8S$zig|$J ze{L4w_`mz}Fzlt&(TBL6!}TJr?{IyH>p5I6QuLjA<}lNKQ$Wnz9x|1gzGT3xT*lno zh1@yUzznmJl4g8UIY>oJEYIv>>?`8Qr=AnIca12tV72p#@YZ8Pc9)9@JEv_KRJL6N zrI%-6)Bap^cT-1mG8%PVHbKL32d1%FiuFbnN$cmBexxieUPUQ<#6m~Kxe5|wF3ndiz&N!zgvjPglJM%2|Z zb=Sj)^`LD`(bpzq)Y|LJq$A&uKjn|z8UIpr`2S~wf2EK6zn%Sm^?!Swz^>E(zeIcv z{^z_%!SD*ESFk+C=^V2wSYE;Q3Z~~6p5uA$9vsUro_QCFVlvUj#sL)^ui$!)^VMgr zKx}xzY;qN#g69={&-oe0^BlW#OwaK<$MYP!b4<_iJNG*rw{r~7al8JA1I!MwD)lQZ zfPg4{R^NS>Bgga{-*ZfFanl8jqIopYIR~8!1+4f<3Uc=-6xpEz-;$;WtY+{^F|-gX zw#>ttmiZ3GrXVfWD8G(L$grfApg=1I`rz?)p!``+_t^*DUQ=+1 zX(@@#RHavT8n9}qPrxbL3wKMjYk z@Gn#A;p8L>jG4Xyubq9ueB&a(0B2)rGP4jDJ@`&M;Bkkxi{s9rr)?t4o~cj9o}<7f zXtDmL)p)ON4;(V>(6&dn0q9y%0ipb<(9iV%(|l5eR^vVxVxiAg+oU^2%*%zy`L@Kx zWH;E~bHNfj9@Wp!!J_2?HtAi5wp*`>@Sc?=wY$HH@wre8c8OMOEYF$>Ru97XFfI1$ zU+b75nU++2trza?6U3-Rhrqf;>a4VIFoeG9&~{$$y)bLlBy8{~CK36nbc(}Zc4)@; z&@$Qw2V`r?wLRGCg6$i&lCv|HqyNP-Wau_`S?}FK2LH4Iy`EJQ(`I*cT$@3b8mz?M zzdrRz`f6xch>ve%WPmhsUYPVVHGq%!zlbN2`@cxQvlH41M4bN%1fBfV4%koIa8Qum&tsCh^r zbzSUAIlV-4^{98>44ADv2)bnz!sO}ZG)K6V_)YXB=VSBX%J2bHZ~H8I`}Jm6>%fOQ zVim~Sy^Olo+e1P|9;}boq%{_LVy?#Gw-3P;)AA}u7-{3(nB& zdDAnS^c=qs)l!3@^1vs63l_t-T)(Eb+w#Yb)^TrY)|$dY_DVl7xCw3$UMZ3R>Je z!-$`#${3)n(&Oz8EW09w-I6WLuIg0iJ>LyR2hPV<>!UDif)HZzqtKwT7ruPpiasrw zGA>DhLQxshLJncSA*-0b8m@qr=s4xH{O9*)AOb+B1|u zA~FP{=n^o>pCIFo?+Fnop|{r+X6W3lva^YT(q9e0-1I()x0;A*d~N0j-73Z=c^<}B z^4{Q07qttu~&p#8g0v&&*bJuunocet|B9-5vl0!i6YI8)mI6}}Up|IQVV`K%5GzS|5< z7mvbWgR^k`yDDfNrb@#C&OwsPCE)J$BrT98SjEwOX?6`f~!sB+|h9~Zn7JVedSC&b>`Ego#t{+Oi9$Fy|+(<4K)JV zzb1m_UH7BXS<7hZ)+y9zLvMOkwGC3=`q2kwv9!K@05y+Xoc87fWN0|j=%qhH$jlO8 zG=|bwj$I%#-3E>x5!3ftU7^aEfxKe_X^8nX5)dq*wd#wgdtMw3-{wneijAnZP)cvC z8%c#R_EgE)oK$yffrtfZkh!uDL>h&pYppXq?9-J7jT%nfXB$()rwkZXw1Vc^>#*dP zUqH7ap8E4yC>S?Ee5XqY?Xl5^4%ht_f+ibNUD0ufdF};uMgE|9 z`y!nvFGot%%B!HU=QdaphA^nrlWHFcq-nK{kh(QPk2=!}0u8D*bujt%j0XLo(41NybOh5O9#sGIU@D<~ z>1D`Qvk!`@3d8t{q)3@-gfN>)DzGRZtSlf4=fyC~0#ZR!bQF~Z zOkpLy{KN!y`49*es72%})&ws3(bKhs?oaeblTjAI}Gh!$xgj)@};moa@lFr~5=r*1P!5U-Scv?hb*4-v? zEgNBuZ7~(Z$3Xb}*RcMO8yzo-1Y6HW7_rF`GjIAr&w>)n2_6pf6cSbw>-OYPe@zQD z2??U-!WwYgxBI}oEd+itkl>=AbXF7PEweqE+{>Y-aVmzN8w-l58g#aAry|3x!xVVFG!xz)Y=zxheBo5Q7ueg3g9#jXgiCWU zv7mt7%Q`~~%!}zO)#YHKScWEDwlMZN6OTpW%j5@4z_ktK19fEJed{*5M8U`K{6KBM zN)=9-xej8z<+Q##8yvl7L&Ovr)~&CC3ql9DU>A!S`Ge55z#X$>)9Ia36`yoeZ*xXg9Y8vlRlwu82-xvuq zyi1|7@F>`fe;?ko_>{y>ysJHY;$3W&i=p^rC-rgj!zNz8M?YTxR+3S)s<45iU*(aT zJtp{~dJh932^0d8SEDxB%3Cv=rGGg6iv|iJGl~n`Z<`+=2cMwGPt6;)aL(dJCS+FtoI7pTzK~t{{ zBqyx_u+ z&13zO-^4uoXU~bLD|)wsu6yzU?BHAJN>&PWvrI!FZp;xV+NTB8Hfvl_6@aq)a?r>! zXr*5g+JC4Z^ZU=UbtN3B#%!;JAkNibx~VICENI2hBqd~qoFL3veTj>4hQ&zqKS=&p z`KU;l>}lse9i3es-2Xc}Pj>n}|H<|L6X9|AzlOX&G(xWj%$Qy;^xD9z&FF?c|LIap zz}^>W;3PAMI+H=~E1e|#+}U_1wHmO_7k4)M(H#L8qLq z#&z!qF`&kQtvmHC3tj5MvAS~;v!HLg7h88re**+vG-cN*c~6ebGd5tFk-#%abQ;=S zeMm$~K3(Be1OoLn($Sm_cN0G#?Kj*>bKju? wrbT< ...] # Each iteration contains @@ -74,8 +77,8 @@ def record_v2(model, iteration, input_dims, label_dims, name, clip=False): optimizer = torch.optim.SGD(model.parameters(), lr=0.1) def record_iteration(write_fn): - inputs = _rand_like(*input_dims, rand="float") - labels = _rand_like(*label_dims, rand="float") + inputs = _rand_like(input_dims, dtype=input_dtype if input_dtype is not None else float) + labels = _rand_like(label_dims, dtype=float) write_fn(inputs) write_fn(labels) write_fn(list(t for _, t in params_translated(model))) diff --git a/test/unittest/models/models_golden_test.cpp b/test/unittest/models/models_golden_test.cpp index 4514727..dd26d89 100644 --- a/test/unittest/models/models_golden_test.cpp +++ b/test/unittest/models/models_golden_test.cpp @@ -78,6 +78,11 @@ TEST_P(nntrainerModelTest, model_test_optimized) { * @brief check given ini is failing/suceeding at validation */ TEST_P(nntrainerModelTest, model_test_validate) { + if (!shouldValidate()) { + std::cout << "[ SKIPPED ] option not enabled \n"; + return; + } + validate(true); /// add stub test for tcm EXPECT_TRUE(true); diff --git a/test/unittest/models/models_golden_test.h b/test/unittest/models/models_golden_test.h index 33870c9..daa43d1 100644 --- a/test/unittest/models/models_golden_test.h +++ b/test/unittest/models/models_golden_test.h @@ -31,14 +31,17 @@ class NeuralNetwork; * */ typedef enum { - NO_THROW_RUN = 0, /**< no comparison, only validate execution without throw */ - COMPARE = 1 << 0, /**< Set this to compare the numbers */ - SAVE_AND_LOAD_INI = 1 << 1, /**< Set this to check if saving and constructing + NO_THROW_RUN = + 1 << 0, /**< no comparison, only validate execution without throw */ + COMPARE_RUN = 1 << 1, /**< Set this to compare the numbers */ + SAVE_AND_LOAD_INI = 1 << 2, /**< Set this to check if saving and constructing a new model works okay (without weights) */ - USE_V2 = 1 << 2, /**< use v2 model format */ + USE_V2 = 1 << 3, /**< use v2 model format */ + COMPARE = COMPARE_RUN | NO_THROW_RUN, /**< Set this to comp are the numbers */ - COMPARE_V2 = COMPARE | USE_V2, /**< compare v2 */ + COMPARE_RUN_V2 = COMPARE_RUN | USE_V2, /**< compare run v2 */ NO_THROW_RUN_V2 = NO_THROW_RUN | USE_V2, /**< no throw run with v2 */ + COMPARE_V2 = COMPARE | USE_V2, /**< compare v2 */ SAVE_AND_LOAD_V2 = SAVE_AND_LOAD_INI | USE_V2, /**< save and load with v2 */ ALL = COMPARE | SAVE_AND_LOAD_INI, /**< Set every option */ @@ -135,7 +138,13 @@ protected: * * @return bool true if test should be done */ - bool shouldCompare() { return options & (ModelTestOption::COMPARE); } + bool shouldCompare() { return options & (ModelTestOption::COMPARE_RUN); } + /** + * @brief query if compare test should be conducted + * + * @return bool true if test should be done + */ + bool shouldValidate() { return options & (ModelTestOption::NO_THROW_RUN); } /** * @brief query if saveload ini test should be done diff --git a/test/unittest/models/unittest_models.cpp b/test/unittest/models/unittest_models.cpp index be49067..c3ca770 100644 --- a/test/unittest/models/unittest_models.cpp +++ b/test/unittest/models/unittest_models.cpp @@ -61,6 +61,28 @@ static std::unique_ptr makeMolAttention() { return nn; } +static std::unique_ptr makeMolAttentionMasked() { + std::unique_ptr nn(new NeuralNetwork()); + nn->setProperty({"batch_size=3"}); + + auto outer_graph = makeGraph({ + {"input", {"name=in4", "input_shape=1:1:1"}}, + {"input", {"name=in3", "input_shape=1:1:5"}}, + {"input", {"name=in2", "input_shape=1:4:6"}}, + {"input", {"name=in1", "input_shape=1:1:6"}}, + {"mol_attention", + {"name=mol", "input_layers=in1,in2,in3,in4", "unit=8", "mol_k=5"}}, + {"constant_derivative", {"name=loss", "input_layers=mol"}}, + }); + + for (auto &node : outer_graph) { + nn->addLayer(node); + } + + nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate = 0.1"})); + return nn; +} + INSTANTIATE_TEST_CASE_P( model, nntrainerModelTest, ::testing::ValuesIn({ @@ -68,6 +90,8 @@ INSTANTIATE_TEST_CASE_P( ModelTestOption::COMPARE_V2), mkModelTc_V2(makeMolAttention, "mol_attention", ModelTestOption::COMPARE_V2), + mkModelTc_V2(makeMolAttentionMasked, "mol_attention_masked", + ModelTestOption::COMPARE_RUN_V2), }), [](const testing::TestParamInfo &info) { return std::get<1>(info.param); -- 2.7.4