Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_power_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8
9 #include "test_graph.hpp"
10
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
15
16
17 using namespace ::testing;
18 using namespace std;
19 using namespace mkldnn;
20
21
22 struct power_test_params {
23     struct {
24         size_t n;
25         size_t c;
26         size_t h;
27         size_t w;
28     } in;
29
30     float power;
31     float scale;
32     float shift;
33
34     size_t num_prim_desc;
35
36     MKLDNNPlugin::impl_desc_type selectedType;
37
38     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
39 };
40
41 template <typename data_t>
42 void ref_power(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, power_test_params prm) {
43     const data_t *src_data = src.readOnly();
44     data_t *dst_data = dst.data();
45
46     for (int i=0; i < src.size(); i++)
47         dst_data[i] = pow(src_data[i]*prm.scale + prm.shift, prm.power);
48 }
49
50 class MKLDNNGraphPowerTests: public TestsCommon,
51                                      public WithParamInterface<power_test_params> {
52     std::string model_t = R"V0G0N(
53 <Net Name="Power_Only" version="2" precision="FP32" batch="1">
54     <layers>
55         <layer name="in1" type="Input" precision="FP32" id="0">
56             <output>
57                 <port id="0">
58                     <dim>_IN_</dim>
59                     <dim>_IC_</dim>
60                     <dim>_IH_</dim>
61                     <dim>_IW_</dim>
62                 </port>
63             </output>
64         </layer>
65         <layer name="power" id="1" type="Power" precision="FP32">
66             <power_data power="_POWER_" scale="_SCALE_" shift="_SHIFT_"/>
67             <input>
68                 <port id="1">
69                     <dim>_IN_</dim>
70                     <dim>_IC_</dim>
71                     <dim>_IH_</dim>
72                     <dim>_IW_</dim>
73                 </port>
74             </input>
75             <output>
76                 <port id="2">
77                     <dim>_IN_</dim>
78                     <dim>_IC_</dim>
79                     <dim>_IH_</dim>
80                     <dim>_IW_</dim>
81                 </port>
82             </output>
83         </layer>
84     </layers>
85     <edges>
86         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
87     </edges>
88 </Net>
89 )V0G0N";
90
91 protected:
92     std::string getModel(power_test_params p) {
93         std::string model = model_t;
94
95         REPLACE_WITH_NUM(model, "_IW_", p.in.w);
96         REPLACE_WITH_NUM(model, "_IH_", p.in.h);
97         REPLACE_WITH_NUM(model, "_IC_", p.in.c);
98         REPLACE_WITH_NUM(model, "_IN_", p.in.n);
99         REPLACE_WITH_NUM(model, "_POWER_", p.power);
100         REPLACE_WITH_NUM(model, "_SCALE_", p.scale);
101         REPLACE_WITH_NUM(model, "_SHIFT_", p.shift);
102
103         return model;
104     }
105
106     virtual void TearDown() {
107     }
108
109     virtual void SetUp() {
110         try {
111             TestsCommon::SetUp();
112             power_test_params p = ::testing::WithParamInterface<power_test_params>::GetParam();
113             std::string model = getModel(p);
114
115             InferenceEngine::CNNNetReader net_reader;
116             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
117
118             MKLDNNGraphTestClass graph;
119             graph.CreateGraph(net_reader.getNetwork());
120             auto& nodes = graph.getNodes();
121             for (int i = 0; i < nodes.size(); i++) {
122                 if (nodes[i]->getType() == MKLDNNPlugin::Power) {
123                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
124                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
125                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
126                     }
127                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
128                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
129                 }
130             }
131
132             InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
133
134             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
135             src->allocate();
136             fill_data(src->buffer(), src->size());
137
138             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
139
140             if (srcPtr == nullptr)
141                 FAIL() << "Cannot cast blob to TBlob<float>.";
142
143             InferenceEngine::BlobMap srcs;
144             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
145
146             InferenceEngine::OutputsDataMap out;
147             out = net_reader.getNetwork().getOutputsInfo();
148             InferenceEngine::BlobMap outputBlobs;
149
150             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
151
152             InferenceEngine::TBlob<float>::Ptr output;
153             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
154             output->allocate();
155             outputBlobs[item.first] = output;
156
157             graph.Infer(srcs, outputBlobs);
158
159             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
160             dst_ref.allocate();
161
162             ref_power(*srcPtr, dst_ref, p);
163
164             compare(*output, dst_ref);
165         } catch (const InferenceEngine::details::InferenceEngineException &e) {
166             FAIL() << e.what();
167         }
168     }
169 };
170
171 TEST_P(MKLDNNGraphPowerTests, TestsPower) {}
172
173
174 INSTANTIATE_TEST_CASE_P(
175         TestsPower, MKLDNNGraphPowerTests,
176         ::testing::Values(
177                 power_test_params{
178                         {1, 3, 13, 13}, 1, 2, 0.5f, 3, MKLDNNPlugin::impl_desc_type::unknown, {
179                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
180                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
181                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
182                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
183                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
184                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
185                                 },
186                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
187                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
188                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
189                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
190                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
191                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
192                                 },
193                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
194                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
195                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
196                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
197                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
198                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
199                                 }}},
200                 power_test_params{{1, 1, 23, 23}, 3, 8, 2, 3 },
201                 power_test_params{{1, 8, 23, 23}, 8, 2, 1, 3 }
202         ));
203
204 class MKLDNNGraphDynBatchPowerTests: public MKLDNNGraphPowerTests {
205     std::string model_t = R"V0G0N(
206 <Net Name="Power_Only" version="2" precision="FP32" batch="1">
207     <layers>
208         <layer name="in1" type="Input" precision="FP32" id="0">
209             <output>
210                 <port id="0">
211                     <dim>_IN_</dim>
212                     <dim>_IC_</dim>
213                     <dim>_IH_</dim>
214                     <dim>_IW_</dim>
215                 </port>
216             </output>
217         </layer>
218         <layer name="power" id="1" type="Power" precision="FP32">
219             <power_data power="_POWER_" scale="_SCALE_" shift="_SHIFT_"/>
220             <input>
221                 <port id="1">
222                     <dim>_IN_</dim>
223                     <dim>_IC_</dim>
224                     <dim>_IH_</dim>
225                     <dim>_IW_</dim>
226                 </port>
227             </input>
228             <output>
229                 <port id="2">
230                     <dim>_IN_</dim>
231                     <dim>_IC_</dim>
232                     <dim>_IH_</dim>
233                     <dim>_IW_</dim>
234                 </port>
235             </output>
236         </layer>
237     </layers>
238     <edges>
239         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
240     </edges>
241 </Net>
242 )V0G0N";
243
244     std::string getModel(power_test_params p) {
245         std::string model = model_t;
246
247         REPLACE_WITH_NUM(model, "_IW_", p.in.w);
248         REPLACE_WITH_NUM(model, "_IH_", p.in.h);
249         REPLACE_WITH_NUM(model, "_IC_", p.in.c);
250         REPLACE_WITH_NUM(model, "_IN_", p.in.n);
251         REPLACE_WITH_NUM(model, "_POWER_", p.power);
252         REPLACE_WITH_NUM(model, "_SCALE_", p.scale);
253         REPLACE_WITH_NUM(model, "_SHIFT_", p.shift);
254
255         return model;
256     }
257
258 protected:
259     virtual void TearDown() {
260     }
261
262     virtual void SetUp() {
263         try {
264             TestsCommon::SetUp();
265             power_test_params p = ::testing::WithParamInterface<power_test_params>::GetParam();
266             std::string model = getModel(p);
267             size_t MB = p.in.n;
268             if (MB < 2)
269                 MB = 2;
270
271             InferenceEngine::CNNNetReader net_reader;
272             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
273             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
274             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
275             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
276             InferenceEngine::ResponseDesc resp;
277             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
278             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
279
280             MKLDNNGraphTestClass graph;
281             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
282             graph.CreateGraph(net_reader.getNetwork());
283
284             InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
285
286             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
287             src->allocate();
288             fill_data(src->buffer(), src->size());
289
290             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
291
292             if (srcPtr == nullptr)
293                 FAIL() << "Cannot cast blob to TBlob<float>.";
294
295             InferenceEngine::BlobMap srcs;
296             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
297
298             InferenceEngine::OutputsDataMap out;
299             out = net_reader.getNetwork().getOutputsInfo();
300             InferenceEngine::BlobMap outputBlobs;
301
302             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
303
304             InferenceEngine::TBlob<float>::Ptr output;
305             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
306             output->allocate();
307             outputBlobs[item.first] = output;
308
309             auto checkPower = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
310                 return node->getType() == MKLDNNPlugin::Power;
311             };
312             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkPower);
313             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkPower);
314         } catch (const InferenceEngine::details::InferenceEngineException &e) {
315             FAIL() << e.what();
316         }
317     }
318 };
319
320 TEST_P(MKLDNNGraphDynBatchPowerTests, TestsDynBatchPower) {}
321
322 INSTANTIATE_TEST_CASE_P(
323         TestsDynBatchPower, MKLDNNGraphDynBatchPowerTests,
324         ::testing::Values(
325                 power_test_params{
326                         {1, 3, 13, 13}, 1, 2, 0.5f, 3, MKLDNNPlugin::impl_desc_type::unknown, {
327                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
328                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
329                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
330                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
331                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
332                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
333                                 },
334                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
335                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
336                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
337                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
338                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
339                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
340                                 },
341                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
342                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
343                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
344                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
345                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
346                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
347                                 }}},
348                 power_test_params{{1, 1, 23, 23}, 3, 8, 2, 3 },
349                 power_test_params{{1, 8, 23, 23}, 8, 2, 1, 3 }
350         ));