Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / core / reference / include / ngraph / runtime / reference / matmul.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #pragma once
18
19 #include <cmath>
20 #include <numeric>
21 #include <utility>
22 #include <vector>
23
24 #include "ngraph/axis_vector.hpp"
25 #include "ngraph/builder/autobroadcast.hpp"
26 #include "ngraph/runtime/opt_kernel/reshape.hpp"
27 #include "ngraph/runtime/reference/broadcast.hpp"
28 #include "ngraph/runtime/reference/dot.hpp"
29 #include "ngraph/shape_util.hpp"
30
31 NGRAPH_SUPPRESS_DEPRECATED_START
32
33 using namespace std;
34
35 namespace ngraph
36 {
37     namespace runtime
38     {
39         namespace reference
40         {
41             /// \brief Reference kernel for matmul computation.
42             ///
43             /// \tparam T Type of input and output tensors.
44             ///
45             /// \param arg0 Pointer to the buffer for left operand input tensor.
46             /// \param arg1 Pointer to the buffer for right operand input tensor.
47             /// \param out Pointer to the buffer for output tensor. This must be pre-allocated by
48             ///            the caller, and must be large enough to hold a tensor of the correct
49             ///            shape.
50             /// \param arg0_shape Shape of arg0.
51             /// \param arg1_shape Shape of arg1.
52             /// \param out_shape Shape of out.
53             /// \param transpose_arg0 Flag to indicate if transpose on arg0.
54             /// \param transpose_arg1 Flag to indicate if transpose on arg1.
55             template <typename T>
56             void matmul(const T* arg0,
57                         const T* arg1,
58                         T* out,
59                         const Shape& arg0_shape,
60                         const Shape& arg1_shape,
61                         const Shape& out_shape,
62                         bool transpose_arg0,
63                         bool transpose_arg1)
64             {
65                 // Steps to compute matmul:
66                 // 1) Check inputs and perform transpose on arg if applicable
67                 // 2) If ranks of both args are 2D and below (no batch dim),
68                 //    perform dot and return result; otherwise, continue next
69                 // 3) Check if auto broadcast is needed on args or transposed args,
70                 //    and perform broadcast if applicable
71                 // 4) Perform dot on the args or updated args and return result
72
73                 size_t arg0_rank = arg0_shape.size();
74                 size_t arg1_rank = arg1_shape.size();
75                 size_t out_rank = out_shape.size();
76
77                 // vector vars to hold pontential intermediate transpose,
78                 // broadcast result
79                 vector<T> arg0_transpose_vec;
80                 vector<T> arg1_transpose_vec;
81                 vector<T> arg0_broadcast_vec;
82                 vector<T> arg1_broadcast_vec;
83
84                 // pointers to updated inputs
85                 const T* arg0_update = arg0;
86                 const T* arg1_update = arg1;
87
88                 // vars for updated inputs shapes
89                 Shape wip_arg0_shape = arg0_shape;
90                 Shape wip_arg1_shape = arg1_shape;
91
92                 auto get_transpose_order = [](const Shape& input_shape) {
93                     size_t rank = input_shape.size();
94                     NGRAPH_CHECK(rank > 1, "Invalid input for transpose");
95                     vector<size_t> axes_order(rank);
96                     iota(axes_order.begin(), axes_order.end(), 0);
97                     swap(axes_order[rank - 1], axes_order[rank - 2]);
98                     return AxisVector{begin(axes_order), end(axes_order)};
99                 };
100
101                 auto get_broadcast_axes = [](const Shape& marker_shape, const Shape& target_shape) {
102                     NGRAPH_CHECK(marker_shape.size() == target_shape.size(),
103                                  "Incompatible input shapes");
104                     AxisSet broadcast_axes;
105                     for (size_t i = 0; i < marker_shape.size(); i++)
106                     {
107                         if (marker_shape[i] == 1 && target_shape[i] != 1)
108                         {
109                             broadcast_axes.insert(i);
110                         }
111                     }
112                     return broadcast_axes;
113                 };
114
115                 // Perform transpose if requested
116                 if (transpose_arg0 && arg0_rank > 1)
117                 {
118                     arg0_transpose_vec.reserve(shape_size(arg0_shape));
119                     auto axis_vector = get_transpose_order(arg0_shape);
120                     swap(wip_arg0_shape[arg0_rank - 1], wip_arg0_shape[arg0_rank - 2]);
121                     opt_kernel::reshape(reinterpret_cast<const char*>(arg0),
122                                         reinterpret_cast<char*>(arg0_transpose_vec.data()),
123                                         arg0_shape,
124                                         axis_vector,
125                                         wip_arg0_shape,
126                                         sizeof(T));
127
128                     arg0_update = arg0_transpose_vec.data();
129                 }
130
131                 if (transpose_arg1 && arg1_rank > 1)
132                 {
133                     arg1_transpose_vec.reserve(shape_size(arg1_shape));
134                     auto axis_vector = get_transpose_order(arg1_shape);
135                     swap(wip_arg1_shape[arg1_rank - 1], wip_arg1_shape[arg1_rank - 2]);
136                     opt_kernel::reshape(reinterpret_cast<const char*>(arg1),
137                                         reinterpret_cast<char*>(arg1_transpose_vec.data()),
138                                         arg1_shape,
139                                         axis_vector,
140                                         wip_arg1_shape,
141                                         sizeof(T));
142
143                     arg1_update = arg1_transpose_vec.data();
144                 }
145
146                 // Inputs are 2D and below, perform dot directly
147                 if (arg0_rank <= 2 && arg1_rank <= 2)
148                 {
149                     return dot(arg0_update,
150                                arg1_update,
151                                out,
152                                wip_arg0_shape,
153                                wip_arg1_shape,
154                                out_shape,
155                                1);
156                 }
157
158                 // Check and perform auto-broadcast if needed
159                 // If one of the arg is 2D or below, no need to
160                 // do broadcast on it, just use its value for
161                 // every batch of dot compuatation later
162
163                 if (arg0_rank > 2 && arg1_rank > 2)
164                 {
165                     const auto& broadcast_shapes = builder::get_numpy_broadcast_shapes(
166                         {Shape{begin(wip_arg0_shape), next(end(wip_arg0_shape), -2)},
167                          Shape{begin(wip_arg1_shape), next(end(wip_arg1_shape), -2)}});
168
169                     Shape arg0_br_target_shape = broadcast_shapes.first;
170                     Shape arg1_br_target_shape = broadcast_shapes.first;
171                     Shape arg0_br_marker_shape = broadcast_shapes.second.at(0);
172                     Shape arg1_br_marker_shape = broadcast_shapes.second.at(1);
173
174                     arg0_br_target_shape.insert(
175                         end(arg0_br_target_shape),
176                         next(begin(wip_arg0_shape), wip_arg0_shape.size() - 2),
177                         end(wip_arg0_shape));
178                     arg1_br_target_shape.insert(
179                         end(arg1_br_target_shape),
180                         next(begin(wip_arg1_shape), wip_arg1_shape.size() - 2),
181                         end(wip_arg1_shape));
182
183                     arg0_br_marker_shape.insert(
184                         end(arg0_br_marker_shape),
185                         next(begin(wip_arg0_shape), wip_arg0_shape.size() - 2),
186                         end(wip_arg0_shape));
187                     arg1_br_marker_shape.insert(
188                         end(arg1_br_marker_shape),
189                         next(begin(wip_arg1_shape), wip_arg1_shape.size() - 2),
190                         end(wip_arg1_shape));
191
192                     if (arg0_br_target_shape != wip_arg0_shape)
193                     {
194                         auto broadcast_axes =
195                             get_broadcast_axes(arg0_br_marker_shape, arg0_br_target_shape);
196                         if (!broadcast_axes.empty())
197                         {
198                             arg0_broadcast_vec.reserve(shape_size(arg0_br_target_shape));
199                             broadcast(arg0_update,
200                                       arg0_broadcast_vec.data(),
201                                       wip_arg0_shape,
202                                       arg0_br_target_shape,
203                                       broadcast_axes);
204
205                             arg0_update = arg0_broadcast_vec.data();
206                             wip_arg0_shape = arg0_br_target_shape;
207                             arg0_rank = wip_arg0_shape.size();
208                         }
209                     }
210
211                     if (arg1_br_target_shape != wip_arg1_shape)
212                     {
213                         auto broadcast_axes =
214                             get_broadcast_axes(arg1_br_marker_shape, arg1_br_target_shape);
215                         if (!broadcast_axes.empty())
216                         {
217                             arg1_broadcast_vec.reserve(shape_size(arg1_br_target_shape));
218                             broadcast(arg1_update,
219                                       arg1_broadcast_vec.data(),
220                                       wip_arg1_shape,
221                                       arg1_br_target_shape,
222                                       broadcast_axes);
223
224                             arg1_update = arg1_broadcast_vec.data();
225                             wip_arg1_shape = arg1_br_target_shape;
226                             arg1_rank = wip_arg1_shape.size();
227                         }
228                     }
229                 }
230
231                 // Perform batched dot
232
233                 size_t output_batch_size = 1;
234
235                 // Calculate number of batches
236                 if (out_rank < 3)
237                 {
238                     // Output is {batch_size, dot_result}, i.e.,
239                     // arg 0 shape {2}, arg1 shape {3, 2, 1}, output shape {3, 1}
240                     output_batch_size = out_shape[0];
241                 }
242                 else
243                 {
244                     for (size_t i = 0; i < (out_rank - 2); i++)
245                     {
246                         output_batch_size *= out_shape[i];
247                     }
248                 }
249
250                 Shape dot_arg0_shape = (arg0_rank > 2) ? Shape{wip_arg0_shape[arg0_rank - 2],
251                                                                wip_arg0_shape[arg0_rank - 1]}
252                                                        : wip_arg0_shape;
253                 Shape dot_arg1_shape = (arg1_rank > 2) ? Shape{wip_arg1_shape[arg1_rank - 2],
254                                                                wip_arg1_shape[arg1_rank - 1]}
255                                                        : wip_arg1_shape;
256                 Shape dot_output_shape =
257                     (out_rank > 2) ? Shape{out_shape[out_rank - 2], out_shape[out_rank - 1]}
258                                    : Shape{out_shape[out_rank - 1]};
259
260                 const size_t arg0_offset = (arg0_rank > 2) ? shape_size(dot_arg0_shape) : 0;
261                 const size_t arg1_offset = (arg1_rank > 2) ? shape_size(dot_arg1_shape) : 0;
262                 const size_t output_offset = shape_size(dot_output_shape);
263                 for (size_t i = 0; i < output_batch_size; i++)
264                 {
265                     dot(arg0_update + i * arg0_offset,
266                         arg1_update + i * arg1_offset,
267                         out + i * output_offset,
268                         dot_arg0_shape,
269                         dot_arg1_shape,
270                         dot_output_shape,
271                         1);
272                 }
273             }
274         }
275     }
276 }
277
278 NGRAPH_SUPPRESS_DEPRECATED_END