1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
6 * @brief A header file that provides wrapper for ICNNNetwork object
7 * @file ie_cnn_network.h
11 #include <details/ie_exception_conversion.hpp>
12 #include <details/ie_cnn_network_iterator.hpp>
13 #include <ie_icnn_network.hpp>
14 #include <ie_icnn_net_reader.h>
15 #include "ie_common.h"
24 namespace InferenceEngine {
27 * @brief This class contains all the information about the Neural Network and the related binary information
32 * @brief A default constructor
34 CNNNetwork() = default;
37 * @brief Initialises helper class from externally managed pointer
38 * @deprecated use shared_pointers based version of CNNNetworks constructor
39 * @param actual Pointer to the network object
41 explicit CNNNetwork(ICNNNetwork* actual) : actual(actual) {
42 if (actual == nullptr) {
43 THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
48 * @brief Allows helper class to manage lifetime of network object
49 * @param network Pointer to the network object
51 explicit CNNNetwork(std::shared_ptr<ICNNNetwork> network)
53 actual = network.get();
54 if (actual == nullptr) {
55 THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
60 * @brief A constructor from ICNNNetReader object
61 * @param reader Pointer to the ICNNNetReader object
63 explicit CNNNetwork(std::shared_ptr<ICNNNetReader> reader)
65 , actual(reader->getNetwork(nullptr)) {
66 if (actual == nullptr) {
67 THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
74 virtual ~CNNNetwork() {}
77 * @brief Wraps original method
78 * ICNNNetwork::getPrecision
80 virtual Precision getPrecision() const {
81 return actual->getPrecision();
85 * @brief Wraps original method
86 * ICNNNetwork::getOutputsInfo
88 virtual OutputsDataMap getOutputsInfo() const {
89 OutputsDataMap outputs;
90 actual->getOutputsInfo(outputs);
91 return std::move(outputs);
95 * @brief Wraps original method
96 * ICNNNetwork::getInputsInfo
98 virtual InputsDataMap getInputsInfo() const {
100 actual->getInputsInfo(inputs);
101 return std::move(inputs);
105 * @brief Wraps original method
106 * ICNNNetwork::layerCount
108 size_t layerCount() const {
109 return actual->layerCount();
113 * @brief Wraps original method
114 * ICNNNetwork::getName
116 const std::string& getName() const noexcept {
117 return actual->getName();
121 * @brief Wraps original method
122 * ICNNNetwork::setBatchSize
124 virtual void setBatchSize(const size_t size) {
125 CALL_STATUS_FNC(setBatchSize, size);
129 * @brief Wraps original method
130 * ICNNNetwork::getBatchSize
132 virtual size_t getBatchSize() const {
133 return actual->getBatchSize();
137 * @brief An overloaded operator & to get current network
138 * @return An instance of the current network
140 operator ICNNNetwork &() const {
145 * @brief Sets tha target device
146 * @param device Device instance to set
148 void setTargetDevice(TargetDevice device) {
149 actual->setTargetDevice(device);
153 * @brief Wraps original method
154 * ICNNNetwork::addOutput
156 void addOutput(const std::string &layerName, size_t outputIndex = 0) {
157 CALL_STATUS_FNC(addOutput, layerName, outputIndex);
161 * @brief Wraps original method
162 * ICNNNetwork::getLayerByName
164 CNNLayerPtr getLayerByName(const char *layerName) const {
166 CALL_STATUS_FNC(getLayerByName, layerName, layer);
171 * @brief Begin layer iterator
172 * Order of layers is implementation specific,
173 * and can be changed in future
175 details::CNNNetworkIterator begin() const {
176 return details::CNNNetworkIterator(actual);
180 * @brief End layer iterator
182 details::CNNNetworkIterator end() const {
183 return details::CNNNetworkIterator();
187 * @brief number of layers in network object
190 size_t size() const {
191 return std::distance(std::begin(*this), std::end(*this));
195 * @brief Registers extension within the plugin
196 * @param extension Pointer to already loaded reader extension with shape propagation implementations
198 void AddExtension(InferenceEngine::IShapeInferExtensionPtr extension) {
199 CALL_STATUS_FNC(AddExtension, extension);
203 * @brief - Helper method to get collect all input shapes with names of corresponding Data objects
204 * @return Map of pairs: input's name and its dimension.
206 virtual ICNNNetwork::InputShapes getInputShapes() const {
207 ICNNNetwork::InputShapes shapes;
208 InputsDataMap inputs;
209 actual->getInputsInfo(inputs);
210 for (const auto& pair : inputs) {
211 auto info = pair.second;
213 auto data = info->getInputData();
215 shapes[data->name] = data->getTensorDesc().getDims();
219 return std::move(shapes);
223 * @brief Run shape inference with new input shapes for the network
224 * @param inputShapes - map of pairs: name of corresponding data and its dimension.
226 virtual void reshape(const ICNNNetwork::InputShapes &inputShapes) {
227 CALL_STATUS_FNC(reshape, inputShapes);
231 * @brief Serialize network to IR and weights files.
232 * @param xmlPath Path to output IR file.
233 * @param binPath Path to output weights file. The parameter is skipped in case
234 * of executable graph info serialization.
236 void serialize(const std::string &xmlPath, const std::string &binPath = "") const {
237 CALL_STATUS_FNC(serialize, xmlPath, binPath);
242 * @brief reader extra reference, might be nullptr
244 std::shared_ptr<ICNNNetReader> reader;
246 * @brief network extra interface, might be nullptr
248 std::shared_ptr<ICNNNetwork> network;
251 * @brief A pointer to the current network
253 ICNNNetwork *actual = nullptr;
255 * @brief A pointer to output data
260 } // namespace InferenceEngine