1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8 #include "test_graph.hpp"
9 #include "single_layer_common.hpp"
10 #include <mkldnn_plugin/mkldnn_extension_utils.h>
11 #include <inference_engine/cnn_network_impl.hpp>
12 #include "tests_common.hpp"
14 using namespace ::testing;
16 using namespace mkldnn;
18 struct depthwise_test_params {
19 mkldnn::algorithm alg;
21 // Formats: NCHW, NCDHW
28 MKLDNNPlugin::impl_desc_type selectedType;
29 std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
31 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
34 template <typename data_t>
35 void ref_depthwise(const InferenceEngine::TBlob<data_t> &src, const data_t *weights, const size_t weightsSize,
36 InferenceEngine::TBlob<data_t> &dst, depthwise_test_params prm) {
37 auto dims_size = src.dims().size();
39 size_t IW = src.dims()[dims_size - 1];
40 size_t IH = src.dims()[dims_size - 2];
41 size_t ID = dims_size == 5 ? src.dims()[2] : 1u;
42 size_t IC = src.dims()[1];
43 size_t MB = src.dims()[0];
45 const data_t *src_data = src.readOnly();
46 const data_t *weights_data = weights;
47 size_t bias_offset = prm.isBroadcast ? 1 : IC;
48 const data_t *bias_data = weights_data + bias_offset;
49 data_t *dst_data = dst.data();
54 for (int mb = 0; mb < MB; mb++) {
56 for (int c = 0; c < IC; c++) {
57 size_t m2 = m1 + c * c1;
58 for (int d = 0; d < ID; d++) {
59 size_t m3 = m2 + d * c2;
60 for (int h = 0; h < IH; h++) {
61 size_t m4 = m3 + h * IW;
62 for (int w = 0; w < IW; w++) {
65 int widx = prm.isBroadcast ? 0 : c;
66 int bidx = prm.isBroadcast ? 0 : c;
68 if (prm.alg == depthwise_scale_shift)
69 dst_data[idx] = src_data[idx] * weights_data[widx] + bias_data[bidx];
70 else if (prm.alg == depthwise_prelu)
71 dst_data[idx] = src_data[idx] > 0 ? src_data[idx] : src_data[idx]*weights_data[widx];
79 class MKLDNNGraphDepthwiseTests: public TestsCommon,
80 public WithParamInterface<depthwise_test_params> {
81 std::string model_t_4D = R"V0G0N(
82 <Net Name="Lrn_Only" version="2" precision="FP32" batch="1">
84 <layer name="in1" type="Input" precision="FP32" id="0">
94 <layer name="depthwise" id="1" type="_LT_" precision="FP32">
95 <data _P_NAME_="_P_VAL_" PrimitivesPriority="_IMPLS_"/>
96 <weights offset="0" size="_S1_" />
97 <biases offset="_S1_" size="_S2_" />
118 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
124 std::string model_t_5D = R"V0G0N(
125 <Net Name="Lrn_Only" version="2" precision="FP32" batch="1">
127 <layer name="in1" type="Input" precision="FP32" id="0">
138 <layer name="depthwise" id="1" type="_LT_" precision="FP32">
139 <data _P_NAME_="_P_VAL_" PrimitivesPriority="_IMPLS_"/>
140 <weights offset="0" size="_S1_" />
141 <biases offset="_S1_" size="_S2_" />
164 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
170 std::string getModel(depthwise_test_params p) {
172 auto dims_size = p.dims.size();
173 if (dims_size == 4) {
175 } else if (dims_size == 5) {
179 REPLACE_WITH_NUM(model, "_IW_", p.dims[dims_size - 1]);
180 REPLACE_WITH_NUM(model, "_IC_", p.dims[1]);
181 REPLACE_WITH_NUM(model, "_IN_", p.dims[0]);
185 REPLACE_WITH_NUM(model, "_ID_", p.dims[dims_size - 3]);
187 REPLACE_WITH_NUM(model, "_IH_", p.dims[dims_size - 2]);
190 if (p.alg == depthwise_scale_shift) {
191 REPLACE_WITH_STR(model, "_LT_", "ScaleShift");
192 REPLACE_WITH_STR(model, "_P_NAME_", "broadcast");
193 REPLACE_WITH_NUM(model, "_P_VAL_", p.isBroadcast ? 1 : 0);
195 else if (p.alg == depthwise_prelu) {
196 REPLACE_WITH_STR(model, "_LT_", "PReLU");
197 REPLACE_WITH_STR(model, "_P_NAME_", "channel_shared");
198 REPLACE_WITH_NUM(model, "_P_VAL_", p.isBroadcast ? 1 : 0);
201 size_t array_size = p.isBroadcast ? 1 : p.dims[1];
202 size_t w_data_size = array_size * sizeof(float);
203 size_t b_data_size = array_size * sizeof(float);
204 REPLACE_WITH_NUM(model, "_S1_", w_data_size);
205 REPLACE_WITH_NUM(model, "_S2_", b_data_size);
208 for (const auto& preferType : p.preferTypes) {
211 impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
213 REPLACE_WITH_STR(model, "_IMPLS_", impls);
218 virtual void SetUp() {
220 TestsCommon::SetUp();
221 depthwise_test_params p = ::testing::WithParamInterface<depthwise_test_params>::GetParam();
222 std::string model = getModel(p);
224 InferenceEngine::CNNNetReader net_reader;
225 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
227 size_t weightSize = 2 * p.dims[1] * sizeof(float);
228 InferenceEngine::TBlob<uint8_t> *weights = new InferenceEngine::TBlob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {weightSize});
230 fill_data( weights->data().as<float*>(), weights->size() / sizeof(float));
232 InferenceEngine::TBlob<uint8_t>::Ptr weights_ptr = InferenceEngine::TBlob<uint8_t>::Ptr(weights);
234 net_reader.SetWeights(weights_ptr);
236 MKLDNNGraphTestClass graph;
237 graph.CreateGraph(net_reader.getNetwork());
238 auto& nodes = graph.getNodes();
239 for (int i = 0; i < nodes.size(); i++) {
240 if (nodes[i]->getType() == MKLDNNPlugin::Depthwise) {
241 ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
242 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
243 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
245 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
246 ASSERT_EQ(p.selectedType,
247 nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
251 InferenceEngine::SizeVector dims_src = p.dims;
252 InferenceEngine::Layout layout = InferenceEngine::ANY;
253 switch (p.dims.size()) {
255 layout = InferenceEngine::NCHW;
258 layout = InferenceEngine::NCDHW;
262 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
264 fill_data(src->buffer(), src->size());
266 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
268 if (srcPtr == nullptr)
269 FAIL() << "Cannot cast blob to TBlob<float>.";
271 InferenceEngine::BlobMap srcs;
272 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
274 InferenceEngine::OutputsDataMap out;
275 out = net_reader.getNetwork().getOutputsInfo();
276 InferenceEngine::BlobMap outputBlobs;
278 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
280 InferenceEngine::TBlob<float>::Ptr output;
281 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
283 outputBlobs[item.first] = output;
285 graph.Infer(srcs, outputBlobs);
287 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
290 ref_depthwise(*srcPtr, weights->readOnly().as<const float*>(), weights->size() / sizeof(float), dst_ref, p);
292 compare(*output, dst_ref);
293 } catch (const InferenceEngine::details::InferenceEngineException &e) {
299 TEST_P(MKLDNNGraphDepthwiseTests, TestsDepthwise) {}
301 INSTANTIATE_TEST_CASE_P(
302 TestsDepthwise, MKLDNNGraphDepthwiseTests,
304 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
305 depthwise_test_params{depthwise_scale_shift, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
306 depthwise_test_params{depthwise_scale_shift, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
307 depthwise_test_params{depthwise_scale_shift, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
308 depthwise_test_params{depthwise_scale_shift, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
309 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
310 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, false,3, MKLDNNPlugin::impl_desc_type::jit},
311 depthwise_test_params{depthwise_prelu, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
312 depthwise_test_params{depthwise_prelu, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
313 depthwise_test_params{depthwise_prelu, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
314 depthwise_test_params{depthwise_prelu, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
315 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
316 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
317 depthwise_test_params{depthwise_scale_shift, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
318 depthwise_test_params{depthwise_scale_shift, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
319 depthwise_test_params{depthwise_scale_shift, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
320 depthwise_test_params{depthwise_scale_shift, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
321 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
322 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, false,3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
323 depthwise_test_params{depthwise_prelu, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
324 depthwise_test_params{depthwise_prelu, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
325 depthwise_test_params{depthwise_prelu, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
326 depthwise_test_params{depthwise_prelu, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
327 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
329 // mkl-dnn does not support 5D depthwise on jit yet
330 // depthwise_test_params{depthwise_scale_shift, {1, 32, 16, 128, 256}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
331 // depthwise_test_params{depthwise_scale_shift, {4, 3, 16, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
332 // depthwise_test_params{depthwise_scale_shift, {1, 1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
333 // depthwise_test_params{depthwise_scale_shift, {4, 4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
334 // depthwise_test_params{depthwise_scale_shift, {1, 32, 16, 128, 256}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
335 // depthwise_test_params{depthwise_scale_shift, {4, 3, 16, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
336 depthwise_test_params{depthwise_scale_shift, {1, 1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
337 depthwise_test_params{depthwise_scale_shift, {4, 4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}
340 class MKLDNNGraphDynBatchDepthwiseTests: public MKLDNNGraphDepthwiseTests {
343 virtual void SetUp() {
345 TestsCommon::SetUp();
346 depthwise_test_params p = ::testing::WithParamInterface<depthwise_test_params>::GetParam();
347 std::string model = getModel(p);
348 size_t MB = p.dims[0];
352 InferenceEngine::CNNNetReader net_reader;
353 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
355 InferenceEngine::TBlob<uint8_t> *weights = new InferenceEngine::TBlob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {p.dims[1] * 4 * sizeof(float)});
357 fill_data( weights->data().as<float*>(), weights->size() / sizeof(float));
358 float * data = weights->buffer();
359 for (size_t i = 0; i < weights->size() / sizeof(float); i++) {
364 InferenceEngine::TBlob<uint8_t>::Ptr weights_ptr = InferenceEngine::TBlob<uint8_t>::Ptr(weights);
365 net_reader.SetWeights(weights_ptr);
366 InferenceEngine::CNNNetwork network = net_reader.getNetwork();
367 auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
368 ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
369 InferenceEngine::ResponseDesc resp;
370 InferenceEngine::StatusCode sts = implNet->setBatchSizeReshape(MB, &resp);
371 ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
374 MKLDNNGraphTestClass graph;
375 graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
376 graph.CreateGraph(net_reader.getNetwork());
378 InferenceEngine::SizeVector dims_src = p.dims;
379 InferenceEngine::Layout layout = InferenceEngine::ANY;
380 switch (p.dims.size()) {
382 layout = InferenceEngine::NCHW;
385 layout = InferenceEngine::NCDHW;
388 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
389 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
390 if (srcPtr == nullptr)
391 FAIL() << "Cannot cast blob to TBlob<float>.";
394 fill_data(src->buffer(), src->size());
396 InferenceEngine::BlobMap srcs;
397 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
399 InferenceEngine::OutputsDataMap out;
400 out = net_reader.getNetwork().getOutputsInfo();
401 InferenceEngine::BlobMap outputBlobs;
403 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
405 InferenceEngine::TBlob<float>::Ptr output;
406 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
408 outputBlobs[item.first] = output;
410 auto checkDepthwise = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
411 return node->getType() == MKLDNNPlugin::Depthwise;
414 graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkDepthwise);
415 graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkDepthwise);
416 } catch (const InferenceEngine::details::InferenceEngineException &e) {
422 TEST_P(MKLDNNGraphDynBatchDepthwiseTests, TestsDynBatchDepthwise) {}
424 INSTANTIATE_TEST_CASE_P(
425 TestsDynBatchDepthwise, MKLDNNGraphDynBatchDepthwiseTests,
427 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
428 depthwise_test_params{depthwise_scale_shift, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
429 depthwise_test_params{depthwise_scale_shift, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
430 depthwise_test_params{depthwise_scale_shift, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
431 depthwise_test_params{depthwise_scale_shift, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
432 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
433 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, false,3, MKLDNNPlugin::impl_desc_type::jit},
434 depthwise_test_params{depthwise_prelu, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
435 depthwise_test_params{depthwise_prelu, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
436 depthwise_test_params{depthwise_prelu, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::jit},
437 depthwise_test_params{depthwise_prelu, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
438 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::jit},
439 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
440 depthwise_test_params{depthwise_scale_shift, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
441 depthwise_test_params{depthwise_scale_shift, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
442 depthwise_test_params{depthwise_scale_shift, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
443 depthwise_test_params{depthwise_scale_shift, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
444 depthwise_test_params{depthwise_scale_shift, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
445 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, false,3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
446 depthwise_test_params{depthwise_prelu, {4, 3, 228, 228}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
447 depthwise_test_params{depthwise_prelu, {1, 1, 1, 1}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
448 depthwise_test_params{depthwise_prelu, {1, 4, 5, 5}, false, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
449 depthwise_test_params{depthwise_prelu, {4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
450 depthwise_test_params{depthwise_prelu, {1, 32, 128, 256}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}