Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / ie_cnn_layer_builder.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <ie_cnn_layer_builder.h>
6
7 using namespace InferenceEngine;
8
9 std::map<std::string, std::string> Builder::convertParameters2Strings(const std::map<std::string, Parameter>& parameters) {
10     std::map<std::string, std::string> oldParams;
11     for (const auto& param : parameters) {
12         // skip blobs and ports
13         if (param.second.is<Blob::CPtr>() || param.second.is<Blob::Ptr>() || param.second.is<std::vector<Port>>()
14                 || param.second.is<PreProcessInfo>())
15             continue;
16         if (param.second.is<std::string>() || param.second.is<std::vector<std::string>>()) {
17             oldParams[param.first] = Builder::convertParameter2String<std::string>(param.second);
18         } else if (param.second.is<int>() || param.second.is<std::vector<int>>()) {
19             oldParams[param.first] = Builder::convertParameter2String<int>(param.second);
20         } else if (param.second.is<float>() || param.second.is<std::vector<float>>()) {
21             oldParams[param.first] = Builder::convertParameter2String<float>(param.second);
22         } else if (param.second.is<unsigned int>() || param.second.is<std::vector<unsigned int>>()) {
23             oldParams[param.first] = Builder::convertParameter2String<unsigned int>(param.second);
24         } else if (param.second.is<size_t>() || param.second.is<std::vector<size_t>>()) {
25             oldParams[param.first] = Builder::convertParameter2String<size_t>(param.second);
26         } else if (param.second.is<bool>() || param.second.is<std::vector<bool>>()) {
27             oldParams[param.first] = Builder::convertParameter2String<bool>(param.second);
28         } else {
29             THROW_IE_EXCEPTION << "Parameter " << param.first << " has unsupported parameter type!";
30         }
31     }
32     return oldParams;
33 }
34
35 Builder::Layer Builder::builderFromCNNLayer(const CNNLayerPtr& cnnLayer) {
36     Builder::Layer layer(cnnLayer->type, cnnLayer->name);
37     std::vector<Port> inputPorts;
38     for (const auto& data : cnnLayer->insData) {
39         auto lockedData = data.lock();
40         if (!lockedData)
41             continue;
42         inputPorts.emplace_back(lockedData->getTensorDesc().getDims());
43     }
44
45     std::vector<Port> outputPorts;
46     for (const auto& data : cnnLayer->outData) {
47         outputPorts.emplace_back(data->getTensorDesc().getDims());
48     }
49
50     size_t inputsCount = inputPorts.size();
51     std::map<std::string, Blob::Ptr> blobs = cnnLayer->blobs;
52     if (blobs.find("weights") != blobs.end()) {
53         auto port = Port();
54         port.setParameter("type", "weights");
55         inputPorts.push_back(port);
56     }
57     if (blobs.find("biases") != blobs.end()) {
58         if (inputsCount == inputPorts.size()) {
59             auto port = Port();
60             port.setParameter("type", "weights");
61             inputPorts.push_back(port);
62         }
63
64         auto port = Port();
65         port.setParameter("type", "biases");
66         inputPorts.push_back(port);
67     }
68     for (const auto& it : blobs) {
69         if (it.first == "weights" || it.first == "biases")
70             continue;
71         auto port = Port();
72         port.setParameter("type", it.first);
73         inputPorts.emplace_back(port);
74     }
75
76     std::map<std::string, Parameter> params;
77     for (const auto& it : cnnLayer->params) {
78         params[it.first] = it.second;
79     }
80
81     layer.setInputPorts(inputPorts).setOutputPorts(outputPorts).setParameters(params);
82
83     Builder::ConverterRegister::convert(cnnLayer, layer);
84
85     return layer;
86 }
87
88 Builder::ConverterRegister::ConverterRegister(const std::string& type, const std::function<void(const CNNLayerPtr&, Layer&)>& converter) {
89     if (getConvertersHolder().converters.find(type) == getConvertersHolder().converters.end())
90         getConvertersHolder().converters[type] = converter;
91 }
92
93 Builder::ConvertersHolder &Builder::ConverterRegister::getConvertersHolder() {
94     static Builder::ConvertersHolder holder;
95     return holder;
96 }