f5d6e106c5847946ee4d33716583eb185dc575df
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / ie_layer_parsers.h
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <debug.h>
8 #include <memory>
9 #include "ie_format_parser.h"
10 #include "xml_parse_utils.h"
11 #include "range_iterator.hpp"
12 #include "details/caseless.hpp"
13 #include <vector>
14 #include <string>
15 #include <map>
16
17 inline pugi::xml_node GetChild(const pugi::xml_node& node, std::vector<std::string> tags, bool failIfMissing = true) {
18     for (auto tag : tags) {
19         pugi::xml_node dn = node.child(tag.c_str());
20         if (!dn.empty()) return dn;
21     }
22     if (failIfMissing)
23         THROW_IE_EXCEPTION << "missing <" << InferenceEngine::details::dumpVec(tags)
24                            << "> Tags at offset :" << node.offset_debug();
25     return pugi::xml_node();
26 }
27
28 using namespace XMLParseUtils;
29
30 namespace InferenceEngine {
31 namespace details {
32 template<class LT>
33 class LayerCreator : public BaseCreator {
34 public:
35     explicit LayerCreator(const std::string& type) : BaseCreator(type) {}
36
37     CNNLayer::Ptr CreateLayer(pugi::xml_node& node, LayerParseParameters& layerParsePrms) override {
38         auto res = std::make_shared<LT>(layerParsePrms.prms);
39
40         if (res->type == "FakeQuantize")
41             res->type = "Quantize";
42
43         if (std::is_same<LT, FullyConnectedLayer>::value) {
44             layerChild[res->name] = {"fc", "fc_data", "data"};
45         } else if (std::is_same<LT, NormLayer>::value) {
46             layerChild[res->name] = {"lrn", "norm", "norm_data", "data"};
47         } else if (std::is_same<LT, CropLayer>::value) {
48             layerChild[res->name] = {"crop", "crop-data", "data"};
49         } else if (std::is_same<LT, BatchNormalizationLayer>::value) {
50             layerChild[res->name] = {"batch_norm", "batch_norm_data", "data"};
51         } else if ((std::is_same<LT, EltwiseLayer>::value)) {
52             layerChild[res->name] = {"elementwise", "elementwise_data", "data"};
53         } else {
54             layerChild[res->name] = {"data", tolower(res->type) + "_data", tolower(res->type)};
55         }
56
57         pugi::xml_node dn = GetChild(node, layerChild[res->name], false);
58
59         if (!dn.empty()) {
60             if (dn.child("crop").empty()) {
61                 for (auto ait = dn.attributes_begin(); ait != dn.attributes_end(); ++ait) {
62                     pugi::xml_attribute attr = *ait;
63                     res->params.emplace(attr.name(), attr.value());
64                 }
65             } else {
66                 if (std::is_same<LT, CropLayer>::value) {
67                     auto crop_res = std::dynamic_pointer_cast<CropLayer>(res);
68                     if (!crop_res) {
69                         THROW_IE_EXCEPTION << "Crop layer is nullptr";
70                     }
71                     std::string axisStr, offsetStr, dimStr;
72                     FOREACH_CHILD(_cn, dn, "crop") {
73                         int axis = GetIntAttr(_cn, "axis", 0);
74                         crop_res->axis.push_back(axis);
75                         axisStr +=  std::to_string(axis) + ",";
76                         int offset = GetIntAttr(_cn, "offset", 0);
77                         crop_res->offset.push_back(offset);
78                         offsetStr +=  std::to_string(offset) + ",";
79                     }
80                     if (!axisStr.empty() && !offsetStr.empty() && !dimStr.empty()) {
81                         res->params["axis"] = axisStr.substr(0, axisStr.size() - 1);
82                         res->params["offset"] = offsetStr.substr(0, offsetStr.size() - 1);
83                     }
84                 }
85             }
86         }
87         return res;
88     }
89
90     std::map <std::string, std::vector<std::string>> layerChild;
91 };
92
93 class ActivationLayerCreator : public BaseCreator {
94  public:
95     explicit ActivationLayerCreator(const std::string& type) : BaseCreator(type) {}
96     CNNLayer::Ptr CreateLayer(pugi::xml_node& node, LayerParseParameters& layerParsePrms) override;
97 };
98
99 class TILayerCreator : public BaseCreator {
100 public:
101     explicit TILayerCreator(const std::string& type) : BaseCreator(type) {}
102     CNNLayer::Ptr CreateLayer(pugi::xml_node& node, LayerParseParameters& layerParsePrms) override;
103 };
104 }  // namespace details
105 }  // namespace InferenceEngine
106
107 /***********************************************************************************/
108 /******* End of Layer Parsers ******************************************************/
109 /***********************************************************************************/