1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "test_graph.hpp"
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
17 using namespace ::testing;
19 using namespace mkldnn;
21 struct fc_test_params {
22 // Formats: NCHW, NCDHW
23 vector<size_t> in_dims;
30 std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
32 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
36 template <typename data_t>
37 void ref_innerproduct(const InferenceEngine::TBlob<data_t> &src, const data_t *weights, const size_t weightsSize,
38 InferenceEngine::TBlob<data_t> &dst, fc_test_params prm) {
39 auto dims_size = src.dims().size();
41 size_t IB = src.dims()[0];
42 size_t IC = src.dims()[1];
43 size_t ID = dims_size == 5 ? src.dims()[dims_size - 3] : 1u;
44 size_t IH = src.dims()[dims_size - 2];
45 size_t IW = src.dims()[dims_size - 1];
47 size_t OC = prm.out_c;
49 const data_t *src_data = src.readOnly();
50 const data_t *weights_data = weights;
51 const data_t *bias_data = weights_data + IW*IH*ID*IC*OC;
52 data_t *dst_data = dst.data();
54 IE_ASSERT( IW*IH*ID*IC*OC + OC == weightsSize );
55 IE_ASSERT( OC == dst.dims()[0] );
57 for (size_t n = 0; n < IB; n++) {
58 for (size_t oc = 0; oc < OC; oc++) {
59 dst_data[n*OC + oc] = bias_data[oc];
60 for (size_t ic = 0; ic < IC; ic++) {
61 for (size_t kd = 0; kd < ID; kd++) {
62 for (size_t kh = 0; kh < IH; kh++) {
63 for (size_t kw = 0; kw < IW; kw++) {
64 size_t iidx = n * IC * ID * IH * IW
69 size_t widx = oc * IC * ID * IH * IW
75 dst_data[n*OC + oc] += src_data[iidx] * weights_data[widx];
84 class MKLDNNGraphFullyConnectedTests: public TestsCommon,
85 public WithParamInterface<fc_test_params> {
86 std::string model_t = R"V0G0N(
87 <Net Name="FullyConnected_Only" version="3" precision="FP32" batch="1">
89 <layer name="in1" type="Input" precision="FP32" id="0">
91 <port id="0">__SRC_DIMS__
95 <layer name="FullyConnected" id="1" type="InnerProduct" precision="FP32">
96 <fc out-size="_OC_" PrimitivesPriority="_IMPLS_"/>
98 <weights offset="0" size="_S1_" />
99 <biases offset="_S1_" size="_S2_" />
102 <port id="1">__SRC_DIMS__
114 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
120 std::string getModel(fc_test_params p) {
121 std::string model = model_t;
123 for (auto& dim : p.in_dims) {
124 s_dims += "\n <dim>";
125 s_dims += std::to_string(dim) + "</dim>";
127 REPLACE_WITH_STR(model, "__SRC_DIMS__", s_dims);
129 REPLACE_WITH_NUM(model, "_IN_", p.in_dims[0]);
130 REPLACE_WITH_NUM(model, "_OC_", p.out_c);
132 size_t w_data_size = p.out_c * sizeof(float);
133 for (int i = 1; i < p.in_dims.size(); i++)
134 w_data_size *= p.in_dims[i];
135 size_t b_data_size = p.out_c * sizeof(float);
136 REPLACE_WITH_NUM(model, "_S1_", w_data_size);
137 REPLACE_WITH_NUM(model, "_S2_", b_data_size);
139 for (const auto& preferType : p.preferTypes) {
142 impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
144 REPLACE_WITH_STR(model, "_IMPLS_", impls);
148 virtual void TearDown() {
151 virtual void SetUp() {
153 TestsCommon::SetUp();
154 fc_test_params p = ::testing::WithParamInterface<fc_test_params>::GetParam();
155 std::string model = getModel(p);
157 InferenceEngine::CNNNetReader net_reader;
158 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
160 size_t weights_size = p.out_c;
161 for (int i = 1; i < p.in_dims.size(); i++) {
162 weights_size *= p.in_dims[i];
164 weights_size = (weights_size + p.out_c) * sizeof(float);
165 InferenceEngine::TBlob<uint8_t> *weights = new InferenceEngine::TBlob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {weights_size});
167 fill_data((float *) weights->buffer(), weights->size() / sizeof(float));
168 InferenceEngine::TBlob<uint8_t>::Ptr weights_ptr = InferenceEngine::TBlob<uint8_t>::Ptr(weights);
170 net_reader.SetWeights(weights_ptr);
172 MKLDNNGraphTestClass graph;
173 graph.CreateGraph(net_reader.getNetwork());
174 auto& nodes = graph.getNodes();
175 for (int i = 0; i < nodes.size(); i++) {
176 if (nodes[i]->getType() == MKLDNNPlugin::FullyConnected) {
177 ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
178 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
179 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
181 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
182 ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
186 InferenceEngine::SizeVector dims_src = p.in_dims;
187 InferenceEngine::Layout layout = InferenceEngine::ANY;
188 switch (p.in_dims.size()) {
190 layout = InferenceEngine::NCHW;
193 layout = InferenceEngine::NCDHW;
197 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
199 fill_data(src->buffer(), src->size());
201 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
203 if (srcPtr == nullptr)
204 FAIL() << "Cannot cast blob to TBlob<float>.";
206 InferenceEngine::BlobMap srcs;
207 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
209 InferenceEngine::OutputsDataMap out;
210 out = net_reader.getNetwork().getOutputsInfo();
211 InferenceEngine::BlobMap outputBlobs;
213 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
215 InferenceEngine::TBlob<float>::Ptr output;
216 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
218 outputBlobs[item.first] = output;
220 graph.Infer(srcs, outputBlobs);
222 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
225 ref_innerproduct(*srcPtr, (const float *)weights->buffer(), weights->size() / sizeof(float), dst_ref, p);
227 compare(*output, dst_ref, 0.9f);
228 } catch (const InferenceEngine::details::InferenceEngineException &e) {
234 TEST_P(MKLDNNGraphFullyConnectedTests, TestsFullyConnected) {}
237 INSTANTIATE_TEST_CASE_P(
238 TestsFullyConnected, MKLDNNGraphFullyConnectedTests,
240 fc_test_params{{1, 3, 227, 227}, 96, 6, MKLDNNPlugin::impl_desc_type::gemm },
241 fc_test_params{{1, 4, 227, 227}, 8, 6, MKLDNNPlugin::impl_desc_type::gemm },
242 fc_test_params{{1, 4, 227, 227}, 10, 6, MKLDNNPlugin::impl_desc_type::gemm },
243 fc_test_params{{1, 3, 227, 227}, 96, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
244 fc_test_params{{1, 4, 227, 227}, 8, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
245 fc_test_params{{1, 4, 227, 227}, 10, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
247 fc_test_params{{1, 4, 32, 32, 32}, 10, 6, MKLDNNPlugin::impl_desc_type::gemm },
248 fc_test_params{{1, 3, 32, 32, 32}, 96, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}));
250 class MKLDNNGraphDynBatchFullyConnectedTests: public MKLDNNGraphFullyConnectedTests {
251 virtual void SetUp() {
253 TestsCommon::SetUp();
254 fc_test_params p = ::testing::WithParamInterface<fc_test_params>::GetParam();
255 std::string model = getModel(p);
256 size_t MB = p.in_dims[0];
260 InferenceEngine::CNNNetReader net_reader;
261 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
263 size_t weights_size = p.out_c;
264 for (int i = 1; i < p.in_dims.size(); i++) {
265 weights_size *= p.in_dims[i];
267 weights_size = (weights_size + p.out_c) * sizeof(float);
268 InferenceEngine::TBlob<uint8_t> *weights = new InferenceEngine::TBlob<uint8_t>(InferenceEngine::Precision::U8, InferenceEngine::C, {weights_size});
270 fill_data((float *) weights->buffer(), weights->size() / sizeof(float));
271 InferenceEngine::TBlob<uint8_t>::Ptr weights_ptr = InferenceEngine::TBlob<uint8_t>::Ptr(weights);
272 net_reader.SetWeights(weights_ptr);
273 InferenceEngine::CNNNetwork network = net_reader.getNetwork();
274 auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
275 ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
276 InferenceEngine::ResponseDesc resp;
277 InferenceEngine::StatusCode sts = implNet->setBatchSizeReshape(MB, &resp);
278 ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
280 MKLDNNGraphTestClass graph;
281 graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
282 graph.CreateGraph(net_reader.getNetwork());
284 InferenceEngine::SizeVector dims_src = p.in_dims;
285 InferenceEngine::Layout layout = InferenceEngine::ANY;
286 switch (p.in_dims.size()) {
288 layout = InferenceEngine::NCHW;
291 layout = InferenceEngine::NCDHW;
295 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
297 fill_data(src->buffer(), src->size());
299 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
301 if (srcPtr == nullptr)
302 FAIL() << "Cannot cast blob to TBlob<float>.";
304 InferenceEngine::BlobMap srcs;
305 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
307 InferenceEngine::OutputsDataMap out;
308 out = net_reader.getNetwork().getOutputsInfo();
309 InferenceEngine::BlobMap outputBlobs;
311 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
313 InferenceEngine::TBlob<float>::Ptr output;
314 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
316 outputBlobs[item.first] = output;
318 auto checkFC = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
319 return node->getType() == MKLDNNPlugin::FullyConnected;
322 graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkFC);
323 graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkFC);
324 } catch (const InferenceEngine::details::InferenceEngineException &e) {
330 TEST_P(MKLDNNGraphDynBatchFullyConnectedTests, TestsDynBatchFullyConnected) {}
332 INSTANTIATE_TEST_CASE_P(
333 TestsDynBatchFullyConnected, MKLDNNGraphDynBatchFullyConnectedTests,
335 fc_test_params{{1, 3, 227, 227}, 96, 6, MKLDNNPlugin::impl_desc_type::gemm },
336 fc_test_params{{1, 4, 227, 227}, 8, 6, MKLDNNPlugin::impl_desc_type::gemm },
337 fc_test_params{{1, 4, 227, 227}, 10, 6, MKLDNNPlugin::impl_desc_type::gemm },
338 fc_test_params{{1, 3, 227, 227}, 96, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
339 fc_test_params{{1, 4, 227, 227}, 8, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
340 fc_test_params{{1, 4, 227, 227}, 10, 6, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}));