1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8 #include "test_graph.hpp"
9 #include "single_layer_common.hpp"
10 #include <mkldnn_plugin/mkldnn_extension_utils.h>
11 #include <inference_engine/cnn_network_impl.hpp>
12 #include "tests_common.hpp"
14 using namespace ::testing;
16 using namespace mkldnn;
18 struct activation_test_params {
19 mkldnn::algorithm alg;
23 // Formats: NCHW, NCDHW
28 MKLDNNPlugin::impl_desc_type selectedType;
29 std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
31 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
34 template <typename T, typename A> inline T relu_fwd(T s, A alpha) {
35 return s > 0 ? s : static_cast<T>(s * alpha);
38 template <typename T, typename A> T elu_fwd(T s, A alpha) {
39 return s > 0 ? s : static_cast<T>(alpha * (::expf(s) - 1));
48 template <typename T, typename A>
49 T bounded_relu_fwd(T s, A alpha) {
51 return s > alpha ? (T)(alpha) : s;
54 template <typename T> T tanh_fwd(T s) {
55 return static_cast<T>(::tanhf((float)s));
58 template <typename data_t>
59 void ref_activation(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, activation_test_params prm) {
60 auto dims_size = src.dims().size();
62 size_t IW = src.dims()[dims_size - 1];
63 size_t IH = src.dims()[dims_size - 2];
64 size_t ID = dims_size == 5 ? src.dims()[dims_size - 3] : 1u;
65 size_t IC = src.dims()[1];
66 size_t MB = src.dims()[0];
68 const data_t *src_data = src.readOnly();
69 data_t *dst_data = dst.data();
71 for(int mb = 0; mb < MB; mb++) {
72 for(int c = 0; c < IC; c++) {
73 for(int d = 0; d < ID; d++) {
74 for(int h = 0; h < IH; h++) {
75 for(int w = 0; w < IW; w++) {
76 int idx = mb * IC * ID * IH * IW
83 case eltwise_relu: dst_data[idx] = relu_fwd(src_data[idx], prm.alpha); break;
84 case eltwise_elu: dst_data[idx] = elu_fwd(src_data[idx], prm.alpha); break;
85 case eltwise_logistic: dst_data[idx] = logistic_fwd(src_data[idx]); break;
86 case eltwise_bounded_relu: dst_data[idx] = bounded_relu_fwd(src_data[idx], prm.alpha); break;
87 case eltwise_tanh: dst_data[idx] = tanh_fwd(src_data[idx]); break;
88 default: assert(!"unknown alg_kind");
97 class MKLDNNGraphActivationTests: public TestsCommon,
98 public WithParamInterface<activation_test_params> {
99 std::string model_t = R"V0G0N(
100 <Net Name="Activation" version="3" precision="FP32" batch="1">
102 <layer name="in1" type="Input" precision="FP32" id="0">
113 <layer name="activation" id="1" type="_LT_" precision="FP32">
114 <data _P1_ _P2_ PrimitivesPriority="_IMPLS_"/>
136 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
142 virtual void TearDown() {
145 std::string getModel(activation_test_params p) {
146 std::string model = model_t;
147 auto dims_size = p.dims.size();
151 REMOVE_LINE(model, "<dim>_IH_</dim>");
153 REMOVE_LINE(model, "<dim>_ID_</dim>");
157 case eltwise_relu: REPLACE_WITH_STR(model, "_LT_", "ReLU"); break;
158 case eltwise_elu: REPLACE_WITH_STR(model, "_LT_", "ELU"); break;
159 case eltwise_logistic: REPLACE_WITH_STR(model, "_LT_", "Sigmoid"); break;
160 case eltwise_bounded_relu: REPLACE_WITH_STR(model, "_LT_", "ReLU6"); break;
161 case eltwise_tanh: REPLACE_WITH_STR(model, "_LT_", "Activation"); break;
162 default: assert(!"unknown alg_kind");
166 if (p.alg == eltwise_relu) {
167 P1 = string("negative_slope=\"") + to_string(p.alpha) + string("\"");
168 P2 = string("beta=\"") + to_string(p.beta) + string("\"");
169 } else if (p.alg == eltwise_bounded_relu) {
170 P1 = string("n=\"") + to_string(p.alpha) + string("\"");
171 P2 = string("beta=\"") + to_string(p.beta) + string("\"");
172 } else if (p.alg == eltwise_tanh) {
173 P1 = string("type=\"tanh\"");
175 P1 = string("alpha=\"") + to_string(p.alpha) + string("\"");
176 P2 = string("beta=\"") + to_string(p.beta) + string("\"");
178 REPLACE_WITH_STR(model, "_P1_", P1);
179 REPLACE_WITH_STR(model, "_P2_", P2);
181 REPLACE_WITH_NUM(model, "_IW_", p.dims[dims_size - 1]);
182 REPLACE_WITH_NUM(model, "_IC_", p.dims[1]);
183 REPLACE_WITH_NUM(model, "_IN_", p.dims[0]);
186 REPLACE_WITH_NUM(model, "_ID_", p.dims[dims_size - 3]);
188 REPLACE_WITH_NUM(model, "_IH_", p.dims[dims_size - 2]);
192 for (const auto& preferType : p.preferTypes) {
195 impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
197 REPLACE_WITH_STR(model, "_IMPLS_", impls);
202 virtual void SetUp() {
204 TestsCommon::SetUp();
205 activation_test_params p = ::testing::WithParamInterface<activation_test_params>::GetParam();
206 std::string model = getModel(p);
208 InferenceEngine::CNNNetReader net_reader;
209 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
211 MKLDNNGraphTestClass graph;
212 graph.CreateGraph(net_reader.getNetwork());
213 auto& nodes = graph.getNodes();
214 for (int i = 0; i < nodes.size(); i++) {
215 if (nodes[i]->getType() == MKLDNNPlugin::Activation) {
216 ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
217 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
218 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
220 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
221 ASSERT_EQ(p.selectedType,
222 nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
226 InferenceEngine::SizeVector dims_src = p.dims;
227 InferenceEngine::Layout layout = InferenceEngine::ANY;
228 switch (p.dims.size()) {
230 layout = InferenceEngine::NCHW;
233 layout = InferenceEngine::NCDHW;
237 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
239 fill_data(src->buffer(), src->size());
241 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
243 if (srcPtr == nullptr)
244 FAIL() << "Cannot cast blob to TBlob<float>.";
246 InferenceEngine::BlobMap srcs;
247 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
249 InferenceEngine::OutputsDataMap out;
250 out = net_reader.getNetwork().getOutputsInfo();
251 InferenceEngine::BlobMap outputBlobs;
253 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
255 InferenceEngine::TBlob<float>::Ptr output;
256 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
258 outputBlobs[item.first] = output;
260 graph.Infer(srcs, outputBlobs);
262 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
265 ref_activation(*srcPtr, dst_ref, p);
267 compare(*output, dst_ref, 0.0005f);
268 } catch (const InferenceEngine::details::InferenceEngineException &e) {
274 TEST_P(MKLDNNGraphActivationTests, TestsActivation) {}
276 INSTANTIATE_TEST_CASE_P(
277 TestsActivation, MKLDNNGraphActivationTests,
279 activation_test_params{eltwise_relu, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
280 activation_test_params{eltwise_relu, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
281 activation_test_params{eltwise_relu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
282 activation_test_params{eltwise_relu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
283 activation_test_params{eltwise_elu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
284 activation_test_params{eltwise_elu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
285 activation_test_params{eltwise_elu, 1.0f, 1.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
286 activation_test_params{eltwise_elu, 1.0f, 1.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
287 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
288 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
289 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
290 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
291 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
292 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
293 activation_test_params{eltwise_relu, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
294 activation_test_params{eltwise_relu, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
295 activation_test_params{eltwise_relu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
296 activation_test_params{eltwise_relu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
297 activation_test_params{eltwise_elu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
298 activation_test_params{eltwise_elu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
299 activation_test_params{eltwise_elu, 1.0f, 1.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
300 activation_test_params{eltwise_elu, 1.0f, 1.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
301 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
302 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
303 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
304 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
305 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
306 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
308 activation_test_params{eltwise_tanh, 0.f, 0.f, {1, 1, 64, 64, 64}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}
311 class MKLDNNGraphDynBatchActivationTests: public MKLDNNGraphActivationTests {
313 virtual void SetUp() {
315 TestsCommon::SetUp();
316 activation_test_params p = ::testing::WithParamInterface<activation_test_params>::GetParam();
317 std::string model = getModel(p);
318 size_t MB = p.dims[0];
322 InferenceEngine::CNNNetReader net_reader;
323 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
324 InferenceEngine::CNNNetwork network = net_reader.getNetwork();
325 auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
326 ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
327 InferenceEngine::ResponseDesc resp;
328 InferenceEngine::StatusCode sts = implNet->setBatchSizeReshape(MB, &resp);
329 ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
331 MKLDNNGraphTestClass graph;
332 graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
333 graph.CreateGraph(net_reader.getNetwork());
335 InferenceEngine::SizeVector dims_src = p.dims;
336 InferenceEngine::Layout layout = InferenceEngine::ANY;
337 switch (p.dims.size()) {
339 layout = InferenceEngine::NCHW;
342 layout = InferenceEngine::NCDHW;
346 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
348 fill_data(src->buffer(), src->size());
350 auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
352 if (srcPtr == nullptr)
353 FAIL() << "Cannot cast blob to TBlob<float>.";
355 InferenceEngine::BlobMap srcs;
356 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
358 InferenceEngine::OutputsDataMap out;
359 out = net_reader.getNetwork().getOutputsInfo();
360 InferenceEngine::BlobMap outputBlobs;
362 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
364 InferenceEngine::TBlob<float>::Ptr output;
365 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
367 outputBlobs[item.first] = output;
369 auto checkActivation = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
370 return node->getType() == MKLDNNPlugin::Activation;
373 graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkActivation);
374 graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkActivation);
375 } catch (const InferenceEngine::details::InferenceEngineException &e) {
381 TEST_P(MKLDNNGraphDynBatchActivationTests, TestsDynBatchActivation) {}
384 INSTANTIATE_TEST_CASE_P(
385 TestsDynBatchActivation, MKLDNNGraphDynBatchActivationTests,
387 activation_test_params{eltwise_relu, 0.0f, 0.0f, {2, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
388 activation_test_params{eltwise_relu, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
389 activation_test_params{eltwise_relu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
390 activation_test_params{eltwise_relu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
391 activation_test_params{eltwise_elu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
392 activation_test_params{eltwise_elu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
393 activation_test_params{eltwise_elu, 1.0f, 1.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
394 activation_test_params{eltwise_elu, 1.0f, 1.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
395 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
396 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
397 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
398 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
399 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
400 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
401 activation_test_params{eltwise_relu, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
402 activation_test_params{eltwise_relu, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
403 activation_test_params{eltwise_relu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
404 activation_test_params{eltwise_relu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
405 activation_test_params{eltwise_elu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
406 activation_test_params{eltwise_elu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
407 activation_test_params{eltwise_elu, 1.0f, 1.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
408 activation_test_params{eltwise_elu, 1.0f, 1.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
409 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
410 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
411 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
412 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
413 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
414 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}