Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / src / pass / constant_folding_dequantize.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "constant_folding.hpp"
18 #include "ngraph/log.hpp"
19 #include "ngraph/op/dequantize.hpp"
20 #include "ngraph/runtime/reference/dequantize.hpp"
21
22 NGRAPH_SUPPRESS_DEPRECATED_START
23
24 using namespace std;
25 using namespace ngraph;
26
27 template <class QUANT, class REAL>
28 shared_ptr<op::Constant> fold_constant_dequantize(shared_ptr<op::Constant> constant,
29                                                   shared_ptr<op::Dequantize> dequant,
30                                                   shared_ptr<op::Constant> scale,
31                                                   shared_ptr<op::Constant> offset)
32 {
33     const Shape& out_shape = constant->get_shape();
34     runtime::AlignedBuffer buffer(shape_size(out_shape) * sizeof(REAL));
35     REAL* data_ptr = buffer.get_ptr<REAL>();
36
37     runtime::reference::dequantize<QUANT, REAL>(constant->get_data_ptr<QUANT>(),
38                                                 scale->get_data_ptr<REAL>(),
39                                                 offset->get_data_ptr<QUANT>(),
40                                                 data_ptr,
41                                                 constant->get_shape(),
42                                                 scale->get_shape(),
43                                                 dequant->get_axes());
44
45     return make_shared<op::Constant>(dequant->get_element_type(), out_shape, data_ptr);
46 }
47
48 void pass::ConstantFolding::construct_constant_dequantize()
49 {
50     auto constant_label =
51         make_shared<pattern::op::Label>(element::u8, Shape{2}, pattern::has_class<op::Constant>());
52     auto dq_scale = op::Constant::create(element::f32, Shape{}, {1});
53     auto dq_offset = op::Constant::create(element::u8, Shape{}, {1});
54     auto dequant_op =
55         make_shared<op::Dequantize>(constant_label, dq_scale, dq_offset, element::f32, AxisSet{});
56     auto dequant = make_shared<pattern::op::Label>(dequant_op, nullptr, NodeVector{dequant_op});
57
58     auto constant_dequantize_callback = [this, constant_label, dequant](pattern::Matcher& m) {
59         NGRAPH_DEBUG << "In callback for constant_dequantize_callback against node = "
60                      << m.get_match_root()->get_name();
61
62         auto pattern_map = m.get_pattern_map();
63
64         auto constant_match = as_type_ptr<op::Constant>(pattern_map[constant_label]);
65         auto dequant_match = pattern_map[dequant];
66         auto dequantize_op = as_type_ptr<op::Dequantize>(dequant_match);
67
68         if (cf_is_disabled(dequantize_op))
69             return false;
70
71         auto scale = as_type_ptr<op::Constant>(dequant_match->input_value(1).get_node_shared_ptr());
72         auto offset =
73             as_type_ptr<op::Constant>(dequant_match->input_value(2).get_node_shared_ptr());
74
75         NGRAPH_CHECK(revalidate_and_ensure_static(dequantize_op));
76         auto type = constant_match->get_element_type();
77
78         if (dequant_match->get_element_type() != element::f32)
79         {
80             return false;
81         }
82
83         if (type == element::u8)
84         {
85             replace_node(m.get_match_root(),
86                          fold_constant_dequantize<uint8_t, float>(
87                              constant_match, dequantize_op, scale, offset));
88             return true;
89         }
90         else if (type == element::i8)
91         {
92             replace_node(m.get_match_root(),
93                          fold_constant_dequantize<int8_t, float>(
94                              constant_match, dequantize_op, scale, offset));
95             return true;
96         }
97
98         return false;
99     };
100
101     auto dequantize_matcher =
102         make_shared<pattern::Matcher>(dequant, "ConstantFolding.ConstantDequantize");
103     NGRAPH_SUPPRESS_DEPRECATED_START
104     this->add_matcher(
105         dequantize_matcher, constant_dequantize_callback, PassProperty::CHANGE_DYNAMIC_STATE);
106     NGRAPH_SUPPRESS_DEPRECATED_END
107 }