Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / constant_propagation_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <ie_iextension.h>
7 #include <ie_common.h>
8 #include <ie_layers.h>
9 #include <tests_common.hpp>
10 #include <mkldnn_plugin/mkldnn_extension_mngr.h>
11 #include "graph/test_graph.hpp"
12
13 using namespace ::testing;
14
15 class ConstLayerImpl : public InferenceEngine::ILayerExecImpl {
16 public:
17     explicit ConstLayerImpl(const InferenceEngine::CNNLayer *layer): cnnLayer(*layer) {}
18     InferenceEngine::StatusCode getSupportedConfigurations(std::vector<InferenceEngine::LayerConfig>& conf, InferenceEngine::ResponseDesc *resp) noexcept override {
19         InferenceEngine::LayerConfig config;
20         config.dynBatchSupport = 0;
21         if (cnnLayer.outData.size() != 1 && cnnLayer.insData.size() != 1)
22             return InferenceEngine::GENERAL_ERROR;
23         InferenceEngine::DataConfig cfg;
24         cfg.constant = true;
25         cfg.inPlace = 0;
26         InferenceEngine::SizeVector order;
27         for(size_t i = 0; i < cnnLayer.outData[0]->getTensorDesc().getDims().size(); i++) {
28             order.push_back(i);
29         }
30         cfg.desc = InferenceEngine::TensorDesc(cnnLayer.outData[0]->getTensorDesc().getPrecision(),
31                                                cnnLayer.outData[0]->getTensorDesc().getDims(),
32                                                {cnnLayer.outData[0]->getTensorDesc().getDims(), order});
33         config.outConfs.push_back(cfg);
34         config.inConfs.push_back(cfg);
35         conf.push_back(config);
36         return InferenceEngine::OK;
37     }
38
39     InferenceEngine::StatusCode init(InferenceEngine::LayerConfig& config, InferenceEngine::ResponseDesc *resp) noexcept override {
40         value = cnnLayer.GetParamAsInt("const_val", 1);
41         if (config.dynBatchSupport)
42             return InferenceEngine::NOT_IMPLEMENTED;
43         for(auto input : config.inConfs) {
44             if (!input.constant)
45                 return InferenceEngine::GENERAL_ERROR;
46         }
47         for(auto output : config.outConfs) {
48             if (!output.constant)
49                 return InferenceEngine::GENERAL_ERROR;
50         }
51         return InferenceEngine::OK;
52     }
53     InferenceEngine::StatusCode execute(std::vector<InferenceEngine::Blob::Ptr>& inputs, std::vector<InferenceEngine::Blob::Ptr>& outputs, InferenceEngine::ResponseDesc *resp) noexcept override {
54         float *dst_data = outputs[0]->buffer();
55
56         size_t data_size = outputs[0]->size();
57         for (size_t i = 0; i < data_size; i++) {
58             dst_data[i] = value;
59         }
60         return InferenceEngine::OK;
61     }
62
63 private:
64     InferenceEngine::CNNLayer cnnLayer;
65     int value = 0;
66 };
67
68 class ConstLayerFactory : public InferenceEngine::ILayerImplFactory {
69 public:
70     ConstLayerFactory(const InferenceEngine::CNNLayer *layer): cnnLayer(*layer) {}
71     // set output shapes by input shapes.
72     InferenceEngine::StatusCode getShapes(const std::vector<InferenceEngine::TensorDesc>& inShapes, std::vector<InferenceEngine::TensorDesc>& outShapes, InferenceEngine::ResponseDesc *resp) noexcept override {
73         outShapes.push_back(inShapes[0]);
74         return InferenceEngine::OK;
75     }
76     // First implementation has more priority than next
77     InferenceEngine::StatusCode getImplementations(std::vector<InferenceEngine::ILayerImpl::Ptr>& impls, InferenceEngine::ResponseDesc *resp) noexcept override {
78         impls.push_back(InferenceEngine::ILayerImpl::Ptr(new ConstLayerImpl(&cnnLayer)));
79         return InferenceEngine::OK;
80     }
81
82 private:
83     InferenceEngine::CNNLayer cnnLayer;
84 };
85
86 using fake_ext_factory = std::function<InferenceEngine::ILayerImplFactory*(const InferenceEngine::CNNLayer *)>;
87
88 class FakeConstExtensionFabric : public InferenceEngine::IExtension {
89 public:
90     FakeConstExtensionFabric() {
91         factories["ConstLayer"] = [](const InferenceEngine::CNNLayer * cnnLayer) -> InferenceEngine::ILayerImplFactory* { return new ConstLayerFactory(cnnLayer); };
92     }
93
94     virtual ~FakeConstExtensionFabric() {
95         factories.clear();
96     }
97
98     void GetVersion(const InferenceEngine::Version *&versionInfo) const noexcept override {}
99     void SetLogCallback(InferenceEngine::IErrorListener &listener) noexcept override {}
100     void Unload() noexcept override {}
101     void Release() noexcept override {
102         delete this;
103     }
104     InferenceEngine::StatusCode getPrimitiveTypes(char**& types, unsigned int& size, InferenceEngine::ResponseDesc* resp) noexcept override {
105         types = new char *[factories.size()];
106         size_t count = 0;
107         for (auto it = factories.begin(); it != factories.end(); it++, count ++) {
108             types[count] = new char[it->first.size() + 1];
109             std::copy(it->first.begin(), it->first.end(), types[count]);
110             types[count][it->first.size() ] = '\0';
111         }
112         return InferenceEngine::OK;
113     };
114     InferenceEngine::StatusCode getFactoryFor(InferenceEngine::ILayerImplFactory *&factory,
115                                               const InferenceEngine::CNNLayer *cnnLayer,
116                                               InferenceEngine::ResponseDesc *resp) noexcept override {
117         if (factories.find(cnnLayer->type) == factories.end()) {
118             std::string errorMsg = std::string("Factory for ") + cnnLayer->type + " wasn't found!";
119             errorMsg.copy(resp->msg, sizeof(resp->msg) - 1);
120             return InferenceEngine::NOT_FOUND;
121         }
122         factory = factories[cnnLayer->type](cnnLayer);
123         return InferenceEngine::OK;
124     }
125
126     InferenceEngine::StatusCode getShapeInferImpl(InferenceEngine::IShapeInferImpl::Ptr& impl, const char* type,
127                                                   InferenceEngine::ResponseDesc* resp) noexcept override {
128         return InferenceEngine::NOT_IMPLEMENTED;
129     }
130
131 private:
132     std::map<std::string, fake_ext_factory> factories;
133 };
134
135 class MKLDNNConstantPropagationTests: public TestsCommon {
136 protected:
137     virtual void SetUp() {
138         TestsCommon::SetUp();
139         extension.reset(new FakeConstExtensionFabric());
140         extMgr.reset(new MKLDNNPlugin::MKLDNNExtensionManager());
141         extMgr->AddExtension(extension);
142     }
143     MKLDNNPlugin::MKLDNNExtensionManager::Ptr extMgr;
144     std::shared_ptr<InferenceEngine::IExtension> extension;
145 };
146
147 TEST_F(MKLDNNConstantPropagationTests, ConcatAfterConstLayers) {
148     std::string model = R"V0G0N(
149         <Net Name="CustomConcat_Only" version="2" precision="FP32" batch="1">
150             <layers>
151                 <layer name="in1" type="Input" precision="FP32" id="0">
152                     <output>
153                         <port id="0">
154                             <dim>1</dim>
155                             <dim>2</dim>
156                             <dim>10</dim>
157                             <dim>5</dim>
158                         </port>
159                     </output>
160                 </layer>
161                 <layer name="in2" type="Input" precision="FP32" id="1">
162                     <output>
163                         <port id="0">
164                             <dim>1</dim>
165                             <dim>2</dim>
166                             <dim>5</dim>
167                             <dim>5</dim>
168                         </port>
169                     </output>
170                 </layer>
171                 <layer name="const1" type="ConstLayer" precision="FP32" id="2">
172                     <input>
173                         <port id="0">
174                             <dim>1</dim>
175                             <dim>2</dim>
176                             <dim>10</dim>
177                             <dim>5</dim>
178                         </port>
179                     </input>
180                     <output>
181                         <port id="1">
182                             <dim>1</dim>
183                             <dim>2</dim>
184                             <dim>10</dim>
185                             <dim>5</dim>
186                         </port>
187                     </output>
188                 </layer>
189                 <layer name="const2" type="ConstLayer" precision="FP32" id="3">
190                     <data const_val="4"/>
191                     <input>
192                         <port id="0">
193                             <dim>1</dim>
194                             <dim>2</dim>
195                             <dim>5</dim>
196                             <dim>5</dim>
197                         </port>
198                     </input>
199                     <output>
200                         <port id="1">
201                             <dim>1</dim>
202                             <dim>2</dim>
203                             <dim>5</dim>
204                             <dim>5</dim>
205                         </port>
206                     </output>
207                 </layer>
208                 <layer name="con" id="4" type="Concat" precision="FP32">
209                     <concat_data axis="2"/>
210                     <input>
211                         <port id="1">
212                             <dim>1</dim>
213                             <dim>2</dim>
214                             <dim>10</dim>
215                             <dim>5</dim>
216                         </port>
217                         <port id="2">
218                             <dim>1</dim>
219                             <dim>2</dim>
220                             <dim>5</dim>
221                             <dim>5</dim>
222                         </port>
223                     </input>
224                     <output>
225                         <port id="3">
226                             <dim>1</dim>
227                             <dim>2</dim>
228                             <dim>15</dim>
229                             <dim>5</dim>
230                         </port>
231                     </output>
232                 </layer>
233             </layers>
234             <edges>
235                 <edge from-layer="0" from-port="0" to-layer="2" to-port="0"/>
236                 <edge from-layer="1" from-port="0" to-layer="3" to-port="0"/>
237                 <edge from-layer="2" from-port="1" to-layer="4" to-port="1"/>
238                 <edge from-layer="3" from-port="1" to-layer="4" to-port="2"/>
239             </edges>
240         </Net>
241         )V0G0N";
242
243     InferenceEngine::CNNNetReader net_reader;
244     ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
245
246     MKLDNNGraphTestClass graph;
247     graph.CreateGraph(net_reader.getNetwork(), extMgr);
248
249     InferenceEngine::SizeVector dims_src1 = {1, 2, 10, 5};
250
251     InferenceEngine::Blob::Ptr src1 =
252             InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src1);
253     src1->allocate();
254
255     InferenceEngine::SizeVector dims_src2 = {1, 2, 5, 5};
256
257     InferenceEngine::Blob::Ptr src2 =
258             InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src2);
259     src2->allocate();
260
261     InferenceEngine::BlobMap srcs;
262     srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
263     srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
264
265     InferenceEngine::OutputsDataMap out;
266     out = net_reader.getNetwork().getOutputsInfo();
267     InferenceEngine::BlobMap outputBlobs;
268
269     auto it = out.begin();
270
271     std::pair<std::string, InferenceEngine::DataPtr> item = *it;
272
273     InferenceEngine::TensorDesc outputDesc1 = item.second->getTensorDesc();
274     InferenceEngine::TBlob<float>::Ptr output;
275     output = InferenceEngine::make_shared_blob<float>(outputDesc1);
276     output->allocate();
277     outputBlobs[item.first] = output;
278
279     auto& nodes = graph.getNodes();
280     bool existConcat = false;
281     for (auto& node : nodes) {
282         if (node->getType() != MKLDNNPlugin::Concatenation && node->getType() != MKLDNNPlugin::Generic)
283             continue;
284         if (node->getName() == "con" && node->getType() == MKLDNNPlugin::Concatenation)
285             existConcat = true;
286         ASSERT_TRUE(node->isConstant());
287     }
288
289     ASSERT_TRUE(existConcat);
290
291     graph.Infer(srcs, outputBlobs);
292
293     // Compare
294     float *dst_ptr = output->buffer();
295
296     int len1 = 1, len2 = 1, cycles;
297     for (int dim = 2; dim < output->dims().size(); dim++) {
298         len1 *= src1->dims()[dim];
299         len2 *= src2->dims()[dim];
300     }
301     cycles = 2;
302
303     int index1 = 0, index2 = 0, index = 0;
304     for (int cycle = 0; cycle < cycles; cycle ++) {
305         for (int i1 = 0; i1 < len1; i1++) {
306             if (1 != dst_ptr[index]) {
307                 FAIL() << "index: " << index << " src: " << 1 << ", dst: " << dst_ptr[index];
308             }
309             index1++; index++;
310         }
311         for (int i2 = 0; i2 < len2; i2++) {
312             if (4 != dst_ptr[index]) {
313                 FAIL() << "index: " << index << " src: " << 4 << ", dst: " << dst_ptr[index];
314             }
315             index2++; index++;
316         }
317     }
318 }