-// Copyright (C) 2018 Intel Corporation
+// Copyright (C) 2018-2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "single_layer_common.hpp"
#include <mkldnn_plugin/mkldnn_extension_utils.h>
#include <inference_engine/cnn_network_impl.hpp>
+#include "ir_gen_helper.hpp"
#include "tests_common.hpp"
using namespace ::testing;
using namespace std;
using namespace mkldnn;
+using namespace single_layer_tests;
struct deconv_test_params {
size_t OC = prm.out_c;
- size_t OW = SW * (IW - 1) + KW - 2 * PW;
- size_t OH = SH * (IH - 1) + KH - 2 * PH;
+ size_t OW = SW * (IW - 1lu) + KW - 2lu * PW;
+ size_t OH = SH * (IH - 1lu) + KH - 2lu * PH;
size_t OD = dims_size == 5 ? (SD * (ID - 1) + KD - 2 * PD) : 1u;
const data_t *src_data = src.readOnly();
size_t CI1 = IH * IW;
size_t CI2 = CI1 * ID;
size_t CI3 = CI2 * IC;
+
+ size_t OC_G = OC / G;
+ size_t IC_G = IC / G;
size_t CK1 = KH * KW;
size_t CK2 = CK1 * KD;
- size_t CK3 = CK2 * (OC / G);
- size_t CK4 = CK3 * (IC / G);
-
- for (int g = 0; g < G; ++g) {
- for (int mb = 0; mb < MB; ++mb) {
- for (int oc = 0; oc < OC / G; ++oc) {
- for (int od = 0; od < OD; ++od) {
- for (int oh = 0; oh < OH; ++oh) {
- for (int ow = 0; ow < OW; ++ow) {
- size_t didx = mb * CS3
- + (g * OC / G + oc) * CS2
- + od * CS1
- + oh * OW
- + ow;
+ size_t CK3 = CK2 * OC_G;
+ size_t CK4 = CK3 * IC_G;
+
+ for (size_t g = 0lu; g < G; ++g) {
+ size_t g_OC_G = g * OC_G;
+ size_t g_IC_G = g * IC_G;
+ size_t g_CK4 = g * CK4;
+ for (size_t mb = 0lu; mb < MB; ++mb) {
+ size_t mb_CS3 = mb * CS3;
+ size_t mb_CI3 = mb * CI3;
+ for (size_t oc = 0lu; oc < OC_G; ++oc) {
+ size_t g_OC_G_oc = g_OC_G + oc;
+ size_t mb_CS3_g_OC_G_oc_CS2 = mb_CS3 + g_OC_G_oc * CS2;
+ size_t g_CK4_oc_CK2 = g_CK4 + oc * CK2;
+ for (size_t od = 0lu; od < OD; ++od) {
+ size_t mb_CS3_g_OC_G_oc_CS2_od_CS1 = mb_CS3_g_OC_G_oc_CS2 + od * CS1;
+ size_t od_PD = od + PD;
+ for (size_t oh = 0lu; oh < OH; ++oh) {
+ size_t mb_CS3_g_OC_G_oc_CS2_od_CS1_oh_OW = mb_CS3_g_OC_G_oc_CS2_od_CS1 + oh * OW;
+ size_t oh_PH = oh + PH;
+ for (size_t ow = 0lu; ow < OW; ++ow) {
+ size_t didx = mb_CS3_g_OC_G_oc_CS2_od_CS1_oh_OW + ow;
+ size_t ow_PW = ow + PW;
dst_data[didx] = data_t(0);
- if (prm.with_bias) dst_data[didx] += bias_data[g * OC / G + oc];
-
- for (int ic = 0; ic < IC / G; ic++) {
- for (int kd = 0; kd < KD; kd++) {
- for (int kh = 0; kh < KH; kh++) {
- for (int kw = 0; kw < KW; kw++) {
- if (ow + PW < kw || oh + PH < kh || od + PD < kd)
- continue;
+ if (prm.with_bias) dst_data[didx] += bias_data[g_OC_G_oc];
+
+ for (size_t ic = 0lu; ic < IC_G; ic++) {
+ size_t mb_CI3_g_IC_G_ic_CI2 = mb_CI3 + (g_IC_G + ic) * CI2;
+ size_t g_CK4_oc_CK2_ic_CK3 = g_CK4_oc_CK2 + ic * CK3;
+ for (int kd = 0lu; kd < KD; kd++) {
+ if (od_PD < kd) continue;
+ size_t id = od_PD - kd;
+ if (id % SD != 0) continue;
+ id /= SD;
+ if (id >= ID) continue;
+ size_t mb_CI3_g_IC_G_ic_CI2_id_CI1 = mb_CI3_g_IC_G_ic_CI2 + id * CI1;
+ size_t g_CK4_oc_CK2_ic_CK3_kd_CK1 = g_CK4_oc_CK2_ic_CK3 + kd * CK1;
+ for (size_t kh = 0lu; kh < KH; kh++) {
+ if (oh_PH < kh) continue;
+ size_t ih = oh_PH - kh;
+ if (ih % SH != 0) continue;
+ ih /= SH;
+ if (ih >= IH) continue;
+ size_t mb_CI3_g_IC_G_ic_CI2_id_CI1_ih_IW = mb_CI3_g_IC_G_ic_CI2_id_CI1 + ih * IW;
+ size_t g_CK4_oc_CK2_ic_CK3_kd_CK1_kh_KW = g_CK4_oc_CK2_ic_CK3_kd_CK1 + kh * KW;
+ for (size_t kw = 0lu; kw < KW; kw++) {
+ if (ow_PW < kw) continue;
+ size_t iw = ow_PW - kw;
+ if (iw % SW != 0) continue;
+ iw /= SW;
+ if (iw >= IW) continue;
- size_t iw = ow - kw + PW;
- size_t ih = oh - kh + PH;
- size_t id = od - kd + PD;
+ size_t sidx = mb_CI3_g_IC_G_ic_CI2_id_CI1_ih_IW + iw;
- if (iw % SW != 0 || ih % SH != 0 || id % SD != 0)
- continue;
+ size_t widx = g_CK4_oc_CK2_ic_CK3_kd_CK1_kh_KW + kw;
- iw /= SW;
- ih /= SH;
- id /= SD;
-
- if (ih < IH && iw < IW && id < ID) {
- size_t sidx = mb * CI3
- + (g * IC / G + ic) * CI2
- + id * CI1
- + ih * IW
- + iw;
-
- size_t widx = g * CK4
- + ic * CK3
- + oc * CK2
- + kd * CK1
- + kh * KW
- + kw;
-
- dst_data[didx] += src_data[sidx] * weights_data[widx];
- }
+ dst_data[didx] += src_data[sidx] * weights_data[widx];
}
}
}
class MKLDNNGraphDeconvolutionalTests: public TestsCommon,
public WithParamInterface<deconv_test_params> {
- std::string model_t_5D = R"V0G0N(
-<net name="Deconvolution_Only" version="3" precision="FP32" batch="1">
- <layers>
- <layer name="in1" type="Input" precision="FP32" id="0">
- <output>
- <port id="0">__SRC_DIMS__
- </port>
- </output>
- </layer>
+ std::string layers_t = R"V0G0N(
<layer name="deconv1" id="1" type="Deconvolution" precision="FP32">
<deconvolution _AP_ kernel="_K_"
pads_begin="_PB_" pads_end="_PE_"
<biases offset="_S1_" size="_S2_" />
<input>
- <port id="1">__SRC_DIMS__
+ <port id="1">
+ __SRC_DIMS__
</port>
</input>
<output>
<port id="2">
<dim>_IN_</dim>
- <dim>_OC_</dim>__DST_DIMS__
+ <dim>_OC_</dim>
+ __DST_DIMS__
</port>
</output>
</layer>
- </layers>
- <edges>
+)V0G0N";
+
+ std::string edges_t = R"V0G0N(
<edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
- </edges>
-</net>
)V0G0N";
protected:
std::string getModel(deconv_test_params p) {
- std::string model = model_t_5D;
- auto dims_size = p.dims.size();
+ std::string model = layers_t;
+
std::string s_dims;
for (auto& dim : p.dims) {
s_dims += "\n <dim>";
}
REPLACE_WITH_STR(model, "_IMPLS_", impls);
+ model = IRTemplateGenerator::getIRTemplate("Deconvolution_Only", p.dims, "FP32", model, edges_t);
+
return model;
}
InferenceEngine::SizeVector dims_src = p.dims;
- InferenceEngine::Layout layout = ANY;
- switch (p.dims.size()) {
- case 4:
- layout = InferenceEngine::NCHW;
- break;
- case 5:
- layout = InferenceEngine::NCDHW;
- break;
- }
- InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
+ InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(
+ InferenceEngine::Precision::FP32, InferenceEngine::TensorDesc::getLayoutByDims(p.dims), dims_src);
src->allocate();
fill_data(src->buffer(), src->size());
::testing::Values(
/*0*/ deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
deconv_test_params{{3, 3, 3, 3}, {4, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
- deconv_test_params{{1, 3, 3, 3}, {4, 3}, {1, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
- deconv_test_params{{1, 3, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
- deconv_test_params{{4, 17, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
deconv_test_params{{2, 8, 5, 5}, {8, 8}, {4, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
deconv_test_params{{2, 8, 5, 5}, {4, 8}, {2, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
- /*8*/ deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
+ /*5*/ deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
deconv_test_params{{3, 3, 3, 3}, {4, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::jit} },
- deconv_test_params{{1, 3, 3, 3}, {4, 3}, {1, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
- deconv_test_params{{1, 3, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
- deconv_test_params{{4, 17, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 8, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
deconv_test_params{{2, 8, 5, 5}, {8, 8}, {4, 4}, {1, 1}, {0, 0}, 8, 8, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
deconv_test_params{{2, 8, 5, 5}, {4, 8}, {2, 4}, {1, 1}, {0, 0}, 8, 8, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::ref_any},
{MKLDNNPlugin::impl_desc_type::ref_any}},
- /*17*/ deconv_test_params{{2, 8, 5, 5}, {1, 3}, {1, 1}, {0, 1}, {0, 1}, 8, 8, true, "", 2,
+ /*11*/ deconv_test_params{{2, 8, 5, 5}, {1, 3}, {1, 1}, {0, 1}, {0, 1}, 8, 8, true, "", 2,
{MKLDNNPlugin::impl_desc_type::ref_any}, {MKLDNNPlugin::impl_desc_type::ref_any}},
deconv_test_params{{1, 6, 6, 5}, {3, 1}, {1, 1}, {1, 0}, {1, 0}, 9, 3, true, "", 2,
{MKLDNNPlugin::impl_desc_type::ref_any}, {MKLDNNPlugin::impl_desc_type::ref_any}},
- deconv_test_params{{2, 24, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
- deconv_test_params{{2, 24, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 1, true, "", 3, {MKLDNNPlugin::impl_desc_type::jit}},
- deconv_test_params{{2, 72, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 72, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
- deconv_test_params{{1, 12, 2, 2}, {4, 4}, {2, 2}, {1, 1}, {1, 1}, 12, 12, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
#ifdef USE_MKL
+ deconv_test_params{{1, 3, 3, 3}, {4, 3}, {1, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
+ deconv_test_params{{1, 3, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
+ deconv_test_params{{4, 17, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 2, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm}},
+ deconv_test_params{{1, 3, 3, 3}, {4, 3}, {1, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
+ deconv_test_params{{1, 3, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
+ deconv_test_params{{4, 17, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, true, "", 2, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 2, true, "", 3, {MKLDNNPlugin::impl_desc_type::gemm}},
deconv_test_params{{1, 6, 6, 5}, {3, 1}, {1, 1}, {1, 0}, {1, 0}, 9, 3, true, "", 2,
{MKLDNNPlugin::impl_desc_type::gemm_blas}},
deconv_test_params{{1, 32, 12, 12, 2}, {2, 2, 2}, {2, 2, 2}, {0, 0, 0}, {1, 0, 0}, 16, 1, true, "", 4,
{MKLDNNPlugin::impl_desc_type::gemm_blas} },
deconv_test_params{{1, 25, 1, 1, 1}, {4, 4, 4}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, true, "valid", 3,
- {MKLDNNPlugin::impl_desc_type::gemm_blas} },
+ {MKLDNNPlugin::impl_desc_type::jit} },
deconv_test_params{{1, 32, 16, 16, 16}, {4, 4, 4}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, 1, 1, true, "same_upper", 3,
{MKLDNNPlugin::impl_desc_type::gemm_blas} },
deconv_test_params{{1, 64, 12, 12, 2}, {2, 2, 2}, {2, 2, 2}, {0, 0, 0}, {1, 0, 0}, 32, 1, true, "same_upper", 3,
deconv_test_params{{1, 50, 1, 1, 1}, {4, 4, 4}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 128, 1, true, "", 3,
{MKLDNNPlugin::impl_desc_type::gemm_blas}, {MKLDNNPlugin::impl_desc_type::gemm_blas}},
#endif
+ deconv_test_params{{2, 24, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
+ deconv_test_params{{2, 24, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 1, true, "", 3, {MKLDNNPlugin::impl_desc_type::jit}},
+ deconv_test_params{{2, 72, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 72, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
+ deconv_test_params{{1, 12, 2, 2}, {4, 4}, {2, 2}, {1, 1}, {1, 1}, 12, 12, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
// 5D
deconv_test_params{{1, 2, 8, 5, 5}, {3, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 4, 1, true, "", 4,
{MKLDNNPlugin::impl_desc_type::ref_any}, {MKLDNNPlugin::impl_desc_type::ref_any} }
-
// Blocked, with biases
// TODO support on jit
// deconv_test_params{{2, 24, 5, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 24, 3, true, "", 4, {MKLDNNPlugin::impl_desc_type::jit}},
graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
graph.CreateGraph(net_reader.getNetwork());
- InferenceEngine::SizeVector dims_src = p.dims;
-
- InferenceEngine::Layout layout = ANY;
- switch (p.dims.size()) {
- case 4:
- layout = InferenceEngine::NCHW;
- break;
- case 5:
- layout = InferenceEngine::NCDHW;
- break;
- }
- InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
+ InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(
+ InferenceEngine::Precision::FP32, InferenceEngine::TensorDesc::getLayoutByDims(p.dims), p.dims);
InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
if (srcPtr == nullptr)
FAIL() << "Cannot cast blob to TBlob<float>.";
::testing::Values(
deconv_test_params{{1, 3, 3, 3}, {3, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 5, {MKLDNNPlugin::impl_desc_type::jit} },
deconv_test_params{{3, 3, 3, 3}, {4, 3}, {1, 1}, {0, 0}, {0, 0}, 2, 1, false, "", 5, {MKLDNNPlugin::impl_desc_type::jit} },
+#ifdef USE_MKL
deconv_test_params{{1, 3, 3, 3}, {4, 3}, {1, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 4, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
deconv_test_params{{1, 3, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
deconv_test_params{{4, 17, 3, 3}, {4, 3}, {2, 2}, {0, 0}, {0, 0}, 2, 1, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm, MKLDNNPlugin::impl_desc_type::jit} },
deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 2, false, "", 3, {MKLDNNPlugin::impl_desc_type::gemm}},
+#endif
deconv_test_params{{2, 8, 5, 5}, {4, 4}, {2, 2}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
deconv_test_params{{2, 8, 5, 5}, {8, 8}, {4, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}},
deconv_test_params{{2, 8, 5, 5}, {4, 8}, {2, 4}, {1, 1}, {0, 0}, 8, 8, false, "", 4, {MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_dw}}