Publishing R3
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_power_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
10
11 #include "test_graph.hpp"
12
13 #include "single_layer_common.hpp"
14 #include <mkldnn_plugin/mkldnn_extension_utils.h>
15 #include <inference_engine/cnn_network_impl.hpp>
16 #include "tests_common.hpp"
17
18
19 using namespace ::testing;
20 using namespace std;
21 using namespace mkldnn;
22
23
24 struct power_test_params {
25     struct {
26         size_t n;
27         size_t c;
28         size_t h;
29         size_t w;
30     } in;
31
32     float power;
33     float scale;
34     float shift;
35
36     size_t num_prim_desc;
37
38     MKLDNNPlugin::impl_desc_type selectedType;
39
40     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
41 };
42
43 template <typename data_t>
44 void ref_power(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, power_test_params prm) {
45     const data_t *src_data = src.readOnly();
46     data_t *dst_data = dst.data();
47
48 #pragma omp parallel for
49     for (int i=0; i < src.size(); i++)
50         dst_data[i] = pow(src_data[i]*prm.scale + prm.shift, prm.power);
51 }
52
53 class MKLDNNGraphPowerTests: public TestsCommon,
54                                      public WithParamInterface<power_test_params> {
55     std::string model_t = R"V0G0N(
56 <Net Name="Power_Only" version="2" precision="FP32" batch="1">
57     <layers>
58         <layer name="in1" type="Input" precision="FP32" id="0">
59             <output>
60                 <port id="0">
61                     <dim>_IN_</dim>
62                     <dim>_IC_</dim>
63                     <dim>_IH_</dim>
64                     <dim>_IW_</dim>
65                 </port>
66             </output>
67         </layer>
68         <layer name="power" id="1" type="Power" precision="FP32">
69             <power_data power="_POWER_" scale="_SCALE_" shift="_SHIFT_"/>
70             <input>
71                 <port id="1">
72                     <dim>_IN_</dim>
73                     <dim>_IC_</dim>
74                     <dim>_IH_</dim>
75                     <dim>_IW_</dim>
76                 </port>
77             </input>
78             <output>
79                 <port id="2">
80                     <dim>_IN_</dim>
81                     <dim>_IC_</dim>
82                     <dim>_IH_</dim>
83                     <dim>_IW_</dim>
84                 </port>
85             </output>
86         </layer>
87     </layers>
88     <edges>
89         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
90     </edges>
91 </Net>
92 )V0G0N";
93
94 protected:
95     std::string getModel(power_test_params p) {
96         std::string model = model_t;
97
98         REPLACE_WITH_NUM(model, "_IW_", p.in.w);
99         REPLACE_WITH_NUM(model, "_IH_", p.in.h);
100         REPLACE_WITH_NUM(model, "_IC_", p.in.c);
101         REPLACE_WITH_NUM(model, "_IN_", p.in.n);
102         REPLACE_WITH_NUM(model, "_POWER_", p.power);
103         REPLACE_WITH_NUM(model, "_SCALE_", p.scale);
104         REPLACE_WITH_NUM(model, "_SHIFT_", p.shift);
105
106         return model;
107     }
108
109     virtual void TearDown() {
110     }
111
112     virtual void SetUp() {
113         try {
114             TestsCommon::SetUp();
115             power_test_params p = ::testing::WithParamInterface<power_test_params>::GetParam();
116             std::string model = getModel(p);
117
118             InferenceEngine::CNNNetReader net_reader;
119             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
120
121             MKLDNNGraphTestClass graph;
122             graph.CreateGraph(net_reader.getNetwork());
123             auto& nodes = graph.getNodes();
124             for (int i = 0; i < nodes.size(); i++) {
125                 if (nodes[i]->getType() == MKLDNNPlugin::Power) {
126                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
127                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
128                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
129                     }
130                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
131                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
132                 }
133             }
134
135             InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
136
137             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
138             src->allocate();
139             fill_data(src->buffer(), src->size());
140
141             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
142
143             if (srcPtr == nullptr)
144                 FAIL() << "Cannot cast blob to TBlob<float>.";
145
146             InferenceEngine::BlobMap srcs;
147             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
148
149             InferenceEngine::OutputsDataMap out;
150             out = net_reader.getNetwork().getOutputsInfo();
151             InferenceEngine::BlobMap outputBlobs;
152
153             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
154
155             InferenceEngine::TBlob<float>::Ptr output;
156             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
157             output->allocate();
158             outputBlobs[item.first] = output;
159
160             graph.Infer(srcs, outputBlobs);
161
162             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
163             dst_ref.allocate();
164
165             ref_power(*srcPtr, dst_ref, p);
166
167             compare(*output, dst_ref);
168         } catch (const InferenceEngine::details::InferenceEngineException &e) {
169             FAIL() << e.what();
170         }
171     }
172 };
173
174 TEST_P(MKLDNNGraphPowerTests, TestsPower) {}
175
176
177 INSTANTIATE_TEST_CASE_P(
178         TestsPower, MKLDNNGraphPowerTests,
179         ::testing::Values(
180                 power_test_params{
181                         {1, 3, 13, 13}, 1, 2, 0.5f, 3, MKLDNNPlugin::impl_desc_type::unknown, {
182                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
183                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
184                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
185                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
186                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
187                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
188                                 },
189                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
190                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
191                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
192                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
193                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
194                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
195                                 },
196                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
197                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
198                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
199                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
200                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
201                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
202                                 }}},
203                 power_test_params{{1, 1, 23, 23}, 3, 8, 2, 3 },
204                 power_test_params{{1, 8, 23, 23}, 8, 2, 1, 3 }
205         ));
206
207 class MKLDNNGraphDynBatchPowerTests: public MKLDNNGraphPowerTests {
208     std::string model_t = R"V0G0N(
209 <Net Name="Power_Only" version="2" precision="FP32" batch="1">
210     <layers>
211         <layer name="in1" type="Input" precision="FP32" id="0">
212             <output>
213                 <port id="0">
214                     <dim>_IN_</dim>
215                     <dim>_IC_</dim>
216                     <dim>_IH_</dim>
217                     <dim>_IW_</dim>
218                 </port>
219             </output>
220         </layer>
221         <layer name="power" id="1" type="Power" precision="FP32">
222             <power_data power="_POWER_" scale="_SCALE_" shift="_SHIFT_"/>
223             <input>
224                 <port id="1">
225                     <dim>_IN_</dim>
226                     <dim>_IC_</dim>
227                     <dim>_IH_</dim>
228                     <dim>_IW_</dim>
229                 </port>
230             </input>
231             <output>
232                 <port id="2">
233                     <dim>_IN_</dim>
234                     <dim>_IC_</dim>
235                     <dim>_IH_</dim>
236                     <dim>_IW_</dim>
237                 </port>
238             </output>
239         </layer>
240     </layers>
241     <edges>
242         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
243     </edges>
244 </Net>
245 )V0G0N";
246
247     std::string getModel(power_test_params p) {
248         std::string model = model_t;
249
250         REPLACE_WITH_NUM(model, "_IW_", p.in.w);
251         REPLACE_WITH_NUM(model, "_IH_", p.in.h);
252         REPLACE_WITH_NUM(model, "_IC_", p.in.c);
253         REPLACE_WITH_NUM(model, "_IN_", p.in.n);
254         REPLACE_WITH_NUM(model, "_POWER_", p.power);
255         REPLACE_WITH_NUM(model, "_SCALE_", p.scale);
256         REPLACE_WITH_NUM(model, "_SHIFT_", p.shift);
257
258         return model;
259     }
260
261 protected:
262     virtual void TearDown() {
263     }
264
265     virtual void SetUp() {
266         try {
267             TestsCommon::SetUp();
268             power_test_params p = ::testing::WithParamInterface<power_test_params>::GetParam();
269             std::string model = getModel(p);
270             size_t MB = p.in.n;
271             if (MB < 2)
272                 MB = 2;
273
274             InferenceEngine::CNNNetReader net_reader;
275             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
276             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
277             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
278             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
279             InferenceEngine::ResponseDesc resp;
280             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
281             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
282
283             MKLDNNGraphTestClass graph;
284             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
285             graph.CreateGraph(net_reader.getNetwork());
286
287             InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
288
289             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
290             src->allocate();
291             fill_data(src->buffer(), src->size());
292
293             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
294
295             if (srcPtr == nullptr)
296                 FAIL() << "Cannot cast blob to TBlob<float>.";
297
298             InferenceEngine::BlobMap srcs;
299             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
300
301             InferenceEngine::OutputsDataMap out;
302             out = net_reader.getNetwork().getOutputsInfo();
303             InferenceEngine::BlobMap outputBlobs;
304
305             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
306
307             InferenceEngine::TBlob<float>::Ptr output;
308             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
309             output->allocate();
310             outputBlobs[item.first] = output;
311
312             auto checkPower = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
313                 return node->getType() == MKLDNNPlugin::Power;
314             };
315             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkPower);
316             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkPower);
317         } catch (const InferenceEngine::details::InferenceEngineException &e) {
318             FAIL() << e.what();
319         }
320     }
321 };
322
323 TEST_P(MKLDNNGraphDynBatchPowerTests, TestsDynBatchPower) {}
324
325 INSTANTIATE_TEST_CASE_P(
326         TestsDynBatchPower, MKLDNNGraphDynBatchPowerTests,
327         ::testing::Values(
328                 power_test_params{
329                         {1, 3, 13, 13}, 1, 2, 0.5f, 3, MKLDNNPlugin::impl_desc_type::unknown, {
330                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
331                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
332                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
333                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
334                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
335                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
336                                 },
337                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
338                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
339                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
340                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
341                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
342                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
343                                 },
344                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
345                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
346                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
347                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
348                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
349                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
350                                 }}},
351                 power_test_params{{1, 1, 23, 23}, 3, 8, 2, 3 },
352                 power_test_params{{1, 8, 23, 23}, 8, 2, 1, 3 }
353         ));