1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "low_precision/convolution.hpp"
13 #include "low_precision/network_helper.hpp"
14 #include "low_precision/common/dequantization_op.hpp"
18 namespace low_precision {
20 ConvolutionTransformation::ConvolutionTransformation(const Params& params) : WeightableLayerTransformation(params) {
23 void ConvolutionTransformation::registerMatcherIn(GraphRewrite &pass, TransformationContext &context) const {
27 make_op_pattern<opset1::Convolution>({ make_op_label<opset1::Multiply>(), make_op_label<opset1::FakeQuantize>()}));
30 bool ConvolutionTransformation::isQuantized(std::shared_ptr<Node> layer) const noexcept {
31 return WeightableLayerTransformation::isQuantized(layer, false);
34 bool ConvolutionTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const {
35 auto convolution = m.get_match_root();
37 if (!WeightableLayerTransformation::canBeTransformed(context, convolution)) {
41 FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(convolution);
42 if (!canSubtractBeHandled(convolution, dequantization)) {
46 if ((!supportAsymmetricQuantization) && getDataPrecisionOnWeights(convolution).hasZeroPoint) {
50 if (updatePrecisions && !dequantization.empty() && !dequantization.isLowPrecision()) {
54 convolution = separateInStandaloneBranch(convolution);
55 dequantization = NetworkHelper::getDequantization(convolution);
58 std::shared_ptr<opset1::Subtract> subtract;
59 if (dequantization.subtract != nullptr) {
60 std::shared_ptr<ngraph::Node> layer = dequantization.subtract;
61 ngraph::pass::low_precision::NetworkHelper::cleanRunTimeInfo(layer);
63 auto optimizedSubtract = NetworkHelper::optimizeSubtract(dequantization.subtract);
64 if (optimizedSubtract == nullptr) {
65 optimizedSubtract = dequantization.subtract;
67 subtract = as_type_ptr<opset1::Subtract>(optimizedSubtract);
70 // workaround normalizes shape of Subtract to match CPU plugin expectations
71 if (subtract && subtract->get_output_partial_shape(0) != subtract->get_input_partial_shape(1)) {
72 size_t length = subtract->get_output_partial_shape(0).rank().get_length();
74 // Insert explicit broadcast for channel dimension [1] and immediately fold it
75 Shape broadcastShape(subtract->get_output_partial_shape(0).rank().get_length(), 1);
76 broadcastShape[1] = subtract->get_output_shape(0)[1];
78 std::shared_ptr<Node> newShift = fold<opset1::Broadcast>(
79 subtract->input_value(1).get_node_shared_ptr(),
80 std::make_shared<opset1::Constant>(
85 const auto newSubtract = as_type_ptr<opset1::Subtract>(subtract->clone_with_new_inputs({
86 subtract->input_value(0).get_node_shared_ptr(),
88 replace_node(subtract, newSubtract);
90 newSubtract->set_output_type(0, subtract->get_output_element_type(0), newSubtract->get_output_partial_shape(0));
91 subtract = newSubtract;
94 const size_t groupsCount = NetworkHelper::getGroupsCount(convolution);
95 std::shared_ptr<Node> newMultiplyAfterConst;
96 if (groupsCount > 1ul) {
97 std::shared_ptr<opset1::Constant> multiplyConst = as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1));
99 const std::vector<float> scales = multiplyConst->cast_vector<float>();
100 if (scales.size() == 1ul) {
101 newMultiplyAfterConst = dequantization.multiply->input_value(1).get_node_shared_ptr()->clone_with_new_inputs({});
103 const ngraph::Shape inputShape = convolution->get_input_shape(0);
104 const size_t inputChannelsInGroup = inputShape[1] / groupsCount;
105 const ngraph::Shape outputShape = convolution->get_output_shape(0);
106 std::vector<float> outputScales(outputShape[1]);
108 const size_t outputChannelsInGroup = outputShape[1] / groupsCount;
109 for (size_t group = 0; group < groupsCount; ++group) {
110 const float scaleValue = scales[group * inputChannelsInGroup];
112 for (size_t i = 0; i < outputChannelsInGroup; ++i) {
113 size_t index = group * outputChannelsInGroup + i;
114 outputScales[index] = scaleValue;
118 auto newMulShape = Shape{ outputScales.size() };
119 for (size_t i = 0; i < convolution->get_output_shape(0).size() - 2; ++i) {
120 newMulShape.push_back(1ul);
123 newMultiplyAfterConst = std::make_shared<opset1::Constant>(
124 dequantization.multiply->get_output_element_type(0),
129 std::shared_ptr<opset1::Constant> reducedConstant = as_type_ptr<opset1::Constant>(
130 dequantization.multiply->input_value(1).get_node_shared_ptr());
131 newMultiplyAfterConst = std::make_shared<opset1::Constant>(
132 reducedConstant->get_output_element_type(0),
134 reducedConstant->cast_vector<float>()[0]);
137 auto newConvolution = convolution->copy_with_new_inputs({ dequantization.multiply->input_value(0), convolution->input_value(1) });
138 std::shared_ptr<ngraph::opset1::Multiply> newMultiplyAfter = std::make_shared<op::TypeRelaxed<DequantizationMultiply>>(
139 std::vector<element::Type>{ element::f32, element::f32 }, std::vector<element::Type>{ element::f32 },
140 ngraph::op::TemporaryReplaceOutputType(newConvolution, element::f32).get(),
141 ngraph::op::TemporaryReplaceOutputType(newMultiplyAfterConst, element::f32).get());
143 replace_node(convolution, newMultiplyAfter);
144 convolution = newMultiplyAfter->input_value(0).get_node_shared_ptr();
146 if (is_type<opset1::Convert>(convolution->get_input_node_ptr(0))) {
147 auto newConvolution = convolution->clone_with_new_inputs({
148 convolution->get_input_node_ptr(0)->get_input_node_shared_ptr(0),
149 convolution->get_input_node_shared_ptr(1) });
150 replace_node(convolution, newConvolution);
151 convolution = newConvolution;
156 decomposeFakeQuantizeForWeightsPath(convolution);
158 std::shared_ptr<opset1::Reshape> reshapeFromWeights = as_type_ptr<opset1::Reshape>(convolution->input_value(1).get_node_shared_ptr());
159 std::shared_ptr<opset1::Multiply> multiplyFromWeights = as_type_ptr<opset1::Multiply>(
160 reshapeFromWeights == nullptr ?
161 convolution->input_value(1).get_node_shared_ptr() :
162 convolution->get_input_node_ptr(1)->get_input_node_shared_ptr(0));
163 std::shared_ptr<opset1::Subtract> subtractFromWeights = as_type_ptr<opset1::Subtract>(multiplyFromWeights->get_input_node_shared_ptr(0));
166 Shape newScaleShape = multiplyFromWeights->get_input_shape(1);
167 // that's all we need: [C, 1, 1, 1] => [C, 1, 1]
168 newScaleShape.pop_back();
170 if (reshapeFromWeights != nullptr) {
171 reshapeFromWeights = as_type_ptr<opset1::Reshape>(reshapeFromWeights->copy_with_new_inputs({
172 multiplyFromWeights->input_value(0),
173 reshapeFromWeights->input_value(1) }));
176 auto newMultiplyAfter = std::make_shared<DequantizationMultiply>(
177 convolution->copy_with_new_inputs({
178 convolution->input_value(0),
179 reshapeFromWeights != nullptr ?
181 multiplyFromWeights->input_value(0)
183 fold_reshape<opset1::Reshape>(
184 multiplyFromWeights->input_value(1),
185 std::make_shared<opset1::Constant>(element::u64, Shape{ newScaleShape.size() }, newScaleShape),
187 replace_node(convolution, newMultiplyAfter);
188 convolution = newMultiplyAfter->input_value(0).get_node_shared_ptr();
191 if (subtractFromWeights != nullptr) {
192 auto optimizedSubtract = NetworkHelper::optimizeSubtract(subtractFromWeights);
193 // TODO: handle optimizedSubtract == nullptr;
194 if (optimizedSubtract != nullptr) {
195 subtractFromWeights = as_type_ptr<opset1::Subtract>(optimizedSubtract);
197 const Shape weightsShape = subtractFromWeights->input(0).get_shape();
198 Shape zeroPointShape(weightsShape.size(), 1ul);
199 zeroPointShape[0] = weightsShape[0];
201 auto zeroPointConstant = fold<opset1::Broadcast>(
202 subtractFromWeights->get_input_node_shared_ptr(1),
203 std::make_shared<opset1::Constant>(element::i32, Shape{ zeroPointShape.size() }, zeroPointShape));
204 replace_node(subtractFromWeights->get_input_node_shared_ptr(1), zeroPointConstant);
208 std::shared_ptr<opset1::Convert> convertFromWeights = as_type_ptr<opset1::Convert>(subtractFromWeights == nullptr ?
209 multiplyFromWeights->get_input_node_shared_ptr(0) :
210 subtractFromWeights->get_input_node_shared_ptr(0));
212 if (convertFromWeights != nullptr) {
213 std::shared_ptr<Node> childNode = reshapeFromWeights == nullptr ? convolution : reshapeFromWeights;
215 auto newConvolution = convolution->clone_with_new_inputs({
216 convolution->get_input_node_shared_ptr(0),
217 childNode.get() == convolution.get() ?
218 convolution->get_input_node_ptr(1)->get_input_node_shared_ptr(0) :
219 childNode->copy_with_new_inputs({convertFromWeights->input_value(0), childNode->input_value(1)})});
220 replace_node(convolution, newConvolution);
221 convolution = newConvolution;
224 reshapeFromWeights = as_type_ptr<opset1::Reshape>(convolution->get_input_node_shared_ptr(1));
225 if (reshapeFromWeights != nullptr) {
226 const std::shared_ptr<Node> newWeights = fold_reshape<opset1::Reshape>(
227 reshapeFromWeights->input_value(0),
228 reshapeFromWeights->input_value(1),
231 replace_node(reshapeFromWeights, newWeights);
235 std::shared_ptr<ngraph::opset1::Multiply> finalDequantization = NetworkHelper::optimizeMultipliesAfter(
236 convolution->output(0).get_target_inputs().begin()->get_node()->shared_from_this());
238 updateOutput(context, finalDequantization, convolution);
240 auto onWeights = convolution->get_input_node_shared_ptr(1);
241 if (is_type<opset1::Reshape>(onWeights)) {
242 onWeights = onWeights->get_input_node_shared_ptr(0);
245 if (is_type<opset1::Subtract>(onWeights)) {
246 auto& rt = onWeights->get_rt_info();
247 rt["DISABLED_CONSTANT_FOLDING"] = std::make_shared<ngraph::VariantWrapper<std::string>>("");
252 } // namespace low_precision
254 } // namespace ngraph