Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / common / ProcessBroadcastShapes.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_PROCESS_BROADCAST_SHAPES_H
19 #define LUCI_INTERPRETER_PAL_PROCESS_BROADCAST_SHAPES_H
20
21 namespace luci_interpreter_pal
22 {
23
24 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
25 // BROADCASTING.
26 //
27 // NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
28 // rectangular array of numbers.
29 //
30 // NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
31 // However, as Dims<N> is to be deprecated, this class exists as an adaptor
32 // to enable simple unoptimized implementations of element-wise broadcasting
33 // operations.
34 template <int N> struct NdArrayDesc
35 {
36   // The "extent" of each dimension. Indices along dimension d must be in the
37   // half-open interval [0, extents[d]).
38   int extents[N];
39
40   // The number of *elements* (not bytes) between consecutive indices of each
41   // dimension.
42   int strides[N];
43 };
44
45 // Copies dims to desc, calculating strides.
46 template <int N>
47 inline void copyDimsToDesc(const luci_interpreter::RuntimeShape &input_shape,
48                            NdArrayDesc<N> *desc_out)
49 {
50   int desc_stride = 1;
51   for (int i = N - 1; i >= 0; --i)
52   {
53     desc_out->extents[i] = input_shape.dims(i);
54     desc_out->strides[i] = desc_stride;
55     desc_stride *= input_shape.dims(i);
56   }
57 }
58
59 template <int N, int DIM, typename Calc>
60 typename std::enable_if<DIM == N - 1, void>::type NDOpsHelperImpl(const NdArrayDesc<N> &output,
61                                                                   const Calc &calc, int indexes[N])
62 {
63   for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM])
64   {
65     calc(indexes);
66   }
67 }
68
69 template <int N, int DIM, typename Calc>
70 typename std::enable_if<DIM != N - 1, void>::type NDOpsHelperImpl(const NdArrayDesc<N> &output,
71                                                                   const Calc &calc, int indexes[N])
72 {
73   for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM])
74   {
75     NDOpsHelperImpl<N, DIM + 1, Calc>(output, calc, indexes);
76   }
77 }
78
79 // Execute the calc function in the innermost iteration based on the shape of
80 // the output. The calc function should take a single argument of type int[N].
81 template <int N, typename Calc>
82 inline void NDOpsHelper(const NdArrayDesc<N> &output, const Calc &calc)
83 {
84   int indexes[N] = {0};
85   NDOpsHelperImpl<N, 0, Calc>(output, calc, indexes);
86 }
87
88 template <int N>
89 inline void NdArrayDescsForElementwiseBroadcast(const luci_interpreter::RuntimeShape &input0_shape,
90                                                 const luci_interpreter::RuntimeShape &input1_shape,
91                                                 NdArrayDesc<N> *desc0_out,
92                                                 NdArrayDesc<N> *desc1_out)
93 {
94
95   auto extended_input0_shape = luci_interpreter::RuntimeShape::extendedShape(N, input0_shape);
96   auto extended_input1_shape = luci_interpreter::RuntimeShape::extendedShape(N, input1_shape);
97
98   // Copy dims to desc, calculating strides.
99   copyDimsToDesc<N>(extended_input0_shape, desc0_out);
100   copyDimsToDesc<N>(extended_input1_shape, desc1_out);
101
102   // Walk over each dimension. If the extents are equal do nothing.
103   // Otherwise, set the desc with extent 1 to have extent equal to the other and
104   // stride 0.
105   for (int i = 0; i < N; ++i)
106   {
107     const int extent0 = extended_input0_shape.dims(i);
108     const int extent1 = extended_input1_shape.dims(i);
109     if (extent0 != extent1)
110     {
111       if (extent0 == 1)
112       {
113         desc0_out->strides[i] = 0;
114         desc0_out->extents[i] = extent1;
115       }
116       else
117       {
118         desc1_out->strides[i] = 0;
119         desc1_out->extents[i] = extent0;
120       }
121     }
122   }
123 }
124
125 inline int subscriptToIndex(const NdArrayDesc<4> &desc, int i0, int i1, int i2, int i3)
126 {
127   return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + i3 * desc.strides[3];
128 }
129
130 inline int subscriptToIndex(const NdArrayDesc<5> &desc, int indexes[5])
131 {
132   return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
133          indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] + indexes[4] * desc.strides[4];
134 }
135
136 // Consolidates dimensions in broadcast inputs, checks for five-fold pattern.
137 //
138 // For example, if sequence of dimensions of one input is
139 // ..., 1, 3, 1, 7, 9, 5,... and the other is ..., 2, 3, 1, 7, 1, 1, ...
140 // we can consolidate these as
141 // ..., 1, 3*7, 9*5, ... and 2, 3*7, 1.
142 //
143 // The category is updated in the less-frequent case of shapes that are
144 // not suited to a fivefold-loop broadcast.
145 //
146 // Falls back to generic pattern when it does not know how to process properly.
147 //
148 // Returns true iff there is some sort of broadcast, which includes five-fold
149 // patterns and falling back to generic broadcast.
150 inline bool ProcessBroadcastShapes(const luci_interpreter::RuntimeShape &shape0,
151                                    const luci_interpreter::RuntimeShape &shape1,
152                                    luci_interpreter_pal::ArithmeticParams *params)
153 {
154   const int dims_count = std::max(shape0.dimensionsCount(), shape1.dimensionsCount());
155
156   params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
157
158   auto extended_shape0 = luci_interpreter::RuntimeShape::extendedShape(dims_count, shape0);
159   auto extended_shape1 = luci_interpreter::RuntimeShape::extendedShape(dims_count, shape1);
160
161   // Check for "exact" match, implicitly accepting any scalar shapes.
162   if (extended_shape0 == extended_shape1)
163   {
164     params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
165     return false;
166   }
167
168   if (shape0.flatSize() == 1)
169   {
170     params->broadcast_category = BroadcastableOpCategory::kScalarFirstBroadcast;
171     return true;
172   }
173   else if (shape1.flatSize() == 1)
174   {
175     params->broadcast_category = BroadcastableOpCategory::kScalarSecondBroadcast;
176     return true;
177   }
178
179   for (int i = dims_count - 1; i >= 0; --i)
180   {
181     if (extended_shape0.dims(i) == extended_shape1.dims(i))
182     {
183       continue;
184     }
185     else if (extended_shape0.dims(i) == 1)
186     {
187       params->broadcast_category = BroadcastableOpCategory::kFirstInputBroadcastsFast;
188       return true;
189     }
190     else if (extended_shape1.dims(i) == 1)
191     {
192       params->broadcast_category = BroadcastableOpCategory::kSecondInputBroadcastsFast;
193       return true;
194     }
195     else
196     {
197       // This case is erroneous: there is a dimension that does not match and
198       // is not a broadcast from one shape to the other.
199       params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
200       return true;
201     }
202   }
203
204   return false;
205 }
206
207 } // namespace luci_interpreter_pal
208
209 #endif // LUCI_INTERPRETER_PAL_PROCESS_BROADCAST_SHAPES_H