Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / ie_format_parser.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <string>
8 #include <map>
9 #include <memory>
10 #include "cnn_network_impl.hpp"
11 #include "ie_layers.h"
12 #include "parsers.h"
13 #include "details/caseless.hpp"
14 #include <vector>
15
16 namespace InferenceEngine {
17 namespace details {
18 struct WeightSegment {
19     Precision precision;
20     // offset in bytes of the global weights array
21     size_t start = 0;
22     // size in bytes
23     size_t size = 0;
24
25     inline size_t getEnd() const { return start + size; }
26
27     // checks if this segment is in the range of 0 to rangeSize, safer than using getEnd() to avoid int overflow
28     inline bool inRange(size_t rangeSize) const {
29         return start < rangeSize && (rangeSize - start) >= size;
30     }
31 };
32
33 struct LayerParseParameters {
34     struct LayerPortData {
35         int           portId;
36         Precision     precision;
37         SizeVector    dims;
38     };
39     InferenceEngine::LayerParams prms;
40     int layerId = -1;
41     std::vector<LayerPortData> inputPorts;
42     std::vector<LayerPortData> outputPorts;
43     std::map<std::string, WeightSegment> blobs;
44
45     std::function<void(const TBlob<uint8_t>::Ptr &weights)> internalWeightSet;
46
47     int underIRVersion = 0;
48
49     void addOutputPort(const LayerPortData &port);
50     void addInputPort(const LayerPortData &port);
51 };
52
53 class BaseCreator {
54     std::string type_;
55 protected:
56     explicit BaseCreator(const std::string& type) : type_(type) {}
57
58 public:
59     virtual ~BaseCreator() {}
60     static int version_;
61
62     virtual CNNLayer::Ptr CreateLayer(pugi::xml_node& node, LayerParseParameters& layerParsePrms) = 0;
63
64     bool shouldCreate(const std::string& nodeType) const {
65         InferenceEngine::details::CaselessEq<std::string> comparator;
66         return comparator(nodeType, type_);
67     }
68 };
69
70 class INFERENCE_ENGINE_API_CLASS(FormatParser) : public IFormatParser {
71 public:
72     explicit FormatParser(int version);
73
74     CNNNetworkImplPtr Parse(pugi::xml_node& root) override;
75
76     Blob::Ptr GetBlobFromSegment(const TBlob<uint8_t>::Ptr& weights, const WeightSegment & weight_segment) const;
77     void SetWeights(const TBlob<uint8_t>::Ptr& weights) override;
78     void ParseDims(SizeVector& dims, const pugi::xml_node &node) const;
79     const DataPtr& GetDataBy(int layer_id, int port_id) const;
80
81 protected:
82     std::map<std::string, LayerParseParameters> layersParseInfo;
83
84 private:
85     int _version;
86     Precision _defPrecision;
87     std::map<std::string, DataPtr> _portsToData;
88
89     CNNNetworkImplPtr _network;
90     std::map<std::string, std::vector<WeightSegment>> _preProcessSegments;
91     const std::vector<std::shared_ptr<BaseCreator> > &getCreators() const;
92     void ParsePort(LayerParseParameters::LayerPortData& port, pugi::xml_node &node) const;
93     void ParseGenericParams(pugi::xml_node& node, LayerParseParameters& layerParsePrms) const;
94     CNNLayer::Ptr CreateLayer(pugi::xml_node& node, LayerParseParameters& prms) const;
95
96     void SetLayerInput(CNNNetworkImpl& network, const std::string& data, CNNLayerPtr& targetLayer, int inputPort);
97
98     DataPtr ParseInputData(pugi::xml_node& root) const;
99
100     void ParsePreProcess(pugi::xml_node& node);
101     void ParseStatisticSection(const pugi::xml_node& statNode);
102 };
103 }  // namespace details
104 }  // namespace InferenceEngine