1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <vpu/pass_manager.hpp>
16 #include <unordered_map>
23 using ReplicatedDataMap = std::unordered_map<int, Data>;
24 using StagesOrderedSet = std::set<Stage, StageNode::StageIndexCmp>;
25 using BatchTilesMap = DataMap<DataVector>;
27 class PassImpl final : public Pass {
29 explicit PassImpl(const StageBuilder::Ptr& stageBuilder) : _stageBuilder(stageBuilder) {}
31 void run(const Model::Ptr& model) override;
34 StagesOrderedSet collectAllStageToSplit(const Model::Ptr& model);
36 StagesOrderedSet extractNextSubGraph(StagesOrderedSet& stagesToSplit);
38 void processStageInputs(
40 const Model::Ptr& model,
41 const StagesOrderedSet& curSubGraph,
42 DataMap<DataVector>& subGraphInputTiles,
43 BatchTilesMap& batchTilesMap);
45 const StageInput& inEdge,
46 const Model::Ptr& model,
47 const StagesOrderedSet& curSubGraph,
48 DataMap<DataVector>& subGraphInputTiles,
49 BatchTilesMap& batchTilesMap);
50 void replicateStageInput(
51 const StageInput& inEdge,
52 const Model::Ptr& model);
54 void processStageOutputs(
56 const Model::Ptr& model,
57 const StagesOrderedSet& curSubGraph,
58 DataMap<DataVector>& subGraphOutputTiles,
59 BatchTilesMap& batchTilesMap);
63 const Model::Ptr& model,
64 const BatchTilesMap& batchTilesMap);
66 void removeOriginalStages(
67 const StagesOrderedSet& curSubGraph,
68 const Model::Ptr& model);
70 void addSplitConcatPair(
71 const DataMap<DataVector>& subGraphInputTiles,
72 const DataMap<DataVector>& subGraphOutputTiles,
73 const Model::Ptr& model);
76 StageBuilder::Ptr _stageBuilder;
79 void PassImpl::run(const Model::Ptr& model) {
80 VPU_PROFILE(adjustDataBatch);
82 auto stagesToSplit = collectAllStageToSplit(model);
84 while (!stagesToSplit.empty()) {
85 auto curSubGraph = extractNextSubGraph(stagesToSplit);
86 IE_ASSERT(!curSubGraph.empty());
88 DataMap<DataVector> subGraphInputTiles;
89 DataMap<DataVector> subGraphOutputTiles;
90 BatchTilesMap batchTilesMap;
92 for (const auto& stage : curSubGraph) {
93 processStageInputs(stage, model, curSubGraph, subGraphInputTiles, batchTilesMap);
94 processStageOutputs(stage, model, curSubGraph, subGraphOutputTiles, batchTilesMap);
95 replicateStage(stage, model, batchTilesMap);
98 removeOriginalStages(curSubGraph, model);
100 addSplitConcatPair(subGraphInputTiles, subGraphOutputTiles, model);
105 // Collect all stages that doesn't support batch
108 StagesOrderedSet PassImpl::collectAllStageToSplit(const Model::Ptr& model) {
109 StagesOrderedSet stagesToSplit;
111 for (const auto& stage : model->getStages()) {
113 // Get stage information
116 const auto& stageInfo = stage->getBatchSupportInfo();
118 if (stageInfo.empty()) {
128 for (const auto& inEdge : stage->inputEdges()) {
129 if (!stageInfo.hasInput(inEdge)) {
133 auto curReq = stageInfo.getInput(inEdge);
135 if (curReq == BatchSupport::Split) {
137 batchSize = inEdge->input()->desc().dim(Dim::N, 1);
139 IE_ASSERT(batchSize == inEdge->input()->desc().dim(Dim::N, 1));
144 IE_ASSERT(batchSize > 0);
146 for (const auto& outEdge : stage->outputEdges()) {
147 IE_ASSERT(stageInfo.getOutput(outEdge) == BatchSupport::Split);
148 IE_ASSERT(batchSize == outEdge->output()->desc().dim(Dim::N, 1));
151 if (batchSize == 1) {
155 stage->attrs().set("batchSize", batchSize);
156 stagesToSplit.emplace(stage);
159 return stagesToSplit;
163 // Extract next sub-graph for process, it should be completely independent from other Stages
166 StagesOrderedSet PassImpl::extractNextSubGraph(StagesOrderedSet& stagesToSplit) {
168 // Add new Stage to the sub-graph only if it depends on Stages from sub-graph only
171 StagesOrderedSet subGraph;
173 for (const auto& stage : stagesToSplit) {
174 bool isInternalStage = true;
175 for (const auto& prevStage : stage->prevStages()) {
176 if (subGraph.count(prevStage) == 0) {
177 isInternalStage = false;
181 if (isInternalStage || subGraph.empty()) {
182 subGraph.emplace(stage);
185 bool shouldStop = false;
186 for (const auto& nextStage : stage->nextStages()) {
187 if (stagesToSplit.count(nextStage) == 0) {
197 for (const auto& stage : subGraph) {
198 stagesToSplit.erase(stage);
204 void PassImpl::processStageInputs(
206 const Model::Ptr& model,
207 const StagesOrderedSet& curSubGraph,
208 DataMap<DataVector>& subGraphInputTiles,
209 BatchTilesMap& batchTilesMap) {
210 const auto& stageInfo = stage->getBatchSupportInfo();
212 for (const auto& inEdge : stage->inputEdges()) {
213 if (!stageInfo.hasInput(inEdge)) {
217 auto curReq = stageInfo.getInput(inEdge);
219 if (curReq == BatchSupport::Split) {
220 splitStageInput(inEdge, model, curSubGraph, subGraphInputTiles, batchTilesMap);
221 } else if (curReq == BatchSupport::ReplicateConstContent) {
222 replicateStageInput(inEdge, model);
227 void PassImpl::splitStageInput(
228 const StageInput& inEdge,
229 const Model::Ptr& model,
230 const StagesOrderedSet& curSubGraph,
231 DataMap<DataVector>& subGraphInputTiles,
232 BatchTilesMap& batchTilesMap) {
233 const auto& input = inEdge->input();
234 const auto& stage = inEdge->consumer();
236 auto batchSize = stage->attrs().get<int>("batchSize");
238 auto newDesc = input->desc();
239 newDesc.setDim(Dim::N, 1);
241 auto& batchTiles = batchTilesMap[input];
242 if (!batchTiles.empty()) {
243 IE_ASSERT(batchTiles.size() == batchSize);
247 batchTiles.resize(batchSize);
248 for (int batchInd = 0; batchInd < batchSize; ++batchInd) {
249 auto postfix = formatString("@batch=%d/%d", batchInd + 1, batchSize);
251 batchTiles[batchInd] = model->duplicateData(
257 bool isInternalInput = false;
258 if (auto producer = input->producer()) {
259 if (curSubGraph.count(producer) != 0) {
260 isInternalInput = true;
263 if (!isInternalInput) {
264 auto res = subGraphInputTiles.emplace(input, batchTiles);
265 IE_ASSERT(res.second);
269 void PassImpl::replicateStageInput(
270 const StageInput& inEdge,
271 const Model::Ptr& model) {
272 const auto& input = inEdge->input();
273 const auto& stage = inEdge->consumer();
275 IE_ASSERT(input->usage() == DataUsage::Const);
276 auto batchSize = stage->attrs().get<int>("batchSize");
278 auto& replicatedDatas = input->attrs().getOrSet<ReplicatedDataMap>("replicatedDatas", ReplicatedDataMap());
279 if (replicatedDatas.count(batchSize) == 0) {
280 auto content = input->content();
281 IE_ASSERT(content != nullptr);
283 auto perm = input->desc().dimsOrder().toPermutation();
284 auto dims = input->desc().dims();
286 int maxDimDigit = -1;
287 for (auto d : perm) {
288 maxDimDigit = std::max(maxDimDigit, static_cast<int>(d));
290 IE_ASSERT(maxDimDigit >= 0);
292 perm.emplace_back(static_cast<Dim>(maxDimDigit + 1));
293 dims.set(perm.back(), batchSize);
295 DataDesc newDesc(input->desc().type(), DimsOrder::fromPermutation(perm), dims);
297 replicatedDatas[batchSize] = model->duplicateData(
299 formatString("@replicated=%d", batchSize),
301 replicateContent(content, batchSize));
305 void PassImpl::processStageOutputs(
307 const Model::Ptr& model,
308 const StagesOrderedSet& curSubGraph,
309 DataMap<DataVector>& subGraphOutputTiles,
310 BatchTilesMap& batchTilesMap) {
311 auto batchSize = stage->attrs().get<int>("batchSize");
313 for (const auto& output : stage->outputs()) {
314 auto newDesc = output->desc();
315 newDesc.setDim(Dim::N, 1);
317 auto& batchTiles = batchTilesMap[output];
318 if (!batchTiles.empty()) {
319 IE_ASSERT(batchTiles.size() == batchSize);
323 batchTiles.resize(batchSize);
324 for (int batchInd = 0; batchInd < batchSize; ++batchInd) {
325 auto postfix = formatString("@batch=%d/%d", batchInd + 1, batchSize);
327 batchTiles[batchInd] = model->duplicateData(
333 bool isInternalOutput = output->usage() == DataUsage::Intermediate;
334 for (const auto& consumer : output->consumers()) {
335 if (curSubGraph.count(consumer) == 0) {
336 isInternalOutput = false;
340 if (!isInternalOutput) {
341 auto res = subGraphOutputTiles.emplace(output, batchTiles);
342 IE_ASSERT(res.second);
347 void PassImpl::replicateStage(
349 const Model::Ptr& model,
350 const BatchTilesMap& batchTilesMap) {
351 const auto& stageInfo = stage->getBatchSupportInfo();
352 auto batchSize = stage->attrs().get<int>("batchSize");
354 for (int batchInd = 0; batchInd < batchSize; ++batchInd) {
355 auto postfix = formatString("@batch=%d/%d", batchInd + 1, batchSize);
357 DataVector newInputs;
358 for (const auto& inEdge : stage->inputEdges()) {
359 if (!stageInfo.hasInput(inEdge)) {
360 newInputs.emplace_back(inEdge->input());
364 auto curReq = stageInfo.getInput(inEdge);
366 if (curReq == BatchSupport::Split) {
367 const auto& batchTiles = batchTilesMap.at(inEdge->input());
368 IE_ASSERT(batchTiles.size() == batchSize);
370 newInputs.emplace_back(batchTiles[batchInd]);
371 } else if (curReq == BatchSupport::ReplicateConstContent) {
372 const auto& replicatedDatas = inEdge->input()->attrs().get<ReplicatedDataMap>("replicatedDatas");
373 newInputs.emplace_back(replicatedDatas.at(batchSize));
377 DataVector newOutputs;
378 for (const auto& output : stage->outputs()) {
379 const auto& batchTiles = batchTilesMap.at(output);
380 IE_ASSERT(batchTiles.size() == batchSize);
382 newOutputs.emplace_back(batchTiles[batchInd]);
385 auto tileStage = model->duplicateStage(
391 tileStage->attrs().set<int>("batchInd", batchInd);
393 if (stage->type() == StageType::StubConv) {
394 tileStage->attrs().set("origConvOutput", newOutputs[0]->desc());
399 void PassImpl::removeOriginalStages(
400 const StagesOrderedSet& curSubGraph,
401 const Model::Ptr& model) {
402 for (const auto& stage : curSubGraph) {
403 model->removeStage(stage);
407 void PassImpl::addSplitConcatPair(
408 const DataMap<DataVector>& subGraphInputTiles,
409 const DataMap<DataVector>& subGraphOutputTiles,
410 const Model::Ptr& model) {
411 for (const auto& p : subGraphInputTiles) {
412 _stageBuilder->addSplitStage(
414 p.first->name() + "@split-batch",
421 for (const auto& p : subGraphOutputTiles) {
422 if (p.first->usage() == DataUsage::Intermediate) {
423 IE_ASSERT(p.first->numConsumers() > 0);
426 _stageBuilder->addConcatStage(
428 p.first->name() + "@concat-batch",
438 Pass::Ptr PassManager::adjustDataBatch() {
439 return std::make_shared<PassImpl>(_stageBuilder);