Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_fullyconnected_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8
9 #include "test_graph.hpp"
10
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
15
16
17 using namespace ::testing;
18 using namespace std;
19 using namespace mkldnn;
20
21 struct fc_test_params {
22     // Formats: NCHW, NCDHW
23     vector<size_t> in_dims;
24
25     size_t out_c;
26
27     size_t num_prim_desc;
28
29     int selectedType;
30     std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
31
32     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
33 };
34
35
36 template <typename data_t>
37 void ref_innerproduct(const InferenceEngine::TBlob<data_t> &src, const data_t *weights, const size_t weightsSize,
38                       InferenceEngine::TBlob<data_t> &dst, fc_test_params prm) {
39     auto dims_size = src.dims().size();
40
41     size_t IB = src.dims()[0];
42     size_t IC = src.dims()[1];
43     size_t ID = dims_size == 5 ? src.dims()[dims_size - 3] : 1u;
44     size_t IH = src.dims()[dims_size - 2];
45     size_t IW = src.dims()[dims_size - 1];
46
47     size_t OC = prm.out_c;
48
49     const data_t *src_data = src.readOnly();
50     const data_t *weights_data = weights;
51     const data_t *bias_data = weights_data + IW*IH*ID*IC*OC;
52     data_t *dst_data = dst.data();
53
54     IE_ASSERT( IW*IH*ID*IC*OC + OC == weightsSize );
55     IE_ASSERT( OC == dst.dims()[0] );
56
57     for (size_t n = 0; n < IB; n++) {
58         for (size_t oc = 0; oc < OC; oc++) {
59             dst_data[n*OC + oc] = bias_data[oc];
60             for (size_t ic = 0; ic < IC; ic++) {
61                 for (size_t kd = 0; kd < ID; kd++) {
62                     for (size_t kh = 0; kh < IH; kh++) {
63                         for (size_t kw = 0; kw < IW; kw++) {
64                             size_t iidx = n * IC * ID * IH * IW
65                                         + ic * ID * IH * IW
66                                         + kd * IH * IW
67                                         + kh * IW
68                                         + kw;
69                             size_t widx = oc * IC * ID * IH * IW
70                                           + ic * ID * IH * IW 
71                                           + kd * IH * IW 
72                                           + kh * IW 
73                                           + kw;
74
75                             dst_data[n*OC + oc] += src_data[iidx] * weights_data[widx];
76                         }
77                     }
78                 }
79             }
80         }
81     }
82 }
83
84 class MKLDNNGraphFullyConnectedTests: public TestsCommon,
85                                       public WithParamInterface<fc_test_params> {
86     std::string model_t = R"V0G0N(
87 <Net Name="FullyConnected_Only" version="3" precision="FP32" batch="1">
88     <layers>
89         <layer name="in1" type="Input" precision="FP32" id="0">
90             <output>
91                 <port id="0">__SRC_DIMS__
92                 </port>
93             </output>
94         </layer>
95         <layer name="FullyConnected" id="1" type="InnerProduct" precision="FP32">
96             <fc out-size="_OC_" PrimitivesPriority="_IMPLS_"/>
97
98             <weights offset="0" size="_S1_" />
99             <biases offset="_S1_" size="_S2_" />
100
101             <input>
102                 <port id="1">__SRC_DIMS__
103                 </port>
104             </input>
105             <output>
106                 <port id="2">
107                     <dim>_IN_</dim>
108                     <dim>_OC_</dim>
109                 </port>
110             </output>
111         </layer>
112     </layers>
113     <edges>
114         <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
115     </edges>
116 </Net>
117 )V0G0N";
118
119 protected:
120     std::string getModel(fc_test_params p) {
121         std::string model = model_t;
122         std::string s_dims;
123         for (auto& dim : p.in_dims) {
124             s_dims += "\n                    <dim>";
125             s_dims += std::to_string(dim) + "</dim>";
126         }
127         REPLACE_WITH_STR(model, "__SRC_DIMS__", s_dims);
128
129         REPLACE_WITH_NUM(model, "_IN_", p.in_dims[0]);
130         REPLACE_WITH_NUM(model, "_OC_", p.out_c);
131
132         size_t w_data_size = p.out_c * sizeof(float);
133         for (int i = 1; i < p.in_dims.size(); i++)
134             w_data_size *= p.in_dims[i];
135         size_t b_data_size = p.out_c * sizeof(float);
136         REPLACE_WITH_NUM(model, "_S1_", w_data_size);
137         REPLACE_WITH_NUM(model, "_S2_", b_data_size);
138         std::string impls;
139         for (const auto& preferType : p.preferTypes) {
140             if (!impls.empty())
141                 impls += ",";
142             impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
143         }
144         REPLACE_WITH_STR(model, "_IMPLS_", impls);
145         return model;
146     }
147
148     virtual void TearDown() {
149     }
150
151     virtual void SetUp() {
152         try {
153             TestsCommon::SetUp();
154             fc_test_params p = ::testing::WithParamInterface<fc_test_params>::GetParam();
155             std::string model = getModel(p);
156
157             InferenceEngine::CNNNetReader net_reader;
158             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
159
160             size_t weights_size = p.out_c;
161             for (int i = 1; i < p.in_dims.size(); i++) {
162                 weights_size *= p.in_dims[i];
163             }
164             weights_size = (weights_size + p.out_c) * sizeof(float);
165             InferenceEngine::TBlob<uint8_t> *weights = new InferenceEngine::TBlob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {weights_size});
166             weights->allocate();
167             fill_data((float *) weights->buffer(), weights->size() / sizeof(float));
168             InferenceEngine::TBlob<uint8_t>::Ptr weights_ptr = InferenceEngine::TBlob<uint8_t>::Ptr(weights);
169
170             net_reader.SetWeights(weights_ptr);
171
172             MKLDNNGraphTestClass graph;
173             graph.CreateGraph(net_reader.getNetwork());
174             auto& nodes = graph.getNodes();
175             for (int i = 0; i < nodes.size(); i++) {
176                 if (nodes[i]->getType() == MKLDNNPlugin::FullyConnected) {
177                     ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
178                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
179                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
180                     }
181                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
182                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
183                 }
184             }
185
186             InferenceEngine::SizeVector dims_src = p.in_dims;
187             InferenceEngine::Layout layout = InferenceEngine::ANY;
188             switch (p.in_dims.size()) {
189                 case 4:
190                     layout = InferenceEngine::NCHW;
191                     break;
192                 case 5:
193                     layout = InferenceEngine::NCDHW;
194                     break;
195             }
196
197             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
198             src->allocate();
199             fill_data(src->buffer(), src->size());
200
201             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
202
203             if (srcPtr == nullptr)
204                 FAIL() << "Cannot cast blob to TBlob<float>.";
205
206             InferenceEngine::BlobMap srcs;
207             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
208
209             InferenceEngine::OutputsDataMap out;
210             out = net_reader.getNetwork().getOutputsInfo();
211             InferenceEngine::BlobMap outputBlobs;
212
213             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
214
215             InferenceEngine::TBlob<float>::Ptr output;
216             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
217             output->allocate();
218             outputBlobs[item.first] = output;
219
220             graph.Infer(srcs, outputBlobs);
221
222             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
223             dst_ref.allocate();
224
225             ref_innerproduct(*srcPtr, (const float *)weights->buffer(), weights->size() / sizeof(float), dst_ref, p);
226
227             compare(*output, dst_ref, 0.9f);
228         } catch (const InferenceEngine::details::InferenceEngineException &e) {
229             FAIL() << e.what();
230         }
231     }
232 };
233
234 TEST_P(MKLDNNGraphFullyConnectedTests, TestsFullyConnected) {}
235
236
237 INSTANTIATE_TEST_CASE_P(
238         TestsFullyConnected, MKLDNNGraphFullyConnectedTests,
239         ::testing::Values(
240                 fc_test_params{{1, 3, 227, 227}, 96, 6, MKLDNNPlugin::impl_desc_type::gemm },
241                 fc_test_params{{1, 4, 227, 227}, 8, 6, MKLDNNPlugin::impl_desc_type::gemm },
242                 fc_test_params{{1, 4, 227, 227}, 10, 6, MKLDNNPlugin::impl_desc_type::gemm },
243                 fc_test_params{{1, 3, 227, 227}, 96, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
244                 fc_test_params{{1, 4, 227, 227}, 8, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
245                 fc_test_params{{1, 4, 227, 227}, 10, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
246                 //5D
247                 fc_test_params{{1, 4, 32, 32, 32}, 10, 6, MKLDNNPlugin::impl_desc_type::gemm },
248                 fc_test_params{{1, 3, 32, 32, 32}, 96, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}));
249
250 class MKLDNNGraphDynBatchFullyConnectedTests: public MKLDNNGraphFullyConnectedTests {
251     virtual void SetUp() {
252         try {
253             TestsCommon::SetUp();
254             fc_test_params p = ::testing::WithParamInterface<fc_test_params>::GetParam();
255             std::string model = getModel(p);
256             size_t MB = p.in_dims[0];
257             if (MB < 2)
258                 MB = 2;
259
260             InferenceEngine::CNNNetReader net_reader;
261             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
262
263             size_t weights_size = p.out_c;
264             for (int i = 1; i < p.in_dims.size(); i++) {
265                 weights_size *= p.in_dims[i];
266             }
267             weights_size = (weights_size + p.out_c) * sizeof(float);
268             InferenceEngine::TBlob<uint8_t> *weights = new InferenceEngine::TBlob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {weights_size});
269             weights->allocate();
270             fill_data((float *) weights->buffer(), weights->size() / sizeof(float));
271             InferenceEngine::TBlob<uint8_t>::Ptr weights_ptr = InferenceEngine::TBlob<uint8_t>::Ptr(weights);
272             net_reader.SetWeights(weights_ptr);
273             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
274             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
275             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
276             InferenceEngine::ResponseDesc resp;
277             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
278             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
279
280             MKLDNNGraphTestClass graph;
281             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
282             graph.CreateGraph(net_reader.getNetwork());
283
284             InferenceEngine::SizeVector dims_src = p.in_dims;
285             InferenceEngine::Layout layout = InferenceEngine::ANY;
286             switch (p.in_dims.size()) {
287                 case 4:
288                     layout = InferenceEngine::NCHW;
289                     break;
290                 case 5:
291                     layout = InferenceEngine::NCDHW;
292                     break;
293             }
294
295             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
296             src->allocate();
297             fill_data(src->buffer(), src->size());
298
299             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
300
301             if (srcPtr == nullptr)
302                 FAIL() << "Cannot cast blob to TBlob<float>.";
303
304             InferenceEngine::BlobMap srcs;
305             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
306
307             InferenceEngine::OutputsDataMap out;
308             out = net_reader.getNetwork().getOutputsInfo();
309             InferenceEngine::BlobMap outputBlobs;
310
311             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
312
313             InferenceEngine::TBlob<float>::Ptr output;
314             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
315             output->allocate();
316             outputBlobs[item.first] = output;
317
318             auto checkFC = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
319                 return node->getType() == MKLDNNPlugin::FullyConnected;
320             };
321
322             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkFC);
323             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkFC);
324         } catch (const InferenceEngine::details::InferenceEngineException &e) {
325             FAIL() << e.what();
326         }
327     }
328 };
329
330 TEST_P(MKLDNNGraphDynBatchFullyConnectedTests, TestsDynBatchFullyConnected) {}
331
332 INSTANTIATE_TEST_CASE_P(
333         TestsDynBatchFullyConnected, MKLDNNGraphDynBatchFullyConnectedTests,
334         ::testing::Values(
335                 fc_test_params{{1, 3, 227, 227}, 96, 6, MKLDNNPlugin::impl_desc_type::gemm },
336                 fc_test_params{{1, 4, 227, 227}, 8, 6, MKLDNNPlugin::impl_desc_type::gemm },
337                 fc_test_params{{1, 4, 227, 227}, 10, 6, MKLDNNPlugin::impl_desc_type::gemm },
338                 fc_test_params{{1, 3, 227, 227}, 96, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
339                 fc_test_params{{1, 4, 227, 227}, 8, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
340                 fc_test_params{{1, 4, 227, 227}, 10, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}));