a0abf29351b23ac3925296391c325c8ab1ce341a
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / Helper / BCast.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2015 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_HELPER_BCAST_H__
19 #define __NNFW_CKER_HELPER_BCAST_H__
20
21 /**
22  * ToDo : This file will be moved into upper folder when integrate with other
23  *        custom operations.
24  *        And It should merged with EinsumHelper's BCast.
25 **/
26
27 #include "cker/Shape.h"
28 #include "cker/eigen/EigenSupport.h"
29
30 namespace nnfw
31 {
32 namespace cker
33 {
34 // Returns the mapping from the output batch indices to the corresponding
35 // input's batch indices, given the input's "reshape" and "bcast" shapes as
36 // returned by the BCastList helper class. The i'th element denotes the
37 // (flattened) batch index of the input that must be used to compute the i'th
38 // batch output.
39 //
40 inline void ComputeBatchIndices(const int32_t output_batch_size,
41                                 const std::vector<int32_t> &reshape,
42                                 const std::vector<int32_t> &bcast,
43                                 std::vector<int32_t> *out_indices)
44 {
45   // Populates the mapping in out_indices. This algorithm is identical to
46   // the following steps:
47   //  - Reshape {0, 1, ..., input_batch_size - 1} to the input shape.
48   //  - Broadcast to the output shape.
49   //  - Reshape back to a flat 1D vector.
50   out_indices->resize(output_batch_size);
51   int32_t num_output_elements = 1;
52   int32_t num_input_elements = 1;
53   for (int32_t i = reshape.size() - 1; i >= 0; --i)
54   {
55     // Replicate the already populated mapping an additional (dim - 1) times.
56     // If we are broadcasting, just copy the existing mapping.
57     // Otherwise, add another dimension from the input shape.
58     const int32_t dim = std::max(reshape[i], bcast[i]);
59     const int32_t incr = bcast[i] > 1 ? 0 : num_input_elements;
60     for (int32_t k = 0; k < (dim - 1) * num_output_elements; ++k)
61     {
62       (*out_indices)[num_output_elements + k] = (*out_indices)[k] + incr;
63     }
64     num_output_elements *= dim;
65     num_input_elements *= reshape[i];
66   }
67 }
68
69 template <int N> class BCastList
70 {
71 public:
72   // A vector of int32_t representing the shape of tensor. The 0-th
73   // element is the outer-most dimension and the last element is the
74   // inner-most dimension. Note that we do not use Shape since
75   // it's more convenient to manipulate Vec directly for this module.
76   typedef std::vector<int32_t> Vec;
77
78   // Constructs all helper shapes, following the aforementioned rules.
79   //
80   // If "fewer_dims_optimization" is set to true (the default), the
81   // implementation tries to reduce intermediate dimensions needed to be more
82   // efficient.  This is transparent to the caller.
83   //
84   // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have
85   // the same number of dimensions as the larger of the two inputs.
86   //
87   // If return_flattened_batch_indices is true, the implementation will compute
88   // for each output member of the flattened output, which batch indices of
89   // each input correspond to it. This is disabled by default.
90   explicit BCastList(const Vec (&x)[N], const bool fewer_dims_optimization = true,
91                      const bool return_flattened_batch_indices = false);
92   ~BCastList() {}
93
94   // Returns true iff two operands are compatible according to the
95   // broadcasting rule.
96   bool IsValid() const { return valid_; }
97   bool IsBroadcastingRequired() const { return broadcasting_required_; }
98
99   // If and only if IsValid(), the following fields can be used in
100   // implementing a broadcasted binary tensor operation according to
101   // the broadcasting rule.
102   const Vec &reshape(int i) const { return reshape_[i]; }
103   const Vec &bcast(int i) const { return bcast_[i]; }
104   const Vec &result_shape() const { return result_; }
105   const Vec &output_shape() const { return output_; }
106   const Vec &grad_reduce_idx(int i) const { return grad_reduce_idx_[i]; }
107   int32_t output_batch_size() const { return output_batch_size_; }
108
109   // Returns the mapping from the flattened output batch indices to x's
110   // flattened batch indices. The result is a vector of length
111   // output_batch_size(). To compute the i'th batch output, a binary matmul-like
112   // operation should use the `x_batch_indices()[i]`th batch index of `x`.
113   // Note: Returns an empty vector if broadcasting is not required. Callers
114   // should only use this when IsBroadcastingRequired() returns true.
115   const std::vector<int32_t> &batch_indices(int i) const { return batch_indices_[i]; }
116
117 protected:
118   bool valid_ = true;
119   bool broadcasting_required_ = true;
120   Vec reshape_[N];
121   Vec bcast_[N];
122   Vec result_;
123   Vec output_;
124   Vec grad_reduce_idx_[N];
125
126   int32_t output_batch_size_;
127   std::vector<int32_t> batch_indices_[N];
128
129   static void Reverse(Vec *shape) { std::reverse(shape->begin(), shape->end()); }
130 }; //  BCastList<N>
131
132 template <int N>
133 BCastList<N>::BCastList(const BCastList::Vec (&x)[N], const bool fewer_dims_optimization,
134                         const bool return_flattened_batch_indices)
135 {
136   typedef BCastList::Vec Vec;
137   bool all_equal = true;
138   size_t largest_rank = 0;
139   output_batch_size_ = 1;
140   for (int i = 0; i < N; ++i)
141   {
142     if (x[i] != x[0])
143     {
144       all_equal = false;
145     }
146     if (x[i].size() > largest_rank)
147     {
148       largest_rank = x[i].size();
149     }
150   }
151   if (all_equal)
152   {
153     broadcasting_required_ = false;
154   }
155   if (all_equal && fewer_dims_optimization)
156   {
157     // Fast path for common case of identical shapes.
158     int32_t elements = 1;
159     const int rank = x[0].size();
160     output_.resize(rank);
161     for (int i = 0; i < rank; i++)
162     {
163       const int32_t dim = x[0][i];
164       elements *= dim;
165       output_[i] = dim;
166     }
167     result_.push_back(elements);
168     output_batch_size_ = elements;
169     for (int i = 0; i < N; ++i)
170     {
171       reshape_[i].push_back(elements);
172       bcast_[i].push_back(1);
173     }
174     // grad_reduce_ is left as empty
175     return;
176   }
177
178   // Reverse all the shapes for convenience
179   // After the reverse, 0-th is the inner-most dimension.
180   Vec copy[N];
181   for (int i = 0; i < N; ++i)
182   {
183     copy[i] = x[i];
184     Reverse(&copy[i]);
185   }
186
187   // 1-extend and align all vectors.
188   for (int i = 0; i < N; ++i)
189   {
190     if (copy[i].size() < largest_rank)
191     {
192       copy[i].resize(largest_rank, 1);
193     }
194   }
195   // Going through each dimension starting from the inner-most
196   // dimension, compares dimension of x and y. They are compatible if
197   // they are equal or either is 1.
198
199   // indices of j-th component of each input.
200   bool prev_is_one[N];
201   bool current_is_one[N];
202   for (int i = 0; i < N; ++i)
203   {
204     prev_is_one[i] = false;
205     current_is_one[i] = false;
206   }
207   Vec output;
208   bool output_dim_set = false;
209   int output_dim = -1;
210   bool none_is_one = true;
211   bool set_one = false;
212   for (size_t j = 0; j < largest_rank; ++j)
213   {
214     output_dim = -1;
215     output_dim_set = false;
216     none_is_one = true;
217     // Find which indices are 1.
218     for (int i = 0; i < N; ++i)
219     {
220       // Keep track of which indices are 1.
221       if (copy[i][j] == 1)
222       {
223         current_is_one[i] = true;
224         none_is_one = false;
225       }
226       else
227       {
228         current_is_one[i] = false;
229         if (!output_dim_set || copy[i][j] == output_dim)
230         {
231           output_dim = copy[i][j];
232           output_dim_set = true;
233         }
234         else
235         {
236           valid_ = false;
237           return;
238         }
239       }
240     }
241     output_.push_back(output_dim_set ? output_dim : 1);
242     output_batch_size_ *= output_.back();
243     // All dimensions are 1.
244     if (!output_dim_set)
245     {
246       if (!fewer_dims_optimization)
247       {
248         for (int i = 0; i < N; ++i)
249         {
250           bcast_[i].push_back(1);
251           reshape_[i].push_back(1);
252         }
253         result_.push_back(1);
254       }
255       for (int i = 0; i < N; ++i)
256       {
257         grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
258       }
259       // This will skip updating the previous state to the current one. We'll
260       // explain why this is safe below.
261       // Consider the previous state P, current state C and the next state N.
262       // In the case where N also is all ones (N == C), we'll do the same
263       // optimization here (push back one dimensions if we need to), which is
264       // safe and is expected.
265       //
266       // When N != C, we'll continue as usual. However, we might trigger the
267       // next block if N == P (because we didn't update the previous state).
268       // We trigger the next block if `fewer_dims_optimization` is true.
269       // This means that we did not modify and broadcast / reshapes in this
270       // block (we skipped updating, since the one dimensions can be ignored).
271       // In essence, we only need to check whether the previous non-one state is
272       // equal to the current non-one state.
273
274       continue;
275     }
276     else if ((fewer_dims_optimization) &&
277              std::equal(current_is_one, current_is_one + N, prev_is_one) && set_one)
278     {
279       // It is a run of the same broadcasting case as last time.
280       // We can reshape the input so that fewer dimensions
281       // are involved in the intermediate computation.
282       result_.back() *= output_dim;
283       for (int i = 0; i < N; ++i)
284       {
285         reshape_[i].back() *= copy[i][j];
286         bcast_[i].back() *= current_is_one[i] ? output_dim : 1;
287         if (current_is_one[i] && !none_is_one)
288         {
289           grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
290         }
291       }
292     }
293     else
294     {
295       result_.push_back(output_dim);
296       for (int i = 0; i < N; ++i)
297       {
298         reshape_[i].push_back(copy[i][j]);
299         bcast_[i].push_back(current_is_one[i] ? output_dim : 1);
300         if (current_is_one[i] && !none_is_one)
301         {
302           grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
303         }
304       }
305     }
306     set_one = true;
307     for (int i = 0; i < N; ++i)
308     {
309       prev_is_one[i] = current_is_one[i];
310     }
311   }
312   if (result_.empty())
313   {
314     result_.push_back(1);
315     for (int i = 0; i < N; ++i)
316     {
317       reshape_[i].push_back(1);
318       bcast_[i].push_back(1);
319     }
320   }
321   // Do something about batches.
322   for (int i = 0; i < N; ++i)
323   {
324     Reverse(&reshape_[i]);
325     Reverse(&bcast_[i]);
326     Reverse(&grad_reduce_idx_[i]);
327   }
328   Reverse(&result_);
329   Reverse(&output_);
330   // Only compute batch indices when we need broadcasting, and we aren't doing
331   // needless work (when the output size is 0 or the
332   // return_flattened_batch_indices isn't enabled).
333   if (return_flattened_batch_indices && broadcasting_required_ && output_batch_size_ > 0)
334   {
335     for (int i = 0; i < N; ++i)
336     {
337       ComputeBatchIndices(output_batch_size_, reshape_[i], bcast_[i], &batch_indices_[i]);
338     }
339   }
340 }
341
342 // BCast is a helper for broadcasting binary tensor operation.
343 // TensorFlow's broadcasting rule follows that of numpy (See
344 // http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
345 //
346 // The rule has the following properties:
347 //
348 //   1. suffix matching: the rule starts with the right-most
349 //      dimension, and works towards the left-most dimension. Since
350 //      TensorFlow is row-major, the right-most dimension (the last
351 //      element in the shape of a tensor) is the inner-most, a.k.a.
352 //      the fastest changing, dimension.
353 //
354 //   2. Two dimensions are compatible for broadcasting if both are the
355 //      same or either is 1.
356 //
357 // BCast takes the shape of two tensors and computes a few vectors of
358 // int32 that are useful for the caller to reshape the tensors, apply
359 // the right broadcasts to them, compute the broadcasted operation,
360 // and possibly the gradients. In a nutshell, the caller is expected
361 // to compute the broadcasted operation as following:
362 //
363 //   BCast b(x.shape(), y.shape());
364 //   output = x.reshape(b.x_reshape()).broadcast(b.x_bcast())
365 //            _op_
366 //            y.reshape(b.y_reshape()).broadcast(b.y_bcast())
367 //
368 // For the gradient computation,
369 //   grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx)
370 //            .reshape(x.shape())
371 //   grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx)
372 //            .reshape(y.shape())
373 // backprop_x and backprop_y are functionals of the binary function "op",
374 // e.g.,
375 //   for +, backprop_x(x, y) = backprop_y(x, y) = 1;
376 //   for *, backprop_x(x, y) =  y, backprop_y(x, y) = x;
377 //   for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2;
378 //
379 // The multiplication in the grad * backprop_x itself is also
380 // broadcasting following the same rule.
381 class BCast : public BCastList<2>
382 {
383 public:
384   // Constructs all helper shapes, following the aforementioned rules.
385   //
386   // If "fewer_dims_optimization" is set to true (the default), the
387   // implementation tries to reduce intermediate dimensions needed to be more
388   // efficient.  This is transparent to the caller.
389   //
390   // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have
391   // the same number of dimensions as the larger of the two inputs.
392   typedef std::vector<int32_t> Vec;
393
394   BCast(const Vec &x, const Vec &y, const bool fewer_dims_optimization = true,
395         const bool return_flattened_batch_indices = false)
396       : BCastList<2>({x, y}, fewer_dims_optimization, return_flattened_batch_indices)
397   {
398   }
399
400   ~BCast() {}
401
402   // If and only if IsValid(), the following fields can be used in
403   // implementing a broadcasted binary tensor operation according to
404   // the broadcasting rule.
405   const Vec &x_reshape() const { return reshape_[0]; }
406   const Vec &x_bcast() const { return bcast_[0]; }
407   const Vec &y_reshape() const { return reshape_[1]; }
408   const Vec &y_bcast() const { return bcast_[1]; }
409   const Vec &result_shape() const { return result_; }
410   const Vec &output_shape() const { return output_; }
411   const Vec &grad_x_reduce_idx() const { return grad_reduce_idx_[0]; }
412   const Vec &grad_y_reduce_idx() const { return grad_reduce_idx_[1]; }
413
414   // Returns the mapping from the flattened output batch indices to x's
415   // flattened batch indices. The result is a vector of length
416   // output_batch_size(). To compute the i'th batch output, a binary matmul-like
417   // operation should use the `x_batch_indices()[i]`th batch index of `x`.
418   // Note: Returns an empty vector if broadcasting is not required. Callers
419   // should only use this when IsBroadcastingRequired() returns true.
420   const std::vector<int32_t> &x_batch_indices() const { return batch_indices_[0]; }
421   // Returns the mapping from the flattened output batch indices to y's
422   // flattened batch indices. Similar to x_batch_indices().
423   // Note: Returns an empty vector if broadcasting is not required. Callers
424   // should only use this when IsBroadcastingRequired() returns true.
425   const std::vector<int32_t> &y_batch_indices() const { return batch_indices_[1]; }
426
427   template <typename IndexType, int NDIMS>
428   static Eigen::array<IndexType, NDIMS> ToIndexArrayType(const BCast::Vec &vec)
429   {
430     assert(vec.size() == NDIMS);
431     Eigen::array<IndexType, NDIMS> ret;
432     for (int i = 0; i < NDIMS; ++i)
433       ret[i] = vec[i];
434     return ret;
435   }
436
437   template <int NDIMS>
438   static Eigen::array<Eigen::DenseIndex, NDIMS> ToIndexArray(const BCast::Vec &vec)
439   {
440     return ToIndexArrayType<Eigen::DenseIndex, NDIMS>(vec);
441   }
442
443   // Static helpers.
444   static Vec FromShape(const Shape &shape)
445   {
446     const int N = shape.DimensionsCount();
447     BCastList::Vec ret(N);
448     for (int i = 0; i < N; ++i)
449     {
450       ret[i] = shape.Dims(i);
451     }
452     return ret;
453   }
454
455   static Shape ToShape(const BCastList::Vec &vec)
456   {
457     const int N = vec.size();
458     Shape shape(N);
459
460     for (int i = 0; i < N; ++i)
461     {
462       shape.SetDim(i, vec[i]);
463     }
464     return shape;
465   }
466
467 }; // BCast
468 } // namespace cker
469 } // namespace nnfw
470
471 #endif // __NNFW_CKER_HELPER_BCAST_H__