Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_reshape_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8 #include "single_layer_common.hpp"
9
10 #include "test_graph.hpp"
11
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include "tests_common.hpp"
14
15
16 using namespace ::testing;
17 using namespace std;
18 using namespace mkldnn;
19
20
21 struct reshape_test_params {
22     InferenceEngine::SizeVector in;
23     InferenceEngine::SizeVector out;
24     std::vector<size_t> shape;
25
26     int axis;
27     int num_axes;
28
29     size_t num_prim_desc;
30
31     MKLDNNPlugin::impl_desc_type selectedType;
32
33     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
34 };
35
36 template <typename data_t>
37 void ref_reshape(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst) {
38     const data_t *src_data = src.readOnly();
39     data_t *dst_data = dst.data();
40
41     for (int i=0; i < src.size(); i++)
42         dst_data[i] = src_data[i];
43 }
44
45 class MKLDNNGraphReshapeTests: public TestsCommon,
46                                      public WithParamInterface<reshape_test_params> {
47     std::string model_t = R"V0G0N(
48 <Net Name="Reshape_Only" version="2" precision="FP32" batch="1">
49     <layers>
50         <layer name="in1" type="Input" precision="FP32" id="0">
51             <output>
52                 <port id="0">
53 __SRC_DIMS__
54                 </port>
55             </output>
56         </layer>
57         <layer name="norm" id="1" type="Reshape" precision="FP32">
58             <data dim="_SHAPE_" axis="_AX_" num_axes="_NAX_"/>
59
60             <input>
61                 <port id="1">
62 __SRC_DIMS__
63                 </port>
64             </input>
65             <output>
66                 <port id="2">
67 __DST_DIMS__
68                 </port>
69             </output>
70         </layer>
71     </layers>
72     <edges>
73         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
74     </edges>
75 </Net>
76 )V0G0N";
77
78     std::string getModel(reshape_test_params p) {
79         std::string model = model_t;
80
81                 std::string src_dims;
82                 for (auto& dim : p.in) {
83                         src_dims += "                    <dim>";
84                         src_dims += std::to_string(dim) + "</dim>\n";
85                 }
86                 REPLACE_WITH_STR(model, "__SRC_DIMS__", src_dims);
87
88                 std::string dst_dims;
89                 for (auto& dim : p.out) {
90                         dst_dims += "\t\t<dim>";
91                         dst_dims += std::to_string(dim) + "</dim>\n";
92                 }
93                 REPLACE_WITH_STR(model, "__DST_DIMS__", dst_dims);
94
95         REPLACE_WITH_NUM(model, "_AX_", p.axis);
96         REPLACE_WITH_NUM(model, "_NAX_", p.num_axes);
97
98         std::string shape_str;
99         for (auto& dim : p.shape) {
100             if (!shape_str.empty())
101                 shape_str += ",";
102             shape_str += std::to_string(dim);
103         }
104         REPLACE_WITH_STR(model, "_SHAPE_", shape_str);
105         return model;
106     }
107
108 protected:
109     virtual void TearDown() {
110     }
111
112     virtual void SetUp() {
113         try {
114             TestsCommon::SetUp();
115             reshape_test_params p = ::testing::WithParamInterface<reshape_test_params>::GetParam();
116             std::string model = getModel(p);
117
118             InferenceEngine::CNNNetReader net_reader;
119             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
120
121             MKLDNNGraphTestClass graph;
122             graph.CreateGraph(net_reader.getNetwork());
123             auto& nodes = graph.getNodes();
124             for (int i = 0; i < nodes.size(); i++) {
125                 if (nodes[i]->getType() == MKLDNNPlugin::Reshape) {
126                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
127                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
128                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
129                     }
130                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
131                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
132                 }
133             }
134
135             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::ANY, p.in);
136             src->allocate();
137             fill_data(src->buffer(), src->size());
138
139             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
140
141             if (srcPtr == nullptr)
142                 FAIL() << "Cannot cast blob to TBlob<float>.";
143
144             InferenceEngine::BlobMap srcs;
145             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
146
147             InferenceEngine::OutputsDataMap out;
148             out = net_reader.getNetwork().getOutputsInfo();
149             InferenceEngine::BlobMap outputBlobs;
150
151             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
152
153             InferenceEngine::TBlob<float>::Ptr output;
154             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
155             output->allocate();
156             outputBlobs[item.first] = output;
157
158             graph.Infer(srcs, outputBlobs);
159
160             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
161             dst_ref.allocate();
162
163             ref_reshape(*srcPtr, dst_ref);
164
165             compare(*output, dst_ref);
166         } catch (const InferenceEngine::details::InferenceEngineException &e) {
167             FAIL() << e.what();
168         }
169     }
170 };
171
172 TEST_P(MKLDNNGraphReshapeTests, TestsReshape) {}
173
174
175 INSTANTIATE_TEST_CASE_P(
176         TestsReshape, MKLDNNGraphReshapeTests,
177         ::testing::Values(
178         reshape_test_params{ {1, 3, 228, 228}, {1, 24, 2, 3249}, {1, 24, 2, 3249}, 0, -1, 1,
179             MKLDNNPlugin::impl_desc_type::unknown, { [](MKLDNNPlugin::PrimitiveDescInfo impl) {
180                 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
181                 ASSERT_EQ(1, impl.getConfig().inConfs.size());
182                 ASSERT_EQ(1, impl.getConfig().outConfs.size());
183                 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
184                 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
185         } } },
186         reshape_test_params{ { 4 },{ 2, 2 },{ 2, 2 }, 0, -1, 1,
187             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
188             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
189             ASSERT_EQ(1, impl.getConfig().inConfs.size());
190             ASSERT_EQ(1, impl.getConfig().outConfs.size());
191             ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().inConfs.at(0).desc.getLayout());
192             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().outConfs.at(0).desc.getLayout());
193         } } },
194         reshape_test_params{ { 4 },{ 1, 2, 2 },{ 1, 2, 2 }, 0, -1, 1,
195             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
196             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
197             ASSERT_EQ(1, impl.getConfig().inConfs.size());
198             ASSERT_EQ(1, impl.getConfig().outConfs.size());
199             ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().inConfs.at(0).desc.getLayout());
200             ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().outConfs.at(0).desc.getLayout());
201         } } },
202         reshape_test_params{ { 4 },{ 1, 4, 1, 1 },{ 1, 4, 1, 1 }, 0, -1, 1,
203             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
204             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
205             ASSERT_EQ(1, impl.getConfig().inConfs.size());
206             ASSERT_EQ(1, impl.getConfig().outConfs.size());
207             ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().inConfs.at(0).desc.getLayout());
208             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
209         } } },
210         reshape_test_params{ { 4, 4 },{ 1, 4, 4 },{ 1, 4, 4 }, 0, -1, 1,
211             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
212             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
213             ASSERT_EQ(1, impl.getConfig().inConfs.size());
214             ASSERT_EQ(1, impl.getConfig().outConfs.size());
215             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
216             ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().outConfs.at(0).desc.getLayout());
217         } } },
218         reshape_test_params{ { 4, 4 },{ 1, 4, 2, 2 },{ 1, 4, 2, 2 }, 0, -1, 1,
219             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
220             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
221             ASSERT_EQ(1, impl.getConfig().inConfs.size());
222             ASSERT_EQ(1, impl.getConfig().outConfs.size());
223             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
224             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
225         } } },
226         reshape_test_params{ { 4, 2, 2 },{ 1, 4, 2, 2 },{ 1, 4, 2, 2 }, 0, -1, 1,
227             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
228             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
229             ASSERT_EQ(1, impl.getConfig().inConfs.size());
230             ASSERT_EQ(1, impl.getConfig().outConfs.size());
231             ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().inConfs.at(0).desc.getLayout());
232             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
233         } } },
234         reshape_test_params{ { 2, 2 },{ 4 },{ 4 }, 0, -1, 1,
235             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
236             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
237             ASSERT_EQ(1, impl.getConfig().inConfs.size());
238             ASSERT_EQ(1, impl.getConfig().outConfs.size());
239             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
240             ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().outConfs.at(0).desc.getLayout());
241         } } },
242         reshape_test_params{ { 1, 2, 2 },{ 4 },{ 4 }, 0, -1, 1,
243             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
244             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
245             ASSERT_EQ(1, impl.getConfig().inConfs.size());
246             ASSERT_EQ(1, impl.getConfig().outConfs.size());
247             ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().inConfs.at(0).desc.getLayout());
248             ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().outConfs.at(0).desc.getLayout());
249         } } },
250         reshape_test_params{ { 1, 1, 2, 2 },{ 4 },{ 4 }, 0, -1, 1,
251             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
252             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
253             ASSERT_EQ(1, impl.getConfig().inConfs.size());
254             ASSERT_EQ(1, impl.getConfig().outConfs.size());
255             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
256             ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().outConfs.at(0).desc.getLayout());
257         } } },
258         reshape_test_params{ { 4, 2, 2 },{ 4, 4 },{ 4, 4 }, 0, -1, 1,
259             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
260             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
261             ASSERT_EQ(1, impl.getConfig().inConfs.size());
262             ASSERT_EQ(1, impl.getConfig().outConfs.size());
263             ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().inConfs.at(0).desc.getLayout());
264             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().outConfs.at(0).desc.getLayout());
265         } } },
266         reshape_test_params{ { 1, 4, 2, 2 },{ 4, 4 },{ 4, 4 }, 0, -1, 1,
267             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
268             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
269             ASSERT_EQ(1, impl.getConfig().inConfs.size());
270             ASSERT_EQ(1, impl.getConfig().outConfs.size());
271             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
272             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().outConfs.at(0).desc.getLayout());
273         } } },
274         reshape_test_params{ { 1, 4, 2, 2 },{ 4, 2, 2 },{ 4, 2, 2 }, 0, -1, 1,
275             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
276             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
277             ASSERT_EQ(1, impl.getConfig().inConfs.size());
278             ASSERT_EQ(1, impl.getConfig().outConfs.size());
279             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
280             ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().outConfs.at(0).desc.getLayout());
281         } } },
282         reshape_test_params{ { 1, 4, 2, 2 }, { 4, 2, 2, 1, 1 }, { 4, 2, 2, 1, 1 }, 0, -1, 1,
283             MKLDNNPlugin::impl_desc_type::unknown, { [](MKLDNNPlugin::PrimitiveDescInfo impl) {
284             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
285             ASSERT_EQ(1, impl.getConfig().inConfs.size());
286             ASSERT_EQ(1, impl.getConfig().outConfs.size());
287             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
288             ASSERT_EQ(InferenceEngine::Layout::NCDHW, impl.getConfig().outConfs.at(0).desc.getLayout());
289         } } },
290         reshape_test_params{ { 4, 2, 2, 1, 1 }, { 1, 4, 2, 2 }, { 1, 4, 2, 2 }, 0, -1, 1,
291             MKLDNNPlugin::impl_desc_type::unknown, { [](MKLDNNPlugin::PrimitiveDescInfo impl) {
292             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
293             ASSERT_EQ(1, impl.getConfig().inConfs.size());
294             ASSERT_EQ(1, impl.getConfig().outConfs.size());
295             ASSERT_EQ(InferenceEngine::Layout::NCDHW, impl.getConfig().inConfs.at(0).desc.getLayout());
296             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
297         } } },
298         reshape_test_params{ { 1, 200 }, { 1, 200, 1, 1, 1 }, { 1, 200, 1, 1, 1 }, 0, -1, 1,
299             MKLDNNPlugin::impl_desc_type::unknown, { [](MKLDNNPlugin::PrimitiveDescInfo impl) {
300             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
301             ASSERT_EQ(1, impl.getConfig().inConfs.size());
302             ASSERT_EQ(1, impl.getConfig().outConfs.size());
303             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
304             ASSERT_EQ(InferenceEngine::Layout::NCDHW, impl.getConfig().outConfs.at(0).desc.getLayout());
305         } } }
306 ));