Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / include / cpp / ie_cnn_network.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 /**
6  * @brief A header file that provides wrapper for ICNNNetwork object
7  * @file ie_cnn_network.h
8  */
9 #pragma once
10
11 #include <details/ie_exception_conversion.hpp>
12 #include <details/ie_cnn_network_iterator.hpp>
13 #include <ie_icnn_network.hpp>
14 #include <ie_icnn_net_reader.h>
15 #include "ie_common.h"
16 #include "ie_data.h"
17 #include "ie_blob.h"
18 #include <vector>
19 #include <string>
20 #include <map>
21 #include <utility>
22 #include <memory>
23
24 namespace InferenceEngine {
25
26 /**
27  * @brief This class contains all the information about the Neural Network and the related binary information
28  */
29 class CNNNetwork {
30 public:
31     /**
32      * @brief A default constructor
33      */
34     CNNNetwork() = default;
35
36     /**
37      * @brief Initialises helper class from externally managed pointer
38      * @deprecated use shared_pointers based version of CNNNetworks constructor
39      * @param actual Pointer to the network object
40      */
41     explicit CNNNetwork(ICNNNetwork* actual) : actual(actual) {
42         if (actual == nullptr) {
43             THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
44         }
45     }
46
47     /**
48      * @brief Allows helper class to manage lifetime of network object
49      * @param network Pointer to the network object
50      */
51     explicit CNNNetwork(std::shared_ptr<ICNNNetwork> network)
52         : network(network) {
53         actual = network.get();
54         if (actual == nullptr) {
55             THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
56         }
57     }
58
59     /**
60      * @brief A constructor from ICNNNetReader object
61      * @param reader Pointer to the ICNNNetReader object
62      */
63     explicit CNNNetwork(std::shared_ptr<ICNNNetReader> reader)
64             : reader(reader)
65             , actual(reader->getNetwork(nullptr)) {
66         if (actual == nullptr) {
67             THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
68         }
69     }
70
71     /**
72      * @brief A destructor
73      */
74     virtual ~CNNNetwork() {}
75
76     /**
77      * @brief Wraps original method
78      * ICNNNetwork::getPrecision
79      */
80     virtual Precision getPrecision() const {
81         return actual->getPrecision();
82     }
83
84     /**
85      * @brief Wraps original method
86      * ICNNNetwork::getOutputsInfo
87      */
88     virtual OutputsDataMap getOutputsInfo() const {
89         OutputsDataMap outputs;
90         actual->getOutputsInfo(outputs);
91         return std::move(outputs);
92     }
93
94     /**
95      * @brief Wraps original method
96      * ICNNNetwork::getInputsInfo
97      */
98     virtual InputsDataMap getInputsInfo() const {
99         InputsDataMap inputs;
100         actual->getInputsInfo(inputs);
101         return std::move(inputs);
102     }
103
104     /**
105      * @brief Wraps original method
106      * ICNNNetwork::layerCount
107      */
108     size_t layerCount() const {
109         return actual->layerCount();
110     }
111
112     /**
113      * @brief Wraps original method
114      * ICNNNetwork::getName
115      */
116     const std::string& getName() const noexcept {
117         return actual->getName();
118     }
119
120     /**
121      * @brief Wraps original method
122      * ICNNNetwork::setBatchSize
123      */
124     virtual void setBatchSize(const size_t size) {
125         CALL_STATUS_FNC(setBatchSize, size);
126     }
127
128     /**
129      * @brief Wraps original method
130      * ICNNNetwork::getBatchSize
131      */
132     virtual size_t getBatchSize() const {
133         return actual->getBatchSize();
134     }
135
136     /**
137      * @brief An overloaded operator & to get current network
138      * @return An instance of the current network
139      */
140     operator ICNNNetwork &() const {
141         return *actual;
142     }
143
144     /**
145      * @brief Sets tha target device
146      * @param device Device instance to set
147      */
148     void setTargetDevice(TargetDevice device) {
149         actual->setTargetDevice(device);
150     }
151
152     /**
153      * @brief Wraps original method
154      * ICNNNetwork::addOutput
155      */
156     void addOutput(const std::string &layerName, size_t outputIndex = 0) {
157         CALL_STATUS_FNC(addOutput, layerName, outputIndex);
158     }
159
160     /**
161      * @brief Wraps original method
162      * ICNNNetwork::getLayerByName
163      */
164     CNNLayerPtr getLayerByName(const char *layerName) const {
165         CNNLayerPtr layer;
166         CALL_STATUS_FNC(getLayerByName, layerName, layer);
167         return layer;
168     }
169
170     /**
171      * @brief Begin layer iterator
172      * Order of layers is implementation specific,
173      * and can be changed in future
174      */
175     details::CNNNetworkIterator begin() const {
176         return details::CNNNetworkIterator(actual);
177     }
178
179     /**
180      * @brief End layer iterator
181      */
182     details::CNNNetworkIterator end() const {
183         return details::CNNNetworkIterator();
184     }
185
186     /**
187      * @brief number of layers in network object
188      * @return
189      */
190     size_t size() const {
191         return std::distance(std::begin(*this), std::end(*this));
192     }
193
194     /**
195      * @brief Registers extension within the plugin
196      * @param extension Pointer to already loaded reader extension with shape propagation implementations
197      */
198     void AddExtension(InferenceEngine::IShapeInferExtensionPtr extension) {
199         CALL_STATUS_FNC(AddExtension, extension);
200     }
201
202     /**
203      * @brief - Helper method to get collect all input shapes with names of corresponding Data objects
204      * @return Map of pairs: input's name and its dimension.
205      */
206     virtual ICNNNetwork::InputShapes getInputShapes() const {
207         ICNNNetwork::InputShapes shapes;
208         InputsDataMap inputs;
209         actual->getInputsInfo(inputs);
210         for (const auto& pair : inputs) {
211             auto info = pair.second;
212             if (info) {
213                 auto data = info->getInputData();
214                 if (data) {
215                     shapes[data->name] = data->getTensorDesc().getDims();
216                 }
217             }
218         }
219         return std::move(shapes);
220     }
221
222     /**
223      * @brief Run shape inference with new input shapes for the network
224      * @param inputShapes - map of pairs: name of corresponding data and its dimension.
225      */
226     virtual void reshape(const ICNNNetwork::InputShapes &inputShapes) {
227         CALL_STATUS_FNC(reshape, inputShapes);
228     }
229
230     /**
231      * @brief Serialize network to IR and weights files.
232      * @param xmlPath Path to output IR file.
233      * @param binPath Path to output weights file. The parameter is skipped in case
234      * of executable graph info serialization.
235      */
236     void serialize(const std::string &xmlPath, const std::string &binPath = "") const {
237         CALL_STATUS_FNC(serialize, xmlPath, binPath);
238     }
239
240 protected:
241     /**
242      * @brief reader extra reference, might be nullptr
243      */
244     std::shared_ptr<ICNNNetReader> reader;
245     /**
246      * @brief network extra interface, might be nullptr
247      */
248     std::shared_ptr<ICNNNetwork> network;
249
250     /**
251      * @brief A pointer to the current network
252      */
253     ICNNNetwork *actual = nullptr;
254     /**
255      * @brief A pointer to output data
256      */
257     DataPtr output;
258 };
259
260 }  // namespace InferenceEngine