Publishing 2019 R1.1 content and Myriad plugin sources (#162)
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / sw / post_op_stage.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <vpu/sw/post_op_stage.hpp>
6
7 #include <memory>
8
9 #include <vpu/model/edges.hpp>
10 #include <vpu/model/data.hpp>
11
12 namespace vpu {
13
14 DataMap<float> PostOpStage::propagateScaleFactorsImpl(
15         const DataMap<float>&,
16         ScalePropagationStep) {
17     IE_ASSERT(!_inputEdges.empty());
18     IE_ASSERT(_outputEdges.size() == 1);
19
20     auto output = _outputEdges[0]->output();
21
22     DataMap<float> out;
23
24     // By default, assume no scale propagation.
25     for (const auto& inEdge : _inputEdges) {
26         out[inEdge->input()] = 1.0f;
27     }
28     out[output] = 1.0f;
29
30     return out;
31 }
32
33 DataMap<DimsOrder> PostOpStage::propagateDataOrderImpl() const {
34     IE_ASSERT(!_inputEdges.empty());
35     IE_ASSERT(_outputEdges.size() == 1);
36
37     // Non-zero-port inputs are constant (scales/biases).
38     for (const auto& inEdge : _inputEdges) {
39         if (inEdge->portInd() > 0) {
40             IE_ASSERT(inEdge->input()->usage() == DataUsage::Const);
41         }
42     }
43
44     auto input = _inputEdges[0]->input();
45     auto output = _outputEdges[0]->output();
46
47     DataMap<DimsOrder> out;
48
49     auto inDimsOrder = input->desc().dimsOrder();
50
51     // TODO: support HCW on firmware side
52     if (inDimsOrder.dimInd(Dim::C) == 1) {
53         inDimsOrder = inDimsOrder.createMovedDim(Dim::C, 2);  // CHW
54         out[input] = inDimsOrder;
55     }
56
57     out[output] = inDimsOrder;
58
59     return out;
60 }
61
62 DataMap<StridesRequirement> PostOpStage::getDataStridesRequirementsImpl() const {
63     IE_ASSERT(!_inputEdges.empty());
64     IE_ASSERT(_outputEdges.size() == 1);
65
66     // Non-zero-port inputs are constant (scales/biases).
67     for (const auto& inEdge : _inputEdges) {
68         if (inEdge->portInd() > 0) {
69             IE_ASSERT(inEdge->input()->usage() == DataUsage::Const);
70         }
71     }
72
73     auto input = _inputEdges[0]->input();
74     auto output = _outputEdges[0]->output();
75
76     DataMap<StridesRequirement> out;
77
78     StridesRequirement reqs;
79
80     // Current PostOp implementation requires Compact major stride.
81     reqs.add(2, DimStride::Compact);
82
83     if (input->desc().dim(Dim::N, 1) > 1) {
84         // To merge batch into previous dimension.
85         reqs.add(input->desc().dimsOrder().dimInd(Dim::N), DimStride::Compact);
86     }
87
88     out[input] = reqs;
89     out[output] = reqs;
90
91     return out;
92 }
93
94 void PostOpStage::finalizeDataLayoutImpl() {
95 }
96
97 DataMap<BatchSupport> PostOpStage::getBatchSupportInfoImpl() const {
98     IE_ASSERT(!_inputEdges.empty());
99     IE_ASSERT(_outputEdges.size() == 1);
100
101     // Non-zero-port inputs are constant (scales/biases).
102     for (const auto& inEdge : _inputEdges) {
103         if (inEdge->portInd() > 0) {
104             IE_ASSERT(inEdge->input()->usage() == DataUsage::Const);
105         }
106     }
107
108     auto mainDesc = _inputEdges[0]->input()->desc();
109
110     DataMap<BatchSupport> out;
111
112     // PostOp will support batch by merging it with previous dimension.
113     for (const auto& inEdge : _inputEdges) {
114         auto input = inEdge->input();
115
116         if (inEdge->portInd() == 0)
117             continue;
118
119         if (input->desc().dimsOrder().dimInd(Dim::C) == input->desc().numDims() - 2) {
120             IE_ASSERT(input->desc().totalDimSize() == input->desc().dim(Dim::C));
121             out[input] = BatchSupport::ReplicateConstContent;
122         }
123     }
124
125     return out;
126 }
127
128 StageSHAVEsRequirements PostOpStage::getSHAVEsRequirementsImpl() const {
129     // TODO: more SHAVEs leads to hang on public MTCNN network with U8 input
130     return StageSHAVEsRequirements::TwoOrOne;
131 }
132
133 void PostOpStage::finalCheckImpl() const {
134 }
135
136 void PostOpStage::serializeDataImpl(BlobSerializer& serializer) const {
137     IE_ASSERT(!_inputEdges.empty());
138     IE_ASSERT(_outputEdges.size() == 1);
139     IE_ASSERT(_tempBufferEdges.empty());
140
141     auto input = _inputEdges[0]->input();
142     auto output = _outputEdges[0]->output();
143
144     if (input->desc().dimsOrder() == DimsOrder::NC) {
145         input->serializeOldBuffer(
146             handle_from_this(),
147             serializer,
148             DimsOrder::HWC,
149             {
150                 {Dim::W, {Dim::N}},
151                 {Dim::C, {Dim::C}}
152             });
153
154         output->serializeOldBuffer(
155             handle_from_this(),
156             serializer,
157             DimsOrder::HWC,
158             {
159                 {Dim::W, {Dim::N}},
160                 {Dim::C, {Dim::C}}
161             });
162     } else if (input->desc().dim(Dim::N, 1) > 1) {
163         auto perm = input->desc().dimsOrder().toPermutation();
164         IE_ASSERT(perm.size() == 4);
165
166         input->serializeOldBuffer(
167             handle_from_this(),
168             serializer,
169             DimsOrder::HWC,
170             {
171                 {Dim::H, {perm[2], perm[3]}},
172                 {Dim::W, {perm[1]}},
173                 {Dim::C, {perm[0]}}
174             });
175
176         output->serializeOldBuffer(
177             handle_from_this(),
178             serializer,
179             DimsOrder::HWC,
180             {
181                 {Dim::H, {perm[2], perm[3]}},
182                 {Dim::W, {perm[1]}},
183                 {Dim::C, {perm[0]}}
184             });
185     } else {
186         input->serializeOldBuffer(handle_from_this(), serializer);
187
188         output->serializeOldBuffer(handle_from_this(), serializer);
189     }
190
191     for (int i = 1; i < _inputEdges.size(); ++i) {
192         _inputEdges[i]->input()->serializeOldBuffer(handle_from_this(), serializer);
193     }
194 }
195
196 }  // namespace vpu