1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
11 #include "test_graph.hpp"
13 #include "single_layer_common.hpp"
14 #include <mkldnn_plugin/mkldnn_extension_utils.h>
15 #include <inference_engine/cnn_network_impl.hpp>
16 #include "tests_common.hpp"
19 using namespace ::testing;
21 using namespace mkldnn;
24 struct lrn_test_params {
41 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
44 template <typename data_t>
45 void ref_lrn(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, lrn_test_params prm)
51 const data_t *src_data = src.readOnly();
52 data_t *dst_data = dst.data();
54 for (uint32_t c = 0; c < IC; c++) {
55 for (uint32_t h = 0; h < IH; h++) {
56 for (uint32_t w = 0; w < IW; w++) {
57 uint32_t oidx = c * IH * IW
60 uint32_t sz = prm.local_size;
61 int32_t c_start = c - sz / 2;
62 int32_t c_end = c_start + sz;
63 if (c_start < 0) c_start = 0;
64 if (c_end > (int32_t)IC) c_end = IC;
66 for (int32_t c1 = c_start; c1 < c_end; c1++) {
67 uint32_t idx = c1 * IH * IW + h * IW + w;
68 data_t s = src_data[idx];
73 data_t norm_coef = powf(1. + prm.alpha * sum / sz, -prm.beta);
74 dst_data[oidx] = norm_coef * src_data[oidx];
80 class MKLDNNGraphLrnTests: public TestsCommon,
81 public WithParamInterface<lrn_test_params> {
82 std::string model_t = R"V0G0N(
83 <Net Name="Lrn_Only" version="2" precision="FP32" batch="1">
85 <layer name="in1" type="Input" precision="FP32" id="0">
95 <layer name="norm" id="1" type="LRN" precision="FP32">
96 <lrn local_size="_LS_" alpha="_A_" beta="_B_" k="_K_" region="ACROSS" />
117 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
123 std::string getModel(lrn_test_params p) {
124 std::string model = model_t;
126 REPLACE_WITH_NUM(model, "_IW_", p.in.w);
127 REPLACE_WITH_NUM(model, "_IH_", p.in.h);
128 REPLACE_WITH_NUM(model, "_IC_", p.in.c);
129 REPLACE_WITH_NUM(model, "_IN_", p.in.n);
131 REPLACE_WITH_NUM(model, "_LS_", p.local_size);
132 REPLACE_WITH_NUM(model, "_A_", p.alpha);
133 REPLACE_WITH_NUM(model, "_B_", p.beta);
134 REPLACE_WITH_NUM(model, "_K_", p.k);
139 virtual void TearDown() {
142 virtual void SetUp() {
144 TestsCommon::SetUp();
145 lrn_test_params p = ::testing::WithParamInterface<lrn_test_params>::GetParam();
146 std::string model = getModel(p);
148 InferenceEngine::CNNNetReader net_reader;
149 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
151 MKLDNNGraphTestClass graph;
152 graph.CreateGraph(net_reader.getNetwork());
153 auto& nodes = graph.getNodes();
154 for (int i = 0; i < nodes.size(); i++) {
155 if (nodes[i]->getType() == MKLDNNPlugin::Lrn) {
156 ASSERT_LE(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
157 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
158 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
160 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
161 ASSERT_EQ(p.selectedType,
162 nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
165 ASSERT_EQ(3, nodes.size());
167 InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
169 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
171 fill_data(src->buffer(), src->size());
173 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
175 if (srcPtr == nullptr)
176 FAIL() << "Cannot cast blob to TBlob<float>.";
178 InferenceEngine::BlobMap srcs;
179 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
181 InferenceEngine::OutputsDataMap out;
182 out = net_reader.getNetwork().getOutputsInfo();
183 InferenceEngine::BlobMap outputBlobs;
185 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
187 InferenceEngine::TBlob<float>::Ptr output;
188 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
190 outputBlobs[item.first] = output;
192 graph.Infer(srcs, outputBlobs);
194 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
197 ref_lrn(*srcPtr, dst_ref, p);
199 compare(*output, dst_ref);
200 } catch (const InferenceEngine::details::InferenceEngineException &e) {
206 TEST_P(MKLDNNGraphLrnTests, TestsLrn) {}
208 INSTANTIATE_TEST_CASE_P(
209 TestsLrn, MKLDNNGraphLrnTests,
213 5, 0.0001f, 0.75f, 1, 3, MKLDNNPlugin::impl_desc_type::ref_any, {
214 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
215 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref_any, impl.getImplementationType());
216 ASSERT_EQ(1, impl.getConfig().inConfs.size());
217 ASSERT_EQ(1, impl.getConfig().outConfs.size());
218 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
219 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
221 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
222 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref_any, impl.getImplementationType());
223 ASSERT_EQ(1, impl.getConfig().inConfs.size());
224 ASSERT_EQ(1, impl.getConfig().outConfs.size());
225 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
226 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
228 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
229 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref_any, impl.getImplementationType());
230 ASSERT_EQ(1, impl.getConfig().inConfs.size());
231 ASSERT_EQ(1, impl.getConfig().outConfs.size());
232 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
233 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
236 lrn_test_params{{1, 16, 228, 228}, 5, 0.0001f, 0.75f, 1, 3, MKLDNNPlugin::impl_desc_type::jit}));
238 class MKLDNNGraphDynBatchLrnTests: public MKLDNNGraphLrnTests {
240 virtual void SetUp() {
242 TestsCommon::SetUp();
243 lrn_test_params p = ::testing::WithParamInterface<lrn_test_params>::GetParam();
244 std::string model = getModel(p);
249 InferenceEngine::CNNNetReader net_reader;
250 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
251 InferenceEngine::CNNNetwork network = net_reader.getNetwork();
252 auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
253 ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
254 InferenceEngine::ResponseDesc resp;
255 InferenceEngine::StatusCode sts = implNet->setBatchSizeReshape(MB, &resp);
256 ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
258 MKLDNNGraphTestClass graph;
259 graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
260 graph.CreateGraph(net_reader.getNetwork());
262 InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
264 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
266 fill_data(src->buffer(), src->size());
268 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
270 if (srcPtr == nullptr)
271 FAIL() << "Cannot cast blob to TBlob<float>.";
273 InferenceEngine::BlobMap srcs;
274 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
276 InferenceEngine::OutputsDataMap out;
277 out = net_reader.getNetwork().getOutputsInfo();
278 InferenceEngine::BlobMap outputBlobs;
280 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
282 InferenceEngine::TBlob<float>::Ptr output;
283 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
285 outputBlobs[item.first] = output;
287 auto checkLRN = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
288 return node->getType() == MKLDNNPlugin::Lrn;
290 graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkLRN);
291 graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkLRN);
292 } catch (const InferenceEngine::details::InferenceEngineException &e) {
298 TEST_P(MKLDNNGraphDynBatchLrnTests, TestsDynBatchLrn) {}
300 INSTANTIATE_TEST_CASE_P(
301 TestsDynBatchLrn, MKLDNNGraphDynBatchLrnTests,
303 lrn_test_params{{1, 3, 228, 228}, 5, 0.0001f, 0.75f, 1, 3, MKLDNNPlugin::impl_desc_type::ref_any},
304 lrn_test_params{{1, 16, 228, 228}, 5, 0.0001f, 0.75f, 1, 3, MKLDNNPlugin::impl_desc_type::jit}));