Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_deconv_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8
9 #include "test_graph.hpp"
10
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "ir_gen_helper.hpp"
15 #include "tests_common.hpp"
16
17
18 using namespace InferenceEngine;
19 using namespace ::testing;
20 using namespace std;
21 using namespace mkldnn;
22 using namespace single_layer_tests;
23
24
25 struct deconv_test_params {
26     // Formats: NCHW, NCDHW
27     vector<size_t> dims;
28     // Formats: WH, WHD
29     vector<size_t> kernel;
30     vector<size_t> strides;
31     vector<size_t> pads_begin;
32     vector<size_t> pads_end;
33
34     size_t out_c;
35     size_t grp_c;
36
37     bool with_bias;
38     string auto_pad;
39
40     size_t num_prim_desc;
41
42     std::vector<int> selectedTypes;
43     std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
44
45     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
46 };
47
48 template <typename data_t>
49 void ref_deconv(const InferenceEngine::TBlob<data_t> &src, const InferenceEngine::Blob::Ptr &weights, const InferenceEngine::Blob::Ptr &bias,
50                 InferenceEngine::TBlob<data_t> &dst, deconv_test_params prm) {
51     auto dims_size = src.dims().size();
52
53     size_t G  = prm.grp_c;
54     size_t KW = prm.kernel[X_AXIS];
55     size_t KH = prm.kernel[Y_AXIS];
56     size_t KD = prm.kernel.size() > Z_AXIS ? prm.kernel[Z_AXIS] : 1u;
57
58     size_t PW = prm.pads_begin[X_AXIS];
59     size_t PH = prm.pads_begin[Y_AXIS];
60     size_t PD = prm.pads_begin.size() > Z_AXIS ? prm.pads_begin[Z_AXIS] : 0u;
61
62     size_t SW = prm.strides[X_AXIS];
63     size_t SH = prm.strides[Y_AXIS];
64     size_t SD = prm.strides.size() > Z_AXIS ? prm.strides[Z_AXIS] : 1u;
65
66     size_t IW = src.dims()[dims_size - 1];
67     size_t IH = src.dims()[dims_size - 2];
68     size_t ID = dims_size == 5 ? src.dims()[dims_size - 3] : 1u;
69     size_t IC = src.dims()[1];
70     size_t MB = src.dims()[0];
71
72     size_t OC = prm.out_c;
73
74     size_t OW = SW * (IW - 1lu) + KW - 2lu * PW;
75     size_t OH = SH * (IH - 1lu) + KH - 2lu * PH;
76     size_t OD = dims_size == 5 ? (SD * (ID - 1) + KD - 2 * PD) : 1u;
77
78     const data_t *src_data = src.readOnly();
79     const data_t *weights_data = weights->buffer().as<data_t*>();
80     const data_t *bias_data = bias->buffer().as<data_t*>();
81
82     data_t *dst_data = dst.data();
83
84     size_t CS1 = OH * OW;
85     size_t CS2 = CS1 * OD;
86     size_t CS3 = CS2 * OC;
87
88     size_t CI1 = IH * IW;
89     size_t CI2 = CI1 * ID;
90     size_t CI3 = CI2 * IC;
91     
92     size_t OC_G = OC / G;
93     size_t IC_G = IC / G;
94
95     size_t CK1 = KH * KW;
96     size_t CK2 = CK1 * KD;
97     size_t CK3 = CK2 * OC_G;
98     size_t CK4 = CK3 * IC_G;
99
100     for (size_t g = 0lu; g < G; ++g) {
101         size_t g_OC_G = g * OC_G;
102         size_t g_IC_G = g * IC_G;
103         size_t g_CK4 = g * CK4;
104         for (size_t mb = 0lu; mb < MB; ++mb) {
105             size_t mb_CS3 = mb * CS3;
106             size_t mb_CI3 = mb * CI3;
107             for (size_t oc = 0lu; oc < OC_G; ++oc) {
108                 size_t g_OC_G_oc = g_OC_G + oc;
109                 size_t mb_CS3_g_OC_G_oc_CS2 = mb_CS3 + g_OC_G_oc * CS2;
110                 size_t g_CK4_oc_CK2 = g_CK4 + oc * CK2;
111                 for (size_t od = 0lu; od < OD; ++od) {
112                     size_t mb_CS3_g_OC_G_oc_CS2_od_CS1 = mb_CS3_g_OC_G_oc_CS2 + od * CS1;
113                     size_t od_PD = od + PD;
114                     for (size_t oh = 0lu; oh < OH; ++oh) {
115                         size_t mb_CS3_g_OC_G_oc_CS2_od_CS1_oh_OW = mb_CS3_g_OC_G_oc_CS2_od_CS1 + oh * OW;
116                         size_t oh_PH = oh + PH;
117                         for (size_t ow = 0lu; ow < OW; ++ow) {
118                             size_t didx = mb_CS3_g_OC_G_oc_CS2_od_CS1_oh_OW + ow;
119                             size_t ow_PW = ow + PW;
120
121                             dst_data[didx] = data_t(0);
122                             if (prm.with_bias) dst_data[didx] += bias_data[g_OC_G_oc];
123
124                             for (size_t ic = 0lu; ic < IC_G; ic++) {
125                                 size_t mb_CI3_g_IC_G_ic_CI2 = mb_CI3 + (g_IC_G + ic) * CI2;
126                                 size_t g_CK4_oc_CK2_ic_CK3 = g_CK4_oc_CK2 + ic * CK3;
127                                 for (int kd = 0lu; kd < KD; kd++) {
128                                     if (od_PD < kd) continue;
129                                     size_t id = od_PD - kd;
130                                     if (id % SD != 0) continue;
131                                     id /= SD;
132                                     if (id >= ID) continue;
133                                     size_t mb_CI3_g_IC_G_ic_CI2_id_CI1 = mb_CI3_g_IC_G_ic_CI2 + id * CI1;
134                                     size_t g_CK4_oc_CK2_ic_CK3_kd_CK1 = g_CK4_oc_CK2_ic_CK3 + kd * CK1;
135                                     for (size_t kh = 0lu; kh < KH; kh++) {
136                                         if (oh_PH < kh) continue;
137                                         size_t ih = oh_PH - kh;
138                                         if (ih % SH != 0) continue;
139                                         ih /= SH;
140                                         if (ih >= IH) continue;
141                                         size_t mb_CI3_g_IC_G_ic_CI2_id_CI1_ih_IW = mb_CI3_g_IC_G_ic_CI2_id_CI1 + ih * IW;
142                                         size_t g_CK4_oc_CK2_ic_CK3_kd_CK1_kh_KW = g_CK4_oc_CK2_ic_CK3_kd_CK1 + kh * KW;
143                                         for (size_t kw = 0lu; kw < KW; kw++) {
144                                             if (ow_PW < kw) continue;
145                                             size_t iw = ow_PW - kw;
146                                             if (iw % SW != 0) continue;
147                                             iw /= SW;
148                                             if (iw >= IW) continue;
149
150                                             size_t sidx = mb_CI3_g_IC_G_ic_CI2_id_CI1_ih_IW + iw;
151
152                                             size_t widx = g_CK4_oc_CK2_ic_CK3_kd_CK1_kh_KW + kw;
153
154                                             dst_data[didx] += src_data[sidx] * weights_data[widx];
155                                         }
156                                     }
157                                 }
158                             }
159                         }
160                     }
161                 }
162             }
163         }
164     }
165 }
166
167 class MKLDNNGraphDeconvolutionalTests: public TestsCommon,
168                                      public WithParamInterface<deconv_test_params> {
169     std::string layers_t = R"V0G0N(
170         <layer name="deconv1" id="1" type="Deconvolution" precision="FP32">
171             <deconvolution _AP_ kernel="_K_"
172                          pads_begin="_PB_"  pads_end="_PE_"
173                          strides="_KS_"
174                          output="_OC_" group="_GC_" PrimitivesPriority="_IMPLS_"/>
175
176             <weights offset="0" size="_S1_" />
177             <biases offset="_S1_" size="_S2_" />
178
179             <input>
180                 <port id="1">
181                     __SRC_DIMS__
182                 </port>
183             </input>
184             <output>
185                 <port id="2">
186                     <dim>_IN_</dim>
187                     <dim>_OC_</dim>
188                     __DST_DIMS__
189                 </port>
190             </output>
191         </layer>
192 )V0G0N";
193
194     std::string edges_t = R"V0G0N(
195         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
196 )V0G0N";
197
198 protected:
199     std::string getModel(deconv_test_params p) {
200         std::string model = layers_t;
201
202         std::string s_dims;
203         for (auto& dim : p.dims) {
204             s_dims += "\n                    <dim>";
205             s_dims += std::to_string(dim) + "</dim>";
206         }
207         REPLACE_WITH_STR(model, "__SRC_DIMS__", s_dims);
208
209         s_dims = "";
210         int k_len = p.kernel.size();
211         for (size_t i = 2; i < p.dims.size(); i++) {
212             size_t inx = k_len - i + 1;
213             size_t dim = p.strides[inx] * (p.dims[i] - 1) + p.kernel[inx] - 2 * p.pads_begin[inx];
214             s_dims += "\n                    <dim>";
215             s_dims += std::to_string(dim) + "</dim>";
216         }
217         REPLACE_WITH_STR(model, "__DST_DIMS__", s_dims);
218         REPLACE_WITH_NUM(model, "_IN_", p.dims[0]);
219
220         if (!p.with_bias) REMOVE_LINE(model, "<biases offset=\"_S1_\" size=\"_S2_\" />");
221
222         REPLACE_WITH_NUM_VECTOR_REVERSE(model, "_K_", p.kernel);
223         REPLACE_WITH_NUM_VECTOR_REVERSE(model, "_KS_", p.strides);
224         REPLACE_WITH_NUM_VECTOR_REVERSE(model, "_PB_", p.pads_begin);
225         REPLACE_WITH_NUM_VECTOR_REVERSE(model, "_PE_", p.pads_end);
226         REPLACE_WITH_NUM(model, "_GC_", p.grp_c);
227         REPLACE_WITH_NUM(model, "_OC_", p.out_c);
228         string auto_pad;
229         if (!p.auto_pad.empty()) auto_pad = string("auto_pad=") + string("\"") + p.auto_pad + string("\"");
230         REPLACE_WITH_STR(model, "_AP_", auto_pad);
231
232         size_t blob_size = p.out_c * (p.dims[1] / p.grp_c);
233         for (auto k : p.kernel) {
234             blob_size *= k;
235         }
236         size_t w_data_size = blob_size * sizeof(float);
237         REPLACE_WITH_NUM(model, "_S1_", w_data_size);
238
239         size_t b_data_size = p.out_c * sizeof(float);
240         REPLACE_WITH_NUM(model, "_S2_", b_data_size);
241
242         std::string impls;
243         for (const auto& preferType : p.preferTypes) {
244             if (!impls.empty())
245                 impls += ",";
246             impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
247         }
248         REPLACE_WITH_STR(model, "_IMPLS_", impls);
249
250         model = IRTemplateGenerator::getIRTemplate("Deconvolution_Only", p.dims, "FP32", model, edges_t);
251
252         return model;
253     }
254
255     virtual void TearDown() {
256     }
257
258     virtual void SetUp() {
259         try {
260             TestsCommon::SetUp();
261             deconv_test_params p = ::testing::WithParamInterface<deconv_test_params>::GetParam();
262             std::string model = getModel(p);
263
264             InferenceEngine::CNNNetReader net_reader;
265             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
266
267             size_t blob_size = p.out_c * (p.dims[1] / p.grp_c);
268             for (auto k : p.kernel) {
269                 blob_size *= k;
270             }
271             InferenceEngine::SizeVector dims_weights = { blob_size };
272
273             std::vector<InferenceEngine::Blob::Ptr> blob_to_model;
274             InferenceEngine::Blob::Ptr weights = InferenceEngine::make_shared_blob<float>(InferenceEngine::Precision::FP32, InferenceEngine::C, dims_weights);
275             weights->allocate();
276             fill_data(weights->buffer().as<float*>(), weights->size());
277             blob_to_model.push_back(weights);
278
279             InferenceEngine::Blob::Ptr bias = InferenceEngine::make_shared_blob<float>(InferenceEngine::Precision::FP32, InferenceEngine::C, {p.out_c});
280             bias->allocate();
281             fill_data(bias->buffer().as<float*>(), bias->size());
282             blob_to_model.push_back(bias);
283
284             size_t total_size_in_bytes = 0;
285             for (InferenceEngine::Blob::Ptr blb : blob_to_model) total_size_in_bytes += blb->byteSize();
286
287             InferenceEngine::TBlob<uint8_t>::Ptr model_blob =
288                     InferenceEngine::make_shared_blob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {total_size_in_bytes});
289             model_blob->allocate();
290             uint8_t* model_blob_ptr = model_blob->buffer().as<uint8_t*>();
291             for (InferenceEngine::Blob::Ptr blb : blob_to_model) {
292                 memcpy(model_blob_ptr, blb->buffer().as<uint8_t*>(), blb->byteSize());
293                 model_blob_ptr += blb->byteSize();
294             }
295             net_reader.SetWeights(model_blob);
296
297             MKLDNNGraphTestClass graph;
298             graph.CreateGraph(net_reader.getNetwork());
299             auto& nodes = graph.getNodes();
300             for (auto &node : nodes) {
301                 if (node->getType() == MKLDNNPlugin::Deconvolution) {
302                     ASSERT_LE(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size());
303                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
304                         p.comp.at(j)(node->getSupportedPrimitiveDescriptors().at(j));
305                     }
306                     ASSERT_NE(nullptr, node->getSelectedPrimitiveDescriptor());
307                     bool good_prim = false;
308                     for (auto & selected : p.selectedTypes)
309                         if (selected == (node->getSelectedPrimitiveDescriptor()->getImplementationType() & selected))
310                             good_prim = true;
311                     ASSERT_TRUE(good_prim);
312                 }
313             }
314
315             InferenceEngine::SizeVector dims_src = p.dims;
316
317             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(
318                     InferenceEngine::Precision::FP32, InferenceEngine::TensorDesc::getLayoutByDims(p.dims), dims_src);
319             src->allocate();
320             fill_data(src->buffer(), src->size());
321
322             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
323
324             if (srcPtr == nullptr)
325                 FAIL() << "Cannot cast blob to TBlob<float>.";
326
327             InferenceEngine::BlobMap srcs;
328             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
329
330             InferenceEngine::OutputsDataMap out;
331             out = net_reader.getNetwork().getOutputsInfo();
332             InferenceEngine::BlobMap outputBlobs;
333
334             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
335
336             InferenceEngine::TBlob<float>::Ptr output;
337             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
338             output->allocate();
339             outputBlobs[item.first] = output;
340
341             graph.Infer(srcs, outputBlobs);
342
343             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
344             dst_ref.allocate();
345
346             ref_deconv(*srcPtr, weights, bias, dst_ref, p);
347
348             compare(*output, dst_ref, 0.0002f);
349         } catch (const InferenceEngine::details::InferenceEngineException &e) {
350             FAIL() << e.what();
351         }
352     }
353 };
354
355 TEST_P(MKLDNNGraphDeconvolutionalTests, TestsDeconvolution) {}
356
357
358 INSTANTIATE_TEST_CASE_P(
359         TestDeconvolution, MKLDNNGraphDeconvolutionalTests,
360         ::testing::Values(
361         /*0*/   deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
362                 deconv_test_params{{3, 3, 3, 3}, {4, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
363                 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
364                 deconv_test_params{{2, 8, 5, 5}, {8, 8}, {4, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
365                 deconv_test_params{{2, 8, 5, 5}, {4, 8}, {2, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
366         /*5*/   deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
367                 deconv_test_params{{3, 3, 3, 3}, {4, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
368                 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 8, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
369                 deconv_test_params{{2, 8, 5, 5}, {8, 8}, {4, 4}, {1, 1}, {0, 0}, 8, 8, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
370                 deconv_test_params{{2, 8, 5, 5}, {4, 8}, {2, 4}, {1, 1}, {0, 0}, 8, 8, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
371                 deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::ref_any}, 
372                                     {MKLDNNPlugin::impl_desc_type::ref_any}},
373         /*11*/  deconv_test_params{{2, 8, 5, 5}, {1, 3}, {1, 1}, {0, 1}, {0, 1}, 8, 8, true, "", 2,
374                     {MKLDNNPlugin::impl_desc_type::ref_any}, {MKLDNNPlugin::impl_desc_type::ref_any}},
375                 deconv_test_params{{1, 6, 6, 5}, {3, 1}, {1, 1}, {1, 0}, {1, 0}, 9, 3, true, "", 2,
376                     {MKLDNNPlugin::impl_desc_type::ref_any}, {MKLDNNPlugin::impl_desc_type::ref_any}},
377 #ifdef USE_MKL
378                 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {1, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
379                 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
380                 deconv_test_params{{4, 17, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
381                 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 2, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm}},
382                 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {1, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
383                 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
384                 deconv_test_params{{4, 17, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
385                 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 2, true, "", 3, {MKLDNNPlugin::impl_desc_type::gemm}},
386                 deconv_test_params{{1, 6, 6, 5}, {3, 1}, {1, 1}, {1, 0}, {1, 0}, 9, 3, true, "", 2,
387                     {MKLDNNPlugin::impl_desc_type::gemm_blas}},
388                 deconv_test_params{{1, 64, 12, 12, 2}, {2, 2, 2}, {2, 2, 2}, {0, 0, 0}, {1, 0, 0}, 32, 1, true, "", 4,
389                     {MKLDNNPlugin::impl_desc_type::gemm_blas}},
390                 deconv_test_params{{1, 32, 12, 12, 2}, {2, 2, 2}, {2, 2, 2}, {0, 0, 0}, {1, 0, 0}, 16, 1, true, "", 4, 
391                     {MKLDNNPlugin::impl_desc_type::gemm_blas} },
392                 deconv_test_params{{1, 25, 1, 1, 1}, {4, 4, 4}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, true, "valid", 3,
393                     {MKLDNNPlugin::impl_desc_type::jit} },
394                 deconv_test_params{{1, 32, 16, 16, 16}, {4, 4, 4}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, 1, 1, true, "same_upper", 3,
395                     {MKLDNNPlugin::impl_desc_type::gemm_blas} },
396                 deconv_test_params{{1, 64, 12, 12, 2}, {2, 2, 2}, {2, 2, 2}, {0, 0, 0}, {1, 0, 0}, 32, 1, true, "same_upper", 3,
397                     {MKLDNNPlugin::impl_desc_type::gemm_blas} },
398                 deconv_test_params{{1, 50, 1, 1, 1}, {4, 4, 4}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 128, 1, true, "", 3,
399                     {MKLDNNPlugin::impl_desc_type::gemm_blas}, {MKLDNNPlugin::impl_desc_type::gemm_blas}},
400 #endif
401                 deconv_test_params{{2, 24, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
402                 deconv_test_params{{2, 24, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 1, true, "", 3, {MKLDNNPlugin::impl_desc_type::jit}},
403                 deconv_test_params{{2, 72, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 72, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
404                 deconv_test_params{{1, 12, 2, 2}, {4, 4}, {2, 2}, {1, 1}, {1, 1}, 12, 12, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
405                 // 5D
406                 deconv_test_params{{1, 2, 8, 5, 5}, {3, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 4, 1, true, "", 4,
407                     {MKLDNNPlugin::impl_desc_type::ref_any}, {MKLDNNPlugin::impl_desc_type::ref_any} }
408                 // Blocked, with biases
409                 // TODO support on jit
410 //                deconv_test_params{{2, 24, 5, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
411 //                deconv_test_params{{2, 24, 5, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 1, true, "", 3, {MKLDNNPlugin::impl_desc_type::jit}},
412 //                deconv_test_params{{2, 72, 5, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 72, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}}
413         ));
414
415 class MKLDNNGraphDynBatchDeconvolutionalTests: public MKLDNNGraphDeconvolutionalTests {
416 protected:
417     virtual void SetUp() {
418         try {
419             TestsCommon::SetUp();
420             deconv_test_params p = ::testing::WithParamInterface<deconv_test_params>::GetParam();
421             std::string model = getModel(p);
422             size_t MB = p.dims[0];
423             if (MB < 2)
424                 MB = 2;
425
426             InferenceEngine::CNNNetReader net_reader;
427             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
428             
429             size_t blob_size = 1;
430             for (auto k : p.kernel) {
431                 blob_size *= k;
432             }
433             InferenceEngine::SizeVector dims_weights = {blob_size * p.out_c * (p.dims[1] / p.grp_c)};
434
435             std::vector<InferenceEngine::Blob::Ptr> blob_to_model;
436             InferenceEngine::Blob::Ptr weights = InferenceEngine::make_shared_blob<float>(InferenceEngine::Precision::FP32, InferenceEngine::C, dims_weights);
437             weights->allocate();
438             fill_data(weights->buffer().as<float*>(), weights->size());
439             blob_to_model.push_back(weights);
440
441             InferenceEngine::Blob::Ptr bias = InferenceEngine::make_shared_blob<float>(InferenceEngine::Precision::FP32, InferenceEngine::C, {p.out_c});
442             bias->allocate();
443             fill_data(bias->buffer().as<float*>(), bias->size());
444             blob_to_model.push_back(bias);
445
446             size_t total_size_in_bytes = 0;
447             for (InferenceEngine::Blob::Ptr blb : blob_to_model) total_size_in_bytes += blb->byteSize();
448
449             InferenceEngine::TBlob<uint8_t>::Ptr model_blob =
450                     InferenceEngine::make_shared_blob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {total_size_in_bytes});
451             model_blob->allocate();
452             uint8_t* model_blob_ptr = model_blob->buffer().as<uint8_t*>();
453             for (InferenceEngine::Blob::Ptr blb : blob_to_model) {
454                 memcpy(model_blob_ptr, blb->buffer().as<uint8_t*>(), blb->byteSize());
455                 model_blob_ptr += blb->byteSize();
456             }
457             net_reader.SetWeights(model_blob);
458
459             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
460             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
461             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
462             InferenceEngine::ResponseDesc resp;
463             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
464             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
465
466
467             MKLDNNGraphTestClass graph;
468             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
469             graph.CreateGraph(net_reader.getNetwork());
470
471             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(
472                     InferenceEngine::Precision::FP32, InferenceEngine::TensorDesc::getLayoutByDims(p.dims), p.dims);
473             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
474             if (srcPtr == nullptr)
475                 FAIL() << "Cannot cast blob to TBlob<float>.";
476
477             src->allocate();
478             fill_data(src->buffer(), src->size());
479
480             InferenceEngine::BlobMap srcs;
481             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
482
483             InferenceEngine::OutputsDataMap out;
484             out = net_reader.getNetwork().getOutputsInfo();
485             InferenceEngine::BlobMap outputBlobs;
486
487             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
488
489             InferenceEngine::TBlob<float>::Ptr output;
490             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
491             output->allocate();
492             outputBlobs[item.first] = output;
493
494             auto checkDeconvolution = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
495                 return node->getType() == MKLDNNPlugin::Deconvolution;
496             };
497
498             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkDeconvolution, MKLDNNGraphTestClass::CheckDynBatchType::Child);
499             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkDeconvolution, MKLDNNGraphTestClass::CheckDynBatchType::Child);
500         } catch (const InferenceEngine::details::InferenceEngineException &e) {
501             FAIL() << e.what();
502         }
503     }
504 };
505
506 TEST_P(MKLDNNGraphDynBatchDeconvolutionalTests, TestsDynBatchDeconvolutional) {}
507
508 INSTANTIATE_TEST_CASE_P(
509         TestsDynBatchDeconvolutional, MKLDNNGraphDynBatchDeconvolutionalTests,
510         ::testing::Values(
511                 deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 5, {MKLDNNPlugin::impl_desc_type::jit} },
512                 deconv_test_params{{3, 3, 3, 3}, {4, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 5, {MKLDNNPlugin::impl_desc_type::jit} },
513 #ifdef USE_MKL
514                 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {1, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 4, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
515                 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
516                 deconv_test_params{{4, 17, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
517                 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 2, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm}},
518 #endif
519                 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
520                 deconv_test_params{{2, 8, 5, 5}, {8, 8}, {4, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
521                 deconv_test_params{{2, 8, 5, 5}, {4, 8}, {2, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}}
522         ));