2f6ab2d91a4e19c95721f92ebcc21b2c897e5727
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / stages / resample.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <vpu/frontend/frontend.hpp>
6
7 #include <vector>
8 #include <unordered_set>
9 #include <memory>
10 #include <set>
11 #include <string>
12
13 namespace vpu {
14
15 VPU_DECLARE_ENUM(ResampleType,
16     Nearest  = 0,
17     Bilinear = 1
18 )
19
20 namespace {
21
22 class ResampleStage final : public StageNode {
23 private:
24     StagePtr cloneImpl() const override {
25         return std::make_shared<ResampleStage>(*this);
26     }
27
28     void propagateDataOrderImpl() const override {
29         IE_ASSERT(_inputEdges.size() == 1);
30         IE_ASSERT(_outputEdges.size() == 1);
31
32         auto input = _inputEdges[0]->input();
33
34         _orderInfo.setOutput(_outputEdges[0], input->desc().dimsOrder());
35     }
36
37     void getDataStridesRequirementsImpl() const override {
38     }
39
40     void finalizeDataLayoutImpl() override {
41     }
42
43     void getBatchSupportInfoImpl() const override {
44         IE_ASSERT(_inputEdges.size() == 1);
45         IE_ASSERT(_outputEdges.size() == 1);
46
47         _batchInfo.setInput(_inputEdges[0], BatchSupport::Split);
48         _batchInfo.setOutput(_outputEdges[0], BatchSupport::Split);
49     }
50
51     void finalCheckImpl() const override {
52     }
53
54     void serializeParamsImpl(BlobSerializer& serializer) const override {
55         auto antialias = attrs().get<bool>("antialias");
56         auto factor = attrs().get<float>("factor");
57         auto sampleType = attrs().get<ResampleType>("type");
58
59         serializer.append(static_cast<int32_t>(antialias));
60         serializer.append(static_cast<float>(factor));
61         serializer.append(static_cast<uint32_t>(sampleType));
62     }
63
64     void serializeDataImpl(BlobSerializer& serializer) const override {
65         IE_ASSERT(_inputEdges.size() == 1);
66         IE_ASSERT(_outputEdges.size() == 1);
67         IE_ASSERT(_tempBufferEdges.empty());
68
69         auto input = _inputEdges[0]->input();
70         auto output = _outputEdges[0]->output();
71
72         input->serializeOldBuffer(handle_from_this(), serializer);
73         output->serializeOldBuffer(handle_from_this(), serializer);
74     }
75 };
76
77 }  // namespace
78
79 void FrontEnd::parseResample(
80         const Model::Ptr& model,
81         const ie::CNNLayerPtr& layer,
82         const DataVector& inputs,
83         const DataVector& outputs) {
84     IE_ASSERT(inputs.size() == 1);
85     IE_ASSERT(outputs.size() == 1);
86
87     ie::details::CaselessEq<std::string> cmp;
88
89     auto stage = model->addNewStage<ResampleStage>(
90         layer->name,
91         StageType::Resample,
92         layer,
93         inputs,
94         outputs);
95
96     stage->attrs().set<bool>("antialias", layer->GetParamAsInt("antialias", 0));
97     stage->attrs().set<float>("factor", layer->GetParamAsInt("factor", -1.0f));
98
99     auto method = layer->GetParamAsString("type", "caffe.ResampleParameter.NEAREST");
100     if (cmp(method, "caffe.ResampleParameter.NEAREST")) {
101         stage->attrs().set<ResampleType>("type", ResampleType::Nearest);
102     } else {
103         stage->attrs().set<ResampleType>("type", ResampleType::Bilinear);
104     }
105 }
106
107 }  // namespace vpu