1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "test_graph.hpp"
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
17 using namespace ::testing;
19 using namespace mkldnn;
22 struct softmax_test_params {
35 std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
37 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
40 template <typename data_t>
41 void check_softmax_fwd(const InferenceEngine::TBlob<data_t> &src, softmax_test_params prm)
43 const data_t *src_data = src.readOnly();
50 auto off = [=](int n, int c, int h, int w)
52 return (n * W * H * C + c * W * H + h * W + w);
55 auto check_norm = [=](double res) {
56 if(res < 0.999f || res > 1.001) {
57 ASSERT_TRUE(res > 0.99f && res < 1.01);
62 for (int c = 0; c < C; ++c) {
63 for (int h = 0; h < H; ++h) {
64 for (int w = 0; w < W; ++w) {
67 for (int n = 0; n < MB; ++n) {
68 result += src_data[off(n, c, h, w)];//dst_ptr[map_index(dst_pd, off(n, c, h, w))];
75 else if(prm.axis == 1) {
76 for (int n = 0; n < MB; ++n) {
77 for (int h = 0; h < H; ++h) {
78 for (int w = 0; w < W; ++w) {
81 for (int c = 0; c < C; ++c) {
82 result += src_data[off(n, c, h, w)];//dst_ptr[map_index(dst_pd, off(n, c, h, w))];
90 else if(prm.axis == 2) {
91 for (int n = 0; n < MB; ++n) {
92 for (int c = 0; c < C; ++c) {
93 for (int w = 0; w < W; ++w) {
96 for (int h = 0; h < H; ++h) {
97 result += src_data[off(n, c, h, w)];//dst_ptr[map_index(dst_pd, off(n, c, h, w))];
105 else if(prm.axis == 3) {
106 for (int n = 0; n < MB; ++n) {
107 for (int c = 0; c < C; ++c) {
108 for (int h = 0; h < H; ++h) {
109 double result = 0.0f;
111 for (int w = 0; w < W; ++w) {
112 result += src_data[off(n, c, h, w)];//dst_ptr[map_index(dst_pd, off(n, c, h, w))];
122 class MKLDNNGraphSoftMaxTests: public TestsCommon,
123 public WithParamInterface<softmax_test_params> {
124 std::string model_t = R"V0G0N(
125 <Net Name="Lrn_Only" version="2" precision="FP32" batch="1">
127 <layer name="in1" type="Input" precision="FP32" id="0">
137 <layer name="norm" id="1" type="Softmax" precision="FP32">
138 <data PrimitivesPriority="_IMPLS_" axis="_AX_"/>
158 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
164 std::string getModel(softmax_test_params p) {
165 std::string model = model_t;
167 REPLACE_WITH_NUM(model, "_IW_", p.in.w);
168 REPLACE_WITH_NUM(model, "_IH_", p.in.h);
169 REPLACE_WITH_NUM(model, "_IC_", p.in.c);
170 REPLACE_WITH_NUM(model, "_IN_", p.in.n);
171 REPLACE_WITH_NUM(model, "_AX_", p.axis);
173 for (const auto& preferType : p.preferTypes) {
176 impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
178 REPLACE_WITH_STR(model, "_IMPLS_", impls);
183 virtual void TearDown() {
186 virtual void SetUp() {
188 TestsCommon::SetUp();
189 softmax_test_params p = ::testing::WithParamInterface<softmax_test_params>::GetParam();
190 std::string model = getModel(p);
192 InferenceEngine::CNNNetReader net_reader;
193 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
195 MKLDNNGraphTestClass graph;
196 graph.CreateGraph(net_reader.getNetwork());
197 auto& nodes = graph.getNodes();
198 for (int i = 0; i < nodes.size(); i++) {
199 if (nodes[i]->getType() == MKLDNNPlugin::SoftMax) {
200 ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
201 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
202 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
204 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
205 ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
209 InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
211 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
213 fill_data(src->buffer(), src->size());
215 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
217 if (srcPtr == nullptr)
218 FAIL() << "Cannot cast blob to TBlob<float>.";
220 InferenceEngine::BlobMap srcs;
221 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
223 InferenceEngine::OutputsDataMap out;
224 out = net_reader.getNetwork().getOutputsInfo();
225 InferenceEngine::BlobMap outputBlobs;
227 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
229 InferenceEngine::TBlob<float>::Ptr output;
230 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
232 outputBlobs[item.first] = output;
234 graph.Infer(srcs, outputBlobs);
236 check_softmax_fwd(*output, p);
237 } catch (const InferenceEngine::details::InferenceEngineException &e) {
243 TEST_P(MKLDNNGraphSoftMaxTests, TestsSoftMax) {}
246 INSTANTIATE_TEST_CASE_P(
247 TestsSoftMax, MKLDNNGraphSoftMaxTests,
249 softmax_test_params{{1, 3, 228, 228}, 1, 2, MKLDNNPlugin::impl_desc_type::jit},
250 softmax_test_params{{1, 3, 228, 228}, 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
251 softmax_test_params{{1, 100, 6, 1}, 1, 2, MKLDNNPlugin::impl_desc_type::jit},
252 softmax_test_params{{1, 100, 6, 1}, 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
253 softmax_test_params{{1, 1000, 1, 1}, 1, 1, MKLDNNPlugin::impl_desc_type::ref},
254 softmax_test_params{{8, 1000, 1, 1}, 1, 1, MKLDNNPlugin::impl_desc_type::ref},
255 softmax_test_params{{1, 19, 128, 128}, 1, 2, MKLDNNPlugin::impl_desc_type::jit},
256 softmax_test_params{{1, 19, 128, 128}, 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
257 // softmax_test_params{{8, 100, 81, 1}, 2, 2, MKLDNNPlugin::impl_desc_type::jit},
258 softmax_test_params{{8, 100, 81, 1}, 2, 1, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
259 softmax_test_params{{1, 1, 1, 1}, 3, 1, MKLDNNPlugin::impl_desc_type::ref},
260 // softmax_test_params{{1, 1, 1, 33}, 3, 2, MKLDNNPlugin::impl_desc_type::jit},
261 softmax_test_params{{1, 1, 1, 33}, 3, 1, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
262 // softmax_test_params{{8, 1, 10, 81}, 3, 2, MKLDNNPlugin::impl_desc_type::jit},
263 softmax_test_params{{8, 1, 10, 81}, 3, 1, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}
266 class MKLDNNGraphDynBatchSoftMaxTests: public MKLDNNGraphSoftMaxTests {
268 virtual void SetUp() {
270 TestsCommon::SetUp();
271 softmax_test_params p = ::testing::WithParamInterface<softmax_test_params>::GetParam();
272 std::string model = getModel(p);
277 InferenceEngine::CNNNetReader net_reader;
278 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
279 InferenceEngine::CNNNetwork network = net_reader.getNetwork();
280 auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
281 ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
282 InferenceEngine::ResponseDesc resp;
283 InferenceEngine::StatusCode sts = implNet->setBatchSizeReshape(MB, &resp);
284 ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
286 MKLDNNGraphTestClass graph;
287 graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
288 graph.CreateGraph(net_reader.getNetwork());
290 InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
292 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
294 fill_data(src->buffer(), src->size());
296 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
298 if (srcPtr == nullptr)
299 FAIL() << "Cannot cast blob to TBlob<float>.";
301 InferenceEngine::BlobMap srcs;
302 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
304 InferenceEngine::OutputsDataMap out;
305 out = net_reader.getNetwork().getOutputsInfo();
306 InferenceEngine::BlobMap outputBlobs;
308 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
310 InferenceEngine::TBlob<float>::Ptr output;
311 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
313 outputBlobs[item.first] = output;
315 auto checkSoftmax = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
316 return node->getType() == MKLDNNPlugin::SoftMax;
319 graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkSoftmax);
320 graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkSoftmax);
321 } catch (const InferenceEngine::details::InferenceEngineException &e) {
327 TEST_P(MKLDNNGraphDynBatchSoftMaxTests, TestsDynBatchSoftMax) {}
330 INSTANTIATE_TEST_CASE_P(
331 TestsDynBatchSoftMax, MKLDNNGraphDynBatchSoftMaxTests,
333 softmax_test_params{{1, 3, 228, 228}, 1, 2, MKLDNNPlugin::impl_desc_type::jit},
334 softmax_test_params{{1, 3, 228, 228}, 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
335 softmax_test_params{{1, 100, 6, 1}, 1, 2, MKLDNNPlugin::impl_desc_type::jit},
336 softmax_test_params{{1, 100, 6, 1}, 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
337 softmax_test_params{{1, 1000, 1, 1}, 1, 1, MKLDNNPlugin::impl_desc_type::ref},
338 softmax_test_params{{8, 1000, 1, 1}, 1, 1, MKLDNNPlugin::impl_desc_type::ref},
339 softmax_test_params{{1, 19, 128, 128}, 1, 2, MKLDNNPlugin::impl_desc_type::jit},
340 softmax_test_params{{1, 19, 128, 128}, 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
341 // softmax_test_params{{8, 100, 81, 1}, 2, 2, MKLDNNPlugin::impl_desc_type::jit},
342 softmax_test_params{{8, 100, 81, 1}, 2, 1, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
343 softmax_test_params{{1, 1, 1, 1}, 3, 1, MKLDNNPlugin::impl_desc_type::ref},
344 // softmax_test_params{{1, 1, 1, 33}, 3, 2, MKLDNNPlugin::impl_desc_type::jit},
345 softmax_test_params{{1, 1, 1, 33}, 3, 1, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}},
346 // softmax_test_params{{8, 1, 10, 81}, 3, 2, MKLDNNPlugin::impl_desc_type::jit},
347 softmax_test_params{{8, 1, 10, 81}, 3, 1, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}