2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2015 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_HELPER_BCAST_H__
19 #define __NNFW_CKER_HELPER_BCAST_H__
22 * ToDo : This file will be moved into upper folder when integrate with other
24 * And It should merged with EinsumHelper's BCast.
27 #include "cker/Shape.h"
28 #include "cker/eigen/EigenSupport.h"
34 // Returns the mapping from the output batch indices to the corresponding
35 // input's batch indices, given the input's "reshape" and "bcast" shapes as
36 // returned by the BCastList helper class. The i'th element denotes the
37 // (flattened) batch index of the input that must be used to compute the i'th
40 inline void ComputeBatchIndices(const int32_t output_batch_size,
41 const std::vector<int32_t> &reshape,
42 const std::vector<int32_t> &bcast,
43 std::vector<int32_t> *out_indices)
45 // Populates the mapping in out_indices. This algorithm is identical to
46 // the following steps:
47 // - Reshape {0, 1, ..., input_batch_size - 1} to the input shape.
48 // - Broadcast to the output shape.
49 // - Reshape back to a flat 1D vector.
50 out_indices->resize(output_batch_size);
51 int32_t num_output_elements = 1;
52 int32_t num_input_elements = 1;
53 for (int32_t i = reshape.size() - 1; i >= 0; --i)
55 // Replicate the already populated mapping an additional (dim - 1) times.
56 // If we are broadcasting, just copy the existing mapping.
57 // Otherwise, add another dimension from the input shape.
58 const int32_t dim = std::max(reshape[i], bcast[i]);
59 const int32_t incr = bcast[i] > 1 ? 0 : num_input_elements;
60 for (int32_t k = 0; k < (dim - 1) * num_output_elements; ++k)
62 (*out_indices)[num_output_elements + k] = (*out_indices)[k] + incr;
64 num_output_elements *= dim;
65 num_input_elements *= reshape[i];
69 template <int N> class BCastList
72 // A vector of int32_t representing the shape of tensor. The 0-th
73 // element is the outer-most dimension and the last element is the
74 // inner-most dimension. Note that we do not use Shape since
75 // it's more convenient to manipulate Vec directly for this module.
76 typedef std::vector<int32_t> Vec;
78 // Constructs all helper shapes, following the aforementioned rules.
80 // If "fewer_dims_optimization" is set to true (the default), the
81 // implementation tries to reduce intermediate dimensions needed to be more
82 // efficient. This is transparent to the caller.
84 // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have
85 // the same number of dimensions as the larger of the two inputs.
87 // If return_flattened_batch_indices is true, the implementation will compute
88 // for each output member of the flattened output, which batch indices of
89 // each input correspond to it. This is disabled by default.
90 explicit BCastList(const Vec (&x)[N], const bool fewer_dims_optimization = true,
91 const bool return_flattened_batch_indices = false);
94 // Returns true iff two operands are compatible according to the
96 bool IsValid() const { return valid_; }
97 bool IsBroadcastingRequired() const { return broadcasting_required_; }
99 // If and only if IsValid(), the following fields can be used in
100 // implementing a broadcasted binary tensor operation according to
101 // the broadcasting rule.
102 const Vec &reshape(int i) const { return reshape_[i]; }
103 const Vec &bcast(int i) const { return bcast_[i]; }
104 const Vec &result_shape() const { return result_; }
105 const Vec &output_shape() const { return output_; }
106 const Vec &grad_reduce_idx(int i) const { return grad_reduce_idx_[i]; }
107 int32_t output_batch_size() const { return output_batch_size_; }
109 // Returns the mapping from the flattened output batch indices to x's
110 // flattened batch indices. The result is a vector of length
111 // output_batch_size(). To compute the i'th batch output, a binary matmul-like
112 // operation should use the `x_batch_indices()[i]`th batch index of `x`.
113 // Note: Returns an empty vector if broadcasting is not required. Callers
114 // should only use this when IsBroadcastingRequired() returns true.
115 const std::vector<int32_t> &batch_indices(int i) const { return batch_indices_[i]; }
119 bool broadcasting_required_ = true;
124 Vec grad_reduce_idx_[N];
126 int32_t output_batch_size_;
127 std::vector<int32_t> batch_indices_[N];
129 static void Reverse(Vec *shape) { std::reverse(shape->begin(), shape->end()); }
133 BCastList<N>::BCastList(const BCastList::Vec (&x)[N], const bool fewer_dims_optimization,
134 const bool return_flattened_batch_indices)
136 typedef BCastList::Vec Vec;
137 bool all_equal = true;
138 size_t largest_rank = 0;
139 output_batch_size_ = 1;
140 for (int i = 0; i < N; ++i)
146 if (x[i].size() > largest_rank)
148 largest_rank = x[i].size();
153 broadcasting_required_ = false;
155 if (all_equal && fewer_dims_optimization)
157 // Fast path for common case of identical shapes.
158 int32_t elements = 1;
159 const int rank = x[0].size();
160 output_.resize(rank);
161 for (int i = 0; i < rank; i++)
163 const int32_t dim = x[0][i];
167 result_.push_back(elements);
168 output_batch_size_ = elements;
169 for (int i = 0; i < N; ++i)
171 reshape_[i].push_back(elements);
172 bcast_[i].push_back(1);
174 // grad_reduce_ is left as empty
178 // Reverse all the shapes for convenience
179 // After the reverse, 0-th is the inner-most dimension.
181 for (int i = 0; i < N; ++i)
187 // 1-extend and align all vectors.
188 for (int i = 0; i < N; ++i)
190 if (copy[i].size() < largest_rank)
192 copy[i].resize(largest_rank, 1);
195 // Going through each dimension starting from the inner-most
196 // dimension, compares dimension of x and y. They are compatible if
197 // they are equal or either is 1.
199 // indices of j-th component of each input.
201 bool current_is_one[N];
202 for (int i = 0; i < N; ++i)
204 prev_is_one[i] = false;
205 current_is_one[i] = false;
208 bool output_dim_set = false;
210 bool none_is_one = true;
211 bool set_one = false;
212 for (size_t j = 0; j < largest_rank; ++j)
215 output_dim_set = false;
217 // Find which indices are 1.
218 for (int i = 0; i < N; ++i)
220 // Keep track of which indices are 1.
223 current_is_one[i] = true;
228 current_is_one[i] = false;
229 if (!output_dim_set || copy[i][j] == output_dim)
231 output_dim = copy[i][j];
232 output_dim_set = true;
241 output_.push_back(output_dim_set ? output_dim : 1);
242 output_batch_size_ *= output_.back();
243 // All dimensions are 1.
246 if (!fewer_dims_optimization)
248 for (int i = 0; i < N; ++i)
250 bcast_[i].push_back(1);
251 reshape_[i].push_back(1);
253 result_.push_back(1);
255 for (int i = 0; i < N; ++i)
257 grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
259 // This will skip updating the previous state to the current one. We'll
260 // explain why this is safe below.
261 // Consider the previous state P, current state C and the next state N.
262 // In the case where N also is all ones (N == C), we'll do the same
263 // optimization here (push back one dimensions if we need to), which is
264 // safe and is expected.
266 // When N != C, we'll continue as usual. However, we might trigger the
267 // next block if N == P (because we didn't update the previous state).
268 // We trigger the next block if `fewer_dims_optimization` is true.
269 // This means that we did not modify and broadcast / reshapes in this
270 // block (we skipped updating, since the one dimensions can be ignored).
271 // In essence, we only need to check whether the previous non-one state is
272 // equal to the current non-one state.
276 else if ((fewer_dims_optimization) &&
277 std::equal(current_is_one, current_is_one + N, prev_is_one) && set_one)
279 // It is a run of the same broadcasting case as last time.
280 // We can reshape the input so that fewer dimensions
281 // are involved in the intermediate computation.
282 result_.back() *= output_dim;
283 for (int i = 0; i < N; ++i)
285 reshape_[i].back() *= copy[i][j];
286 bcast_[i].back() *= current_is_one[i] ? output_dim : 1;
287 if (current_is_one[i] && !none_is_one)
289 grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
295 result_.push_back(output_dim);
296 for (int i = 0; i < N; ++i)
298 reshape_[i].push_back(copy[i][j]);
299 bcast_[i].push_back(current_is_one[i] ? output_dim : 1);
300 if (current_is_one[i] && !none_is_one)
302 grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
307 for (int i = 0; i < N; ++i)
309 prev_is_one[i] = current_is_one[i];
314 result_.push_back(1);
315 for (int i = 0; i < N; ++i)
317 reshape_[i].push_back(1);
318 bcast_[i].push_back(1);
321 // Do something about batches.
322 for (int i = 0; i < N; ++i)
324 Reverse(&reshape_[i]);
326 Reverse(&grad_reduce_idx_[i]);
330 // Only compute batch indices when we need broadcasting, and we aren't doing
331 // needless work (when the output size is 0 or the
332 // return_flattened_batch_indices isn't enabled).
333 if (return_flattened_batch_indices && broadcasting_required_ && output_batch_size_ > 0)
335 for (int i = 0; i < N; ++i)
337 ComputeBatchIndices(output_batch_size_, reshape_[i], bcast_[i], &batch_indices_[i]);
342 // BCast is a helper for broadcasting binary tensor operation.
343 // TensorFlow's broadcasting rule follows that of numpy (See
344 // http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
346 // The rule has the following properties:
348 // 1. suffix matching: the rule starts with the right-most
349 // dimension, and works towards the left-most dimension. Since
350 // TensorFlow is row-major, the right-most dimension (the last
351 // element in the shape of a tensor) is the inner-most, a.k.a.
352 // the fastest changing, dimension.
354 // 2. Two dimensions are compatible for broadcasting if both are the
355 // same or either is 1.
357 // BCast takes the shape of two tensors and computes a few vectors of
358 // int32 that are useful for the caller to reshape the tensors, apply
359 // the right broadcasts to them, compute the broadcasted operation,
360 // and possibly the gradients. In a nutshell, the caller is expected
361 // to compute the broadcasted operation as following:
363 // BCast b(x.shape(), y.shape());
364 // output = x.reshape(b.x_reshape()).broadcast(b.x_bcast())
366 // y.reshape(b.y_reshape()).broadcast(b.y_bcast())
368 // For the gradient computation,
369 // grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx)
370 // .reshape(x.shape())
371 // grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx)
372 // .reshape(y.shape())
373 // backprop_x and backprop_y are functionals of the binary function "op",
375 // for +, backprop_x(x, y) = backprop_y(x, y) = 1;
376 // for *, backprop_x(x, y) = y, backprop_y(x, y) = x;
377 // for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2;
379 // The multiplication in the grad * backprop_x itself is also
380 // broadcasting following the same rule.
381 class BCast : public BCastList<2>
384 // Constructs all helper shapes, following the aforementioned rules.
386 // If "fewer_dims_optimization" is set to true (the default), the
387 // implementation tries to reduce intermediate dimensions needed to be more
388 // efficient. This is transparent to the caller.
390 // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have
391 // the same number of dimensions as the larger of the two inputs.
392 typedef std::vector<int32_t> Vec;
394 BCast(const Vec &x, const Vec &y, const bool fewer_dims_optimization = true,
395 const bool return_flattened_batch_indices = false)
396 : BCastList<2>({x, y}, fewer_dims_optimization, return_flattened_batch_indices)
402 // If and only if IsValid(), the following fields can be used in
403 // implementing a broadcasted binary tensor operation according to
404 // the broadcasting rule.
405 const Vec &x_reshape() const { return reshape_[0]; }
406 const Vec &x_bcast() const { return bcast_[0]; }
407 const Vec &y_reshape() const { return reshape_[1]; }
408 const Vec &y_bcast() const { return bcast_[1]; }
409 const Vec &result_shape() const { return result_; }
410 const Vec &output_shape() const { return output_; }
411 const Vec &grad_x_reduce_idx() const { return grad_reduce_idx_[0]; }
412 const Vec &grad_y_reduce_idx() const { return grad_reduce_idx_[1]; }
414 // Returns the mapping from the flattened output batch indices to x's
415 // flattened batch indices. The result is a vector of length
416 // output_batch_size(). To compute the i'th batch output, a binary matmul-like
417 // operation should use the `x_batch_indices()[i]`th batch index of `x`.
418 // Note: Returns an empty vector if broadcasting is not required. Callers
419 // should only use this when IsBroadcastingRequired() returns true.
420 const std::vector<int32_t> &x_batch_indices() const { return batch_indices_[0]; }
421 // Returns the mapping from the flattened output batch indices to y's
422 // flattened batch indices. Similar to x_batch_indices().
423 // Note: Returns an empty vector if broadcasting is not required. Callers
424 // should only use this when IsBroadcastingRequired() returns true.
425 const std::vector<int32_t> &y_batch_indices() const { return batch_indices_[1]; }
427 template <typename IndexType, int NDIMS>
428 static Eigen::array<IndexType, NDIMS> ToIndexArrayType(const BCast::Vec &vec)
430 assert(vec.size() == NDIMS);
431 Eigen::array<IndexType, NDIMS> ret;
432 for (int i = 0; i < NDIMS; ++i)
438 static Eigen::array<Eigen::DenseIndex, NDIMS> ToIndexArray(const BCast::Vec &vec)
440 return ToIndexArrayType<Eigen::DenseIndex, NDIMS>(vec);
444 static Vec FromShape(const Shape &shape)
446 const int N = shape.DimensionsCount();
447 BCastList::Vec ret(N);
448 for (int i = 0; i < N; ++i)
450 ret[i] = shape.Dims(i);
455 static Shape ToShape(const BCastList::Vec &vec)
457 const int N = vec.size();
460 for (int i = 0; i < N; ++i)
462 shape.SetDim(i, vec[i]);
471 #endif // __NNFW_CKER_HELPER_BCAST_H__