1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "test_graph.hpp"
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "ir_gen_helper.hpp"
15 #include "tests_common.hpp"
18 using namespace InferenceEngine;
19 using namespace ::testing;
21 using namespace mkldnn;
22 using namespace single_layer_tests;
25 struct deconv_test_params {
26 // Formats: NCHW, NCDHW
29 vector<size_t> kernel;
30 vector<size_t> strides;
31 vector<size_t> pads_begin;
32 vector<size_t> pads_end;
42 std::vector<int> selectedTypes;
43 std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
45 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
48 template <typename data_t>
49 void ref_deconv(const InferenceEngine::TBlob<data_t> &src, const InferenceEngine::Blob::Ptr &weights, const InferenceEngine::Blob::Ptr &bias,
50 InferenceEngine::TBlob<data_t> &dst, deconv_test_params prm) {
51 auto dims_size = src.dims().size();
54 size_t KW = prm.kernel[X_AXIS];
55 size_t KH = prm.kernel[Y_AXIS];
56 size_t KD = prm.kernel.size() > Z_AXIS ? prm.kernel[Z_AXIS] : 1u;
58 size_t PW = prm.pads_begin[X_AXIS];
59 size_t PH = prm.pads_begin[Y_AXIS];
60 size_t PD = prm.pads_begin.size() > Z_AXIS ? prm.pads_begin[Z_AXIS] : 0u;
62 size_t SW = prm.strides[X_AXIS];
63 size_t SH = prm.strides[Y_AXIS];
64 size_t SD = prm.strides.size() > Z_AXIS ? prm.strides[Z_AXIS] : 1u;
66 size_t IW = src.dims()[dims_size - 1];
67 size_t IH = src.dims()[dims_size - 2];
68 size_t ID = dims_size == 5 ? src.dims()[dims_size - 3] : 1u;
69 size_t IC = src.dims()[1];
70 size_t MB = src.dims()[0];
72 size_t OC = prm.out_c;
74 size_t OW = SW * (IW - 1lu) + KW - 2lu * PW;
75 size_t OH = SH * (IH - 1lu) + KH - 2lu * PH;
76 size_t OD = dims_size == 5 ? (SD * (ID - 1) + KD - 2 * PD) : 1u;
78 const data_t *src_data = src.readOnly();
79 const data_t *weights_data = weights->buffer().as<data_t*>();
80 const data_t *bias_data = bias->buffer().as<data_t*>();
82 data_t *dst_data = dst.data();
85 size_t CS2 = CS1 * OD;
86 size_t CS3 = CS2 * OC;
89 size_t CI2 = CI1 * ID;
90 size_t CI3 = CI2 * IC;
96 size_t CK2 = CK1 * KD;
97 size_t CK3 = CK2 * OC_G;
98 size_t CK4 = CK3 * IC_G;
100 for (size_t g = 0lu; g < G; ++g) {
101 size_t g_OC_G = g * OC_G;
102 size_t g_IC_G = g * IC_G;
103 size_t g_CK4 = g * CK4;
104 for (size_t mb = 0lu; mb < MB; ++mb) {
105 size_t mb_CS3 = mb * CS3;
106 size_t mb_CI3 = mb * CI3;
107 for (size_t oc = 0lu; oc < OC_G; ++oc) {
108 size_t g_OC_G_oc = g_OC_G + oc;
109 size_t mb_CS3_g_OC_G_oc_CS2 = mb_CS3 + g_OC_G_oc * CS2;
110 size_t g_CK4_oc_CK2 = g_CK4 + oc * CK2;
111 for (size_t od = 0lu; od < OD; ++od) {
112 size_t mb_CS3_g_OC_G_oc_CS2_od_CS1 = mb_CS3_g_OC_G_oc_CS2 + od * CS1;
113 size_t od_PD = od + PD;
114 for (size_t oh = 0lu; oh < OH; ++oh) {
115 size_t mb_CS3_g_OC_G_oc_CS2_od_CS1_oh_OW = mb_CS3_g_OC_G_oc_CS2_od_CS1 + oh * OW;
116 size_t oh_PH = oh + PH;
117 for (size_t ow = 0lu; ow < OW; ++ow) {
118 size_t didx = mb_CS3_g_OC_G_oc_CS2_od_CS1_oh_OW + ow;
119 size_t ow_PW = ow + PW;
121 dst_data[didx] = data_t(0);
122 if (prm.with_bias) dst_data[didx] += bias_data[g_OC_G_oc];
124 for (size_t ic = 0lu; ic < IC_G; ic++) {
125 size_t mb_CI3_g_IC_G_ic_CI2 = mb_CI3 + (g_IC_G + ic) * CI2;
126 size_t g_CK4_oc_CK2_ic_CK3 = g_CK4_oc_CK2 + ic * CK3;
127 for (int kd = 0lu; kd < KD; kd++) {
128 if (od_PD < kd) continue;
129 size_t id = od_PD - kd;
130 if (id % SD != 0) continue;
132 if (id >= ID) continue;
133 size_t mb_CI3_g_IC_G_ic_CI2_id_CI1 = mb_CI3_g_IC_G_ic_CI2 + id * CI1;
134 size_t g_CK4_oc_CK2_ic_CK3_kd_CK1 = g_CK4_oc_CK2_ic_CK3 + kd * CK1;
135 for (size_t kh = 0lu; kh < KH; kh++) {
136 if (oh_PH < kh) continue;
137 size_t ih = oh_PH - kh;
138 if (ih % SH != 0) continue;
140 if (ih >= IH) continue;
141 size_t mb_CI3_g_IC_G_ic_CI2_id_CI1_ih_IW = mb_CI3_g_IC_G_ic_CI2_id_CI1 + ih * IW;
142 size_t g_CK4_oc_CK2_ic_CK3_kd_CK1_kh_KW = g_CK4_oc_CK2_ic_CK3_kd_CK1 + kh * KW;
143 for (size_t kw = 0lu; kw < KW; kw++) {
144 if (ow_PW < kw) continue;
145 size_t iw = ow_PW - kw;
146 if (iw % SW != 0) continue;
148 if (iw >= IW) continue;
150 size_t sidx = mb_CI3_g_IC_G_ic_CI2_id_CI1_ih_IW + iw;
152 size_t widx = g_CK4_oc_CK2_ic_CK3_kd_CK1_kh_KW + kw;
154 dst_data[didx] += src_data[sidx] * weights_data[widx];
167 class MKLDNNGraphDeconvolutionalTests: public TestsCommon,
168 public WithParamInterface<deconv_test_params> {
169 std::string layers_t = R"V0G0N(
170 <layer name="deconv1" id="1" type="Deconvolution" precision="FP32">
171 <deconvolution _AP_ kernel="_K_"
172 pads_begin="_PB_" pads_end="_PE_"
174 output="_OC_" group="_GC_" PrimitivesPriority="_IMPLS_"/>
176 <weights offset="0" size="_S1_" />
177 <biases offset="_S1_" size="_S2_" />
194 std::string edges_t = R"V0G0N(
195 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
199 std::string getModel(deconv_test_params p) {
200 std::string model = layers_t;
203 for (auto& dim : p.dims) {
204 s_dims += "\n <dim>";
205 s_dims += std::to_string(dim) + "</dim>";
207 REPLACE_WITH_STR(model, "__SRC_DIMS__", s_dims);
210 int k_len = p.kernel.size();
211 for (size_t i = 2; i < p.dims.size(); i++) {
212 size_t inx = k_len - i + 1;
213 size_t dim = p.strides[inx] * (p.dims[i] - 1) + p.kernel[inx] - 2 * p.pads_begin[inx];
214 s_dims += "\n <dim>";
215 s_dims += std::to_string(dim) + "</dim>";
217 REPLACE_WITH_STR(model, "__DST_DIMS__", s_dims);
218 REPLACE_WITH_NUM(model, "_IN_", p.dims[0]);
220 if (!p.with_bias) REMOVE_LINE(model, "<biases offset=\"_S1_\" size=\"_S2_\" />");
222 REPLACE_WITH_NUM_VECTOR_REVERSE(model, "_K_", p.kernel);
223 REPLACE_WITH_NUM_VECTOR_REVERSE(model, "_KS_", p.strides);
224 REPLACE_WITH_NUM_VECTOR_REVERSE(model, "_PB_", p.pads_begin);
225 REPLACE_WITH_NUM_VECTOR_REVERSE(model, "_PE_", p.pads_end);
226 REPLACE_WITH_NUM(model, "_GC_", p.grp_c);
227 REPLACE_WITH_NUM(model, "_OC_", p.out_c);
229 if (!p.auto_pad.empty()) auto_pad = string("auto_pad=") + string("\"") + p.auto_pad + string("\"");
230 REPLACE_WITH_STR(model, "_AP_", auto_pad);
232 size_t blob_size = p.out_c * (p.dims[1] / p.grp_c);
233 for (auto k : p.kernel) {
236 size_t w_data_size = blob_size * sizeof(float);
237 REPLACE_WITH_NUM(model, "_S1_", w_data_size);
239 size_t b_data_size = p.out_c * sizeof(float);
240 REPLACE_WITH_NUM(model, "_S2_", b_data_size);
243 for (const auto& preferType : p.preferTypes) {
246 impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
248 REPLACE_WITH_STR(model, "_IMPLS_", impls);
250 model = IRTemplateGenerator::getIRTemplate("Deconvolution_Only", p.dims, "FP32", model, edges_t);
255 virtual void TearDown() {
258 virtual void SetUp() {
260 TestsCommon::SetUp();
261 deconv_test_params p = ::testing::WithParamInterface<deconv_test_params>::GetParam();
262 std::string model = getModel(p);
264 InferenceEngine::CNNNetReader net_reader;
265 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
267 size_t blob_size = p.out_c * (p.dims[1] / p.grp_c);
268 for (auto k : p.kernel) {
271 InferenceEngine::SizeVector dims_weights = { blob_size };
273 std::vector<InferenceEngine::Blob::Ptr> blob_to_model;
274 InferenceEngine::Blob::Ptr weights = InferenceEngine::make_shared_blob<float>(InferenceEngine::Precision::FP32, InferenceEngine::C, dims_weights);
276 fill_data(weights->buffer().as<float*>(), weights->size());
277 blob_to_model.push_back(weights);
279 InferenceEngine::Blob::Ptr bias = InferenceEngine::make_shared_blob<float>(InferenceEngine::Precision::FP32, InferenceEngine::C, {p.out_c});
281 fill_data(bias->buffer().as<float*>(), bias->size());
282 blob_to_model.push_back(bias);
284 size_t total_size_in_bytes = 0;
285 for (InferenceEngine::Blob::Ptr blb : blob_to_model) total_size_in_bytes += blb->byteSize();
287 InferenceEngine::TBlob<uint8_t>::Ptr model_blob =
288 InferenceEngine::make_shared_blob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {total_size_in_bytes});
289 model_blob->allocate();
290 uint8_t* model_blob_ptr = model_blob->buffer().as<uint8_t*>();
291 for (InferenceEngine::Blob::Ptr blb : blob_to_model) {
292 memcpy(model_blob_ptr, blb->buffer().as<uint8_t*>(), blb->byteSize());
293 model_blob_ptr += blb->byteSize();
295 net_reader.SetWeights(model_blob);
297 MKLDNNGraphTestClass graph;
298 graph.CreateGraph(net_reader.getNetwork());
299 auto& nodes = graph.getNodes();
300 for (auto &node : nodes) {
301 if (node->getType() == MKLDNNPlugin::Deconvolution) {
302 ASSERT_LE(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size());
303 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
304 p.comp.at(j)(node->getSupportedPrimitiveDescriptors().at(j));
306 ASSERT_NE(nullptr, node->getSelectedPrimitiveDescriptor());
307 bool good_prim = false;
308 for (auto & selected : p.selectedTypes)
309 if (selected == (node->getSelectedPrimitiveDescriptor()->getImplementationType() & selected))
311 ASSERT_TRUE(good_prim);
315 InferenceEngine::SizeVector dims_src = p.dims;
317 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(
318 InferenceEngine::Precision::FP32, InferenceEngine::TensorDesc::getLayoutByDims(p.dims), dims_src);
320 fill_data(src->buffer(), src->size());
322 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
324 if (srcPtr == nullptr)
325 FAIL() << "Cannot cast blob to TBlob<float>.";
327 InferenceEngine::BlobMap srcs;
328 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
330 InferenceEngine::OutputsDataMap out;
331 out = net_reader.getNetwork().getOutputsInfo();
332 InferenceEngine::BlobMap outputBlobs;
334 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
336 InferenceEngine::TBlob<float>::Ptr output;
337 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
339 outputBlobs[item.first] = output;
341 graph.Infer(srcs, outputBlobs);
343 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
346 ref_deconv(*srcPtr, weights, bias, dst_ref, p);
348 compare(*output, dst_ref, 0.0002f);
349 } catch (const InferenceEngine::details::InferenceEngineException &e) {
355 TEST_P(MKLDNNGraphDeconvolutionalTests, TestsDeconvolution) {}
358 INSTANTIATE_TEST_CASE_P(
359 TestDeconvolution, MKLDNNGraphDeconvolutionalTests,
361 /*0*/ deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
362 deconv_test_params{{3, 3, 3, 3}, {4, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
363 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
364 deconv_test_params{{2, 8, 5, 5}, {8, 8}, {4, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
365 deconv_test_params{{2, 8, 5, 5}, {4, 8}, {2, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
366 /*5*/ deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
367 deconv_test_params{{3, 3, 3, 3}, {4, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
368 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 8, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
369 deconv_test_params{{2, 8, 5, 5}, {8, 8}, {4, 4}, {1, 1}, {0, 0}, 8, 8, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
370 deconv_test_params{{2, 8, 5, 5}, {4, 8}, {2, 4}, {1, 1}, {0, 0}, 8, 8, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
371 deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::ref_any},
372 {MKLDNNPlugin::impl_desc_type::ref_any}},
373 /*11*/ deconv_test_params{{2, 8, 5, 5}, {1, 3}, {1, 1}, {0, 1}, {0, 1}, 8, 8, true, "", 2,
374 {MKLDNNPlugin::impl_desc_type::ref_any}, {MKLDNNPlugin::impl_desc_type::ref_any}},
375 deconv_test_params{{1, 6, 6, 5}, {3, 1}, {1, 1}, {1, 0}, {1, 0}, 9, 3, true, "", 2,
376 {MKLDNNPlugin::impl_desc_type::ref_any}, {MKLDNNPlugin::impl_desc_type::ref_any}},
378 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {1, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
379 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
380 deconv_test_params{{4, 17, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
381 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 2, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm}},
382 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {1, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
383 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
384 deconv_test_params{{4, 17, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
385 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 2, true, "", 3, {MKLDNNPlugin::impl_desc_type::gemm}},
386 deconv_test_params{{1, 6, 6, 5}, {3, 1}, {1, 1}, {1, 0}, {1, 0}, 9, 3, true, "", 2,
387 {MKLDNNPlugin::impl_desc_type::gemm_blas}},
388 deconv_test_params{{1, 64, 12, 12, 2}, {2, 2, 2}, {2, 2, 2}, {0, 0, 0}, {1, 0, 0}, 32, 1, true, "", 4,
389 {MKLDNNPlugin::impl_desc_type::gemm_blas}},
390 deconv_test_params{{1, 32, 12, 12, 2}, {2, 2, 2}, {2, 2, 2}, {0, 0, 0}, {1, 0, 0}, 16, 1, true, "", 4,
391 {MKLDNNPlugin::impl_desc_type::gemm_blas} },
392 deconv_test_params{{1, 25, 1, 1, 1}, {4, 4, 4}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, true, "valid", 3,
393 {MKLDNNPlugin::impl_desc_type::jit} },
394 deconv_test_params{{1, 32, 16, 16, 16}, {4, 4, 4}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, 1, 1, true, "same_upper", 3,
395 {MKLDNNPlugin::impl_desc_type::gemm_blas} },
396 deconv_test_params{{1, 64, 12, 12, 2}, {2, 2, 2}, {2, 2, 2}, {0, 0, 0}, {1, 0, 0}, 32, 1, true, "same_upper", 3,
397 {MKLDNNPlugin::impl_desc_type::gemm_blas} },
398 deconv_test_params{{1, 50, 1, 1, 1}, {4, 4, 4}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 128, 1, true, "", 3,
399 {MKLDNNPlugin::impl_desc_type::gemm_blas}, {MKLDNNPlugin::impl_desc_type::gemm_blas}},
401 deconv_test_params{{2, 24, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
402 deconv_test_params{{2, 24, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 1, true, "", 3, {MKLDNNPlugin::impl_desc_type::jit}},
403 deconv_test_params{{2, 72, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 72, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
404 deconv_test_params{{1, 12, 2, 2}, {4, 4}, {2, 2}, {1, 1}, {1, 1}, 12, 12, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
406 deconv_test_params{{1, 2, 8, 5, 5}, {3, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 4, 1, true, "", 4,
407 {MKLDNNPlugin::impl_desc_type::ref_any}, {MKLDNNPlugin::impl_desc_type::ref_any} }
408 // Blocked, with biases
409 // TODO support on jit
410 // deconv_test_params{{2, 24, 5, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
411 // deconv_test_params{{2, 24, 5, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 1, true, "", 3, {MKLDNNPlugin::impl_desc_type::jit}},
412 // deconv_test_params{{2, 72, 5, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 72, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}}
415 class MKLDNNGraphDynBatchDeconvolutionalTests: public MKLDNNGraphDeconvolutionalTests {
417 virtual void SetUp() {
419 TestsCommon::SetUp();
420 deconv_test_params p = ::testing::WithParamInterface<deconv_test_params>::GetParam();
421 std::string model = getModel(p);
422 size_t MB = p.dims[0];
426 InferenceEngine::CNNNetReader net_reader;
427 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
429 size_t blob_size = 1;
430 for (auto k : p.kernel) {
433 InferenceEngine::SizeVector dims_weights = {blob_size * p.out_c * (p.dims[1] / p.grp_c)};
435 std::vector<InferenceEngine::Blob::Ptr> blob_to_model;
436 InferenceEngine::Blob::Ptr weights = InferenceEngine::make_shared_blob<float>(InferenceEngine::Precision::FP32, InferenceEngine::C, dims_weights);
438 fill_data(weights->buffer().as<float*>(), weights->size());
439 blob_to_model.push_back(weights);
441 InferenceEngine::Blob::Ptr bias = InferenceEngine::make_shared_blob<float>(InferenceEngine::Precision::FP32, InferenceEngine::C, {p.out_c});
443 fill_data(bias->buffer().as<float*>(), bias->size());
444 blob_to_model.push_back(bias);
446 size_t total_size_in_bytes = 0;
447 for (InferenceEngine::Blob::Ptr blb : blob_to_model) total_size_in_bytes += blb->byteSize();
449 InferenceEngine::TBlob<uint8_t>::Ptr model_blob =
450 InferenceEngine::make_shared_blob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {total_size_in_bytes});
451 model_blob->allocate();
452 uint8_t* model_blob_ptr = model_blob->buffer().as<uint8_t*>();
453 for (InferenceEngine::Blob::Ptr blb : blob_to_model) {
454 memcpy(model_blob_ptr, blb->buffer().as<uint8_t*>(), blb->byteSize());
455 model_blob_ptr += blb->byteSize();
457 net_reader.SetWeights(model_blob);
459 InferenceEngine::CNNNetwork network = net_reader.getNetwork();
460 auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
461 ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
462 InferenceEngine::ResponseDesc resp;
463 InferenceEngine::StatusCode sts = implNet->setBatchSizeReshape(MB, &resp);
464 ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
467 MKLDNNGraphTestClass graph;
468 graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
469 graph.CreateGraph(net_reader.getNetwork());
471 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(
472 InferenceEngine::Precision::FP32, InferenceEngine::TensorDesc::getLayoutByDims(p.dims), p.dims);
473 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
474 if (srcPtr == nullptr)
475 FAIL() << "Cannot cast blob to TBlob<float>.";
478 fill_data(src->buffer(), src->size());
480 InferenceEngine::BlobMap srcs;
481 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
483 InferenceEngine::OutputsDataMap out;
484 out = net_reader.getNetwork().getOutputsInfo();
485 InferenceEngine::BlobMap outputBlobs;
487 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
489 InferenceEngine::TBlob<float>::Ptr output;
490 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
492 outputBlobs[item.first] = output;
494 auto checkDeconvolution = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
495 return node->getType() == MKLDNNPlugin::Deconvolution;
498 graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkDeconvolution, MKLDNNGraphTestClass::CheckDynBatchType::Child);
499 graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkDeconvolution, MKLDNNGraphTestClass::CheckDynBatchType::Child);
500 } catch (const InferenceEngine::details::InferenceEngineException &e) {
506 TEST_P(MKLDNNGraphDynBatchDeconvolutionalTests, TestsDynBatchDeconvolutional) {}
508 INSTANTIATE_TEST_CASE_P(
509 TestsDynBatchDeconvolutional, MKLDNNGraphDynBatchDeconvolutionalTests,
511 deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 5, {MKLDNNPlugin::impl_desc_type::jit} },
512 deconv_test_params{{3, 3, 3, 3}, {4, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 5, {MKLDNNPlugin::impl_desc_type::jit} },
514 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {1, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 4, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
515 deconv_test_params{{1, 3, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
516 deconv_test_params{{4, 17, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
517 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 2, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm}},
519 deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
520 deconv_test_params{{2, 8, 5, 5}, {8, 8}, {4, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
521 deconv_test_params{{2, 8, 5, 5}, {4, 8}, {2, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}}