Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_softmax_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8
9 #include "test_graph.hpp"
10
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
15
16
17 using namespace ::testing;
18 using namespace std;
19 using namespace mkldnn;
20
21
22 struct softmax_test_params {
23     struct {
24         size_t n;
25         size_t c;
26         size_t h;
27         size_t w;
28     } in;
29
30     int axis;
31
32     size_t num_prim_desc;
33
34     int selectedType;
35     std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
36
37     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
38 };
39
40 template <typename data_t>
41 void check_softmax_fwd(const InferenceEngine::TBlob<data_t> &src, softmax_test_params prm)
42 {
43     const data_t *src_data = src.readOnly();
44
45     size_t W = prm.in.w;
46     size_t H = prm.in.h;
47     size_t C = prm.in.c;
48     size_t MB = prm.in.n;
49
50     auto off = [=](int n, int c, int h, int w)
51     {
52         return (n * W * H * C + c * W * H + h * W + w);
53     };
54
55     auto check_norm = [=](double res) {
56         if(res < 0.999f || res > 1.001) {
57             ASSERT_TRUE(res > 0.99f && res < 1.01);
58         }
59     };
60
61     if(prm.axis == 0) {
62         for (int c = 0; c < C; ++c) {
63             for (int h = 0; h < H; ++h) {
64                 for (int w = 0; w < W; ++w) {
65                     double result = 0.0f;
66
67                     for (int n = 0; n < MB; ++n) {
68                         result += src_data[off(n, c, h, w)];//dst_ptr[map_index(dst_pd, off(n, c, h, w))];
69                     }
70                     check_norm(result);
71                 }
72             }
73         }
74     }
75     else if(prm.axis == 1) {
76         for (int n = 0; n < MB; ++n) {
77             for (int h = 0; h < H; ++h) {
78                 for (int w = 0; w < W; ++w) {
79                     double result = 0.0f;
80
81                     for (int c = 0; c < C; ++c) {
82                         result += src_data[off(n, c, h, w)];//dst_ptr[map_index(dst_pd, off(n, c, h, w))];
83                     }
84
85                     check_norm(result);
86                 }
87             }
88         }
89     }
90     else if(prm.axis == 2) {
91         for (int n = 0; n < MB; ++n) {
92             for (int c = 0; c < C; ++c) {
93                 for (int w = 0; w < W; ++w) {
94                     double result = 0.0f;
95
96                     for (int h = 0; h < H; ++h) {
97                         result += src_data[off(n, c, h, w)];//dst_ptr[map_index(dst_pd, off(n, c, h, w))];
98                     }
99
100                     check_norm(result);
101                 }
102             }
103         }
104     }
105     else if(prm.axis == 3) {
106         for (int n = 0; n < MB; ++n) {
107             for (int c = 0; c < C; ++c) {
108                 for (int h = 0; h < H; ++h) {
109                     double result = 0.0f;
110
111                     for (int w = 0; w < W; ++w) {
112                         result += src_data[off(n, c, h, w)];//dst_ptr[map_index(dst_pd, off(n, c, h, w))];
113                     }
114
115                     check_norm(result);
116                 }
117             }
118         }
119     }
120 }
121
122 class MKLDNNGraphSoftMaxTests: public TestsCommon,
123                                      public WithParamInterface<softmax_test_params> {
124     std::string model_t = R"V0G0N(
125 <Net Name="Lrn_Only" version="2" precision="FP32" batch="1">
126     <layers>
127         <layer name="in1" type="Input" precision="FP32" id="0">
128             <output>
129                 <port id="0">
130                     <dim>_IN_</dim>
131                     <dim>_IC_</dim>
132                     <dim>_IH_</dim>
133                     <dim>_IW_</dim>
134                 </port>
135             </output>
136         </layer>
137         <layer name="norm" id="1" type="Softmax" precision="FP32">
138             <data PrimitivesPriority="_IMPLS_" axis="_AX_"/>
139             <input>
140                 <port id="1">
141                     <dim>_IN_</dim>
142                     <dim>_IC_</dim>
143                     <dim>_IH_</dim>
144                     <dim>_IW_</dim>
145                 </port>
146             </input>
147             <output>
148                 <port id="2">
149                     <dim>_IN_</dim>
150                     <dim>_IC_</dim>
151                     <dim>_IH_</dim>
152                     <dim>_IW_</dim>
153                 </port>
154             </output>
155         </layer>
156     </layers>
157     <edges>
158         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
159     </edges>
160 </Net>
161 )V0G0N";
162
163 protected:
164     std::string getModel(softmax_test_params p) {
165         std::string model = model_t;
166
167         REPLACE_WITH_NUM(model, "_IW_", p.in.w);
168         REPLACE_WITH_NUM(model, "_IH_", p.in.h);
169         REPLACE_WITH_NUM(model, "_IC_", p.in.c);
170         REPLACE_WITH_NUM(model, "_IN_", p.in.n);
171         REPLACE_WITH_NUM(model, "_AX_", p.axis);
172         std::string impls;
173         for (const auto& preferType : p.preferTypes) {
174             if (!impls.empty())
175                 impls += ",";
176             impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
177         }
178         REPLACE_WITH_STR(model, "_IMPLS_", impls);
179
180         return model;
181     }
182
183     virtual void TearDown() {
184     }
185
186     virtual void SetUp() {
187         try {
188             TestsCommon::SetUp();
189             softmax_test_params p = ::testing::WithParamInterface<softmax_test_params>::GetParam();
190             std::string model = getModel(p);
191
192             InferenceEngine::CNNNetReader net_reader;
193             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
194
195             MKLDNNGraphTestClass graph;
196             graph.CreateGraph(net_reader.getNetwork());
197             auto& nodes = graph.getNodes();
198             for (int i = 0; i < nodes.size(); i++) {
199                 if (nodes[i]->getType() == MKLDNNPlugin::SoftMax) {
200                     ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
201                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
202                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
203                     }
204                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
205                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
206                 }
207             }
208
209             InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
210
211             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
212             src->allocate();
213             fill_data(src->buffer(), src->size());
214
215             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
216
217             if (srcPtr == nullptr)
218                 FAIL() << "Cannot cast blob to TBlob<float>.";
219
220             InferenceEngine::BlobMap srcs;
221             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
222
223             InferenceEngine::OutputsDataMap out;
224             out = net_reader.getNetwork().getOutputsInfo();
225             InferenceEngine::BlobMap outputBlobs;
226
227             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
228
229             InferenceEngine::TBlob<float>::Ptr output;
230             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
231             output->allocate();
232             outputBlobs[item.first] = output;
233
234             graph.Infer(srcs, outputBlobs);
235
236             check_softmax_fwd(*output, p);
237         } catch (const InferenceEngine::details::InferenceEngineException &e) {
238             FAIL() << e.what();
239         }
240     }
241 };
242
243 TEST_P(MKLDNNGraphSoftMaxTests, TestsSoftMax) {}
244
245
246 INSTANTIATE_TEST_CASE_P(
247         TestsSoftMax, MKLDNNGraphSoftMaxTests,
248         ::testing::Values(
249                 softmax_test_params{{1, 3, 228, 228}, 1, 2, MKLDNNPlugin::impl_desc_type::jit},
250                 softmax_test_params{{1, 3, 228, 228}, 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
251                 softmax_test_params{{1, 100, 6, 1}, 1, 2, MKLDNNPlugin::impl_desc_type::jit},
252                 softmax_test_params{{1, 100, 6, 1}, 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
253                 softmax_test_params{{1, 1000, 1, 1}, 1, 1, MKLDNNPlugin::impl_desc_type::ref},
254                 softmax_test_params{{8, 1000, 1, 1}, 1, 1, MKLDNNPlugin::impl_desc_type::ref},
255                 softmax_test_params{{1, 19, 128, 128}, 1, 2, MKLDNNPlugin::impl_desc_type::jit},
256                 softmax_test_params{{1, 19, 128, 128}, 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
257 //                softmax_test_params{{8, 100, 81, 1}, 2, 2, MKLDNNPlugin::impl_desc_type::jit},
258                 softmax_test_params{{8, 100, 81, 1}, 2, 1, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
259                 softmax_test_params{{1, 1, 1, 1}, 3, 1, MKLDNNPlugin::impl_desc_type::ref},
260 //                softmax_test_params{{1, 1, 1, 33}, 3, 2, MKLDNNPlugin::impl_desc_type::jit},
261                 softmax_test_params{{1, 1, 1, 33}, 3, 1, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
262 //                softmax_test_params{{8, 1, 10, 81}, 3, 2, MKLDNNPlugin::impl_desc_type::jit},
263                 softmax_test_params{{8, 1, 10, 81}, 3, 1, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}
264                 ));
265
266 class MKLDNNGraphDynBatchSoftMaxTests: public MKLDNNGraphSoftMaxTests {
267 protected:
268     virtual void SetUp() {
269         try {
270             TestsCommon::SetUp();
271             softmax_test_params p = ::testing::WithParamInterface<softmax_test_params>::GetParam();
272             std::string model = getModel(p);
273             size_t MB = p.in.n;
274             if (MB < 2)
275                 MB = 2;
276
277             InferenceEngine::CNNNetReader net_reader;
278             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
279             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
280             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
281             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
282             InferenceEngine::ResponseDesc resp;
283             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
284             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
285
286             MKLDNNGraphTestClass graph;
287             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
288             graph.CreateGraph(net_reader.getNetwork());
289
290             InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
291
292             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
293             src->allocate();
294             fill_data(src->buffer(), src->size());
295
296             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
297
298             if (srcPtr == nullptr)
299                 FAIL() << "Cannot cast blob to TBlob<float>.";
300
301             InferenceEngine::BlobMap srcs;
302             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
303
304             InferenceEngine::OutputsDataMap out;
305             out = net_reader.getNetwork().getOutputsInfo();
306             InferenceEngine::BlobMap outputBlobs;
307
308             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
309
310             InferenceEngine::TBlob<float>::Ptr output;
311             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
312             output->allocate();
313             outputBlobs[item.first] = output;
314
315             auto checkSoftmax = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
316                 return node->getType() == MKLDNNPlugin::SoftMax;
317             };
318
319             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkSoftmax);
320             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkSoftmax);
321         } catch (const InferenceEngine::details::InferenceEngineException &e) {
322             FAIL() << e.what();
323         }
324     }
325 };
326
327 TEST_P(MKLDNNGraphDynBatchSoftMaxTests, TestsDynBatchSoftMax) {}
328
329
330 INSTANTIATE_TEST_CASE_P(
331         TestsDynBatchSoftMax, MKLDNNGraphDynBatchSoftMaxTests,
332         ::testing::Values(
333                 softmax_test_params{{1, 3, 228, 228}, 1, 2, MKLDNNPlugin::impl_desc_type::jit},
334                 softmax_test_params{{1, 3, 228, 228}, 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
335                 softmax_test_params{{1, 100, 6, 1}, 1, 2, MKLDNNPlugin::impl_desc_type::jit},
336                 softmax_test_params{{1, 100, 6, 1}, 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
337                 softmax_test_params{{1, 1000, 1, 1}, 1, 1, MKLDNNPlugin::impl_desc_type::ref},
338                 softmax_test_params{{8, 1000, 1, 1}, 1, 1, MKLDNNPlugin::impl_desc_type::ref},
339                 softmax_test_params{{1, 19, 128, 128}, 1, 2, MKLDNNPlugin::impl_desc_type::jit},
340                 softmax_test_params{{1, 19, 128, 128}, 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
341 //                softmax_test_params{{8, 100, 81, 1}, 2, 2, MKLDNNPlugin::impl_desc_type::jit},
342                 softmax_test_params{{8, 100, 81, 1}, 2, 1, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
343                 softmax_test_params{{1, 1, 1, 1}, 3, 1, MKLDNNPlugin::impl_desc_type::ref},
344 //                softmax_test_params{{1, 1, 1, 33}, 3, 2, MKLDNNPlugin::impl_desc_type::jit},
345                 softmax_test_params{{1, 1, 1, 33}, 3, 1, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
346 //                softmax_test_params{{8, 1, 10, 81}, 3, 2, MKLDNNPlugin::impl_desc_type::jit},
347                 softmax_test_params{{8, 1, 10, 81}, 3, 1, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}
348                 ));