1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
11 #include "test_graph.hpp"
13 #include "single_layer_common.hpp"
14 #include <mkldnn_plugin/mkldnn_extension_utils.h>
15 #include <inference_engine/cnn_network_impl.hpp>
16 #include "tests_common.hpp"
19 using namespace ::testing;
21 using namespace mkldnn;
23 struct pooling_test_params {
40 MKLDNNPlugin::impl_desc_type selectedType;
41 std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
43 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
46 template <typename data_t>
47 void ref_pool(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, pooling_test_params prm)
49 size_t KW = prm.krn_w;
50 size_t KH = prm.krn_h;
55 size_t OW = (IW + 2 * prm.pad_w - prm.krn_w) / prm.str_w + 1;
56 size_t OH = (IH + 2 * prm.pad_h - prm.krn_h) / prm.str_h + 1;
59 const data_t *src_data = src.readOnly();
60 data_t *dst_data = dst.data();
62 IE_ASSERT( OC == dst.dims()[2]);
64 for (size_t c = 0; c < OC; c++) {
65 for (size_t oh = 0; oh < OH; oh++) {
66 for (size_t ow = 0; ow < OW; ow++) {
67 size_t oidx = c * OH * OW
69 data_t out_ref = data_t(0);
70 bool is_initialized = false;
71 for (uint32_t kh = 0; kh < KH; kh++) {
72 for (uint32_t kw = 0; kw < KW; kw++) {
73 int32_t iw = ow * prm.str_w - prm.pad_w + kw;
74 int32_t ih = oh * prm.str_h - prm.pad_h + kh;
75 if (iw < 0 || iw >= IW || ih < 0
78 uint32_t iidx = c * IH * IW + ih * IW + iw;
80 data_t d = src_data[iidx];
81 if (!is_initialized) {
83 is_initialized = true;
90 dst_data[oidx] = out_ref;
96 class MKLDNNGraphPoolingTests: public TestsCommon,
97 public WithParamInterface<pooling_test_params> {
98 std::string model_t = R"V0G0N(
99 <Net Name="Pooling_Only" version="2" precision="FP32" batch="1">
101 <layer name="in1" type="Input" precision="FP32" id="0">
111 <layer name="pool" id="1" type="Pooling" precision="FP32">
113 <pooling stride-x="_SW_" stride-y="_SH_"
114 pad-x="_PW_" pad-y="_PH_"
115 kernel-x="_KW_" kernel-y="_KH_"
116 method="MAX" round="Ceil" PrimitivesPriority="_IMPLS_"/>
137 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
143 std::string getModel(pooling_test_params p) {
144 std::string model = model_t;
146 REPLACE_WITH_NUM(model, "_IW_", p.in.w);
147 REPLACE_WITH_NUM(model, "_IH_", p.in.h);
148 REPLACE_WITH_NUM(model, "_IC_", p.in.c);
149 REPLACE_WITH_NUM(model, "_IN_", p.in.n);
151 REPLACE_WITH_NUM(model, "_KW_", p.krn_w);
152 REPLACE_WITH_NUM(model, "_KH_", p.krn_h);
153 REPLACE_WITH_NUM(model, "_SW_", p.str_w);
154 REPLACE_WITH_NUM(model, "_SH_", p.str_h);
155 REPLACE_WITH_NUM(model, "_PW_", p.pad_w);
156 REPLACE_WITH_NUM(model, "_PH_", p.pad_h);
158 REPLACE_WITH_NUM(model, "_OW_", (p.in.w + 2 * p.pad_w - p.krn_w) / p.str_w + 1);
159 REPLACE_WITH_NUM(model, "_OH_", (p.in.h + 2 * p.pad_h - p.krn_h) / p.str_h + 1);
162 for (const auto& preferType : p.preferTypes) {
165 impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
167 REPLACE_WITH_STR(model, "_IMPLS_", impls);
171 virtual void TearDown() {
174 virtual void SetUp() {
176 TestsCommon::SetUp();
177 pooling_test_params p = ::testing::WithParamInterface<pooling_test_params>::GetParam();
178 std::string model = getModel(p);
180 InferenceEngine::CNNNetReader net_reader;
181 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
183 MKLDNNGraphTestClass graph;
184 graph.CreateGraph(net_reader.getNetwork());
185 auto& nodes = graph.getNodes();
186 for (int i = 0; i < nodes.size(); i++) {
187 if (nodes[i]->getType() == MKLDNNPlugin::Pooling) {
188 ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
189 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
190 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
192 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
193 ASSERT_TRUE(nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() | p.selectedType);
197 InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
199 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
201 fill_data(src->buffer(), src->size());
203 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
205 if (srcPtr == nullptr)
206 FAIL() << "Cannot cast blob to TBlob<float>.";
208 InferenceEngine::BlobMap srcs;
209 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
211 InferenceEngine::OutputsDataMap out;
212 out = net_reader.getNetwork().getOutputsInfo();
213 InferenceEngine::BlobMap outputBlobs;
215 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
217 InferenceEngine::TBlob<float>::Ptr output;
218 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
220 outputBlobs[item.first] = output;
222 graph.Infer(srcs, outputBlobs);
224 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
227 ref_pool(*srcPtr, dst_ref, p);
229 compare(*output, dst_ref);
230 } catch (const InferenceEngine::details::InferenceEngineException &e) {
236 TEST_P(MKLDNNGraphPoolingTests, TestsPooling) {}
238 INSTANTIATE_TEST_CASE_P(
239 TestsPooling, MKLDNNGraphPoolingTests,
241 pooling_test_params{{1, 3, 228, 228}, 2, 2, 2, 2, 0, 0, 6, MKLDNNPlugin::impl_desc_type::jit},
242 pooling_test_params{{1, 3, 228, 228}, 4, 2, 2, 2, 0, 0, 4, MKLDNNPlugin::impl_desc_type::jit},
243 pooling_test_params{{1, 3, 228, 228}, 4, 2, 2, 1, 0, 0, 4, MKLDNNPlugin::impl_desc_type::jit},
244 pooling_test_params{{1, 3, 228, 228}, 2, 2, 2, 2, 0, 0, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
245 pooling_test_params{{1, 3, 228, 228}, 4, 2, 2, 2, 0, 0, 4, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
246 pooling_test_params{{1, 3, 228, 228}, 4, 2, 2, 1, 0, 0, 4, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}));
249 class MKLDNNGraphDynBatchPoolingTests: public MKLDNNGraphPoolingTests {
251 virtual void SetUp() {
253 TestsCommon::SetUp();
254 pooling_test_params p = ::testing::WithParamInterface<pooling_test_params>::GetParam();
255 std::string model = getModel(p);
260 InferenceEngine::CNNNetReader net_reader;
261 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
262 InferenceEngine::CNNNetwork network = net_reader.getNetwork();
263 auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
264 ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
265 InferenceEngine::ResponseDesc resp;
266 InferenceEngine::StatusCode sts = implNet->setBatchSizeReshape(MB, &resp);
267 ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
269 MKLDNNGraphTestClass graph;
270 graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
271 graph.CreateGraph(net_reader.getNetwork());
273 InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
275 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
277 fill_data(src->buffer(), src->size());
279 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
281 if (srcPtr == nullptr)
282 FAIL() << "Cannot cast blob to TBlob<float>.";
284 InferenceEngine::BlobMap srcs;
285 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
287 InferenceEngine::OutputsDataMap out;
288 out = net_reader.getNetwork().getOutputsInfo();
289 InferenceEngine::BlobMap outputBlobs;
291 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
293 InferenceEngine::TBlob<float>::Ptr output;
294 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
296 outputBlobs[item.first] = output;
298 auto checkPooling = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
299 return node->getType() == MKLDNNPlugin::Pooling;
301 graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkPooling);
302 graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkPooling);
303 } catch (const InferenceEngine::details::InferenceEngineException &e) {
309 TEST_P(MKLDNNGraphDynBatchPoolingTests, TestsDynBatchPooling) {}
311 INSTANTIATE_TEST_CASE_P(
312 TestsDynBatchPooling, MKLDNNGraphDynBatchPoolingTests,
314 pooling_test_params{{1, 3, 228, 228}, 4, 2, 2, 1, 0, 0, 4, MKLDNNPlugin::impl_desc_type::jit},
315 pooling_test_params{{1, 3, 228, 228}, 2, 2, 2, 2, 0, 0, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
316 pooling_test_params{{1, 3, 228, 228}, 4, 2, 2, 2, 0, 0, 4, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
317 pooling_test_params{{1, 3, 228, 228}, 4, 2, 2, 1, 0, 0, 4, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}));