1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <vpu/frontend/frontend.hpp>
13 #include <unordered_set>
16 #include <vpu/utils/numeric.hpp>
22 class ConcatStage final : public StageNode {
24 StagePtr cloneImpl() const override {
25 return std::make_shared<ConcatStage>(*this);
28 void propagateScaleFactorsImpl(
29 const SmallVector<float>& inputScales,
30 ScalePropagationStep step,
31 StageDataInfo<float>& scaleInfo) override {
32 auto output = outputEdge(0)->output();
34 if (step == ScalePropagationStep::Propagate) {
35 // Keep the largest input scale factor.
36 auto maxScale = std::numeric_limits<float>::lowest();
37 for (const auto& inEdge : inputEdges()) {
38 maxScale = std::max(maxScale, inputScales[inEdge->portInd()]);
41 IE_ASSERT(maxScale > 0.0f);
43 for (const auto& inEdge : inputEdges()) {
44 auto curScale = inputScales[inEdge->portInd()];
46 if (!isFloatEqual(curScale, maxScale)) {
47 scaleInfo.setInput(inEdge, maxScale / curScale);
51 scaleInfo.setOutput(outputEdge(0), maxScale);
53 // Concat can only propagate scaling.
54 for (const auto& inEdge : inputEdges()) {
55 scaleInfo.setInput(inEdge, 1.0f);
58 scaleInfo.setOutput(outputEdge(0), 1.0f);
62 void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
63 auto output = outputEdge(0)->output();
65 DimsOrderMap<int> dimsOrderVotes;
66 for (const auto& inEdge : inputEdges()) {
67 dimsOrderVotes[inEdge->input()->desc().dimsOrder()]++;
70 // Select DimsOrder with most votes.
71 // For equal votes : HCW > CHW > HWC.
75 for (const auto& p : dimsOrderVotes) {
76 if (p.second > curVotes) {
79 } else if (p.second == curVotes) {
80 if (p.first.numDims() >= 3) {
81 if (p.first.dimInd(Dim::C) == 2) {
83 } else if (p.first.dimInd(Dim::C) == 3 &&
84 finalOrder.dimInd(Dim::C) != 2) {
91 IE_ASSERT(finalOrder.numDims() > 0);
92 IE_ASSERT(curVotes > 0);
94 for (const auto& inEdge : inputEdges()) {
95 orderInfo.setInput(inEdge, finalOrder);
98 orderInfo.setOutput(outputEdge(0), finalOrder);
101 void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
102 auto output = outputEdge(0)->output();
104 auto dimsOrder = output->desc().dimsOrder();
107 // Get smallest Dim over which Concat is done.
110 auto minConcatDimInd = dimsOrder.numDims() - 1;
112 for (const auto& inEdge : inputEdges()) {
113 auto input = inEdge->input();
115 for (const auto& p : output->desc().dims()) {
116 if (input->desc().dim(p.first) != p.second) {
117 minConcatDimInd = std::min(minConcatDimInd, dimsOrder.dimInd(p.first));
122 IE_ASSERT(minConcatDimInd < dimsOrder.numDims());
125 // Initial StridesRequirement for inputs and output.
128 auto outputReqs = output->requiredStrides();
130 auto inputReqs = outputReqs;
131 for (int i = minConcatDimInd + 1; i < dimsOrder.numDims(); ++i) {
136 // Merge input StridesRequirement.
139 for (const auto& inEdge : inputEdges()) {
140 auto curInput = inEdge->input();
141 auto curInputReqs = curInput->requiredStrides();
143 for (int i = 0; i < minConcatDimInd + 1; ++i) {
144 if (outputReqs.get(i) == DimStride::Any) {
145 if (curInputReqs.get(i) != DimStride::Any) {
146 inputReqs.add(i, curInputReqs.get(i));
147 outputReqs.add(i, curInputReqs.get(i));
154 // Merge output consumers StridesRequirement.
157 for (const auto& consumerEdge : output->consumerEdges()) {
158 const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
160 if (consumerInfo.hasInput(consumerEdge)) {
161 const auto& consumerReqs = consumerInfo.getInput(consumerEdge);
163 for (int i = 0; i < minConcatDimInd + 1; ++i) {
164 if (outputReqs.get(i) == DimStride::Any) {
165 if (consumerReqs.get(i) != DimStride::Any) {
166 inputReqs.add(i, consumerReqs.get(i));
167 outputReqs.add(i, consumerReqs.get(i));
175 // Return merged StridesRequirement.
178 for (const auto& inEdge : inputEdges()) {
179 stridesInfo.setInput(inEdge, inputReqs);
181 stridesInfo.setOutput(outputEdge(0), outputReqs);
184 void finalizeDataLayoutImpl() override {
187 void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
190 void initialCheckImpl() const override {
191 IE_ASSERT(numInputs() > 0);
192 IE_ASSERT(numOutputs() == 1);
194 const auto& firstInputPrecision = input(0)->desc().type();
195 assertAllInputsOutputsTypes(this, {firstInputPrecision}, {firstInputPrecision});
198 void serializeParamsImpl(BlobSerializer&) const override {
199 VPU_THROW_EXCEPTION << "Must never be called";
202 void serializeDataImpl(BlobSerializer&) const override {
203 VPU_THROW_EXCEPTION << "Must never be called";
209 void FrontEnd::parseConcat(
210 const Model::Ptr& model,
211 const ie::CNNLayerPtr& _layer,
212 const DataVector& inputs,
213 const DataVector& outputs) {
214 IE_ASSERT(!inputs.empty());
215 IE_ASSERT(outputs.size() == 1);
217 auto output = outputs[0];
219 auto layer = std::dynamic_pointer_cast<ie::ConcatLayer>(_layer);
220 IE_ASSERT(layer != nullptr);
222 IE_ASSERT(layer->_axis < output->desc().numDims());
224 auto perm = DimsOrder::fromNumDims(output->desc().numDims()).toPermutation();
225 auto axis = perm[output->desc().numDims() - 1 - layer->_axis];
227 _stageBuilder->addConcatStage(model, layer->name, layer, axis, inputs, output);
230 Stage StageBuilder::addConcatStage(
231 const Model::Ptr& model,
232 const std::string& name,
233 const ie::CNNLayerPtr& layer,
235 const DataVector& inputs,
236 const Data& output) {
237 std::vector<DimValues> offsets;
238 offsets.reserve(inputs.size());
240 DimValues curOffset({{axis, 0}});
241 for (const auto& input : inputs) {
242 offsets.emplace_back(curOffset);
243 curOffset.set(axis, curOffset[axis] + input->desc().dim(axis));
246 auto stage = addConcatStage(model, name, layer, std::move(offsets), inputs, output);
248 stage->attrs().set("axis", axis);
253 Stage StageBuilder::addConcatStage(
254 const Model::Ptr& model,
255 const std::string& name,
256 const ie::CNNLayerPtr& layer,
257 std::vector<DimValues>&& offsets,
258 const DataVector& inputs,
259 const Data& output) {
260 IE_ASSERT(offsets.size() == inputs.size());
262 auto stage = model->addNewStage<ConcatStage>(
269 stage->attrs().set("offsets", std::move(offsets));