Publishing R5 content (#72)
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_deconv_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8
9 #include "test_graph.hpp"
10
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
15
16
17 using namespace InferenceEngine;
18 using namespace ::testing;
19 using namespace std;
20 using namespace mkldnn;
21
22
23 struct deconv_test_params {
24     // Formats: NCHW, NCDHW
25     vector<size_t> dims;
26     // Formats: WH, WHD
27     vector<size_t> kernel;
28     vector<size_t> strides;
29     vector<size_t> pads_begin;
30     vector<size_t> pads_end;
31
32     size_t out_c;
33     size_t grp_c;
34
35     bool with_bias;
36     string auto_pad;
37
38     size_t num_prim_desc;
39
40     std::vector<int> selectedTypes;
41     std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
42
43     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
44 };
45
46 template <typename data_t>
47 void ref_deconv(const InferenceEngine::TBlob<data_t> &src, const InferenceEngine::Blob::Ptr &weights, const InferenceEngine::Blob::Ptr &bias,
48                 InferenceEngine::TBlob<data_t> &dst, deconv_test_params prm) {
49     auto dims_size = src.dims().size();
50
51     size_t G  = prm.grp_c;
52     size_t KW = prm.kernel[X_AXIS];
53     size_t KH = prm.kernel[Y_AXIS];
54     size_t KD = prm.kernel.size() > Z_AXIS ? prm.kernel[Z_AXIS] : 1u;
55
56     size_t PW = prm.pads_begin[X_AXIS];
57     size_t PH = prm.pads_begin[Y_AXIS];
58     size_t PD = prm.pads_begin.size() > Z_AXIS ? prm.pads_begin[Z_AXIS] : 0u;
59
60     size_t SW = prm.strides[X_AXIS];
61     size_t SH = prm.strides[Y_AXIS];
62     size_t SD = prm.strides.size() > Z_AXIS ? prm.strides[Z_AXIS] : 1u;
63
64     size_t IW = src.dims()[dims_size - 1];
65     size_t IH = src.dims()[dims_size - 2];
66     size_t ID = dims_size == 5 ? src.dims()[dims_size - 3] : 1u;
67     size_t IC = src.dims()[1];
68     size_t MB = src.dims()[0];
69
70     size_t OC = prm.out_c;
71
72     size_t OW = SW * (IW - 1) + KW - 2 * PW;
73     size_t OH = SH * (IH - 1) + KH - 2 * PH;
74     size_t OD = dims_size == 5 ? (SD * (ID - 1) + KD - 2 * PD) : 1u;
75
76     const data_t *src_data = src.readOnly();
77     const data_t *weights_data = weights->buffer().as<data_t*>();
78     const data_t *bias_data = bias->buffer().as<data_t*>();
79
80     data_t *dst_data = dst.data();
81
82     size_t CS1 = OH * OW;
83     size_t CS2 = CS1 * OD;
84     size_t CS3 = CS2 * OC;
85
86     size_t CI1 = IH * IW;
87     size_t CI2 = CI1 * ID;
88     size_t CI3 = CI2 * IC;
89
90     size_t CK1 = KH * KW;
91     size_t CK2 = CK1 * KD;
92     size_t CK3 = CK2 * (OC / G);
93     size_t CK4 = CK3 * (IC / G);
94
95     for (int g = 0; g < G; ++g) {
96         for (int mb = 0; mb < MB; ++mb) {
97             for (int oc = 0; oc < OC / G; ++oc) {
98                 for (int od = 0; od < OD; ++od) {
99                     for (int oh = 0; oh < OH; ++oh) {
100                         for (int ow = 0; ow < OW; ++ow) {
101                             size_t didx = mb * CS3
102                                           + (g * OC / G + oc) * CS2
103                                           + od * CS1
104                                           + oh * OW
105                                           + ow;
106
107                             dst_data[didx] = data_t(0);
108                             if (prm.with_bias) dst_data[didx] += bias_data[g * OC / G + oc];
109
110                             for (int ic = 0; ic < IC / G; ic++) {
111                                 for (int kd = 0; kd < KD; kd++) {
112                                     for (int kh = 0; kh < KH; kh++) {
113                                         for (int kw = 0; kw < KW; kw++) {
114                                             if (ow + PW < kw || oh + PH < kh || od + PD < kd)
115                                                 continue;
116
117                                             size_t iw = ow - kw + PW;
118                                             size_t ih = oh - kh + PH;
119                                             size_t id = od - kd + PD;
120
121                                             if (iw % SW != 0 || ih % SH != 0 || id % SD != 0)
122                                                 continue;
123
124                                             iw /= SW;
125                                             ih /= SH;
126                                             id /= SD;
127
128                                             if (ih < IH && iw < IW && id < ID) {
129                                                 size_t sidx = mb * CI3
130                                                               + (g * IC / G + ic) * CI2
131                                                               + id * CI1
132                                                               + ih * IW
133                                                               + iw;
134
135                                                 size_t widx = g * CK4
136                                                               + ic * CK3
137                                                               + oc * CK2
138                                                               + kd * CK1
139                                                               + kh * KW
140                                                               + kw;
141
142                                                 dst_data[didx] += src_data[sidx] * weights_data[widx];
143                                             }
144                                         }
145                                     }
146                                 }
147                             }
148                         }
149                     }
150                 }
151             }
152         }
153     }
154 }
155
156 class MKLDNNGraphDeconvolutionalTests: public TestsCommon,
157                                      public WithParamInterface<deconv_test_params> {
158     std::string model_t_5D = R"V0G0N(
159 <net name="Deconvolution_Only" version="3" precision="FP32" batch="1">
160     <layers>
161         <layer name="in1" type="Input" precision="FP32" id="0">
162             <output>
163                 <port id="0">__SRC_DIMS__
164                 </port>
165             </output>
166         </layer>
167         <layer name="deconv1" id="1" type="Deconvolution" precision="FP32">
168             <deconvolution _AP_ kernel="_K_"
169                          pads_begin="_PB_"  pads_end="_PE_"
170                          strides="_KS_"
171                          output="_OC_" group="_GC_" PrimitivesPriority="_IMPLS_"/>
172
173             <weights offset="0" size="_S1_" />
174             <biases offset="_S1_" size="_S2_" />
175
176             <input>
177                 <port id="1">__SRC_DIMS__
178                 </port>
179             </input>
180             <output>
181                 <port id="2">
182                     <dim>_IN_</dim>
183                     <dim>_OC_</dim>__DST_DIMS__
184                 </port>
185             </output>
186         </layer>
187     </layers>
188     <edges>
189         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
190     </edges>
191 </net>
192 )V0G0N";
193
194 protected:
195     std::string getModel(deconv_test_params p) {
196         std::string model = model_t_5D;
197         auto dims_size = p.dims.size();
198         std::string s_dims;
199         for (auto& dim : p.dims) {
200             s_dims += "\n                    <dim>";
201             s_dims += std::to_string(dim) + "</dim>";
202         }
203         REPLACE_WITH_STR(model, "__SRC_DIMS__", s_dims);
204
205         s_dims = "";
206         int k_len = p.kernel.size();
207         for (size_t i = 2; i < p.dims.size(); i++) {
208             size_t inx = k_len - i + 1;
209             size_t dim = p.strides[inx] * (p.dims[i] - 1) + p.kernel[inx] - 2 * p.pads_begin[inx];
210             s_dims += "\n                    <dim>";
211             s_dims += std::to_string(dim) + "</dim>";
212         }
213         REPLACE_WITH_STR(model, "__DST_DIMS__", s_dims);
214         REPLACE_WITH_NUM(model, "_IN_", p.dims[0]);
215
216         if (!p.with_bias) REMOVE_LINE(model, "<biases offset=\"_S1_\" size=\"_S2_\" />");
217
218         REPLACE_WITH_NUM_VECTOR_REVERSE(model, "_K_", p.kernel);
219         REPLACE_WITH_NUM_VECTOR_REVERSE(model, "_KS_", p.strides);
220         REPLACE_WITH_NUM_VECTOR_REVERSE(model, "_PB_", p.pads_begin);
221         REPLACE_WITH_NUM_VECTOR_REVERSE(model, "_PE_", p.pads_end);
222         REPLACE_WITH_NUM(model, "_GC_", p.grp_c);
223         REPLACE_WITH_NUM(model, "_OC_", p.out_c);
224         string auto_pad;
225         if (!p.auto_pad.empty()) auto_pad = string("auto_pad=") + string("\"") + p.auto_pad + string("\"");
226         REPLACE_WITH_STR(model, "_AP_", auto_pad);
227
228         size_t blob_size = p.out_c * (p.dims[1] / p.grp_c);
229         for (auto k : p.kernel) {
230             blob_size *= k;
231         }
232         size_t w_data_size = blob_size * sizeof(float);
233         REPLACE_WITH_NUM(model, "_S1_", w_data_size);
234
235         size_t b_data_size = p.out_c * sizeof(float);
236         REPLACE_WITH_NUM(model, "_S2_", b_data_size);
237
238         std::string impls;
239         for (const auto& preferType : p.preferTypes) {
240             if (!impls.empty())
241                 impls += ",";
242             impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
243         }
244         REPLACE_WITH_STR(model, "_IMPLS_", impls);
245
246         return model;
247     }
248
249     virtual void TearDown() {
250     }
251
252     virtual void SetUp() {
253         try {
254             TestsCommon::SetUp();
255             deconv_test_params p = ::testing::WithParamInterface<deconv_test_params>::GetParam();
256             std::string model = getModel(p);
257
258             InferenceEngine::CNNNetReader net_reader;
259             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
260
261             size_t blob_size = p.out_c * (p.dims[1] / p.grp_c);
262             for (auto k : p.kernel) {
263                 blob_size *= k;
264             }
265             InferenceEngine::SizeVector dims_weights = { blob_size };
266
267             std::vector<InferenceEngine::Blob::Ptr> blob_to_model;
268             InferenceEngine::Blob::Ptr weights = InferenceEngine::make_shared_blob<float>(InferenceEngine::Precision::FP32, InferenceEngine::C, dims_weights);
269             weights->allocate();
270             fill_data(weights->buffer().as<float*>(), weights->size());
271             blob_to_model.push_back(weights);
272
273             InferenceEngine::Blob::Ptr bias = InferenceEngine::make_shared_blob<float>(InferenceEngine::Precision::FP32, InferenceEngine::C, {p.out_c});
274             bias->allocate();
275             fill_data(bias->buffer().as<float*>(), bias->size());
276             blob_to_model.push_back(bias);
277
278             size_t total_size_in_bytes = 0;
279             for (InferenceEngine::Blob::Ptr blb : blob_to_model) total_size_in_bytes += blb->byteSize();
280
281             InferenceEngine::TBlob<uint8_t>::Ptr model_blob =
282                     InferenceEngine::make_shared_blob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {total_size_in_bytes});
283             model_blob->allocate();
284             uint8_t* model_blob_ptr = model_blob->buffer().as<uint8_t*>();
285             for (InferenceEngine::Blob::Ptr blb : blob_to_model) {
286                 memcpy(model_blob_ptr, blb->buffer().as<uint8_t*>(), blb->byteSize());
287                 model_blob_ptr += blb->byteSize();
288             }
289             net_reader.SetWeights(model_blob);
290
291             MKLDNNGraphTestClass graph;
292             graph.CreateGraph(net_reader.getNetwork());
293             auto& nodes = graph.getNodes();
294             for (auto &node : nodes) {
295                 if (node->getType() == MKLDNNPlugin::Deconvolution) {
296                     ASSERT_LE(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size());
297                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
298                         p.comp.at(j)(node->getSupportedPrimitiveDescriptors().at(j));
299                     }
300                     ASSERT_NE(nullptr, node->getSelectedPrimitiveDescriptor());
301                     bool good_prim = false;
302                     for (auto & selected : p.selectedTypes)
303                         if (selected == (node->getSelectedPrimitiveDescriptor()->getImplementationType() & selected))
304                             good_prim = true;
305                     ASSERT_TRUE(good_prim);
306                 }
307             }
308
309             InferenceEngine::SizeVector dims_src = p.dims;
310
311             InferenceEngine::Layout layout = ANY;
312             switch (p.dims.size()) {
313                 case 4:
314                     layout = InferenceEngine::NCHW;
315                     break;
316                 case 5:
317                     layout = InferenceEngine::NCDHW;
318                     break;
319             }
320             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
321             src->allocate();
322             fill_data(src->buffer(), src->size());
323
324             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
325
326             if (srcPtr == nullptr)
327                 FAIL() << "Cannot cast blob to TBlob<float>.";
328
329             InferenceEngine::BlobMap srcs;
330             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
331
332             InferenceEngine::OutputsDataMap out;
333             out = net_reader.getNetwork().getOutputsInfo();
334             InferenceEngine::BlobMap outputBlobs;
335
336             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
337
338             InferenceEngine::TBlob<float>::Ptr output;
339             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
340             output->allocate();
341             outputBlobs[item.first] = output;
342
343             graph.Infer(srcs, outputBlobs);
344
345             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
346             dst_ref.allocate();
347
348             ref_deconv(*srcPtr, weights, bias, dst_ref, p);
349
350             compare(*output, dst_ref, 0.0002f);
351         } catch (const InferenceEngine::details::InferenceEngineException &e) {
352             FAIL() << e.what();
353         }
354     }
355 };
356
357 TEST_P(MKLDNNGraphDeconvolutionalTests, TestsDeconvolution) {}
358
359
360 INSTANTIATE_TEST_CASE_P(
361         TestDeconvolution, MKLDNNGraphDeconvolutionalTests,
362         ::testing::Values(
363         /*0*/   deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
364                 deconv_test_params{{3, 3, 3, 3}, {4, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
365                 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {1, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
366                 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
367                 deconv_test_params{{4, 17, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
368                 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
369                 deconv_test_params{{2, 8, 5, 5}, {8, 8}, {4, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
370                 deconv_test_params{{2, 8, 5, 5}, {4, 8}, {2, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
371         /*8*/   deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
372                 deconv_test_params{{3, 3, 3, 3}, {4, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
373                 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {1, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
374                 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
375                 deconv_test_params{{4, 17, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
376                 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 8, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
377                 deconv_test_params{{2, 8, 5, 5}, {8, 8}, {4, 4}, {1, 1}, {0, 0}, 8, 8, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
378                 deconv_test_params{{2, 8, 5, 5}, {4, 8}, {2, 4}, {1, 1}, {0, 0}, 8, 8, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
379                 deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::ref_any}, 
380                                     {MKLDNNPlugin::impl_desc_type::ref_any}},
381         /*17*/  deconv_test_params{{2, 8, 5, 5}, {1, 3}, {1, 1}, {0, 1}, {0, 1}, 8, 8, true, "", 2,
382                     {MKLDNNPlugin::impl_desc_type::ref_any}, {MKLDNNPlugin::impl_desc_type::ref_any}},
383                 deconv_test_params{{1, 6, 6, 5}, {3, 1}, {1, 1}, {1, 0}, {1, 0}, 9, 3, true, "", 2,
384                     {MKLDNNPlugin::impl_desc_type::ref_any}, {MKLDNNPlugin::impl_desc_type::ref_any}},
385                 deconv_test_params{{2, 24, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
386                 deconv_test_params{{2, 24, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 1, true, "", 3, {MKLDNNPlugin::impl_desc_type::jit}},
387                 deconv_test_params{{2, 72, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 72, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
388                 deconv_test_params{{1, 12, 2, 2}, {4, 4}, {2, 2}, {1, 1}, {1, 1}, 12, 12, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
389 #ifdef USE_MKL
390                 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 2, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm}},
391                 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 2, true, "", 3, {MKLDNNPlugin::impl_desc_type::gemm}},
392                 deconv_test_params{{1, 6, 6, 5}, {3, 1}, {1, 1}, {1, 0}, {1, 0}, 9, 3, true, "", 2,
393                     {MKLDNNPlugin::impl_desc_type::gemm_blas}},
394                 deconv_test_params{{1, 64, 12, 12, 2}, {2, 2, 2}, {2, 2, 2}, {0, 0, 0}, {1, 0, 0}, 32, 1, true, "", 4,
395                     {MKLDNNPlugin::impl_desc_type::gemm_blas}},
396                 deconv_test_params{{1, 32, 12, 12, 2}, {2, 2, 2}, {2, 2, 2}, {0, 0, 0}, {1, 0, 0}, 16, 1, true, "", 4, 
397                     {MKLDNNPlugin::impl_desc_type::gemm_blas} },
398                 deconv_test_params{{1, 25, 1, 1, 1}, {4, 4, 4}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, true, "valid", 3,
399                     {MKLDNNPlugin::impl_desc_type::gemm_blas} },
400                 deconv_test_params{{1, 32, 16, 16, 16}, {4, 4, 4}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, 1, 1, true, "same_upper", 3,
401                     {MKLDNNPlugin::impl_desc_type::gemm_blas} },
402                 deconv_test_params{{1, 64, 12, 12, 2}, {2, 2, 2}, {2, 2, 2}, {0, 0, 0}, {1, 0, 0}, 32, 1, true, "same_upper", 3,
403                     {MKLDNNPlugin::impl_desc_type::gemm_blas} },
404                 deconv_test_params{{1, 50, 1, 1, 1}, {4, 4, 4}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 128, 1, true, "", 3,
405                     {MKLDNNPlugin::impl_desc_type::gemm_blas}, {MKLDNNPlugin::impl_desc_type::gemm_blas}},
406 #endif
407                 // 5D
408                 deconv_test_params{{1, 2, 8, 5, 5}, {3, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 4, 1, true, "", 4,
409                     {MKLDNNPlugin::impl_desc_type::ref_any}, {MKLDNNPlugin::impl_desc_type::ref_any} }
410
411                 // Blocked, with biases
412                 // TODO support on jit
413 //                deconv_test_params{{2, 24, 5, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
414 //                deconv_test_params{{2, 24, 5, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 1, true, "", 3, {MKLDNNPlugin::impl_desc_type::jit}},
415 //                deconv_test_params{{2, 72, 5, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 72, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}}
416         ));
417
418 class MKLDNNGraphDynBatchDeconvolutionalTests: public MKLDNNGraphDeconvolutionalTests {
419 protected:
420     virtual void SetUp() {
421         try {
422             TestsCommon::SetUp();
423             deconv_test_params p = ::testing::WithParamInterface<deconv_test_params>::GetParam();
424             std::string model = getModel(p);
425             size_t MB = p.dims[0];
426             if (MB < 2)
427                 MB = 2;
428
429             InferenceEngine::CNNNetReader net_reader;
430             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
431             
432             size_t blob_size = 1;
433             for (auto k : p.kernel) {
434                 blob_size *= k;
435             }
436             InferenceEngine::SizeVector dims_weights = {blob_size * p.out_c * (p.dims[1] / p.grp_c)};
437
438             std::vector<InferenceEngine::Blob::Ptr> blob_to_model;
439             InferenceEngine::Blob::Ptr weights = InferenceEngine::make_shared_blob<float>(InferenceEngine::Precision::FP32, InferenceEngine::C, dims_weights);
440             weights->allocate();
441             fill_data(weights->buffer().as<float*>(), weights->size());
442             blob_to_model.push_back(weights);
443
444             InferenceEngine::Blob::Ptr bias = InferenceEngine::make_shared_blob<float>(InferenceEngine::Precision::FP32, InferenceEngine::C, {p.out_c});
445             bias->allocate();
446             fill_data(bias->buffer().as<float*>(), bias->size());
447             blob_to_model.push_back(bias);
448
449             size_t total_size_in_bytes = 0;
450             for (InferenceEngine::Blob::Ptr blb : blob_to_model) total_size_in_bytes += blb->byteSize();
451
452             InferenceEngine::TBlob<uint8_t>::Ptr model_blob =
453                     InferenceEngine::make_shared_blob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {total_size_in_bytes});
454             model_blob->allocate();
455             uint8_t* model_blob_ptr = model_blob->buffer().as<uint8_t*>();
456             for (InferenceEngine::Blob::Ptr blb : blob_to_model) {
457                 memcpy(model_blob_ptr, blb->buffer().as<uint8_t*>(), blb->byteSize());
458                 model_blob_ptr += blb->byteSize();
459             }
460             net_reader.SetWeights(model_blob);
461
462             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
463             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
464             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
465             InferenceEngine::ResponseDesc resp;
466             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
467             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
468
469
470             MKLDNNGraphTestClass graph;
471             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
472             graph.CreateGraph(net_reader.getNetwork());
473
474             InferenceEngine::SizeVector dims_src = p.dims;
475
476             InferenceEngine::Layout layout = ANY;
477             switch (p.dims.size()) {
478                 case 4:
479                     layout = InferenceEngine::NCHW;
480                     break;
481                 case 5:
482                     layout = InferenceEngine::NCDHW;
483                     break;
484             }
485             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
486             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
487             if (srcPtr == nullptr)
488                 FAIL() << "Cannot cast blob to TBlob<float>.";
489
490             src->allocate();
491             fill_data(src->buffer(), src->size());
492
493             InferenceEngine::BlobMap srcs;
494             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
495
496             InferenceEngine::OutputsDataMap out;
497             out = net_reader.getNetwork().getOutputsInfo();
498             InferenceEngine::BlobMap outputBlobs;
499
500             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
501
502             InferenceEngine::TBlob<float>::Ptr output;
503             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
504             output->allocate();
505             outputBlobs[item.first] = output;
506
507             auto checkDeconvolution = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
508                 return node->getType() == MKLDNNPlugin::Deconvolution;
509             };
510
511             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkDeconvolution, MKLDNNGraphTestClass::CheckDynBatchType::Child);
512             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkDeconvolution, MKLDNNGraphTestClass::CheckDynBatchType::Child);
513         } catch (const InferenceEngine::details::InferenceEngineException &e) {
514             FAIL() << e.what();
515         }
516     }
517 };
518
519 TEST_P(MKLDNNGraphDynBatchDeconvolutionalTests, TestsDynBatchDeconvolutional) {}
520
521 INSTANTIATE_TEST_CASE_P(
522         TestsDynBatchDeconvolutional, MKLDNNGraphDynBatchDeconvolutionalTests,
523         ::testing::Values(
524                 deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 5, {MKLDNNPlugin::impl_desc_type::jit} },
525                 deconv_test_params{{3, 3, 3, 3}, {4, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 5, {MKLDNNPlugin::impl_desc_type::jit} },
526                 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {1, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 4, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
527                 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
528                 deconv_test_params{{4, 17, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
529                 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 2, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm}},
530                 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
531                 deconv_test_params{{2, 8, 5, 5}, {8, 8}, {4, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
532                 deconv_test_params{{2, 8, 5, 5}, {4, 8}, {2, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}}
533         ));