[LPT] integration: issue #42391 & issue #43001 (#3201)
[platform/upstream/dldt.git] / inference-engine / src / low_precision_transformations / src / group_convolution.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "low_precision/group_convolution.hpp"
6
7 #include <memory>
8 #include <string>
9 #include <vector>
10
11 #include "low_precision/network_helper.hpp"
12
13 namespace ngraph {
14 namespace pass {
15 namespace low_precision {
16
17 GroupConvolutionTransformation::GroupConvolutionTransformation(const Params& params) : ConvolutionTransformation(params) {
18 }
19
20 void GroupConvolutionTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
21     // question to nGraph: why it doesn't work
22     // addPattern(
23     //    pass,
24     //    context,
25     //    make_op_pattern<opset1::GroupConvolution>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::FakeQuantize>()}));
26
27     addSingleNodePattern<opset1::GroupConvolution>(pass, context);
28 }
29
30 bool GroupConvolutionTransformation::isQuantized(std::shared_ptr<Node> layer) const noexcept {
31     return WeightableLayerTransformation::isQuantized(layer, true);
32 }
33
34 bool GroupConvolutionTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
35     auto convolution = m.get_match_root();
36
37     if (!GroupConvolutionTransformation::canBeTransformed(context, convolution)) {
38         return false;
39     }
40
41     ConvolutionTransformation::transform(context, m);
42     return true;
43 }
44
45 } // namespace low_precision
46 } // namespace pass
47 } // namespace ngraph