1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "test_graph.hpp"
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <extension/ext_list.hpp>
14 #include "tests_common.hpp"
17 using namespace ::testing;
19 using namespace mkldnn;
22 struct math_test_params {
23 std::string math_function;
24 InferenceEngine::SizeVector in_out;
25 std::vector<float> input_tensor;
26 std::vector<float> alpha;
27 std::vector<float> beta;
28 std::vector<float> gamma;
29 std::vector<float> reference;
31 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
35 std::string math_function,
36 InferenceEngine::TBlob<float> &src,
37 std::vector<float> alpha,
38 std::vector<float> beta,
39 std::vector<float> gamma,
40 InferenceEngine::TBlob<float> &dst
43 float* src_data = src.data();
44 float *dst_data = dst.data();
45 size_t dst_size = dst.size();
47 if (math_function == "Erf") {
48 for (i = 0; i < dst_size; i++) {
49 dst_data[i] = std::erf(src_data[i]);
51 } else if (math_function == "Abs") {
52 for (i = 0; i < dst_size; i++) {
53 dst_data[i] = (std::abs)(src_data[i]);
55 } else if (math_function == "Acos") {
56 for (i = 0; i < dst_size; i++) {
57 dst_data[i] = acosf(src_data[i]);
59 } else if (math_function == "Acosh") {
60 for (i = 0; i < dst_size; i++) {
61 dst_data[i] = acoshf(src_data[i]);
63 } else if (math_function == "Asin") {
64 for (i = 0; i < dst_size; i++) {
65 dst_data[i] = asinf(src_data[i]);
67 } else if (math_function == "Asinh") {
68 for (i = 0; i < dst_size; i++) {
69 dst_data[i] = asinhf(src_data[i]);
71 } else if (math_function == "Atan") {
72 for (i = 0; i < dst_size; i++) {
73 dst_data[i] = atanf(src_data[i]);
75 } else if (math_function == "Atanh") {
76 for (i = 0; i < dst_size; i++) {
77 dst_data[i] = atanhf(src_data[i]);
79 } else if (math_function == "Ceil") {
80 for (i = 0; i < dst_size; i++) {
81 dst_data[i] = ceilf(src_data[i]);
83 } else if (math_function == "Cos") {
84 for (i = 0; i < dst_size; i++) {
85 dst_data[i] = cosf(src_data[i]);
87 } else if (math_function == "Cosh") {
88 for (i = 0; i < dst_size; i++) {
89 dst_data[i] = coshf(src_data[i]);
91 } else if (math_function == "Floor") {
92 for (i = 0; i < dst_size; i++) {
93 dst_data[i] = floorf(src_data[i]);
95 } else if (math_function == "HardSigmoid") {
96 alpha[0] = (alpha[0] == 0.0f) ? 0.2f : alpha[0];
97 beta[0] = (beta[0] == 0.0f) ? 0.5f : beta[0];
98 for (i = 0; i < dst_size; i++) {
99 dst_data[i] = (std::max)(0.f, (std::min)(1.f, alpha[0] * src_data[i] + beta[0]));
101 } else if (math_function == "Log") {
102 for (i = 0; i < dst_size; i++) {
103 dst_data[i] = logf(src_data[i]);
105 } else if (math_function == "Neg") {
106 for (i = 0; i < dst_size; i++) {
107 dst_data[i] = -src_data[i];
109 } else if (math_function == "Reciprocal") {
110 for (i = 0; i < dst_size; i++) {
111 dst_data[i] = 1.0f / src_data[i];
113 } else if (math_function == "Selu") {
114 alpha[0] = (alpha[0] == 0.0f) ? 1.67326f : alpha[0];
115 gamma[0] = (gamma[0] == 0.0f) ? 1.0507f : gamma[0];
116 for (i = 0; i < dst_size; i++) {
117 float x = src_data[i];
118 dst_data[i] = (x > 0.0f) ? (gamma[0] * x) : (gamma[0] * alpha[0] * (exp(x) - 1.0f));
120 } else if (math_function == "Sign") {
121 for (i = 0; i < dst_size; i++) {
122 if (src_data[i] > 0.0f) dst_data[i] = 1.0f;
123 else if (src_data[i] < 0.0f) dst_data[i] = -1.0f;
124 else dst_data[i] = 0.0f;
126 } else if (math_function == "Sin") {
127 for (i = 0; i < dst_size; i++) {
128 dst_data[i] = sinf(src_data[i]);
130 } else if (math_function == "Sinh") {
131 for (i = 0; i < dst_size; i++) {
132 dst_data[i] = sinhf(src_data[i]);
134 } else if (math_function == "Softplus") {
135 for (i = 0; i < dst_size; i++) {
136 dst_data[i] = logf(expf(src_data[i]) + 1);
138 } else if (math_function == "Softsign") {
139 for (i = 0; i < dst_size; i++) {
140 float x = src_data[i];
141 dst_data[i] = x / (1.f + (std::abs)(x));
143 } else if (math_function == "Tan") {
144 for (i = 0; i < dst_size; i++) {
145 dst_data[i] = tanf(src_data[i]);
150 class MKLDNNCPUExtMathTests: public TestsCommon, public WithParamInterface<math_test_params> {
151 std::string model_t = R"V0G0N(
152 <net Name="Math_net" version="2" precision="FP32" batch="1">
154 <layer name="Input" type="Input" precision="FP32" id="1">
161 <layer name="math" id="2" type="_MATH_FUNCTION_" precision="FP32">
162 <data _ALPHA_ _BETA_ _GAMMA_/>
176 <edge from-layer="1" from-port="1" to-layer="2" to-port="1"/>
181 std::string getModel(math_test_params p) {
182 std::string model = model_t;
188 for (auto& dst : p.in_out) {
190 in_out += std::to_string(dst) + "</dim>\n";
193 REPLACE_WITH_STR(model, "_IN_OUT_", in_out);
194 REPLACE_WITH_STR(model, "_MATH_FUNCTION_", p.math_function);
196 if (p.alpha.size()) {
197 alpha = "alpha=\"" + std::to_string(p.alpha[0]) + "\"";
199 REPLACE_WITH_STR(model, "_ALPHA_", alpha);
202 beta = "beta=\"" + std::to_string(p.beta[0]) + "\"";
204 REPLACE_WITH_STR(model, "_BETA_", beta);
206 if (p.gamma.size()) {
207 gamma = "gamma=\"" + std::to_string(p.gamma[0]) + "\"";
209 REPLACE_WITH_STR(model, "_GAMMA_", gamma);
213 template <typename data_t>
214 static void fill_data_dbgval(data_t *data, size_t size) {
215 for (size_t i = 0; i < size; i++) {
216 data[i] = static_cast<data_t>(i & (sizeof(data_t) * 8 - 1));
220 virtual void TearDown() {
223 virtual void SetUp() {
225 TestsCommon::SetUp();
226 math_test_params p = ::testing::WithParamInterface<math_test_params>::GetParam();
227 std::string model = getModel(p);
229 InferenceEngine::CNNNetReader net_reader;
230 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
232 InferenceEngine::Extension cpuExt(make_so_name("cpu_extension"));
233 MKLDNNPlugin::MKLDNNExtensionManager::Ptr extMgr(new MKLDNNPlugin::MKLDNNExtensionManager());
234 extMgr->AddExtension(InferenceEngine::IExtensionPtr(&cpuExt, [](InferenceEngine::IExtension*){}));
236 MKLDNNGraphTestClass graph;
237 graph.CreateGraph(net_reader.getNetwork(), extMgr);
240 InferenceEngine::Blob::Ptr srcData = InferenceEngine::make_shared_blob<float>({ InferenceEngine::Precision::FP32, p.in_out, InferenceEngine::TensorDesc::getLayoutByDims(p.in_out) });
242 if (p.input_tensor.size())
243 memcpy(srcData->buffer(), &p.input_tensor[0], sizeof(float)*p.input_tensor.size());
245 if (p.math_function == "Erf")
246 fill_data_sine(srcData->buffer(), srcData->size(), 0.f, 3.f, 1.f);
248 fill_data(srcData->buffer(), srcData->size());
250 auto * srcDataPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(srcData.get());
251 if (srcDataPtr == nullptr)
252 FAIL() << "Cannot cast blob to TBlob<float>.";
255 InferenceEngine::OutputsDataMap out;
256 out = net_reader.getNetwork().getOutputsInfo();
257 InferenceEngine::BlobMap outputBlobs;
259 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
261 InferenceEngine::TBlob<float>::Ptr output;
262 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
264 outputBlobs[item.first] = output;
267 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
271 ref_math(p.math_function, *srcDataPtr, p.alpha, p.beta, p.gamma, dst_ref);
272 if (p.reference.size()) {
273 for (size_t i = 0; i < p.reference.size(); i++) {
274 ASSERT_NEAR(dst_ref.data()[i], p.reference[i], 0.00001f);
278 InferenceEngine::BlobMap srcs;
279 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("Input", srcData));
282 graph.Infer(srcs, outputBlobs);
283 float threshold = p.math_function == "Erf" ? 0.0001f : 0.00001f;
284 compare(*output, dst_ref, threshold);
285 } catch (const InferenceEngine::details::InferenceEngineException &e) {
291 TEST_P(MKLDNNCPUExtMathTests, TestsMath) {}
293 INSTANTIATE_TEST_CASE_P(
294 TestsMath, MKLDNNCPUExtMathTests,
296 // Params: math_function, in_out, input_tensor, alpha, beta, gamma, reference
297 math_test_params{ "Erf", { 1, 1, 12, 256 }, {},{},{},{}, {} },
298 math_test_params{ "Erf", { 12, 256, 3 },{},{},{},{},{} },
299 math_test_params{ "Erf", { 3, 4 },{},{},{},{},{} },
300 math_test_params{ "Erf", { 20 },{},{},{},{},{} },
301 math_test_params{ "Erf", { 12, 4, 9, 8 },{},{},{},{},{} },
302 math_test_params{ "Erf", { 6, 12, 4, 9, 8, 10, 3 },{},{},{},{},{} },
303 math_test_params{ "Abs",{ 3 },{ -1, 0, 1 },{},{},{},{ 1, 0, 1 } },
304 math_test_params{ "Acos",{ 3 },{ -0.5f, 0.f, 0.5f },{},{},{},{ 2.09439516f, 1.57079637f, 1.04719758f } },
305 math_test_params{ "Acosh",{ 3 },{ 1.f, 2.0f, 3.0f },{},{},{},{} },
306 math_test_params{ "Asin",{ 3 },{ -0.5f, 0.f, 0.5f },{},{},{},{ -0.523598790f, 0.0f, 0.523598790f } },
307 math_test_params{ "Asinh",{ 3 },{ -0.5f, 0.f, 0.5f },{},{},{},{ } },
308 math_test_params{ "Atan",{ 3 },{ -1, 0, 1 },{},{},{},{ -0.785398185f, 0.0f, 0.785398185f } },
309 math_test_params{ "Atanh",{ 3 },{ -0.5f, 0.f, 0.5f },{},{},{},{ } },
310 math_test_params{ "Ceil",{ 2 },{ -1.5f, 1.2f },{},{},{},{ -1, 2 } },
311 math_test_params{ "Cos",{ 3 },{ -1, 0, 1 },{},{},{},{ 0.540302336f, 1.0f, 0.540302336f } },
312 math_test_params{ "Cosh",{ 3 },{ -0.5f, 0.f, 0.5f },{},{},{},{ } },
313 math_test_params{ "Floor",{ 3 },{-1.5f, 1.2f, 2.f},{},{},{},{-2, 1, 2} },
314 math_test_params{ "HardSigmoid",{ 3 },{ -1, 0, 1 },{0.5f},{0.6f},{},{ 0.1f, 0.6f, 1.f } },
315 math_test_params{ "Log",{ 2 },{ 1, 10 },{},{},{},{ 0.f, 2.30258512f } },
316 math_test_params{ "Neg",{ 3 },{ -1, 0, 1 },{},{},{},{ 1, 0, -1 } },
317 math_test_params{ "Reciprocal",{ 3 },{ -1, 0.1, 1 },{2},{},{3},{-1, 10, 1} },
318 math_test_params{ "Selu",{ 3 },{ -1, 0, 1 },{2},{},{3},{ -3.79272318f, 0.f, 3.f } },
319 math_test_params{ "Sign",{ 3 },{ -0.5f, 0.f, 0.5f },{},{},{},{-1, 0, 1} },
320 math_test_params{ "Sin",{ 3 },{ -1, 0, 1 },{},{},{},{ -0.841470957f, 0.0f, 0.841470957f } },
321 math_test_params{ "Sinh",{ 3 },{ -0.5f, 0.f, 0.5f },{},{},{},{ } },
322 math_test_params{ "Softplus",{ 3 },{ -1, 0, 1 },{},{},{},{ 0.31326166f, 0.69314718f, 1.31326163f } },
323 math_test_params{ "Softsign",{ 3 },{ -1, 0, 1 },{},{},{},{ -0.5f, 0.f, 0.5f } },
324 math_test_params{ "Tan",{ 3 },{ -1, 0, 1 },{},{},{},{ -1.55740774f, 0.0f, 1.55740774f } }