[IE][VPU]: Refactoring of SpecialStageProcessor (#2885)
authorRoman Vyunov (Intel) <roman.vyunov@intel.com>
Tue, 17 Nov 2020 13:30:30 +0000 (16:30 +0300)
committerGitHub <noreply@github.com>
Tue, 17 Nov 2020 13:30:30 +0000 (16:30 +0300)
* SpecialStageProcessor refactoring
* Fix for Yolo-v3-pytorch and related test

inference-engine/src/vpu/graph_transformer/src/middleend/special_stage_processor.cpp
inference-engine/tests/functional/plugin/myriad/subgraph_tests/concat_split_transpose.cpp [new file with mode: 0644]

index ae3d0f3..e1ed4d1 100644 (file)
 #include "vpu/middleend/special_stage_processor.hpp"
 
 #include <vector>
-#include <set>
 #include <utility>
 
 namespace vpu {
 
-void SpecialStageProcessor::processSplit(
-        const Model& model,
-        const Stage& stage) {
-    IE_ASSERT(stage->type() == StageType::Split);
+namespace {
 
-    auto input = stage->input(0);
+struct NeedCopyDesc {
+    bool isCopyNeed = false;
+    bool isCopyOptional = false;
+};
 
-    const auto& offsets = stage->attrs().get<std::vector<DimValues>>("offsets");
-    IE_ASSERT(offsets.size() == checked_cast<size_t>(stage->numOutputs()));
+NeedCopyDesc isOutputCopyRequired(
+                const Stage& stage,
+                const StageOutput& outputEdge,
+                const Data& inputData) {
+    NeedCopyDesc needCopyDesc;
+    auto output = outputEdge->output();
+    if (output->usage() != DataUsage::Intermediate) {
+        needCopyDesc.isCopyNeed = true;
+    } else if (output->parentDataToDataEdge() != nullptr) {
+        needCopyDesc.isCopyNeed = true;
+    } else {
+        //
+        // Check output StridesRequirement
+        //
 
-    for (const auto& outEdge : stage->outputEdges()) {
-        IE_ASSERT(outEdge->portInd() >= 0);
-        IE_ASSERT(checked_cast<size_t>(outEdge->portInd()) < offsets.size());
+        IE_ASSERT(output->checkStrides(output->requiredStrides()));
+        if (!checkStrides(output->desc(), inputData->strides(), output->requiredStrides())) {
+            needCopyDesc.isCopyNeed = true;
+        }
 
-        auto output = outEdge->output();
-        const auto& offsetFromInput = offsets[checked_cast<size_t>(outEdge->portInd())];
+        //
+        // Check consumers StridesRequirement.
+        //
 
-        IE_ASSERT(input->desc().dimsOrder() == output->desc().dimsOrder());
-        IE_ASSERT(offsetFromInput.size() <= checked_cast<size_t>(input->desc().numDims()));
-        for (const auto& p : offsetFromInput) {
-            IE_ASSERT(input->desc().dimsOrder().hasDim(p.first));
-            IE_ASSERT(p.second + output->desc().dim(p.first) <= input->desc().dim(p.first));
+        if (!needCopyDesc.isCopyNeed) {
+            for (const auto& consumerEdge : output->consumerEdges()) {
+                const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
+                if (consumerInfo.hasInput(consumerEdge)) {
+                    const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
+                    IE_ASSERT(output->checkStrides(consumerStrideReqs));
+                    if (!checkStrides(output->desc(), inputData->strides(), consumerStrideReqs)) {
+                        needCopyDesc.isCopyNeed = true;
+                        break;
+                    }
+                }
+            }
         }
+    }
+    return needCopyDesc;
+}
 
+NeedCopyDesc isInputCopyRequired(
+                const Stage& stage,
+                const StageInput& inputEdge,
+                const Data& outputData) {
+    auto input = inputEdge->input();
+    NeedCopyDesc needCopyDesc;
+    if (input->usage() != DataUsage::Intermediate) {
+        needCopyDesc.isCopyNeed = true;
+    } else if (input->parentDataToDataEdge() != nullptr) {
+        needCopyDesc.isCopyNeed = true;
+    } else {
         //
-        // Check if we need to insert Copy stage
+        // Check input StridesRequirement.
         //
 
-        bool needCopy = false;
-        if (output->usage() != DataUsage::Intermediate) {
-            needCopy = true;
-        } else if (output->parentDataToDataEdge() != nullptr) {
-            needCopy = true;
-        } else {
-            //
-            // Check output StridesRequirement.
-            //
-
-            IE_ASSERT(output->checkStrides(output->requiredStrides()));
-            if (!checkStrides(output->desc(), input->strides(), output->requiredStrides())) {
-                needCopy = true;
-            }
+        IE_ASSERT(input->checkStrides(input->requiredStrides()));
+        if (!checkStrides(input->desc(), outputData->strides(), input->requiredStrides())) {
+            needCopyDesc.isCopyNeed = true;
+        }
 
-            //
-            // Check consumers StridesRequirement.
-            //
+        //
+        // Check consumers StridesRequirement.
+        //
 
-            if (!needCopy) {
-                for (const auto& consumerEdge : output->consumerEdges()) {
-                    const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
+        if (!needCopyDesc.isCopyNeed) {
+            for (const auto& consumerEdge : input->consumerEdges()) {
+                const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
 
-                    if (consumerInfo.hasInput(consumerEdge)) {
-                        const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
-                        IE_ASSERT(output->checkStrides(consumerStrideReqs));
+                if (consumerInfo.hasInput(consumerEdge)) {
+                    const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
+                    IE_ASSERT(input->checkStrides(consumerStrideReqs));
 
-                        if (!checkStrides(output->desc(), input->strides(), consumerStrideReqs)) {
-                            needCopy = true;
-                            break;
-                        }
+                    if (!checkStrides(input->desc(), outputData->strides(), consumerStrideReqs)) {
+                        needCopyDesc.isCopyNeed = true;
                     }
                 }
             }
         }
 
         //
-        // Insert Copy if needed
+        // Check producer StridesRequirement.
         //
 
-        if (needCopy) {
-            auto outputCopy = model->duplicateData(output, "@copy");
-            outputCopy->resetRequiredStrides();
+        if (!needCopyDesc.isCopyNeed) {
+            if (auto producerEdge = input->producerEdge()) {
+                const auto& producerInfo = producerEdge->producer()->getDataStridesRequirements();
 
-            auto outPortInd = outEdge->portInd();
+                if (producerInfo.hasOutput(producerEdge)) {
+                    const auto& producerStrideReqs = producerInfo.getOutput(producerEdge);
+                    IE_ASSERT(input->checkStrides(producerStrideReqs));
 
-            model->replaceStageOutput(outEdge, outputCopy);
+                    if (!checkStrides(input->desc(), outputData->strides(), producerStrideReqs)) {
+                        needCopyDesc.isCopyNeed = true;
+                    }
+                }
 
-            auto copyStage = _stageBuilder->addCopyStage(
-                model,
-                formatString("%s@output=%d@copy-for-split", stage->name(), outPortInd),
-                stage->origLayer(),
-                outputCopy,
-                output,
-                "special::split");
-            if (stage->attrs().has("batchInd")) {
-                copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
+                if (!needCopyDesc.isCopyNeed) {
+                    //
+                    // To reduce the size of HW output (still can be optimized).
+                    //
+
+                    if (producerEdge->producer()->category() == StageCategory::HW) {
+                        needCopyDesc.isCopyNeed = true;
+                        needCopyDesc.isCopyOptional = true;
+                    }
+                }
             }
+        }
+    }
+
+    return needCopyDesc;
+}
+
+Data insertCopyOfInput(const Model& model,
+                       const Stage& stage,
+                       const StageInput& edge,
+                       const StageBuilder::Ptr& _stageBuilder,
+                       const NeedCopyDesc& desc) {
+    auto data = edge->input();
+
+    Data copy;
+    if (data->usage() == DataUsage::Const) {
+        copy = model->addNewData(data->name() + "@copy", data->desc());
+    } else {
+        copy = model->duplicateData(data, "@copy");
+        copy->resetRequiredStrides();
+    }
+    if (stage->type() == StageType::Reshape)
+        copy->updateRequiredStrides(StridesRequirement::compact());
+
+    bool hasMultipleInputs = stage->numInputs() > 1;
+    auto inputNumStr = hasMultipleInputs ? formatString("@input=%d", edge->portInd()) : "";
+    std::stringstream typeAsString;
+    typeAsString << stage->type();
+
+    auto copyStage = _stageBuilder->addCopyStage(
+            model,
+            formatString("%s%s@copy-for-%s", stage->name(), inputNumStr, typeAsString),
+            stage->origLayer(),
+            data,
+            copy,
+            formatString("special::%s", typeAsString));
+    if (stage->type() != StageType::Reshape) {
+        copyStage->attrs().set<bool>("optional", desc.isCopyOptional);
+    }
+    if (stage->attrs().has("batchInd")) {
+        copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
+    }
+
+    model->replaceStageInput(edge, copy);
+
+    return copy;
+}
+
+Data insertCopyOfOutput(const Model& model,
+                        const Stage& stage,
+                        const StageOutput& edge,
+                        const StageBuilder::Ptr& _stageBuilder) {
+    auto data = edge->output();
+    auto copy = model->duplicateData(data, "@copy");
+    copy->resetRequiredStrides();
+
+    model->replaceStageOutput(edge, copy);
 
-            output = outputCopy;
+    bool hasMultipleOutputs = stage->numOutputs() > 1;
+    auto outputNumStr = hasMultipleOutputs ? formatString("@output=%d", edge->portInd()) : "";
+    std::stringstream typeAsString;
+    typeAsString << stage->type();
+
+    auto copyStage = _stageBuilder->addCopyStage(
+            model,
+            formatString("%s%s@copy-for-%s", stage->name(), outputNumStr, typeAsString),
+            stage->origLayer(),
+            copy,
+            data,
+            formatString("special::%s", typeAsString));
+    if (stage->attrs().has("batchInd")) {
+        copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
+    }
+
+    return copy;
+}
+
+} // namespace
+
+
+void SpecialStageProcessor::processSplit(
+        const Model& model,
+        const Stage& stage) {
+    IE_ASSERT(stage->type() == StageType::Split);
+    auto input = stage->input(0);
+
+    const auto& offsets = stage->attrs().get<std::vector<DimValues>>("offsets");
+    IE_ASSERT(offsets.size() == checked_cast<size_t>(stage->numOutputs()));
+
+    for (const auto& outEdge : stage->outputEdges()) {
+        IE_ASSERT(outEdge->portInd() >= 0);
+        IE_ASSERT(checked_cast<size_t>(outEdge->portInd()) < offsets.size());
+
+        auto output = outEdge->output();
+        const auto& offsetFromInput = offsets[checked_cast<size_t>(outEdge->portInd())];
+
+        IE_ASSERT(input->desc().dimsOrder() == output->desc().dimsOrder());
+        IE_ASSERT(offsetFromInput.size() <= checked_cast<size_t>(input->desc().numDims()));
+        for (const auto& p : offsetFromInput) {
+            IE_ASSERT(input->desc().dimsOrder().hasDim(p.first));
+            IE_ASSERT(p.second + output->desc().dim(p.first) <= input->desc().dim(p.first));
+        }
+
+        auto desc = isOutputCopyRequired(stage, outEdge, input);
+        if (desc.isCopyNeed) {
+            output = insertCopyOfOutput(model, stage, outEdge, _stageBuilder);
         }
 
         //
@@ -136,113 +267,9 @@ void SpecialStageProcessor::processConcat(
             IE_ASSERT(p.second + input->desc().dim(p.first) <= output->desc().dim(p.first));
         }
 
-        //
-        // Check if we need to insert Copy stage
-        //
-
-        bool needCopy = false;
-        bool optionalCopy = false;
-        if (input->usage() != DataUsage::Intermediate) {
-            needCopy = true;
-            optionalCopy = false;
-        } else if (input->parentDataToDataEdge() != nullptr) {
-            needCopy = true;
-            optionalCopy = false;
-        } else {
-            //
-            // Check input StridesRequirement.
-            //
-
-            IE_ASSERT(input->checkStrides(input->requiredStrides()));
-            if (!checkStrides(input->desc(), output->strides(), input->requiredStrides())) {
-                needCopy = true;
-                optionalCopy = false;
-            }
-
-            //
-            // Check consumers StridesRequirement.
-            //
-
-            if (!needCopy) {
-                for (const auto& consumerEdge : input->consumerEdges()) {
-                    const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
-
-                    if (consumerInfo.hasInput(consumerEdge)) {
-                        const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
-                        IE_ASSERT(input->checkStrides(consumerStrideReqs));
-
-                        if (!checkStrides(input->desc(), output->strides(), consumerStrideReqs)) {
-                            needCopy = true;
-                            optionalCopy = false;
-                        }
-                    }
-                }
-            }
-
-            //
-            // Check producer StridesRequirement.
-            //
-
-            if (!needCopy) {
-                if (auto producerEdge = input->producerEdge()) {
-                    const auto& producerInfo = producerEdge->producer()->getDataStridesRequirements();
-
-                    if (producerInfo.hasOutput(producerEdge)) {
-                        const auto& producerStrideReqs = producerInfo.getOutput(producerEdge);
-                        IE_ASSERT(input->checkStrides(producerStrideReqs));
-
-                        if (!checkStrides(input->desc(), output->strides(), producerStrideReqs)) {
-                            needCopy = true;
-                            optionalCopy = false;
-                        }
-                    }
-
-                    if (!needCopy) {
-                        //
-                        // To reduce the size of HW output (still can be optimized).
-                        //
-
-                        if (producerEdge->producer()->category() == StageCategory::HW) {
-                            needCopy = true;
-                            optionalCopy = true;
-                        }
-                    }
-                }
-            }
-        }
-
-        //
-        // Insert Copy if needed
-        //
-
-        if (needCopy) {
-            Data inputCopy;
-            if (input->usage() == DataUsage::Const) {
-                inputCopy = model->addNewData(
-                    input->name() + "@copy",
-                    input->desc());
-            } else {
-                inputCopy = model->duplicateData(
-                    input,
-                    "@copy");
-                inputCopy->resetRequiredStrides();
-            }
-
-            auto copyStage = _stageBuilder->addCopyStage(
-                model,
-                formatString("%s@input=%d@copy-for-concat", stage->name(), inEdge->portInd()),
-                stage->origLayer(),
-                input,
-                inputCopy,
-                "special::concat");
-            copyStage->attrs().set<bool>("optional", optionalCopy);
-            if (stage->attrs().has("batchInd")) {
-                copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
-            }
-
-            model->replaceStageInput(inEdge, inputCopy);
-
-            input = inputCopy;
+        NeedCopyDesc desc = isInputCopyRequired(stage, inEdge, output);
+        if (desc.isCopyNeed) {
+            input = insertCopyOfInput(model, stage, inEdge, _stageBuilder, desc);
         }
 
         //
@@ -272,50 +299,12 @@ void SpecialStageProcessor::processReshape(
     IE_ASSERT(output->desc().dimsOrder() == DimsOrder::fromNumDims(output->desc().numDims()));
     IE_ASSERT(output->checkStrides(StridesRequirement::compact()));
 
-    //
-    // Check if we need to insert Copy stage
-    //
-
-    bool needCopy = false;
-    if (input->usage() != DataUsage::Intermediate &&
-        output->usage() != DataUsage::Intermediate) {
-        needCopy = true;
-    } else if (input->parentDataToDataEdge() != nullptr &&
-               output->parentDataToDataEdge() != nullptr) {
-        needCopy = true;
-    }
-
-    //
-    // Insert Copy if needed
-    //
-
-    if (needCopy) {
-        Data inputCopy;
-        if (input->usage() == DataUsage::Const) {
-            inputCopy = model->addNewData(
-                input->name() + "@copy",
-                input->desc());
-        } else {
-            inputCopy = model->duplicateData(
-                input,
-                "@copy");
-        }
-        inputCopy->updateRequiredStrides(StridesRequirement::compact());
-
-        auto copyStage = _stageBuilder->addCopyStage(
-            model,
-            formatString("%s@copy-for-reshape", stage->name()),
-            stage->origLayer(),
-            input,
-            inputCopy,
-            "special::reshape");
-        if (stage->attrs().has("batchInd")) {
-            copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
-        }
-
-        model->replaceStageInput(stage->inputEdge(0), inputCopy);
-
-        input = inputCopy;
+    NeedCopyDesc desc;
+    if ((input->usage() != DataUsage::Intermediate || input->parentDataToDataEdge() != nullptr) &&
+        (output->usage() != DataUsage::Intermediate || output->parentDataToDataEdge() != nullptr))
+        desc.isCopyNeed = true;
+    if (desc.isCopyNeed) {
+        input = insertCopyOfInput(model, stage, stage->inputEdge(0), _stageBuilder, desc);
     }
 
     //
@@ -330,16 +319,19 @@ void SpecialStageProcessor::processReshape(
             .mode(SharedDataMode::Reshape)
             .order(SharedDataOrder::ChildWritesToParent)
             .done();
-    } else {
-        IE_ASSERT(output->usage() == DataUsage::Intermediate);
-        IE_ASSERT(output->parentDataToDataEdge() == nullptr);
-
+    } else if (output->usage() == DataUsage::Intermediate &&
+               output->parentDataToDataEdge() == nullptr) {
         model->connectDataWithData()
             .parent(input)
             .child(output)
             .mode(SharedDataMode::Reshape)
             .order(SharedDataOrder::ParentWritesToChild)
             .done();
+    } else {
+        IE_ASSERT(input->usage() == DataUsage::Intermediate &&
+                  input->parentDataToDataEdge() == nullptr);
+        IE_ASSERT(output->usage() == DataUsage::Intermediate &&
+                  output->parentDataToDataEdge() == nullptr);
     }
 }
 
@@ -359,113 +351,9 @@ void SpecialStageProcessor::processExpand(
         IE_ASSERT(p.second + input->desc().dim(p.first) <= output->desc().dim(p.first));
     }
 
-    //
-    // Check if we need to insert Copy stage
-    //
-
-    bool needCopy = false;
-    bool optionalCopy = false;
-    if (input->usage() != DataUsage::Intermediate) {
-        needCopy = true;
-        optionalCopy = false;
-    } else if (input->parentDataToDataEdge() != nullptr) {
-        needCopy = true;
-        optionalCopy = false;
-    } else {
-        //
-        // Check input StridesRequirement.
-        //
-
-        IE_ASSERT(input->checkStrides(input->requiredStrides()));
-        if (!checkStrides(input->desc(), output->strides(), input->requiredStrides())) {
-            needCopy = true;
-            optionalCopy = false;
-        }
-
-        //
-        // Check consumers StridesRequirement.
-        //
-
-        if (!needCopy) {
-            for (const auto& consumerEdge : input->consumerEdges()) {
-                const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
-
-                if (consumerInfo.hasInput(consumerEdge)) {
-                    const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
-                    IE_ASSERT(input->checkStrides(consumerStrideReqs));
-
-                    if (!checkStrides(input->desc(), output->strides(), consumerStrideReqs)) {
-                        needCopy = true;
-                        optionalCopy = false;
-                    }
-                }
-            }
-        }
-
-        //
-        // Check producer StridesRequirement.
-        //
-
-        if (!needCopy) {
-            if (auto producerEdge = input->producerEdge()) {
-                const auto& producerInfo = producerEdge->producer()->getDataStridesRequirements();
-
-                if (producerInfo.hasOutput(producerEdge)) {
-                    const auto& producerStrideReqs = producerInfo.getOutput(producerEdge);
-                    IE_ASSERT(input->checkStrides(producerStrideReqs));
-
-                    if (!checkStrides(input->desc(), output->strides(), producerStrideReqs)) {
-                        needCopy = true;
-                        optionalCopy = false;
-                    }
-                }
-
-                if (!needCopy) {
-                    //
-                    // To reduce the size of HW output (still can be optimized).
-                    //
-
-                    if (producerEdge->producer()->category() == StageCategory::HW) {
-                        needCopy = true;
-                        optionalCopy = true;
-                    }
-                }
-            }
-        }
-    }
-
-    //
-    // Insert Copy if needed
-    //
-
-    if (needCopy) {
-        Data inputCopy;
-        if (input->usage() == DataUsage::Const) {
-            inputCopy = model->addNewData(
-                input->name() + "@copy",
-                input->desc());
-        } else {
-            inputCopy = model->duplicateData(
-                input,
-                "@copy");
-            inputCopy->resetRequiredStrides();
-        }
-
-        auto copyStage = _stageBuilder->addCopyStage(
-            model,
-            formatString("%s@copy-for-expand", stage->name()),
-            stage->origLayer(),
-            input,
-            inputCopy,
-            "special::expand");
-        copyStage->attrs().set<bool>("optional", optionalCopy);
-        if (stage->attrs().has("batchInd")) {
-            copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
-        }
-
-        model->replaceStageInput(stage->inputEdge(0), inputCopy);
-
-        input = inputCopy;
+    auto desc = isInputCopyRequired(stage, stage->inputEdge(0), output);
+    if (desc.isCopyNeed) {
+        input = insertCopyOfInput(model, stage, stage->inputEdge(0), _stageBuilder, desc);
     }
 
     //
@@ -497,76 +385,11 @@ void SpecialStageProcessor::processCrop(
         IE_ASSERT(p.second + output->desc().dim(p.first) <= input->desc().dim(p.first));
     }
 
-    //
-    // Check if we need to insert Copy for output
-    //
-
-    bool needCopy = false;
-    if (output->usage() != DataUsage::Intermediate) {
-        needCopy = true;
-    } else if (output->parentDataToDataEdge() != nullptr) {
-        needCopy = true;
-    } else {
-        //
-        // Check output StridesRequirement.
-        //
-
-        IE_ASSERT(output->checkStrides(output->requiredStrides()));
-        if (!checkStrides(output->desc(), input->strides(), output->requiredStrides())) {
-            needCopy = true;
-        }
-
-        //
-        // Check consumers StridesRequirement.
-        //
-
-        if (!needCopy) {
-            for (const auto& consumerEdge : output->consumerEdges()) {
-                const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
-
-                if (consumerInfo.hasInput(consumerEdge)) {
-                    const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
-                    IE_ASSERT(output->checkStrides(consumerStrideReqs));
-
-                    if (!checkStrides(output->desc(), input->strides(), consumerStrideReqs)) {
-                        needCopy = true;
-                        break;
-                    }
-                }
-            }
-        }
-    }
-
-    //
-    // Insert output Copy if needed
-    //
-
-    if (needCopy) {
-        auto outputCopy = model->duplicateData(
-            output,
-            "@copy");
-        outputCopy->resetRequiredStrides();
-
-        model->replaceStageOutput(stage->outputEdge(0), outputCopy);
-
-        auto copyStage = _stageBuilder->addCopyStage(
-            model,
-            formatString("%s@copy-output-for-crop", stage->name()),
-            stage->origLayer(),
-            outputCopy,
-            output,
-            "special::crop");
-        if (stage->attrs().has("batchInd")) {
-            copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
-        }
-
-        output = outputCopy;
+    auto desc = isOutputCopyRequired(stage, stage->outputEdge(0), input);
+    if (desc.isCopyNeed) {
+        output = insertCopyOfOutput(model, stage, stage->outputEdge(0), _stageBuilder);
     }
 
-    //
-    // Add Data<->Data edge
-    //
-
     model->connectDataWithData()
         .parent(input)
         .child(output)
diff --git a/inference-engine/tests/functional/plugin/myriad/subgraph_tests/concat_split_transpose.cpp b/inference-engine/tests/functional/plugin/myriad/subgraph_tests/concat_split_transpose.cpp
new file mode 100644 (file)
index 0000000..58ac64b
--- /dev/null
@@ -0,0 +1,86 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <functional_test_utils/layer_test_utils.hpp>
+#include "vpu/private_plugin_config.hpp"
+
+#include <ngraph_functions/builders.hpp>
+#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
+#include <vpu/myriad_plugin_config.hpp>
+
+namespace {
+
+using DataType = ngraph::element::Type_t;
+using DataDims = std::vector<std::vector<std::size_t>>;
+
+using Parameters = std::tuple<
+        DataType,
+        DataDims,
+        std::int64_t,
+        std::vector<std::size_t>,
+        LayerTestsUtils::TargetDevice>;
+
+class Concat_Split_Transpose : public testing::WithParamInterface<Parameters>, virtual public LayerTestsUtils::LayerTestsCommon {
+protected:
+    void SetUp() override {
+        SetRefMode(LayerTestsUtils::RefMode::CONSTANT_FOLDING);
+        configuration[InferenceEngine::MYRIAD_DISABLE_CONVERT_STAGES] = CONFIG_VALUE(YES);
+        configuration[InferenceEngine::MYRIAD_DETECT_NETWORK_BATCH] = CONFIG_VALUE(NO);
+
+        const auto& dataType = std::get<0>(GetParam());
+        const auto& dataDims = std::get<1>(GetParam());
+        const auto& axis = std::get<2>(GetParam());
+        const auto& length = std::get<3>(GetParam());
+        targetDevice = std::get<4>(GetParam());
+
+        auto params = ngraph::builder::makeParams(dataType, dataDims);
+        auto paramOuts = ngraph::helpers::convert2OutputVector(
+                ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(params));
+
+        auto concat = std::make_shared<ngraph::opset1::Concat>(paramOuts, axis);
+
+        const auto lengthData = std::make_shared<ngraph::opset3::Constant>(ngraph::element::i64,
+                                                                            ngraph::Shape{length.size()},
+                                                                            length);
+        const auto axisData = std::make_shared<ngraph::opset3::Constant>(ngraph::element::i64,
+                                                                          ngraph::Shape{1},
+                                                                          axis);
+        auto split = std::make_shared<ngraph::opset3::VariadicSplit>(concat, axisData, lengthData);
+
+        auto permutation = std::vector<std::int64_t>(split->get_output_shape(0).size());
+        std::iota(permutation.rbegin(), permutation.rend(), 0);
+        const auto transposition = std::make_shared<ngraph::opset3::Constant>(ngraph::element::i64,
+                                                                              ngraph::Shape{split->get_output_shape(0).size()},
+                                                                              permutation);
+
+        ngraph::ResultVector results;
+        for (int i = 0; i < 2; i++) {
+            const auto transpose = std::make_shared<ngraph::opset3::Transpose>(split->output(i), transposition);
+            results.push_back(std::make_shared<ngraph::opset1::Result>(transpose));
+        }
+        function = std::make_shared<ngraph::Function>(results, params, "concat-split-transpose");
+    }
+};
+
+TEST_P(Concat_Split_Transpose, CompareWithRefs) {
+    Run();
+}
+
+std::vector<DataDims> dims = {
+        {{400, 1}, {600, 1}}
+};
+
+std::vector<std::vector<std::size_t>> length = {
+        {500, 500}
+};
+
+INSTANTIATE_TEST_CASE_P(SpecialStages, Concat_Split_Transpose,
+                        ::testing::Combine(
+                                ::testing::Values(ngraph::element::i32),
+                                ::testing::ValuesIn(dims),
+                                ::testing::Values(0),
+                                ::testing::ValuesIn(length),
+                                ::testing::Values(CommonTestUtils::DEVICE_MYRIAD)));
+
+}  // namespace