1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8 #include "single_layer_common.hpp"
10 #include "test_graph.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include "tests_common.hpp"
16 using namespace ::testing;
18 using namespace mkldnn;
21 struct reshape_test_params {
22 InferenceEngine::SizeVector in;
23 InferenceEngine::SizeVector out;
24 std::vector<size_t> shape;
31 MKLDNNPlugin::impl_desc_type selectedType;
33 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
36 template <typename data_t>
37 void ref_reshape(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst) {
38 const data_t *src_data = src.readOnly();
39 data_t *dst_data = dst.data();
41 for (int i=0; i < src.size(); i++)
42 dst_data[i] = src_data[i];
45 class MKLDNNGraphReshapeTests: public TestsCommon,
46 public WithParamInterface<reshape_test_params> {
47 std::string model_t = R"V0G0N(
48 <Net Name="Reshape_Only" version="2" precision="FP32" batch="1">
50 <layer name="in1" type="Input" precision="FP32" id="0">
57 <layer name="norm" id="1" type="Reshape" precision="FP32">
58 <data dim="_SHAPE_" axis="_AX_" num_axes="_NAX_"/>
73 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
78 std::string getModel(reshape_test_params p) {
79 std::string model = model_t;
82 for (auto& dim : p.in) {
84 src_dims += std::to_string(dim) + "</dim>\n";
86 REPLACE_WITH_STR(model, "__SRC_DIMS__", src_dims);
89 for (auto& dim : p.out) {
90 dst_dims += "\t\t<dim>";
91 dst_dims += std::to_string(dim) + "</dim>\n";
93 REPLACE_WITH_STR(model, "__DST_DIMS__", dst_dims);
95 REPLACE_WITH_NUM(model, "_AX_", p.axis);
96 REPLACE_WITH_NUM(model, "_NAX_", p.num_axes);
98 std::string shape_str;
99 for (auto& dim : p.shape) {
100 if (!shape_str.empty())
102 shape_str += std::to_string(dim);
104 REPLACE_WITH_STR(model, "_SHAPE_", shape_str);
109 virtual void TearDown() {
112 virtual void SetUp() {
114 TestsCommon::SetUp();
115 reshape_test_params p = ::testing::WithParamInterface<reshape_test_params>::GetParam();
116 std::string model = getModel(p);
118 InferenceEngine::CNNNetReader net_reader;
119 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
121 MKLDNNGraphTestClass graph;
122 graph.CreateGraph(net_reader.getNetwork());
123 auto& nodes = graph.getNodes();
124 for (int i = 0; i < nodes.size(); i++) {
125 if (nodes[i]->getType() == MKLDNNPlugin::Reshape) {
126 ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
127 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
128 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
130 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
131 ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
135 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::ANY, p.in);
137 fill_data(src->buffer(), src->size());
139 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
141 if (srcPtr == nullptr)
142 FAIL() << "Cannot cast blob to TBlob<float>.";
144 InferenceEngine::BlobMap srcs;
145 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
147 InferenceEngine::OutputsDataMap out;
148 out = net_reader.getNetwork().getOutputsInfo();
149 InferenceEngine::BlobMap outputBlobs;
151 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
153 InferenceEngine::TBlob<float>::Ptr output;
154 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
156 outputBlobs[item.first] = output;
158 graph.Infer(srcs, outputBlobs);
160 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
163 ref_reshape(*srcPtr, dst_ref);
165 compare(*output, dst_ref);
166 } catch (const InferenceEngine::details::InferenceEngineException &e) {
172 TEST_P(MKLDNNGraphReshapeTests, TestsReshape) {}
175 INSTANTIATE_TEST_CASE_P(
176 TestsReshape, MKLDNNGraphReshapeTests,
178 reshape_test_params{ {1, 3, 228, 228}, {1, 24, 2, 3249}, {1, 24, 2, 3249}, 0, -1, 1,
179 MKLDNNPlugin::impl_desc_type::unknown, { [](MKLDNNPlugin::PrimitiveDescInfo impl) {
180 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
181 ASSERT_EQ(1, impl.getConfig().inConfs.size());
182 ASSERT_EQ(1, impl.getConfig().outConfs.size());
183 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
184 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
186 reshape_test_params{ { 4 },{ 2, 2 },{ 2, 2 }, 0, -1, 1,
187 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
188 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
189 ASSERT_EQ(1, impl.getConfig().inConfs.size());
190 ASSERT_EQ(1, impl.getConfig().outConfs.size());
191 ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().inConfs.at(0).desc.getLayout());
192 ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().outConfs.at(0).desc.getLayout());
194 reshape_test_params{ { 4 },{ 1, 2, 2 },{ 1, 2, 2 }, 0, -1, 1,
195 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
196 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
197 ASSERT_EQ(1, impl.getConfig().inConfs.size());
198 ASSERT_EQ(1, impl.getConfig().outConfs.size());
199 ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().inConfs.at(0).desc.getLayout());
200 ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().outConfs.at(0).desc.getLayout());
202 reshape_test_params{ { 4 },{ 1, 4, 1, 1 },{ 1, 4, 1, 1 }, 0, -1, 1,
203 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
204 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
205 ASSERT_EQ(1, impl.getConfig().inConfs.size());
206 ASSERT_EQ(1, impl.getConfig().outConfs.size());
207 ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().inConfs.at(0).desc.getLayout());
208 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
210 reshape_test_params{ { 4, 4 },{ 1, 4, 4 },{ 1, 4, 4 }, 0, -1, 1,
211 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
212 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
213 ASSERT_EQ(1, impl.getConfig().inConfs.size());
214 ASSERT_EQ(1, impl.getConfig().outConfs.size());
215 ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
216 ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().outConfs.at(0).desc.getLayout());
218 reshape_test_params{ { 4, 4 },{ 1, 4, 2, 2 },{ 1, 4, 2, 2 }, 0, -1, 1,
219 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
220 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
221 ASSERT_EQ(1, impl.getConfig().inConfs.size());
222 ASSERT_EQ(1, impl.getConfig().outConfs.size());
223 ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
224 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
226 reshape_test_params{ { 4, 2, 2 },{ 1, 4, 2, 2 },{ 1, 4, 2, 2 }, 0, -1, 1,
227 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
228 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
229 ASSERT_EQ(1, impl.getConfig().inConfs.size());
230 ASSERT_EQ(1, impl.getConfig().outConfs.size());
231 ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().inConfs.at(0).desc.getLayout());
232 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
234 reshape_test_params{ { 2, 2 },{ 4 },{ 4 }, 0, -1, 1,
235 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
236 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
237 ASSERT_EQ(1, impl.getConfig().inConfs.size());
238 ASSERT_EQ(1, impl.getConfig().outConfs.size());
239 ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
240 ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().outConfs.at(0).desc.getLayout());
242 reshape_test_params{ { 1, 2, 2 },{ 4 },{ 4 }, 0, -1, 1,
243 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
244 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
245 ASSERT_EQ(1, impl.getConfig().inConfs.size());
246 ASSERT_EQ(1, impl.getConfig().outConfs.size());
247 ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().inConfs.at(0).desc.getLayout());
248 ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().outConfs.at(0).desc.getLayout());
250 reshape_test_params{ { 1, 1, 2, 2 },{ 4 },{ 4 }, 0, -1, 1,
251 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
252 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
253 ASSERT_EQ(1, impl.getConfig().inConfs.size());
254 ASSERT_EQ(1, impl.getConfig().outConfs.size());
255 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
256 ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().outConfs.at(0).desc.getLayout());
258 reshape_test_params{ { 4, 2, 2 },{ 4, 4 },{ 4, 4 }, 0, -1, 1,
259 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
260 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
261 ASSERT_EQ(1, impl.getConfig().inConfs.size());
262 ASSERT_EQ(1, impl.getConfig().outConfs.size());
263 ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().inConfs.at(0).desc.getLayout());
264 ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().outConfs.at(0).desc.getLayout());
266 reshape_test_params{ { 1, 4, 2, 2 },{ 4, 4 },{ 4, 4 }, 0, -1, 1,
267 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
268 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
269 ASSERT_EQ(1, impl.getConfig().inConfs.size());
270 ASSERT_EQ(1, impl.getConfig().outConfs.size());
271 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
272 ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().outConfs.at(0).desc.getLayout());
274 reshape_test_params{ { 1, 4, 2, 2 },{ 4, 2, 2 },{ 4, 2, 2 }, 0, -1, 1,
275 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
276 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
277 ASSERT_EQ(1, impl.getConfig().inConfs.size());
278 ASSERT_EQ(1, impl.getConfig().outConfs.size());
279 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
280 ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().outConfs.at(0).desc.getLayout());
282 reshape_test_params{ { 1, 4, 2, 2 }, { 4, 2, 2, 1, 1 }, { 4, 2, 2, 1, 1 }, 0, -1, 1,
283 MKLDNNPlugin::impl_desc_type::unknown, { [](MKLDNNPlugin::PrimitiveDescInfo impl) {
284 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
285 ASSERT_EQ(1, impl.getConfig().inConfs.size());
286 ASSERT_EQ(1, impl.getConfig().outConfs.size());
287 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
288 ASSERT_EQ(InferenceEngine::Layout::NCDHW, impl.getConfig().outConfs.at(0).desc.getLayout());
290 reshape_test_params{ { 4, 2, 2, 1, 1 }, { 1, 4, 2, 2 }, { 1, 4, 2, 2 }, 0, -1, 1,
291 MKLDNNPlugin::impl_desc_type::unknown, { [](MKLDNNPlugin::PrimitiveDescInfo impl) {
292 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
293 ASSERT_EQ(1, impl.getConfig().inConfs.size());
294 ASSERT_EQ(1, impl.getConfig().outConfs.size());
295 ASSERT_EQ(InferenceEngine::Layout::NCDHW, impl.getConfig().inConfs.at(0).desc.getLayout());
296 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
298 reshape_test_params{ { 1, 200 }, { 1, 200, 1, 1, 1 }, { 1, 200, 1, 1, 1 }, 0, -1, 1,
299 MKLDNNPlugin::impl_desc_type::unknown, { [](MKLDNNPlugin::PrimitiveDescInfo impl) {
300 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
301 ASSERT_EQ(1, impl.getConfig().inConfs.size());
302 ASSERT_EQ(1, impl.getConfig().outConfs.size());
303 ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
304 ASSERT_EQ(InferenceEngine::Layout::NCDHW, impl.getConfig().outConfs.at(0).desc.getLayout());