Publishing R3
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_crop_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
10
11 #include "test_graph.hpp"
12
13 #include "single_layer_common.hpp"
14 #include <mkldnn_plugin/mkldnn_extension_utils.h>
15 #include <inference_engine/cnn_network_impl.hpp>
16 #include "tests_common.hpp"
17
18
19 using namespace ::testing;
20 using namespace std;
21 using namespace mkldnn;
22
23 struct crop_test_params {
24     struct {
25         size_t n;
26         size_t c;
27         size_t h;
28         size_t w;
29     } in;
30
31     std::vector<int> axis;
32     std::vector<int> offsets;
33     std::vector<int> dims;
34
35     size_t num_prim_desc;
36
37     MKLDNNPlugin::impl_desc_type selectedType;
38
39     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
40 };
41
42
43
44 template <typename data_t>
45 void ref_crop(InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, crop_test_params prm) {
46     data_t *dst_ptr = dst.data();
47
48     std::vector<int> offsets(4);
49     for (size_t i = 0; i < prm.offsets.size(); i++) {
50         offsets[prm.axis[i]] = prm.offsets[i];
51     }
52     int OFFSET_N = offsets.at(0);
53     int OFFSET_C = offsets.at(1);
54     int OFFSET_H = offsets.at(2);
55     int OFFSET_W = offsets.at(3);
56
57     const int ON = dst.dims().at(3);
58     const int OC = dst.dims().at(2);
59     const int OH = dst.dims().at(1);
60     const int OW = dst.dims().at(0);
61
62     const int _IN = src.dims().at(0);
63     const int IC = src.dims().at(1);
64     const int IH = src.dims().at(2);
65     const int IW = src.dims().at(3);
66
67     auto dst_off = [=](int n, int c, int h, int w) {
68         return (n * OW * OH * OC + c * OW * OH + h * OW + w);
69     };
70     auto src_off = [=](int n, int c, int h, int w) {
71         return (n * IW * IH * IC + c * IW * IH + h * IW + w);
72     };
73
74     ASSERT_GE(_IN - OFFSET_N, ON);
75     ASSERT_GE(IC - OFFSET_C, OC);
76     ASSERT_GE(IH - OFFSET_H, OH);
77     ASSERT_GE(IW - OFFSET_W, OW);
78
79     data_t* src_ptr = src.data();
80     for (int n = 0; n < ON; ++n) {
81         for (int c = 0; c < OC; ++c) {
82             for (int h = 0; h < OH; ++h) {
83                 for (int w = 0; w < OW; ++w) {
84                     dst_ptr[dst_off(n, c, h, w)] = src_ptr[src_off(n + OFFSET_N, c + OFFSET_C,
85                                                                    h + OFFSET_H, w + OFFSET_W)];
86                 }
87             }
88         }
89     }
90 }
91
92 class MKLDNNGraphCropTests: public TestsCommon,
93                                      public WithParamInterface<crop_test_params> {
94     std::string model_t = R"V0G0N(
95 <Net Name="Crop_Only" version="2" precision="FP32" batch="1">
96     <layers>
97         <layer name="in1" type="Input" precision="FP32" id="0">
98             <output>
99                 <port id="0">
100                     <dim>_IN_</dim>
101                     <dim>_IC_</dim>
102                     <dim>_IH_</dim>
103                     <dim>_IW_</dim>
104                 </port>
105             </output>
106         </layer>
107         <layer name="crop" id="1" type="Crop" precision="FP32">
108             <data axis="_AXC_" offset="_OFC_" dim="_DIMC_" />
109             <input>
110                 <port id="1">
111                     <dim>_IN_</dim>
112                     <dim>_IC_</dim>
113                     <dim>_IH_</dim>
114                     <dim>_IW_</dim>
115                 </port>
116             </input>
117             <output>
118                 <port id="2">
119                     <dim>_DN_</dim>
120                     <dim>_DC_</dim>
121                     <dim>_DH_</dim>
122                     <dim>_DW_</dim>
123                 </port>
124             </output>
125         </layer>
126     </layers>
127     <edges>
128         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
129     </edges>
130 </Net>
131 )V0G0N";
132
133 protected:
134     std::string getModel(crop_test_params p) {
135         std::string model = model_t;
136
137         std::string axis, offset, dim;
138         std::vector<size_t> outDims = {p.in.n, p.in.c, p.in.h, p.in.w};
139         for (size_t i = 0; i < p.offsets.size(); i++) {
140             if (!axis.empty())
141                 axis += ",";
142             axis += std::to_string(p.axis[i]);
143             if (!offset.empty())
144                 offset += ",";
145             offset += std::to_string(p.offsets[i]);
146             if (!dim.empty())
147                 dim += ",";
148             dim += std::to_string(p.dims[i]);
149             outDims[p.axis[i]] = p.dims[i];
150         }
151
152         REPLACE_WITH_NUM(model, "_IW_", p.in.w);
153         REPLACE_WITH_NUM(model, "_IH_", p.in.h);
154         REPLACE_WITH_NUM(model, "_IC_", p.in.c);
155         REPLACE_WITH_NUM(model, "_IN_", p.in.n);
156         REPLACE_WITH_NUM(model, "_DN_", outDims[0]);
157         REPLACE_WITH_NUM(model, "_DC_", outDims[1]);
158         REPLACE_WITH_NUM(model, "_DH_", outDims[2]);
159         REPLACE_WITH_NUM(model, "_DW_", outDims[3]);
160         REPLACE_WITH_STR(model, "_AXC_", axis);
161         REPLACE_WITH_STR(model, "_OFC_", offset);
162         REPLACE_WITH_STR(model, "_DIMC_", dim);
163         return model;
164     }
165
166     virtual void TearDown() {
167     }
168
169     virtual void SetUp() {
170         try {
171             TestsCommon::SetUp();
172             crop_test_params p = ::testing::WithParamInterface<crop_test_params>::GetParam();
173             std::string model = getModel(p);
174
175             InferenceEngine::CNNNetReader net_reader;
176             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
177
178             MKLDNNGraphTestClass graph;
179             graph.CreateGraph(net_reader.getNetwork());
180
181             auto& nodes = graph.getNodes();
182             for (int i = 0; i < nodes.size(); i++) {
183                 if (nodes[i]->getType() == MKLDNNPlugin::Crop) {
184                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
185                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
186                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
187                     }
188                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
189                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
190                 }
191             }
192
193             InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
194
195             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW,  dims_src);
196             src->allocate();
197             fill_data(src->buffer(), src->size());
198
199             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
200
201             if (srcPtr == nullptr)
202                 FAIL() << "Cannot cast blob to TBlob<float>.";
203
204             InferenceEngine::BlobMap srcs;
205             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
206
207             InferenceEngine::OutputsDataMap out;
208             out = net_reader.getNetwork().getOutputsInfo();
209             InferenceEngine::BlobMap outputBlobs;
210
211             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
212
213             InferenceEngine::TBlob<float>::Ptr output;
214             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
215             output->allocate();
216             outputBlobs[item.first] = output;
217
218             graph.Infer(srcs, outputBlobs);
219
220             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
221             dst_ref.allocate();
222
223             ref_crop(*srcPtr, dst_ref, p);
224
225             compare(*output, dst_ref);
226         } catch (const InferenceEngine::details::InferenceEngineException &e) {
227             FAIL() << e.what();
228         }
229     }
230 };
231
232 TEST_P(MKLDNNGraphCropTests, TestCrop) {}
233
234
235 INSTANTIATE_TEST_CASE_P(
236         TestCrop, MKLDNNGraphCropTests,
237         ::testing::Values(
238                 crop_test_params{{1, 5, 32, 32}, {1, 2, 3}, {2, 5, 4}, {2, 23, 23}, 1, MKLDNNPlugin::impl_desc_type::unknown, {
239                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
240                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
241                             ASSERT_EQ(1, impl.getConfig().inConfs.size());
242                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
243                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
244                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
245                         }}},
246                 crop_test_params{{3, 8, 32, 32}, {0, 1, 2, 3}, {1, 0, 20, 20}, {2, 8, 5, 5}, 2, MKLDNNPlugin::impl_desc_type::unknown, {
247                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
248                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
249                             ASSERT_EQ(1, impl.getConfig().inConfs.size());
250                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
251                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
252                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
253                         }} },
254                 crop_test_params{{1, 5, 32, 32}, {3}, {10}, {20}, 1, MKLDNNPlugin::impl_desc_type::unknown },
255                 crop_test_params{{1, 5, 32, 20}, {2, 3}, {30, 10}, {2, 10}, 1, MKLDNNPlugin::impl_desc_type::unknown }));
256
257 class MKLDNNGraphDynBatchCropTests: public MKLDNNGraphCropTests {
258 protected:
259
260     virtual void SetUp() {
261         try {
262             TestsCommon::SetUp();
263             crop_test_params p = ::testing::WithParamInterface<crop_test_params>::GetParam();
264             std::string model = getModel(p);
265             size_t MB = p.in.n;
266             if (MB < 2)
267                 MB = 2;
268
269             InferenceEngine::CNNNetReader net_reader;
270             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
271             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
272             network.setBatchSize(MB);
273
274             MKLDNNGraphTestClass graph;
275             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
276             graph.CreateGraph(network);
277
278             InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
279             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
280             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
281             if (srcPtr == nullptr)
282                 FAIL() << "Cannot cast blob to TBlob<float>.";
283
284             src->allocate();
285             fill_data(src->buffer(), src->size());
286
287             InferenceEngine::BlobMap srcs;
288             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
289
290             InferenceEngine::OutputsDataMap out;
291             out = net_reader.getNetwork().getOutputsInfo();
292             InferenceEngine::BlobMap outputBlobs;
293
294             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
295
296             InferenceEngine::TBlob<float>::Ptr output;
297             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
298             output->allocate();
299             outputBlobs[item.first] = output;
300
301             auto checkCrop = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
302                 return node->getType() == MKLDNNPlugin::Crop;
303             };
304
305             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkCrop);
306             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkCrop);
307         } catch (const InferenceEngine::details::InferenceEngineException &e) {
308             FAIL() << e.what();
309         }
310     }
311 };
312
313 TEST_P(MKLDNNGraphDynBatchCropTests, TestsDynBatchCrop) {}
314
315 INSTANTIATE_TEST_CASE_P(
316         TestsDynBatchCrop, MKLDNNGraphDynBatchCropTests,
317         ::testing::Values(
318                 crop_test_params{{1, 5, 32, 32}, {1, 2, 3}, {2, 5, 4}, {2, 23, 23}, 1, MKLDNNPlugin::impl_desc_type::unknown, {
319                         [](MKLDNNPlugin::PrimitiveDescInfo impl) {
320                             ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
321                             ASSERT_EQ(1, impl.getConfig().inConfs.size());
322                             ASSERT_EQ(1, impl.getConfig().outConfs.size());
323                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
324                             ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
325                         }}},
326                 crop_test_params{{1, 5, 32, 32}, {3}, {10}, {20}, 1, MKLDNNPlugin::impl_desc_type::unknown },
327                 crop_test_params{{1, 5, 32, 20}, {2, 3}, {30, 10}, {2, 10}, 1, MKLDNNPlugin::impl_desc_type::unknown }));