1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
24 #include "ngraph/axis_vector.hpp"
25 #include "ngraph/builder/autobroadcast.hpp"
26 #include "ngraph/runtime/opt_kernel/reshape.hpp"
27 #include "ngraph/runtime/reference/broadcast.hpp"
28 #include "ngraph/runtime/reference/dot.hpp"
29 #include "ngraph/shape_util.hpp"
31 NGRAPH_SUPPRESS_DEPRECATED_START
41 /// \brief Reference kernel for matmul computation.
43 /// \tparam T Type of input and output tensors.
45 /// \param arg0 Pointer to the buffer for left operand input tensor.
46 /// \param arg1 Pointer to the buffer for right operand input tensor.
47 /// \param out Pointer to the buffer for output tensor. This must be pre-allocated by
48 /// the caller, and must be large enough to hold a tensor of the correct
50 /// \param arg0_shape Shape of arg0.
51 /// \param arg1_shape Shape of arg1.
52 /// \param out_shape Shape of out.
53 /// \param transpose_arg0 Flag to indicate if transpose on arg0.
54 /// \param transpose_arg1 Flag to indicate if transpose on arg1.
56 void matmul(const T* arg0,
59 const Shape& arg0_shape,
60 const Shape& arg1_shape,
61 const Shape& out_shape,
65 // Steps to compute matmul:
66 // 1) Check inputs and perform transpose on arg if applicable
67 // 2) If ranks of both args are 2D and below (no batch dim),
68 // perform dot and return result; otherwise, continue next
69 // 3) Check if auto broadcast is needed on args or transposed args,
70 // and perform broadcast if applicable
71 // 4) Perform dot on the args or updated args and return result
73 size_t arg0_rank = arg0_shape.size();
74 size_t arg1_rank = arg1_shape.size();
75 size_t out_rank = out_shape.size();
77 // vector vars to hold pontential intermediate transpose,
79 vector<T> arg0_transpose_vec;
80 vector<T> arg1_transpose_vec;
81 vector<T> arg0_broadcast_vec;
82 vector<T> arg1_broadcast_vec;
84 // pointers to updated inputs
85 const T* arg0_update = arg0;
86 const T* arg1_update = arg1;
88 // vars for updated inputs shapes
89 Shape wip_arg0_shape = arg0_shape;
90 Shape wip_arg1_shape = arg1_shape;
92 auto get_transpose_order = [](const Shape& input_shape) {
93 size_t rank = input_shape.size();
94 NGRAPH_CHECK(rank > 1, "Invalid input for transpose");
95 vector<size_t> axes_order(rank);
96 iota(axes_order.begin(), axes_order.end(), 0);
97 swap(axes_order[rank - 1], axes_order[rank - 2]);
98 return AxisVector{begin(axes_order), end(axes_order)};
101 auto get_broadcast_axes = [](const Shape& marker_shape, const Shape& target_shape) {
102 NGRAPH_CHECK(marker_shape.size() == target_shape.size(),
103 "Incompatible input shapes");
104 AxisSet broadcast_axes;
105 for (size_t i = 0; i < marker_shape.size(); i++)
107 if (marker_shape[i] == 1 && target_shape[i] != 1)
109 broadcast_axes.insert(i);
112 return broadcast_axes;
115 // Perform transpose if requested
116 if (transpose_arg0 && arg0_rank > 1)
118 arg0_transpose_vec.reserve(shape_size(arg0_shape));
119 auto axis_vector = get_transpose_order(arg0_shape);
120 swap(wip_arg0_shape[arg0_rank - 1], wip_arg0_shape[arg0_rank - 2]);
121 opt_kernel::reshape(reinterpret_cast<const char*>(arg0),
122 reinterpret_cast<char*>(arg0_transpose_vec.data()),
128 arg0_update = arg0_transpose_vec.data();
131 if (transpose_arg1 && arg1_rank > 1)
133 arg1_transpose_vec.reserve(shape_size(arg1_shape));
134 auto axis_vector = get_transpose_order(arg1_shape);
135 swap(wip_arg1_shape[arg1_rank - 1], wip_arg1_shape[arg1_rank - 2]);
136 opt_kernel::reshape(reinterpret_cast<const char*>(arg1),
137 reinterpret_cast<char*>(arg1_transpose_vec.data()),
143 arg1_update = arg1_transpose_vec.data();
146 // Inputs are 2D and below, perform dot directly
147 if (arg0_rank <= 2 && arg1_rank <= 2)
149 return dot(arg0_update,
158 // Check and perform auto-broadcast if needed
159 // If one of the arg is 2D or below, no need to
160 // do broadcast on it, just use its value for
161 // every batch of dot compuatation later
163 if (arg0_rank > 2 && arg1_rank > 2)
165 const auto& broadcast_shapes = builder::get_numpy_broadcast_shapes(
166 {Shape{begin(wip_arg0_shape), next(end(wip_arg0_shape), -2)},
167 Shape{begin(wip_arg1_shape), next(end(wip_arg1_shape), -2)}});
169 Shape arg0_br_target_shape = broadcast_shapes.first;
170 Shape arg1_br_target_shape = broadcast_shapes.first;
171 Shape arg0_br_marker_shape = broadcast_shapes.second.at(0);
172 Shape arg1_br_marker_shape = broadcast_shapes.second.at(1);
174 arg0_br_target_shape.insert(
175 end(arg0_br_target_shape),
176 next(begin(wip_arg0_shape), wip_arg0_shape.size() - 2),
177 end(wip_arg0_shape));
178 arg1_br_target_shape.insert(
179 end(arg1_br_target_shape),
180 next(begin(wip_arg1_shape), wip_arg1_shape.size() - 2),
181 end(wip_arg1_shape));
183 arg0_br_marker_shape.insert(
184 end(arg0_br_marker_shape),
185 next(begin(wip_arg0_shape), wip_arg0_shape.size() - 2),
186 end(wip_arg0_shape));
187 arg1_br_marker_shape.insert(
188 end(arg1_br_marker_shape),
189 next(begin(wip_arg1_shape), wip_arg1_shape.size() - 2),
190 end(wip_arg1_shape));
192 if (arg0_br_target_shape != wip_arg0_shape)
194 auto broadcast_axes =
195 get_broadcast_axes(arg0_br_marker_shape, arg0_br_target_shape);
196 if (!broadcast_axes.empty())
198 arg0_broadcast_vec.reserve(shape_size(arg0_br_target_shape));
199 broadcast(arg0_update,
200 arg0_broadcast_vec.data(),
202 arg0_br_target_shape,
205 arg0_update = arg0_broadcast_vec.data();
206 wip_arg0_shape = arg0_br_target_shape;
207 arg0_rank = wip_arg0_shape.size();
211 if (arg1_br_target_shape != wip_arg1_shape)
213 auto broadcast_axes =
214 get_broadcast_axes(arg1_br_marker_shape, arg1_br_target_shape);
215 if (!broadcast_axes.empty())
217 arg1_broadcast_vec.reserve(shape_size(arg1_br_target_shape));
218 broadcast(arg1_update,
219 arg1_broadcast_vec.data(),
221 arg1_br_target_shape,
224 arg1_update = arg1_broadcast_vec.data();
225 wip_arg1_shape = arg1_br_target_shape;
226 arg1_rank = wip_arg1_shape.size();
231 // Perform batched dot
233 size_t output_batch_size = 1;
235 // Calculate number of batches
238 // Output is {batch_size, dot_result}, i.e.,
239 // arg 0 shape {2}, arg1 shape {3, 2, 1}, output shape {3, 1}
240 output_batch_size = out_shape[0];
244 for (size_t i = 0; i < (out_rank - 2); i++)
246 output_batch_size *= out_shape[i];
250 Shape dot_arg0_shape = (arg0_rank > 2) ? Shape{wip_arg0_shape[arg0_rank - 2],
251 wip_arg0_shape[arg0_rank - 1]}
253 Shape dot_arg1_shape = (arg1_rank > 2) ? Shape{wip_arg1_shape[arg1_rank - 2],
254 wip_arg1_shape[arg1_rank - 1]}
256 Shape dot_output_shape =
257 (out_rank > 2) ? Shape{out_shape[out_rank - 2], out_shape[out_rank - 1]}
258 : Shape{out_shape[out_rank - 1]};
260 const size_t arg0_offset = (arg0_rank > 2) ? shape_size(dot_arg0_shape) : 0;
261 const size_t arg1_offset = (arg1_rank > 2) ? shape_size(dot_arg1_shape) : 0;
262 const size_t output_offset = shape_size(dot_output_shape);
263 for (size_t i = 0; i < output_batch_size; i++)
265 dot(arg0_update + i * arg0_offset,
266 arg1_update + i * arg1_offset,
267 out + i * output_offset,
278 NGRAPH_SUPPRESS_DEPRECATED_END