Publishing R3
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_input_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
10
11 #include "test_graph.hpp"
12
13 #include <mkldnn_plugin/mkldnn_extension_utils.h>
14 #include "tests_common.hpp"
15
16
17 using namespace ::testing;
18 using namespace std;
19 using namespace mkldnn;
20
21
22 struct input_test_params {
23     size_t num_prim_desc;
24
25     MKLDNNPlugin::impl_desc_type selectedType;
26
27     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
28 };
29
30 class MKLDNNGraphInputTests: public TestsCommon,
31                                      public WithParamInterface<input_test_params> {
32     std::string model_t = R"V0G0N(
33 <net name="InputsOnly" version="2" precision="FP32" batch="1">
34     <layers>
35         <layer name="in1" type="Input" precision="FP32" id="1">
36             <output>
37                 <port id="1">
38                     <dim>1</dim>
39                     <dim>3</dim>
40                     <dim>3</dim>
41                     <dim>3</dim>
42                 </port>
43             </output>
44         </layer>
45         <layer name="in2" type="Input" precision="FP32" id="2">
46             <output>
47                 <port id="2">
48                     <dim>1</dim>
49                     <dim>3</dim>
50                     <dim>3</dim>
51                     <dim>3</dim>
52                 </port>
53             </output>
54         </layer>
55         <layer name="in3" type="Input" precision="FP32" id="3">
56             <output>
57                 <port id="3">
58                     <dim>1</dim>
59                     <dim>3</dim>
60                 </port>
61             </output>
62         </layer>
63         <layer name="power1" id="4" type="Power" precision="FP32">
64             <power_data power="1" scale="1" shift="1"/>
65             <input>
66                 <port id="4">
67                     <dim>1</dim>
68                     <dim>3</dim>
69                     <dim>3</dim>
70                     <dim>3</dim>
71                 </port>
72             </input>
73             <output>
74                 <port id="5">
75                     <dim>1</dim>
76                     <dim>3</dim>
77                     <dim>3</dim>
78                     <dim>3</dim>
79                 </port>
80             </output>
81         </layer>
82         <layer name="power2" id="5" type="Power" precision="FP32">
83             <power_data power="1" scale="1" shift="1"/>
84             <input>
85                 <port id="6">
86                     <dim>1</dim>
87                     <dim>3</dim>
88                     <dim>3</dim>
89                     <dim>3</dim>
90                 </port>
91             </input>
92             <output>
93                 <port id="7">
94                     <dim>1</dim>
95                     <dim>3</dim>
96                     <dim>3</dim>
97                     <dim>3</dim>
98                 </port>
99             </output>
100         </layer>
101         <layer name="power3" id="6" type="Power" precision="FP32">
102             <power_data power="1" scale="1" shift="1"/>
103             <input>
104                 <port id="8">
105                     <dim>1</dim>
106                     <dim>3</dim>
107                 </port>
108             </input>
109             <output>
110                 <port id="9">
111                     <dim>1</dim>
112                     <dim>3</dim>
113                 </port>
114             </output>
115         </layer>
116     </layers>
117     <edges>
118         <edge from-layer="1" from-port="1" to-layer="4" to-port="4"/>
119         <edge from-layer="2" from-port="2" to-layer="5" to-port="6"/>
120         <edge from-layer="3" from-port="3" to-layer="6" to-port="8"/>
121     </edges>
122 </net>
123 )V0G0N";
124
125     std::string getModel(input_test_params p) {
126         return model_t;
127     }
128
129 protected:
130     virtual void TearDown() {
131     }
132
133     virtual void SetUp() {
134         try {
135             TestsCommon::SetUp();
136             input_test_params p = ::testing::WithParamInterface<input_test_params>::GetParam();
137             std::string model = getModel(p);
138
139             InferenceEngine::CNNNetReader net_reader;
140             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
141
142             MKLDNNGraphTestClass graph;
143             graph.CreateGraph(net_reader.getNetwork());
144
145             auto& nodes = graph.getNodes();
146             for (int i = 0; i < nodes.size(); i++) {
147                 if (nodes[i]->getType() == MKLDNNPlugin::Input || nodes[i]->getType() == MKLDNNPlugin::Output) {
148                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
149                     size_t count = (nodes[i]->getType() == MKLDNNPlugin::Input) ? 0 : 2;
150                     if (nodes[i]->getName() == "in3") {
151                         count = 1;
152                     }
153                     if (nodes[i]->getName() == "out_power3") {
154                         count = 3;
155                     }
156                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
157                         p.comp.at(count)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
158                     }
159                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
160                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
161                 }
162             }
163         } catch (const InferenceEngine::details::InferenceEngineException &e) {
164             FAIL() << e.what();
165         }
166     }
167 };
168
169 TEST_P(MKLDNNGraphInputTests, TestsInput) {}
170
171
172 INSTANTIATE_TEST_CASE_P(
173         TestsInput, MKLDNNGraphInputTests,
174         ::testing::Values(
175                 input_test_params{1, MKLDNNPlugin::impl_desc_type::unknown, {
176                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
177                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
178                             ASSERT_EQ(0, impl.getConfig().inConfs.size());
179                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
180                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
181                         },
182                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
183                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
184                             ASSERT_EQ(0, impl.getConfig().inConfs.size());
185                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
186                             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().outConfs.at(0).desc.getLayout());
187                         },
188                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
189                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
190                             ASSERT_EQ(1, impl.getConfig().inConfs.size());
191                             ASSERT_EQ(0, impl.getConfig().outConfs.size());
192                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
193                         },
194                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
195                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
196                             ASSERT_EQ(1, impl.getConfig().inConfs.size());
197                             ASSERT_EQ(0, impl.getConfig().outConfs.size());
198                             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
199                         }
200                 } }
201         ));
202
203 class MKLDNNGraphConstInputTests: public TestsCommon {
204     std::string model_t = R"V0G0N(
205 <net name="ConcatOnly" version="2" precision="FP32" batch="1">
206     <layers>
207         <layer name="in1" type="Input" precision="FP32" id="1">
208             <output>
209                 <port id="1">
210                     <dim>1</dim>
211                     <dim>3</dim>
212                     <dim>2</dim>
213                     <dim>2</dim>
214                 </port>
215             </output>
216             <blobs>
217                 <custom offset="0" size="48"/>
218             </blobs>
219         </layer>
220         <layer name="in2" type="Const" precision="FP32" id="2">
221             <output>
222                 <port id="2">
223                     <dim>1</dim>
224                     <dim>3</dim>
225                     <dim>1</dim>
226                     <dim>2</dim>
227                 </port>
228             </output>
229             <blobs>
230                 <custom offset="48" size="24"/>
231             </blobs>
232         </layer>
233         <layer name="con" id="3" type="Concat" precision="FP32">
234             <concat_data axis="2"/>
235             <input>
236                 <port id="1">
237                     <dim>1</dim>
238                     <dim>3</dim>
239                     <dim>2</dim>
240                     <dim>2</dim>
241                 </port>
242                 <port id="2">
243                     <dim>1</dim>
244                     <dim>3</dim>
245                     <dim>1</dim>
246                     <dim>2</dim>
247                 </port>
248             </input>
249             <output>
250                 <port id="3">
251                     <dim>1</dim>
252                     <dim>3</dim>
253                     <dim>3</dim>
254                     <dim>2</dim>
255                 </port>
256             </output>
257         </layer>
258     </layers>
259     <edges>
260         <edge from-layer="1" from-port="1" to-layer="3" to-port="1"/>
261         <edge from-layer="2" from-port="2" to-layer="3" to-port="2"/>
262     </edges>
263 </net>
264 )V0G0N";
265
266 protected:
267     virtual void TearDown() {
268     }
269
270     virtual void SetUp() {
271         try {
272             TestsCommon::SetUp();
273             std::string model = model_t;
274
275             InferenceEngine::CNNNetReader net_reader;
276             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
277
278             InferenceEngine::TBlob<uint8_t> *weights = new InferenceEngine::TBlob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {72});
279             weights->allocate();
280             float * data = weights->buffer();
281
282             std::cout << weights->size() << std::endl;
283
284             InferenceEngine::SizeVector dims_src1 = {1, 3, 2, 2};
285             InferenceEngine::SizeVector dims_src2 = {1, 3, 1, 2};
286             InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src1);
287             src1->allocate();
288             float *srcData = src1->buffer();
289             for (size_t i = 0; i < 12; i++, data++, srcData++) {
290                 *data = 1;
291                 *srcData = 1;
292             }
293
294             InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src2);
295             src2->allocate();
296             srcData = src2->buffer();
297             for (size_t i = 0; i < 6; i++, data++, srcData++) {
298                 *data = 2;
299                 *srcData = 2;
300             }
301             InferenceEngine::TBlob<uint8_t>::Ptr weights_ptr = InferenceEngine::TBlob<uint8_t>::Ptr(weights);
302
303             net_reader.SetWeights(weights_ptr);
304
305             MKLDNNGraphTestClass graph;
306             graph.CreateGraph(net_reader.getNetwork());
307             auto& nodes = graph.getNodes();
308             ASSERT_LE(3, nodes.size());
309
310             InferenceEngine::BlobMap srcs;
311             srcs["in1"] = src1;
312             InferenceEngine::OutputsDataMap out;
313             out = net_reader.getNetwork().getOutputsInfo();
314             InferenceEngine::BlobMap outputBlobs;
315
316             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
317
318             InferenceEngine::TBlob<float>::Ptr output;
319             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
320             output->allocate();
321             outputBlobs[item.first] = output;
322
323             graph.Infer(srcs, outputBlobs);
324
325             // Compare
326             float *src1_ptr = src1->buffer();
327             size_t src1_size = src1->size();
328             float *src2_ptr = src2->buffer();
329             size_t src2_size = src2->size();
330             float *dst_ptr = output->buffer();
331             size_t dst_size = output->size();
332
333             int len1 = 1, len2 = 1, cycles;
334             for (int dim = 2; dim < output->dims().size(); dim++) {
335                 len1 *= src1->dims()[dim];
336                 len2 *= src2->dims()[dim];
337             }
338             cycles = 2;
339
340             int index1 = 0, index2 = 0, index = 0;
341             for (int cycle = 0; cycle < cycles; cycle ++) {
342                 for (int i1 = 0; i1 < len1; i1++) {
343                     if (src1_ptr[index1] != dst_ptr[index])
344                     {
345                         FAIL() << "index: " << index << " src: " << src1_ptr[index1] << ", dst: " << dst_ptr[index];
346                     }
347                     index1++; index++;
348                 }
349                 for (int i2 = 0; i2 < len2; i2++) {
350                     if (src2_ptr[index2] != dst_ptr[index])
351                     {
352                         FAIL() << "index: " << index << " src: " << src2_ptr[index2] << ", dst: " << dst_ptr[index];
353                     }
354                     index2++; index++;
355                 }
356             }
357         } catch (const InferenceEngine::details::InferenceEngineException &e) {
358             FAIL() << e.what();
359         }
360     }
361 };
362
363 TEST_F(MKLDNNGraphConstInputTests, TestsConstInput) {}