[CPU] fixed MergePermuteAndReorder optimization (#3317)
authorAnton Voronov <anton.voronov@intel.com>
Tue, 24 Nov 2020 16:14:30 +0000 (19:14 +0300)
committerGitHub <noreply@github.com>
Tue, 24 Nov 2020 16:14:30 +0000 (19:14 +0300)
inference-engine/src/mkldnn_plugin/mkldnn_graph.cpp
inference-engine/src/mkldnn_plugin/mkldnn_graph.h
inference-engine/src/mkldnn_plugin/mkldnn_graph_optimizer.cpp
inference-engine/tests/functional/plugin/cpu/subgraph_tests/include/fuse_permute_reorder.hpp
inference-engine/tests/functional/plugin/cpu/subgraph_tests/src/fuse_permute_reorder.cpp

index 7883277..2969029 100644 (file)
@@ -1083,7 +1083,7 @@ void MKLDNNGraph::RemoveDroppedEdges() {
     }
 }
 
-void MKLDNNGraph::InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const TensorDesc& inDesc, const TensorDesc& outDesc,
+MKLDNNNodePtr MKLDNNGraph::InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const TensorDesc& inDesc, const TensorDesc& outDesc,
                                 bool isOptimized, InferenceEngine::Blob::Ptr scales) {
     CNNLayerPtr layer(new CNNLayer({layerName,
                                     "Reorder",
@@ -1133,6 +1133,7 @@ void MKLDNNGraph::InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const
     }
 
     graphNodes.push_back(newReorder);
+    return newReorder;
 }
 
 void MKLDNNGraph::dumpToDotFile(std::string file) const {
index b97cf9d..27d2480 100644 (file)
@@ -109,10 +109,10 @@ public:
      * optimization flag; if isOptimized is true then Reorder node does nothing
      * @param scales
      * pointer to the blob containing scales
-     * @return none.
+     * @return pointer to the new Reorder node.
      */
-    void InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const InferenceEngine::TensorDesc& inDesc, const InferenceEngine::TensorDesc& outDesc,
-                       bool isOptimized = false, InferenceEngine::Blob::Ptr scales = nullptr);
+    MKLDNNNodePtr InsertReorder(MKLDNNEdgePtr edge, std::string layerName, const InferenceEngine::TensorDesc& inDesc,
+            const InferenceEngine::TensorDesc& outDesc, bool isOptimized = false, InferenceEngine::Blob::Ptr scales = nullptr);
 
     InferenceEngine::CNNNetwork dump() const;
 
index 05cec5d..ebda579 100644 (file)
@@ -2312,8 +2312,8 @@ void MKLDNNGraphOptimizer::MergePermuteAndReorder(MKLDNNGraph &graph) {
         graph.DropNode(parentNode);
         graph.DropNode(childNode);
 
-        auto inDesc = parentParentNode->getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].desc;
-        auto outDesc = childChildNode->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].desc;
+        auto inDesc = parentNode->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].desc;
+        auto outDesc = childNode->getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].desc;
 
         auto inPrec = inDesc.getPrecision();
         auto outPrec = outDesc.getPrecision();
@@ -2333,13 +2333,12 @@ void MKLDNNGraphOptimizer::MergePermuteAndReorder(MKLDNNGraph &graph) {
             }
         }
 
-        graph.InsertReorder(edge, reorderlayerName, reorderInDesc, reorderOutDesc, true);
+        auto reorderNode = graph.InsertReorder(edge, reorderlayerName, reorderInDesc, reorderOutDesc, true);
 
         // case 2
         if (inPrec != outPrec) {
-            auto reorderNode = parentParentNode->getChildEdgeAt(0)->getChild();
-            auto reorderInDesc2 = TensorDesc(reorderNode->getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].desc);
-            auto reorderOutDesc2 = TensorDesc(childChildNode->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].desc);
+            auto reorderInDesc2 = TensorDesc(reorderOutDesc);
+            auto reorderOutDesc2 = TensorDesc(outDesc);
 
             std::string reorderLayerName2 = reorderNode->getName() + "_" +
                                     MKLDNNExtensionUtils::getReorderArgs(reorderInDesc2, reorderOutDesc2) + "_" + childChildNode->getName();
index cfbd70c..22a57ef 100644 (file)
@@ -18,8 +18,8 @@ using namespace CPUTestUtils;
 namespace LayerTestsDefinitions {
 
 using FusePermuteAndReorderParams = std::tuple<
-        InferenceEngine::SizeVector, // Input shape
-        InferenceEngine::Precision   // Input precision
+        InferenceEngine::SizeVector,    // Input shape
+        InferenceEngine::Precision      // Input precision
 >;
 
 class FusePermuteAndReorderTest : public testing::WithParamInterface<FusePermuteAndReorderParams>, public CPUTestsBase,
@@ -29,7 +29,21 @@ public:
 
 protected:
     void SetUp() override;
-    std::string pluginTypeNode;
+    virtual void CreateGraph();
+    void CheckPermuteCount(size_t expectedPermuteCount);
+
+    InferenceEngine::SizeVector inputShape;
+    InferenceEngine::Precision inPrec;
+};
+
+class FusePermuteAndReorderTest1 : public FusePermuteAndReorderTest {
+protected:
+    void CreateGraph() override;
+};
+
+class FusePermuteAndReorderTest2 : public FusePermuteAndReorderTest {
+protected:
+    void CreateGraph() override;
 };
 
 } // namespace LayerTestsDefinitions
index 6f1fb7d..e5b734b 100644 (file)
@@ -21,25 +21,69 @@ std::string FusePermuteAndReorderTest::getTestCaseName(testing::TestParamInfo<Fu
     return result.str();
 }
 
+void FusePermuteAndReorderTest::CheckPermuteCount(size_t expectedPermuteCount) {
+    InferenceEngine::CNNNetwork execGraphInfo = executableNetwork.GetExecGraphInfo();
+    auto function = execGraphInfo.getFunction();
+    ASSERT_NE(nullptr, function);
+    size_t actualPermuteCount = 0;
+    for (const auto &node : function->get_ops()) {
+        const auto & rtInfo = node->get_rt_info();
+        auto getExecValue = [&rtInfo](const std::string & paramName) -> std::string {
+            auto it = rtInfo.find(paramName);
+            IE_ASSERT(rtInfo.end() != it);
+            auto value = std::dynamic_pointer_cast<ngraph::VariantImpl<std::string>>(it->second);
+            IE_ASSERT(nullptr != value);
+            return value->get();
+        };
+        if (getExecValue(ExecGraphInfoSerialization::LAYER_TYPE) == "Permute") {
+            actualPermuteCount++;
+        }
+    }
+
+    ASSERT_EQ(expectedPermuteCount, actualPermuteCount);
+}
+
 void FusePermuteAndReorderTest::SetUp() {
     targetDevice = CommonTestUtils::DEVICE_CPU;
-    SizeVector inputShape;
-    Precision inPrec;
 
     std::tie(inputShape, inPrec) = this->GetParam();
+    CreateGraph();
+}
 
+const auto fusePermuteAndReorderCommonParams = ::testing::Combine(
+        ::testing::Values(SizeVector{1, 2, 3, 4}, SizeVector{1, 2, 3, 4, 5}),
+        ::testing::Values(Precision::I8, Precision::U8)
+);
+
+/*  FusePermuteAndReorderTest graph
+      ---------
+      |Input  |
+      ---------
+          |
+    -------------
+    | --------- |
+    | |Permute| |
+    | --------- |
+    |     |     |
+    | --------- |
+    | |Reorder| |
+    | --------- |
+    |-----------|
+          |
+      ---------
+      |Output |
+      ---------
+*/
+
+void FusePermuteAndReorderTest::CreateGraph() {
     auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
     auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
-    auto paramOuts = ngraph::helpers::convert2OutputVector(
-            ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(params));
 
     auto order = inputShape.size() == 5 ? std::vector<int64_t>{0, 2, 3, 4, 1} : std::vector<int64_t>{0, 2, 3, 1};
     auto memFmt = inputShape.size() == 5 ? ndhwc : nhwc;
 
     auto constOrder = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
-
-    auto permute = std::make_shared<ngraph::opset5::Transpose>(paramOuts[0], constOrder);
-
+    auto permute = std::make_shared<ngraph::opset5::Transpose>(params[0], constOrder);
     permute->get_rt_info() = setCPUInfo({memFmt}, {memFmt}, {});
 
     ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(permute)};
@@ -50,33 +94,146 @@ TEST_P(FusePermuteAndReorderTest, CompareWithRefs) {
     SKIP_IF_CURRENT_TEST_IS_DISABLED()
 
     Run();
+    CheckPermuteCount(0);
+}
 
-    InferenceEngine::CNNNetwork execGraphInfo = executableNetwork.GetExecGraphInfo();
-    auto function = execGraphInfo.getFunction();
-    ASSERT_NE(nullptr, function);
-    bool permuteFound = false;
-    for (const auto &node : function->get_ops()) {
-        const auto & rtInfo = node->get_rt_info();
-        auto getExecValue = [&rtInfo](const std::string & paramName) -> std::string {
-            auto it = rtInfo.find(paramName);
-            IE_ASSERT(rtInfo.end() != it);
-            auto value = std::dynamic_pointer_cast<ngraph::VariantImpl<std::string>>(it->second);
-            IE_ASSERT(nullptr != value);
-            return value->get();
-        };
-        if (getExecValue(ExecGraphInfoSerialization::LAYER_TYPE) == "Permute") {
-            permuteFound = true;
-            break;
-        }
-    }
-    ASSERT_TRUE(!permuteFound);
+INSTANTIATE_TEST_CASE_P(smoke_Basic, FusePermuteAndReorderTest, fusePermuteAndReorderCommonParams, FusePermuteAndReorderTest::getTestCaseName);
+
+
+/*  FusePermuteAndReorderTest1 graph
+             ---------
+             |Input  |
+             ---------
+                 |
+             ---------
+             |Permute|
+             ---------
+                 |
+        -------------------
+        |                 |
+        |           -------------
+        |           | --------- |
+        |           | |Permute| |
+    ---------       | --------- |
+    |Reshape|       |     |     |
+    ---------       | --------- |
+        |           | |Reorder| |
+        |           | --------- |
+        |           |-----------|
+        |                 |
+        |             ---------
+        |             |Permute|
+        |             ---------
+        |                 |
+        --------   --------
+               |   |
+             ---------
+             |Concat |
+             ---------
+                 |
+             ---------
+             |Output |
+             ---------
+*/
+
+void FusePermuteAndReorderTest1::CreateGraph() {
+    auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
+    auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
+
+    auto order = inputShape.size() == 5 ? std::vector<int64_t>{0, 2, 3, 4, 1} : std::vector<int64_t>{0, 2, 3, 1};
+
+    auto constOrder1 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
+    auto permute1 = std::make_shared<ngraph::opset5::Transpose>(params[0], constOrder1);
+    auto memFmt1 = inputShape.size() == 5 ? ndhwc : nhwc;
+    permute1->get_rt_info() = setCPUInfo({memFmt1}, {memFmt1}, {});
+
+    auto constOrder2 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
+    auto permute2 = std::make_shared<ngraph::opset5::Transpose>(permute1, constOrder2);
+    auto memFmt2 = inputShape.size() == 5 ? ndhwc : nhwc;
+    permute2->get_rt_info() = setCPUInfo({memFmt2}, {memFmt2}, {});
+
+    auto constOrder3 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
+    auto permute3 = std::make_shared<ngraph::opset5::Transpose>(permute2, constOrder3);
+    auto memFmt3 = inputShape.size() == 5 ? ncdhw : nchw;
+    permute3->get_rt_info() = setCPUInfo({memFmt3}, {memFmt3}, {});
+
+    auto shape = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, permute3->get_output_shape(0));
+    auto reshape = std::make_shared<ngraph::opset5::Reshape>(permute1, shape, false);
+
+    auto concat = ngraph::builder::makeConcat({permute3, reshape}, 1);
+
+    ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(concat)};
+    function = std::make_shared<ngraph::Function>(results, params, "Permute_PermuteReorderPermute_Reshape_Concat");
 }
 
-const auto fusePermuteAndReorderParams = ::testing::Combine(
-        ::testing::Values(SizeVector{1, 2, 3, 4}, SizeVector{1, 2, 3, 4, 5}),
-        ::testing::Values(Precision::I8, Precision::U8)
-);
+TEST_P(FusePermuteAndReorderTest1, CompareWithRefs) {
+    SKIP_IF_CURRENT_TEST_IS_DISABLED()
+
+    Run();
+    CheckPermuteCount(2);
+}
+
+INSTANTIATE_TEST_CASE_P(smoke_Basic, FusePermuteAndReorderTest1, fusePermuteAndReorderCommonParams, FusePermuteAndReorderTest::getTestCaseName);
+
+
+/*  FusePermuteAndReorderTest2 graph
+    ---------         ---------
+    |Input  |         |Input  |
+    ---------         ---------
+        |                 |
+        |           -------------
+    ---------       | --------- |
+    |Reorder|       | |Permute| |
+    ---------       | --------- |
+        |           |     |     |
+    ---------       | --------- |
+    |Permute|       | |Reorder| |
+    ---------       | --------- |
+        |           |-----------|
+        |                 |
+        --------   --------
+               |   |
+             ---------
+             |Concat |
+             ---------
+                 |
+             ---------
+             |Output |
+             ---------
+*/
+
+void FusePermuteAndReorderTest2::CreateGraph() {
+    auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inPrec);
+
+    auto inputShape2(inputShape);
+    inputShape2[inputShape2.size() - 1] *= 2;
+    auto params = ngraph::builder::makeParams(ngPrc, {inputShape, inputShape2});
+
+    auto order = inputShape.size() == 5 ? std::vector<int64_t>{0, 4, 1, 2, 3} : std::vector<int64_t>{0, 3, 1, 2};
+
+    auto constOrder1 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
+    auto permute1 = std::make_shared<ngraph::opset5::Transpose>(params[0], constOrder1);
+    auto memFmt1 = inputShape.size() == 5 ? ndhwc : nhwc;
+    permute1->get_rt_info() = setCPUInfo({memFmt1}, {memFmt1}, {});
+
+    auto constOrder2 = ngraph::builder::makeConstant(ngraph::element::i64, {inputShape.size()}, order);
+    auto permute2 = std::make_shared<ngraph::opset5::Transpose>(params[1], constOrder2);
+    auto memFmt2 = inputShape.size() == 5 ? ncdhw : nchw;
+    permute2->get_rt_info() = setCPUInfo({memFmt2}, {memFmt2}, {});
+
+    auto concat = ngraph::builder::makeConcat({permute1, permute2}, 1);
+
+    ngraph::ResultVector results{std::make_shared<ngraph::opset5::Result>(concat)};
+    function = std::make_shared<ngraph::Function>(results, params, "Permute_Permute_Concat");
+}
+
+TEST_P(FusePermuteAndReorderTest2, CompareWithRefs) {
+    SKIP_IF_CURRENT_TEST_IS_DISABLED()
+
+    Run();
+    CheckPermuteCount(1);
+}
 
-INSTANTIATE_TEST_CASE_P(smoke_Basic, FusePermuteAndReorderTest, fusePermuteAndReorderParams, FusePermuteAndReorderTest::getTestCaseName);
+INSTANTIATE_TEST_CASE_P(smoke_Basic, FusePermuteAndReorderTest2, fusePermuteAndReorderCommonParams, FusePermuteAndReorderTest::getTestCaseName);
 
 }  // namespace LayerTestsDefinitions