1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "constant_folding.hpp"
18 #include "ngraph/log.hpp"
19 #include "ngraph/op/dequantize.hpp"
20 #include "ngraph/runtime/reference/dequantize.hpp"
22 NGRAPH_SUPPRESS_DEPRECATED_START
25 using namespace ngraph;
27 template <class QUANT, class REAL>
28 shared_ptr<op::Constant> fold_constant_dequantize(shared_ptr<op::Constant> constant,
29 shared_ptr<op::Dequantize> dequant,
30 shared_ptr<op::Constant> scale,
31 shared_ptr<op::Constant> offset)
33 const Shape& out_shape = constant->get_shape();
34 runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(REAL));
35 REAL* data_ptr = buffer.get_ptr<REAL>();
37 runtime::reference::dequantize<QUANT, REAL>(constant->get_data_ptr<QUANT>(),
38 scale->get_data_ptr<REAL>(),
39 offset->get_data_ptr<QUANT>(),
41 constant->get_shape(),
45 return make_shared<op::Constant>(dequant->get_element_type(), out_shape, data_ptr);
48 void pass::ConstantFolding::construct_constant_dequantize()
51 make_shared<pattern::op::Label>(element::u8, Shape{2}, pattern::has_class<op::Constant>());
52 auto dq_scale = op::Constant::create(element::f32, Shape{}, {1});
53 auto dq_offset = op::Constant::create(element::u8, Shape{}, {1});
55 make_shared<op::Dequantize>(constant_label, dq_scale, dq_offset, element::f32, AxisSet{});
56 auto dequant = make_shared<pattern::op::Label>(dequant_op, nullptr, NodeVector{dequant_op});
58 auto constant_dequantize_callback = [this, constant_label, dequant](pattern::Matcher& m) {
59 NGRAPH_DEBUG << "In callback for constant_dequantize_callback against node = "
60 << m.get_match_root()->get_name();
62 auto pattern_map = m.get_pattern_map();
64 auto constant_match = as_type_ptr<op::Constant>(pattern_map[constant_label]);
65 auto dequant_match = pattern_map[dequant];
66 auto dequantize_op = as_type_ptr<op::Dequantize>(dequant_match);
68 if (cf_is_disabled(dequantize_op))
71 auto scale = as_type_ptr<op::Constant>(dequant_match->input_value(1).get_node_shared_ptr());
73 as_type_ptr<op::Constant>(dequant_match->input_value(2).get_node_shared_ptr());
75 NGRAPH_CHECK(revalidate_and_ensure_static(dequantize_op));
76 auto type = constant_match->get_element_type();
78 if (dequant_match->get_element_type() != element::f32)
83 if (type == element::u8)
85 replace_node(m.get_match_root(),
86 fold_constant_dequantize<uint8_t, float>(
87 constant_match, dequantize_op, scale, offset));
90 else if (type == element::i8)
92 replace_node(m.get_match_root(),
93 fold_constant_dequantize<int8_t, float>(
94 constant_match, dequantize_op, scale, offset));
101 auto dequantize_matcher =
102 make_shared<pattern::Matcher>(dequant, "ConstantFolding.ConstantDequantize");
103 NGRAPH_SUPPRESS_DEPRECATED_START
105 dequantize_matcher, constant_dequantize_callback, PassProperty::CHANGE_DYNAMIC_STATE);
106 NGRAPH_SUPPRESS_DEPRECATED_END