Publishing R3
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_activation_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
10 #include "test_graph.hpp"
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
15
16 using namespace ::testing;
17 using namespace std;
18 using namespace mkldnn;
19
20 struct activation_test_params {
21     mkldnn::algorithm alg;
22     float alpha;
23     float beta;
24
25     struct {
26         size_t n;
27         size_t c;
28         size_t h;
29         size_t w;
30     } in;
31
32     size_t num_prim_desc;
33
34     MKLDNNPlugin::impl_desc_type selectedType;
35     std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
36
37     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
38 };
39
40 template <typename T, typename A> inline T relu_fwd(T s, A alpha) {
41     return s > 0 ? s : static_cast<T>(s * alpha);
42 }
43
44 template <typename T, typename A> T elu_fwd(T s, A alpha) {
45     return s > 0 ? s : static_cast<T>(alpha * (::expf(s) - 1));
46 }
47
48 template <typename T>
49 T logistic_fwd(T s) {
50     T v = ::expf(s);
51     return v / (v + 1);
52 }
53
54 template <typename T, typename A>
55 T bounded_relu_fwd(T s, A alpha) {
56     s = s > 0 ? s : 0;
57     return s > alpha ? (T)(alpha) : s;
58 }
59
60 template <typename data_t>
61 void ref_activation(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, activation_test_params prm) {
62     size_t IW = src.dims()[3];
63     size_t IH = src.dims()[2];
64     size_t IC = src.dims()[1];
65     size_t MB = src.dims()[0];
66
67     const data_t *src_data = src.readOnly();
68     data_t *dst_data = dst.data();
69
70     for(int mb = 0; mb < MB; mb++) {
71         for(int c = 0; c < IC; c++) {
72             for(int h = 0; h < IH; h++) {
73                 for(int w = 0; w < IW; w++) {
74                     int idx = mb * IC * IH * IW
75                               + c * IH * IW
76                               + h * IW + w;
77
78                     switch (prm.alg) {
79                         case eltwise_relu:         dst_data[idx] = relu_fwd(src_data[idx], prm.alpha);         break;
80                         case eltwise_elu:          dst_data[idx] = elu_fwd(src_data[idx], prm.alpha);          break;
81                         case eltwise_logistic:     dst_data[idx] = logistic_fwd(src_data[idx]);                break;
82                         case eltwise_bounded_relu: dst_data[idx] = bounded_relu_fwd(src_data[idx], prm.alpha); break;
83                         default: assert(!"unknown alg_kind");
84                     }
85                 }
86             }
87         }
88     }
89 }
90
91 class MKLDNNGraphActivationTests: public TestsCommon,
92                                      public WithParamInterface<activation_test_params> {
93     std::string model_t = R"V0G0N(
94 <Net Name="Activation" version="2" precision="FP32" batch="1">
95     <layers>
96         <layer name="in1" type="Input" precision="FP32" id="0">
97             <output>
98                 <port id="0">
99                     <dim>_IN_</dim>
100                     <dim>_IC_</dim>
101                     <dim>_IH_</dim>
102                     <dim>_IW_</dim>
103                 </port>
104             </output>
105         </layer>
106         <layer name="activation" id="1" type="_LT_" precision="FP32">
107             <data _P1_NAME_="_P1_VAL_" _P2_NAME_="_P2_VAL_" PrimitivesPriority="_IMPLS_"/>
108             <input>
109                 <port id="1">
110                     <dim>_IN_</dim>
111                     <dim>_IC_</dim>
112                     <dim>_IH_</dim>
113                     <dim>_IW_</dim>
114                 </port>
115             </input>
116             <output>
117                 <port id="2">
118                     <dim>_IN_</dim>
119                     <dim>_IC_</dim>
120                     <dim>_IH_</dim>
121                     <dim>_IW_</dim>
122                 </port>
123             </output>
124         </layer>
125     </layers>
126     <edges>
127         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
128     </edges>
129 </Net>
130 )V0G0N";
131
132 protected:
133     virtual void TearDown() {
134     }
135
136     std::string getModel(activation_test_params p) {
137         std::string model = model_t;
138
139         switch (p.alg) {
140             case eltwise_relu:         REPLACE_WITH_STR(model, "_LT_", "ReLU"); break;
141             case eltwise_elu:          REPLACE_WITH_STR(model, "_LT_", "ELU"); break;
142             case eltwise_logistic:     REPLACE_WITH_STR(model, "_LT_", "Sigmoid"); break;
143             case eltwise_bounded_relu: REPLACE_WITH_STR(model, "_LT_", "ReLU6"); break;
144             default: assert(!"unknown alg_kind");
145         }
146
147         if (p.alg == eltwise_relu)
148             REPLACE_WITH_STR(model, "_P1_NAME_", "negative_slope");
149         else if (p.alg == eltwise_bounded_relu)
150             REPLACE_WITH_STR(model, "_P1_NAME_", "n");
151         else
152             REPLACE_WITH_STR(model, "_P1_NAME_", "alpha");
153         REPLACE_WITH_NUM(model, "_P1_VAL_", p.alpha);
154
155         REPLACE_WITH_STR(model, "_P2_NAME_", "beta");
156         REPLACE_WITH_NUM(model, "_P2_VAL_", p.beta);
157
158         REPLACE_WITH_NUM(model, "_IW_", p.in.w);
159         REPLACE_WITH_NUM(model, "_IH_", p.in.h);
160         REPLACE_WITH_NUM(model, "_IC_", p.in.c);
161         REPLACE_WITH_NUM(model, "_IN_", p.in.n);
162
163         std::string impls;
164         for (const auto& preferType : p.preferTypes) {
165             if (!impls.empty())
166                 impls += ",";
167             impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
168         }
169         REPLACE_WITH_STR(model, "_IMPLS_", impls);
170
171         return model;
172     }
173
174     virtual void SetUp() {
175         try {
176             TestsCommon::SetUp();
177             activation_test_params p = ::testing::WithParamInterface<activation_test_params>::GetParam();
178             std::string model = getModel(p);
179
180             InferenceEngine::CNNNetReader net_reader;
181             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
182
183             MKLDNNGraphTestClass graph;
184             graph.CreateGraph(net_reader.getNetwork());
185             auto& nodes = graph.getNodes();
186             for (int i = 0; i < nodes.size(); i++) {
187                 if (nodes[i]->getType() == MKLDNNPlugin::Activation) {
188                     ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
189                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
190                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
191                     }
192                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
193                     ASSERT_EQ(p.selectedType,
194                               nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
195                 }
196             }
197
198             InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
199
200             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
201             src->allocate();
202             fill_data(src->buffer(), src->size());
203
204             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
205
206             if (srcPtr == nullptr)
207                 FAIL() << "Cannot cast blob to TBlob<float>.";
208
209             InferenceEngine::BlobMap srcs;
210             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
211
212             InferenceEngine::OutputsDataMap out;
213             out = net_reader.getNetwork().getOutputsInfo();
214             InferenceEngine::BlobMap outputBlobs;
215
216             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
217
218             InferenceEngine::TBlob<float>::Ptr output;
219             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
220             output->allocate();
221             outputBlobs[item.first] = output;
222
223             graph.Infer(srcs, outputBlobs);
224
225             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
226             dst_ref.allocate();
227
228             ref_activation(*srcPtr, dst_ref, p);
229
230             compare(*output, dst_ref);
231         } catch (const InferenceEngine::details::InferenceEngineException &e) {
232             FAIL() << e.what();
233         }
234     }
235 };
236
237 TEST_P(MKLDNNGraphActivationTests, TestsActivation) {}
238
239 INSTANTIATE_TEST_CASE_P(
240         TestsActivation, MKLDNNGraphActivationTests,
241         ::testing::Values(
242                 activation_test_params{eltwise_relu, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
243                 activation_test_params{eltwise_relu, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
244                 activation_test_params{eltwise_relu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
245                 activation_test_params{eltwise_relu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
246                 activation_test_params{eltwise_elu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
247                 activation_test_params{eltwise_elu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
248                 activation_test_params{eltwise_elu, 1.0f, 1.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
249                 activation_test_params{eltwise_elu, 1.0f, 1.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
250                 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
251                 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
252                 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
253                 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
254                 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
255                 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
256                 activation_test_params{eltwise_relu, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
257                 activation_test_params{eltwise_relu, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
258                 activation_test_params{eltwise_relu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
259                 activation_test_params{eltwise_relu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
260                 activation_test_params{eltwise_elu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
261                 activation_test_params{eltwise_elu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
262                 activation_test_params{eltwise_elu, 1.0f, 1.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
263                 activation_test_params{eltwise_elu, 1.0f, 1.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
264                 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
265                 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
266                 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
267                 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
268                 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
269                 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}
270         ));
271
272 class MKLDNNGraphDynBatchActivationTests: public MKLDNNGraphActivationTests {
273 protected:
274     virtual void SetUp() {
275         try {
276             TestsCommon::SetUp();
277             activation_test_params p = ::testing::WithParamInterface<activation_test_params>::GetParam();
278             std::string model = getModel(p);
279             size_t MB = p.in.n;
280             if (MB < 2)
281                 MB = 2;
282
283             InferenceEngine::CNNNetReader net_reader;
284             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
285             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
286             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
287             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
288             InferenceEngine::ResponseDesc resp;
289             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
290             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
291
292             MKLDNNGraphTestClass graph;
293             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
294             graph.CreateGraph(net_reader.getNetwork());
295
296             InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
297
298             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
299             src->allocate();
300             fill_data(src->buffer(), src->size());
301
302             auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
303
304             if (srcPtr == nullptr)
305                 FAIL() << "Cannot cast blob to TBlob<float>.";
306
307             InferenceEngine::BlobMap srcs;
308             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
309
310             InferenceEngine::OutputsDataMap out;
311             out = net_reader.getNetwork().getOutputsInfo();
312             InferenceEngine::BlobMap outputBlobs;
313
314             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
315
316             InferenceEngine::TBlob<float>::Ptr output;
317             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
318             output->allocate();
319             outputBlobs[item.first] = output;
320
321             auto checkActivation = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
322                 return node->getType() == MKLDNNPlugin::Activation;
323             };
324
325             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkActivation);
326             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkActivation);
327         } catch (const InferenceEngine::details::InferenceEngineException &e) {
328             FAIL() << e.what();
329         }
330     }
331 };
332
333 TEST_P(MKLDNNGraphDynBatchActivationTests, TestsDynBatchActivation) {}
334
335
336 INSTANTIATE_TEST_CASE_P(
337         TestsDynBatchActivation, MKLDNNGraphDynBatchActivationTests,
338         ::testing::Values(
339                 activation_test_params{eltwise_relu, 0.0f, 0.0f, {2, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
340                 activation_test_params{eltwise_relu, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
341                 activation_test_params{eltwise_relu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
342                 activation_test_params{eltwise_relu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
343                 activation_test_params{eltwise_elu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
344                 activation_test_params{eltwise_elu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
345                 activation_test_params{eltwise_elu, 1.0f, 1.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
346                 activation_test_params{eltwise_elu, 1.0f, 1.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
347                 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
348                 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
349                 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
350                 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
351                 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
352                 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
353                 activation_test_params{eltwise_relu, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
354                 activation_test_params{eltwise_relu, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
355                 activation_test_params{eltwise_relu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
356                 activation_test_params{eltwise_relu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
357                 activation_test_params{eltwise_elu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
358                 activation_test_params{eltwise_elu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
359                 activation_test_params{eltwise_elu, 1.0f, 1.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
360                 activation_test_params{eltwise_elu, 1.0f, 1.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
361                 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
362                 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
363                 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
364                 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
365                 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
366                 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}
367         ));