Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / Types.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_TYPES_H__
19 #define __NNFW_CKER_TYPES_H__
20
21 #include <cstdint>
22 #include <type_traits>
23 #include <limits>
24 #include <string>
25
26 namespace nnfw
27 {
28 namespace cker
29 {
30
31 enum class FusedActivationFunctionType
32 {
33   kNone = 0,
34   kRelu6 = 1,
35   kRelu1 = 2,
36   kRelu = 3,
37 };
38 enum class PaddingType
39 {
40   kNone = 0,
41   kSame = 1,
42   kValid = 2,
43 };
44
45 enum class BinaryArithmeticOpType
46 {
47   ADD = 0,
48   SUB = 1,
49   MUL = 2,
50   DIV = 3,
51   POW = 4,
52 };
53
54 enum class ComparisonOpType
55 {
56   Equal,
57   NotEqual,
58   Greater,
59   GreaterEqual,
60   Less,
61   LessEqual
62 };
63
64 struct PaddingValues
65 {
66   int16_t width;
67   int16_t height;
68 };
69
70 enum class BroadcastableOpCategory : uint8_t
71 {
72   kNone,
73   kNonBroadcast,              // Matching input shapes.
74   kFirstInputBroadcastsFast,  // Fivefold nested loops.
75   kSecondInputBroadcastsFast, // Fivefold nested loops.
76   kGenericBroadcast,          // Fall-back.
77 };
78
79 struct PoolParams
80 {
81   FusedActivationFunctionType activation;
82   PaddingType padding_type;
83   PaddingValues padding_values;
84   int stride_height;
85   int stride_width;
86   int filter_height;
87   int filter_width;
88   // uint8, etc, activation params.
89   int32_t quantized_activation_min;
90   int32_t quantized_activation_max;
91   // float activation params.
92   float float_activation_min;
93   float float_activation_max;
94 };
95
96 struct SoftmaxParams
97 {
98   // beta is not really used (not a Tensorflow parameter) and not implemented
99   // for LogSoftmax.
100   double beta;
101   int axis;
102   // uint8 inference params.  Used even when beta defaults to 1.0.
103   int32_t input_multiplier;
104   int32_t input_left_shift;
105   // Reverse scaling is only used by LogSoftmax.
106   int32_t reverse_scaling_divisor;
107   int32_t reverse_scaling_right_shift;
108   int diff_min;
109 };
110
111 struct PackParams
112 {
113   int8_t axis;
114   // zeropoint and scale were only used to implement PackWithScaling in the legacy code of
115   // tensorflow
116   // const int32_t* input_zeropoint;
117   // const float* input_scale;
118   uint16_t inputs_count;
119   // int32_t output_zeropoint;
120   // float output_scale;
121 };
122
123 struct UnpackParams
124 {
125   uint16_t num_split;
126   int16_t axis;
127 };
128
129 struct ConvParams
130 {
131   PaddingType padding_type;
132   PaddingValues padding_values;
133   // TODO(starka): This was just "stride", so check that width+height is OK.
134   int16_t stride_width;
135   int16_t stride_height;
136   int16_t dilation_width_factor;
137   int16_t dilation_height_factor;
138   // uint8_t inference params.
139   // TODO(b/65838351): Use smaller types if appropriate.
140   int32_t input_offset;
141   int32_t weights_offset;
142   int32_t output_offset;
143   int32_t output_multiplier;
144   int output_shift;
145   // uint8_t, etc, activation params.
146   int32_t quantized_activation_min;
147   int32_t quantized_activation_max;
148   // float activation params.
149   float float_activation_min;
150   float float_activation_max;
151   bool is_replaced_weights{false};
152 };
153
154 struct ComparisonParams
155 {
156   ComparisonOpType type;
157   int left_shift;
158   int input1_shift;
159   int input2_shift;
160   int32_t input1_offset;
161   int32_t input1_multiplier;
162   int32_t input2_offset;
163   int32_t input2_multiplier;
164   bool is_broadcast;
165 };
166
167 struct BinaryArithmeticOpParam
168 {
169   // Shape dependent / common to data / op types.
170   BroadcastableOpCategory broadcast_category;
171   // uint8 inference params.
172   int32_t input1_offset;
173   int32_t input2_offset;
174   int32_t output_offset;
175   int32_t output_multiplier;
176   int32_t output_shift;
177   // Add / Sub, not Mul, uint8 inference params.
178   int32_t left_shift;
179   int32_t input1_multiplier;
180   int32_t input1_shift;
181   int32_t input2_multiplier;
182   int32_t input2_shift;
183   // uint8, etc, activation params.
184   int32_t quantized_activation_min;
185   int32_t quantized_activation_max;
186   // float activation params.
187   float float_activation_min;
188   float float_activation_max;
189
190   // Processed output dimensions.
191   // Let input "a" be the one that broadcasts in the faster-changing dimension.
192   // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
193   // {b0, b1, b2, b3, b4},
194   // broadcast_shape[4] = b0 = a0.
195   // broadcast_shape[3] = b1; a1 = 1.
196   // broadcast_shape[2] = b2 = a2.
197   // broadcast_shape[1] = a3; b3 = 1.
198   // broadcast_shape[0] = b4 = a4.
199   int broadcast_shape[5] = {};
200 };
201
202 struct TransposeParams
203 {
204   int8_t perm_count;
205   int32_t perm[4];
206 };
207
208 struct ConcatenationParams
209 {
210   int8_t axis;
211   const int32_t *input_zeropoint;
212   const float *input_scale;
213   uint16_t inputs_count;
214   int32_t output_zeropoint;
215   float output_scale;
216 };
217
218 struct DepthwiseConvParams
219 {
220   PaddingType padding_type;
221   PaddingValues padding_values;
222   int16_t stride_width;
223   int16_t stride_height;
224   int16_t dilation_width_factor;
225   int16_t dilation_height_factor;
226   int16_t depth_multiplier;
227   // uint8 inference params.
228   // TODO(b/65838351): Use smaller types if appropriate.
229   int32_t input_offset;
230   int32_t weights_offset;
231   int32_t output_offset;
232   int32_t output_multiplier;
233   int output_shift;
234   // uint8, etc, activation params.
235   int32_t quantized_activation_min;
236   int32_t quantized_activation_max;
237   // float activation params.
238   float float_activation_min;
239   float float_activation_max;
240 };
241
242 struct FullyConnectedParams
243 {
244   FusedActivationFunctionType activation{FusedActivationFunctionType::kNone};
245   // uint8 inference params.
246   // TODO(b/65838351): Use smaller types if appropriate.
247   int32_t input_offset;
248   int32_t weights_offset;
249   float weights_scale;
250   int32_t output_offset;
251   int32_t output_multiplier;
252   int output_shift;
253   // uint8, etc, activation params.
254   int32_t quantized_activation_min;
255   int32_t quantized_activation_max;
256   // float activation params.
257   float float_activation_min;
258   float float_activation_max;
259   // FullyConnectedWeightsFormat weights_format;
260 };
261
262 struct L2NormParams
263 {
264   // uint8 inference params.
265   int32_t input_zero_point;
266 };
267
268 struct GatherParams
269 {
270   int32_t axis;
271 };
272
273 struct InstanceNormParams
274 {
275   float epsilon;
276   float float_activation_min;
277   float float_activation_max;
278 };
279
280 struct ResizeBilinearParams
281 {
282   int32_t output_height;
283   int32_t output_width;
284   bool align_corners;
285   bool half_pixel_centers;
286 };
287
288 struct TransposeConvParams
289 {
290   PaddingType padding_type;
291   PaddingValues padding_values;
292   // TODO(starka): This was just "stride", so check that width+height is OK.
293   int16_t stride_width;
294   int16_t stride_height;
295   int16_t dilation_width_factor;
296   int16_t dilation_height_factor;
297   // uint8_t inference params.
298   // TODO(b/65838351): Use smaller types if appropriate.
299   int32_t input_offset;
300   int32_t weights_offset;
301   int32_t output_offset;
302   int32_t output_multiplier;
303   int output_shift;
304   // uint8_t, etc, activation params.
305   int32_t quantized_activation_min;
306   int32_t quantized_activation_max;
307   // float activation params.
308   float float_activation_min;
309   float float_activation_max;
310 };
311
312 struct SliceParams
313 {
314   int8_t begin_count;
315   int32_t begin[4];
316   int8_t size_count;
317   int32_t size[4];
318 };
319
320 struct StridedSliceParams
321 {
322   int8_t start_indices_count;
323   int16_t start_indices[4];
324   int8_t stop_indices_count;
325   int16_t stop_indices[4];
326   int8_t strides_count;
327   int16_t strides[4];
328
329   int16_t begin_mask;
330   int16_t ellipsis_mask;
331   int16_t end_mask;
332   int16_t new_axis_mask;
333   int16_t shrink_axis_mask;
334 };
335
336 struct SplitParams
337 {
338   uint16_t num_split;
339   int16_t axis;
340 };
341
342 struct SplitVParams
343 {
344   uint16_t num_split;
345   int16_t axis;
346 };
347
348 struct FusedBatchNormParams
349 {
350   bool is_training;
351   std::string data_format; // UNKNOWN(0), NHWC(1), NCHW(2)
352   float epsilon;
353 };
354
355 struct SpaceToBatchParams
356 {
357   // "Zero" padding for uint8 means padding with the output offset.
358   int32_t output_offset;
359 };
360
361 struct SpaceToDepthParams
362 {
363   int32_t block_size;
364 };
365
366 enum class Order
367 {
368   kColMajor,
369   kRowMajor
370 };
371
372 // MatrixParams encapsulates the parameters that Gemm needs about each
373 // matrix, besides the buffer data pointer.
374 // Compare to ruy::Matrix, which also encapsulates the data pointer.
375 // Rationale for leaving the data pointer out of here: doing so
376 // requires complicated const-correctness mechanics. See
377 // ruy::ConstCheckingPtr.
378 template <typename Scalar> struct MatrixParams
379 {
380   // Storage layout order. For now we only do plain linear non-strided
381   // layout. It would be easy to support a stride if needed.
382   Order order = Order::kColMajor;
383   // Number of rows of the matrix.
384   int rows = 0;
385   // Number of columns of the matrix.
386   int cols = 0;
387   // The zero_point, i.e. which Scalar value is to be interpreted as zero.
388   // When Scalar is floating-point, this must be 0.
389   Scalar zero_point = 0;
390   // Indicate whether the underlying data will remain unchanged for
391   // some period of time. Defaults to false, but should be set to true
392   // for unchanging data (e.g. weights buffers in many cases)
393   bool cacheable = false;
394 };
395
396 // Enumeration of broad categories of Gemm.
397 //
398 // The primary reason for this to exist is to allow Gemm to compile
399 // only uniform-quantized or only per-channel-quantized code paths.
400 // This is unneeded with ruy as the back-end, as this is only a runtime
401 // difference in ruy, but with gemmlowp these really are separate code
402 // paths and templatizing in a QuantizationFlavor is necessary to avoid
403 // compiling unused gemmlowp code. Indeed, TFLite currently uses
404 // uint8 with uniform quantization and int8 with per-channel quantization,
405 // and does not use uint8 with per-channel. We want to avoid compiling
406 // the gemmlowp uint8 per-channel path when gemmlowp is the back-end.
407 //
408 // It's possible to drop this in the future if gemmlowp goes away and no
409 // other then-relevant backend library handles quantized paths in a way that
410 // requires knowing this at compile-time.
411 enum class QuantizationFlavor
412 {
413   // Floating-point Gemm: the accumulators are not multiplied by any
414   // 'multiplier'.
415   kFloatingPoint,
416   // Quantized Gemm using a single multiplier for all accumulators.
417   kIntegerWithUniformMultiplier,
418   // Quantized Gemm using a separate multipliers for accumulators of each
419   // row of the destination matrix. This is what is called 'per-channel'
420   // in GemmParams. Here we use the more specific 'per-row' terminology
421   // to allow for the possibility of 'per-column' in the future, and to
422   // allow for that to be a separate code path in some back-end such as
423   // gemmlowp.
424   kIntegerWithPerRowMultiplier
425 };
426
427 // Additional parameters that Gemm needs, beyond what falls into
428 // the MatrixParams that it takes. Compare to ruy::Spec.
429 //
430 // Decoupling AccumScalar from DstScalar (rather than deducing it from that)
431 // is useful future-proofing. Think of a float16 path using float32 accum.
432 //
433 // QuantizationFlavor is passed here even though it's technically not used
434 // in this class. This is so that we retain the ability in the future to
435 // specialize this class for quantization flavor, and this allows for
436 // Gemm to be templatized in quantization_flavor via the GemmParams that it
437 // takes, allowing for automatic template parameter deduction to take place,
438 // so that most call sites don't need to specify a QuantizationFlavor
439 // (only those that need perchannel quantization do).
440 template <typename AccumScalar, typename DstScalar,
441           QuantizationFlavor quantization_flavor =
442               std::is_floating_point<AccumScalar>::value
443                   ? QuantizationFlavor::kFloatingPoint
444                   : QuantizationFlavor::kIntegerWithUniformMultiplier>
445 struct GemmParams
446 {
447   // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
448   // of the multiplier by which accumulators are multiplied before being casted
449   // to the destination type.
450   AccumScalar multiplier_fixedpoint = 0;
451   // Only for non-floating-point cases. The exponent part of the aforementioned
452   // multiplier.
453   int multiplier_exponent = 0;
454   // Per-channel variant of multiplier_fixedpoint. If not nullptr, this must
455   // point to a buffer of as many values as there are rows in the destination
456   // matrix. Each row of the destination matrix will use the corresponding
457   // buffer element instead of multiplier_fixedpoint.
458   const AccumScalar *multiplier_fixedpoint_perchannel = nullptr;
459   // Per-channel variant of multiplier_exponent. If not nullptr, this must
460   // point to a buffer of as many values as there are rows in the destination
461   // matrix. Each row of the destination matrix will use the corresponding
462   // buffer element instead of multiplier_exponent.
463   //
464   // Either none or both of multiplier_exponent_perchannel and
465   // multiplier_fixedpoint_perchannel must be nullptr.
466   const int *multiplier_exponent_perchannel = nullptr;
467   // The bias vector data, if not null.
468   const AccumScalar *bias = nullptr;
469   // min clamp bound of destination values.
470   DstScalar clamp_min = std::is_floating_point<DstScalar>::value
471                             ? -std::numeric_limits<DstScalar>::infinity()
472                             : std::numeric_limits<DstScalar>::lowest();
473   // max clamp bound of destination values.
474   DstScalar clamp_max = std::is_floating_point<DstScalar>::value
475                             ? std::numeric_limits<DstScalar>::infinity()
476                             : std::numeric_limits<DstScalar>::max();
477 };
478
479 // Validates self-consistency of GemmParams.
480 template <typename AccumScalar, typename DstScalar, QuantizationFlavor quantization_flavor>
481 void ValidateGemmParams(const GemmParams<AccumScalar, DstScalar, quantization_flavor> &params)
482 {
483   // Guard consistency of the quantized multiplier fields.
484   if (quantization_flavor == QuantizationFlavor::kFloatingPoint)
485   {
486     assert(!params.multiplier_fixedpoint);
487     assert(!params.multiplier_exponent);
488     assert(!params.multiplier_fixedpoint_perchannel);
489     assert(!params.multiplier_exponent_perchannel);
490   }
491   else if (quantization_flavor == QuantizationFlavor::kIntegerWithUniformMultiplier &&
492            !std::is_same<DstScalar, int32_t>::value)
493   {
494     assert(params.multiplier_fixedpoint);
495     // Nothing to check about multiplier_exponent
496     assert(!params.multiplier_fixedpoint_perchannel);
497     assert(!params.multiplier_exponent_perchannel);
498   }
499   else if (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier &&
500            !std::is_same<DstScalar, int32_t>::value)
501   {
502     assert(!params.multiplier_fixedpoint);
503     assert(!params.multiplier_exponent);
504     assert(params.multiplier_fixedpoint_perchannel);
505     assert(params.multiplier_exponent_perchannel);
506   }
507   else
508   {
509     // For the get raw accumulator case, we should make sure none of the
510     // quantization params are set.
511     assert(!params.multiplier_fixedpoint);
512     assert(!params.multiplier_exponent);
513     assert(!params.multiplier_fixedpoint_perchannel);
514     assert(!params.multiplier_exponent_perchannel);
515   }
516   UNUSED_RELEASE(params);
517 }
518
519 } // namespace cker
520 } // namespace nnfw
521
522 #endif // __NNFW_CKER_TYPES_H__