Publishing R3
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_eltwise_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
10
11 #include "test_graph.hpp"
12
13 #include "single_layer_common.hpp"
14 #include <mkldnn_plugin/mkldnn_extension_utils.h>
15 #include <inference_engine/cnn_network_impl.hpp>
16 #include "tests_common.hpp"
17
18 using namespace ::testing;
19 using namespace std;
20 using namespace mkldnn;
21
22 struct eltwise_test_params {
23     struct {
24         size_t n;
25         size_t c;
26         size_t h;
27         size_t w;
28     } in;
29
30     enum opType {
31         Sum = 0, Prod = 1, Max = 2
32     };
33
34     opType op;
35
36     std::string scales;
37
38     size_t num_prim_desc;
39
40     MKLDNNPlugin::impl_desc_type selectedType;
41
42     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
43 };
44
45 template<typename data_t>
46 void ref_eltwise(const std::vector<InferenceEngine::TBlob<data_t>> &src, InferenceEngine::TBlob<data_t> &dst, eltwise_test_params prm) {
47     std::vector<float> scales;
48     if (prm.scales != "") {
49         std::istringstream stream(prm.scales);
50         std::string str;
51         while (getline(stream, str, ',')) {
52             float val = std::stof(str);
53             scales.push_back(val);
54         }
55     } else {
56         for (int i = 0; i < src.size(); i++) {
57             scales.push_back(1.0f);
58         }
59     }
60
61     data_t *dst_data = dst.data();
62
63     const data_t *src_data = src[0].readOnly();
64
65     for (int i = 0; i < src[0].size(); i++) {
66         switch (prm.op) {
67             case eltwise_test_params::Sum: {
68                 dst_data[i] = scales[0]*src_data[i];
69             }
70                 break;
71             default: {
72                 dst_data[i] = src_data[i];
73             }
74         }
75     }
76
77     for (int n = 1; n < src.size(); n++) {
78         src_data = src[n].readOnly();
79
80         for (int i = 0; i < src[n].size(); i++) {
81             switch (prm.op) {
82                 case eltwise_test_params::Sum: {
83                     dst_data[i] += scales[n]*src_data[i];
84                 }
85                     break;
86
87                 case eltwise_test_params::Prod: {
88                     dst_data[i] *= src_data[i];
89                 }
90                     break;
91
92                 case eltwise_test_params::Max: {
93                     dst_data[i] = (std::max)(dst_data[i], src_data[i]);
94                 }
95                     break;
96             }
97         }
98     }
99 }
100
101 class MKLDNNGraphEltwiseTests: public TestsCommon,
102                                      public WithParamInterface<eltwise_test_params> {
103     std::string model_t = R"V0G0N(
104 <net name="EltwiseOnly" version="2" precision="FP32" batch="1">
105     <layers>
106         <layer name="in1" type="Input" precision="FP32" id="1">
107             <output>
108                 <port id="1">
109                     <dim>_IN_</dim>
110                     <dim>_IC_</dim>
111                     <dim>_IH_</dim>
112                     <dim>_IW_</dim>
113                 </port>
114             </output>
115         </layer>
116         <layer name="in2" type="Input" precision="FP32" id="2">
117             <output>
118                 <port id="2">
119                     <dim>_IN_</dim>
120                     <dim>_IC_</dim>
121                     <dim>_IH_</dim>
122                     <dim>_IW_</dim>
123                 </port>
124             </output>
125         </layer>
126         <layer name="in3" type="Input" precision="FP32" id="3">
127             <output>
128                 <port id="3">
129                     <dim>_IN_</dim>
130                     <dim>_IC_</dim>
131                     <dim>_IH_</dim>
132                     <dim>_IW_</dim>
133                 </port>
134             </output>
135         </layer>
136         <layer name="con" id="4" type="Eltwise" precision="FP32">
137             <elementwise_data operation="_OP_" coeff="_COEFF_"/>
138             <input>
139                 <port id="1">
140                     <dim>_IN_</dim>
141                     <dim>_IC_</dim>
142                     <dim>_IH_</dim>
143                     <dim>_IW_</dim>
144                 </port>
145                 <port id="2">
146                     <dim>_IN_</dim>
147                     <dim>_IC_</dim>
148                     <dim>_IH_</dim>
149                     <dim>_IW_</dim>
150                 </port>
151                 <port id="3">
152                     <dim>_IN_</dim>
153                     <dim>_IC_</dim>
154                     <dim>_IH_</dim>
155                     <dim>_IW_</dim>
156                 </port>
157             </input>
158             <output>
159                 <port id="4">
160                     <dim>_IN_</dim>
161                     <dim>_IC_</dim>
162                     <dim>_IH_</dim>
163                     <dim>_IW_</dim>
164                 </port>
165             </output>
166         </layer>
167     </layers>
168     <edges>
169         <edge from-layer="1" from-port="1" to-layer="4" to-port="1"/>
170         <edge from-layer="2" from-port="2" to-layer="4" to-port="2"/>
171         <edge from-layer="3" from-port="3" to-layer="4" to-port="3"/>
172     </edges>
173 </net>
174 )V0G0N";
175
176 protected:
177     std::string getModel(eltwise_test_params p) {
178         std::string model = model_t;
179         std::string op;
180
181         if (p.op == 0) {
182             op = "sum";
183         } else if (p.op == 1) {
184             op = "mul";
185         } else if (p.op == 2) {
186             op = "max";
187         }
188
189         REPLACE_WITH_NUM(model, "_IW_", p.in.w);
190         REPLACE_WITH_NUM(model, "_IH_", p.in.h);
191         REPLACE_WITH_NUM(model, "_IC_", p.in.c);
192         REPLACE_WITH_NUM(model, "_IN_", p.in.n);
193         REPLACE_WITH_STR(model, "_OP_", op);
194         REPLACE_WITH_STR(model, "_COEFF_", p.scales);
195         return model;
196     }
197
198     virtual void TearDown() {
199     }
200
201     virtual void SetUp() {
202         try {
203             TestsCommon::SetUp();
204             eltwise_test_params p = ::testing::WithParamInterface<eltwise_test_params>::GetParam();
205             std::string model = getModel(p);
206
207             InferenceEngine::CNNNetReader net_reader;
208             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
209
210             MKLDNNGraphTestClass graph;
211             graph.CreateGraph(net_reader.getNetwork());
212
213             auto& nodes = graph.getNodes();
214             for (int i = 0; i < nodes.size(); i++) {
215                 if (nodes[i]->getType() == MKLDNNPlugin::Eltwise) {
216                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
217                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
218                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
219                     }
220                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
221                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
222                 }
223             }
224
225             InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
226
227             InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
228             src1->allocate();
229
230             InferenceEngine::TBlob<float>* srcPtr1 = dynamic_cast<InferenceEngine::TBlob<float>*>(src1.get());
231
232             if (srcPtr1 == nullptr)
233                 FAIL() << "Cannot cast blob to TBlob<float>.";
234
235             fill_data(src1->buffer(), src1->size());
236             InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
237             src2->allocate();
238
239             InferenceEngine::TBlob<float>* srcPtr2 = dynamic_cast<InferenceEngine::TBlob<float>*>(src2.get());
240
241             if (srcPtr2 == nullptr)
242                 FAIL() << "Cannot cast blob to TBlob<float>.";
243             fill_data(src2->buffer(), src2->size());
244             InferenceEngine::Blob::Ptr src3 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
245             src3->allocate();
246
247             InferenceEngine::TBlob<float>* srcPtr3 = dynamic_cast<InferenceEngine::TBlob<float>*>(src3.get());
248
249             if (srcPtr3 == nullptr)
250                 FAIL() << "Cannot cast blob to TBlob<float>.";
251             fill_data(src3->buffer(), src3->size());
252             InferenceEngine::BlobMap srcs;
253             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
254             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
255             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in3", src3));
256
257             InferenceEngine::OutputsDataMap out;
258             out = net_reader.getNetwork().getOutputsInfo();
259             InferenceEngine::BlobMap outputBlobs;
260
261             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
262
263             InferenceEngine::TBlob<float>::Ptr output;
264             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
265             output->allocate();
266             outputBlobs[item.first] = output;
267
268             graph.Infer(srcs, outputBlobs);
269
270             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
271             dst_ref.allocate();
272
273             std::vector<InferenceEngine::TBlob<float>> src_vec = {*srcPtr1, *srcPtr2, *srcPtr3};
274
275             ref_eltwise(src_vec, dst_ref, p);
276
277             compare(*output, dst_ref);
278         } catch (const InferenceEngine::details::InferenceEngineException &e) {
279             FAIL() << e.what();
280         }
281     }
282 };
283
284 TEST_P(MKLDNNGraphEltwiseTests, TestsEltwise) {}
285
286
287 INSTANTIATE_TEST_CASE_P(
288         TestsEltwise, MKLDNNGraphEltwiseTests,
289         ::testing::Values(
290                 eltwise_test_params{{1, 3, 3, 3}, eltwise_test_params::opType::Sum, "", 3, MKLDNNPlugin::impl_desc_type::ref, {
291                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
292                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
293                             ASSERT_EQ(3, impl.getConfig().inConfs.size());
294                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
295                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
296                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
297                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
298                         }
299                 } },
300                 eltwise_test_params{{1, 3, 3, 3}, eltwise_test_params::opType::Sum, "1.0,1.0,1.0", 3, MKLDNNPlugin::impl_desc_type::ref, {
301                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
302                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
303                             ASSERT_EQ(3, impl.getConfig().inConfs.size());
304                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
305                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
306                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
307                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
308                         }
309                 } },
310                 eltwise_test_params{{1, 3, 3, 3}, eltwise_test_params::opType::Sum, "1.5,0.5,-2.0", 3, MKLDNNPlugin::impl_desc_type::ref, {
311                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
312                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
313                             ASSERT_EQ(3, impl.getConfig().inConfs.size());
314                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
315                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
316                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
317                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
318                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
319                         }
320                 } },
321                 eltwise_test_params{{1, 3, 3, 3}, eltwise_test_params::opType::Prod, "", 3, MKLDNNPlugin::impl_desc_type::ref, {
322                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
323                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
324                             ASSERT_EQ(3, impl.getConfig().inConfs.size());
325                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
326                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
327                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
328                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
329                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
330                         }
331                 } },
332                 eltwise_test_params{{1, 3, 3, 3}, eltwise_test_params::opType::Max, "", 3, MKLDNNPlugin::impl_desc_type::ref, {
333                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
334                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
335                             ASSERT_EQ(3, impl.getConfig().inConfs.size());
336                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
337                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
338                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(1).desc.getLayout());
339                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(2).desc.getLayout());
340                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
341                         }
342                 } }
343         ));
344
345 class MKLDNNGraphDynBatchEltwiseTests: public MKLDNNGraphEltwiseTests {
346 protected:
347     virtual void SetUp() {
348         try {
349             TestsCommon::SetUp();
350             eltwise_test_params p = ::testing::WithParamInterface<eltwise_test_params>::GetParam();
351             std::string model = getModel(p);
352             size_t MB = p.in.n;
353             if (MB < 2)
354                 MB = 2;
355
356             InferenceEngine::CNNNetReader net_reader;
357             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
358             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
359             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
360             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
361             InferenceEngine::ResponseDesc resp;
362             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
363             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
364
365             MKLDNNGraphTestClass graph;
366             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
367             graph.CreateGraph(net_reader.getNetwork());
368
369             InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
370
371             InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
372             src1->allocate();
373
374             InferenceEngine::TBlob<float>* srcPtr1 = dynamic_cast<InferenceEngine::TBlob<float>*>(src1.get());
375
376             if (srcPtr1 == nullptr)
377                 FAIL() << "Cannot cast blob to TBlob<float>.";
378
379             fill_data(src1->buffer(), src1->size());
380             InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
381             src2->allocate();
382
383             InferenceEngine::TBlob<float>* srcPtr2 = dynamic_cast<InferenceEngine::TBlob<float>*>(src2.get());
384
385             if (srcPtr2 == nullptr)
386                 FAIL() << "Cannot cast blob to TBlob<float>.";
387             fill_data(src2->buffer(), src2->size());
388             InferenceEngine::Blob::Ptr src3 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
389             src3->allocate();
390
391             InferenceEngine::TBlob<float>* srcPtr3 = dynamic_cast<InferenceEngine::TBlob<float>*>(src3.get());
392
393             if (srcPtr3 == nullptr)
394                 FAIL() << "Cannot cast blob to TBlob<float>.";
395             fill_data(src3->buffer(), src3->size());
396             InferenceEngine::BlobMap srcs;
397             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
398             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
399             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in3", src3));
400
401             InferenceEngine::OutputsDataMap out;
402             out = net_reader.getNetwork().getOutputsInfo();
403             InferenceEngine::BlobMap outputBlobs;
404
405             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
406
407             InferenceEngine::TBlob<float>::Ptr output;
408             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
409             output->allocate();
410             outputBlobs[item.first] = output;
411
412
413             auto checkDepthwise = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
414                 return node->getType() == MKLDNNPlugin::Eltwise;
415             };
416
417             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkDepthwise);
418             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkDepthwise);
419         } catch (const InferenceEngine::details::InferenceEngineException &e) {
420             FAIL() << e.what();
421         }
422     }
423 };
424
425 TEST_P(MKLDNNGraphDynBatchEltwiseTests, TestsDynBatchEltwise) {}
426
427 INSTANTIATE_TEST_CASE_P(
428         TestsDynBatchEltwise, MKLDNNGraphDynBatchEltwiseTests,
429         ::testing::Values(
430                 eltwise_test_params{{1, 3, 3, 3}, eltwise_test_params::opType::Sum, "", 3, MKLDNNPlugin::impl_desc_type::ref},
431                 eltwise_test_params{{1, 3, 3, 3}, eltwise_test_params::opType::Sum, "1.0,1.0,1.0", 3, MKLDNNPlugin::impl_desc_type::ref},
432                 eltwise_test_params{{1, 3, 3, 3}, eltwise_test_params::opType::Sum, "1.5,0.5,-2.0", 3, MKLDNNPlugin::impl_desc_type::ref},
433                 eltwise_test_params{{1, 3, 3, 3}, eltwise_test_params::opType::Prod, "", 3, MKLDNNPlugin::impl_desc_type::ref},
434                 eltwise_test_params{{1, 3, 3, 3}, eltwise_test_params::opType::Max, "", 3, MKLDNNPlugin::impl_desc_type::ref}));