Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / operation / Einsum.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_EINSUM_H__
19 #define __NNFW_CKER_EINSUM_H__
20
21 #include "cker/Types.h"
22 #include "cker/Shape.h"
23 #include "cker/Utils.h"
24
25 #include "cker/operation/Helper/Tensor.h"
26 #include "cker/operation/Helper/MatmulBCast.h"
27
28 #include "Transpose.h"
29 #include "BatchMatMul.h"
30
31 #include <string>
32 #include <vector>
33 #include <map>
34 #include <numeric>
35 #include <algorithm>
36
37 namespace nnfw
38 {
39 namespace cker
40 {
41
42 namespace functor
43 {
44
45 template <typename Device, typename T, int N> struct StrideFunctor
46 {
47   void operator()(const Device &d, typename TTypes<T, N>::ConstTensor input,
48                   const std::vector<int32_t> &strides, typename TTypes<T, N>::Tensor output)
49   {
50
51     Eigen::DSizes<Eigen::DenseIndex, N> dsizes;
52     for (size_t d = 0; d < strides.size(); d++)
53     {
54       dsizes[d] = static_cast<Eigen::DenseIndex>(strides[d]);
55     }
56     for (size_t d = strides.size(); d < N; d++)
57     {
58       dsizes[d] = 1;
59     }
60
61     output.device(d) = input.stride(dsizes);
62   }
63 };
64
65 template <typename Device, typename T, int N> struct InflateFunctor
66 {
67   void operator()(const Device &d, typename TTypes<T, N>::ConstTensor input,
68                   const std::vector<int32_t> &strides, typename TTypes<T, N>::Tensor output)
69   {
70
71     Eigen::DSizes<Eigen::DenseIndex, N> dsizes;
72     for (size_t d = 0; d < strides.size(); d++)
73     {
74       dsizes[d] = static_cast<Eigen::DenseIndex>(strides[d]);
75     }
76     for (size_t d = strides.size(); d < N; d++)
77     {
78       dsizes[d] = 1;
79     }
80
81     output.device(d) = input.inflate(dsizes);
82   }
83 };
84
85 template <typename Device, typename Reducer> struct ReduceFunctor
86 {
87   template <typename OUT_T, typename IN_T, typename ReductionAxes>
88   static void Reduce(const Device &d, OUT_T out, IN_T in, const ReductionAxes &reduction_axes,
89                      const Reducer &reducer)
90   {
91     out.device(d) = in.reduce(reduction_axes, reducer);
92   }
93 };
94
95 template <typename Device, typename T> struct SetZeroFunctor
96 {
97   // Computes on device "d": out = out.setZero(),
98   void operator()(const Device &d, typename TTypes<T>::Flat out)
99   {
100     out.device(d) = out.constant(T(0));
101   }
102 };
103
104 } // namespace functor
105
106 using ShapeVec = std::vector<int32_t>;
107 using Labels = std::vector<int32_t>;
108 using OperandLabels = std::vector<Labels>;
109 using LabelCounts = std::vector<int32_t>;
110 using OperandLabelCounts = std::vector<LabelCounts>;
111 using LabelToDimSizes = std::vector<int32_t>;
112
113 // Each dimension is categorized into exactly one of five types based on
114 // whether its corresponding label is present in the input and/or the output
115 // subscripts.
116 enum DimensionType
117 {
118   // Batch dimensions are those present in two inputs as well as the output.
119   // They are part of the batch dimensions during Tensor contraction.
120   // Such dimensions may be broadcasting dimensions (those mapping to
121   // ellipsis)
122   // or explicit batch dimensions corresponding to named axis labels.
123   kBroadcasting = 0,
124   kBatch = 1,
125   // Free dimensions are present in exactly one of the inputs, and also the
126   // output. These are non-contracted axes in the Tensor contraction.
127   kFree = 2,
128   // Contract dimensions are present in two inputs, but not the output. These
129   // dimensions are contracted in Tensor contraction.
130   kContract = 3,
131   // Reduce dimensions are present in exactly one input; and not in the output
132   // and are summed over prior to Tensor contraction.
133   kReduce = 4,
134 };
135
136 namespace
137 {
138
139 constexpr int kEllipsisLabel = -1;
140
141 std::vector<std::string> strSplit(const std::string &text, const std::string delimiter)
142 {
143   std::vector<std::string> result;
144
145   size_t start = 0;
146   size_t pos = 0;
147
148   do
149   {
150     pos = text.find(delimiter, start);
151     if (pos == std::string::npos)
152     {
153       result.push_back(text.substr(start, text.size() - start));
154       break;
155     }
156
157     result.push_back(text.substr(start, pos - start));
158     start = pos + delimiter.size();
159   } while (pos != std::string::npos);
160
161   return result;
162 }
163
164 inline DimensionType getDimensionType(bool is_removed, bool is_unique)
165 {
166   if (!is_removed && !is_unique)
167     return kBatch;
168   else if (!is_removed && is_unique)
169     return kFree;
170   else if (is_removed && !is_unique)
171     return kContract;
172   else // is_removed && is_unique
173     return kReduce;
174 }
175
176 inline Shape copyShape(const Shape &shape)
177 {
178   return Shape::ExtendedShape(shape.DimensionsCount(), shape);
179 }
180 }
181
182 class Einsum
183 {
184 public:
185   Einsum() : _prepared(false)
186   {
187     // DO NOTHING
188   }
189
190   void prepare(std::string &equation)
191   {
192     if (_prepared)
193     {
194       return;
195     }
196
197     // Parse equation
198     parseEquation(equation);
199     _prepared = true;
200   }
201
202   void operator()(std::string &equation, const std::vector<Shape> &input_shapes,
203                   const std::vector<const float *> &input_data, const Shape &output_shape,
204                   float *output_data)
205   {
206     if (!_prepared)
207     {
208       prepare(equation);
209     }
210
211     const int num_inputs = input_shapes.size();
212     std::vector<InputTensor<float>> inputs(num_inputs);
213     for (int i = 0; i < num_inputs; i++)
214     {
215       inputs[i].shape.ReplaceWith(input_shapes[i].DimensionsCount(), input_shapes[i].DimsData());
216       inputs[i].buffer = input_data[i];
217     }
218
219     OperandLabels input_labels(_input_labels);
220     Labels output_labels(_output_labels);
221     std::vector<DimensionType> label_types(_label_types);
222     OperandLabelCounts input_label_counts(_input_label_counts);
223     LabelCounts output_label_counts(_output_label_counts);
224     LabelToDimSizes label_to_dim_sizes;
225
226     processDimensions(inputs, &input_labels, &output_labels, &label_types, &input_label_counts,
227                       &output_label_counts, &label_to_dim_sizes);
228
229     // The reduction phase (a) sums across reduction dimensions, (b) takes
230     // generalized diagonals, and (c) reshapes it into shape
231     //   [(broadcasting) batch shape] + [F,C]
232     // where F and C denote the total (compacted) size of free and contract
233     // dimensions, respectively.
234
235     OperandLabels free_labels(num_inputs);
236     std::vector<Tensor> inputs_reduced(num_inputs);
237     std::vector<bool> swap_free_and_contract(num_inputs);
238     for (int i = 0; i < num_inputs; ++i)
239     {
240       bool temp_swap_free_and_contract = false;
241       reduceOperand<float>(inputs[i], label_types, input_label_counts[i], &input_labels[i],
242                            &free_labels[i], &temp_swap_free_and_contract, &inputs_reduced[i]);
243       swap_free_and_contract[i] = temp_swap_free_and_contract;
244     }
245
246     // After reduction, the inputs should be reshaped to Tensors suitable for
247     // contraction. If num_inputs is 1, the reduced input is simply forwarded to
248     // the output.
249     Tensor contraction_output_reshaped;
250     contractOperands(inputs_reduced, swap_free_and_contract, &contraction_output_reshaped);
251
252     // Copy the batch labels from the contraction output. Recover the batch
253     // shape, which may have been broadcasted.
254     std::vector<int32_t> result_shape_dims(contraction_output_reshaped.shape.DimensionsCount() - 2);
255
256     for (size_t i = 0; i < result_shape_dims.size(); i++)
257     {
258       result_shape_dims[i] = contraction_output_reshaped.shape.Dims(i);
259     }
260
261     int num_labels = label_types.size();
262     Labels result_labels;
263     // All batch dimensions should be present in the contracted result. First
264     // the broadcasting dimensions, then the named batch dimensions.
265     for (int label = 0; label < num_labels; ++label)
266     {
267       if (label_types[label] == kBroadcasting)
268         result_labels.push_back(label);
269     }
270     for (int label = 0; label < num_labels; ++label)
271     {
272       if (label_types[label] == kBatch)
273         result_labels.push_back(label);
274     }
275     for (int i = 0; i < num_inputs; ++i)
276     {
277       for (int label : free_labels[i])
278       {
279         result_labels.push_back(label);
280         result_shape_dims.push_back(label_to_dim_sizes[label]);
281       }
282     }
283
284     Shape result_shape(result_shape_dims.size(), result_shape_dims.data());
285
286     // Reshape the contraction (or reduction) result to its expanded shape:
287     // [(broadcasted) batch shape] + [free shape 0] + [free shape 1].
288     Tensor contraction_output;
289     copyFrom(contraction_output_reshaped, result_shape, &contraction_output);
290
291     // Inflate the output if necessary. (E.g. for the equation 'i->iii' which
292     // may arise while computing gradient of a regular Einsum).
293     // TODO(anudhyan): It's possible that Eigen's contract and inflate can be
294     // chained here to avoid materializing an intermediate.
295     Tensor output_inflated;
296     strideOrInflate<float>(contraction_output, result_labels, output_label_counts,
297                            true /* should_inflate */, &output_inflated);
298
299     if (output_inflated.shape.DimensionsCount() > contraction_output.shape.DimensionsCount())
300     {
301       // We inflated the output. Modify result labels accordingly.
302       Labels inflated_labels;
303       for (int label : result_labels)
304       {
305         inflated_labels.insert(inflated_labels.end(), output_label_counts[label], label);
306       }
307       result_labels.swap(inflated_labels);
308     }
309
310     // Find the permutation to map the result labels to the output labels. Note
311     // that both the result and the final output may have the repeated labels,
312     // in which case the permutation preserves the left-to-right ordering.
313     // E.g. if result labels are [0, 0, 1] and output is [0, l, 0] then the
314     // permutation should be [0, 2, 1]. We also use the fact that repeated
315     // labels in the result are adjacent to each other.
316     std::vector<int32_t> output_permutation(output_labels.size());
317     std::vector<int32_t> label_to_position(num_labels, -1);
318     for (size_t i = 0; i < result_labels.size(); ++i)
319     {
320       // Remember the position of only the leftmost result label.
321       if (label_to_position[result_labels[i]] == -1)
322       {
323         label_to_position[result_labels[i]] = i;
324       }
325     }
326     for (size_t i = 0; i < output_labels.size(); ++i)
327     {
328       output_permutation[i] = label_to_position[output_labels[i]];
329       // We have found the leftmost occurrence. The next one would be adjacent.
330       label_to_position[output_labels[i]] += 1;
331     }
332
333     InputTensor<float> temp_inflated;
334     temp_inflated.shape.ReplaceWith(output_inflated.shape.DimensionsCount(),
335                                     output_inflated.shape.DimsData());
336     temp_inflated.buffer = (reinterpret_cast<const float *>(output_inflated.buffer));
337     ;
338
339     Tensor output;
340     transposeOperand<float>(temp_inflated, output_permutation, &output);
341
342     memcpy(output_data, output.buffer, output_shape.FlatSize() * sizeof(float));
343
344     temp_operand.clear();
345   }
346
347 private:
348   void parseEquation(std::string &equation)
349   {
350     std::vector<std::string> input_str;
351     std::string output_str;
352
353     parseEinsumEquation(equation, input_str, output_str);
354
355     // Temporary map from single character labels to (consecutive) integer
356     // labels.
357     std::map<char, int> label_mapping;
358     int num_inputs = input_str.size();
359     _input_labels.resize(num_inputs);
360
361     // Map from single characters to integer labels.
362     for (int i = 0; i < num_inputs; ++i)
363     {
364       mapToLabels(input_str[i], _input_labels.at(i), label_mapping);
365     }
366     mapToLabels(output_str, _output_labels, label_mapping);
367
368     // Compute counts for input and output labels.
369     int num_labels = label_mapping.size();
370     _input_label_counts.resize(num_inputs);
371     _input_has_ellipsis.resize(num_inputs);
372     for (int i = 0; i < num_inputs; ++i)
373     {
374       _input_label_counts.at(i).resize(num_labels);
375       for (const int label : _input_labels.at(i))
376       {
377         if (label != kEllipsisLabel)
378           _input_label_counts.at(i)[label] += 1;
379         else
380           _input_has_ellipsis.at(i) = true;
381       }
382     }
383     _output_label_counts.resize(num_labels);
384     for (const int label : _output_labels)
385     {
386       if (label != kEllipsisLabel)
387         _output_label_counts.at(label) += 1;
388       else
389         _output_has_ellipsis = true;
390     }
391
392     // Map each label to a unique DimensionType.
393     _label_types.resize(num_labels);
394     for (int label = 0; label < num_labels; ++label)
395     {
396       bool removed = (_output_label_counts[label] == 0);
397       bool unique =
398         num_inputs == 1 || _input_label_counts[0][label] == 0 || _input_label_counts[1][label] == 0;
399       _label_types[label] = getDimensionType(removed, unique);
400     }
401   }
402
403   void parseEinsumEquation(const std::string &equation, std::vector<std::string> &input_subscripts,
404                            std::string &output_subscript)
405   {
406     std::vector<std::string> inputs_and_output_subscripts = strSplit(equation, "->");
407     if (inputs_and_output_subscripts.size() != 2)
408     {
409       throw std::runtime_error{"Einsum: Expecting exactly one '->' in einsum equation: " +
410                                equation};
411     }
412
413     output_subscript = inputs_and_output_subscripts[1];
414     input_subscripts = strSplit(inputs_and_output_subscripts[0], ",");
415     if (input_subscripts.size() != 1 && input_subscripts.size() != 2)
416     {
417       throw std::runtime_error{"Einsum: Expecting 1 or 2 input subscripts in equation '" +
418                                equation + "' but got: " + std::to_string(input_subscripts.size())};
419     }
420   }
421
422   // Maps the character labels to consecutive integers.
423   void mapToLabels(const std::string &subscript, Labels &labels, std::map<char, int> &label_mapping)
424   {
425     for (size_t i = 0; i < subscript.size(); ++i)
426     {
427       const char label_char = subscript[i];
428       if (label_char == '.')
429       {
430         labels.push_back(kEllipsisLabel);
431         i += 2; // Skip next 2 characters as well.
432         continue;
433       }
434       if (label_mapping.find(label_char) == label_mapping.end())
435       {
436         const int next_label = label_mapping.size();
437         label_mapping[label_char] = next_label;
438       }
439       const int mapped_label = label_mapping[label_char];
440       labels.push_back(mapped_label);
441     }
442   }
443
444   template <typename T>
445   void processDimensions(const std::vector<InputTensor<T>> &inputs, OperandLabels *input_labels,
446                          Labels *output_labels, std::vector<DimensionType> *label_types,
447                          OperandLabelCounts *input_label_counts, LabelCounts *output_label_counts,
448                          LabelToDimSizes *label_to_dim_sizes)
449   {
450     if (inputs.size() != input_labels->size())
451     {
452       throw std::runtime_error{"Expected " + std::to_string(input_labels->size()) +
453                                " inputs but got: " + std::to_string(inputs.size())};
454     }
455     const int num_inputs = inputs.size();
456
457     // We infer the number of broadcasting dimensions by taking the maximum rank
458     // among the broadcasting subshapes of the input.
459     int max_bcast_dims = 0;
460     const int num_named_labels = label_types->size();
461     label_to_dim_sizes->resize(num_named_labels);
462     for (int i = 0; i < num_inputs; ++i)
463     {
464       Labels *labels = &(*input_labels)[i];
465
466       if (!_input_has_ellipsis[i])
467       {
468         if (inputs[i].shape.DimensionsCount() != ((int32_t)labels->size()))
469         {
470           throw std::runtime_error{"Expected input " + std::to_string(i) + " to have rank " +
471                                    std::to_string(labels->size()) + " but got: " +
472                                    std::to_string(inputs[i].shape.DimensionsCount())};
473         }
474         for (size_t label_idx = 0; label_idx < labels->size(); ++label_idx)
475         {
476           const int label = (*labels)[label_idx];
477           recordLabelToDimension(label, label_idx, inputs[i].shape, label_to_dim_sizes);
478         }
479         continue;
480       }
481
482       // Input has an ellipsis.
483       if (inputs[i].shape.DimensionsCount() + 1 < (int32_t)labels->size())
484       {
485         throw std::runtime_error{"Expected input " + std::to_string(i) + " to have rank at least " +
486                                  std::to_string(labels->size() - 1) +
487                                  " but got: " + std::to_string(inputs[i].shape.DimensionsCount())};
488       }
489       int ellipsis_axis = -1;
490       const int num_bcast_dims = inputs[i].shape.DimensionsCount() - labels->size() + 1;
491       for (size_t label_idx = 0; label_idx < labels->size(); ++label_idx)
492       {
493         const int label = (*labels)[label_idx];
494         if (label == kEllipsisLabel)
495         {
496           ellipsis_axis = label_idx;
497           continue;
498         }
499         // Current label is not an ellipsis.
500         const int axis = label_idx + (ellipsis_axis == -1 ? 0 : num_bcast_dims - 1);
501         recordLabelToDimension(label, axis, inputs[i].shape, label_to_dim_sizes);
502       }
503       // Found an ellipsis. Replace 'kEllipsisLabel' with broadcasting
504       // dimensions.
505       if (ellipsis_axis != -1)
506       {
507         insertBroadcastLabels(num_bcast_dims, num_named_labels, ellipsis_axis, labels,
508                               &input_label_counts->at(i));
509         max_bcast_dims = std::max(max_bcast_dims, num_bcast_dims);
510       }
511     }
512
513     std::vector<bool>::iterator it_input =
514       std::find(_input_has_ellipsis.begin(), _input_has_ellipsis.end(), true);
515     if (it_input == _input_has_ellipsis.end() && !_output_has_ellipsis)
516     {
517       return;
518     }
519     // Insert broadcasting dimensions in the output labels.
520     auto it = std::find(output_labels->begin(), output_labels->end(), kEllipsisLabel);
521     if (it != output_labels->end())
522     {
523       const int ellipsis_axis = it - output_labels->begin();
524       insertBroadcastLabels(max_bcast_dims, num_named_labels, ellipsis_axis, output_labels,
525                             output_label_counts);
526     }
527     else if (max_bcast_dims > 0)
528     {
529       std::runtime_error{"Output contains " + std::to_string(max_bcast_dims) +
530                          " broadcasting dimension(s) but no ellipsis " +
531                          "(...) was found in the output subscripts."};
532     }
533     // Populate DimensionType for the new broadcasting labels.
534     label_types->resize(num_named_labels + max_bcast_dims, kBroadcasting);
535   }
536
537   void recordLabelToDimension(const int32_t label, const int axis, const Shape &input_shape,
538                               LabelToDimSizes *label_to_dim_sizes)
539   {
540     const int32_t input_dim = input_shape.Dims(axis);
541     // We know that label_to_dim_sizes has the size to accommodate named labels.
542     if (label_to_dim_sizes->at(label) != 0 && label_to_dim_sizes->at(label) != input_dim)
543     {
544       std::runtime_error{"Expected dimension " + std::to_string(label_to_dim_sizes->at(label)) +
545                          " at axis " + std::to_string(axis) +
546                          " of the input shaped but got dimension " + std::to_string(input_dim)};
547     }
548     (*label_to_dim_sizes)[label] = input_dim;
549   }
550
551   void insertBroadcastLabels(int num_bcast_dims, int num_named_labels, int ellipsis_axis,
552                              Labels *labels, LabelCounts *label_counts)
553   {
554     labels->erase(labels->begin() + ellipsis_axis);
555     labels->insert(labels->begin() + ellipsis_axis, num_bcast_dims, 0);
556     std::iota(labels->begin() + ellipsis_axis, labels->begin() + ellipsis_axis + num_bcast_dims,
557               num_named_labels);
558     // Increment label counts. Since these are new labels, the count is set
559     // to 1.
560     label_counts->resize(num_named_labels + num_bcast_dims, 1);
561   }
562
563   template <typename T>
564   void reduceOperand(const InputTensor<T> &input, const std::vector<DimensionType> &label_types,
565                      const LabelCounts &label_counts, Labels *labels, Labels *free_labels,
566                      bool *swap_free_and_contract, Tensor *output)
567   {
568     // Find the permutation to transpose the input dimensions in the order of
569     // DimensionType; i.e. batch, free, contract and reduce dimensions. This
570     // makes it more convenient to invoke Reduce/Contract operations.
571     std::vector<int32_t> permutation(input.shape.DimensionsCount());
572     std::iota(permutation.begin(), permutation.end(), 0);
573     Tensor input_transposed;
574
575     // Check if we can avoid the transpose. We need to flip the adj_x (or adj_y)
576     // flag during BatchMatMul. This is an extra optimization not necessary for
577     // correctness.
578     if (shouldSwapFreeAndContract(*labels, label_types))
579     {
580       *swap_free_and_contract = true;
581     }
582     else
583     {
584       std::sort(permutation.begin(), permutation.end(), [&](int i, int j) {
585         int label_i = (*labels)[i];
586         int label_j = (*labels)[j];
587         return std::tie(label_types[label_i], label_i) < std::tie(label_types[label_j], label_j);
588       });
589     }
590     // Transpose the input so that DimensionTypes are in order.
591     transposeOperand<T>(input, permutation, &input_transposed);
592
593     permuteLabels(permutation, labels);
594
595     // Take the generalized diagonal for dimensions with repeated axis labels.
596     Tensor input_deduped;
597     labels->erase(std::unique(labels->begin(), labels->end()), labels->end());
598     strideOrInflate<T>(input_transposed, *labels, label_counts, false /* should_inflate */,
599                        &input_deduped);
600
601     // Reshape denotes the rank-5 shape [broadcast, batch, free, contract,
602     // reduce] where we've compacted the dimensions of each DimensionType.
603     std::vector<int32_t> reshape(5, 1);
604
605     // The output shape is [batch shape] + [free size, contract size]
606     // That is, the batch shape is preserved (for broadcasting while
607     // contracting) while the free dims and contract dims are compressed to one
608     // dimension each.
609     Shape output_shape;
610     std::vector<int32_t> output_shape_dims;
611     for (size_t label_idx = 0; label_idx < labels->size(); ++label_idx)
612     {
613       const int label = labels->at(label_idx);
614       int32_t dim = input_deduped.shape.Dims(label_idx);
615       if (label_types[label] == kBroadcasting || label_types[label] == kBatch)
616       {
617         output_shape_dims.push_back(dim);
618       }
619       else if (label_types[label] == kFree)
620       {
621         free_labels->push_back(label);
622       }
623       reshape[label_types[label]] *= dim;
624     }
625
626     if (*swap_free_and_contract)
627       std::swap(reshape[kFree], reshape[kContract]);
628
629     output_shape_dims.push_back(reshape[kFree]);
630     output_shape_dims.push_back(reshape[kContract]);
631
632     output_shape.ReplaceWith(output_shape_dims.size(), output_shape_dims.data());
633
634     if (reshape[kReduce] == 1)
635     { // No need to actually reduce.
636       return copyFrom(input_deduped, output_shape, output);
637     }
638
639     allocateTemp(output_shape, output);
640
641     using Reducer = Eigen::internal::SumReducer<T>;
642     using Index = typename TTypes<T>::Tensor::Index;
643
644     const Eigen::ThreadPoolDevice &device = *eigen_support::GetThreadPoolDevice();
645
646     // Reduce along the last axis (i.e axis 1) of the rank-2 Tensor.
647     const int32_t output_size =
648       reshape[kBroadcasting] * reshape[kBatch] * reshape[kFree] * reshape[kContract];
649     functor::ReduceFunctor<Eigen::ThreadPoolDevice, Reducer>::Reduce(
650       device, output->shaped<T, 1>({output_size}),
651       input_deduped.shaped<T, 2>({output_size, reshape[kReduce]}), Eigen::array<Index, 1>({1}),
652       Reducer());
653   }
654
655   bool shouldSwapFreeAndContract(const Labels &labels,
656                                  const std::vector<DimensionType> &label_types)
657   {
658     // Check that ordering is according to dimension type, with the role of
659     // free and contract dimensions swapped.
660     std::vector<int> remap = {0, 1, 3, 2, 4};
661     for (size_t i = 0; i + 1 < labels.size(); ++i)
662     {
663       const int dimtype_a = remap[label_types[labels[i]]];
664       const int dimtype_b = remap[label_types[labels[i + 1]]];
665       if (dimtype_a > dimtype_b || (dimtype_a == dimtype_b && labels[i] > labels[i + 1]))
666       {
667         return false;
668       }
669     }
670     return true;
671   }
672
673   template <typename T>
674   void transposeOperand(const InputTensor<T> &input, const std::vector<int32_t> &permutation,
675                         Tensor *output)
676   {
677     if (!shouldTranspose(input.shape, permutation))
678     {
679       copyFrom(input, input.shape, output);
680       return;
681     }
682     Shape transposed_shape(input.shape.DimensionsCount());
683     for (int i = 0; i < input.shape.DimensionsCount(); ++i)
684     {
685       transposed_shape.SetDim(i, input.shape.Dims(permutation[i]));
686     }
687     // For empty Tensors, just change the shape. E.g. we may need to transpose
688     // from shape [1, 0, 5] to [5, 1, 0].
689     if (input.shape.FlatSize() == 0)
690     {
691       copyFrom(input, transposed_shape, output);
692       return;
693     }
694
695     temp_operand.emplace_back(std::make_unique<T[]>(transposed_shape.FlatSize()));
696     T *new_buffer = temp_operand.back().get();
697
698     TransposeParams transpose_params;
699     transpose_params.perm_count = permutation.size();
700     for (size_t i = 0; i < permutation.size(); i++)
701     {
702       transpose_params.perm[i] = permutation[i];
703     }
704
705     Transpose<T>(transpose_params, input.shape, input.buffer, transposed_shape, new_buffer);
706
707     output->shape.ReplaceWith(transposed_shape.DimensionsCount(), transposed_shape.DimsData());
708     output->buffer = new_buffer;
709   }
710
711   bool shouldTranspose(const Shape &input_shape, const std::vector<int32_t> &permutation)
712   {
713     if (input_shape.DimensionsCount() < 2)
714       return false;
715     for (size_t i = 0; i < permutation.size(); ++i)
716     {
717       if (permutation[i] != (int32_t)i)
718         return true;
719     }
720     return false;
721   }
722
723   template <typename T>
724   void copyFrom(const InputTensor<T> &input, const Shape &shape, Tensor *output)
725   {
726     Tensor temp_tensor;
727     temp_tensor.shape.ReplaceWith(input.shape.DimensionsCount(), input.shape.DimsData());
728     temp_operand.emplace_back(std::make_unique<float[]>(input.shape.FlatSize()));
729     temp_tensor.buffer = temp_operand.back().get();
730     memcpy(temp_tensor.buffer, input.buffer, input.shape.FlatSize() * sizeof(float));
731
732     copyFrom(temp_tensor, shape, output);
733   }
734
735   void copyFrom(const Tensor &input, const Shape &shape, Tensor *output)
736   {
737     if (output->copyFrom(input, shape))
738       return;
739
740     throw std::runtime_error{"Einsum: Encountered error while reshaping a Tensor"};
741   }
742
743   // Permutes the labels according to the given permutation.
744   void permuteLabels(const std::vector<int32_t> &permutation, Labels *labels)
745   {
746     Labels permuted_labels(labels->size());
747     for (size_t i = 0; i < labels->size(); ++i)
748     {
749       permuted_labels[i] = (*labels)[permutation[i]];
750     }
751     labels->swap(permuted_labels);
752   }
753
754   // If there are repeated labels in either the input or output, then this
755   // strides the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively.
756   template <typename T>
757   void strideOrInflate(const Tensor &input, const Labels &labels, const LabelCounts &label_counts,
758                        const bool should_inflate, Tensor *output)
759   {
760     // Return early if there are no repeated indices.
761     if (std::all_of(label_counts.begin(), label_counts.end(), [](int c) { return c <= 1; }))
762     {
763       return copyFrom(input, input.shape, output);
764     }
765     // We reshape so that each repeated label is compressed to one dimension.
766     // E.g. For iiij -> ij, The shape [3, 3, 3, 5] would be compressed to [27,
767     // 5]. Striding appropriately (in this case with strides 14 (=1+3+9) and 1)
768     // recovers the generalized diagonal of shape [3, 5].
769     std::vector<int32_t> reshape;
770     std::vector<int32_t> strides;
771     // Strided and inflated shapes correspond to input and output shapes,
772     // respectively, should_inflate is true (vice-versa if should_inflate is
773     // false). E.g. they are [3, 5] and [3, 3, 3, 5] in the above example.
774     Shape strided_shape;
775     Shape inflated_shape;
776     std::vector<int32_t> strided_shape_dims;
777     std::vector<int32_t> inflated_shape_dims;
778     for (int label : labels)
779     {
780       const int32_t count = label_counts[label];
781       const int current_axis =
782         should_inflate ? strided_shape_dims.size() : inflated_shape_dims.size();
783       const int32_t dim = input.shape.Dims(current_axis);
784       strided_shape_dims.push_back(dim);
785       inflated_shape_dims.insert(inflated_shape_dims.end(), count, dim);
786       const int32_t reshape_dim = std::pow(dim, count);
787       reshape.push_back(reshape_dim);
788       // While taking the d-diagonal in a rank k Tensor, we take d
789       // equally-spaced elements including the first and last element. Then, (k
790       // - 1) * stride = d^k - 1, or, stride = (d^k - 1)/(d - 1).
791       const int32_t stride = (dim > 1 && count > 1) ? (reshape_dim - 1) / (dim - 1) : 1;
792       strides.push_back(stride);
793     }
794
795     strided_shape.ReplaceWith(strided_shape_dims.size(), strided_shape_dims.data());
796     inflated_shape.ReplaceWith(inflated_shape_dims.size(), inflated_shape_dims.data());
797
798     Shape output_shape = Shape(should_inflate ? inflated_shape : strided_shape);
799
800     output->shape.ReplaceWith(output_shape.DimensionsCount(), output_shape.DimsData());
801     temp_operand.emplace_back(std::make_unique<float[]>(output_shape.FlatSize()));
802     output->buffer = temp_operand.back().get();
803
804     const Eigen::ThreadPoolDevice &device = *eigen_support::GetThreadPoolDevice();
805
806     switch (reshape.size())
807     {
808 #define NDIMS_CASE(N)                                                                      \
809   case N:                                                                                  \
810   {                                                                                        \
811     if (should_inflate)                                                                    \
812     {                                                                                      \
813       auto output_map = output->shaped<T, N>(reshape);                                     \
814       auto input_map = input.shaped<T, N>(strided_shape_dims);                             \
815       functor::InflateFunctor<Eigen::ThreadPoolDevice, T, N>()(device, input_map, strides, \
816                                                                output_map);                \
817     }                                                                                      \
818     else                                                                                   \
819     {                                                                                      \
820       auto input_map = input.shaped<T, N>(reshape);                                        \
821       auto output_map = output->shaped<T, N>(strided_shape_dims);                          \
822       functor::StrideFunctor<Eigen::ThreadPoolDevice, T, N>()(device, input_map, strides,  \
823                                                               output_map);                 \
824     }                                                                                      \
825   }                                                                                        \
826   break;
827       NDIMS_CASE(1);
828       NDIMS_CASE(2);
829       NDIMS_CASE(3);
830       NDIMS_CASE(4);
831       NDIMS_CASE(5);
832       NDIMS_CASE(6);
833       default:
834         throw std::runtime_error{"Unsupported rank: " + std::to_string(reshape.size()) +
835                                  " while handling repeated indices. Up to rank 6 is supported."};
836 #undef NDIMS_CASE
837     }
838   }
839
840   void allocateTemp(const Shape &shape, Tensor *output)
841   {
842     output->shape.ReplaceWith(shape.DimensionsCount(), shape.DimsData());
843     temp_operand.emplace_back(std::make_unique<float[]>(shape.FlatSize()));
844     output->buffer = temp_operand.back().get();
845   }
846
847   // Contracts the inputs along the last axis. (or the second last if the
848   // corresponding value of swap_free_and_contract is true). The batch
849   // dimensions are broadcast to the output shape.
850   // TODO(anudhyan): Factor this function into a BatchMatMul functor and support
851   // transpose_x and transpose_y attributes (in addition to adj_x and adj_y).
852   // Also, the BatchMatMul might devolve into a component-wise multiplication
853   // when the matrix shape is [1,1]; in this case BatchMatMul functor would be
854   // very inefficient. The functor should detect if this is the case and perform
855   // componentwise multiplication functor instead.
856   void contractOperands(std::vector<Tensor> &inputs, std::vector<bool> &swap_free_and_contract,
857                         Tensor *output)
858   {
859     if (inputs.size() == 1)
860       return copyFrom(inputs[0], inputs[0].shape, output);
861
862     MatMulBCast bcast(inputs[0].shape, inputs[1].shape);
863     if (!bcast.IsValid())
864     {
865       throw std::runtime_error{"Einsum: Invalid broadcasting dimensions"};
866     }
867
868     Tensor lhs;
869     reshapeToRank3(inputs[0], bcast.x_batch_size(), &lhs);
870     Tensor rhs;
871     reshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs);
872     Shape old_output_shape = bcast.output_batch_shape();
873     Shape output_shape(old_output_shape.DimensionsCount() + inputs.size());
874     for (int i = 0; i < old_output_shape.DimensionsCount(); i++)
875     {
876       output_shape.SetDim(i, old_output_shape.Dims(i));
877     }
878
879     for (size_t i = 0; i < inputs.size(); ++i)
880     {
881       const int32_t free_axis =
882         inputs[i].shape.DimensionsCount() - (swap_free_and_contract[i] ? 1 : 2);
883       output_shape.SetDim(i + old_output_shape.DimensionsCount(), inputs[i].shape.Dims(free_axis));
884     }
885     bool adj_x = swap_free_and_contract[0];
886     bool adj_y = !swap_free_and_contract[1];
887
888     allocateTemp(output_shape, output);
889
890     const Eigen::ThreadPoolDevice &device = *eigen_support::GetThreadPoolDevice();
891
892     if (lhs.shape.FlatSize() == 0 || rhs.shape.FlatSize() == 0)
893     {
894       functor::SetZeroFunctor<Eigen::ThreadPoolDevice, float> set_zero;
895       set_zero(device,
896                typename TTypes<float, 1>::Tensor(output->base<float>(), output->shape.FlatSize()));
897       return;
898     }
899
900     Tensor output_reshaped;
901     reshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped);
902
903     // LaunchBatchMatMul::Launch(lhs, rhs, adj_x, adj_y, bcast, &output_reshaped);
904     BatchMatMul batchMatMul;
905     batchMatMul.prepare(lhs.shape, rhs.shape, adj_x, adj_y);
906     batchMatMul(lhs.shape, lhs.base<float>(), rhs.shape, rhs.base<float>(), adj_x, adj_y,
907                 output_reshaped.shape, output_reshaped.base<float>());
908   }
909
910   void reshapeToRank3(const Tensor &input, int batch_size, Tensor *output)
911   {
912     const int rank = input.shape.DimensionsCount();
913     Shape output_shape({batch_size, input.shape.Dims(rank - 2), input.shape.Dims(rank - 1)});
914     copyFrom(input, output_shape, output);
915   }
916
917 private:
918   bool _prepared;
919
920   OperandLabels _input_labels;
921   Labels _output_labels;
922   std::vector<DimensionType> _label_types;
923   OperandLabelCounts _input_label_counts;
924   LabelCounts _output_label_counts;
925   std::vector<bool> _input_has_ellipsis;
926   bool _output_has_ellipsis = false;
927
928   std::vector<std::unique_ptr<float[]>> temp_operand;
929 };
930
931 } // namespace cker
932 } // namespace nnfw
933
934 #endif // __NNFW_CKER_EINSUM_H__