Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / stages / concat.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <vpu/frontend/frontend.hpp>
6
7 #include <vector>
8 #include <limits>
9 #include <string>
10 #include <algorithm>
11 #include <memory>
12 #include <set>
13 #include <unordered_set>
14 #include <utility>
15
16 #include <vpu/utils/numeric.hpp>
17
18 namespace vpu {
19
20 namespace {
21
22 class ConcatStage final : public StageNode {
23 protected:
24     StagePtr cloneImpl() const override {
25         return std::make_shared<ConcatStage>(*this);
26     }
27
28     void propagateScaleFactorsImpl(
29             const SmallVector<float>& inputScales,
30             ScalePropagationStep step,
31             StageDataInfo<float>& scaleInfo) override {
32         auto output = outputEdge(0)->output();
33
34         if (step == ScalePropagationStep::Propagate) {
35             // Keep the largest input scale factor.
36             auto maxScale = std::numeric_limits<float>::lowest();
37             for (const auto& inEdge : inputEdges()) {
38                 maxScale = std::max(maxScale, inputScales[inEdge->portInd()]);
39             }
40
41             IE_ASSERT(maxScale > 0.0f);
42
43             for (const auto& inEdge : inputEdges()) {
44                 auto curScale = inputScales[inEdge->portInd()];
45
46                 if (!isFloatEqual(curScale, maxScale)) {
47                     scaleInfo.setInput(inEdge, maxScale / curScale);
48                 }
49             }
50
51             scaleInfo.setOutput(outputEdge(0), maxScale);
52         } else {
53             // Concat can only propagate scaling.
54             for (const auto& inEdge : inputEdges()) {
55                 scaleInfo.setInput(inEdge, 1.0f);
56             }
57
58             scaleInfo.setOutput(outputEdge(0), 1.0f);
59         }
60     }
61
62     void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
63         auto output = outputEdge(0)->output();
64
65         DimsOrderMap<int> dimsOrderVotes;
66         for (const auto& inEdge : inputEdges()) {
67             dimsOrderVotes[inEdge->input()->desc().dimsOrder()]++;
68         }
69
70         // Select DimsOrder with most votes.
71         // For equal votes : HCW > CHW > HWC.
72
73         DimsOrder finalOrder;
74         int curVotes = -1;
75         for (const auto& p : dimsOrderVotes) {
76             if (p.second > curVotes) {
77                 finalOrder = p.first;
78                 curVotes = p.second;
79             } else if (p.second == curVotes) {
80                 if (p.first.numDims() >= 3) {
81                     if (p.first.dimInd(Dim::C) == 2) {
82                         finalOrder = p.first;
83                     } else if (p.first.dimInd(Dim::C) == 3 &&
84                                finalOrder.dimInd(Dim::C) != 2) {
85                         finalOrder = p.first;
86                     }
87                 }
88             }
89         }
90
91         IE_ASSERT(finalOrder.numDims() > 0);
92         IE_ASSERT(curVotes > 0);
93
94         for (const auto& inEdge : inputEdges()) {
95             orderInfo.setInput(inEdge, finalOrder);
96         }
97
98         orderInfo.setOutput(outputEdge(0), finalOrder);
99     }
100
101     void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
102         auto output = outputEdge(0)->output();
103
104         auto dimsOrder = output->desc().dimsOrder();
105
106         //
107         // Get smallest Dim over which Concat is done.
108         //
109
110         auto minConcatDimInd = dimsOrder.numDims() - 1;
111
112         for (const auto& inEdge : inputEdges()) {
113             auto input = inEdge->input();
114
115             for (const auto& p : output->desc().dims()) {
116                 if (input->desc().dim(p.first) != p.second) {
117                     minConcatDimInd = std::min(minConcatDimInd, dimsOrder.dimInd(p.first));
118                 }
119             }
120         }
121
122         IE_ASSERT(minConcatDimInd < dimsOrder.numDims());
123
124         //
125         // Initial StridesRequirement for inputs and output.
126         //
127
128         auto outputReqs = output->requiredStrides();
129
130         auto inputReqs = outputReqs;
131         for (int i = minConcatDimInd + 1; i < dimsOrder.numDims(); ++i) {
132             inputReqs.remove(i);
133         }
134
135         //
136         // Merge input StridesRequirement.
137         //
138
139         for (const auto& inEdge : inputEdges()) {
140             auto curInput = inEdge->input();
141             auto curInputReqs = curInput->requiredStrides();
142
143             for (int i = 0; i < minConcatDimInd + 1; ++i) {
144                 if (outputReqs.get(i) == DimStride::Any) {
145                     if (curInputReqs.get(i) != DimStride::Any) {
146                         inputReqs.add(i, curInputReqs.get(i));
147                         outputReqs.add(i, curInputReqs.get(i));
148                     }
149                 }
150             }
151         }
152
153         //
154         // Merge output consumers StridesRequirement.
155         //
156
157         for (const auto& consumerEdge : output->consumerEdges()) {
158             const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
159
160             if (consumerInfo.hasInput(consumerEdge)) {
161                 const auto& consumerReqs = consumerInfo.getInput(consumerEdge);
162
163                 for (int i = 0; i < minConcatDimInd + 1; ++i) {
164                     if (outputReqs.get(i) == DimStride::Any) {
165                         if (consumerReqs.get(i) != DimStride::Any) {
166                             inputReqs.add(i, consumerReqs.get(i));
167                             outputReqs.add(i, consumerReqs.get(i));
168                         }
169                     }
170                 }
171             }
172         }
173
174         //
175         // Return merged StridesRequirement.
176         //
177
178         for (const auto& inEdge : inputEdges()) {
179             stridesInfo.setInput(inEdge, inputReqs);
180         }
181         stridesInfo.setOutput(outputEdge(0), outputReqs);
182     }
183
184     void finalizeDataLayoutImpl() override {
185     }
186
187     void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
188     }
189
190     void initialCheckImpl() const override {
191         IE_ASSERT(numInputs() > 0);
192         IE_ASSERT(numOutputs() == 1);
193
194         const auto& firstInputPrecision = input(0)->desc().type();
195         assertAllInputsOutputsTypes(this, {firstInputPrecision}, {firstInputPrecision});
196     }
197
198     void serializeParamsImpl(BlobSerializer&) const override {
199         VPU_THROW_EXCEPTION << "Must never be called";
200     }
201
202     void serializeDataImpl(BlobSerializer&) const override {
203         VPU_THROW_EXCEPTION << "Must never be called";
204     }
205 };
206
207 }  // namespace
208
209 void FrontEnd::parseConcat(
210         const Model::Ptr& model,
211         const ie::CNNLayerPtr& _layer,
212         const DataVector& inputs,
213         const DataVector& outputs) {
214     IE_ASSERT(!inputs.empty());
215     IE_ASSERT(outputs.size() == 1);
216
217     auto output = outputs[0];
218
219     auto layer = std::dynamic_pointer_cast<ie::ConcatLayer>(_layer);
220     IE_ASSERT(layer != nullptr);
221
222     IE_ASSERT(layer->_axis < output->desc().numDims());
223
224     auto perm = DimsOrder::fromNumDims(output->desc().numDims()).toPermutation();
225     auto axis = perm[output->desc().numDims() - 1 - layer->_axis];
226
227     _stageBuilder->addConcatStage(model, layer->name, layer, axis, inputs, output);
228 }
229
230 Stage StageBuilder::addConcatStage(
231         const Model::Ptr& model,
232         const std::string& name,
233         const ie::CNNLayerPtr& layer,
234         Dim axis,
235         const DataVector& inputs,
236         const Data& output) {
237     std::vector<DimValues> offsets;
238     offsets.reserve(inputs.size());
239
240     DimValues curOffset({{axis, 0}});
241     for (const auto& input : inputs) {
242         offsets.emplace_back(curOffset);
243         curOffset.set(axis, curOffset[axis] + input->desc().dim(axis));
244     }
245
246     auto stage = addConcatStage(model, name, layer, std::move(offsets), inputs, output);
247
248     stage->attrs().set("axis", axis);
249
250     return stage;
251 }
252
253 Stage StageBuilder::addConcatStage(
254         const Model::Ptr& model,
255         const std::string& name,
256         const ie::CNNLayerPtr& layer,
257         std::vector<DimValues>&& offsets,
258         const DataVector& inputs,
259         const Data& output) {
260     IE_ASSERT(offsets.size() == inputs.size());
261
262     auto stage = model->addNewStage<ConcatStage>(
263         name,
264         StageType::Concat,
265         layer,
266         inputs,
267         {output});
268
269     stage->attrs().set("offsets", std::move(offsets));
270
271     return stage;
272 }
273
274 }  // namespace vpu