1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
11 #include "test_graph.hpp"
13 #include "single_layer_common.hpp"
14 #include <mkldnn_plugin/mkldnn_extension_utils.h>
15 #include <inference_engine/cnn_network_impl.hpp>
16 #include "tests_common.hpp"
19 using namespace ::testing;
21 using namespace mkldnn;
24 struct power_test_params {
38 MKLDNNPlugin::impl_desc_type selectedType;
40 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
43 template <typename data_t>
44 void ref_power(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, power_test_params prm) {
45 const data_t *src_data = src.readOnly();
46 data_t *dst_data = dst.data();
48 #pragma omp parallel for
49 for (int i=0; i < src.size(); i++)
50 dst_data[i] = pow(src_data[i]*prm.scale + prm.shift, prm.power);
53 class MKLDNNGraphPowerTests: public TestsCommon,
54 public WithParamInterface<power_test_params> {
55 std::string model_t = R"V0G0N(
56 <Net Name="Power_Only" version="2" precision="FP32" batch="1">
58 <layer name="in1" type="Input" precision="FP32" id="0">
68 <layer name="power" id="1" type="Power" precision="FP32">
69 <power_data power="_POWER_" scale="_SCALE_" shift="_SHIFT_"/>
89 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
95 std::string getModel(power_test_params p) {
96 std::string model = model_t;
98 REPLACE_WITH_NUM(model, "_IW_", p.in.w);
99 REPLACE_WITH_NUM(model, "_IH_", p.in.h);
100 REPLACE_WITH_NUM(model, "_IC_", p.in.c);
101 REPLACE_WITH_NUM(model, "_IN_", p.in.n);
102 REPLACE_WITH_NUM(model, "_POWER_", p.power);
103 REPLACE_WITH_NUM(model, "_SCALE_", p.scale);
104 REPLACE_WITH_NUM(model, "_SHIFT_", p.shift);
109 virtual void TearDown() {
112 virtual void SetUp() {
114 TestsCommon::SetUp();
115 power_test_params p = ::testing::WithParamInterface<power_test_params>::GetParam();
116 std::string model = getModel(p);
118 InferenceEngine::CNNNetReader net_reader;
119 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
121 MKLDNNGraphTestClass graph;
122 graph.CreateGraph(net_reader.getNetwork());
123 auto& nodes = graph.getNodes();
124 for (int i = 0; i < nodes.size(); i++) {
125 if (nodes[i]->getType() == MKLDNNPlugin::Power) {
126 ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
127 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
128 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
130 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
131 ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
135 InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
137 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
139 fill_data(src->buffer(), src->size());
141 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
143 if (srcPtr == nullptr)
144 FAIL() << "Cannot cast blob to TBlob<float>.";
146 InferenceEngine::BlobMap srcs;
147 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
149 InferenceEngine::OutputsDataMap out;
150 out = net_reader.getNetwork().getOutputsInfo();
151 InferenceEngine::BlobMap outputBlobs;
153 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
155 InferenceEngine::TBlob<float>::Ptr output;
156 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
158 outputBlobs[item.first] = output;
160 graph.Infer(srcs, outputBlobs);
162 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
165 ref_power(*srcPtr, dst_ref, p);
167 compare(*output, dst_ref);
168 } catch (const InferenceEngine::details::InferenceEngineException &e) {
174 TEST_P(MKLDNNGraphPowerTests, TestsPower) {}
177 INSTANTIATE_TEST_CASE_P(
178 TestsPower, MKLDNNGraphPowerTests,
181 {1, 3, 13, 13}, 1, 2, 0.5f, 3, MKLDNNPlugin::impl_desc_type::unknown, {
182 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
183 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
184 ASSERT_EQ(1, impl.getConfig().inConfs.size());
185 ASSERT_EQ(1, impl.getConfig().outConfs.size());
186 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
187 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
189 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
190 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
191 ASSERT_EQ(1, impl.getConfig().inConfs.size());
192 ASSERT_EQ(1, impl.getConfig().outConfs.size());
193 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
194 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
196 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
197 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
198 ASSERT_EQ(1, impl.getConfig().inConfs.size());
199 ASSERT_EQ(1, impl.getConfig().outConfs.size());
200 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
201 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
203 power_test_params{{1, 1, 23, 23}, 3, 8, 2, 3 },
204 power_test_params{{1, 8, 23, 23}, 8, 2, 1, 3 }
207 class MKLDNNGraphDynBatchPowerTests: public MKLDNNGraphPowerTests {
208 std::string model_t = R"V0G0N(
209 <Net Name="Power_Only" version="2" precision="FP32" batch="1">
211 <layer name="in1" type="Input" precision="FP32" id="0">
221 <layer name="power" id="1" type="Power" precision="FP32">
222 <power_data power="_POWER_" scale="_SCALE_" shift="_SHIFT_"/>
242 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
247 std::string getModel(power_test_params p) {
248 std::string model = model_t;
250 REPLACE_WITH_NUM(model, "_IW_", p.in.w);
251 REPLACE_WITH_NUM(model, "_IH_", p.in.h);
252 REPLACE_WITH_NUM(model, "_IC_", p.in.c);
253 REPLACE_WITH_NUM(model, "_IN_", p.in.n);
254 REPLACE_WITH_NUM(model, "_POWER_", p.power);
255 REPLACE_WITH_NUM(model, "_SCALE_", p.scale);
256 REPLACE_WITH_NUM(model, "_SHIFT_", p.shift);
262 virtual void TearDown() {
265 virtual void SetUp() {
267 TestsCommon::SetUp();
268 power_test_params p = ::testing::WithParamInterface<power_test_params>::GetParam();
269 std::string model = getModel(p);
274 InferenceEngine::CNNNetReader net_reader;
275 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
276 InferenceEngine::CNNNetwork network = net_reader.getNetwork();
277 auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
278 ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
279 InferenceEngine::ResponseDesc resp;
280 InferenceEngine::StatusCode sts = implNet->setBatchSizeReshape(MB, &resp);
281 ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
283 MKLDNNGraphTestClass graph;
284 graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
285 graph.CreateGraph(net_reader.getNetwork());
287 InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
289 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
291 fill_data(src->buffer(), src->size());
293 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
295 if (srcPtr == nullptr)
296 FAIL() << "Cannot cast blob to TBlob<float>.";
298 InferenceEngine::BlobMap srcs;
299 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
301 InferenceEngine::OutputsDataMap out;
302 out = net_reader.getNetwork().getOutputsInfo();
303 InferenceEngine::BlobMap outputBlobs;
305 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
307 InferenceEngine::TBlob<float>::Ptr output;
308 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
310 outputBlobs[item.first] = output;
312 auto checkPower = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
313 return node->getType() == MKLDNNPlugin::Power;
315 graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkPower);
316 graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkPower);
317 } catch (const InferenceEngine::details::InferenceEngineException &e) {
323 TEST_P(MKLDNNGraphDynBatchPowerTests, TestsDynBatchPower) {}
325 INSTANTIATE_TEST_CASE_P(
326 TestsDynBatchPower, MKLDNNGraphDynBatchPowerTests,
329 {1, 3, 13, 13}, 1, 2, 0.5f, 3, MKLDNNPlugin::impl_desc_type::unknown, {
330 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
331 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
332 ASSERT_EQ(1, impl.getConfig().inConfs.size());
333 ASSERT_EQ(1, impl.getConfig().outConfs.size());
334 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
335 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
337 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
338 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
339 ASSERT_EQ(1, impl.getConfig().inConfs.size());
340 ASSERT_EQ(1, impl.getConfig().outConfs.size());
341 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
342 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
344 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
345 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
346 ASSERT_EQ(1, impl.getConfig().inConfs.size());
347 ASSERT_EQ(1, impl.getConfig().outConfs.size());
348 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
349 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
351 power_test_params{{1, 1, 23, 23}, 3, 8, 2, 3 },
352 power_test_params{{1, 8, 23, 23}, 8, 2, 1, 3 }