Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / passes / adjust_data_batch.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <vpu/pass_manager.hpp>
6
7 #include <tuple>
8 #include <vector>
9 #include <algorithm>
10 #include <limits>
11 #include <string>
12 #include <utility>
13 #include <cmath>
14 #include <list>
15 #include <set>
16 #include <unordered_map>
17 #include <memory>
18
19 namespace vpu {
20
21 namespace {
22
23 using ReplicatedDataMap = std::unordered_map<int, Data>;
24 using StagesOrderedSet = std::set<Stage, StageNode::StageIndexCmp>;
25 using BatchTilesMap = DataMap<DataVector>;
26
27 class PassImpl final : public Pass {
28 public:
29     explicit PassImpl(const StageBuilder::Ptr& stageBuilder) : _stageBuilder(stageBuilder) {}
30
31     void run(const Model::Ptr& model) override;
32
33 private:
34     StagesOrderedSet collectAllStageToSplit(const Model::Ptr& model);
35
36     StagesOrderedSet extractNextSubGraph(StagesOrderedSet& stagesToSplit);
37
38     void processStageInputs(
39             const Stage& stage,
40             const Model::Ptr& model,
41             const StagesOrderedSet& curSubGraph,
42             DataMap<DataVector>& subGraphInputTiles,
43             BatchTilesMap& batchTilesMap);
44     void splitStageInput(
45             const StageInput& inEdge,
46             const Model::Ptr& model,
47             const StagesOrderedSet& curSubGraph,
48             DataMap<DataVector>& subGraphInputTiles,
49             BatchTilesMap& batchTilesMap);
50     void replicateStageInput(
51             const StageInput& inEdge,
52             const Model::Ptr& model);
53
54     void processStageOutputs(
55             const Stage& stage,
56             const Model::Ptr& model,
57             const StagesOrderedSet& curSubGraph,
58             DataMap<DataVector>& subGraphOutputTiles,
59             BatchTilesMap& batchTilesMap);
60
61     void replicateStage(
62             const Stage& stage,
63             const Model::Ptr& model,
64             const BatchTilesMap& batchTilesMap);
65
66     void removeOriginalStages(
67             const StagesOrderedSet& curSubGraph,
68             const Model::Ptr& model);
69
70     void addSplitConcatPair(
71             const DataMap<DataVector>& subGraphInputTiles,
72             const DataMap<DataVector>& subGraphOutputTiles,
73             const Model::Ptr& model);
74
75 private:
76     StageBuilder::Ptr _stageBuilder;
77 };
78
79 void PassImpl::run(const Model::Ptr& model) {
80     VPU_PROFILE(adjustDataBatch);
81
82     auto stagesToSplit = collectAllStageToSplit(model);
83
84     while (!stagesToSplit.empty()) {
85         auto curSubGraph = extractNextSubGraph(stagesToSplit);
86         IE_ASSERT(!curSubGraph.empty());
87
88         DataMap<DataVector> subGraphInputTiles;
89         DataMap<DataVector> subGraphOutputTiles;
90         BatchTilesMap batchTilesMap;
91
92         for (const auto& stage : curSubGraph) {
93             processStageInputs(stage, model, curSubGraph, subGraphInputTiles, batchTilesMap);
94             processStageOutputs(stage, model, curSubGraph, subGraphOutputTiles, batchTilesMap);
95             replicateStage(stage, model, batchTilesMap);
96         }
97
98         removeOriginalStages(curSubGraph, model);
99
100         addSplitConcatPair(subGraphInputTiles, subGraphOutputTiles, model);
101     }
102 }
103
104 //
105 // Collect all stages that doesn't support batch
106 //
107
108 StagesOrderedSet PassImpl::collectAllStageToSplit(const Model::Ptr& model) {
109     StagesOrderedSet stagesToSplit;
110
111     for (const auto& stage : model->getStages()) {
112         //
113         // Get stage information
114         //
115
116         const auto& stageInfo = stage->getBatchSupportInfo();
117
118         if (stageInfo.empty()) {
119             continue;
120         }
121
122         //
123         // Get batch size
124         //
125
126         int batchSize = -1;
127
128         for (const auto& inEdge : stage->inputEdges()) {
129             if (!stageInfo.hasInput(inEdge)) {
130                 continue;
131             }
132
133             auto curReq = stageInfo.getInput(inEdge);
134
135             if (curReq == BatchSupport::Split) {
136                 if (batchSize < 0) {
137                     batchSize = inEdge->input()->desc().dim(Dim::N, 1);
138                 } else {
139                     IE_ASSERT(batchSize == inEdge->input()->desc().dim(Dim::N, 1));
140                 }
141             }
142         }
143
144         IE_ASSERT(batchSize > 0);
145
146         for (const auto& outEdge : stage->outputEdges()) {
147             IE_ASSERT(stageInfo.getOutput(outEdge) == BatchSupport::Split);
148             IE_ASSERT(batchSize == outEdge->output()->desc().dim(Dim::N, 1));
149         }
150
151         if (batchSize == 1) {
152             continue;
153         }
154
155         stage->attrs().set("batchSize", batchSize);
156         stagesToSplit.emplace(stage);
157     }
158
159     return stagesToSplit;
160 }
161
162 //
163 // Extract next sub-graph for process, it should be completely independent from other Stages
164 //
165
166 StagesOrderedSet PassImpl::extractNextSubGraph(StagesOrderedSet& stagesToSplit) {
167     //
168     // Add new Stage to the sub-graph only if it depends on Stages from sub-graph only
169     //
170
171     StagesOrderedSet subGraph;
172
173     for (const auto& stage : stagesToSplit) {
174         bool isInternalStage = true;
175         for (const auto& prevStage : stage->prevStages()) {
176             if (subGraph.count(prevStage) == 0) {
177                 isInternalStage = false;
178                 break;
179             }
180         }
181         if (isInternalStage || subGraph.empty()) {
182             subGraph.emplace(stage);
183         }
184
185         bool shouldStop = false;
186         for (const auto& nextStage : stage->nextStages()) {
187             if (stagesToSplit.count(nextStage) == 0) {
188                 shouldStop = true;
189                 break;
190             }
191         }
192         if (shouldStop) {
193             break;
194         }
195     }
196
197     for (const auto& stage : subGraph) {
198         stagesToSplit.erase(stage);
199     }
200
201     return subGraph;
202 }
203
204 void PassImpl::processStageInputs(
205         const Stage& stage,
206         const Model::Ptr& model,
207         const StagesOrderedSet& curSubGraph,
208         DataMap<DataVector>& subGraphInputTiles,
209         BatchTilesMap& batchTilesMap) {
210     const auto& stageInfo = stage->getBatchSupportInfo();
211
212     for (const auto& inEdge : stage->inputEdges()) {
213         if (!stageInfo.hasInput(inEdge)) {
214             continue;
215         }
216
217         auto curReq = stageInfo.getInput(inEdge);
218
219         if (curReq == BatchSupport::Split) {
220             splitStageInput(inEdge, model, curSubGraph, subGraphInputTiles, batchTilesMap);
221         } else if (curReq == BatchSupport::ReplicateConstContent) {
222             replicateStageInput(inEdge, model);
223         }
224     }
225 }
226
227 void PassImpl::splitStageInput(
228         const StageInput& inEdge,
229         const Model::Ptr& model,
230         const StagesOrderedSet& curSubGraph,
231         DataMap<DataVector>& subGraphInputTiles,
232         BatchTilesMap& batchTilesMap) {
233     const auto& input = inEdge->input();
234     const auto& stage = inEdge->consumer();
235
236     auto batchSize = stage->attrs().get<int>("batchSize");
237
238     auto newDesc = input->desc();
239     newDesc.setDim(Dim::N, 1);
240
241     auto& batchTiles = batchTilesMap[input];
242     if (!batchTiles.empty()) {
243         IE_ASSERT(batchTiles.size() == batchSize);
244         return;
245     }
246
247     batchTiles.resize(batchSize);
248     for (int batchInd = 0; batchInd < batchSize; ++batchInd) {
249         auto postfix = formatString("@batch=%d/%d", batchInd + 1, batchSize);
250
251         batchTiles[batchInd] = model->duplicateData(
252             input,
253             postfix,
254             newDesc);
255     }
256
257     bool isInternalInput = false;
258     if (auto producer = input->producer()) {
259         if (curSubGraph.count(producer) != 0) {
260             isInternalInput = true;
261         }
262     }
263     if (!isInternalInput) {
264         auto res = subGraphInputTiles.emplace(input, batchTiles);
265         IE_ASSERT(res.second);
266     }
267 }
268
269 void PassImpl::replicateStageInput(
270         const StageInput& inEdge,
271         const Model::Ptr& model) {
272     const auto& input = inEdge->input();
273     const auto& stage = inEdge->consumer();
274
275     IE_ASSERT(input->usage() == DataUsage::Const);
276     auto batchSize = stage->attrs().get<int>("batchSize");
277
278     auto& replicatedDatas = input->attrs().getOrSet<ReplicatedDataMap>("replicatedDatas", ReplicatedDataMap());
279     if (replicatedDatas.count(batchSize) == 0) {
280         auto content = input->content();
281         IE_ASSERT(content != nullptr);
282
283         auto perm = input->desc().dimsOrder().toPermutation();
284         auto dims = input->desc().dims();
285
286         int maxDimDigit = -1;
287         for (auto d : perm) {
288             maxDimDigit = std::max(maxDimDigit, static_cast<int>(d));
289         }
290         IE_ASSERT(maxDimDigit >= 0);
291
292         perm.emplace_back(static_cast<Dim>(maxDimDigit + 1));
293         dims.set(perm.back(), batchSize);
294
295         DataDesc newDesc(input->desc().type(), DimsOrder::fromPermutation(perm), dims);
296
297         replicatedDatas[batchSize] = model->duplicateData(
298             input,
299             formatString("@replicated=%d", batchSize),
300             newDesc,
301             replicateContent(content, batchSize));
302     }
303 }
304
305 void PassImpl::processStageOutputs(
306         const Stage& stage,
307         const Model::Ptr& model,
308         const StagesOrderedSet& curSubGraph,
309         DataMap<DataVector>& subGraphOutputTiles,
310         BatchTilesMap& batchTilesMap) {
311     auto batchSize = stage->attrs().get<int>("batchSize");
312
313     for (const auto& output : stage->outputs()) {
314         auto newDesc = output->desc();
315         newDesc.setDim(Dim::N, 1);
316
317         auto& batchTiles = batchTilesMap[output];
318         if (!batchTiles.empty()) {
319             IE_ASSERT(batchTiles.size() == batchSize);
320             continue;
321         }
322
323         batchTiles.resize(batchSize);
324         for (int batchInd = 0; batchInd < batchSize; ++batchInd) {
325             auto postfix = formatString("@batch=%d/%d", batchInd + 1, batchSize);
326
327             batchTiles[batchInd] = model->duplicateData(
328                 output,
329                 postfix,
330                 newDesc);
331         }
332
333         bool isInternalOutput = output->usage() == DataUsage::Intermediate;
334         for (const auto& consumer : output->consumers()) {
335             if (curSubGraph.count(consumer) == 0) {
336                 isInternalOutput = false;
337                 break;
338             }
339         }
340         if (!isInternalOutput) {
341             auto res = subGraphOutputTiles.emplace(output, batchTiles);
342             IE_ASSERT(res.second);
343         }
344     }
345 }
346
347 void PassImpl::replicateStage(
348         const Stage& stage,
349         const Model::Ptr& model,
350         const BatchTilesMap& batchTilesMap) {
351     const auto& stageInfo = stage->getBatchSupportInfo();
352     auto batchSize = stage->attrs().get<int>("batchSize");
353
354     for (int batchInd = 0; batchInd < batchSize; ++batchInd) {
355         auto postfix = formatString("@batch=%d/%d", batchInd + 1, batchSize);
356
357         DataVector newInputs;
358         for (const auto& inEdge : stage->inputEdges()) {
359             if (!stageInfo.hasInput(inEdge)) {
360                 newInputs.emplace_back(inEdge->input());
361                 continue;
362             }
363
364             auto curReq = stageInfo.getInput(inEdge);
365
366             if (curReq == BatchSupport::Split) {
367                 const auto& batchTiles = batchTilesMap.at(inEdge->input());
368                 IE_ASSERT(batchTiles.size() == batchSize);
369
370                 newInputs.emplace_back(batchTiles[batchInd]);
371             } else if (curReq == BatchSupport::ReplicateConstContent) {
372                 const auto& replicatedDatas = inEdge->input()->attrs().get<ReplicatedDataMap>("replicatedDatas");
373                 newInputs.emplace_back(replicatedDatas.at(batchSize));
374             }
375         }
376
377         DataVector newOutputs;
378         for (const auto& output : stage->outputs()) {
379             const auto& batchTiles = batchTilesMap.at(output);
380             IE_ASSERT(batchTiles.size() == batchSize);
381
382             newOutputs.emplace_back(batchTiles[batchInd]);
383         }
384
385         auto tileStage = model->duplicateStage(
386             stage,
387             postfix,
388             newInputs,
389             newOutputs);
390
391         tileStage->attrs().set<int>("batchInd", batchInd);
392
393         if (stage->type() == StageType::StubConv) {
394             tileStage->attrs().set("origConvOutput", newOutputs[0]->desc());
395         }
396     }
397 }
398
399 void PassImpl::removeOriginalStages(
400         const StagesOrderedSet& curSubGraph,
401         const Model::Ptr& model) {
402     for (const auto& stage : curSubGraph) {
403         model->removeStage(stage);
404     }
405 }
406
407 void PassImpl::addSplitConcatPair(
408         const DataMap<DataVector>& subGraphInputTiles,
409         const DataMap<DataVector>& subGraphOutputTiles,
410         const Model::Ptr& model) {
411     for (const auto& p : subGraphInputTiles) {
412         _stageBuilder->addSplitStage(
413             model,
414             p.first->name() + "@split-batch",
415             nullptr,
416             Dim::N,
417             p.first,
418             p.second);
419     }
420
421     for (const auto& p : subGraphOutputTiles) {
422         if (p.first->usage() == DataUsage::Intermediate) {
423             IE_ASSERT(p.first->numConsumers() > 0);
424         }
425
426         _stageBuilder->addConcatStage(
427             model,
428             p.first->name() + "@concat-batch",
429             nullptr,
430             Dim::N,
431             p.second,
432             p.first);
433     }
434 }
435
436 }  // namespace
437
438 Pass::Ptr PassManager::adjustDataBatch() {
439     return std::make_shared<PassImpl>(_stageBuilder);
440 }
441
442 }  // namespace vpu