Publishing R3
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_reshape_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
10 #include "single_layer_common.hpp"
11
12 #include "test_graph.hpp"
13
14 #include <mkldnn_plugin/mkldnn_extension_utils.h>
15 #include "tests_common.hpp"
16
17
18 using namespace ::testing;
19 using namespace std;
20 using namespace mkldnn;
21
22
23 struct reshape_test_params {
24     InferenceEngine::SizeVector in;
25     InferenceEngine::SizeVector out;
26     std::vector<size_t> shape;
27
28     int axis;
29     int num_axes;
30
31     size_t num_prim_desc;
32
33     MKLDNNPlugin::impl_desc_type selectedType;
34
35     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
36 };
37
38 template <typename data_t>
39 void ref_reshape(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst) {
40     const data_t *src_data = src.readOnly();
41     data_t *dst_data = dst.data();
42
43 #pragma omp parallel for
44     for (int i=0; i < src.size(); i++)
45         dst_data[i] = src_data[i];
46 }
47
48 class MKLDNNGraphReshapeTests: public TestsCommon,
49                                      public WithParamInterface<reshape_test_params> {
50     std::string model_t = R"V0G0N(
51 <Net Name="Reshape_Only" version="2" precision="FP32" batch="1">
52     <layers>
53         <layer name="in1" type="Input" precision="FP32" id="0">
54             <output>
55                 <port id="0">
56 __SRC_DIMS__
57                 </port>
58             </output>
59         </layer>
60         <layer name="norm" id="1" type="Reshape" precision="FP32">
61             <data dim="_SHAPE_" axis="_AX_" num_axes="_NAX_"/>
62
63             <input>
64                 <port id="1">
65 __SRC_DIMS__
66                 </port>
67             </input>
68             <output>
69                 <port id="2">
70 __DST_DIMS__
71                 </port>
72             </output>
73         </layer>
74     </layers>
75     <edges>
76         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
77     </edges>
78 </Net>
79 )V0G0N";
80
81     std::string getModel(reshape_test_params p) {
82         std::string model = model_t;
83
84                 std::string src_dims;
85                 for (auto& dim : p.in) {
86                         src_dims += "<dim>";
87                         src_dims += std::to_string(dim) + "</dim>\n";
88                 }
89                 REPLACE_WITH_STR(model, "__SRC_DIMS__", src_dims);
90
91                 std::string dst_dims;
92                 for (auto& dim : p.out) {
93                         dst_dims += "<dim>";
94                         dst_dims += std::to_string(dim) + "</dim>\n";
95                 }
96                 REPLACE_WITH_STR(model, "__DST_DIMS__", dst_dims);
97
98         REPLACE_WITH_NUM(model, "_AX_", p.axis);
99         REPLACE_WITH_NUM(model, "_NAX_", p.num_axes);
100
101         std::string shape_str;
102         for (auto& dim : p.shape) {
103             if (!shape_str.empty())
104                 shape_str += ",";
105             shape_str += std::to_string(dim);
106         }
107         REPLACE_WITH_STR(model, "_SHAPE_", shape_str);
108         return model;
109     }
110
111 protected:
112     virtual void TearDown() {
113     }
114
115     virtual void SetUp() {
116         try {
117             TestsCommon::SetUp();
118             reshape_test_params p = ::testing::WithParamInterface<reshape_test_params>::GetParam();
119             std::string model = getModel(p);
120
121             InferenceEngine::CNNNetReader net_reader;
122             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
123
124             MKLDNNGraphTestClass graph;
125             graph.CreateGraph(net_reader.getNetwork());
126             auto& nodes = graph.getNodes();
127             for (int i = 0; i < nodes.size(); i++) {
128                 if (nodes[i]->getType() == MKLDNNPlugin::Reshape) {
129                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
130                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
131                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
132                     }
133                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
134                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
135                 }
136             }
137
138             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::ANY, p.in);
139             src->allocate();
140             fill_data(src->buffer(), src->size());
141
142             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
143
144             if (srcPtr == nullptr)
145                 FAIL() << "Cannot cast blob to TBlob<float>.";
146
147             InferenceEngine::BlobMap srcs;
148             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
149
150             InferenceEngine::OutputsDataMap out;
151             out = net_reader.getNetwork().getOutputsInfo();
152             InferenceEngine::BlobMap outputBlobs;
153
154             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
155
156             InferenceEngine::TBlob<float>::Ptr output;
157             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
158             output->allocate();
159             outputBlobs[item.first] = output;
160
161             graph.Infer(srcs, outputBlobs);
162
163             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
164             dst_ref.allocate();
165
166             ref_reshape(*srcPtr, dst_ref);
167
168             compare(*output, dst_ref);
169         } catch (const InferenceEngine::details::InferenceEngineException &e) {
170             FAIL() << e.what();
171         }
172     }
173 };
174
175 TEST_P(MKLDNNGraphReshapeTests, TestsReshape) {}
176
177
178 INSTANTIATE_TEST_CASE_P(
179         TestsReshape, MKLDNNGraphReshapeTests,
180         ::testing::Values(
181         reshape_test_params{ {1, 3, 228, 228}, {1, 24, 2, 3249}, {1, 24, 2, 3249}, 0, -1, 2,
182             MKLDNNPlugin::impl_desc_type::unknown, { [](MKLDNNPlugin::PrimitiveDescInfo impl) {
183                 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
184                 ASSERT_EQ(1, impl.getConfig().inConfs.size());
185                 ASSERT_EQ(1, impl.getConfig().outConfs.size());
186                 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
187                 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
188         } } },
189         reshape_test_params{ { 4 },{ 2, 2 },{ 2, 2 }, 0, -1, 2,
190             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
191             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
192             ASSERT_EQ(1, impl.getConfig().inConfs.size());
193             ASSERT_EQ(1, impl.getConfig().outConfs.size());
194             ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().inConfs.at(0).desc.getLayout());
195             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().outConfs.at(0).desc.getLayout());
196         } } },
197         reshape_test_params{ { 4 },{ 1, 2, 2 },{ 1, 2, 2 }, 0, -1, 2,
198             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
199             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
200             ASSERT_EQ(1, impl.getConfig().inConfs.size());
201             ASSERT_EQ(1, impl.getConfig().outConfs.size());
202             ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().inConfs.at(0).desc.getLayout());
203             ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().outConfs.at(0).desc.getLayout());
204         } } },
205         reshape_test_params{ { 4 },{ 1, 4, 1, 1 },{ 1, 4, 1, 1 }, 0, -1, 2,
206             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
207             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
208             ASSERT_EQ(1, impl.getConfig().inConfs.size());
209             ASSERT_EQ(1, impl.getConfig().outConfs.size());
210             ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().inConfs.at(0).desc.getLayout());
211             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
212         } } },
213         reshape_test_params{ { 4, 4 },{ 1, 4, 4 },{ 1, 4, 4 }, 0, -1, 2,
214             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
215             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
216             ASSERT_EQ(1, impl.getConfig().inConfs.size());
217             ASSERT_EQ(1, impl.getConfig().outConfs.size());
218             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
219             ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().outConfs.at(0).desc.getLayout());
220         } } },
221         reshape_test_params{ { 4, 4 },{ 1, 4, 2, 2 },{ 1, 4, 2, 2 }, 0, -1, 2,
222             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
223             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
224             ASSERT_EQ(1, impl.getConfig().inConfs.size());
225             ASSERT_EQ(1, impl.getConfig().outConfs.size());
226             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
227             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
228         } } },
229         reshape_test_params{ { 4, 2, 2 },{ 1, 4, 2, 2 },{ 1, 4, 2, 2 }, 0, -1, 2,
230             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
231             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
232             ASSERT_EQ(1, impl.getConfig().inConfs.size());
233             ASSERT_EQ(1, impl.getConfig().outConfs.size());
234             ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().inConfs.at(0).desc.getLayout());
235             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
236         } } },
237         reshape_test_params{ { 2, 2 },{ 4 },{ 4 }, 0, -1, 2,
238             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
239             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
240             ASSERT_EQ(1, impl.getConfig().inConfs.size());
241             ASSERT_EQ(1, impl.getConfig().outConfs.size());
242             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
243             ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().outConfs.at(0).desc.getLayout());
244         } } },
245             reshape_test_params{ { 1, 2, 2 },{ 4 },{ 4 }, 0, -1, 2,
246             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
247             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
248             ASSERT_EQ(1, impl.getConfig().inConfs.size());
249             ASSERT_EQ(1, impl.getConfig().outConfs.size());
250             ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().inConfs.at(0).desc.getLayout());
251             ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().outConfs.at(0).desc.getLayout());
252         } } },
253         reshape_test_params{ { 1, 1, 2, 2 },{ 4 },{ 4 }, 0, -1, 2,
254             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
255             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
256             ASSERT_EQ(1, impl.getConfig().inConfs.size());
257             ASSERT_EQ(1, impl.getConfig().outConfs.size());
258             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
259             ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().outConfs.at(0).desc.getLayout());
260         } } },
261         reshape_test_params{ { 4, 2, 2 },{ 4, 4 },{ 4, 4 }, 0, -1, 2,
262             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
263             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
264             ASSERT_EQ(1, impl.getConfig().inConfs.size());
265             ASSERT_EQ(1, impl.getConfig().outConfs.size());
266             ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().inConfs.at(0).desc.getLayout());
267             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().outConfs.at(0).desc.getLayout());
268         } } },
269         reshape_test_params{ { 1, 4, 2, 2 },{ 4, 4 },{ 4, 4 }, 0, -1, 2,
270             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
271             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
272             ASSERT_EQ(1, impl.getConfig().inConfs.size());
273             ASSERT_EQ(1, impl.getConfig().outConfs.size());
274             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
275             ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().outConfs.at(0).desc.getLayout());
276         } } },
277         reshape_test_params{ { 1, 4, 2, 2 },{ 4, 2, 2 },{ 4, 2, 2 }, 0, -1, 2,
278             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
279             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
280             ASSERT_EQ(1, impl.getConfig().inConfs.size());
281             ASSERT_EQ(1, impl.getConfig().outConfs.size());
282             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
283             ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().outConfs.at(0).desc.getLayout());
284         } } },
285         reshape_test_params{ { 1, 4, 2, 2 },{ 4, 2, 2, 1, 1 },{ 4, 2, 2, 1, 1 }, 0, -1, 2,
286             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
287             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
288             ASSERT_EQ(1, impl.getConfig().inConfs.size());
289             ASSERT_EQ(1, impl.getConfig().outConfs.size());
290             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
291             ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
292         } } },
293         reshape_test_params{ { 4, 2, 2, 1, 1 },{ 1, 4, 2, 2 },{ 1, 4, 2, 2 }, 0, -1, 2,
294             MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
295             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
296             ASSERT_EQ(1, impl.getConfig().inConfs.size());
297             ASSERT_EQ(1, impl.getConfig().outConfs.size());
298             ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
299             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
300         } } }
301 ));