Publishing R3
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_conv_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
10
11 #include "test_graph.hpp"
12
13 #include "single_layer_common.hpp"
14 #include <mkldnn_plugin/mkldnn_extension_utils.h>
15 #include <inference_engine/cnn_network_impl.hpp>
16 #include "tests_common.hpp"
17
18
19 using namespace ::testing;
20 using namespace std;
21 using namespace mkldnn;
22
23
24 struct conv_test_params {
25     struct {
26         size_t n;
27         size_t c;
28         size_t h;
29         size_t w;
30     } in;
31
32     size_t krn_w;
33     size_t krn_h;
34     size_t str_w;
35     size_t str_h;
36     size_t pad_w;
37     size_t pad_h;
38
39     size_t out_c;
40     size_t grp_c;
41
42     size_t num_prim_desc;
43
44     int selectedType;
45     std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
46
47     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
48 };
49
50 template <typename data_t>
51 void ref_conv(const InferenceEngine::TBlob<data_t> &src, const data_t *weights, const size_t weightsSize,
52                 InferenceEngine::TBlob<data_t> &dst, conv_test_params prm) {
53     size_t KW = prm.krn_w;
54     size_t KH = prm.krn_h;
55     size_t GC = prm.grp_c;
56
57     size_t IC = src.dims()[1];
58     size_t IH = src.dims()[2];
59     size_t IW = src.dims()[3];
60
61     size_t OW = (IW + 2 * prm.pad_w - prm.krn_w) / prm.str_w + 1;
62     size_t OH = (IH + 2 * prm.pad_h - prm.krn_h) / prm.str_h + 1;
63     size_t OC = prm.out_c;
64
65
66     const data_t *src_data = src.readOnly();
67     const data_t *weights_data = weights;
68     const data_t *bias_data = weights_data + KW * KH * OC * IC / GC;
69     data_t *dst_data = dst.data();
70
71     IE_ASSERT(KW * KH * OC * IC / GC + OC == weightsSize);
72     IE_ASSERT(OW == dst.dims()[0]);
73     IE_ASSERT(OH == dst.dims()[1]);
74
75     for (uint32_t g = 0; g < GC; g++) {
76         for (uint32_t oc = 0; oc < OC / GC; oc++) {
77             for (uint32_t oh = 0; oh < OH; oh++) {
78                 for (uint32_t ow = 0; ow < OW; ow++) {
79                     size_t oidx = g * OC / GC * OH * OW
80                                   + oc * OH * OW + oh * OW + ow;
81                     dst_data[oidx] = bias_data[g * OC / GC + oc];
82
83                     for (size_t ic = 0; ic < IC / GC; ic++) {
84                         for (size_t kh = 0; kh < KH; kh++) {
85                             for (size_t kw = 0; kw < KW; kw++) {
86                                 int32_t iw = ow * prm.str_w - prm.pad_w + kw;
87                                 int32_t ih = oh * prm.str_h - prm.pad_h + kh;
88                                 if (iw < 0 || iw >= (int32_t)IW || ih < 0
89                                     || ih >= (int32_t)IH)
90                                     continue;
91                                 size_t iidx = g * IC / GC * IH * IW
92                                               + ic * IH * IW + ih * IW + iw;
93                                 size_t widx = g * OC / GC * IC / GC * KH * KW
94                                               + oc * IC / GC * KH * KW
95                                               + ic * KH * KW + kh * KW + kw;
96
97                                 dst_data[ oidx] += src_data[iidx] * weights_data[widx];
98                             }
99                         }
100                     }
101                 }
102             }
103         }
104     }
105 }
106
107 class MKLDNNGraphConvolutionTests: public TestsCommon,
108                                    public WithParamInterface<conv_test_params> {
109     std::string model_t = R"V0G0N(
110 <Net Name="Convolution_Only" version="2" precision="FP32" batch="1">
111     <layers>
112         <layer name="in1" type="Input" precision="FP32" id="0">
113             <output>
114                 <port id="0">
115                     <dim>_IN_</dim>
116                     <dim>_IC_</dim>
117                     <dim>_IH_</dim>
118                     <dim>_IW_</dim>
119                 </port>
120             </output>
121         </layer>
122         <layer name="conv1" id="1" type="Convolution" precision="FP32">
123             <convolution stride-x="_SW_" stride-y="_SH_"
124                          pad-x="_PW_"    pad-y="_PH_"
125                          kernel-x="_KW_" kernel-y="_KH_"
126                          output="_OC_"   group="_GC_" PrimitivesPriority="_IMPLS_"/>
127
128             <weights offset="0" size="_S1_" />
129             <biases offset="_S1_" size="_S2_" />
130
131             <input>
132                 <port id="1">
133                     <dim>_IN_</dim>
134                     <dim>_IC_</dim>
135                     <dim>_IH_</dim>
136                     <dim>_IW_</dim>
137                 </port>
138             </input>
139             <output>
140                 <port id="2">
141                     <dim>_IN_</dim>
142                     <dim>_OC_</dim>
143                     <dim>_OH_</dim>
144                     <dim>_OW_</dim>
145                 </port>
146             </output>
147         </layer>
148     </layers>
149     <edges>
150         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
151     </edges>
152 </Net>
153 )V0G0N";
154
155 protected:
156     std::string getModel(conv_test_params p) {
157         std::string model = model_t;
158         REPLACE_WITH_NUM(model, "_IW_", p.in.w);
159         REPLACE_WITH_NUM(model, "_IH_", p.in.h);
160         REPLACE_WITH_NUM(model, "_IC_", p.in.c);
161         REPLACE_WITH_NUM(model, "_IN_", p.in.n);
162
163         REPLACE_WITH_NUM(model, "_KW_", p.krn_w);
164         REPLACE_WITH_NUM(model, "_KH_", p.krn_h);
165         REPLACE_WITH_NUM(model, "_SW_", p.str_w);
166         REPLACE_WITH_NUM(model, "_SH_", p.str_h);
167         REPLACE_WITH_NUM(model, "_PW_", p.pad_w);
168         REPLACE_WITH_NUM(model, "_PH_", p.pad_h);
169
170         REPLACE_WITH_NUM(model, "_GC_", p.grp_c);
171         REPLACE_WITH_NUM(model, "_OC_", p.out_c);
172         REPLACE_WITH_NUM(model, "_OH_", (p.in.h + 2 * p.pad_h - p.krn_h) / p.str_h + 1);
173         REPLACE_WITH_NUM(model, "_OW_", (p.in.w + 2 * p.pad_w - p.krn_w) / p.str_w + 1);
174
175         size_t w_data_size = (p.krn_w * p.krn_h * p.out_c * p.in.c / p.grp_c) * sizeof(float);
176         size_t b_data_size = p.out_c * sizeof(float);
177         REPLACE_WITH_NUM(model, "_S1_", w_data_size);
178         REPLACE_WITH_NUM(model, "_S2_", b_data_size);
179         std::string impls;
180         for (const auto& preferType : p.preferTypes) {
181             if (!impls.empty())
182                 impls += ",";
183             impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
184         }
185         REPLACE_WITH_STR(model, "_IMPLS_", impls);
186         return model;
187     }
188
189     virtual void TearDown() {
190     }
191
192     virtual void SetUp() {
193         try {
194             TestsCommon::SetUp();
195             conv_test_params p = ::testing::WithParamInterface<conv_test_params>::GetParam();
196             std::string model = getModel(p);
197
198             InferenceEngine::CNNNetReader net_reader;
199             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
200
201             InferenceEngine::TBlob<uint8_t> *weights = new InferenceEngine::TBlob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {(p.krn_w * p.krn_h * p.out_c * p.in.c / p.grp_c + p.out_c)
202                                                               * sizeof(float)});
203             weights->allocate();
204             fill_data((float *) weights->buffer(), weights->size() / sizeof(float));
205             InferenceEngine::TBlob<uint8_t>::Ptr weights_ptr = InferenceEngine::TBlob<uint8_t>::Ptr(weights);
206
207             net_reader.SetWeights(weights_ptr);
208
209             MKLDNNGraphTestClass graph;
210             graph.CreateGraph(net_reader.getNetwork());
211
212             auto& nodes = graph.getNodes();
213             nodes = graph.getNodes();
214             for (auto &node : nodes) {
215                 if (node->getType() == MKLDNNPlugin::Convolution) {
216                     ASSERT_LE(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size());
217                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
218                         p.comp.at(j)(node->getSupportedPrimitiveDescriptors().at(j));
219                     }
220                     ASSERT_NE(nullptr, node->getSelectedPrimitiveDescriptor());
221                     ASSERT_EQ(p.selectedType,
222                               node->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
223                 }
224             }
225
226             InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
227
228             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
229             src->allocate();
230             fill_data(src->buffer(), src->size());
231
232             auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
233
234             if (srcPtr == nullptr)
235                 FAIL() << "Cannot cast blob to TBlob<float>.";
236
237             InferenceEngine::BlobMap srcs;
238             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
239
240             InferenceEngine::OutputsDataMap out;
241             out = net_reader.getNetwork().getOutputsInfo();
242             InferenceEngine::BlobMap outputBlobs;
243
244             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
245
246             InferenceEngine::TBlob<float>::Ptr output;
247             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
248             output->allocate();
249             outputBlobs[item.first] = output;
250
251             graph.Infer(srcs, outputBlobs);
252
253
254             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
255             dst_ref.allocate();
256             ref_conv(*srcPtr, (const float *)weights->buffer(), weights->size() / sizeof(float), dst_ref, p);
257             compare(*output, dst_ref);
258         } catch (const InferenceEngine::details::InferenceEngineException &e) {
259             FAIL() << e.what();
260         }
261     }
262 };
263
264 TEST_P(MKLDNNGraphConvolutionTests, TestsConvolution) {}
265
266 INSTANTIATE_TEST_CASE_P(
267         TestConvolution, MKLDNNGraphConvolutionTests,
268         ::testing::Values(
269                 conv_test_params{{1, 9, 16, 32},
270                                  1, 1, 1, 1, 0, 0, 17, 1, 7, MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1,
271                 },
272                 conv_test_params{{1, 9, 32, 16},
273                                  2, 4, 1, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
274                 conv_test_params{{1, 9, 32, 16},
275                                  2, 4, 2, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
276                 conv_test_params{{1, 3, 40, 40},
277                                  3, 3, 1, 2, 0, 0, 20, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
278                 conv_test_params{{1, 1, 40, 40},
279                                  3, 3, 1, 2, 0, 0, 20, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
280                 conv_test_params{{1, 1, 32, 16},
281                                  2, 4, 2, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
282                 conv_test_params{{1, 9, 16, 32},
283                                  1, 1, 1, 1, 0, 0, 17, 1, 7, MKLDNNPlugin::impl_desc_type::gemm,
284                                  {MKLDNNPlugin::impl_desc_type::gemm_any,
285                                   MKLDNNPlugin::impl_desc_type::gemm_blas,
286                                   MKLDNNPlugin::impl_desc_type::gemm_avx512,
287                                   MKLDNNPlugin::impl_desc_type::gemm_avx2,
288                                   MKLDNNPlugin::impl_desc_type::gemm_sse42}
289                 },
290                 conv_test_params{{1, 9, 32, 16},
291                                  2, 4, 1, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::ref_any, {MKLDNNPlugin::impl_desc_type::ref_any} }));
292
293 class MKLDNNGraphDynBatchConvolutionTests: public MKLDNNGraphConvolutionTests {
294 protected:
295     virtual void SetUp() {
296         try {
297             TestsCommon::SetUp();
298             conv_test_params p = ::testing::WithParamInterface<conv_test_params>::GetParam();
299             std::string model = getModel(p);
300             size_t MB = p.in.n;
301             if (MB < 2)
302                 MB = 2;
303
304             InferenceEngine::CNNNetReader net_reader;
305             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
306
307             InferenceEngine::TBlob<uint8_t> *weights = new InferenceEngine::TBlob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C,
308                     {(p.krn_w * p.krn_h * p.out_c * p.in.c / p.grp_c + p.out_c) * sizeof(float)});
309             weights->allocate();
310             fill_data((float *) weights->buffer(), weights->size() / sizeof(float));
311             InferenceEngine::TBlob<uint8_t>::Ptr weights_ptr = InferenceEngine::TBlob<uint8_t>::Ptr(weights);
312
313             net_reader.SetWeights(weights_ptr);
314             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
315             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
316             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
317             InferenceEngine::ResponseDesc resp;
318             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
319             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
320
321             MKLDNNGraphTestClass graph;
322             graph.CreateGraph(net_reader.getNetwork());
323
324             InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
325
326             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
327             src->allocate();
328             fill_data(src->buffer(), src->size());
329
330             auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
331
332             if (srcPtr == nullptr)
333                 FAIL() << "Cannot cast blob to TBlob<float>.";
334
335             InferenceEngine::BlobMap srcs;
336             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
337
338             InferenceEngine::OutputsDataMap out;
339             out = net_reader.getNetwork().getOutputsInfo();
340             InferenceEngine::BlobMap outputBlobs;
341
342             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
343
344             InferenceEngine::TBlob<float>::Ptr output;
345             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
346             output->allocate();
347             outputBlobs[item.first] = output;
348
349             auto checkConvolution = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
350                 return node->getType() == MKLDNNPlugin::Convolution ||
351                        node->getType() == MKLDNNPlugin::Convolution_Activation ||
352                        node->getType() == MKLDNNPlugin::Convolution_Sum ||
353                        node->getType() == MKLDNNPlugin::Convolution_Sum_Activation;
354             };
355
356             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkConvolution, MKLDNNGraphTestClass::CheckDynBatchType::Child);
357             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkConvolution, MKLDNNGraphTestClass::CheckDynBatchType::Child);
358         } catch (const InferenceEngine::details::InferenceEngineException &e) {
359             FAIL() << e.what();
360         }
361     }
362 };
363
364 TEST_P(MKLDNNGraphDynBatchConvolutionTests, TestsDynBatchConvolution) {}
365
366 INSTANTIATE_TEST_CASE_P(
367         TestDynBatchConvolution, MKLDNNGraphDynBatchConvolutionTests,
368         ::testing::Values(
369                 conv_test_params{{1, 8, 16, 32},
370                                  1, 1, 1, 1, 0, 0, 17, 1, 7, MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1,
371                 },
372                 conv_test_params{{1, 9, 32, 16},
373                                  2, 4, 1, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
374                 conv_test_params{{1, 9, 32, 16},
375                                  2, 4, 2, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
376                 conv_test_params{{1, 3, 40, 40},
377                                  3, 3, 1, 2, 0, 0, 20, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
378                 conv_test_params{{1, 1, 40, 40},
379                                  3, 3, 1, 2, 0, 0, 20, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
380                 conv_test_params{{1, 1, 32, 16},
381                                  2, 4, 2, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
382                 conv_test_params{{1, 9, 16, 32},
383                                  1, 1, 1, 1, 0, 0, 17, 1, 7, MKLDNNPlugin::impl_desc_type::gemm,
384                                  {MKLDNNPlugin::impl_desc_type::gemm_any,
385                                   MKLDNNPlugin::impl_desc_type::gemm_blas,
386                                   MKLDNNPlugin::impl_desc_type::gemm_avx512,
387                                   MKLDNNPlugin::impl_desc_type::gemm_avx2,
388                                   MKLDNNPlugin::impl_desc_type::gemm_sse42}
389                 },
390                 conv_test_params{{1, 9, 32, 16},
391                                  2, 4, 1, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::ref_any, {MKLDNNPlugin::impl_desc_type::ref_any} }));