1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
10 #include "cnn_network_impl.hpp"
11 #include "ie_layers.h"
13 #include "details/caseless.hpp"
16 namespace InferenceEngine {
18 struct WeightSegment {
20 // offset in bytes of the global weights array
25 inline size_t getEnd() const { return start + size; }
27 // checks if this segment is in the range of 0 to rangeSize, safer than using getEnd() to avoid int overflow
28 inline bool inRange(size_t rangeSize) const {
29 return start < rangeSize && (rangeSize - start) >= size;
33 struct LayerParseParameters {
34 struct LayerPortData {
39 InferenceEngine::LayerParams prms;
41 std::vector<LayerPortData> inputPorts;
42 std::vector<LayerPortData> outputPorts;
43 std::map<std::string, WeightSegment> blobs;
45 std::function<void(const TBlob<uint8_t>::Ptr &weights)> internalWeightSet;
47 int underIRVersion = 0;
49 void addOutputPort(const LayerPortData &port);
50 void addInputPort(const LayerPortData &port);
56 explicit BaseCreator(const std::string& type) : type_(type) {}
59 virtual ~BaseCreator() {}
62 virtual CNNLayer::Ptr CreateLayer(pugi::xml_node& node, LayerParseParameters& layerParsePrms) = 0;
64 bool shouldCreate(const std::string& nodeType) const {
65 InferenceEngine::details::CaselessEq<std::string> comparator;
66 return comparator(nodeType, type_);
70 class INFERENCE_ENGINE_API_CLASS(FormatParser) : public IFormatParser {
72 explicit FormatParser(int version);
74 CNNNetworkImplPtr Parse(pugi::xml_node& root) override;
76 Blob::Ptr GetBlobFromSegment(const TBlob<uint8_t>::Ptr& weights, const WeightSegment & weight_segment) const;
77 void SetWeights(const TBlob<uint8_t>::Ptr& weights) override;
78 void ParseDims(SizeVector& dims, const pugi::xml_node &node) const;
79 const DataPtr& GetDataBy(int layer_id, int port_id) const;
82 std::map<std::string, LayerParseParameters> layersParseInfo;
86 Precision _defPrecision;
87 std::map<std::string, DataPtr> _portsToData;
89 CNNNetworkImplPtr _network;
90 std::map<std::string, std::vector<WeightSegment>> _preProcessSegments;
91 const std::vector<std::shared_ptr<BaseCreator> > &getCreators() const;
92 void ParsePort(LayerParseParameters::LayerPortData& port, pugi::xml_node &node) const;
93 void ParseGenericParams(pugi::xml_node& node, LayerParseParameters& layerParsePrms) const;
94 CNNLayer::Ptr CreateLayer(pugi::xml_node& node, LayerParseParameters& prms) const;
96 void SetLayerInput(CNNNetworkImpl& network, const std::string& data, CNNLayerPtr& targetLayer, int inputPort);
98 DataPtr ParseInputData(pugi::xml_node& root) const;
100 void ParsePreProcess(pugi::xml_node& node);
101 void ParseStatisticSection(const pugi::xml_node& statNode);
103 } // namespace details
104 } // namespace InferenceEngine