1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
7 #include <ie_icnn_network.hpp>
13 #include "ie_layers.h"
14 #include "ie_ishape_infer_extension.hpp"
15 #include "description_buffer.hpp"
18 #include "ie_common.h"
20 #include "ie_input_info.hpp"
22 namespace InferenceEngine {
23 namespace ShapeInfer {
26 using ReshaperPtr = std::shared_ptr<Reshaper>;
27 } // namespace ShapeInfer
30 class INFERENCE_ENGINE_API_CLASS(CNNNetworkImpl): public ICNNNetwork {
33 ~CNNNetworkImpl() override;
35 std::shared_ptr<::ngraph::Function> getFunction() noexcept override {
39 std::shared_ptr<const ::ngraph::Function> getFunction() const noexcept override {
43 void getOutputsInfo(std::map<std::string, DataPtr>& out) const noexcept override;
45 void getInputsInfo(InputsDataMap& inputs) const noexcept override;
47 InputInfo::Ptr getInput(const std::string& inputName) const noexcept override {
48 auto it = _inputData.find(inputName);
49 if (it == _inputData.end()) {
55 void setInputInfo(InputInfo::Ptr data) {
56 _inputData[data->name()] = data;
59 void removeInputInfo(const std::string& name) {
60 _inputData.erase(name);
63 const std::string& getName() const noexcept override {
67 void setName(const std::string& name) {
71 const std::map<std::string, CNNLayerPtr>& allLayers() const {
75 size_t layerCount() const noexcept override {
76 return _layers.size();
79 DataPtr& getData(const char* name) noexcept {
83 void addData(const char* name, DataPtr data) noexcept {
84 _data.emplace(name, data);
87 DataPtr& getData(const std::string& name) {
88 return getData(name.c_str());
91 void addLayer(const CNNLayerPtr& layer) noexcept;
93 void removeLayer(const std::string& layerName);
95 // renames layer, statistics is not supported
96 void renameLayer(const std::string& currentName, const std::string& newName);
98 void removeData(const std::string& dataName);
100 StatusCode getLayerByName(const char* layerName, CNNLayerPtr& out, ResponseDesc* resp) const noexcept;
103 StatusCode setBatchSize(size_t size, ResponseDesc* responseDesc) noexcept override;
105 // for internal usage (e.g. setBatch via reshape in tests)
106 StatusCode setBatchSizeReshape(size_t size, ResponseDesc* responseDesc) noexcept;
108 size_t getBatchSize() const noexcept override;
110 StatusCode addOutput(const std::string& layerName, size_t outputIndex, ResponseDesc* resp) noexcept override;
112 void resolveOutput();
114 void addOutput(const std::string& dataName);
116 void removeOutput(const std::string& dataName);
118 void Release() noexcept override {
122 virtual void validate(int = 2);
124 StatusCode reshape(const std::map<std::string, std::vector<size_t>>& inputShapes,
125 ResponseDesc* resp) noexcept override;
127 StatusCode AddExtension(const InferenceEngine::IShapeInferExtensionPtr& extension,
128 InferenceEngine::ResponseDesc* resp) noexcept;
130 StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
134 std::map<std::string, DataPtr> _data;
135 std::map<std::string, CNNLayerPtr> _layers;
136 InferenceEngine::InputsDataMap _inputData;
137 std::map<std::string, DataPtr> _outputData;
140 ShapeInfer::ReshaperPtr _reshaper;
143 typedef std::shared_ptr<CNNNetworkImpl> CNNNetworkImplPtr;
144 } // namespace details
145 } // namespace InferenceEngine