96cdc1f3a03dfe8cc118c8a7d60102c313a1de68
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / stages / roipooling.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <vpu/frontend/frontend.hpp>
6
7 #include <cstdio>
8
9 #include <vector>
10 #include <string>
11 #include <unordered_set>
12 #include <memory>
13 #include <set>
14
15 namespace vpu {
16
17 VPU_DECLARE_ENUM(ROIPoolingMethod,
18     Max = 0,
19     Bilinear = 1
20 )
21
22 namespace {
23
24 class ROIPoolingStage final : public StageNode {
25 private:
26     StagePtr cloneImpl() const override {
27         return std::make_shared<ROIPoolingStage>(*this);
28     }
29
30     void propagateDataOrderImpl() const override {
31         IE_ASSERT(_inputEdges.size() == 2);
32         IE_ASSERT(_outputEdges.size() == 1);
33
34         auto input0 = _inputEdges[0]->input();
35         auto output = _outputEdges[0]->output();
36
37         _orderInfo.setInput(_inputEdges[0], input0->desc().dimsOrder().createMovedDim(Dim::C, 2));
38         _orderInfo.setOutput(_outputEdges[0], output->desc().dimsOrder().createMovedDim(Dim::C, 2));
39     }
40
41     void getDataStridesRequirementsImpl() const override {
42         IE_ASSERT(_inputEdges.size() == 2);
43         IE_ASSERT(_outputEdges.size() == 1);
44
45         _stridesInfo.setInput(_inputEdges[0], StridesRequirement::compact());
46         _stridesInfo.setInput(_inputEdges[1], StridesRequirement::compact());
47         _stridesInfo.setOutput(_outputEdges[0], StridesRequirement::compact());
48     }
49
50     void finalizeDataLayoutImpl() override {
51     }
52
53     void getBatchSupportInfoImpl() const override {
54     }
55
56     void finalCheckImpl() const override {
57     }
58
59     void serializeParamsImpl(BlobSerializer& serializer) const override {
60         auto pooled_w = attrs().get<int>("pooled_w");
61         auto pooled_h = attrs().get<int>("pooled_h");
62         auto spatial_scale = attrs().get<float>("spatial_scale");
63         auto method = attrs().get<ROIPoolingMethod>("method");
64
65         serializer.append(static_cast<uint32_t>(pooled_w));
66         serializer.append(static_cast<uint32_t>(pooled_h));
67         serializer.append(static_cast<float>(spatial_scale));
68         serializer.append(static_cast<uint32_t>(method));
69     }
70
71     void serializeDataImpl(BlobSerializer& serializer) const override {
72         IE_ASSERT(_inputEdges.size() == 2);
73         IE_ASSERT(_outputEdges.size() == 1);
74         IE_ASSERT(_tempBufferEdges.empty());
75
76         auto input0 = _inputEdges[0]->input();
77         auto input1 = _inputEdges[1]->input();
78         auto output = _outputEdges[0]->output();
79
80         input0->serializeNewBuffer(serializer);
81         output->serializeNewBuffer(serializer);
82         input1->serializeNewBuffer(serializer);
83     }
84 };
85
86 }  // namespace
87
88 void FrontEnd::parseROIPooling(
89         const Model::Ptr& model,
90         const ie::CNNLayerPtr& layer,
91         const DataVector& inputs,
92         const DataVector& outputs) {
93     ie::details::CaselessEq<std::string> cmp;
94
95     IE_ASSERT(inputs.size() == 2);
96     IE_ASSERT(outputs.size() == 1);
97
98     auto stage = model->addNewStage<ROIPoolingStage>(
99         layer->name,
100         StageType::ROIPooling,
101         layer,
102         inputs,
103         outputs);
104
105     stage->attrs().set<int>("pooled_w", layer->GetParamAsInt("pooled_w", 7));
106     stage->attrs().set<int>("pooled_h", layer->GetParamAsInt("pooled_h", 7));
107     stage->attrs().set<float>("spatial_scale", layer->GetParamAsFloat("spatial_scale", 0.0625f));
108
109     auto method = layer->GetParamAsString("method", "max");
110     if (cmp(method, "bilinear")) {
111         stage->attrs().set("method", ROIPoolingMethod::Bilinear);
112     } else {
113         stage->attrs().set("method", ROIPoolingMethod::Max);
114     }
115 }
116
117 }  // namespace vpu