Publishing R3
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_depthwise_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
10 #include "test_graph.hpp"
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
15
16 using namespace ::testing;
17 using namespace std;
18 using namespace mkldnn;
19
20 struct depthwise_test_params {
21     mkldnn::algorithm alg;
22
23     struct {
24         size_t n;
25         size_t c;
26         size_t h;
27         size_t w;
28     } in;
29
30     bool isBroadcast;
31
32     size_t num_prim_desc;
33
34     MKLDNNPlugin::impl_desc_type selectedType;
35     std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
36
37     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
38 };
39
40 template <typename data_t>
41 void ref_depthwise(const InferenceEngine::TBlob<data_t> &src, const data_t *weights, const size_t weightsSize,
42                    InferenceEngine::TBlob<data_t> &dst, depthwise_test_params prm) {
43     size_t IW = src.dims()[3];
44     size_t IH = src.dims()[2];
45     size_t IC = src.dims()[1];
46     size_t MB = src.dims()[0];
47
48     const data_t *src_data = src.readOnly();
49     const data_t *weights_data = weights;
50     size_t bias_offset = prm.isBroadcast ? 1 : IC;
51     const data_t *bias_data = weights_data + bias_offset;
52     data_t *dst_data = dst.data();
53
54     for(int mb = 0; mb < MB; mb++) {
55         for(int c = 0; c < IC; c++) {
56             for(int h = 0; h < IH; h++) {
57                 for(int w = 0; w < IW; w++) {
58                     int idx = mb * IC * IH * IW
59                               + c * IH * IW
60                               + h * IW + w;
61
62                     int widx = prm.isBroadcast ? 0 : c;
63                     int bidx = prm.isBroadcast ? 0 : c;
64
65                     if (prm.alg == depthwise_scale_shift)
66                         dst_data[idx] = src_data[idx] * weights_data[widx] + bias_data[bidx];
67                     else if (prm.alg == depthwise_prelu)
68                         dst_data[idx] = src_data[idx] > 0 ? src_data[idx] : src_data[idx]*weights_data[widx];
69                 }
70             }
71         }
72     }
73 }
74
75 class MKLDNNGraphDepthwiseTests: public TestsCommon,
76                                      public WithParamInterface<depthwise_test_params> {
77     std::string model_t = R"V0G0N(
78 <Net Name="Lrn_Only" version="2" precision="FP32" batch="1">
79     <layers>
80         <layer name="in1" type="Input" precision="FP32" id="0">
81             <output>
82                 <port id="0">
83                     <dim>_IN_</dim>
84                     <dim>_IC_</dim>
85                     <dim>_IH_</dim>
86                     <dim>_IW_</dim>
87                 </port>
88             </output>
89         </layer>
90         <layer name="depthwise" id="1" type="_LT_" precision="FP32">
91             <data _P_NAME_="_P_VAL_"  PrimitivesPriority="_IMPLS_"/>
92             <weights offset="0" size="_S1_" />
93             <biases offset="_S1_" size="_S2_" />
94
95             <input>
96                 <port id="1">
97                     <dim>_IN_</dim>
98                     <dim>_IC_</dim>
99                     <dim>_IH_</dim>
100                     <dim>_IW_</dim>
101                 </port>
102             </input>
103             <output>
104                 <port id="2">
105                     <dim>_IN_</dim>
106                     <dim>_IC_</dim>
107                     <dim>_IH_</dim>
108                     <dim>_IW_</dim>
109                 </port>
110             </output>
111         </layer>
112     </layers>
113     <edges>
114         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
115     </edges>
116 </Net>
117 )V0G0N";
118
119 protected:
120     std::string getModel(depthwise_test_params p) {
121         std::string model = model_t;
122
123         REPLACE_WITH_NUM(model, "_IW_", p.in.w);
124         REPLACE_WITH_NUM(model, "_IH_", p.in.h);
125         REPLACE_WITH_NUM(model, "_IC_", p.in.c);
126         REPLACE_WITH_NUM(model, "_IN_", p.in.n);
127
128         if (p.alg == depthwise_scale_shift) {
129             REPLACE_WITH_STR(model, "_LT_", "ScaleShift");
130             REPLACE_WITH_STR(model, "_P_NAME_", "broadcast");
131             REPLACE_WITH_NUM(model, "_P_VAL_", p.isBroadcast ? 1 : 0);
132         }
133         else if (p.alg == depthwise_prelu) {
134             REPLACE_WITH_STR(model, "_LT_", "PReLU");
135             REPLACE_WITH_STR(model, "_P_NAME_", "channel_shared");
136             REPLACE_WITH_NUM(model, "_P_VAL_", p.isBroadcast ? 1 : 0);
137         }
138
139         size_t array_size =  p.isBroadcast ? 1 : p.in.c;
140         size_t w_data_size = array_size * sizeof(float);
141         size_t b_data_size = array_size * sizeof(float);
142         REPLACE_WITH_NUM(model, "_S1_", w_data_size);
143         REPLACE_WITH_NUM(model, "_S2_", b_data_size);
144
145         std::string impls;
146         for (const auto& preferType : p.preferTypes) {
147             if (!impls.empty())
148                 impls += ",";
149             impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
150         }
151         REPLACE_WITH_STR(model, "_IMPLS_", impls);
152
153         return model;
154     }
155
156     virtual void SetUp() {
157         try {
158             TestsCommon::SetUp();
159             depthwise_test_params p = ::testing::WithParamInterface<depthwise_test_params>::GetParam();
160             std::string model = getModel(p);
161
162             InferenceEngine::CNNNetReader net_reader;
163             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
164
165             size_t weightSize = 2*p.in.c*sizeof(float);
166             InferenceEngine::TBlob<uint8_t> *weights = new InferenceEngine::TBlob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {weightSize});
167             weights->allocate();
168             fill_data( weights->data().as<float*>(), weights->size() / sizeof(float));
169
170             InferenceEngine::TBlob<uint8_t>::Ptr weights_ptr = InferenceEngine::TBlob<uint8_t>::Ptr(weights);
171
172             net_reader.SetWeights(weights_ptr);
173
174             MKLDNNGraphTestClass graph;
175             graph.CreateGraph(net_reader.getNetwork());
176             auto& nodes = graph.getNodes();
177             for (int i = 0; i < nodes.size(); i++) {
178                 if (nodes[i]->getType() == MKLDNNPlugin::Depthwise) {
179                     ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
180                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
181                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
182                     }
183                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
184                     ASSERT_EQ(p.selectedType,
185                               nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
186                 }
187             }
188
189             InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
190
191             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
192             src->allocate();
193             fill_data(src->buffer(), src->size());
194
195             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
196
197             if (srcPtr == nullptr)
198                 FAIL() << "Cannot cast blob to TBlob<float>.";
199
200             InferenceEngine::BlobMap srcs;
201             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
202
203             InferenceEngine::OutputsDataMap out;
204             out = net_reader.getNetwork().getOutputsInfo();
205             InferenceEngine::BlobMap outputBlobs;
206
207             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
208
209             InferenceEngine::TBlob<float>::Ptr output;
210             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
211             output->allocate();
212             outputBlobs[item.first] = output;
213
214             graph.Infer(srcs, outputBlobs);
215
216             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
217             dst_ref.allocate();
218
219             ref_depthwise(*srcPtr, weights->readOnly().as<const float*>(), weights->size() / sizeof(float), dst_ref, p);
220
221             compare(*output, dst_ref);
222         } catch (const InferenceEngine::details::InferenceEngineException &e) {
223             FAIL() << e.what();
224         }
225     }
226 };
227
228 TEST_P(MKLDNNGraphDepthwiseTests, TestsDepthwise) {}
229
230 INSTANTIATE_TEST_CASE_P(
231         TestsDepthwise, MKLDNNGraphDepthwiseTests,
232         ::testing::Values(
233                 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
234                 depthwise_test_params{depthwise_scale_shift, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
235                 depthwise_test_params{depthwise_scale_shift, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
236                 depthwise_test_params{depthwise_scale_shift, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
237                 depthwise_test_params{depthwise_scale_shift, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
238                 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
239                 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, false,3, MKLDNNPlugin::impl_desc_type::jit},
240                 depthwise_test_params{depthwise_prelu, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
241                 depthwise_test_params{depthwise_prelu, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
242                 depthwise_test_params{depthwise_prelu, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
243                 depthwise_test_params{depthwise_prelu, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
244                 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
245                 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
246                 depthwise_test_params{depthwise_scale_shift, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
247                 depthwise_test_params{depthwise_scale_shift, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
248                 depthwise_test_params{depthwise_scale_shift, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
249                 depthwise_test_params{depthwise_scale_shift, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
250                 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
251                 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, false,3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
252                 depthwise_test_params{depthwise_prelu, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
253                 depthwise_test_params{depthwise_prelu, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
254                 depthwise_test_params{depthwise_prelu, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
255                 depthwise_test_params{depthwise_prelu, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
256                 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}
257         ));
258
259 class MKLDNNGraphDynBatchDepthwiseTests: public MKLDNNGraphDepthwiseTests {
260 protected:
261
262     virtual void SetUp() {
263         try {
264             TestsCommon::SetUp();
265             depthwise_test_params p = ::testing::WithParamInterface<depthwise_test_params>::GetParam();
266             std::string model = getModel(p);
267             size_t MB = p.in.n;
268             if (MB < 2)
269                 MB = 2;
270
271             InferenceEngine::CNNNetReader net_reader;
272             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
273
274             InferenceEngine::TBlob<uint8_t> *weights = new InferenceEngine::TBlob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {p.in.c * 4 * sizeof(float)});
275             weights->allocate();
276             fill_data( weights->data().as<float*>(), weights->size() / sizeof(float));
277             float * data = weights->buffer();
278             for (size_t i = 0; i < weights->size() / sizeof(float); i++) {
279                 if (data[i] < 0) {
280                     data[i] *= -1;
281                 }
282             }
283             InferenceEngine::TBlob<uint8_t>::Ptr weights_ptr = InferenceEngine::TBlob<uint8_t>::Ptr(weights);
284             net_reader.SetWeights(weights_ptr);
285             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
286             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
287             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
288             InferenceEngine::ResponseDesc resp;
289             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
290             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
291
292
293             MKLDNNGraphTestClass graph;
294             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
295             graph.CreateGraph(net_reader.getNetwork());
296
297             InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
298             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
299             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
300             if (srcPtr == nullptr)
301                 FAIL() << "Cannot cast blob to TBlob<float>.";
302
303             src->allocate();
304             fill_data(src->buffer(), src->size());
305
306             InferenceEngine::BlobMap srcs;
307             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
308
309             InferenceEngine::OutputsDataMap out;
310             out = net_reader.getNetwork().getOutputsInfo();
311             InferenceEngine::BlobMap outputBlobs;
312
313             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
314
315             InferenceEngine::TBlob<float>::Ptr output;
316             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
317             output->allocate();
318             outputBlobs[item.first] = output;
319
320             auto checkDepthwise = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
321                 return node->getType() == MKLDNNPlugin::Depthwise;
322             };
323
324             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkDepthwise);
325             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkDepthwise);
326         } catch (const InferenceEngine::details::InferenceEngineException &e) {
327             FAIL() << e.what();
328         }
329     }
330 };
331
332 TEST_P(MKLDNNGraphDynBatchDepthwiseTests, TestsDynBatchDepthwise) {}
333
334 INSTANTIATE_TEST_CASE_P(
335         TestsDynBatchDepthwise, MKLDNNGraphDynBatchDepthwiseTests,
336         ::testing::Values(
337                 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
338                 depthwise_test_params{depthwise_scale_shift, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
339                 depthwise_test_params{depthwise_scale_shift, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
340                 depthwise_test_params{depthwise_scale_shift, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
341                 depthwise_test_params{depthwise_scale_shift, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
342                 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
343                 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, false,3, MKLDNNPlugin::impl_desc_type::jit},
344                 depthwise_test_params{depthwise_prelu, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
345                 depthwise_test_params{depthwise_prelu, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
346                 depthwise_test_params{depthwise_prelu, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
347                 depthwise_test_params{depthwise_prelu, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
348                 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
349                 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
350                 depthwise_test_params{depthwise_scale_shift, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
351                 depthwise_test_params{depthwise_scale_shift, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
352                 depthwise_test_params{depthwise_scale_shift, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
353                 depthwise_test_params{depthwise_scale_shift, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
354                 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
355                 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, false,3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
356                 depthwise_test_params{depthwise_prelu, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
357                 depthwise_test_params{depthwise_prelu, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
358                 depthwise_test_params{depthwise_prelu, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
359                 depthwise_test_params{depthwise_prelu, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
360                 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}
361         ));