1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "test_graph.hpp"
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include "tests_common.hpp"
16 using namespace ::testing;
18 using namespace mkldnn;
20 struct shuffle_channels_test_params {
21 InferenceEngine::SizeVector in_out_shape;
25 std::vector<float> reference;
26 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
29 void ref_shuffle_channels(
30 InferenceEngine::TBlob<float> &src,
31 InferenceEngine::TBlob<float> &dst,
36 const float *src_data = src.data();
37 InferenceEngine::SizeVector src_dims = src.getTensorDesc().getDims();
38 InferenceEngine::SizeVector srcStrides = src.getTensorDesc().getBlockingDesc().getStrides();
39 float* dst_data = dst.data();
40 InferenceEngine::SizeVector dst_dims = dst.getTensorDesc().getDims();
41 InferenceEngine::SizeVector dstStrides = dst.getTensorDesc().getBlockingDesc().getStrides();
44 axis += dst_dims.size();
46 if (axis < 0 || axis >= dst_dims.size())
47 FAIL() << "Incorrect input parameters dimensions and axis number!";
49 if (dst_dims[axis] % group)
50 FAIL() << "Group parameter must evenly divide the channel dimension!";
52 // Find number of dictionaries, index range and data length
53 size_t numDictionaries = 1;
54 for (i = 0; i <= axis; i++)
55 numDictionaries *= dst_dims[i];
57 size_t channelsNum = dst_dims[axis] / group;
59 size_t dataLength = 1;
60 for (i = axis + 1; i < dst_dims.size(); i++)
61 dataLength *= dst_dims[i];
64 FAIL() << "Incorrect input parameters dimension!";
67 for (j = 0, k = 0; j < numDictionaries; j += dst_dims[axis]) {
68 for (i = 0; i < (dst_dims[axis] * channelsNum); i += channelsNum, k += dataLength) {
69 int idx = j + i / dst_dims[axis] + i % dst_dims[axis];
70 memcpy(&dst_data[k], &src_data[dataLength * idx], sizeof(float) * dataLength);
75 class MKLDNNCPUExtShuffleChannelsTests : public TestsCommon, public WithParamInterface<shuffle_channels_test_params> {
76 std::string model_t = R"V0G0N(
77 <net Name="ShuffleChannels_net" version="2" precision="FP32" batch="1">
79 <layer name="input" type="Input" precision="FP32" id="1">
86 <layer name="output" id="2" type="ShuffleChannels" precision="FP32">
87 <data axis="_AX_" group="_GR_"/>
101 <edge from-layer="1" from-port="1" to-layer="2" to-port="1"/>
106 std::string getModel(shuffle_channels_test_params p) {
107 std::string model = model_t;
108 std::string in_out_shape;
110 for (size_t i = 0; i < p.in_out_shape.size(); i++) {
111 in_out_shape += "<dim>";
112 in_out_shape += std::to_string(p.in_out_shape[i]) + "</dim>\n";
114 REPLACE_WITH_STR(model, "_IN_OUT_", in_out_shape);
115 REPLACE_WITH_NUM(model, "_AX_", p.axis);
116 REPLACE_WITH_NUM(model, "_GR_", p.group);
122 virtual void TearDown() {
125 virtual void SetUp() {
127 TestsCommon::SetUp();
128 shuffle_channels_test_params p = ::testing::WithParamInterface<shuffle_channels_test_params>::GetParam();
129 std::string model = getModel(p);
130 ////std::cout << model;
131 InferenceEngine::CNNNetReader net_reader;
132 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
134 MKLDNNGraphTestClass graph;
135 graph.CreateGraph(net_reader.getNetwork());
138 InferenceEngine::OutputsDataMap out;
139 out = net_reader.getNetwork().getOutputsInfo();
140 InferenceEngine::BlobMap outputBlobs;
142 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
144 InferenceEngine::TBlob<float>::Ptr output;
145 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
147 outputBlobs[item.first] = output;
150 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
154 InferenceEngine::Blob::Ptr src;
155 src = InferenceEngine::make_shared_blob<float>({ InferenceEngine::Precision::FP32, p.in_out_shape, InferenceEngine::TensorDesc::getLayoutByDims(p.in_out_shape) });
157 fill_data_dbgval(src->buffer(), src->size());
158 auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
159 if (srcPtr == nullptr)
160 FAIL() << "Cannot cast blob to TBlob<float>.";
163 InferenceEngine::SizeVector out_dims;
164 ref_shuffle_channels(*srcPtr, dst_ref, p.axis, p.group);
167 if (memcmp(dst_ref.data(), &p.reference[0], p.reference.size() * sizeof(float)) != 0)
168 FAIL() << "Wrong result with compare TF reference!";
170 InferenceEngine::BlobMap srcs;
171 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("input", src));
174 graph.Infer(srcs, outputBlobs);
175 compare(*output, dst_ref);
176 } catch (const InferenceEngine::details::InferenceEngineException &e) {
183 TEST_P(MKLDNNCPUExtShuffleChannelsTests, TestsShuffleChannels) {}
186 static std::vector<float> test0 = { 0.f, 1.f, 2.f, 3.f, 12.f, 13.f, 14.f, 15.f, 24.f, 25.f, 26.f, 27.f, 36.f, 37.f, 38.f, 39.f, 48.f, 49.f, 50.f, 51.f,
187 4.f, 5.f, 6.f, 7.f, 16.f, 17.f, 18.f, 19.f, 28.f, 29.f, 30.f, 31.f, 40.f, 41.f, 42.f, 43.f, 52.f, 53.f, 54.f, 55.f,
188 8.f, 9.f, 10.f, 11.f, 20.f, 21.f, 22.f, 23.f, 32.f, 33.f, 34.f, 35.f, 44.f, 45.f, 46.f, 47.f, 56.f, 57.f, 58.f, 59.f };
189 static std::vector<float> test4 = { 0.f, 2.f, 4.f, 1.f, 3.f, 5.f, 6.f, 8.f, 10.f, 7.f, 9.f, 11.f, 12.f, 14.f, 16.f, 13.f, 15.f, 17.f, 18.f, 20.f, 22.f, 19.f, 21.f, 23.f };
190 static std::vector<float> test5 = { 0.f, 1.f, 4.f, 5.f, 8.f, 9.f, 2.f, 3.f, 6.f, 7.f, 10.f, 11.f, 12.f, 13.f, 16.f, 17.f, 20.f, 21.f, 14.f, 15.f, 18.f, 19.f, 22.f, 23.f };
191 static std::vector<float> test6 = { 0.f, 3.f, 1.f, 4.f, 2.f, 5.f, 6.f, 9.f, 7.f, 10.f, 8.f, 11.f, 12.f, 15.f, 13.f, 16.f, 14.f, 17.f, 18.f, 21.f, 19.f, 22.f, 20.f, 23.f };
192 static std::vector<float> test7 = { 0.f, 1.f, 6.f, 7.f, 2.f, 3.f, 8.f, 9.f, 4.f, 5.f, 10.f, 11.f, 12.f, 13.f, 18.f, 19.f, 14.f, 15.f, 20.f, 21.f, 16.f, 17.f, 22.f, 23.f };
193 static std::vector<float> test8 = { 0.f, 3.f, 1.f, 4.f, 2.f, 5.f };
195 INSTANTIATE_TEST_CASE_P(
196 TestsShuffleChannels, MKLDNNCPUExtShuffleChannelsTests,
198 // Params: in_out_shape, axis, group, reference
199 /* 0 */ shuffle_channels_test_params{ { 1, 15, 2, 2 }, 1, 5, test0 },
200 shuffle_channels_test_params{ { 1, 15, 2, 2 }, -3, 5, test0 },
201 shuffle_channels_test_params{ { 15, 2, 2 }, 0, 5, test0 },
202 shuffle_channels_test_params{ { 15, 2, 2 }, -3, 5, test0 },
203 shuffle_channels_test_params{ { 2, 2, 6 }, -1, 3, test4 },
204 /* 5 */ shuffle_channels_test_params{ { 2, 6, 2 }, -2, 3, test5 },
205 shuffle_channels_test_params{ { 2, 2, 6 }, -1, 2, test6 },
206 shuffle_channels_test_params{ { 2, 6, 2 }, -2, 2, test7 },
207 shuffle_channels_test_params{ { 6 }, 0, 2, test8 }