Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_relu_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8
9 #include "test_graph.hpp"
10
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include "tests_common.hpp"
14
15
16 using namespace ::testing;
17 using namespace std;
18 using namespace mkldnn;
19
20
21 struct relu_test_params {
22     // Formats: NCHW, NCDHW
23     vector<size_t> dims;
24
25     float n_clope;
26
27     size_t num_prim_desc;
28
29     MKLDNNPlugin::impl_desc_type selectedType;
30
31     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
32 };
33
34 template <typename data_t>
35 void ref_relu(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, relu_test_params prm)
36 {
37     auto dims_size = src.dims().size();
38     
39     size_t IW = src.dims()[dims_size - 1];
40     size_t IH = src.dims()[dims_size - 2];
41     size_t ID = dims_size == 5 ? src.dims()[dims_size - 3] : 1u;
42     size_t IC = src.dims()[1];
43
44     const data_t *src_data = src.readOnly();
45     data_t *dst_data = dst.data();
46
47     for (uint32_t c = 0; c < IC; c++) {
48         for (uint32_t d = 0; d < ID; d++) {
49             for (uint32_t h = 0; h < IH; h++) {
50                 for (uint32_t w = 0; w < IW; w++) {
51                     uint32_t oidx = c * ID * IH * IW
52                                     + d * IH * IW
53                                     + h * IW
54                                     + w;
55
56                     dst_data[oidx] = src_data[oidx] >= 0.0 ?
57                                      src_data[oidx] :
58                                      src_data[oidx] * prm.n_clope;
59                 }
60             }
61         }
62     }
63 }
64
65 class MKLDNNGraphReluTests: public TestsCommon,
66                                      public WithParamInterface<relu_test_params> {
67     std::string model_t = R"V0G0N(
68 <Net Name="Relu_Only" version="3" precision="FP32" batch="1">
69     <layers>
70         <layer name="in1" type="Input" precision="FP32" id="0">
71             <output>
72                 <port id="0">
73                     <dim>_IN_</dim>
74                     <dim>_IC_</dim>
75                     <dim>_ID_</dim>
76                     <dim>_IH_</dim>
77                     <dim>_IW_</dim>
78                 </port>
79             </output>
80         </layer>
81         <layer name="norm" id="1" type="ReLU" precision="FP32">
82             <input>
83                 <port id="1">
84                     <dim>_IN_</dim>
85                     <dim>_IC_</dim>
86                     <dim>_ID_</dim>
87                     <dim>_IH_</dim>
88                     <dim>_IW_</dim>
89                 </port>
90             </input>
91             <output>
92                 <port id="2">
93                     <dim>_IN_</dim>
94                     <dim>_IC_</dim>
95                     <dim>_ID_</dim>
96                     <dim>_IH_</dim>
97                     <dim>_IW_</dim>
98                 </port>
99             </output>
100         </layer>
101     </layers>
102     <edges>
103         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
104     </edges>
105 </Net>
106 )V0G0N";
107
108     std::string getModel(relu_test_params p) {
109         std::string model = model_t;
110         auto dims_size = p.dims.size();
111
112         switch (dims_size) {
113             case 3:
114                 REMOVE_LINE(model, "<dim>_IH_</dim>");
115             case 4:
116                 REMOVE_LINE(model, "<dim>_ID_</dim>");
117         }
118
119         REPLACE_WITH_NUM(model, "_IW_", p.dims[dims_size - 1]);
120         REPLACE_WITH_NUM(model, "_IC_", p.dims[1]);
121         REPLACE_WITH_NUM(model, "_IN_", p.dims[0]);
122         switch (dims_size) {
123             case 5:
124                 REPLACE_WITH_NUM(model, "_ID_", p.dims[dims_size - 3]);
125             case 4:
126                 REPLACE_WITH_NUM(model, "_IH_", p.dims[dims_size - 2]);
127         }
128
129         return model;
130     }
131
132 protected:
133     virtual void TearDown() {
134     }
135
136     virtual void SetUp() {
137         try {
138             TestsCommon::SetUp();
139             relu_test_params p = ::testing::WithParamInterface<relu_test_params>::GetParam();
140             std::string model = getModel(p);
141
142             InferenceEngine::CNNNetReader net_reader;
143             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
144
145             MKLDNNGraphTestClass graph;
146             graph.CreateGraph(net_reader.getNetwork());
147             auto& nodes = graph.getNodes();
148             for (int i = 0; i < nodes.size(); i++) {
149                 if (nodes[i]->getType() == MKLDNNPlugin::Activation) {
150                     ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
151                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
152                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
153                     }
154                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
155                     ASSERT_TRUE(nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() | p.selectedType);
156                 }
157             }
158
159             InferenceEngine::SizeVector dims_src = p.dims;
160             InferenceEngine::Layout layout = InferenceEngine::ANY;
161             switch (p.dims.size()) {
162                 case 4:
163                     layout = InferenceEngine::NCHW;
164                     break;
165                 case 5:
166                     layout = InferenceEngine::NCDHW;
167                     break;
168             }
169
170             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
171             src->allocate();
172             fill_data(src->buffer(), src->size());
173
174             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
175
176             if (srcPtr == nullptr)
177                 FAIL() << "Cannot cast blob to TBlob<float>.";
178
179             InferenceEngine::BlobMap srcs;
180             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
181
182             InferenceEngine::OutputsDataMap out;
183             out = net_reader.getNetwork().getOutputsInfo();
184             InferenceEngine::BlobMap outputBlobs;
185
186             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
187
188             InferenceEngine::TBlob<float>::Ptr output;
189             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
190             output->allocate();
191             outputBlobs[item.first] = output;
192
193             graph.Infer(srcs, outputBlobs);
194
195             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
196             dst_ref.allocate();
197
198             ref_relu(*srcPtr, dst_ref, p);
199
200             compare(*output, dst_ref, 0.0005f);
201         } catch (const InferenceEngine::details::InferenceEngineException &e) {
202             FAIL() << e.what();
203         }
204     }
205 };
206
207 TEST_P(MKLDNNGraphReluTests, TestsRelu) {}
208
209
210 INSTANTIATE_TEST_CASE_P(
211         TestsRelu, MKLDNNGraphReluTests,
212         ::testing::Values(
213                 relu_test_params{
214                         {1, 3, 228, 228}, 0.0f, 5, MKLDNNPlugin::impl_desc_type::jit, {
215                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
216                                     ASSERT_TRUE(impl.getImplementationType() | MKLDNNPlugin::impl_desc_type::jit);
217                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
218                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
219                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
220                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
221                                 },
222                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
223                                     ASSERT_TRUE(impl.getImplementationType() | MKLDNNPlugin::impl_desc_type::jit);
224                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
225                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
226                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
227                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
228                                 }
229                         }},
230                 relu_test_params{
231                         {1, 64, 32, 32, 32}, 0.0f, 3, MKLDNNPlugin::impl_desc_type::ref_any, {
232                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
233                                     ASSERT_TRUE(impl.getImplementationType() | MKLDNNPlugin::impl_desc_type::ref_any);
234                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
235                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
236                                     ASSERT_EQ(InferenceEngine::Layout::NCDHW, impl.getConfig().inConfs.at(0).desc.getLayout());
237                                     ASSERT_EQ(InferenceEngine::Layout::NCDHW, impl.getConfig().outConfs.at(0).desc.getLayout());
238                                 },
239                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
240                                     ASSERT_TRUE(impl.getImplementationType() | MKLDNNPlugin::impl_desc_type::ref_any);
241                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
242                                     ASSERT_EQ(1, impl.getConfig().outConfs.size());
243                                     ASSERT_EQ(InferenceEngine::Layout::NCDHW, impl.getConfig().inConfs.at(0).desc.getLayout());
244                                     ASSERT_EQ(InferenceEngine::Layout::NCDHW, impl.getConfig().outConfs.at(0).desc.getLayout());
245                                 }
246                         }}
247         ));