Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / src / builder / matmul_factory.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <cstddef>
18 #include <iterator>
19 #include <memory>
20 #include <numeric>
21
22 #include "ngraph/builder/autobroadcast.hpp"
23 #include "ngraph/builder/make_constant.hpp"
24 #include "ngraph/builder/matmul_factory.hpp"
25 #include "ngraph/builder/reshape.hpp"
26 #include "ngraph/op/concat.hpp"
27 #include "ngraph/op/dot.hpp"
28 #include "ngraph/op/quantized_dot.hpp"
29 #include "ngraph/op/reshape.hpp"
30 #include "ngraph/op/slice.hpp"
31
32 NGRAPH_SUPPRESS_DEPRECATED_START
33
34 using namespace ngraph;
35 using namespace std;
36
37 /// \brief      Slice the sub matrix from the input tensor.
38 ///
39 /// \param[in]  node  The input tensor. Must be at most of rank 3.
40 /// \param[in]  idx   The index on the first axis, at which to slice sub-matrix.
41 ///
42 /// \return     The node representing sub matrix.
43 ///
44 static Output<Node> get_sub_matrix(const Output<Node>& node, size_t idx)
45 {
46     const Shape& shape{node.get_shape()};
47     if (shape.size() < 3)
48     {
49         return node.get_node_shared_ptr();
50     }
51     // Below bounds defines the sub_matrix through ranges for each input node axis.
52     Coordinate lower_bounds(shape.size());
53     Coordinate upper_bounds = shape;
54     // We assume `node` tensor is of rank equal 3, thus we slice the sub-matrix lying in the last
55     // two dimensions at index `idx` of first axis.
56     lower_bounds.at(0) = idx;
57     upper_bounds.at(0) = idx + 1;
58
59     auto sub_matrix = Output<Node>{make_shared<op::Slice>(node, lower_bounds, upper_bounds)};
60     // Remove first single entry dim.
61     return builder::opset1::squeeze(sub_matrix);
62 }
63
64 Output<Node> builder::MatmulFactory::get_left()
65 {
66     return m_inputs.at(0);
67 }
68
69 Output<Node> builder::MatmulFactory::get_right()
70 {
71     return m_inputs.at(1);
72 }
73
74 OutputVector builder::MatmulFactory::make_matmul_op()
75 {
76     auto collapse = [](const Output<Node>& value, const size_t start_axis, const size_t end_axis) {
77         auto shape = value.get_shape();
78         size_t collapsed_axis_size = accumulate(next(begin(shape), start_axis),
79                                                 next(begin(shape), end_axis + 1),
80                                                 1UL,
81                                                 multiplies<size_t>());
82
83         Shape output_shape{collapsed_axis_size};
84         output_shape.insert(end(output_shape), next(begin(shape), end_axis + 1), end(shape));
85         return make_shared<op::Reshape>(
86                    value, get_default_order(value.get_shape().size()), output_shape)
87             ->add_provenance_group_members_above({value});
88     };
89     auto left = get_left();
90     auto right = get_right();
91
92     size_t left_rank{left.get_shape().size()};
93     size_t right_rank{right.get_shape().size()};
94
95     // First (easy) case that is already internally handled by Ngraph Dot operator.
96     // Multiply two tensors where both of them has rank lower equal 2.
97     if (left_rank <= 2 && right_rank <= 2)
98     {
99         return {make_dot(left, right)
100                     .get_node_shared_ptr()
101                     ->add_provenance_group_members_above(m_inputs)};
102     }
103
104     // Second case:
105     // Multiply two tensors where at least one of them is rank greater equal 3.
106
107     // Broadcast input arguments only if both of them are not vectors.
108     if (left_rank > 1 && right_rank > 1)
109     {
110         const OutputVector& broadcasted_nodes =
111             builder::numpy_broadcast_for_matmul_operation(left, right);
112
113         left = broadcasted_nodes.at(0);
114         right = broadcasted_nodes.at(1);
115     }
116     const auto& left_shape = left.get_shape();
117     const auto& right_shape = right.get_shape();
118
119     // Collapse both tensors _stack of matrices_ axes (all except the last two).
120     // This will make easier further dot product calculations.
121     if (left_shape.size() > 3)
122     {
123         left = collapse(left, 0, left_shape.size() - 3);
124     }
125     if (right_shape.size() > 3)
126     {
127         right = collapse(right, 0, right_shape.size() - 3);
128     }
129
130     // Perform multiple small dot products
131     size_t groups = left.get_shape().at(0);
132     // If we haven't broadcast earlier this means that one of the inputs is a vector,
133     // thus the number of groups is defined by the shape of the bigger tensor.
134     if (right.get_shape().size() > left.get_shape().size())
135     {
136         groups = right.get_shape().at(0);
137     }
138     NodeVector small_dots(groups);
139
140     for (size_t g = 0; g < groups; ++g)
141     {
142         const auto sliced_left = get_sub_matrix(left, g);
143         const auto sliced_right = get_sub_matrix(right, g);
144         auto sub_dot = make_dot(sliced_left, sliced_right);
145
146         // Expand sub_dot result with single empty outermost axis, in order to
147         // later concatenate sub_dots at this axis.
148         small_dots.at(g) = builder::opset1::expand_dims(sub_dot);
149     }
150
151     // Concatenate sub_dots on groups axis.
152     auto result = make_shared<op::Concat>(small_dots, 0);
153
154     if (left_shape.size() <= 3 && right_shape.size() <= 3)
155     {
156         return {result->add_provenance_group_members_above(m_inputs)};
157     }
158     // Expand result _stack of matrices_ axes to get expected result shape.
159     else
160     {
161         const Shape& shape{result->get_shape()};
162         Shape result_shape(next(begin(shape)), end(shape));
163         result_shape.insert(
164             begin(result_shape), begin(left_shape), next(begin(left_shape), left_shape.size() - 2));
165         return {make_shared<op::Reshape>(result, get_default_order(shape.size()), result_shape)
166                     ->add_provenance_group_members_above(m_inputs)};
167     }
168 }
169
170 Output<Node> builder::MatmulFactory::make_dot(const Output<Node>& left, const Output<Node>& right)
171 {
172     return make_shared<op::Dot>(left, right);
173 }
174
175 Output<Node> builder::QLinearMatmulFactory::get_right()
176 {
177     return m_inputs.at(3);
178 }
179
180 Output<Node> builder::QLinearMatmulFactory::make_dot(const Output<Node>& left,
181                                                      const Output<Node>& right)
182 {
183     ngraph::element::Type output_type;
184
185     if (left.get_element_type() == ngraph::element::u8 &&
186         right.get_element_type() == ngraph::element::i8)
187     {
188         output_type = ngraph::element::i8;
189     }
190     else if (left.get_element_type() == ngraph::element::u8 &&
191              right.get_element_type() == ngraph::element::u8)
192     {
193         output_type = ngraph::element::u8;
194     }
195
196     return std::make_shared<ngraph::op::QuantizedDot>(left,
197                                                       right,
198                                                       1,
199                                                       m_inputs.at(1),
200                                                       m_inputs.at(2),
201                                                       m_inputs.at(4),
202                                                       m_inputs.at(5),
203                                                       m_inputs.at(6),
204                                                       m_inputs.at(7),
205                                                       output_type,
206                                                       ngraph::AxisSet{},
207                                                       ngraph::AxisSet{},
208                                                       ngraph::AxisSet{});
209 }
210
211 Output<Node> builder::MatmulIntegerFactory::make_dot(const Output<Node>& left,
212                                                      const Output<Node>& right)
213 {
214     auto num_inputs = m_inputs.size();
215     auto scale_one = ngraph::builder::make_constant(ngraph::element::f32, Shape{}, 1);
216     auto output_zero_point = ngraph::builder::make_constant(ngraph::element::i32, Shape{}, 0);
217     auto left_zero_point = ngraph::builder::make_constant(left.get_element_type(), Shape{}, 0);
218     auto right_zero_point = ngraph::builder::make_constant(right.get_element_type(), Shape{}, 0);
219     if (num_inputs == 2)
220     {
221         return std::make_shared<ngraph::op::QuantizedDot>(left,
222                                                           right,
223                                                           1,
224                                                           scale_one,
225                                                           left_zero_point,
226                                                           scale_one,
227                                                           right_zero_point,
228                                                           scale_one,
229                                                           output_zero_point,
230                                                           ngraph::element::i32,
231                                                           ngraph::AxisSet{},
232                                                           ngraph::AxisSet{},
233                                                           ngraph::AxisSet{});
234     }
235
236     left_zero_point = m_inputs.at(2).get_node_shared_ptr();
237     if (num_inputs == 4)
238     {
239         right_zero_point = m_inputs.at(3).get_node_shared_ptr();
240     }
241
242     return std::make_shared<ngraph::op::QuantizedDot>(left,
243                                                       right,
244                                                       1,
245                                                       scale_one,
246                                                       left_zero_point,
247                                                       scale_one,
248                                                       right_zero_point,
249                                                       scale_one,
250                                                       output_zero_point,
251                                                       ngraph::element::i32,
252                                                       ngraph::AxisSet{},
253                                                       ngraph::AxisSet{},
254                                                       ngraph::AxisSet{});
255 }