1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
10 #include "single_layer_common.hpp"
12 #include "test_graph.hpp"
14 #include <mkldnn_plugin/mkldnn_extension_utils.h>
15 #include "tests_common.hpp"
18 using namespace ::testing;
20 using namespace mkldnn;
23 struct reshape_test_params {
24 InferenceEngine::SizeVector in;
25 InferenceEngine::SizeVector out;
26 std::vector<size_t> shape;
33 MKLDNNPlugin::impl_desc_type selectedType;
35 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
38 template <typename data_t>
39 void ref_reshape(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst) {
40 const data_t *src_data = src.readOnly();
41 data_t *dst_data = dst.data();
43 #pragma omp parallel for
44 for (int i=0; i < src.size(); i++)
45 dst_data[i] = src_data[i];
48 class MKLDNNGraphReshapeTests: public TestsCommon,
49 public WithParamInterface<reshape_test_params> {
50 std::string model_t = R"V0G0N(
51 <Net Name="Reshape_Only" version="2" precision="FP32" batch="1">
53 <layer name="in1" type="Input" precision="FP32" id="0">
60 <layer name="norm" id="1" type="Reshape" precision="FP32">
61 <data dim="_SHAPE_" axis="_AX_" num_axes="_NAX_"/>
76 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
81 std::string getModel(reshape_test_params p) {
82 std::string model = model_t;
85 for (auto& dim : p.in) {
87 src_dims += std::to_string(dim) + "</dim>\n";
89 REPLACE_WITH_STR(model, "__SRC_DIMS__", src_dims);
92 for (auto& dim : p.out) {
94 dst_dims += std::to_string(dim) + "</dim>\n";
96 REPLACE_WITH_STR(model, "__DST_DIMS__", dst_dims);
98 REPLACE_WITH_NUM(model, "_AX_", p.axis);
99 REPLACE_WITH_NUM(model, "_NAX_", p.num_axes);
101 std::string shape_str;
102 for (auto& dim : p.shape) {
103 if (!shape_str.empty())
105 shape_str += std::to_string(dim);
107 REPLACE_WITH_STR(model, "_SHAPE_", shape_str);
112 virtual void TearDown() {
115 virtual void SetUp() {
117 TestsCommon::SetUp();
118 reshape_test_params p = ::testing::WithParamInterface<reshape_test_params>::GetParam();
119 std::string model = getModel(p);
121 InferenceEngine::CNNNetReader net_reader;
122 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
124 MKLDNNGraphTestClass graph;
125 graph.CreateGraph(net_reader.getNetwork());
126 auto& nodes = graph.getNodes();
127 for (int i = 0; i < nodes.size(); i++) {
128 if (nodes[i]->getType() == MKLDNNPlugin::Reshape) {
129 ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
130 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
131 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
133 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
134 ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
138 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::ANY, p.in);
140 fill_data(src->buffer(), src->size());
142 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
144 if (srcPtr == nullptr)
145 FAIL() << "Cannot cast blob to TBlob<float>.";
147 InferenceEngine::BlobMap srcs;
148 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
150 InferenceEngine::OutputsDataMap out;
151 out = net_reader.getNetwork().getOutputsInfo();
152 InferenceEngine::BlobMap outputBlobs;
154 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
156 InferenceEngine::TBlob<float>::Ptr output;
157 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
159 outputBlobs[item.first] = output;
161 graph.Infer(srcs, outputBlobs);
163 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
166 ref_reshape(*srcPtr, dst_ref);
168 compare(*output, dst_ref);
169 } catch (const InferenceEngine::details::InferenceEngineException &e) {
175 TEST_P(MKLDNNGraphReshapeTests, TestsReshape) {}
178 INSTANTIATE_TEST_CASE_P(
179 TestsReshape, MKLDNNGraphReshapeTests,
181 reshape_test_params{ {1, 3, 228, 228}, {1, 24, 2, 3249}, {1, 24, 2, 3249}, 0, -1, 2,
182 MKLDNNPlugin::impl_desc_type::unknown, { [](MKLDNNPlugin::PrimitiveDescInfo impl) {
183 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
184 ASSERT_EQ(1, impl.getConfig().inConfs.size());
185 ASSERT_EQ(1, impl.getConfig().outConfs.size());
186 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
187 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
189 reshape_test_params{ { 4 },{ 2, 2 },{ 2, 2 }, 0, -1, 2,
190 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
191 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
192 ASSERT_EQ(1, impl.getConfig().inConfs.size());
193 ASSERT_EQ(1, impl.getConfig().outConfs.size());
194 ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().inConfs.at(0).desc.getLayout());
195 ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().outConfs.at(0).desc.getLayout());
197 reshape_test_params{ { 4 },{ 1, 2, 2 },{ 1, 2, 2 }, 0, -1, 2,
198 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
199 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
200 ASSERT_EQ(1, impl.getConfig().inConfs.size());
201 ASSERT_EQ(1, impl.getConfig().outConfs.size());
202 ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().inConfs.at(0).desc.getLayout());
203 ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().outConfs.at(0).desc.getLayout());
205 reshape_test_params{ { 4 },{ 1, 4, 1, 1 },{ 1, 4, 1, 1 }, 0, -1, 2,
206 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
207 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
208 ASSERT_EQ(1, impl.getConfig().inConfs.size());
209 ASSERT_EQ(1, impl.getConfig().outConfs.size());
210 ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().inConfs.at(0).desc.getLayout());
211 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
213 reshape_test_params{ { 4, 4 },{ 1, 4, 4 },{ 1, 4, 4 }, 0, -1, 2,
214 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
215 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
216 ASSERT_EQ(1, impl.getConfig().inConfs.size());
217 ASSERT_EQ(1, impl.getConfig().outConfs.size());
218 ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
219 ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().outConfs.at(0).desc.getLayout());
221 reshape_test_params{ { 4, 4 },{ 1, 4, 2, 2 },{ 1, 4, 2, 2 }, 0, -1, 2,
222 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
223 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
224 ASSERT_EQ(1, impl.getConfig().inConfs.size());
225 ASSERT_EQ(1, impl.getConfig().outConfs.size());
226 ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
227 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
229 reshape_test_params{ { 4, 2, 2 },{ 1, 4, 2, 2 },{ 1, 4, 2, 2 }, 0, -1, 2,
230 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
231 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
232 ASSERT_EQ(1, impl.getConfig().inConfs.size());
233 ASSERT_EQ(1, impl.getConfig().outConfs.size());
234 ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().inConfs.at(0).desc.getLayout());
235 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
237 reshape_test_params{ { 2, 2 },{ 4 },{ 4 }, 0, -1, 2,
238 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
239 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
240 ASSERT_EQ(1, impl.getConfig().inConfs.size());
241 ASSERT_EQ(1, impl.getConfig().outConfs.size());
242 ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().inConfs.at(0).desc.getLayout());
243 ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().outConfs.at(0).desc.getLayout());
245 reshape_test_params{ { 1, 2, 2 },{ 4 },{ 4 }, 0, -1, 2,
246 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
247 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
248 ASSERT_EQ(1, impl.getConfig().inConfs.size());
249 ASSERT_EQ(1, impl.getConfig().outConfs.size());
250 ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().inConfs.at(0).desc.getLayout());
251 ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().outConfs.at(0).desc.getLayout());
253 reshape_test_params{ { 1, 1, 2, 2 },{ 4 },{ 4 }, 0, -1, 2,
254 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
255 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
256 ASSERT_EQ(1, impl.getConfig().inConfs.size());
257 ASSERT_EQ(1, impl.getConfig().outConfs.size());
258 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
259 ASSERT_EQ(InferenceEngine::Layout::C, impl.getConfig().outConfs.at(0).desc.getLayout());
261 reshape_test_params{ { 4, 2, 2 },{ 4, 4 },{ 4, 4 }, 0, -1, 2,
262 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
263 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
264 ASSERT_EQ(1, impl.getConfig().inConfs.size());
265 ASSERT_EQ(1, impl.getConfig().outConfs.size());
266 ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().inConfs.at(0).desc.getLayout());
267 ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().outConfs.at(0).desc.getLayout());
269 reshape_test_params{ { 1, 4, 2, 2 },{ 4, 4 },{ 4, 4 }, 0, -1, 2,
270 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
271 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
272 ASSERT_EQ(1, impl.getConfig().inConfs.size());
273 ASSERT_EQ(1, impl.getConfig().outConfs.size());
274 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
275 ASSERT_EQ(InferenceEngine::Layout::NC, impl.getConfig().outConfs.at(0).desc.getLayout());
277 reshape_test_params{ { 1, 4, 2, 2 },{ 4, 2, 2 },{ 4, 2, 2 }, 0, -1, 2,
278 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
279 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
280 ASSERT_EQ(1, impl.getConfig().inConfs.size());
281 ASSERT_EQ(1, impl.getConfig().outConfs.size());
282 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
283 ASSERT_EQ(InferenceEngine::Layout::CHW, impl.getConfig().outConfs.at(0).desc.getLayout());
285 reshape_test_params{ { 1, 4, 2, 2 },{ 4, 2, 2, 1, 1 },{ 4, 2, 2, 1, 1 }, 0, -1, 2,
286 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
287 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
288 ASSERT_EQ(1, impl.getConfig().inConfs.size());
289 ASSERT_EQ(1, impl.getConfig().outConfs.size());
290 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
291 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
293 reshape_test_params{ { 4, 2, 2, 1, 1 },{ 1, 4, 2, 2 },{ 1, 4, 2, 2 }, 0, -1, 2,
294 MKLDNNPlugin::impl_desc_type::unknown,{ [](MKLDNNPlugin::PrimitiveDescInfo impl) {
295 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
296 ASSERT_EQ(1, impl.getConfig().inConfs.size());
297 ASSERT_EQ(1, impl.getConfig().outConfs.size());
298 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
299 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());