5ce5b44cb3990d643771f4f0d5ea4bdb586847aa
[platform/upstream/dldt.git] / inference-engine / src / legacy_api / include / cnn_network_impl.hpp
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <ie_icnn_network.hpp>
8 #include <map>
9 #include <memory>
10 #include <string>
11 #include <vector>
12
13 #include "ie_layers.h"
14 #include "ie_ishape_infer_extension.hpp"
15 #include "description_buffer.hpp"
16 #include "ie_api.h"
17 #include "ie_blob.h"
18 #include "ie_common.h"
19 #include "ie_data.h"
20 #include "ie_input_info.hpp"
21
22 namespace InferenceEngine {
23 namespace ShapeInfer {
24 class Reshaper;
25
26 using ReshaperPtr = std::shared_ptr<Reshaper>;
27 }  // namespace ShapeInfer
28 namespace details {
29
30 class INFERENCE_ENGINE_API_CLASS(CNNNetworkImpl): public ICNNNetwork {
31 public:
32     CNNNetworkImpl();
33     ~CNNNetworkImpl() override;
34
35     std::shared_ptr<::ngraph::Function> getFunction() noexcept override {
36         return nullptr;
37     }
38
39     std::shared_ptr<const ::ngraph::Function> getFunction() const noexcept override {
40         return nullptr;
41     }
42
43     void getOutputsInfo(std::map<std::string, DataPtr>& out) const noexcept override;
44
45     void getInputsInfo(InputsDataMap& inputs) const noexcept override;
46
47     InputInfo::Ptr getInput(const std::string& inputName) const noexcept override {
48         auto it = _inputData.find(inputName);
49         if (it == _inputData.end()) {
50             return nullptr;
51         }
52         return it->second;
53     }
54
55     void setInputInfo(InputInfo::Ptr data) {
56         _inputData[data->name()] = data;
57     }
58
59     void removeInputInfo(const std::string& name) {
60         _inputData.erase(name);
61     }
62
63     const std::string& getName() const noexcept override {
64         return _name;
65     }
66
67     void setName(const std::string& name) {
68         _name = name;
69     }
70
71     const std::map<std::string, CNNLayerPtr>& allLayers() const {
72         return _layers;
73     }
74
75     size_t layerCount() const noexcept override {
76         return _layers.size();
77     }
78
79     DataPtr& getData(const char* name) noexcept {
80         return _data[name];
81     }
82
83     void addData(const char* name, DataPtr data) noexcept {
84         _data.emplace(name, data);
85     }
86
87     DataPtr& getData(const std::string& name) {
88         return getData(name.c_str());
89     }
90
91     void addLayer(const CNNLayerPtr& layer) noexcept;
92
93     void removeLayer(const std::string& layerName);
94
95     // renames layer, statistics is not supported
96     void renameLayer(const std::string& currentName, const std::string& newName);
97
98     void removeData(const std::string& dataName);
99
100     StatusCode getLayerByName(const char* layerName, CNNLayerPtr& out, ResponseDesc* resp) const noexcept;
101
102     // public version
103     StatusCode setBatchSize(size_t size, ResponseDesc* responseDesc) noexcept override;
104
105     // for internal usage (e.g. setBatch via reshape in tests)
106     StatusCode setBatchSizeReshape(size_t size, ResponseDesc* responseDesc) noexcept;
107
108     size_t getBatchSize() const noexcept override;
109
110     StatusCode addOutput(const std::string& layerName, size_t outputIndex, ResponseDesc* resp) noexcept override;
111
112     void resolveOutput();
113
114     void addOutput(const std::string& dataName);
115
116     void removeOutput(const std::string& dataName);
117
118     void Release() noexcept override {
119         delete this;
120     }
121
122     virtual void validate(int = 2);
123
124     StatusCode reshape(const std::map<std::string, std::vector<size_t>>& inputShapes,
125                        ResponseDesc* resp) noexcept override;
126
127     StatusCode AddExtension(const InferenceEngine::IShapeInferExtensionPtr& extension,
128                             InferenceEngine::ResponseDesc* resp) noexcept;
129
130     StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
131         noexcept override;
132
133 protected:
134     std::map<std::string, DataPtr> _data;
135     std::map<std::string, CNNLayerPtr> _layers;
136     InferenceEngine::InputsDataMap _inputData;
137     std::map<std::string, DataPtr> _outputData;
138     std::string _name;
139     DataPtr _emptyData;
140     ShapeInfer::ReshaperPtr _reshaper;
141 };
142
143 typedef std::shared_ptr<CNNNetworkImpl> CNNNetworkImplPtr;
144 }  // namespace details
145 }  // namespace InferenceEngine