[ nG transformation ] Const -> FQ -> Reshape fuse (#2388)
authorEvgenya Stepyreva <evgenya.stepyreva@intel.com>
Thu, 24 Sep 2020 08:44:08 +0000 (11:44 +0300)
committerGitHub <noreply@github.com>
Thu, 24 Sep 2020 08:44:08 +0000 (11:44 +0300)
* [ nG transformation ] Const -> FQ -> Reshape fuse
Ticket: 39124

* fix dtype incompatibility: uint64 vs size_t

* Review comments adressed

inference-engine/src/transformations/include/transformations/common_optimizations/fq_reshape_fusion.hpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/common_optimizations/fq_reshape_fusion.cpp [new file with mode: 0644]
inference-engine/src/transformations/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp
inference-engine/tests/functional/inference_engine/transformations/fq_reshape_fusion.cpp [new file with mode: 0644]

diff --git a/inference-engine/src/transformations/include/transformations/common_optimizations/fq_reshape_fusion.hpp b/inference-engine/src/transformations/include/transformations/common_optimizations/fq_reshape_fusion.hpp
new file mode 100644 (file)
index 0000000..de5fd0d
--- /dev/null
@@ -0,0 +1,32 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include <memory>
+#include <vector>
+
+#include <transformations_visibility.hpp>
+
+#include <ngraph/pass/graph_rewrite.hpp>
+
+namespace ngraph {
+namespace pass {
+
+class TRANSFORMATIONS_API FakeQuantizeReshapeFusion;
+
+} // namespace pass
+} // namespace ngraph
+
+/**
+ * @ingroup ie_transformation_common_api
+ * @brief This transformation looks for a FQ + Reshape pair in the graph and moves
+ * the Reshape operation above the FQ node. Shapes of limit inputs are updated
+ * following FQ broadcasting semantics
+ */
+
+class ngraph::pass::FakeQuantizeReshapeFusion : public ngraph::pass::MatcherPass {
+public:
+    FakeQuantizeReshapeFusion();
+};
diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/fq_reshape_fusion.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/fq_reshape_fusion.cpp
new file mode 100644 (file)
index 0000000..0f7127c
--- /dev/null
@@ -0,0 +1,70 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/common_optimizations/fq_reshape_fusion.hpp"
+
+#include <memory>
+#include <vector>
+
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+#include <ngraph/rt_info.hpp>
+
+ngraph::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() {
+    const auto fq_node_p = ngraph::pattern::wrap_type<opset4::FakeQuantize>(
+            {ngraph::pattern::wrap_type<opset4::Constant>(), // for weights only
+             ngraph::pattern::any_input(),
+             ngraph::pattern::any_input(),
+             ngraph::pattern::any_input(),
+             ngraph::pattern::any_input()},
+            pattern::consumers_count(1));
+    const auto reshape_node_p = ngraph::pattern::wrap_type<opset4::Reshape>(
+            {fq_node_p, ngraph::pattern::any_input()});
+
+    ngraph::matcher_pass_callback callback = [=](pattern::Matcher &m) {
+        const auto &pattern_map = m.get_pattern_value_map();
+        const auto fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr();
+        if (fq_node->is_dynamic())
+            return false;
+        const auto &reshape_node = pattern_map.at(reshape_node_p).get_node_shared_ptr();
+        const auto &original_data_rank = fq_node->get_input_shape(0).size();
+        OutputVector renewed_inputs = {reshape_node->clone_with_new_inputs({fq_node->input_value(0), reshape_node->input_value(1)})};
+        for (auto i = 1; i < 5; ++i) {
+            Output<Node> limit_input = fq_node->input_value(i);
+            auto limit_shape = limit_input.get_shape();
+            NGRAPH_CHECK(limit_shape.size() <= original_data_rank, "FakeQuantize limit input has unexpected rank");
+            if (limit_shape.size() < original_data_rank) // aligning limit rank with data rank
+                limit_shape.insert(limit_shape.begin(), original_data_rank - limit_shape.size(), uint64_t(1));
+            NGRAPH_CHECK(limit_shape.size() == original_data_rank, "FakeQuantize limit input has unexpected rank");
+            const auto &limit_size = shape_size(limit_shape);
+            const auto &max_element = *std::max_element(limit_shape.begin(), limit_shape.end());
+            if (max_element == limit_size) { // per-tensor / per-channel limit
+                auto new_limit_shape = reshape_node->get_output_shape(0);
+                std::transform(new_limit_shape.begin(), new_limit_shape.end(), new_limit_shape.begin(),
+                               [max_element](size_t &dim) { return dim == max_element ? max_element : 1; });
+                const auto &new_limit_size = shape_size(new_limit_shape);
+                if (new_limit_size == limit_size) { // we tracked future channel placement
+                    if (new_limit_shape == limit_input.get_shape())
+                        renewed_inputs.push_back(limit_input);
+                    else
+                        renewed_inputs.push_back(reshape_node->copy_with_new_inputs(
+                                {limit_input, opset4::Constant::create(element::i64, {new_limit_shape.size()}, new_limit_shape)}));
+                    continue;
+                }
+            }
+            // resulting FQ will become or already is more than per-tensor / per-channel
+            return false;
+        }
+        for (auto &new_input : renewed_inputs)
+            copy_runtime_info({reshape_node, fq_node}, new_input.get_node_shared_ptr());
+        const auto new_fq_node = fq_node->clone_with_new_inputs(renewed_inputs);
+        replace_node(reshape_node, new_fq_node);
+        new_fq_node->set_friendly_name(fq_node->get_friendly_name());
+        copy_runtime_info({fq_node, reshape_node}, new_fq_node);
+        return true;
+    };
+
+    auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_node_p, "FakeQuantizeReshapeFusion");
+    this->register_matcher(m, callback);
+}
index abc219538bd2b9cd003a2a9bcff3fe2b5a39cc8a..82a2e5ce6ec423305d5d6683f736a235a9a3636b 100644 (file)
@@ -52,6 +52,7 @@
 #include <transformations/reduce_l1_decomposition.hpp>
 #include <transformations/reduce_l2_decomposition.hpp>
 #include <transformations/common_optimizations/fq_mul_fusion.hpp>
+#include <transformations/common_optimizations/fq_reshape_fusion.hpp>
 
 #include <ngraph/pass/constant_folding.hpp>
 #include <ngraph/pass/manager.hpp>
@@ -97,7 +98,6 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
     decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
     decomp->add_matcher<ngraph::pass::ConvertMatMulToFC>();
     decomp->add_matcher<ngraph::pass::ConvertMatMulToGemm>();
-    decomp->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
     decomp->set_name("ngraph::pass::Decompositions");
 
     // CF is required after all decompositions
@@ -112,9 +112,6 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
     manager.register_pass<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
     manager.register_pass<ngraph::pass::ConstantFolding>();
 
-    // Multiply the thrird and fourth input instead of the output of FQ with all const inputs
-    manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
-
     // Convolution/Deconvolution/FullyConnected fusions
     auto convert_convolutions = manager.register_pass<ngraph::pass::GraphRewrite>();
     convert_convolutions->add_matcher<ngraph::pass::ConvertConvolution>();
@@ -123,6 +120,12 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
     convert_convolutions->add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
     convert_convolutions->set_name("ngraph::pass::ConvertConvolutions");
 
+    auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
+    fq_fusions->add_matcher<FakeQuantizeMulFusion>();
+    fq_fusions->add_matcher<FakeQuantizeReshapeFusion>();
+    fq_fusions->add_matcher<PullTransposeThroughFQUp>();
+    fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
+
     // Convolution/Deconvolution/FullyConnected fusions
     auto fusion = manager.register_pass<ngraph::pass::GraphRewrite>();
     fusion->add_matcher<ngraph::pass::ConvAddFusion>();
diff --git a/inference-engine/tests/functional/inference_engine/transformations/fq_reshape_fusion.cpp b/inference-engine/tests/functional/inference_engine/transformations/fq_reshape_fusion.cpp
new file mode 100644 (file)
index 0000000..aea06e1
--- /dev/null
@@ -0,0 +1,125 @@
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+#include <map>
+
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/function.hpp>
+#include <common_test_utils/ngraph_test_utils.hpp>
+#include <ngraph/pass/manager.hpp>
+#include <transformations/common_optimizations/fq_reshape_fusion.hpp>
+#include <transformations/init_node_info.hpp>
+
+#include "cnn_network_ngraph_impl.hpp"
+
+using namespace testing;
+using namespace InferenceEngine;
+
+namespace {
+
+ngraph::Shape DO_NOT_RESHAPE = ngraph::Shape{0};
+
+struct FQReshapeFusionTestCase {
+    ngraph::Shape data_shape, il_shape, ih_shape, ol_shape, oh_shape;
+    std::vector<int64_t> reshape_pattern;
+    ngraph::Shape new_il_shape, new_ih_shape, new_ol_shape, new_oh_shape;
+    bool is_negative;
+};
+
+class nGraphFQReshapeFusionTests : public CommonTestUtils::TestsCommon, public testing::WithParamInterface<std::tuple<FQReshapeFusionTestCase>> {
+public:
+    std::shared_ptr<ngraph::Function> f, ref_f;
+
+    void SetUp() override {
+        const auto& parameters = GetParam();
+        const auto& test_case = std::get<0>(GetParam());
+        f = get_initial_function(test_case);
+        if (test_case.is_negative)
+            ref_f = get_initial_function(test_case);
+        else
+            ref_f = get_reference_function(test_case);
+    }
+
+private:
+    std::shared_ptr<ngraph::Function> get_initial_function(const FQReshapeFusionTestCase & test_case) {
+        const auto & data =  std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, test_case.data_shape, 0);
+        auto il = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.il_shape);
+        auto ih = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ih_shape);
+        auto ol = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ol_shape);
+        auto oh = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.oh_shape);
+
+        auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(data, il, ih, ol, oh, 42);
+
+        auto reshape_pattern = std::make_shared<ngraph::opset4::Constant>(
+                ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern);
+        auto reshape = std::make_shared<ngraph::opset4::Reshape>(fq, reshape_pattern, true);
+
+        auto result = std::make_shared<ngraph::op::Result>(reshape);
+        ngraph::ParameterVector params = {il, ih, ol, oh};
+        ngraph::ResultVector results = {result};
+        return std::make_shared<ngraph::Function>(results, params);
+    }
+
+    std::shared_ptr<ngraph::Function> get_reference_function(const FQReshapeFusionTestCase & test_case) {
+        const auto & data =  std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, test_case.data_shape, 0);
+        const auto & reshaped_data = std::make_shared<ngraph::opset4::Reshape>(
+                data,
+                std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern),
+                true);
+
+        const auto & p_il = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.il_shape);
+        ngraph::Output<ngraph::Node> il = p_il;
+        const auto & p_ih = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ih_shape);
+        ngraph::Output<ngraph::Node> ih = p_ih;
+        const auto & p_ol = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ol_shape);
+        ngraph::Output<ngraph::Node> ol = p_ol;
+        const auto & p_oh = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.oh_shape);
+        ngraph::Output<ngraph::Node> oh = p_oh;
+
+        if (test_case.new_il_shape != DO_NOT_RESHAPE)
+            il = std::make_shared<ngraph::opset4::Reshape>(
+                    il, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_il_shape.size()}, test_case.new_il_shape), true);
+        if (test_case.new_ih_shape != DO_NOT_RESHAPE)
+            ih = std::make_shared<ngraph::opset4::Reshape>(
+                    ih, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_ih_shape.size()}, test_case.new_ih_shape), true);
+        if (test_case.new_ol_shape != DO_NOT_RESHAPE)
+            ol = std::make_shared<ngraph::opset4::Reshape>(
+                    ol, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_ol_shape.size()}, test_case.new_ol_shape), true);
+        if (test_case.new_oh_shape != DO_NOT_RESHAPE)
+            oh = std::make_shared<ngraph::opset4::Reshape>(
+                    oh, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_oh_shape.size()}, test_case.new_oh_shape), true);
+
+        auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(reshaped_data, il, ih, ol, oh, 42);
+
+        auto result = std::make_shared<ngraph::op::Result>(fq);
+        ngraph::ParameterVector params = {p_il, p_ih, p_ol, p_oh};
+        ngraph::ResultVector results = {result};
+        return std::make_shared<ngraph::Function>(results, params);
+    }
+};
+
+TEST_P(nGraphFQReshapeFusionTests, ReshapeMatMul) {
+    ngraph::pass::Manager manager;
+    manager.register_pass<ngraph::pass::InitNodeInfo>();
+    manager.register_pass<ngraph::pass::FakeQuantizeReshapeFusion>();
+
+    manager.run_passes(f);
+    ASSERT_NO_THROW(check_rt_info(f));
+    auto res = compare_functions(f, ref_f);
+    ASSERT_TRUE(res.first) << res.second;
+}
+
+INSTANTIATE_TEST_CASE_P(NGraph, nGraphFQReshapeFusionTests, testing::Values(
+    // positive
+    FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 1}, {1}, {1, 1}, {1, 2, 1, 1}, {2, 3}, {2, 1}, {1, 1}, DO_NOT_RESHAPE, {2, 1}, false},
+    FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 1}, {1}, {1, 1}, {1, 2, 1, 1}, {1, 2, 1, 3}, {1, 2, 1, 1}, {1, 1, 1, 1},  {1, 1, 1, 1}, DO_NOT_RESHAPE, false},
+    FQReshapeFusionTestCase{{2, 3}, {2, 1}, {1}, {1, 1}, {1, 1}, {1, 2, 1, 3}, {1, 2, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, false},
+    // negative
+    FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 3}, {1}, {1, 1}, {1, 2, 1, 1}, {1, 2, 1, 3}, {}, {},  {}, {}, true},
+    FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 1}, {1}, {1, 1}, {1, 2, 1, 1}, {6}, {}, {},  {}, {}, true}));
+}  // namespace