1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "test_graph.hpp"
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
17 using namespace ::testing;
19 using namespace mkldnn;
22 struct tile_test_params {
35 MKLDNNPlugin::impl_desc_type selectedType;
37 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
41 template <typename data_t>
42 void ref_tile(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst_blob, tile_test_params prm) {
43 const float* m_src = src.readOnly();
47 for (int i=0; i < prm.axis; i++ ) m_outer_dim *= src.dims()[i];
48 for (int i=prm.axis; i < src.dims().size(); i++ ) m_inner_dim *= src.dims()[i];
50 float* dst = dst_blob.data();
52 for (int i = 0; i < m_outer_dim; ++i) {
53 for (int t = 0; t < prm.tiles; ++t) {
54 memcpy(dst, m_src, m_inner_dim* sizeof(float));
61 class MKLDNNGraphTileTests: public TestsCommon,
62 public WithParamInterface<tile_test_params> {
63 std::string model_t = R"V0G0N(
64 <Net Name="Tile_Only" version="2" precision="FP32" batch="1">
66 <layer name="in1" type="Input" precision="FP32" id="0">
76 <layer name="tile" id="1" type="Tile" precision="FP32">
77 <data axis="_AX_" tiles="_TL_"/>
98 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
104 std::string getModel(tile_test_params p) {
105 std::string model = model_t;
107 REPLACE_WITH_NUM(model, "_IW_", p.in.w);
108 REPLACE_WITH_NUM(model, "_IH_", p.in.h);
109 REPLACE_WITH_NUM(model, "_IC_", p.in.c);
110 REPLACE_WITH_NUM(model, "_IN_", p.in.n);
112 REPLACE_WITH_NUM(model, "_OW_", (p.axis == 3) ? p.in.w*p.tiles : p.in.w);
113 REPLACE_WITH_NUM(model, "_OH_", (p.axis == 2) ? p.in.h*p.tiles : p.in.h);
114 REPLACE_WITH_NUM(model, "_OC_", (p.axis == 1) ? p.in.c*p.tiles : p.in.c);
115 REPLACE_WITH_NUM(model, "_ON_", (p.axis == 0) ? p.in.n*p.tiles : p.in.n);
117 REPLACE_WITH_NUM(model, "_AX_", p.axis);
118 REPLACE_WITH_NUM(model, "_TL_", p.tiles);
122 virtual void TearDown() {
125 virtual void SetUp() {
127 TestsCommon::SetUp();
128 tile_test_params p = ::testing::WithParamInterface<tile_test_params>::GetParam();
129 std::string model = getModel(p);
131 InferenceEngine::CNNNetReader net_reader;
132 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
134 MKLDNNGraphTestClass graph;
135 graph.CreateGraph(net_reader.getNetwork());
136 auto& nodes = graph.getNodes();
137 for (int i = 0; i < nodes.size(); i++) {
138 if (nodes[i]->getType() == MKLDNNPlugin::Tile) {
139 ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
140 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
141 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
143 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
144 ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
148 InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
150 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
152 fill_data(src->buffer(), src->size());
154 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
156 if (srcPtr == nullptr)
157 FAIL() << "Cannot cast blob to TBlob<float>.";
159 InferenceEngine::BlobMap srcs;
160 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
162 InferenceEngine::OutputsDataMap out;
163 out = net_reader.getNetwork().getOutputsInfo();
164 InferenceEngine::BlobMap outputBlobs;
166 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
168 InferenceEngine::TBlob<float>::Ptr output;
169 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
171 outputBlobs[item.first] = output;
173 graph.Infer(srcs, outputBlobs);
175 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
178 ref_tile(*srcPtr, dst_ref, p);
180 compare(*output, dst_ref);
181 } catch (const InferenceEngine::details::InferenceEngineException &e) {
187 TEST_P(MKLDNNGraphTileTests, TestsTile) {}
190 INSTANTIATE_TEST_CASE_P(
191 TestsTile, MKLDNNGraphTileTests,
194 {1, 128, 1, 1}, 3, 24, 1, MKLDNNPlugin::impl_desc_type::unknown, {
195 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
196 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
197 ASSERT_EQ(1, impl.getConfig().inConfs.size());
198 ASSERT_EQ(1, impl.getConfig().outConfs.size());
199 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
200 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
204 class MKLDNNGraphDynBatchTileTests: public MKLDNNGraphTileTests {
206 virtual void SetUp() {
208 TestsCommon::SetUp();
209 tile_test_params p = ::testing::WithParamInterface<tile_test_params>::GetParam();
210 std::string model = getModel(p);
215 InferenceEngine::CNNNetReader net_reader;
216 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
217 InferenceEngine::CNNNetwork network = net_reader.getNetwork();
218 auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
219 ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
220 InferenceEngine::ResponseDesc resp;
221 InferenceEngine::StatusCode sts = implNet->setBatchSizeReshape(MB, &resp);
222 ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
224 MKLDNNGraphTestClass graph;
225 graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
226 graph.CreateGraph(net_reader.getNetwork());
228 InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
230 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
232 fill_data(src->buffer(), src->size());
234 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
236 if (srcPtr == nullptr)
237 FAIL() << "Cannot cast blob to TBlob<float>.";
239 InferenceEngine::BlobMap srcs;
240 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
242 InferenceEngine::OutputsDataMap out;
243 out = net_reader.getNetwork().getOutputsInfo();
244 InferenceEngine::BlobMap outputBlobs;
246 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
248 InferenceEngine::TBlob<float>::Ptr output;
249 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
251 outputBlobs[item.first] = output;
253 auto checkTile = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
254 return node->getType() == MKLDNNPlugin::Tile;
257 graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkTile);
258 graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkTile);
259 } catch (const InferenceEngine::details::InferenceEngineException &e) {
265 TEST_P(MKLDNNGraphDynBatchTileTests, TestsDynBatchTile) {}
268 INSTANTIATE_TEST_CASE_P(
269 TestsDynBatchTile, MKLDNNGraphDynBatchTileTests,
272 {1, 128, 1, 1}, 3, 24, 1, MKLDNNPlugin::impl_desc_type::unknown, {
273 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
274 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
275 ASSERT_EQ(1, impl.getConfig().inConfs.size());
276 ASSERT_EQ(1, impl.getConfig().outConfs.size());
277 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
278 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());