From 166a9b62405f1dee2b467db4241e25f5b7ae85fa Mon Sep 17 00:00:00 2001 From: Yongjoo Ahn Date: Thu, 23 Feb 2023 18:29:47 +0900 Subject: [PATCH] [test] Add a tflite inference test with 32 in/output model - Add a simple tflite model file with 32 inputs whose return values are 32 outputs (add 1.0 to each input) - Add a unittest to test inference of the model. Signed-off-by: Yongjoo Ahn --- .../unittest_filter_tensorflow2_lite.cc | 89 +++++++++++++++++++++ .../test_models/models/simple_32_in_32_out.tflite | Bin 0 -> 10976 bytes 2 files changed, 89 insertions(+) create mode 100644 tests/test_models/models/simple_32_in_32_out.tflite diff --git a/tests/nnstreamer_filter_tensorflow2_lite/unittest_filter_tensorflow2_lite.cc b/tests/nnstreamer_filter_tensorflow2_lite/unittest_filter_tensorflow2_lite.cc index c479886..a778cef 100644 --- a/tests/nnstreamer_filter_tensorflow2_lite/unittest_filter_tensorflow2_lite.cc +++ b/tests/nnstreamer_filter_tensorflow2_lite/unittest_filter_tensorflow2_lite.cc @@ -14,6 +14,8 @@ #include #include +#include "nnstreamer_plugin_api.h" +#include "nnstreamer_plugin_api_util.h" /** * @brief internal function to get model file path @@ -32,6 +34,9 @@ _GetModelFilePath (gchar ** model_file, int option) case 1: model_name = "mobilenet_v2_1.0_224.tflite"; break; + case 2: + model_name = "simple_32_in_32_out.tflite"; + break; default: break; } @@ -275,6 +280,90 @@ TEST (nnstreamerFilterTensorFlow2Lite, floatModelXNNPACKResult) } /** + * @brief Signal to validate the result in tensor_sink of 32 input/output model. + */ +static void +check_output_many (GstElement *element, GstBuffer *buffer, gpointer user_data) +{ + GstMemory *mem_res; + GstMapInfo info_res; + gboolean mapped; + UNUSED (element); + + GstTensorsInfo ts_info; + gst_tensors_info_init (&ts_info); + ts_info.num_tensors = 32; + + guint *data_received = (guint *) user_data; + (*data_received)++; + + for (guint i = 0; i < 32; i++) { + mem_res = gst_tensor_buffer_get_nth_memory (buffer, &ts_info, i); + mapped = gst_memory_map (mem_res, &info_res, GST_MAP_READ); + ASSERT_TRUE (mapped); + gfloat *output = (gfloat *) info_res.data; + EXPECT_EQ (17.f, *output); + gst_memory_unmap (mem_res, &info_res); + gst_memory_unref (mem_res); + } +} + +/** + * @brief Check result of tflite model with 32 input/output tensors. + */ +TEST (nnstreamerFilterTensorFlow2Lite, manyInOutModel) +{ + gchar *pipeline; + GstElement *gstpipe; + GError *err = NULL; + gchar *model_file; + + ASSERT_TRUE (_GetModelFilePath (&model_file, 2)); + + /* make 32 "t. ! queue ! mux.sink_## " */ + gchar *tee_queue_mux = g_strdup (""); + for (int i = 0; i < 32; i++) { + gchar *aux = g_strdup (tee_queue_mux); + g_free (tee_queue_mux); + tee_queue_mux = g_strdup_printf ("%s t. ! queue ! mux.sink_%d ", aux, i); + g_free (aux); + } + + /* create a nnstreamer pipeline */ + pipeline = g_strdup_printf ( + "videotestsrc pattern=2 num-buffers=10 is-live=true ! " + "videoscale ! videoconvert ! video/x-raw,format=GRAY8,width=1,height=1,framerate=30/1 ! " + "tensor_converter ! tensor_transform mode=typecast option=float32 ! tee name=t " + "%s" + "tensor_mux name=mux ! other/tensors,format=static,num_tensors=32 ! " + "tensor_filter framework=tensorflow2-lite model=\"%s\" ! tensor_sink name=sinkx", + tee_queue_mux, model_file); + + g_free (tee_queue_mux); + + gstpipe = gst_parse_launch (pipeline, &err); + ASSERT_TRUE (gstpipe != nullptr); + + GstElement *sink_handle = gst_bin_get_by_name (GST_BIN (gstpipe), "sinkx"); + ASSERT_TRUE (sink_handle != nullptr); + + guint data_received = 0U; + g_signal_connect (sink_handle, "new-data", (GCallback) check_output_many, &data_received); + + EXPECT_EQ (setPipelineStateSync (gstpipe, GST_STATE_PLAYING, UNITTEST_STATECHANGE_TIMEOUT * 10), 0); + g_usleep (1000 * 1000 * 5); // wait for 5 seconds to check all output is valid + + EXPECT_EQ (setPipelineStateSync (gstpipe, GST_STATE_NULL, UNITTEST_STATECHANGE_TIMEOUT), 0); + + EXPECT_EQ (10U, data_received); + + gst_object_unref (sink_handle); + gst_object_unref (gstpipe); + g_free (pipeline); + g_free (model_file); +} + +/** * @brief Main gtest */ int diff --git a/tests/test_models/models/simple_32_in_32_out.tflite b/tests/test_models/models/simple_32_in_32_out.tflite new file mode 100644 index 0000000000000000000000000000000000000000..933880ae62903cbb65e168d20caee7792a925f84 GIT binary patch literal 10976 zcmai)e`uXo8OKj++~O8z*4=H{jPqT0*JWGG?!CGB(OI{&%|-@U7it(mifPNuOe9M~ zdN+y;DI-Y!C{jj{{81!~AVG?h5v0hFFoFarQbv&eQ6$bdgVq^moKdUS&-dodynWtt z-#6padv~Aj{l4cs?>Wyo?@1b+bN$cvwx(R*8eM}+xq4T}XSJ?|_m0n;Tlt)G&3Euh zU=1!j;oN8Vd>7P%b9a=#rSBjse_!9b@LA`+!0YLuiPuNQzBO1Fdga7xqeZUd+yJp= zaIjM2+&sJieUOGNE^q~Ap$I+D2phYdTY?!Fg)T_JI*oZ3&O-q@p$=BL!~#sgF~|V+ z?&>F{Qdju?eqM{OvR)(c!kxJ)5X&cfmhOrDqa#ZLSC1 z*AjJkH;(Q=qOR4=p!;f~F6S1}J)Nk_x^;B>6LlF^zf~%ACh6MuyH0c+iMmdA99?^& zuER~CYfIF%yG!WuiMlp-4P9%Z4yS*XO1VT`t7}4+P1NOFAG)2fSatlhXWbcePbK`=3o*E&<$y*fi>!?YV90MLIJuV4K=VveO-b%n1lj!LmFyejcY8y985w1 zx*-iU(0{X3()ytHEHQ3mZ2Yz2U^cILHcsqbPfT;`9I@LxG0mrWVp>*=XRB)35;2+` ziK&jQ5^M0pRJ(2xqdSrBsa`ewr&Rj9C#D*eC$`5EQ(fvK_K+v0S~N;b>!?BR(aVZ_Mj)GywsOU zr3XAQ<)xXJ_CcndS>>gNnChR6DK7`a*>*x-L?=?mazjWl~=Z=2;g_n*DzW7}K!R~|o2iXtSmG^`iCxwU@b#gI(<7(G%5!{<3tWI{I0uvP1{6X2wgMc3e&~g6H~@KQhBP!n3hJN+ZZZis z;f87pT!$670!yI%-vV5OInZA4ES!OH7>4680DaH{U7)>U2AUv%2B?QxaInS82CTy> zT!X8y42y6XF2Ov^!VH{;DHsOjcsDuLYd>fV*Sfnt&1XITp9be{D}T0A+sYrAy)|hl z-#d9pmutypTQbg_4E^PmZE4d-pZ?VSmbnNiXn;ltAPr5>3>nBnCmet-=!PEXg+Azq z0XPQ7p#Z}$3gb|OGw=qSg-MtK+b^}9JM-)uv3a-%mtX-d!@IBuORx-A;3}-ZHMkC| zumMgq_8f*G!!5Y zT6b5t$#YNy%?r%|jdzXhN$7_(Y;vt7n1%x60bBQD_A(j={M@OrruX%~+g`)Ec#mEA zxh&xI*B`#A_533(X(?;J;n6OJT0H|v zD+2hQM|&gG>N!Zd1_JN>eBYzp47FNIq}5mm;Ab9f@cZpNXv|2f zeM0~rdbGWvR_l$lS|bA3!k&i0to^n}do|Q%(8yQ& zjsSk((cTQTdTx?dZ4BT&kGAQ$jJnzw747GZWl2-ee046-zYoS)pQqpQ)6TlBW+L}LX_f5}J(yD$1@Jo+26Kb1? zORIfP0H1iY$3pF6#HH0fD1g1JK{gLlp>`i}X|*p3;D|@N5NaPKF0J-S0er`!T@AHq z;?ip06u^%>+Pd}aJUl{NTKxp@zDL^`YWEVCR(q=els%0NosQiZt_8NQ#qhfiqtkb_ z_XBcs@WRLnyYHF?ve7TF{_%Vaieuo{Z&3x+f)OT z53LQFi@H~;yBZ4`I~v2PRT`hFZ>lc`p%(EpU+RSG7@UPPEMqtd^{~JN2H++an}u#z z=i=ue11mJ<3^c+bjVnMcT%@smut~#b-~g=Bkx6KRD|GAvWZ*g-oq{~9((!Z932P+c zJRE>^5;F~5aDzn6fY$R35_bW5Kx_Cc^uiX2)witMb13WraOjy~O*-PVQH?xe@B%j1_qr+XLi z%iQYsj>k{q&$h2n#ZM2F>c>v|cHfBhS7YA#EmZSkqEz}F_4sLCS-%=|zW(arwbJix zkDunC_3N+Z$7HPZ``F{Bd1?LTs`)YbEBy}rIhsGsb?diT%`bne-&-C()e7s^jn13D z)?58P@c5~=S-BE$;?`E{WX^-DZ6+hLYO1}w@-ybb!BUXJ(8 z+4g?Zsya5E)O}=-k*AWwNJOc!|}ektns~z->=&|s9n}~G2S=ZX?$Pt`06># z`Zl18$wRio`2O7E`+&z+f3FbZ+iHF9^2YgdVH_O`{wedy{~(GQyyRay+T|b_8Z?{d3^8j_@0gT zZME&)$2qjkgZ3_V{9cRq&DweVU5~G7zV*$Zi_3!@XF4~x?N!aTzEknOdD~u{n_FLw zIU>ILdxe~uuzI7ho^?2W` z9cMZ>x4s+=McUhiE+!9ITOV|8ZhbjsiuleX_}coQb93v*9)DSTVL%& z?0xG&7n6saUFUUfZhbjQi}=pQ`{wOBuXA(j%duRc+?e+TT;82|tP literal 0 HcmV?d00001 -- 2.7.4