2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef LUCI_INTERPRETER_PAL_PROCESS_BROADCAST_SHAPES_H
19 #define LUCI_INTERPRETER_PAL_PROCESS_BROADCAST_SHAPES_H
21 namespace luci_interpreter_pal
24 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
27 // NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
28 // rectangular array of numbers.
30 // NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
31 // However, as Dims<N> is to be deprecated, this class exists as an adaptor
32 // to enable simple unoptimized implementations of element-wise broadcasting
34 template <int N> struct NdArrayDesc
36 // The "extent" of each dimension. Indices along dimension d must be in the
37 // half-open interval [0, extents[d]).
40 // The number of *elements* (not bytes) between consecutive indices of each
45 // Copies dims to desc, calculating strides.
47 inline void copyDimsToDesc(const luci_interpreter::RuntimeShape &input_shape,
48 NdArrayDesc<N> *desc_out)
51 for (int i = N - 1; i >= 0; --i)
53 desc_out->extents[i] = input_shape.dims(i);
54 desc_out->strides[i] = desc_stride;
55 desc_stride *= input_shape.dims(i);
59 template <int N, int DIM, typename Calc>
60 typename std::enable_if<DIM == N - 1, void>::type NDOpsHelperImpl(const NdArrayDesc<N> &output,
61 const Calc &calc, int indexes[N])
63 for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM])
69 template <int N, int DIM, typename Calc>
70 typename std::enable_if<DIM != N - 1, void>::type NDOpsHelperImpl(const NdArrayDesc<N> &output,
71 const Calc &calc, int indexes[N])
73 for (indexes[DIM] = 0; indexes[DIM] < output.extents[DIM]; ++indexes[DIM])
75 NDOpsHelperImpl<N, DIM + 1, Calc>(output, calc, indexes);
79 // Execute the calc function in the innermost iteration based on the shape of
80 // the output. The calc function should take a single argument of type int[N].
81 template <int N, typename Calc>
82 inline void NDOpsHelper(const NdArrayDesc<N> &output, const Calc &calc)
85 NDOpsHelperImpl<N, 0, Calc>(output, calc, indexes);
89 inline void NdArrayDescsForElementwiseBroadcast(const luci_interpreter::RuntimeShape &input0_shape,
90 const luci_interpreter::RuntimeShape &input1_shape,
91 NdArrayDesc<N> *desc0_out,
92 NdArrayDesc<N> *desc1_out)
95 auto extended_input0_shape = luci_interpreter::RuntimeShape::extendedShape(N, input0_shape);
96 auto extended_input1_shape = luci_interpreter::RuntimeShape::extendedShape(N, input1_shape);
98 // Copy dims to desc, calculating strides.
99 copyDimsToDesc<N>(extended_input0_shape, desc0_out);
100 copyDimsToDesc<N>(extended_input1_shape, desc1_out);
102 // Walk over each dimension. If the extents are equal do nothing.
103 // Otherwise, set the desc with extent 1 to have extent equal to the other and
105 for (int i = 0; i < N; ++i)
107 const int extent0 = extended_input0_shape.dims(i);
108 const int extent1 = extended_input1_shape.dims(i);
109 if (extent0 != extent1)
113 desc0_out->strides[i] = 0;
114 desc0_out->extents[i] = extent1;
118 desc1_out->strides[i] = 0;
119 desc1_out->extents[i] = extent0;
125 inline int subscriptToIndex(const NdArrayDesc<4> &desc, int i0, int i1, int i2, int i3)
127 return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + i3 * desc.strides[3];
130 inline int subscriptToIndex(const NdArrayDesc<5> &desc, int indexes[5])
132 return indexes[0] * desc.strides[0] + indexes[1] * desc.strides[1] +
133 indexes[2] * desc.strides[2] + indexes[3] * desc.strides[3] + indexes[4] * desc.strides[4];
136 // Consolidates dimensions in broadcast inputs, checks for five-fold pattern.
138 // For example, if sequence of dimensions of one input is
139 // ..., 1, 3, 1, 7, 9, 5,... and the other is ..., 2, 3, 1, 7, 1, 1, ...
140 // we can consolidate these as
141 // ..., 1, 3*7, 9*5, ... and 2, 3*7, 1.
143 // The category is updated in the less-frequent case of shapes that are
144 // not suited to a fivefold-loop broadcast.
146 // Falls back to generic pattern when it does not know how to process properly.
148 // Returns true iff there is some sort of broadcast, which includes five-fold
149 // patterns and falling back to generic broadcast.
150 inline bool ProcessBroadcastShapes(const luci_interpreter::RuntimeShape &shape0,
151 const luci_interpreter::RuntimeShape &shape1,
152 luci_interpreter_pal::ArithmeticParams *params)
154 const int dims_count = std::max(shape0.dimensionsCount(), shape1.dimensionsCount());
156 params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
158 auto extended_shape0 = luci_interpreter::RuntimeShape::extendedShape(dims_count, shape0);
159 auto extended_shape1 = luci_interpreter::RuntimeShape::extendedShape(dims_count, shape1);
161 // Check for "exact" match, implicitly accepting any scalar shapes.
162 if (extended_shape0 == extended_shape1)
164 params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
168 if (shape0.flatSize() == 1)
170 params->broadcast_category = BroadcastableOpCategory::kScalarFirstBroadcast;
173 else if (shape1.flatSize() == 1)
175 params->broadcast_category = BroadcastableOpCategory::kScalarSecondBroadcast;
179 for (int i = dims_count - 1; i >= 0; --i)
181 if (extended_shape0.dims(i) == extended_shape1.dims(i))
185 else if (extended_shape0.dims(i) == 1)
187 params->broadcast_category = BroadcastableOpCategory::kFirstInputBroadcastsFast;
190 else if (extended_shape1.dims(i) == 1)
192 params->broadcast_category = BroadcastableOpCategory::kSecondInputBroadcastsFast;
197 // This case is erroneous: there is a dimension that does not match and
198 // is not a broadcast from one shape to the other.
199 params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
207 } // namespace luci_interpreter_pal
209 #endif // LUCI_INTERPRETER_PAL_PROCESS_BROADCAST_SHAPES_H