Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / cnn_network_impl.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <map>
8 #include <memory>
9 #include <ie_icnn_network.hpp>
10 #include "ie_common.h"
11 #include "ie_data.h"
12 #include "ie_blob.h"
13 #include "ie_api.h"
14 #include "ie_input_info.hpp"
15 #include "description_buffer.hpp"
16 #include <string>
17 #include <vector>
18
19 #include "cnn_network_stats_impl.hpp"
20
21 namespace InferenceEngine {
22 namespace ShapeInfer {
23 class Reshaper;
24
25 using ReshaperPtr = std::shared_ptr<Reshaper>;
26 }  // namespace ShapeInfer
27 namespace details {
28 class INFERENCE_ENGINE_API_CLASS(CNNNetworkImpl) : public ICNNNetwork {
29 public:
30     CNNNetworkImpl();
31     ~CNNNetworkImpl() override;
32     Precision getPrecision() const noexcept override {
33         return precision;
34     }
35
36     void setPrecision(Precision::ePrecision  prec) {
37         precision = prec;
38     }
39
40     void getOutputsInfo(std::map<std::string, DataPtr> &out) const noexcept override;
41
42     void getInputsInfo(InputsDataMap& inputs) const noexcept override;
43
44     InputInfo::Ptr getInput(const std::string& inputName) const noexcept override {
45         auto it = _inputData.find(inputName);
46         if (it == _inputData.end()) {
47             return nullptr;
48         }
49         return it->second;
50     }
51
52     void setInputInfo(InputInfo::Ptr data) {
53         _inputData[data->name()] = data;
54     }
55
56     void removeInputInfo(const std::string& name) {
57         _inputData.erase(name);
58     }
59
60     void getName(char* pName, size_t len) const noexcept override {
61         // Description buffer will preserve garbage if external pointer not initialized
62         if (len < 1) return;
63         memset(pName, 0, len);
64         DescriptionBuffer(pName, len) << _name;
65     }
66
67     const std::string& getName() const noexcept override {
68         return _name;
69     }
70
71     void setName(const std::string& name) {
72         _name = name;
73     }
74
75     const std::map<std::string, CNNLayerPtr>& allLayers() const {
76         return _layers;
77     }
78
79     size_t layerCount()  const noexcept override {
80         return _layers.size();
81     }
82
83     DataPtr& getData(const char* name) noexcept override  {
84         return _data[name];
85     }
86
87     DataPtr& getData(const std::string& name) {
88         return getData(name.c_str());
89     }
90
91     void addLayer(const CNNLayerPtr& layer) noexcept override;
92
93     void removeLayer(const std::string& layerName);
94
95     void removeData(const std::string& dataName);
96
97     StatusCode getLayerByName(const char* layerName, CNNLayerPtr& out, ResponseDesc* resp) const noexcept override;
98
99     // deprecated, as there is no ResponseDesc to put error message
100     StatusCode setBatchSize(const size_t size) noexcept override;
101
102     // public version
103     StatusCode setBatchSize(size_t size, ResponseDesc* responseDesc) noexcept override;
104
105     // for internal usage (e.g. setBatch via reshape in tests)
106     StatusCode setBatchSizeReshape(size_t size, ResponseDesc* responseDesc) noexcept;
107
108     size_t getBatchSize() const noexcept override;
109
110     void setTargetDevice(TargetDevice device) noexcept override {
111         _targetDevice = device;
112     }
113
114     TargetDevice getTargetDevice() const noexcept override {
115         return _targetDevice;
116     }
117
118     StatusCode addOutput(const std::string& layerName, size_t outputIndex, ResponseDesc* resp) noexcept override;
119
120     void resolveOutput();
121
122     void addOutput(const std::string& dataName);
123
124     StatusCode getStats(ICNNNetworkStats** stats, ResponseDesc* resp) const noexcept override {
125         if (stats == nullptr) return StatusCode::PARAMETER_MISMATCH;
126         *stats = _stats.get();
127         return StatusCode::OK;
128     }
129
130     void Release() noexcept override {
131         delete this;
132     }
133
134     virtual void validate(int = 2);
135
136     StatusCode reshape(const std::map<std::string, std::vector<size_t>> &inputShapes, ResponseDesc* resp) noexcept override;
137
138     StatusCode
139     AddExtension(const InferenceEngine::IShapeInferExtensionPtr &extension, InferenceEngine::ResponseDesc *resp) noexcept override;
140
141     StatusCode serialize(const std::string &xmlPath, const std::string &binPath, ResponseDesc* resp) const noexcept override;
142
143 protected:
144     Precision precision {Precision::MIXED};
145     std::map<std::string, DataPtr> _data;
146     std::map<std::string, CNNLayerPtr> _layers;
147     InferenceEngine::InputsDataMap _inputData;
148     std::map<std::string, DataPtr> _outputData;
149     std::string _name;
150     /// @brief
151     TargetDevice _targetDevice;
152     DataPtr _emptyData;
153     ShapeInfer::ReshaperPtr _reshaper;
154     CNNNetworkStatsImplPtr _stats;
155 };
156
157
158 typedef std::shared_ptr<CNNNetworkImpl> CNNNetworkImplPtr;
159 }  // namespace details
160 }  // namespace InferenceEngine