Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_permute_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8
9 #include "test_graph.hpp"
10
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
15
16
17 using namespace ::testing;
18 using namespace std;
19 using namespace mkldnn;
20
21
22 struct permute_test_params {
23     InferenceEngine::SizeVector dims;
24     InferenceEngine::SizeVector order;
25     size_t num_prim_desc;
26
27     MKLDNNPlugin::impl_desc_type selectedType;
28
29     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
30 };
31
32 template <typename data_t>
33 void ref_permute(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, permute_test_params prm) {
34     const data_t *src_data = src.readOnly();
35     data_t *dst_data = dst.data();
36
37     InferenceEngine::SizeVector orderedDims;
38     for (auto ord : prm.order) {
39         orderedDims.push_back(src.getTensorDesc().getDims()[ord]);
40     }
41     InferenceEngine::TensorDesc desc(InferenceEngine::Precision::FP32, src.getTensorDesc().getDims(), {orderedDims, prm.order});
42
43     for (int i=0; i < src.size(); i++) {
44         dst_data[desc.offset(i)] = src_data[src.getTensorDesc().offset(i)];
45     }
46 }
47
48 class MKLDNNGraphPermuteTests: public TestsCommon,
49                                public WithParamInterface<permute_test_params> {
50     std::string model_t = R"V0G0N(
51 <Net Name="Power_Only" version="2" precision="FP32" batch="1">
52     <layers>
53         <layer name="in1" type="Input" precision="FP32" id="0">
54             <output>
55                 <port id="0">
56                     __DIMS__
57                 </port>
58             </output>
59         </layer>
60         <layer name="permute" id="1" type="Permute" precision="FP32">
61             <data order="_ORDER_"/>
62             <input>
63                 <port id="1">
64                     __DIMS__
65                 </port>
66             </input>
67             <output>
68                 <port id="2">
69                     __DST_DIMS__
70                 </port>
71             </output>
72         </layer>
73     </layers>
74     <edges>
75         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
76     </edges>
77 </Net>
78 )V0G0N";
79
80 protected:
81     std::string getModel(permute_test_params p) {
82         std::string model = model_t;
83         std::string dims;
84         std::string dst_dims;
85         for (auto& dim : p.dims) {
86             dims += "<dim>";
87             dims += std::to_string(dim) + "</dim>\n";
88         }
89
90         std::string order;
91         for (auto& ord : p.order) {
92             if (!order.empty())
93                 order += ",";
94             order += std::to_string(ord);
95             dst_dims += "<dim>";
96             dst_dims += std::to_string(p.dims[ord]) + "</dim>\n";
97         }
98
99         REPLACE_WITH_STR(model, "__DIMS__", dims);
100         REPLACE_WITH_STR(model, "__DST_DIMS__", dst_dims);
101         REPLACE_WITH_STR(model, "_ORDER_", order);
102
103         return model;
104     }
105
106     virtual void TearDown() {
107     }
108
109     virtual void SetUp() {
110         try {
111             TestsCommon::SetUp();
112             permute_test_params p = ::testing::WithParamInterface<permute_test_params>::GetParam();
113             std::string model = getModel(p);
114
115             InferenceEngine::CNNNetReader net_reader;
116             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
117
118             MKLDNNGraphTestClass graph;
119             graph.CreateGraph(net_reader.getNetwork());
120             auto& nodes = graph.getNodes();
121             for (int i = 0; i < nodes.size(); i++) {
122                 if (nodes[i]->getType() == MKLDNNPlugin::Permute) {
123                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
124                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
125                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
126                     }
127                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
128                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
129                 }
130             }
131
132             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float>({InferenceEngine::Precision::FP32, p.dims, InferenceEngine::TensorDesc::getLayoutByDims(p.dims)});
133             src->allocate();
134             fill_data(src->buffer(), src->size());
135
136             auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
137
138             if (srcPtr == nullptr)
139                 FAIL() << "Cannot cast blob to TBlob<float>.";
140
141             InferenceEngine::BlobMap srcs;
142             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
143
144             InferenceEngine::OutputsDataMap out;
145             out = net_reader.getNetwork().getOutputsInfo();
146             InferenceEngine::BlobMap outputBlobs;
147
148             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
149
150             InferenceEngine::TBlob<float>::Ptr output;
151             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
152             output->allocate();
153             outputBlobs[item.first] = output;
154
155             graph.Infer(srcs, outputBlobs);
156
157             InferenceEngine::TensorDesc td(InferenceEngine::Precision::FP32, p.dims, InferenceEngine::TensorDesc::getLayoutByDims(p.dims));
158             InferenceEngine::TBlob<float> dst_ref(td);
159             dst_ref.allocate();
160
161             ref_permute(*srcPtr, dst_ref, p);
162
163             compare(*output, dst_ref);
164         } catch (const InferenceEngine::details::InferenceEngineException &e) {
165             FAIL() << e.what();
166         }
167     }
168 };
169
170 TEST_P(MKLDNNGraphPermuteTests, TestsPermute) {}
171
172 INSTANTIATE_TEST_CASE_P(
173         TestsPermute, MKLDNNGraphPermuteTests,
174         ::testing::Values(
175                 permute_test_params{{2, 3, 4, 5}, {0, 1, 2, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
176                 permute_test_params{{2, 3, 4, 5}, {0, 2, 3, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
177                 permute_test_params{{2, 3, 4, 5}, {3, 0, 1, 2}, 1, MKLDNNPlugin::impl_desc_type::unknown},
178                 permute_test_params{{2, 3, 4, 5}, {1, 3, 2, 0}, 1, MKLDNNPlugin::impl_desc_type::unknown},
179                 permute_test_params{{2, 3, 4, 5}, {3, 2, 1, 0}, 1, MKLDNNPlugin::impl_desc_type::unknown},
180                 permute_test_params{{2, 3, 4, 5}, {0, 2, 1, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
181                 permute_test_params{{2, 3, 4}, {0, 1, 2}, 1, MKLDNNPlugin::impl_desc_type::unknown},
182                 permute_test_params{{2, 3, 4}, {0, 2, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
183                 permute_test_params{{2, 3, 4}, {2, 1, 0}, 1, MKLDNNPlugin::impl_desc_type::unknown},
184                 permute_test_params{{2, 3, 4}, {1, 2, 0}, 1, MKLDNNPlugin::impl_desc_type::unknown},
185                 permute_test_params{{2, 3, 4}, {2, 0, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
186                 permute_test_params{{2, 3, 4}, {1, 0, 2}, 1, MKLDNNPlugin::impl_desc_type::unknown},
187                 permute_test_params{{2, 3}, {1, 0}, 1, MKLDNNPlugin::impl_desc_type::unknown},
188                 permute_test_params{{2, 3}, {0, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
189                 permute_test_params{{2, 3, 4, 5, 6}, {0, 1, 2, 3, 4}, 1, MKLDNNPlugin::impl_desc_type::unknown},
190                 permute_test_params{{2, 3, 4, 5, 6}, {0, 4, 2, 1, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
191                 permute_test_params{{2, 3, 4, 5, 6}, {0, 2, 4, 3, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
192                 permute_test_params{{2, 3, 4, 5, 6}, {0, 3, 2, 4, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
193                 permute_test_params{{2, 8, 2, 2, 4, 5}, {0, 1, 4, 2, 5, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
194                 permute_test_params{{2, 8, 3, 3, 4, 5}, {0, 1, 4, 2, 5, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
195                 permute_test_params{{2, 8, 3, 4}, {3, 0, 1, 2}, 2, MKLDNNPlugin::impl_desc_type::unknown},
196                 permute_test_params{{2, 12, 9}, {0, 2, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
197                 permute_test_params{{2, 8, 3, 3, 4, 5}, {0, 3, 4, 1, 5, 2}, 1, MKLDNNPlugin::impl_desc_type::unknown}
198         ));
199
200 class MKLDNNGraphDynBatchPermuteTests: public MKLDNNGraphPermuteTests {
201 protected:
202     virtual void SetUp() {
203         try {
204             TestsCommon::SetUp();
205             permute_test_params p = ::testing::WithParamInterface<permute_test_params>::GetParam();
206             std::string model = getModel(p);
207             size_t MB = p.dims[0];
208             if (MB < 2)
209                 MB = 2;
210             p.dims[0] = MB;
211
212             InferenceEngine::CNNNetReader net_reader;
213             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
214             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
215             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
216             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
217             InferenceEngine::ResponseDesc resp;
218             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
219             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
220
221             MKLDNNGraphTestClass graph;
222             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
223             graph.CreateGraph(net_reader.getNetwork());
224
225             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float>({InferenceEngine::Precision::FP32, p.dims, InferenceEngine::TensorDesc::getLayoutByDims(p.dims)});
226             src->allocate();
227             fill_data(src->buffer(), src->size());
228
229             auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
230
231             if (srcPtr == nullptr)
232                 FAIL() << "Cannot cast blob to TBlob<float>.";
233
234             InferenceEngine::BlobMap srcs;
235             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
236
237             InferenceEngine::OutputsDataMap out;
238             out = net_reader.getNetwork().getOutputsInfo();
239             InferenceEngine::BlobMap outputBlobs;
240
241             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
242
243             InferenceEngine::TBlob<float>::Ptr output;
244             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
245             output->allocate();
246             outputBlobs[item.first] = output;
247
248             auto checkPermute = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
249                 return node->getType() == MKLDNNPlugin::Permute;
250             };
251             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkPermute);
252             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkPermute);
253         } catch (const InferenceEngine::details::InferenceEngineException &e) {
254             FAIL() << e.what();
255         }
256     }
257 };
258
259 TEST_P(MKLDNNGraphDynBatchPermuteTests, TestsDynBatchPermute) {}
260
261 INSTANTIATE_TEST_CASE_P(
262         TestsDynBatchPermute, MKLDNNGraphDynBatchPermuteTests,
263         ::testing::Values(
264                 permute_test_params{{2, 3, 4, 5}, {0, 1, 2, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
265                 permute_test_params{{2, 3, 4, 5}, {0, 2, 3, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
266                 permute_test_params{{2, 3, 4, 5}, {0, 2, 1, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
267                 permute_test_params{{2, 3, 4}, {0, 1, 2}, 1, MKLDNNPlugin::impl_desc_type::unknown},
268                 permute_test_params{{2, 3, 4}, {0, 2, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
269                 permute_test_params{{2, 3}, {0, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
270                 permute_test_params{{2, 3, 4, 5, 6}, {0, 1, 2, 3, 4}, 1, MKLDNNPlugin::impl_desc_type::unknown},
271                 permute_test_params{{2, 3, 4, 5, 6}, {0, 4, 2, 1, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
272                 permute_test_params{{2, 3, 4, 5, 6}, {0, 2, 4, 3, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
273                 permute_test_params{{2, 3, 4, 5, 6}, {0, 3, 2, 4, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
274                 permute_test_params{{2, 8, 2, 2, 4, 5}, {0, 1, 4, 2, 5, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
275                 permute_test_params{{2, 8, 3, 3, 4, 5}, {0, 1, 4, 2, 5, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
276                 permute_test_params{{2, 12, 9}, {0, 2, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
277                 permute_test_params{{2, 8, 3, 3, 4, 5}, {0, 3, 4, 1, 5, 2}, 1, MKLDNNPlugin::impl_desc_type::unknown}
278         ));