e6088954a2b89eb4bec91a1fb229a7bb78085538
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / stages / grn.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <vpu/frontend/frontend.hpp>
6
7 #include <vector>
8 #include <unordered_set>
9 #include <memory>
10 #include <set>
11
12 namespace vpu {
13
14 namespace {
15
16 class GRNStage final : public StageNode {
17 private:
18     StagePtr cloneImpl() const override {
19         return std::make_shared<GRNStage>(*this);
20     }
21
22     void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
23         auto input = inputEdge(0)->input();
24
25         orderInfo.setOutput(outputEdge(0), input->desc().dimsOrder());
26     }
27
28     void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
29     }
30
31     void finalizeDataLayoutImpl() override {
32     }
33
34     void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
35         batchInfo.setInput(inputEdge(0), BatchSupport::Split);
36         batchInfo.setOutput(outputEdge(0), BatchSupport::Split);
37     }
38
39     void initialCheckImpl() const override {
40         assertInputsOutputsTypes(this, {{DataType::FP16}}, {{DataType::FP16}});
41     }
42
43     void serializeParamsImpl(BlobSerializer& serializer) const override {
44         auto bias = attrs().get<float>("bias");
45
46         serializer.append(static_cast<float>(bias));
47     }
48
49     void serializeDataImpl(BlobSerializer& serializer) const override {
50         auto input = inputEdge(0)->input();
51         auto output = outputEdge(0)->output();
52
53         input->serializeNewBuffer(serializer);
54         output->serializeNewBuffer(serializer);
55     }
56 };
57
58 }  // namespace
59
60 void FrontEnd::parseGRN(
61         const Model::Ptr& model,
62         const ie::CNNLayerPtr& _layer,
63         const DataVector& inputs,
64         const DataVector& outputs) {
65     IE_ASSERT(inputs.size() == 1);
66     IE_ASSERT(outputs.size() == 1);
67
68     auto layer = std::dynamic_pointer_cast<ie::GRNLayer>(_layer);
69     IE_ASSERT(layer != nullptr);
70
71     auto stage = model->addNewStage<GRNStage>(
72         layer->name,
73         StageType::GRN,
74         layer,
75         inputs,
76         outputs);
77
78     stage->attrs().set<float>("bias", layer->bias);
79 }
80
81 }  // namespace vpu