1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "test_graph.hpp"
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include "tests_common.hpp"
16 using namespace ::testing;
18 using namespace mkldnn;
21 struct relu_test_params {
22 // Formats: NCHW, NCDHW
29 MKLDNNPlugin::impl_desc_type selectedType;
31 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
34 template <typename data_t>
35 void ref_relu(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, relu_test_params prm)
37 auto dims_size = src.dims().size();
39 size_t IW = src.dims()[dims_size - 1];
40 size_t IH = src.dims()[dims_size - 2];
41 size_t ID = dims_size == 5 ? src.dims()[dims_size - 3] : 1u;
42 size_t IC = src.dims()[1];
44 const data_t *src_data = src.readOnly();
45 data_t *dst_data = dst.data();
47 for (uint32_t c = 0; c < IC; c++) {
48 for (uint32_t d = 0; d < ID; d++) {
49 for (uint32_t h = 0; h < IH; h++) {
50 for (uint32_t w = 0; w < IW; w++) {
51 uint32_t oidx = c * ID * IH * IW
56 dst_data[oidx] = src_data[oidx] >= 0.0 ?
58 src_data[oidx] * prm.n_clope;
65 class MKLDNNGraphReluTests: public TestsCommon,
66 public WithParamInterface<relu_test_params> {
67 std::string model_t = R"V0G0N(
68 <Net Name="Relu_Only" version="3" precision="FP32" batch="1">
70 <layer name="in1" type="Input" precision="FP32" id="0">
81 <layer name="norm" id="1" type="ReLU" precision="FP32">
103 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
108 std::string getModel(relu_test_params p) {
109 std::string model = model_t;
110 auto dims_size = p.dims.size();
114 REMOVE_LINE(model, "<dim>_IH_</dim>");
116 REMOVE_LINE(model, "<dim>_ID_</dim>");
119 REPLACE_WITH_NUM(model, "_IW_", p.dims[dims_size - 1]);
120 REPLACE_WITH_NUM(model, "_IC_", p.dims[1]);
121 REPLACE_WITH_NUM(model, "_IN_", p.dims[0]);
124 REPLACE_WITH_NUM(model, "_ID_", p.dims[dims_size - 3]);
126 REPLACE_WITH_NUM(model, "_IH_", p.dims[dims_size - 2]);
133 virtual void TearDown() {
136 virtual void SetUp() {
138 TestsCommon::SetUp();
139 relu_test_params p = ::testing::WithParamInterface<relu_test_params>::GetParam();
140 std::string model = getModel(p);
142 InferenceEngine::CNNNetReader net_reader;
143 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
145 MKLDNNGraphTestClass graph;
146 graph.CreateGraph(net_reader.getNetwork());
147 auto& nodes = graph.getNodes();
148 for (int i = 0; i < nodes.size(); i++) {
149 if (nodes[i]->getType() == MKLDNNPlugin::Activation) {
150 ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
151 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
152 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
154 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
155 ASSERT_TRUE(nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() | p.selectedType);
159 InferenceEngine::SizeVector dims_src = p.dims;
160 InferenceEngine::Layout layout = InferenceEngine::ANY;
161 switch (p.dims.size()) {
163 layout = InferenceEngine::NCHW;
166 layout = InferenceEngine::NCDHW;
170 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
172 fill_data(src->buffer(), src->size());
174 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
176 if (srcPtr == nullptr)
177 FAIL() << "Cannot cast blob to TBlob<float>.";
179 InferenceEngine::BlobMap srcs;
180 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
182 InferenceEngine::OutputsDataMap out;
183 out = net_reader.getNetwork().getOutputsInfo();
184 InferenceEngine::BlobMap outputBlobs;
186 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
188 InferenceEngine::TBlob<float>::Ptr output;
189 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
191 outputBlobs[item.first] = output;
193 graph.Infer(srcs, outputBlobs);
195 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
198 ref_relu(*srcPtr, dst_ref, p);
200 compare(*output, dst_ref, 0.0005f);
201 } catch (const InferenceEngine::details::InferenceEngineException &e) {
207 TEST_P(MKLDNNGraphReluTests, TestsRelu) {}
210 INSTANTIATE_TEST_CASE_P(
211 TestsRelu, MKLDNNGraphReluTests,
214 {1, 3, 228, 228}, 0.0f, 5, MKLDNNPlugin::impl_desc_type::jit, {
215 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
216 ASSERT_TRUE(impl.getImplementationType() | MKLDNNPlugin::impl_desc_type::jit);
217 ASSERT_EQ(1, impl.getConfig().inConfs.size());
218 ASSERT_EQ(1, impl.getConfig().outConfs.size());
219 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
220 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
222 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
223 ASSERT_TRUE(impl.getImplementationType() | MKLDNNPlugin::impl_desc_type::jit);
224 ASSERT_EQ(1, impl.getConfig().inConfs.size());
225 ASSERT_EQ(1, impl.getConfig().outConfs.size());
226 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
227 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
231 {1, 64, 32, 32, 32}, 0.0f, 3, MKLDNNPlugin::impl_desc_type::ref_any, {
232 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
233 ASSERT_TRUE(impl.getImplementationType() | MKLDNNPlugin::impl_desc_type::ref_any);
234 ASSERT_EQ(1, impl.getConfig().inConfs.size());
235 ASSERT_EQ(1, impl.getConfig().outConfs.size());
236 ASSERT_EQ(InferenceEngine::Layout::NCDHW, impl.getConfig().inConfs.at(0).desc.getLayout());
237 ASSERT_EQ(InferenceEngine::Layout::NCDHW, impl.getConfig().outConfs.at(0).desc.getLayout());
239 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
240 ASSERT_TRUE(impl.getImplementationType() | MKLDNNPlugin::impl_desc_type::ref_any);
241 ASSERT_EQ(1, impl.getConfig().inConfs.size());
242 ASSERT_EQ(1, impl.getConfig().outConfs.size());
243 ASSERT_EQ(InferenceEngine::Layout::NCDHW, impl.getConfig().inConfs.at(0).desc.getLayout());
244 ASSERT_EQ(InferenceEngine::Layout::NCDHW, impl.getConfig().outConfs.at(0).desc.getLayout());