1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "test_graph.hpp"
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
17 using namespace ::testing;
19 using namespace mkldnn;
22 struct permute_test_params {
23 InferenceEngine::SizeVector dims;
24 InferenceEngine::SizeVector order;
27 MKLDNNPlugin::impl_desc_type selectedType;
29 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
32 template <typename data_t>
33 void ref_permute(const InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst, permute_test_params prm) {
34 const data_t *src_data = src.readOnly();
35 data_t *dst_data = dst.data();
37 InferenceEngine::SizeVector orderedDims;
38 for (auto ord : prm.order) {
39 orderedDims.push_back(src.getTensorDesc().getDims()[ord]);
41 InferenceEngine::TensorDesc desc(InferenceEngine::Precision::FP32, src.getTensorDesc().getDims(), {orderedDims, prm.order});
43 for (int i=0; i < src.size(); i++) {
44 dst_data[desc.offset(i)] = src_data[src.getTensorDesc().offset(i)];
48 class MKLDNNGraphPermuteTests: public TestsCommon,
49 public WithParamInterface<permute_test_params> {
50 std::string model_t = R"V0G0N(
51 <Net Name="Power_Only" version="2" precision="FP32" batch="1">
53 <layer name="in1" type="Input" precision="FP32" id="0">
60 <layer name="permute" id="1" type="Permute" precision="FP32">
61 <data order="_ORDER_"/>
75 <edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
81 std::string getModel(permute_test_params p) {
82 std::string model = model_t;
85 for (auto& dim : p.dims) {
87 dims += std::to_string(dim) + "</dim>\n";
91 for (auto& ord : p.order) {
94 order += std::to_string(ord);
96 dst_dims += std::to_string(p.dims[ord]) + "</dim>\n";
99 REPLACE_WITH_STR(model, "__DIMS__", dims);
100 REPLACE_WITH_STR(model, "__DST_DIMS__", dst_dims);
101 REPLACE_WITH_STR(model, "_ORDER_", order);
106 virtual void TearDown() {
109 virtual void SetUp() {
111 TestsCommon::SetUp();
112 permute_test_params p = ::testing::WithParamInterface<permute_test_params>::GetParam();
113 std::string model = getModel(p);
115 InferenceEngine::CNNNetReader net_reader;
116 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
118 MKLDNNGraphTestClass graph;
119 graph.CreateGraph(net_reader.getNetwork());
120 auto& nodes = graph.getNodes();
121 for (int i = 0; i < nodes.size(); i++) {
122 if (nodes[i]->getType() == MKLDNNPlugin::Permute) {
123 ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
124 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
125 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
127 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
128 ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
132 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float>({InferenceEngine::Precision::FP32, p.dims, InferenceEngine::TensorDesc::getLayoutByDims(p.dims)});
134 fill_data(src->buffer(), src->size());
136 auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
138 if (srcPtr == nullptr)
139 FAIL() << "Cannot cast blob to TBlob<float>.";
141 InferenceEngine::BlobMap srcs;
142 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
144 InferenceEngine::OutputsDataMap out;
145 out = net_reader.getNetwork().getOutputsInfo();
146 InferenceEngine::BlobMap outputBlobs;
148 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
150 InferenceEngine::TBlob<float>::Ptr output;
151 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
153 outputBlobs[item.first] = output;
155 graph.Infer(srcs, outputBlobs);
157 InferenceEngine::TensorDesc td(InferenceEngine::Precision::FP32, p.dims, InferenceEngine::TensorDesc::getLayoutByDims(p.dims));
158 InferenceEngine::TBlob<float> dst_ref(td);
161 ref_permute(*srcPtr, dst_ref, p);
163 compare(*output, dst_ref);
164 } catch (const InferenceEngine::details::InferenceEngineException &e) {
170 TEST_P(MKLDNNGraphPermuteTests, TestsPermute) {}
172 INSTANTIATE_TEST_CASE_P(
173 TestsPermute, MKLDNNGraphPermuteTests,
175 permute_test_params{{2, 3, 4, 5}, {0, 1, 2, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
176 permute_test_params{{2, 3, 4, 5}, {0, 2, 3, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
177 permute_test_params{{2, 3, 4, 5}, {3, 0, 1, 2}, 1, MKLDNNPlugin::impl_desc_type::unknown},
178 permute_test_params{{2, 3, 4, 5}, {1, 3, 2, 0}, 1, MKLDNNPlugin::impl_desc_type::unknown},
179 permute_test_params{{2, 3, 4, 5}, {3, 2, 1, 0}, 1, MKLDNNPlugin::impl_desc_type::unknown},
180 permute_test_params{{2, 3, 4, 5}, {0, 2, 1, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
181 permute_test_params{{2, 3, 4}, {0, 1, 2}, 1, MKLDNNPlugin::impl_desc_type::unknown},
182 permute_test_params{{2, 3, 4}, {0, 2, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
183 permute_test_params{{2, 3, 4}, {2, 1, 0}, 1, MKLDNNPlugin::impl_desc_type::unknown},
184 permute_test_params{{2, 3, 4}, {1, 2, 0}, 1, MKLDNNPlugin::impl_desc_type::unknown},
185 permute_test_params{{2, 3, 4}, {2, 0, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
186 permute_test_params{{2, 3, 4}, {1, 0, 2}, 1, MKLDNNPlugin::impl_desc_type::unknown},
187 permute_test_params{{2, 3}, {1, 0}, 1, MKLDNNPlugin::impl_desc_type::unknown},
188 permute_test_params{{2, 3}, {0, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
189 permute_test_params{{2, 3, 4, 5, 6}, {0, 1, 2, 3, 4}, 1, MKLDNNPlugin::impl_desc_type::unknown},
190 permute_test_params{{2, 3, 4, 5, 6}, {0, 4, 2, 1, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
191 permute_test_params{{2, 3, 4, 5, 6}, {0, 2, 4, 3, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
192 permute_test_params{{2, 3, 4, 5, 6}, {0, 3, 2, 4, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
193 permute_test_params{{2, 8, 2, 2, 4, 5}, {0, 1, 4, 2, 5, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
194 permute_test_params{{2, 8, 3, 3, 4, 5}, {0, 1, 4, 2, 5, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
195 permute_test_params{{2, 8, 3, 4}, {3, 0, 1, 2}, 2, MKLDNNPlugin::impl_desc_type::unknown},
196 permute_test_params{{2, 12, 9}, {0, 2, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
197 permute_test_params{{2, 8, 3, 3, 4, 5}, {0, 3, 4, 1, 5, 2}, 1, MKLDNNPlugin::impl_desc_type::unknown}
200 class MKLDNNGraphDynBatchPermuteTests: public MKLDNNGraphPermuteTests {
202 virtual void SetUp() {
204 TestsCommon::SetUp();
205 permute_test_params p = ::testing::WithParamInterface<permute_test_params>::GetParam();
206 std::string model = getModel(p);
207 size_t MB = p.dims[0];
212 InferenceEngine::CNNNetReader net_reader;
213 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
214 InferenceEngine::CNNNetwork network = net_reader.getNetwork();
215 auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
216 ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
217 InferenceEngine::ResponseDesc resp;
218 InferenceEngine::StatusCode sts = implNet->setBatchSizeReshape(MB, &resp);
219 ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
221 MKLDNNGraphTestClass graph;
222 graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
223 graph.CreateGraph(net_reader.getNetwork());
225 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float>({InferenceEngine::Precision::FP32, p.dims, InferenceEngine::TensorDesc::getLayoutByDims(p.dims)});
227 fill_data(src->buffer(), src->size());
229 auto * srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
231 if (srcPtr == nullptr)
232 FAIL() << "Cannot cast blob to TBlob<float>.";
234 InferenceEngine::BlobMap srcs;
235 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
237 InferenceEngine::OutputsDataMap out;
238 out = net_reader.getNetwork().getOutputsInfo();
239 InferenceEngine::BlobMap outputBlobs;
241 std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
243 InferenceEngine::TBlob<float>::Ptr output;
244 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
246 outputBlobs[item.first] = output;
248 auto checkPermute = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
249 return node->getType() == MKLDNNPlugin::Permute;
251 graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkPermute);
252 graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkPermute);
253 } catch (const InferenceEngine::details::InferenceEngineException &e) {
259 TEST_P(MKLDNNGraphDynBatchPermuteTests, TestsDynBatchPermute) {}
261 INSTANTIATE_TEST_CASE_P(
262 TestsDynBatchPermute, MKLDNNGraphDynBatchPermuteTests,
264 permute_test_params{{2, 3, 4, 5}, {0, 1, 2, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
265 permute_test_params{{2, 3, 4, 5}, {0, 2, 3, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
266 permute_test_params{{2, 3, 4, 5}, {0, 2, 1, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
267 permute_test_params{{2, 3, 4}, {0, 1, 2}, 1, MKLDNNPlugin::impl_desc_type::unknown},
268 permute_test_params{{2, 3, 4}, {0, 2, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
269 permute_test_params{{2, 3}, {0, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
270 permute_test_params{{2, 3, 4, 5, 6}, {0, 1, 2, 3, 4}, 1, MKLDNNPlugin::impl_desc_type::unknown},
271 permute_test_params{{2, 3, 4, 5, 6}, {0, 4, 2, 1, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
272 permute_test_params{{2, 3, 4, 5, 6}, {0, 2, 4, 3, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
273 permute_test_params{{2, 3, 4, 5, 6}, {0, 3, 2, 4, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
274 permute_test_params{{2, 8, 2, 2, 4, 5}, {0, 1, 4, 2, 5, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
275 permute_test_params{{2, 8, 3, 3, 4, 5}, {0, 1, 4, 2, 5, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown},
276 permute_test_params{{2, 12, 9}, {0, 2, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown},
277 permute_test_params{{2, 8, 3, 3, 4, 5}, {0, 3, 4, 1, 5, 2}, 1, MKLDNNPlugin::impl_desc_type::unknown}