--- /dev/null
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "transformations/common_optimizations/fq_reshape_fusion.hpp"
+
+#include <memory>
+#include <vector>
+
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/pattern/op/wrap_type.hpp>
+#include <ngraph/rt_info.hpp>
+
+ngraph::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() {
+ const auto fq_node_p = ngraph::pattern::wrap_type<opset4::FakeQuantize>(
+ {ngraph::pattern::wrap_type<opset4::Constant>(), // for weights only
+ ngraph::pattern::any_input(),
+ ngraph::pattern::any_input(),
+ ngraph::pattern::any_input(),
+ ngraph::pattern::any_input()},
+ pattern::consumers_count(1));
+ const auto reshape_node_p = ngraph::pattern::wrap_type<opset4::Reshape>(
+ {fq_node_p, ngraph::pattern::any_input()});
+
+ ngraph::matcher_pass_callback callback = [=](pattern::Matcher &m) {
+ const auto &pattern_map = m.get_pattern_value_map();
+ const auto fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr();
+ if (fq_node->is_dynamic())
+ return false;
+ const auto &reshape_node = pattern_map.at(reshape_node_p).get_node_shared_ptr();
+ const auto &original_data_rank = fq_node->get_input_shape(0).size();
+ OutputVector renewed_inputs = {reshape_node->clone_with_new_inputs({fq_node->input_value(0), reshape_node->input_value(1)})};
+ for (auto i = 1; i < 5; ++i) {
+ Output<Node> limit_input = fq_node->input_value(i);
+ auto limit_shape = limit_input.get_shape();
+ NGRAPH_CHECK(limit_shape.size() <= original_data_rank, "FakeQuantize limit input has unexpected rank");
+ if (limit_shape.size() < original_data_rank) // aligning limit rank with data rank
+ limit_shape.insert(limit_shape.begin(), original_data_rank - limit_shape.size(), uint64_t(1));
+ NGRAPH_CHECK(limit_shape.size() == original_data_rank, "FakeQuantize limit input has unexpected rank");
+ const auto &limit_size = shape_size(limit_shape);
+ const auto &max_element = *std::max_element(limit_shape.begin(), limit_shape.end());
+ if (max_element == limit_size) { // per-tensor / per-channel limit
+ auto new_limit_shape = reshape_node->get_output_shape(0);
+ std::transform(new_limit_shape.begin(), new_limit_shape.end(), new_limit_shape.begin(),
+ [max_element](size_t &dim) { return dim == max_element ? max_element : 1; });
+ const auto &new_limit_size = shape_size(new_limit_shape);
+ if (new_limit_size == limit_size) { // we tracked future channel placement
+ if (new_limit_shape == limit_input.get_shape())
+ renewed_inputs.push_back(limit_input);
+ else
+ renewed_inputs.push_back(reshape_node->copy_with_new_inputs(
+ {limit_input, opset4::Constant::create(element::i64, {new_limit_shape.size()}, new_limit_shape)}));
+ continue;
+ }
+ }
+ // resulting FQ will become or already is more than per-tensor / per-channel
+ return false;
+ }
+ for (auto &new_input : renewed_inputs)
+ copy_runtime_info({reshape_node, fq_node}, new_input.get_node_shared_ptr());
+ const auto new_fq_node = fq_node->clone_with_new_inputs(renewed_inputs);
+ replace_node(reshape_node, new_fq_node);
+ new_fq_node->set_friendly_name(fq_node->get_friendly_name());
+ copy_runtime_info({fq_node, reshape_node}, new_fq_node);
+ return true;
+ };
+
+ auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_node_p, "FakeQuantizeReshapeFusion");
+ this->register_matcher(m, callback);
+}
--- /dev/null
+// Copyright (C) 2018-2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <gtest/gtest.h>
+
+#include <string>
+#include <memory>
+#include <map>
+
+#include <ngraph/opsets/opset4.hpp>
+#include <ngraph/function.hpp>
+#include <common_test_utils/ngraph_test_utils.hpp>
+#include <ngraph/pass/manager.hpp>
+#include <transformations/common_optimizations/fq_reshape_fusion.hpp>
+#include <transformations/init_node_info.hpp>
+
+#include "cnn_network_ngraph_impl.hpp"
+
+using namespace testing;
+using namespace InferenceEngine;
+
+namespace {
+
+ngraph::Shape DO_NOT_RESHAPE = ngraph::Shape{0};
+
+struct FQReshapeFusionTestCase {
+ ngraph::Shape data_shape, il_shape, ih_shape, ol_shape, oh_shape;
+ std::vector<int64_t> reshape_pattern;
+ ngraph::Shape new_il_shape, new_ih_shape, new_ol_shape, new_oh_shape;
+ bool is_negative;
+};
+
+class nGraphFQReshapeFusionTests : public CommonTestUtils::TestsCommon, public testing::WithParamInterface<std::tuple<FQReshapeFusionTestCase>> {
+public:
+ std::shared_ptr<ngraph::Function> f, ref_f;
+
+ void SetUp() override {
+ const auto& parameters = GetParam();
+ const auto& test_case = std::get<0>(GetParam());
+ f = get_initial_function(test_case);
+ if (test_case.is_negative)
+ ref_f = get_initial_function(test_case);
+ else
+ ref_f = get_reference_function(test_case);
+ }
+
+private:
+ std::shared_ptr<ngraph::Function> get_initial_function(const FQReshapeFusionTestCase & test_case) {
+ const auto & data = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, test_case.data_shape, 0);
+ auto il = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.il_shape);
+ auto ih = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ih_shape);
+ auto ol = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ol_shape);
+ auto oh = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.oh_shape);
+
+ auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(data, il, ih, ol, oh, 42);
+
+ auto reshape_pattern = std::make_shared<ngraph::opset4::Constant>(
+ ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern);
+ auto reshape = std::make_shared<ngraph::opset4::Reshape>(fq, reshape_pattern, true);
+
+ auto result = std::make_shared<ngraph::op::Result>(reshape);
+ ngraph::ParameterVector params = {il, ih, ol, oh};
+ ngraph::ResultVector results = {result};
+ return std::make_shared<ngraph::Function>(results, params);
+ }
+
+ std::shared_ptr<ngraph::Function> get_reference_function(const FQReshapeFusionTestCase & test_case) {
+ const auto & data = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, test_case.data_shape, 0);
+ const auto & reshaped_data = std::make_shared<ngraph::opset4::Reshape>(
+ data,
+ std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern),
+ true);
+
+ const auto & p_il = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.il_shape);
+ ngraph::Output<ngraph::Node> il = p_il;
+ const auto & p_ih = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ih_shape);
+ ngraph::Output<ngraph::Node> ih = p_ih;
+ const auto & p_ol = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ol_shape);
+ ngraph::Output<ngraph::Node> ol = p_ol;
+ const auto & p_oh = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.oh_shape);
+ ngraph::Output<ngraph::Node> oh = p_oh;
+
+ if (test_case.new_il_shape != DO_NOT_RESHAPE)
+ il = std::make_shared<ngraph::opset4::Reshape>(
+ il, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_il_shape.size()}, test_case.new_il_shape), true);
+ if (test_case.new_ih_shape != DO_NOT_RESHAPE)
+ ih = std::make_shared<ngraph::opset4::Reshape>(
+ ih, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_ih_shape.size()}, test_case.new_ih_shape), true);
+ if (test_case.new_ol_shape != DO_NOT_RESHAPE)
+ ol = std::make_shared<ngraph::opset4::Reshape>(
+ ol, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_ol_shape.size()}, test_case.new_ol_shape), true);
+ if (test_case.new_oh_shape != DO_NOT_RESHAPE)
+ oh = std::make_shared<ngraph::opset4::Reshape>(
+ oh, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_oh_shape.size()}, test_case.new_oh_shape), true);
+
+ auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(reshaped_data, il, ih, ol, oh, 42);
+
+ auto result = std::make_shared<ngraph::op::Result>(fq);
+ ngraph::ParameterVector params = {p_il, p_ih, p_ol, p_oh};
+ ngraph::ResultVector results = {result};
+ return std::make_shared<ngraph::Function>(results, params);
+ }
+};
+
+TEST_P(nGraphFQReshapeFusionTests, ReshapeMatMul) {
+ ngraph::pass::Manager manager;
+ manager.register_pass<ngraph::pass::InitNodeInfo>();
+ manager.register_pass<ngraph::pass::FakeQuantizeReshapeFusion>();
+
+ manager.run_passes(f);
+ ASSERT_NO_THROW(check_rt_info(f));
+ auto res = compare_functions(f, ref_f);
+ ASSERT_TRUE(res.first) << res.second;
+}
+
+INSTANTIATE_TEST_CASE_P(NGraph, nGraphFQReshapeFusionTests, testing::Values(
+ // positive
+ FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 1}, {1}, {1, 1}, {1, 2, 1, 1}, {2, 3}, {2, 1}, {1, 1}, DO_NOT_RESHAPE, {2, 1}, false},
+ FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 1}, {1}, {1, 1}, {1, 2, 1, 1}, {1, 2, 1, 3}, {1, 2, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, DO_NOT_RESHAPE, false},
+ FQReshapeFusionTestCase{{2, 3}, {2, 1}, {1}, {1, 1}, {1, 1}, {1, 2, 1, 3}, {1, 2, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, false},
+ // negative
+ FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 3}, {1}, {1, 1}, {1, 2, 1, 1}, {1, 2, 1, 3}, {}, {}, {}, {}, true},
+ FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 1}, {1}, {1, 1}, {1, 2, 1, 1}, {6}, {}, {}, {}, {}, true}));
+} // namespace