1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
10 #include "test_graph.hpp"
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
16 using namespace ::testing;
18 using namespace mkldnn;
20 struct activation_test_params {
21 mkldnn::algorithm alg;
34 MKLDNNPlugin::impl_desc_type selectedType;
35 std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
37 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
40 template <typename T, typename A> inline T relu_fwd(T s, A alpha) {
41 return s > 0 ? s : static_cast<T>(s * alpha);
44 template <typename T, typename A> T elu_fwd(T s, A alpha) {
45 return s > 0 ? s : static_cast<T>(alpha * (::expf(s) - 1));
54 template <typename T, typename A>
55 T bounded_relu_fwd(T s, A alpha) {
57 return s > alpha ? (T)(alpha) : s;
60 template <typename data_t>
61 void ref_activation(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, activation_test_params prm) {
62 size_t IW = src.dims()[3];
63 size_t IH = src.dims()[2];
64 size_t IC = src.dims()[1];
65 size_t MB = src.dims()[0];
67 const data_t *src_data = src.readOnly();
68 data_t *dst_data = dst.data();
70 for(int mb = 0; mb < MB; mb++) {
71 for(int c = 0; c < IC; c++) {
72 for(int h = 0; h < IH; h++) {
73 for(int w = 0; w < IW; w++) {
74 int idx = mb * IC * IH * IW
79 case eltwise_relu: dst_data[idx] = relu_fwd(src_data[idx], prm.alpha); break;
80 case eltwise_elu: dst_data[idx] = elu_fwd(src_data[idx], prm.alpha); break;
81 case eltwise_logistic: dst_data[idx] = logistic_fwd(src_data[idx]); break;
82 case eltwise_bounded_relu: dst_data[idx] = bounded_relu_fwd(src_data[idx], prm.alpha); break;
83 default: assert(!"unknown alg_kind");
91 class MKLDNNGraphActivationTests: public TestsCommon,
92 public WithParamInterface<activation_test_params> {
93 std::string model_t = R"V0G0N(
94 <Net Name="Activation" version="2" precision="FP32" batch="1">
96 <layer name="in1" type="Input" precision="FP32" id="0">
106 <layer name="activation" id="1" type="_LT_" precision="FP32">
107 <data _P1_NAME_="_P1_VAL_" _P2_NAME_="_P2_VAL_" PrimitivesPriority="_IMPLS_"/>
127 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
133 virtual void TearDown() {
136 std::string getModel(activation_test_params p) {
137 std::string model = model_t;
140 case eltwise_relu: REPLACE_WITH_STR(model, "_LT_", "ReLU"); break;
141 case eltwise_elu: REPLACE_WITH_STR(model, "_LT_", "ELU"); break;
142 case eltwise_logistic: REPLACE_WITH_STR(model, "_LT_", "Sigmoid"); break;
143 case eltwise_bounded_relu: REPLACE_WITH_STR(model, "_LT_", "ReLU6"); break;
144 default: assert(!"unknown alg_kind");
147 if (p.alg == eltwise_relu)
148 REPLACE_WITH_STR(model, "_P1_NAME_", "negative_slope");
149 else if (p.alg == eltwise_bounded_relu)
150 REPLACE_WITH_STR(model, "_P1_NAME_", "n");
152 REPLACE_WITH_STR(model, "_P1_NAME_", "alpha");
153 REPLACE_WITH_NUM(model, "_P1_VAL_", p.alpha);
155 REPLACE_WITH_STR(model, "_P2_NAME_", "beta");
156 REPLACE_WITH_NUM(model, "_P2_VAL_", p.beta);
158 REPLACE_WITH_NUM(model, "_IW_", p.in.w);
159 REPLACE_WITH_NUM(model, "_IH_", p.in.h);
160 REPLACE_WITH_NUM(model, "_IC_", p.in.c);
161 REPLACE_WITH_NUM(model, "_IN_", p.in.n);
164 for (const auto& preferType : p.preferTypes) {
167 impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
169 REPLACE_WITH_STR(model, "_IMPLS_", impls);
174 virtual void SetUp() {
176 TestsCommon::SetUp();
177 activation_test_params p = ::testing::WithParamInterface<activation_test_params>::GetParam();
178 std::string model = getModel(p);
180 InferenceEngine::CNNNetReader net_reader;
181 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
183 MKLDNNGraphTestClass graph;
184 graph.CreateGraph(net_reader.getNetwork());
185 auto& nodes = graph.getNodes();
186 for (int i = 0; i < nodes.size(); i++) {
187 if (nodes[i]->getType() == MKLDNNPlugin::Activation) {
188 ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
189 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
190 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
192 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
193 ASSERT_EQ(p.selectedType,
194 nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
198 InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
200 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
202 fill_data(src->buffer(), src->size());
204 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
206 if (srcPtr == nullptr)
207 FAIL() << "Cannot cast blob to TBlob<float>.";
209 InferenceEngine::BlobMap srcs;
210 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
212 InferenceEngine::OutputsDataMap out;
213 out = net_reader.getNetwork().getOutputsInfo();
214 InferenceEngine::BlobMap outputBlobs;
216 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
218 InferenceEngine::TBlob<float>::Ptr output;
219 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
221 outputBlobs[item.first] = output;
223 graph.Infer(srcs, outputBlobs);
225 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
228 ref_activation(*srcPtr, dst_ref, p);
230 compare(*output, dst_ref);
231 } catch (const InferenceEngine::details::InferenceEngineException &e) {
237 TEST_P(MKLDNNGraphActivationTests, TestsActivation) {}
239 INSTANTIATE_TEST_CASE_P(
240 TestsActivation, MKLDNNGraphActivationTests,
242 activation_test_params{eltwise_relu, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
243 activation_test_params{eltwise_relu, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
244 activation_test_params{eltwise_relu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
245 activation_test_params{eltwise_relu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
246 activation_test_params{eltwise_elu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
247 activation_test_params{eltwise_elu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
248 activation_test_params{eltwise_elu, 1.0f, 1.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
249 activation_test_params{eltwise_elu, 1.0f, 1.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
250 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
251 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
252 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
253 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
254 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
255 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
256 activation_test_params{eltwise_relu, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
257 activation_test_params{eltwise_relu, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
258 activation_test_params{eltwise_relu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
259 activation_test_params{eltwise_relu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
260 activation_test_params{eltwise_elu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
261 activation_test_params{eltwise_elu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
262 activation_test_params{eltwise_elu, 1.0f, 1.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
263 activation_test_params{eltwise_elu, 1.0f, 1.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
264 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
265 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
266 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
267 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
268 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
269 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}
272 class MKLDNNGraphDynBatchActivationTests: public MKLDNNGraphActivationTests {
274 virtual void SetUp() {
276 TestsCommon::SetUp();
277 activation_test_params p = ::testing::WithParamInterface<activation_test_params>::GetParam();
278 std::string model = getModel(p);
283 InferenceEngine::CNNNetReader net_reader;
284 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
285 InferenceEngine::CNNNetwork network = net_reader.getNetwork();
286 auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
287 ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
288 InferenceEngine::ResponseDesc resp;
289 InferenceEngine::StatusCode sts = implNet->setBatchSizeReshape(MB, &resp);
290 ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
292 MKLDNNGraphTestClass graph;
293 graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
294 graph.CreateGraph(net_reader.getNetwork());
296 InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
298 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
300 fill_data(src->buffer(), src->size());
302 auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
304 if (srcPtr == nullptr)
305 FAIL() << "Cannot cast blob to TBlob<float>.";
307 InferenceEngine::BlobMap srcs;
308 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
310 InferenceEngine::OutputsDataMap out;
311 out = net_reader.getNetwork().getOutputsInfo();
312 InferenceEngine::BlobMap outputBlobs;
314 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
316 InferenceEngine::TBlob<float>::Ptr output;
317 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
319 outputBlobs[item.first] = output;
321 auto checkActivation = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
322 return node->getType() == MKLDNNPlugin::Activation;
325 graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkActivation);
326 graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkActivation);
327 } catch (const InferenceEngine::details::InferenceEngineException &e) {
333 TEST_P(MKLDNNGraphDynBatchActivationTests, TestsDynBatchActivation) {}
336 INSTANTIATE_TEST_CASE_P(
337 TestsDynBatchActivation, MKLDNNGraphDynBatchActivationTests,
339 activation_test_params{eltwise_relu, 0.0f, 0.0f, {2, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
340 activation_test_params{eltwise_relu, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
341 activation_test_params{eltwise_relu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
342 activation_test_params{eltwise_relu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
343 activation_test_params{eltwise_elu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
344 activation_test_params{eltwise_elu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
345 activation_test_params{eltwise_elu, 1.0f, 1.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
346 activation_test_params{eltwise_elu, 1.0f, 1.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
347 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
348 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
349 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
350 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
351 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::jit},
352 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::jit},
353 activation_test_params{eltwise_relu, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
354 activation_test_params{eltwise_relu, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
355 activation_test_params{eltwise_relu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
356 activation_test_params{eltwise_relu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
357 activation_test_params{eltwise_elu, 0.5f, 0.5f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
358 activation_test_params{eltwise_elu, 0.5f, 0.5f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
359 activation_test_params{eltwise_elu, 1.0f, 1.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
360 activation_test_params{eltwise_elu, 1.0f, 1.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
361 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
362 activation_test_params{eltwise_logistic, 0.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
363 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
364 activation_test_params{eltwise_bounded_relu, 6.0f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
365 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {1, 32, 128, 256}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
366 activation_test_params{eltwise_bounded_relu, 0.1f, 0.0f, {4, 3, 228, 228}, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}