From a6bb5aa0375056332936be04692063f1fb928f98 Mon Sep 17 00:00:00 2001 From: Anton Dudchenko Date: Wed, 10 Jun 2020 17:30:37 +0300 Subject: [PATCH] [VPU][GT] Trivial permute optimization (#571) * Transformation to eliminate trivial permute * Minor changes in unit tests * Replace trivial permutation with copy if input and output dims is equal * Fix mergePermuteStages tests * Small changes in the loop * Add const modifier, change dimsVector type to SizeVector * Change loop condition, rename valiable * To reverse dimsVector --- .../src/middleend/passes/merge_permute_stages.cpp | 39 ++++++++++++++++------ .../vpu/common/myriad_merge_permute_tests.hpp | 8 +++++ 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/inference-engine/src/vpu/graph_transformer/src/middleend/passes/merge_permute_stages.cpp b/inference-engine/src/vpu/graph_transformer/src/middleend/passes/merge_permute_stages.cpp index 05e0b36..4ca636b 100644 --- a/inference-engine/src/vpu/graph_transformer/src/middleend/passes/merge_permute_stages.cpp +++ b/inference-engine/src/vpu/graph_transformer/src/middleend/passes/merge_permute_stages.cpp @@ -55,11 +55,24 @@ private: permuteStage->output(0)->desc().dimsOrder()); } - static bool isTrivialPermute(const PermutationIndexVector& permuteDims) { - for (size_t i = 0; i < permuteDims.size(); ++i) - if (i != permuteDims[i]) - return false; + static bool isTrivialPermute(const PermutationIndexVector& permutation, const vpu::DimValues& dims) { + InferenceEngine::SizeVector dimsVector(dims.size()); + for (const auto& dim : dims) { + auto index = dimToIeInd(dim.first, dims.size()); + dimsVector[dims.size() - 1 - index] = dim.second; + } + for (size_t i = 0; i < permutation.size() - 1; ++i) { + if (i != permutation[i]) { + bool swapAdjacentDims = permutation[i] == (i + 1) && permutation[i + 1] == i; + bool dimIsOne = dimsVector[i] == 1 || dimsVector[i + 1] == 1; + if (swapAdjacentDims && dimIsOne) { + i++; + } else { + return false; + } + } + } return true; } @@ -131,7 +144,6 @@ private: void PassImpl::run(const Model& model) { VPU_PROFILE(mergePermuteStages); const StageMergeGroupList stageMergeGroupList = prepareStagesForMerge(model); - for (const auto& stageMergeGroup : stageMergeGroupList) { const auto& firstPermuteStage = stageMergeGroup.first; auto resultPermutation = permuteVectorFromStageInternal(firstPermuteStage); @@ -157,8 +169,8 @@ void PassImpl::run(const Model& model) { if (!outputLayout.empty()) firstPermuteStage->attrs().set(outputOrderKey, outputLayout); - // if we have no actual permutation, replace it with copy. - if (isTrivialPermute(resultPermutation)) { + // if we have no actual permutation, replace it with copy or reshape. + if (isTrivialPermute(resultPermutation, firstPermuteStage->input(0)->desc().dims())) { auto permuteInput = firstPermuteStage->input(0); auto permuteOutput = firstPermuteStage->output(0); if (permuteInput->desc().dimsOrder() == permuteOutput->desc().dimsOrder()) { @@ -166,10 +178,15 @@ void PassImpl::run(const Model& model) { auto origLayer = firstPermuteStage->origLayer(); model->removeStage(firstPermuteStage); - auto copyStage = _stageBuilder->addCopyStage(model, stageName + "@merged-to-copy", - origLayer, permuteInput, permuteOutput, "Eliminated permute"); - // TODO: make this optional=true with corresponding fixes in eliminate_copy (it expects Special stages now). - copyStage->attrs().set("optional", false); + if (permuteInput->desc().dims() == permuteOutput->desc().dims()) { + auto copyStage = _stageBuilder->addCopyStage(model, stageName + "@merged-to-copy", + origLayer, permuteInput, permuteOutput, "Eliminated permute"); + // TODO: make this optional=true with corresponding fixes in eliminate_copy (it expects Special stages now). + copyStage->attrs().set("optional", false); + } else { + _stageBuilder->addReshapeStage(model, stageName + "@merged-to-reshape", origLayer, + permuteInput, permuteOutput); + } } } } diff --git a/inference-engine/tests_deprecated/functional/vpu/common/myriad_merge_permute_tests.hpp b/inference-engine/tests_deprecated/functional/vpu/common/myriad_merge_permute_tests.hpp index ec0194d..dd5757d 100644 --- a/inference-engine/tests_deprecated/functional/vpu/common/myriad_merge_permute_tests.hpp +++ b/inference-engine/tests_deprecated/functional/vpu/common/myriad_merge_permute_tests.hpp @@ -79,26 +79,34 @@ TEST_P(myriadLayersMergePermuteNDTests_nightly, Permute) { } static const std::vector s_inTensors_3D = { {5, 7, 11}, + {1, 3, 4}, }; static const std::vector s_permuteParams_3D = { + {{0, 1, 2}, {1, 0, 2}}, // trivial for case with dims {1, 3, 4} {{1, 2, 0}, {1, 2, 0}}, {{1, 2, 0}, {1, 2, 0}, {1, 2, 0}}, // trivial one. }; static const std::vector s_inTensors_4D = { {3, 5, 7, 11}, + {5, 1, 1, 7}, }; static const std::vector s_permuteParams_4D = { + {{0, 1, 2, 3}, {1, 0, 3, 2}}, // + {{0, 1, 2, 3}, {0, 1, 3, 2}}, // trivial for case with dims {5, 1, 1, 7} {{1, 2, 3, 0}, {1, 2, 3, 0}, {1, 2, 3, 0}}, {{1, 2, 3, 0}, {1, 2, 3, 0}, {1, 2, 3, 0}, {1, 2, 3, 0}}, // trivial one. }; static const std::vector s_inTensors_5D = { {2, 3, 5, 7, 11}, + {2, 3, 1, 7, 11}, }; static const std::vector s_permuteParams_5D = { + {{0, 1, 2, 3, 4}, {0, 1, 3, 2, 4}}, // + {{0, 1, 2, 3, 4}, {0, 2, 1, 3, 4}}, // trivial for case with dims {2, 3, 1, 7, 11} {{0, 4, 1, 2, 3}, {0, 2, 1, 3, 4}}, {{0, 3, 4, 1, 2}, {0, 1, 3, 2, 4}}, {{1, 2, 3, 4, 0}, {1, 2, 3, 4, 0}}, -- 2.7.4