1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
9 #include <ie_icnn_network.hpp>
10 #include "ie_common.h"
14 #include "ie_input_info.hpp"
15 #include "description_buffer.hpp"
19 #include "cnn_network_stats_impl.hpp"
21 namespace InferenceEngine {
22 namespace ShapeInfer {
25 using ReshaperPtr = std::shared_ptr<Reshaper>;
26 } // namespace ShapeInfer
28 class INFERENCE_ENGINE_API_CLASS(CNNNetworkImpl) : public ICNNNetwork {
31 ~CNNNetworkImpl() override;
32 Precision getPrecision() const noexcept override {
36 void setPrecision(Precision::ePrecision prec) {
40 void getOutputsInfo(std::map<std::string, DataPtr> &out) const noexcept override;
42 void getInputsInfo(InputsDataMap& inputs) const noexcept override;
44 InputInfo::Ptr getInput(const std::string& inputName) const noexcept override {
45 auto it = _inputData.find(inputName);
46 if (it == _inputData.end()) {
52 void setInputInfo(InputInfo::Ptr data) {
53 _inputData[data->name()] = data;
56 void removeInputInfo(const std::string& name) {
57 _inputData.erase(name);
60 void getName(char* pName, size_t len) const noexcept override {
61 // Description buffer will preserve garbage if external pointer not initialized
63 memset(pName, 0, len);
64 DescriptionBuffer(pName, len) << _name;
67 const std::string& getName() const noexcept override {
71 void setName(const std::string& name) {
75 const std::map<std::string, CNNLayerPtr>& allLayers() const {
79 size_t layerCount() const noexcept override {
80 return _layers.size();
83 DataPtr& getData(const char* name) noexcept override {
87 DataPtr& getData(const std::string& name) {
88 return getData(name.c_str());
91 void addLayer(const CNNLayerPtr& layer) noexcept override;
93 void removeLayer(const std::string& layerName);
95 void removeData(const std::string& dataName);
97 StatusCode getLayerByName(const char* layerName, CNNLayerPtr& out, ResponseDesc* resp) const noexcept override;
99 // deprecated, as there is no ResponseDesc to put error message
100 StatusCode setBatchSize(const size_t size) noexcept override;
103 StatusCode setBatchSize(size_t size, ResponseDesc* responseDesc) noexcept override;
105 // for internal usage (e.g. setBatch via reshape in tests)
106 StatusCode setBatchSizeReshape(size_t size, ResponseDesc* responseDesc) noexcept;
108 size_t getBatchSize() const noexcept override;
110 void setTargetDevice(TargetDevice device) noexcept override {
111 _targetDevice = device;
114 TargetDevice getTargetDevice() const noexcept override {
115 return _targetDevice;
118 StatusCode addOutput(const std::string& layerName, size_t outputIndex, ResponseDesc* resp) noexcept override;
120 void resolveOutput();
122 void addOutput(const std::string& dataName);
124 StatusCode getStats(ICNNNetworkStats** stats, ResponseDesc* resp) const noexcept override {
125 if (stats == nullptr) return StatusCode::PARAMETER_MISMATCH;
126 *stats = _stats.get();
127 return StatusCode::OK;
130 void Release() noexcept override {
134 virtual void validate(int = 2);
136 StatusCode reshape(const std::map<std::string, std::vector<size_t>> &inputShapes, ResponseDesc* resp) noexcept override;
139 AddExtension(const InferenceEngine::IShapeInferExtensionPtr &extension, InferenceEngine::ResponseDesc *resp) noexcept override;
141 StatusCode serialize(const std::string &xmlPath, const std::string &binPath, ResponseDesc* resp) const noexcept override;
144 Precision precision {Precision::MIXED};
145 std::map<std::string, DataPtr> _data;
146 std::map<std::string, CNNLayerPtr> _layers;
147 InferenceEngine::InputsDataMap _inputData;
148 std::map<std::string, DataPtr> _outputData;
151 TargetDevice _targetDevice;
153 ShapeInfer::ReshaperPtr _reshaper;
154 CNNNetworkStatsImplPtr _stats;
158 typedef std::shared_ptr<CNNNetworkImpl> CNNNetworkImplPtr;
159 } // namespace details
160 } // namespace InferenceEngine