Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / extensions / resample_tests.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8
9 #include "test_graph.hpp"
10
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include "tests_common.hpp"
14
15 using namespace ::testing;
16 using namespace std;
17 using namespace mkldnn;
18
19 struct resample_test_params {
20     struct {
21         size_t n;
22         size_t c;
23         size_t h;
24         size_t w;
25     } in;
26
27     float factor;
28     int antialias;
29     std::string type;
30
31     size_t num_prim_desc;
32     bool isBlockedFormat;
33     int selectedType;
34
35     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
36 };
37
38
39 static inline float triangleCoeff(float x) {
40     return std::max(0.0f, 1 - std::abs(x));
41 }
42
43 extern InferenceEngine::IExtensionPtr make_FakeExtensions();
44
45 template <typename data_t>
46 void ref_resample(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, resample_test_params prm) {
47     const data_t *src_data = src.readOnly();
48     data_t *dst_data = dst.data();
49
50     size_t N = prm.in.n;
51     size_t C = prm.in.c;
52     size_t IH = prm.in.h;
53     size_t IW = prm.in.w;
54     size_t OH = prm.in.h / prm.factor;
55     size_t OW = prm.in.w / prm.factor;
56
57     float fx = static_cast<float>(IW) / static_cast<float>(OW);
58     float fy = static_cast<float>(IH) / static_cast<float>(OH);
59
60     if (prm.type == "caffe.ResampleParameter.NEAREST") {
61         for (size_t b = 0; b < N; b++) {
62             for (size_t c = 0; c < C; c++) {
63                 const float *in_ptr = src_data + IW * IH * C * b + IW * IH * c;
64                 float *out_ptr = dst_data + OW * OH * C * b + OW * OH * c;
65
66                 for (size_t oy = 0; oy < OH; oy++) {
67                     for (size_t ox = 0; ox < OW; ox++) {
68                         float ix = ox * fx + fy / 2.0f - 0.5f;
69                         float iy = oy * fy + fx / 2.0f - 0.5f;
70
71                         size_t ix_r = static_cast<size_t>(round(ix));
72                         size_t iy_r = static_cast<size_t>(round(iy));
73
74                         out_ptr[oy * OW + ox] = in_ptr[iy_r * IW + ix_r];
75                     }
76                 }
77             }
78         }
79     } else if (prm.type == "caffe.ResampleParameter.LINEAR") {
80         size_t kernel_width = 2;
81         bool isDownsample = (fx > 1) || (fy > 1);
82         bool antialias = isDownsample && prm.antialias;
83
84         for (size_t b = 0; b < N; b++) {
85             for (size_t c = 0; c < C; c++) {
86                 const float *in_ptr = src_data + IW * IH * C * b + IW * IH * c;
87                 float *out_ptr = dst_data + OW * OH * C * b + OW * OH * c;
88
89                 for (size_t oy = 0; oy < OH; oy++) {
90                     for (size_t ox = 0; ox < OW; ox++) {
91                         float ix = ox * fx + fy / 2.0f - 0.5f;
92                         float iy = oy * fy + fx / 2.0f - 0.5f;
93
94                         int ix_r = static_cast<int>(round(ix));
95                         int iy_r = static_cast<int>(round(iy));
96
97                         float sum = 0;
98                         float wsum = 0;
99
100                         float ax = 1.0f / (antialias ? fx : 1.0f);
101                         float ay = 1.0f / (antialias ? fy : 1.0f);
102
103                         int rx = (fx < 1.0f) ? 2 : ceil(static_cast<float>(kernel_width) / ax);
104                         int ry = (fy < 1.0f) ? 2 : ceil(static_cast<float>(kernel_width) / ay);
105
106                         for (int y = iy_r - ry; y <= iy_r + ry; y++) {
107                             for (int x = ix_r - rx; x <= ix_r + rx; x++) {
108                                 if (y < 0 || x < 0 || y >= static_cast<int>(IH) || x >= static_cast<int>(IW))
109                                     continue;
110
111                                 float dx = ix - x;
112                                 float dy = iy - y;
113
114                                 float w = ax * triangleCoeff(ax * dx) * ay * triangleCoeff(ay * dy);
115
116                                 sum += w * in_ptr[y * IW + x];
117                                 wsum += w;
118                             }
119                         }
120
121                         out_ptr[oy * OW + ox] = (!wsum) ? 0 : (sum / wsum);
122                     }
123                 }
124             }
125         }
126     } else {
127         assert(!"Unsupported resample operation type");
128     }
129 }
130
131 class MKLDNNCPUExtResampleTests: public TestsCommon, public WithParamInterface<resample_test_params> {
132     std::string model_t = R"V0G0N(
133 <Net Name="MVN_net" version="2" precision="FP32" batch="1">
134     <layers>
135         <layer name="in1" type="Input" precision="FP32" id="0">
136             <output>
137                 <port id="0">
138                     <dim>_IN_</dim>
139                     <dim>_IC_</dim>
140                     <dim>_IH_</dim>
141                     <dim>_IW_</dim>
142                 </port>
143             </output>
144         </layer>
145         <layer name="fakeLayer" id="1" type="_FL_" precision="FP32">
146             <input>
147                 <port id="1">
148                     <dim>_IN_</dim>
149                     <dim>_IC_</dim>
150                     <dim>_IH_</dim>
151                     <dim>_IW_</dim>
152                 </port>
153             </input>
154             <output>
155                 <port id="2">
156                     <dim>_IN_</dim>
157                     <dim>_IC_</dim>
158                     <dim>_IH_</dim>
159                     <dim>_IW_</dim>
160                 </port>
161             </output>
162         </layer>
163         <layer name="resample" id="2" type="Resample" precision="FP32">
164             <data antialias="_AN_" factor="_F_" type="_T_"/>
165             <input>
166                 <port id="3">
167                     <dim>_IN_</dim>
168                     <dim>_IC_</dim>
169                     <dim>_IH_</dim>
170                     <dim>_IW_</dim>
171                 </port>
172             </input>
173             <output>
174                 <port id="4">
175                     <dim>_IN_</dim>
176                     <dim>_IC_</dim>
177                     <dim>_OH_</dim>
178                     <dim>_OW_</dim>
179                 </port>
180             </output>
181         </layer>
182     </layers>
183     <edges>
184         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
185         <edge from-layer="1" from-port="2" to-layer="2" to-port="3"/>
186     </edges>
187 </Net>
188 )V0G0N";
189
190     std::string getModel(resample_test_params p) {
191         std::string model = model_t;
192         if (p.isBlockedFormat)
193             REPLACE_WITH_STR(model, "_FL_", "FakeLayerBLK");
194         else
195             REPLACE_WITH_STR(model, "_FL_", "FakeLayerPLN");
196
197         REPLACE_WITH_NUM(model, "_IW_", p.in.w);
198         REPLACE_WITH_NUM(model, "_IH_", p.in.h);
199         REPLACE_WITH_NUM(model, "_IC_", p.in.c);
200         REPLACE_WITH_NUM(model, "_IN_", p.in.n);
201
202         REPLACE_WITH_NUM(model, "_OW_", (int)(p.in.w / p.factor));
203         REPLACE_WITH_NUM(model, "_OH_", (int)(p.in.h / p.factor));
204
205         REPLACE_WITH_NUM(model, "_AN_", p.antialias);
206         REPLACE_WITH_NUM(model, "_F_", p.factor);
207         REPLACE_WITH_STR(model, "_T_", p.type);
208
209         return model;
210     }
211
212 protected:
213     virtual void TearDown() {
214     }
215
216     virtual void SetUp() {
217         try {
218             TestsCommon::SetUp();
219             resample_test_params p = ::testing::WithParamInterface<resample_test_params>::GetParam();
220             std::string model = getModel(p);
221
222             InferenceEngine::CNNNetReader net_reader;
223             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
224
225             InferenceEngine::Extension cpuExt(make_so_name("cpu_extension"));
226             MKLDNNPlugin::MKLDNNExtensionManager::Ptr extMgr(new MKLDNNPlugin::MKLDNNExtensionManager());
227             extMgr->AddExtension(InferenceEngine::IExtensionPtr(&cpuExt, [](InferenceEngine::IExtension*){}));
228             extMgr->AddExtension(make_FakeExtensions());
229
230             MKLDNNGraphTestClass graph;
231             graph.CreateGraph(net_reader.getNetwork(), extMgr);
232
233             auto& nodes = graph.getNodes();
234             nodes = graph.getNodes();
235
236             for (auto &node : nodes) {
237                 if (node->getName() == "resample") {
238                     ASSERT_EQ(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size());
239                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
240                         p.comp.at(j)(node->getSupportedPrimitiveDescriptors().at(j));
241                     }
242                     ASSERT_NE(nullptr, node->getSelectedPrimitiveDescriptor());
243                     ASSERT_EQ(p.selectedType,
244                               node->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
245                 }
246             }
247
248             if (p.isBlockedFormat)
249                 ASSERT_EQ(6, nodes.size());
250             else
251                 ASSERT_EQ(4, nodes.size());
252
253             InferenceEngine::SizeVector dims_src = {p.in.w, p.in.h, p.in.c, p.in.n};
254
255             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NHWC, dims_src);
256             src->allocate();
257             fill_data(src->buffer(), src->size());
258
259             auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
260
261             if (srcPtr == nullptr)
262                 FAIL() << "Cannot cast blob to TBlob<float>.";
263
264             InferenceEngine::BlobMap srcs;
265             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
266
267             InferenceEngine::OutputsDataMap out;
268             out = net_reader.getNetwork().getOutputsInfo();
269             InferenceEngine::BlobMap outputBlobs;
270
271             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
272
273             InferenceEngine::TBlob<float>::Ptr output;
274             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
275             output->allocate();
276             outputBlobs[item.first] = output;
277
278             graph.Infer(srcs, outputBlobs);
279
280             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
281             dst_ref.allocate();
282             ref_resample(*srcPtr, dst_ref, p);
283             compare(*output, dst_ref);
284         } catch (const InferenceEngine::details::InferenceEngineException &e) {
285             FAIL() << e.what();
286         }
287     }
288 };
289
290 TEST_P(MKLDNNCPUExtResampleTests, TestsResample) {}
291
292 INSTANTIATE_TEST_CASE_P(
293         TestsResample, MKLDNNCPUExtResampleTests,
294         ::testing::Values(
295                 resample_test_params{{2, 64, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 2, false, MKLDNNPlugin::impl_desc_type::unknown },
296                 resample_test_params{{2, 64, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 2, true, MKLDNNPlugin::impl_desc_type::unknown },
297                 resample_test_params{{2, 64, 15, 25}, 1.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
298                 resample_test_params{{2, 64, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 2, false, MKLDNNPlugin::impl_desc_type::unknown },
299                 resample_test_params{{2, 64, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 2, true, MKLDNNPlugin::impl_desc_type::unknown },
300                 resample_test_params{{2, 64, 10, 20}, 0.25f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
301                 resample_test_params{{2, 64, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 2, false, MKLDNNPlugin::impl_desc_type::unknown },
302                 resample_test_params{{2, 64, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 2, true, MKLDNNPlugin::impl_desc_type::unknown },
303                 resample_test_params{{2, 64, 10, 20}, 4.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
304                 resample_test_params{{2, 3, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 2, false, MKLDNNPlugin::impl_desc_type::unknown },
305                 resample_test_params{{2, 3, 15, 25}, 1.f, 0, "caffe.ResampleParameter.NEAREST", 2, true, MKLDNNPlugin::impl_desc_type::unknown },
306                 resample_test_params{{2, 3, 15, 25}, 1.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
307                 resample_test_params{{2, 3, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 2, false, MKLDNNPlugin::impl_desc_type::unknown },
308                 resample_test_params{{2, 3, 10, 20}, 0.25f, 0, "caffe.ResampleParameter.NEAREST", 2, true, MKLDNNPlugin::impl_desc_type::unknown },
309                 resample_test_params{{2, 3, 10, 20}, 0.25f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown },
310                 resample_test_params{{2, 3, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 2, false, MKLDNNPlugin::impl_desc_type::unknown },
311                 resample_test_params{{2, 3, 10, 20}, 4.f, 0, "caffe.ResampleParameter.NEAREST", 2, true, MKLDNNPlugin::impl_desc_type::unknown },
312                 resample_test_params{{2, 3, 10, 20}, 4.f, 1, "caffe.ResampleParameter.LINEAR", 1, false, MKLDNNPlugin::impl_desc_type::unknown }));