d3e6bd41956d5ddd724334e906291e82aa9f4d3c
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / stages / resample.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <vpu/frontend/frontend.hpp>
6
7 #include <vector>
8 #include <unordered_set>
9 #include <memory>
10 #include <set>
11 #include <string>
12
13 namespace vpu {
14
15 VPU_DECLARE_ENUM(ResampleType,
16     Nearest  = 0,
17     Bilinear = 1
18 )
19
20 namespace {
21
22 class ResampleStage final : public StageNode {
23 private:
24     StagePtr cloneImpl() const override {
25         return std::make_shared<ResampleStage>(*this);
26     }
27
28     void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
29         auto input = inputEdge(0)->input();
30
31         orderInfo.setOutput(outputEdge(0), input->desc().dimsOrder());
32     }
33
34     void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
35     }
36
37     void finalizeDataLayoutImpl() override {
38     }
39
40     void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
41         batchInfo.setInput(inputEdge(0), BatchSupport::Split);
42         batchInfo.setOutput(outputEdge(0), BatchSupport::Split);
43     }
44
45     void initialCheckImpl() const override {
46         assertInputsOutputsTypes(this, {{DataType::FP16}}, {{DataType::FP16}});
47     }
48
49     void serializeParamsImpl(BlobSerializer& serializer) const override {
50         auto antialias = attrs().get<bool>("antialias");
51         auto factor = attrs().get<float>("factor");
52         auto sampleType = attrs().get<ResampleType>("type");
53
54         serializer.append(static_cast<int32_t>(antialias));
55         serializer.append(static_cast<float>(factor));
56         serializer.append(static_cast<uint32_t>(sampleType));
57     }
58
59     void serializeDataImpl(BlobSerializer& serializer) const override {
60         auto input = inputEdge(0)->input();
61         auto output = outputEdge(0)->output();
62
63         input->serializeOldBuffer(handle_from_this(), serializer);
64         output->serializeOldBuffer(handle_from_this(), serializer);
65     }
66 };
67
68 }  // namespace
69
70 void FrontEnd::parseResample(
71         const Model::Ptr& model,
72         const ie::CNNLayerPtr& layer,
73         const DataVector& inputs,
74         const DataVector& outputs) {
75     IE_ASSERT(inputs.size() == 1);
76     IE_ASSERT(outputs.size() == 1);
77
78     ie::details::CaselessEq<std::string> cmp;
79
80     auto stage = model->addNewStage<ResampleStage>(
81         layer->name,
82         StageType::Resample,
83         layer,
84         inputs,
85         outputs);
86
87     stage->attrs().set<bool>("antialias", layer->GetParamAsInt("antialias", 0));
88     stage->attrs().set<float>("factor", layer->GetParamAsInt("factor", -1.0f));
89
90     auto method = layer->GetParamAsString("type", "caffe.ResampleParameter.NEAREST");
91     if (cmp(method, "caffe.ResampleParameter.NEAREST")) {
92         stage->attrs().set<ResampleType>("type", ResampleType::Nearest);
93     } else {
94         stage->attrs().set<ResampleType>("type", ResampleType::Bilinear);
95     }
96 }
97
98 }  // namespace vpu