3588814fd9ccbc60d7ca4c23e1ec82d77e87df71
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / passes / sw_conv_adaptation.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <vpu/pass_manager.hpp>
6 #include <limits>
7
8 #include <vector>
9 #include <string>
10 #include <memory>
11 #include <unordered_set>
12 #include <set>
13
14 #include <vpu/sw/utility.hpp>
15
16 #define REFERENCE_CONVOLUTION 0
17
18 namespace vpu {
19
20 namespace {
21
22 class ConvIm2ColWeightsContent final : public CalculatedDataContent {
23 public:
24     explicit ConvIm2ColWeightsContent(const DataContent::Ptr& origContent) :
25             CalculatedDataContent({origContent}) {
26     }
27
28 protected:
29     void fillTempBuf(const SmallVector<DataContent::Ptr, 2>& baseContents, void* tempBuf) const override {
30         VPU_PROFILE(ConvIm2ColWeightsContent);
31         kchw_to_khwc(baseContents[0]->get<fp16_t>(), static_cast<fp16_t*>(tempBuf), _desc);
32     }
33 };
34
35 class Conv3x3WeightsContent final : public CalculatedDataContent {
36 public:
37     explicit Conv3x3WeightsContent(const DataContent::Ptr& origContent) :
38             CalculatedDataContent({origContent}) {
39     }
40
41 protected:
42     void fillTempBuf(const SmallVector<DataContent::Ptr, 2>& baseContents, void* tempBuf) const override {
43         VPU_PROFILE(Conv3x3WeightsContent);
44         kchw_to_hwkc(baseContents[0]->get<fp16_t>(), static_cast<fp16_t*>(tempBuf), _desc);
45     }
46 };
47
48 class ConvCHWWeightsContent final : public CalculatedDataContent {
49 public:
50     explicit ConvCHWWeightsContent(const DataContent::Ptr& origContent) :
51             CalculatedDataContent({origContent}) {
52     }
53
54 protected:
55     void fillTempBuf(const SmallVector<DataContent::Ptr, 2>& baseContents, void* tempBuf) const override {
56         VPU_PROFILE(ConvCHWWeightsContent);
57         kchw_to_hwkc(baseContents[0]->get<fp16_t>(), static_cast<fp16_t*>(tempBuf), _desc);
58     }
59 };
60
61 class ConvStage final : public StageNode {
62 private:
63     StagePtr cloneImpl() const override {
64         return std::make_shared<ConvStage>(*this);
65     }
66
67     void propagateScaleFactorsImpl(
68             const SmallVector<float>&,
69             ScalePropagationStep) override {
70         VPU_THROW_EXCEPTION << "Must never be called";
71     }
72
73     void propagateDataOrderImpl() const override {
74         IE_ASSERT(_inputEdges.size() == 3);
75         IE_ASSERT(_outputEdges.size() == 1);
76
77         auto input = _inputEdges[0]->input();
78         auto weights = _inputEdges[1]->input();
79         auto output = _outputEdges[0]->output();
80
81         auto finalOrder = input->desc().dimsOrder();
82         if (finalOrder.dimInd(Dim::C) == 1) {
83             // HCW -> CHW
84             finalOrder.moveDim(Dim::C, 2);
85         }
86
87         if (_type == StageType::Conv ||
88             _type == StageType::Im2ColConvolution) {
89             if (finalOrder != input->desc().dimsOrder()) {
90                 _orderInfo.setInput(_inputEdges[0], finalOrder);
91             }
92             _orderInfo.setOutput(_outputEdges[0], finalOrder);
93         } else if (_type == StageType::DepthConv) {
94             if (finalOrder != input->desc().dimsOrder()) {
95                 _orderInfo.setInput(_inputEdges[0], finalOrder);
96             }
97             _orderInfo.setOutput(_outputEdges[0], finalOrder);
98         } else {
99             _orderInfo.setInput(_inputEdges[0], finalOrder.createMovedDim(Dim::C, 0));
100             _orderInfo.setOutput(_outputEdges[0], finalOrder.createMovedDim(Dim::C, 0));
101         }
102     }
103
104     void getDataStridesRequirementsImpl() const override {
105         IE_ASSERT(_inputEdges.size() == 3);
106         IE_ASSERT(_outputEdges.size() == 1);
107
108         if (_type != StageType::DepthConv) {
109             _stridesInfo.setInput(_inputEdges[0], StridesRequirement::compact());
110             _stridesInfo.setOutput(_outputEdges[0], StridesRequirement::compact());
111         }
112     }
113
114     void finalizeDataLayoutImpl() override {
115         IE_ASSERT(_inputEdges.size() == 3);
116         IE_ASSERT(_outputEdges.size() == 1);
117
118         auto input = _inputEdges[0]->input();
119         auto weights = _inputEdges[1]->input();
120         auto output = _outputEdges[0]->output();
121
122         auto kernelSizeX = attrs().get<int>("kernelSizeX");
123         auto kernelSizeY = attrs().get<int>("kernelSizeY");
124
125         Data swWeights;
126
127         if (_type == StageType::DepthConv) {
128             swWeights = weights->attrs().getOrDefault<Data>("swWeights", nullptr);
129             if (swWeights == nullptr) {
130                 DataDesc newWeightsDesc({
131                     kernelSizeX * kernelSizeY,
132                     1,
133                     output->desc().dim(Dim::C)});
134
135                 swWeights = _model->duplicateData(
136                     weights,
137                     "@SW",
138                     newWeightsDesc,
139                     std::make_shared<DefaultSwWeightsContent>(weights->content()));
140
141                 weights->attrs().set<Data>("swWeights", swWeights);
142             }
143         } else if (input->desc().dimsOrder().dimInd(Dim::C) == 0) {
144             //
145             // HWC case
146             //
147
148             auto isSpatialConv = attrs().get<bool>("isSpatialConv");
149             auto isConv1x1 = attrs().get<bool>("isConv1x1");
150             auto isConv3x3 = attrs().get<bool>("isConv3x3");
151
152             swWeights = weights->attrs().getOrDefault<Data>("swWeights", nullptr);
153             if (swWeights == nullptr) {
154                 DataDesc newWeightsDesc({
155                     kernelSizeX * kernelSizeY,
156                     input->desc().dim(Dim::C),
157                     output->desc().dim(Dim::C)});
158
159                 if (isSpatialConv) {
160                     swWeights = _model->duplicateData(
161                         weights,
162                         "@SW",
163                         newWeightsDesc,
164                         std::make_shared<DefaultSwWeightsContent>(weights->content()));
165                 } else if (isConv1x1) {
166                     swWeights = _model->duplicateData(
167                         weights,
168                         "@SW",
169                         newWeightsDesc,
170                         weights->content());
171                 } else if (isConv3x3) {
172                     swWeights = _model->duplicateData(
173                         weights,
174                         "@SW",
175                         newWeightsDesc,
176                         std::make_shared<Conv3x3WeightsContent>(weights->content()));
177                 } else {
178                     swWeights = _model->duplicateData(
179                         weights,
180                         "@SW",
181                         newWeightsDesc,
182                         std::make_shared<ConvIm2ColWeightsContent>(weights->content()));
183                 }
184
185                 weights->attrs().set<Data>("swWeights", swWeights);
186             }
187         } else if (input->desc().dimsOrder().dimInd(Dim::C) == 2) {
188             //
189             // CHW case
190             //
191
192             auto isConv1x1 = attrs().get<bool>("isConv1x1");
193
194             if (_type == StageType::Im2ColConvolution) {
195                 // Transform CHW "Im2ColConvolution" into CHW "Conv"
196                 _type = StageType::Conv;
197             }
198
199             swWeights = weights->attrs().getOrDefault<Data>("swWeights", nullptr);
200             if (swWeights == nullptr) {
201                 DataDesc newWeightsDesc({
202                     kernelSizeX * kernelSizeY,
203                     input->desc().dim(Dim::C),
204                     output->desc().dim(Dim::C)});
205
206                 if (isConv1x1) {
207                     swWeights = _model->duplicateData(
208                         weights,
209                         "@SW",
210                         newWeightsDesc,
211                         weights->content());
212                 } else {
213                     swWeights = _model->duplicateData(
214                         weights,
215                         "@SW",
216                         newWeightsDesc,
217                         std::make_shared<ConvCHWWeightsContent>(weights->content()));
218                 }
219
220                 weights->attrs().set<Data>("swWeights", swWeights);
221             }
222         }
223
224         IE_ASSERT(swWeights != nullptr);
225
226         _model->replaceStageInput(_inputEdges[1], swWeights);
227     }
228
229     void getBatchSupportInfoImpl() const  override {
230         IE_ASSERT(_inputEdges.size() == 3);
231         IE_ASSERT(_outputEdges.size() == 1);
232
233         _batchInfo.setInput(_inputEdges[0], BatchSupport::Split);
234         _batchInfo.setOutput(_outputEdges[0], BatchSupport::Split);
235     }
236
237     void finalCheckImpl() const override {
238     }
239
240     void serializeParamsImpl(BlobSerializer& serializer) const override {
241         auto kernelSizeX = attrs().get<int>("kernelSizeX");
242         auto kernelSizeY = attrs().get<int>("kernelSizeY");
243         auto kernelStrideX = attrs().get<int>("kernelStrideX");
244         auto kernelStrideY = attrs().get<int>("kernelStrideY");
245         auto padLeft = attrs().get<int>("padLeft");
246         auto padTop = attrs().get<int>("padTop");
247         auto dilationX = attrs().get<int>("dilationX");
248         auto dilationY = attrs().get<int>("dilationY");
249
250         serializer.append(static_cast<uint32_t>(kernelSizeX));
251         serializer.append(static_cast<uint32_t>(kernelSizeY));
252         serializer.append(static_cast<uint32_t>(kernelStrideX));
253         serializer.append(static_cast<uint32_t>(kernelStrideY));
254         serializer.append(static_cast<uint32_t>(padLeft));
255         serializer.append(static_cast<uint32_t>(padTop));
256         serializer.append(static_cast<uint32_t>(dilationX));
257         serializer.append(static_cast<uint32_t>(dilationY));
258     }
259
260     void serializeDataImpl(BlobSerializer& serializer) const override {
261         IE_ASSERT(_inputEdges.size() == 3);
262         IE_ASSERT(_outputEdges.size() == 1);
263
264         auto input = _inputEdges[0]->input();
265         auto weights = _inputEdges[1]->input();
266         auto biases = _inputEdges[2]->input();
267         auto output = _outputEdges[0]->output();
268
269         input->serializeOldBuffer(handle_from_this(), serializer);
270         output->serializeOldBuffer(handle_from_this(), serializer);
271         weights->serializeOldBuffer(handle_from_this(), serializer);
272
273         if (!_tempBufferEdges.empty()) {
274             _tempBufferEdges[0]->tempBuffer()->serializeOldBuffer(handle_from_this(), serializer);
275         }
276
277         // TODO: remove this
278         biases->serializeOldBuffer(handle_from_this(), serializer);
279     }
280 };
281
282 class PassImpl final : public Pass {
283 public:
284     explicit PassImpl(const StageBuilder::Ptr& stageBuilder) : _stageBuilder(stageBuilder) {}
285
286     void run(const Model::Ptr& model) override;
287
288 private:
289     StageBuilder::Ptr _stageBuilder;
290 };
291
292 void PassImpl::run(const Model::Ptr& model) {
293     VPU_PROFILE(swConvAdaptation);
294
295     for (const auto& stage : model->getStages()) {
296         if (stage->type() != StageType::StubConv)
297             continue;
298
299         auto origStageName = stage->name();
300         auto origLayer = stage->origLayer();
301
302         auto input = stage->input(0);
303         auto weights = stage->input(1);
304         auto biases = stage->input(2);
305         auto output = stage->output(0);
306
307         auto kernelSizeX = stage->attrs().get<int>("kernelSizeX");
308         auto kernelSizeY = stage->attrs().get<int>("kernelSizeY");
309         auto kernelStrideX = stage->attrs().get<int>("kernelStrideX");
310         auto kernelStrideY = stage->attrs().get<int>("kernelStrideY");
311         auto padLeft = stage->attrs().get<int>("padLeft");
312         auto padRight = stage->attrs().get<int>("padRight");
313         auto padTop = stage->attrs().get<int>("padTop");
314         auto padBottom = stage->attrs().get<int>("padBottom");
315         auto dilationX = stage->attrs().get<int>("dilationX");
316         auto dilationY = stage->attrs().get<int>("dilationY");
317         auto groupSize = stage->attrs().get<int>("groupSize");
318
319         model->removeStage(stage);
320
321         bool isFC = (
322             kernelSizeX == 1 && kernelSizeY == 1 &&
323             kernelStrideX == 1 && kernelStrideY == 1 &&
324             padLeft == 0 && padRight == 0 && padTop == 0 && padBottom == 0 &&
325             dilationX == 1 && dilationY == 1 &&
326             input->desc().dim(Dim::W) == 1 && input->desc().dim(Dim::H) == 1 &&
327             output->desc().dim(Dim::W) == 1 && output->desc().dim(Dim::H) == 1);
328
329         bool isConv1x1 = (
330             kernelSizeX == 1 && kernelSizeY == 1 &&
331             dilationX == 1 && dilationY == 1 &&
332             !isFC);
333
334         bool isConv3x3 = (
335             kernelSizeX == 3 && kernelSizeY == 3 &&
336             (input->desc().dim(Dim::C) / groupSize) > 3 &&
337             ((input->desc().dim(Dim::C) / groupSize) * (output->desc().dim(Dim::C) / groupSize)) > 256);
338
339         bool iskernelSizeMatchSpatial = (
340             kernelSizeX > 1 && kernelSizeX < 12 && kernelSizeX % 2 == 1);
341
342         bool isSpatialConv = (
343             iskernelSizeMatchSpatial &&
344             kernelSizeY != 1 &&  // kernelSizeX != 1 was checked in iskernelSizeMatchSpatial condition
345             ((input->desc().dim(Dim::C) / groupSize) * (output->desc().dim(Dim::C) / groupSize)) <= 256 &&
346             groupSize == 1);
347
348 #if REFERENCE_CONVOLUTION
349         isSpatialConv  = false;
350         isConv3x3 = false;
351         isConv1x1 = false;
352 #endif
353
354         if (groupSize == 1) {
355             if (isFC) {
356                 _stageBuilder->addSwFullyConnectedStage(
357                     model,
358                     origStageName,
359                     origLayer,
360                     input,
361                     weights,
362                     biases,
363                     output);
364             } else {
365                 if (biases->usage() != DataUsage::Fake) {
366                     auto tempOutput = model->duplicateData(
367                         output,
368                         "@temp");
369
370                     _stageBuilder->addBiasStage(
371                         model,
372                         origStageName + "@biases",
373                         origLayer,
374                         tempOutput, biases,
375                         output);
376
377                     output = tempOutput;
378                 }
379
380                 Stage swStage;
381                 if (isConv1x1 || isSpatialConv || isConv3x3) {
382                     swStage = model->addNewStage<ConvStage>(
383                         origStageName,
384                         StageType::Conv,
385                         origLayer,
386                         {input, weights, biases},
387                         {output});
388                 } else {
389                     swStage = model->addNewStage<ConvStage>(
390                         origStageName,
391 #if REFERENCE_CONVOLUTION
392                         StageType::RefConvolution,
393 #else
394                         StageType::Im2ColConvolution,
395 #endif
396                         origLayer,
397                         {input, weights, biases},
398                         {output});
399
400                     double im2ColBufSizeF = static_cast<double>(kernelSizeX) * kernelSizeY *
401                         output->desc().dim(Dim::W) * output->desc().dim(Dim::H) * input->desc().dim(Dim::C)
402                         + 32;
403
404                     if (im2ColBufSizeF >= std::numeric_limits<int>::max()) {
405                         VPU_THROW_EXCEPTION << "stage: " << origStageName << ", im2col bufferSize cannot fit 32s: "
406                             << std::setprecision(0) << std::fixed << im2ColBufSizeF
407                             << "(" << kernelSizeX << "x" << kernelSizeY << "x"
408                             << output->desc().dim(Dim::W) << "x" << output->desc().dim(Dim::H) << "x" << output->desc().dim(Dim::C) << ")";
409                     }
410
411                     model->addTempBuffer(swStage, DataDesc({static_cast<int>(im2ColBufSizeF)}));
412                 }
413
414                 swStage->attrs().set<int>("kernelSizeX", kernelSizeX);
415                 swStage->attrs().set<int>("kernelSizeY", kernelSizeY);
416
417                 swStage->attrs().set<int>("kernelStrideX", kernelStrideX);
418                 swStage->attrs().set<int>("kernelStrideY", kernelStrideY);
419
420                 swStage->attrs().set<int>("padLeft", padLeft);
421                 swStage->attrs().set<int>("padRight", padRight);
422                 swStage->attrs().set<int>("padTop", padTop);
423                 swStage->attrs().set<int>("padBottom", padBottom);
424
425                 swStage->attrs().set<int>("dilationX", dilationX);
426                 swStage->attrs().set<int>("dilationY", dilationY);
427
428                 swStage->attrs().set<bool>("isSpatialConv", isSpatialConv);
429                 swStage->attrs().set<bool>("isConv1x1", isConv1x1);
430                 swStage->attrs().set<bool>("isConv3x3", isConv3x3);
431             }
432         } else if (groupSize == input->desc().dim(Dim::C) &&
433                    groupSize == output->desc().dim(Dim::C)) {
434             if (biases->usage() != DataUsage::Fake) {
435                 auto tempOutput = model->duplicateData(
436                     output,
437                     "@temp");
438
439                 _stageBuilder->addBiasStage(
440                     model,
441                     origStageName + "@biases",
442                     origLayer,
443                     tempOutput, biases,
444                     output);
445
446                 output = tempOutput;
447             }
448
449             auto swStage = model->addNewStage<ConvStage>(
450                 origStageName,
451                 StageType::DepthConv,
452                 origLayer,
453                 {input, weights, biases},
454                 {output});
455
456             swStage->attrs().set<int>("kernelSizeX", kernelSizeX);
457             swStage->attrs().set<int>("kernelSizeY", kernelSizeY);
458
459             swStage->attrs().set<int>("kernelStrideX", kernelStrideX);
460             swStage->attrs().set<int>("kernelStrideY", kernelStrideY);
461
462             swStage->attrs().set<int>("padLeft", padLeft);
463             swStage->attrs().set<int>("padRight", padRight);
464             swStage->attrs().set<int>("padTop", padTop);
465             swStage->attrs().set<int>("padBottom", padBottom);
466
467             swStage->attrs().set<int>("dilationX", dilationX);
468             swStage->attrs().set<int>("dilationY", dilationY);
469         } else {
470             VPU_THROW_EXCEPTION << "Internal error : grouped convolution was not processed";
471         }
472     }
473 }
474
475 }  // namespace
476
477 Pass::Ptr PassManager::swConvAdaptation() {
478     return std::make_shared<PassImpl>(_stageBuilder);
479 }
480
481 }  // namespace vpu