1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <ie_iextension.h>
9 #include <tests_common.hpp>
10 #include <mkldnn_plugin/mkldnn_extension_mngr.h>
11 #include "graph/test_graph.hpp"
13 using namespace ::testing;
15 class ConstLayerImpl : public InferenceEngine::ILayerExecImpl {
17 explicit ConstLayerImpl(const InferenceEngine::CNNLayer *layer): cnnLayer(*layer) {}
18 InferenceEngine::StatusCode getSupportedConfigurations(std::vector<InferenceEngine::LayerConfig>& conf, InferenceEngine::ResponseDesc *resp) noexcept override {
19 InferenceEngine::LayerConfig config;
20 config.dynBatchSupport = 0;
21 if (cnnLayer.outData.size() != 1 && cnnLayer.insData.size() != 1)
22 return InferenceEngine::GENERAL_ERROR;
23 InferenceEngine::DataConfig cfg;
26 InferenceEngine::SizeVector order;
27 for(size_t i = 0; i < cnnLayer.outData[0]->getTensorDesc().getDims().size(); i++) {
30 cfg.desc = InferenceEngine::TensorDesc(cnnLayer.outData[0]->getTensorDesc().getPrecision(),
31 cnnLayer.outData[0]->getTensorDesc().getDims(),
32 {cnnLayer.outData[0]->getTensorDesc().getDims(), order});
33 config.outConfs.push_back(cfg);
34 config.inConfs.push_back(cfg);
35 conf.push_back(config);
36 return InferenceEngine::OK;
39 InferenceEngine::StatusCode init(InferenceEngine::LayerConfig& config, InferenceEngine::ResponseDesc *resp) noexcept override {
40 value = cnnLayer.GetParamAsInt("const_val", 1);
41 if (config.dynBatchSupport)
42 return InferenceEngine::NOT_IMPLEMENTED;
43 for(auto input : config.inConfs) {
45 return InferenceEngine::GENERAL_ERROR;
47 for(auto output : config.outConfs) {
49 return InferenceEngine::GENERAL_ERROR;
51 return InferenceEngine::OK;
53 InferenceEngine::StatusCode execute(std::vector<InferenceEngine::Blob::Ptr>& inputs, std::vector<InferenceEngine::Blob::Ptr>& outputs, InferenceEngine::ResponseDesc *resp) noexcept override {
54 float *dst_data = outputs[0]->buffer();
56 size_t data_size = outputs[0]->size();
57 for (size_t i = 0; i < data_size; i++) {
60 return InferenceEngine::OK;
64 InferenceEngine::CNNLayer cnnLayer;
68 class ConstLayerFactory : public InferenceEngine::ILayerImplFactory {
70 ConstLayerFactory(const InferenceEngine::CNNLayer *layer): cnnLayer(*layer) {}
71 // set output shapes by input shapes.
72 InferenceEngine::StatusCode getShapes(const std::vector<InferenceEngine::TensorDesc>& inShapes, std::vector<InferenceEngine::TensorDesc>& outShapes, InferenceEngine::ResponseDesc *resp) noexcept override {
73 outShapes.push_back(inShapes[0]);
74 return InferenceEngine::OK;
76 // First implementation has more priority than next
77 InferenceEngine::StatusCode getImplementations(std::vector<InferenceEngine::ILayerImpl::Ptr>& impls, InferenceEngine::ResponseDesc *resp) noexcept override {
78 impls.push_back(InferenceEngine::ILayerImpl::Ptr(new ConstLayerImpl(&cnnLayer)));
79 return InferenceEngine::OK;
83 InferenceEngine::CNNLayer cnnLayer;
86 using fake_ext_factory = std::function<InferenceEngine::ILayerImplFactory*(const InferenceEngine::CNNLayer *)>;
88 class FakeConstExtensionFabric : public InferenceEngine::IExtension {
90 FakeConstExtensionFabric() {
91 factories["ConstLayer"] = [](const InferenceEngine::CNNLayer * cnnLayer) -> InferenceEngine::ILayerImplFactory* { return new ConstLayerFactory(cnnLayer); };
94 virtual ~FakeConstExtensionFabric() {
98 void GetVersion(const InferenceEngine::Version *&versionInfo) const noexcept override {}
99 void SetLogCallback(InferenceEngine::IErrorListener &listener) noexcept override {}
100 void Unload() noexcept override {}
101 void Release() noexcept override {
104 InferenceEngine::StatusCode getPrimitiveTypes(char**& types, unsigned int& size, InferenceEngine::ResponseDesc* resp) noexcept override {
105 types = new char *[factories.size()];
107 for (auto it = factories.begin(); it != factories.end(); it++, count ++) {
108 types[count] = new char[it->first.size() + 1];
109 std::copy(it->first.begin(), it->first.end(), types[count]);
110 types[count][it->first.size() ] = '\0';
112 return InferenceEngine::OK;
114 InferenceEngine::StatusCode getFactoryFor(InferenceEngine::ILayerImplFactory *&factory,
115 const InferenceEngine::CNNLayer *cnnLayer,
116 InferenceEngine::ResponseDesc *resp) noexcept override {
117 if (factories.find(cnnLayer->type) == factories.end()) {
118 std::string errorMsg = std::string("Factory for ") + cnnLayer->type + " wasn't found!";
119 errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
120 return InferenceEngine::NOT_FOUND;
122 factory = factories[cnnLayer->type](cnnLayer);
123 return InferenceEngine::OK;
126 InferenceEngine::StatusCode getShapeInferImpl(InferenceEngine::IShapeInferImpl::Ptr& impl, const char* type,
127 InferenceEngine::ResponseDesc* resp) noexcept override {
128 return InferenceEngine::NOT_IMPLEMENTED;
132 std::map<std::string, fake_ext_factory> factories;
135 class MKLDNNConstantPropagationTests: public TestsCommon {
137 virtual void SetUp() {
138 TestsCommon::SetUp();
139 extension.reset(new FakeConstExtensionFabric());
140 extMgr.reset(new MKLDNNPlugin::MKLDNNExtensionManager());
141 extMgr->AddExtension(extension);
143 MKLDNNPlugin::MKLDNNExtensionManager::Ptr extMgr;
144 std::shared_ptr<InferenceEngine::IExtension> extension;
147 TEST_F(MKLDNNConstantPropagationTests, ConcatAfterConstLayers) {
148 std::string model = R"V0G0N(
149 <Net Name="CustomConcat_Only" version="2" precision="FP32" batch="1">
151 <layer name="in1" type="Input" precision="FP32" id="0">
161 <layer name="in2" type="Input" precision="FP32" id="1">
171 <layer name="const1" type="ConstLayer" precision="FP32" id="2">
189 <layer name="const2" type="ConstLayer" precision="FP32" id="3">
190 <data const_val="4"/>
208 <layer name="con" id="4" type="Concat" precision="FP32">
209 <concat_data axis="2"/>
235 <edge from-layer="0" from-port="0" to-layer="2" to-port="0"/>
236 <edge from-layer="1" from-port="0" to-layer="3" to-port="0"/>
237 <edge from-layer="2" from-port="1" to-layer="4" to-port="1"/>
238 <edge from-layer="3" from-port="1" to-layer="4" to-port="2"/>
243 InferenceEngine::CNNNetReader net_reader;
244 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
246 MKLDNNGraphTestClass graph;
247 graph.CreateGraph(net_reader.getNetwork(), extMgr);
249 InferenceEngine::SizeVector dims_src1 = {1, 2, 10, 5};
251 InferenceEngine::Blob::Ptr src1 =
252 InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src1);
255 InferenceEngine::SizeVector dims_src2 = {1, 2, 5, 5};
257 InferenceEngine::Blob::Ptr src2 =
258 InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src2);
261 InferenceEngine::BlobMap srcs;
262 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
263 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
265 InferenceEngine::OutputsDataMap out;
266 out = net_reader.getNetwork().getOutputsInfo();
267 InferenceEngine::BlobMap outputBlobs;
269 auto it = out.begin();
271 std::pair<std::string, InferenceEngine::DataPtr> item = *it;
273 InferenceEngine::TensorDesc outputDesc1 = item.second->getTensorDesc();
274 InferenceEngine::TBlob<float>::Ptr output;
275 output = InferenceEngine::make_shared_blob<float>(outputDesc1);
277 outputBlobs[item.first] = output;
279 auto& nodes = graph.getNodes();
280 bool existConcat = false;
281 for (auto& node : nodes) {
282 if (node->getType() != MKLDNNPlugin::Concatenation && node->getType() != MKLDNNPlugin::Generic)
284 if (node->getName() == "con" && node->getType() == MKLDNNPlugin::Concatenation)
286 ASSERT_TRUE(node->isConstant());
289 ASSERT_TRUE(existConcat);
291 graph.Infer(srcs, outputBlobs);
294 float *dst_ptr = output->buffer();
296 int len1 = 1, len2 = 1, cycles;
297 for (int dim = 2; dim < output->dims().size(); dim++) {
298 len1 *= src1->dims()[dim];
299 len2 *= src2->dims()[dim];
303 int index1 = 0, index2 = 0, index = 0;
304 for (int cycle = 0; cycle < cycles; cycle ++) {
305 for (int i1 = 0; i1 < len1; i1++) {
306 if (1 != dst_ptr[index]) {
307 FAIL() << "index: " << index << " src: " << 1 << ", dst: " << dst_ptr[index];
311 for (int i2 = 0; i2 < len2; i2++) {
312 if (4 != dst_ptr[index]) {
313 FAIL() << "index: " << index << " src: " << 4 << ", dst: " << dst_ptr[index];