1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
11 #include "test_graph.hpp"
13 #include "single_layer_common.hpp"
14 #include <mkldnn_plugin/mkldnn_extension_utils.h>
15 #include <inference_engine/cnn_network_impl.hpp>
16 #include "tests_common.hpp"
19 using namespace ::testing;
21 using namespace mkldnn;
24 struct conv_test_params {
45 std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
47 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
50 template <typename data_t>
51 void ref_conv(const InferenceEngine::TBlob<data_t> &src, const data_t *weights, const size_t weightsSize,
52 InferenceEngine::TBlob<data_t> &dst, conv_test_params prm) {
53 size_t KW = prm.krn_w;
54 size_t KH = prm.krn_h;
55 size_t GC = prm.grp_c;
57 size_t IC = src.dims()[1];
58 size_t IH = src.dims()[2];
59 size_t IW = src.dims()[3];
61 size_t OW = (IW + 2 * prm.pad_w - prm.krn_w) / prm.str_w + 1;
62 size_t OH = (IH + 2 * prm.pad_h - prm.krn_h) / prm.str_h + 1;
63 size_t OC = prm.out_c;
66 const data_t *src_data = src.readOnly();
67 const data_t *weights_data = weights;
68 const data_t *bias_data = weights_data + KW * KH * OC * IC / GC;
69 data_t *dst_data = dst.data();
71 IE_ASSERT(KW * KH * OC * IC / GC + OC == weightsSize);
72 IE_ASSERT(OW == dst.dims()[0]);
73 IE_ASSERT(OH == dst.dims()[1]);
75 for (uint32_t g = 0; g < GC; g++) {
76 for (uint32_t oc = 0; oc < OC / GC; oc++) {
77 for (uint32_t oh = 0; oh < OH; oh++) {
78 for (uint32_t ow = 0; ow < OW; ow++) {
79 size_t oidx = g * OC / GC * OH * OW
80 + oc * OH * OW + oh * OW + ow;
81 dst_data[oidx] = bias_data[g * OC / GC + oc];
83 for (size_t ic = 0; ic < IC / GC; ic++) {
84 for (size_t kh = 0; kh < KH; kh++) {
85 for (size_t kw = 0; kw < KW; kw++) {
86 int32_t iw = ow * prm.str_w - prm.pad_w + kw;
87 int32_t ih = oh * prm.str_h - prm.pad_h + kh;
88 if (iw < 0 || iw >= (int32_t)IW || ih < 0
91 size_t iidx = g * IC / GC * IH * IW
92 + ic * IH * IW + ih * IW + iw;
93 size_t widx = g * OC / GC * IC / GC * KH * KW
94 + oc * IC / GC * KH * KW
95 + ic * KH * KW + kh * KW + kw;
97 dst_data[ oidx] += src_data[iidx] * weights_data[widx];
107 class MKLDNNGraphConvolutionTests: public TestsCommon,
108 public WithParamInterface<conv_test_params> {
109 std::string model_t = R"V0G0N(
110 <Net Name="Convolution_Only" version="2" precision="FP32" batch="1">
112 <layer name="in1" type="Input" precision="FP32" id="0">
122 <layer name="conv1" id="1" type="Convolution" precision="FP32">
123 <convolution stride-x="_SW_" stride-y="_SH_"
124 pad-x="_PW_" pad-y="_PH_"
125 kernel-x="_KW_" kernel-y="_KH_"
126 output="_OC_" group="_GC_" PrimitivesPriority="_IMPLS_"/>
128 <weights offset="0" size="_S1_" />
129 <biases offset="_S1_" size="_S2_" />
150 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
156 std::string getModel(conv_test_params p) {
157 std::string model = model_t;
158 REPLACE_WITH_NUM(model, "_IW_", p.in.w);
159 REPLACE_WITH_NUM(model, "_IH_", p.in.h);
160 REPLACE_WITH_NUM(model, "_IC_", p.in.c);
161 REPLACE_WITH_NUM(model, "_IN_", p.in.n);
163 REPLACE_WITH_NUM(model, "_KW_", p.krn_w);
164 REPLACE_WITH_NUM(model, "_KH_", p.krn_h);
165 REPLACE_WITH_NUM(model, "_SW_", p.str_w);
166 REPLACE_WITH_NUM(model, "_SH_", p.str_h);
167 REPLACE_WITH_NUM(model, "_PW_", p.pad_w);
168 REPLACE_WITH_NUM(model, "_PH_", p.pad_h);
170 REPLACE_WITH_NUM(model, "_GC_", p.grp_c);
171 REPLACE_WITH_NUM(model, "_OC_", p.out_c);
172 REPLACE_WITH_NUM(model, "_OH_", (p.in.h + 2 * p.pad_h - p.krn_h) / p.str_h + 1);
173 REPLACE_WITH_NUM(model, "_OW_", (p.in.w + 2 * p.pad_w - p.krn_w) / p.str_w + 1);
175 size_t w_data_size = (p.krn_w * p.krn_h * p.out_c * p.in.c / p.grp_c) * sizeof(float);
176 size_t b_data_size = p.out_c * sizeof(float);
177 REPLACE_WITH_NUM(model, "_S1_", w_data_size);
178 REPLACE_WITH_NUM(model, "_S2_", b_data_size);
180 for (const auto& preferType : p.preferTypes) {
183 impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
185 REPLACE_WITH_STR(model, "_IMPLS_", impls);
189 virtual void TearDown() {
192 virtual void SetUp() {
194 TestsCommon::SetUp();
195 conv_test_params p = ::testing::WithParamInterface<conv_test_params>::GetParam();
196 std::string model = getModel(p);
198 InferenceEngine::CNNNetReader net_reader;
199 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
201 InferenceEngine::TBlob<uint8_t> *weights = new InferenceEngine::TBlob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {(p.krn_w * p.krn_h * p.out_c * p.in.c / p.grp_c + p.out_c)
204 fill_data((float *) weights->buffer(), weights->size() / sizeof(float));
205 InferenceEngine::TBlob<uint8_t>::Ptr weights_ptr = InferenceEngine::TBlob<uint8_t>::Ptr(weights);
207 net_reader.SetWeights(weights_ptr);
209 MKLDNNGraphTestClass graph;
210 graph.CreateGraph(net_reader.getNetwork());
212 auto& nodes = graph.getNodes();
213 nodes = graph.getNodes();
214 for (auto &node : nodes) {
215 if (node->getType() == MKLDNNPlugin::Convolution) {
216 ASSERT_LE(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size());
217 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
218 p.comp.at(j)(node->getSupportedPrimitiveDescriptors().at(j));
220 ASSERT_NE(nullptr, node->getSelectedPrimitiveDescriptor());
221 ASSERT_EQ(p.selectedType,
222 node->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
226 InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
228 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
230 fill_data(src->buffer(), src->size());
232 auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
234 if (srcPtr == nullptr)
235 FAIL() << "Cannot cast blob to TBlob<float>.";
237 InferenceEngine::BlobMap srcs;
238 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
240 InferenceEngine::OutputsDataMap out;
241 out = net_reader.getNetwork().getOutputsInfo();
242 InferenceEngine::BlobMap outputBlobs;
244 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
246 InferenceEngine::TBlob<float>::Ptr output;
247 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
249 outputBlobs[item.first] = output;
251 graph.Infer(srcs, outputBlobs);
254 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
256 ref_conv(*srcPtr, (const float *)weights->buffer(), weights->size() / sizeof(float), dst_ref, p);
257 compare(*output, dst_ref);
258 } catch (const InferenceEngine::details::InferenceEngineException &e) {
264 TEST_P(MKLDNNGraphConvolutionTests, TestsConvolution) {}
266 INSTANTIATE_TEST_CASE_P(
267 TestConvolution, MKLDNNGraphConvolutionTests,
269 conv_test_params{{1, 9, 16, 32},
270 1, 1, 1, 1, 0, 0, 17, 1, 7, MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1,
272 conv_test_params{{1, 9, 32, 16},
273 2, 4, 1, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
274 conv_test_params{{1, 9, 32, 16},
275 2, 4, 2, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
276 conv_test_params{{1, 3, 40, 40},
277 3, 3, 1, 2, 0, 0, 20, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
278 conv_test_params{{1, 1, 40, 40},
279 3, 3, 1, 2, 0, 0, 20, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
280 conv_test_params{{1, 1, 32, 16},
281 2, 4, 2, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
282 conv_test_params{{1, 9, 16, 32},
283 1, 1, 1, 1, 0, 0, 17, 1, 7, MKLDNNPlugin::impl_desc_type::gemm,
284 {MKLDNNPlugin::impl_desc_type::gemm_any,
285 MKLDNNPlugin::impl_desc_type::gemm_blas,
286 MKLDNNPlugin::impl_desc_type::gemm_avx512,
287 MKLDNNPlugin::impl_desc_type::gemm_avx2,
288 MKLDNNPlugin::impl_desc_type::gemm_sse42}
290 conv_test_params{{1, 9, 32, 16},
291 2, 4, 1, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::ref_any, {MKLDNNPlugin::impl_desc_type::ref_any} }));
293 class MKLDNNGraphDynBatchConvolutionTests: public MKLDNNGraphConvolutionTests {
295 virtual void SetUp() {
297 TestsCommon::SetUp();
298 conv_test_params p = ::testing::WithParamInterface<conv_test_params>::GetParam();
299 std::string model = getModel(p);
304 InferenceEngine::CNNNetReader net_reader;
305 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
307 InferenceEngine::TBlob<uint8_t> *weights = new InferenceEngine::TBlob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C,
308 {(p.krn_w * p.krn_h * p.out_c * p.in.c / p.grp_c + p.out_c) * sizeof(float)});
310 fill_data((float *) weights->buffer(), weights->size() / sizeof(float));
311 InferenceEngine::TBlob<uint8_t>::Ptr weights_ptr = InferenceEngine::TBlob<uint8_t>::Ptr(weights);
313 net_reader.SetWeights(weights_ptr);
314 InferenceEngine::CNNNetwork network = net_reader.getNetwork();
315 auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
316 ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
317 InferenceEngine::ResponseDesc resp;
318 InferenceEngine::StatusCode sts = implNet->setBatchSizeReshape(MB, &resp);
319 ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
321 MKLDNNGraphTestClass graph;
322 graph.CreateGraph(net_reader.getNetwork());
324 InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
326 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
328 fill_data(src->buffer(), src->size());
330 auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
332 if (srcPtr == nullptr)
333 FAIL() << "Cannot cast blob to TBlob<float>.";
335 InferenceEngine::BlobMap srcs;
336 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
338 InferenceEngine::OutputsDataMap out;
339 out = net_reader.getNetwork().getOutputsInfo();
340 InferenceEngine::BlobMap outputBlobs;
342 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
344 InferenceEngine::TBlob<float>::Ptr output;
345 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
347 outputBlobs[item.first] = output;
349 auto checkConvolution = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
350 return node->getType() == MKLDNNPlugin::Convolution ||
351 node->getType() == MKLDNNPlugin::Convolution_Activation ||
352 node->getType() == MKLDNNPlugin::Convolution_Sum ||
353 node->getType() == MKLDNNPlugin::Convolution_Sum_Activation;
356 graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkConvolution, MKLDNNGraphTestClass::CheckDynBatchType::Child);
357 graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkConvolution, MKLDNNGraphTestClass::CheckDynBatchType::Child);
358 } catch (const InferenceEngine::details::InferenceEngineException &e) {
364 TEST_P(MKLDNNGraphDynBatchConvolutionTests, TestsDynBatchConvolution) {}
366 INSTANTIATE_TEST_CASE_P(
367 TestDynBatchConvolution, MKLDNNGraphDynBatchConvolutionTests,
369 conv_test_params{{1, 8, 16, 32},
370 1, 1, 1, 1, 0, 0, 17, 1, 7, MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1,
372 conv_test_params{{1, 9, 32, 16},
373 2, 4, 1, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
374 conv_test_params{{1, 9, 32, 16},
375 2, 4, 2, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
376 conv_test_params{{1, 3, 40, 40},
377 3, 3, 1, 2, 0, 0, 20, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
378 conv_test_params{{1, 1, 40, 40},
379 3, 3, 1, 2, 0, 0, 20, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
380 conv_test_params{{1, 1, 32, 16},
381 2, 4, 2, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::jit },
382 conv_test_params{{1, 9, 16, 32},
383 1, 1, 1, 1, 0, 0, 17, 1, 7, MKLDNNPlugin::impl_desc_type::gemm,
384 {MKLDNNPlugin::impl_desc_type::gemm_any,
385 MKLDNNPlugin::impl_desc_type::gemm_blas,
386 MKLDNNPlugin::impl_desc_type::gemm_avx512,
387 MKLDNNPlugin::impl_desc_type::gemm_avx2,
388 MKLDNNPlugin::impl_desc_type::gemm_sse42}
390 conv_test_params{{1, 9, 32, 16},
391 2, 4, 1, 1, 0, 0, 17, 1, 5, MKLDNNPlugin::impl_desc_type::ref_any, {MKLDNNPlugin::impl_desc_type::ref_any} }));