Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / Types.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_TYPES_H__
19 #define __NNFW_CKER_TYPES_H__
20
21 #include <cstdint>
22 #include <type_traits>
23 #include <limits>
24 #include <string>
25
26 namespace nnfw
27 {
28 namespace cker
29 {
30
31 enum class FusedActivationFunctionType
32 {
33   kNone = 0,
34   kRelu6 = 1,
35   kRelu1 = 2,
36   kRelu = 3,
37   kTanh = 4,
38   kSigmoid = 6,
39 };
40 enum class PaddingType
41 {
42   kNone = 0,
43   kSame = 1,
44   kValid = 2,
45 };
46
47 enum class BinaryArithmeticOpType
48 {
49   ADD = 0,
50   SUB = 1,
51   MUL = 2,
52   DIV = 3,
53   POW = 4,
54 };
55
56 enum class ComparisonOpType
57 {
58   Equal,
59   NotEqual,
60   Greater,
61   GreaterEqual,
62   Less,
63   LessEqual
64 };
65
66 struct PaddingValues
67 {
68   int16_t width;
69   int16_t height;
70 };
71
72 enum class BroadcastableOpCategory : uint8_t
73 {
74   kNone,
75   kNonBroadcast,              // Matching input shapes.
76   kFirstInputBroadcastsFast,  // Fivefold nested loops.
77   kSecondInputBroadcastsFast, // Fivefold nested loops.
78   kGenericBroadcast,          // Fall-back.
79 };
80
81 struct PoolParams
82 {
83   PaddingValues padding_values;
84   int stride_height;
85   int stride_width;
86   int filter_height;
87   int filter_width;
88   // uint8, etc, activation params.
89   int32_t quantized_activation_min;
90   int32_t quantized_activation_max;
91   // float activation params.
92   float float_activation_min;
93   float float_activation_max;
94 };
95
96 struct SoftmaxParams
97 {
98   // beta is not really used (not a Tensorflow parameter) and not implemented
99   // for LogSoftmax.
100   double beta;
101   int axis;
102   // uint8 inference params.  Used even when beta defaults to 1.0.
103   int32_t input_multiplier;
104   int32_t input_left_shift;
105   // Reverse scaling is only used by LogSoftmax.
106   int32_t reverse_scaling_divisor;
107   int32_t reverse_scaling_right_shift;
108   int diff_min;
109   int32_t zero_point;
110   float scale;
111   float *table;
112   uint8_t *uint8_table1;
113   uint8_t *uint8_table2;
114 };
115
116 struct PackParams
117 {
118   int8_t axis;
119   // zeropoint and scale were only used to implement PackWithScaling in the legacy code of
120   // tensorflow
121   // const int32_t* input_zeropoint;
122   // const float* input_scale;
123   uint16_t inputs_count;
124   // int32_t output_zeropoint;
125   // float output_scale;
126 };
127
128 struct UnpackParams
129 {
130   uint16_t num_split;
131   int16_t axis;
132 };
133
134 struct ConvParams
135 {
136   PaddingType padding_type;
137   PaddingValues padding_values;
138   // TODO(starka): This was just "stride", so check that width+height is OK.
139   int16_t stride_width;
140   int16_t stride_height;
141   int16_t dilation_width_factor;
142   int16_t dilation_height_factor;
143   // uint8_t inference params.
144   // TODO(b/65838351): Use smaller types if appropriate.
145   int32_t input_offset;
146   int32_t weights_offset;
147   int32_t output_offset;
148   int32_t output_multiplier;
149   int output_shift;
150   // uint8_t, etc, activation params.
151   int32_t quantized_activation_min;
152   int32_t quantized_activation_max;
153   // float activation params.
154   float float_activation_min;
155   float float_activation_max;
156   bool is_replaced_weights{false};
157 };
158
159 struct ComparisonParams
160 {
161   ComparisonOpType type;
162   int left_shift;
163   int input1_shift;
164   int input2_shift;
165   int32_t input1_offset;
166   int32_t input1_multiplier;
167   int32_t input2_offset;
168   int32_t input2_multiplier;
169   bool is_broadcast;
170 };
171
172 struct BinaryArithmeticOpParam
173 {
174   // Shape dependent / common to data / op types.
175   BroadcastableOpCategory broadcast_category{BroadcastableOpCategory::kNone};
176   // uint8 inference params.
177   int32_t input1_offset = 0;
178   int32_t input2_offset = 0;
179   int32_t output_offset = 0;
180   int32_t output_multiplier = 0;
181   int32_t output_shift = 0;
182   // Add / Sub, not Mul, uint8 inference params.
183   int32_t left_shift = 0;
184   int32_t input1_multiplier = 0;
185   int32_t input1_shift = 0;
186   int32_t input2_multiplier = 0;
187   int32_t input2_shift = 0;
188   // uint8, etc, activation params.
189   int32_t quantized_activation_min = 0;
190   int32_t quantized_activation_max = 0;
191   // float activation params.
192   float float_activation_min = 0;
193   float float_activation_max = 0;
194
195   // Processed output dimensions.
196   // Let input "a" be the one that broadcasts in the faster-changing dimension.
197   // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
198   // {b0, b1, b2, b3, b4},
199   // broadcast_shape[4] = b0 = a0.
200   // broadcast_shape[3] = b1; a1 = 1.
201   // broadcast_shape[2] = b2 = a2.
202   // broadcast_shape[1] = a3; b3 = 1.
203   // broadcast_shape[0] = b4 = a4.
204   int broadcast_shape[5] = {};
205 };
206
207 struct TransposeParams
208 {
209   int8_t perm_count;
210   int32_t perm[4];
211 };
212
213 struct ConcatenationParams
214 {
215   int8_t axis;
216   const int32_t *input_zeropoint;
217   const float *input_scale;
218   uint16_t inputs_count;
219   int32_t output_zeropoint;
220   float output_scale;
221 };
222
223 struct DepthwiseConvParams
224 {
225   PaddingType padding_type;
226   PaddingValues padding_values;
227   int16_t stride_width;
228   int16_t stride_height;
229   int16_t dilation_width_factor;
230   int16_t dilation_height_factor;
231   int16_t depth_multiplier;
232   // uint8 inference params.
233   // TODO(b/65838351): Use smaller types if appropriate.
234   int32_t input_offset;
235   int32_t weights_offset;
236   int32_t output_offset;
237   int32_t output_multiplier;
238   int output_shift;
239   // uint8, etc, activation params.
240   int32_t quantized_activation_min;
241   int32_t quantized_activation_max;
242   // float activation params.
243   float float_activation_min;
244   float float_activation_max;
245 };
246
247 struct FullyConnectedParams
248 {
249   FusedActivationFunctionType activation{FusedActivationFunctionType::kNone};
250   // uint8 inference params.
251   // TODO(b/65838351): Use smaller types if appropriate.
252   int32_t input_offset;
253   int32_t weights_offset;
254   float weights_scale;
255   int32_t output_offset;
256   int32_t output_multiplier;
257   int output_shift;
258   // uint8, etc, activation params.
259   int32_t quantized_activation_min;
260   int32_t quantized_activation_max;
261   // float activation params
262   float float_activation_min;
263   float float_activation_max;
264   // Mark the operands as cacheable if they are unchanging, e.g. weights.
265   bool lhs_cacheable;
266   bool rhs_cacheable;
267   // FullyConnectedWeightsFormat weights_format;
268 };
269
270 struct L2NormParams
271 {
272   // uint8 inference params.
273   int32_t input_zero_point;
274 };
275
276 enum LSTMKernelType
277 {
278   kTfLiteLSTMFullKernel = 0,
279   kTfLiteLSTMBasicKernel
280 };
281
282 struct LSTMParams
283 {
284   // Parameters for LSTM version 1.
285   FusedActivationFunctionType activation{FusedActivationFunctionType::kNone};
286   float cell_clip;
287   float proj_clip;
288
289   // Parameters for LSTM version 2.
290   // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
291   LSTMKernelType kernel_type;
292
293   // Parameters for LSTM version 4.
294   bool asymmetric_quantize_inputs;
295 };
296
297 struct GatherParams
298 {
299   int32_t axis;
300 };
301
302 struct InstanceNormParams
303 {
304   float epsilon;
305   float float_activation_min;
306   float float_activation_max;
307 };
308
309 struct ResizeBilinearParams
310 {
311   int32_t output_height;
312   int32_t output_width;
313   bool align_corners;
314   bool half_pixel_centers;
315 };
316
317 struct TransposeConvParams
318 {
319   PaddingType padding_type;
320   PaddingValues padding_values;
321   // TODO(starka): This was just "stride", so check that width+height is OK.
322   int16_t stride_width;
323   int16_t stride_height;
324   int16_t dilation_width_factor;
325   int16_t dilation_height_factor;
326   // uint8_t inference params.
327   // TODO(b/65838351): Use smaller types if appropriate.
328   int32_t input_offset;
329   int32_t weights_offset;
330   int32_t output_offset;
331   int32_t output_multiplier;
332   int output_shift;
333   // uint8_t, etc, activation params.
334   int32_t quantized_activation_min;
335   int32_t quantized_activation_max;
336   // float activation params.
337   float float_activation_min;
338   float float_activation_max;
339 };
340
341 struct SliceParams
342 {
343   int8_t begin_count;
344   int32_t begin[4];
345   int8_t size_count;
346   int32_t size[4];
347 };
348
349 struct StridedSliceParams
350 {
351   int8_t start_indices_count;
352   int16_t start_indices[4];
353   int8_t stop_indices_count;
354   int16_t stop_indices[4];
355   int8_t strides_count;
356   int16_t strides[4];
357
358   int16_t begin_mask;
359   int16_t ellipsis_mask;
360   int16_t end_mask;
361   int16_t new_axis_mask;
362   int16_t shrink_axis_mask;
363 };
364
365 struct SplitParams
366 {
367   uint16_t num_split;
368   int16_t axis;
369 };
370
371 struct SplitVParams
372 {
373   uint16_t num_split;
374   int16_t axis;
375 };
376
377 struct FusedBatchNormParams
378 {
379   bool is_training;
380   std::string data_format; // UNKNOWN(0), NHWC(1), NCHW(2)
381   float epsilon;
382 };
383
384 struct SpaceToBatchParams
385 {
386   // "Zero" padding for uint8 means padding with the output offset.
387   int32_t output_offset;
388 };
389
390 struct SpaceToDepthParams
391 {
392   int32_t block_size;
393 };
394
395 struct LeakyReluParams
396 {
397   float alpha;
398 };
399
400 enum class Order
401 {
402   kColMajor,
403   kRowMajor
404 };
405
406 enum class CachePolicy : std::uint8_t
407 {
408   kNeverCache,
409   kCacheIfLargeSpeedup,
410   kAlwaysCache,
411 };
412
413 // MatrixParams encapsulates the parameters that Gemm needs about each
414 // matrix, besides the buffer data pointer.
415 // Compare to ruy::Matrix, which also encapsulates the data pointer.
416 // Rationale for leaving the data pointer out of here: doing so
417 // requires complicated const-correctness mechanics. See
418 // ruy::ConstCheckingPtr.
419 template <typename Scalar> struct MatrixParams
420 {
421   // Storage layout order. For now we only do plain linear non-strided
422   // layout. It would be easy to support a stride if needed.
423   Order order = Order::kColMajor;
424   // Number of rows of the matrix.
425   int rows = 0;
426   // Number of columns of the matrix.
427   int cols = 0;
428   // The zero_point, i.e. which Scalar value is to be interpreted as zero.
429   // When Scalar is floating-point, this must be 0.
430   Scalar zero_point = 0;
431   // When the data pointed to by this matrix is constant data, so that it is
432   // valid to assume that equality of pointers implies equality of data,
433   // a CachePolicy may be used instead of the default kNeverCache,
434   // which will enable ruy to take advantage of this constancy of the data to
435   // cache the packing work, which can be a large speedup in matrix*vector
436   // and other narrow shapes.
437   CachePolicy cache_policy = CachePolicy::kNeverCache;
438 };
439
440 // Enumeration of broad categories of Gemm.
441 //
442 // The primary reason for this to exist is to allow Gemm to compile
443 // only uniform-quantized or only per-channel-quantized code paths.
444 // This is unneeded with ruy as the back-end, as this is only a runtime
445 // difference in ruy, but with gemmlowp these really are separate code
446 // paths and templatizing in a QuantizationFlavor is necessary to avoid
447 // compiling unused gemmlowp code. Indeed, TFLite currently uses
448 // uint8 with uniform quantization and int8 with per-channel quantization,
449 // and does not use uint8 with per-channel. We want to avoid compiling
450 // the gemmlowp uint8 per-channel path when gemmlowp is the back-end.
451 //
452 // It's possible to drop this in the future if gemmlowp goes away and no
453 // other then-relevant backend library handles quantized paths in a way that
454 // requires knowing this at compile-time.
455 enum class QuantizationFlavor
456 {
457   // Floating-point Gemm: the accumulators are not multiplied by any
458   // 'multiplier'.
459   kFloatingPoint,
460   // Quantized Gemm using a single multiplier for all accumulators.
461   kIntegerWithUniformMultiplier,
462   // Quantized Gemm using a separate multipliers for accumulators of each
463   // row of the destination matrix. This is what is called 'per-channel'
464   // in GemmParams. Here we use the more specific 'per-row' terminology
465   // to allow for the possibility of 'per-column' in the future, and to
466   // allow for that to be a separate code path in some back-end such as
467   // gemmlowp.
468   kIntegerWithPerRowMultiplier
469 };
470
471 // Additional parameters that Gemm needs, beyond what falls into
472 // the MatrixParams that it takes. Compare to ruy::Spec.
473 //
474 // Decoupling AccumScalar from DstScalar (rather than deducing it from that)
475 // is useful future-proofing. Think of a float16 path using float32 accum.
476 //
477 // QuantizationFlavor is passed here even though it's technically not used
478 // in this class. This is so that we retain the ability in the future to
479 // specialize this class for quantization flavor, and this allows for
480 // Gemm to be templatized in quantization_flavor via the GemmParams that it
481 // takes, allowing for automatic template parameter deduction to take place,
482 // so that most call sites don't need to specify a QuantizationFlavor
483 // (only those that need perchannel quantization do).
484 template <typename AccumScalar, typename DstScalar,
485           QuantizationFlavor quantization_flavor =
486             std::is_floating_point<AccumScalar>::value
487               ? QuantizationFlavor::kFloatingPoint
488               : QuantizationFlavor::kIntegerWithUniformMultiplier>
489 struct GemmParams
490 {
491   // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
492   // of the multiplier by which accumulators are multiplied before being casted
493   // to the destination type.
494   AccumScalar multiplier_fixedpoint = 0;
495   // Only for non-floating-point cases. The exponent part of the aforementioned
496   // multiplier.
497   int multiplier_exponent = 0;
498   // Per-channel variant of multiplier_fixedpoint. If not nullptr, this must
499   // point to a buffer of as many values as there are rows in the destination
500   // matrix. Each row of the destination matrix will use the corresponding
501   // buffer element instead of multiplier_fixedpoint.
502   const AccumScalar *multiplier_fixedpoint_perchannel = nullptr;
503   // Per-channel variant of multiplier_exponent. If not nullptr, this must
504   // point to a buffer of as many values as there are rows in the destination
505   // matrix. Each row of the destination matrix will use the corresponding
506   // buffer element instead of multiplier_exponent.
507   //
508   // Either none or both of multiplier_exponent_perchannel and
509   // multiplier_fixedpoint_perchannel must be nullptr.
510   const int *multiplier_exponent_perchannel = nullptr;
511   // The bias vector data, if not null.
512   const AccumScalar *bias = nullptr;
513   // min clamp bound of destination values.
514   DstScalar clamp_min = std::is_floating_point<DstScalar>::value
515                           ? -std::numeric_limits<DstScalar>::infinity()
516                           : std::numeric_limits<DstScalar>::lowest();
517   // max clamp bound of destination values.
518   DstScalar clamp_max = std::is_floating_point<DstScalar>::value
519                           ? std::numeric_limits<DstScalar>::infinity()
520                           : std::numeric_limits<DstScalar>::max();
521 };
522
523 // Validates self-consistency of GemmParams.
524 template <typename AccumScalar, typename DstScalar, QuantizationFlavor quantization_flavor>
525 void ValidateGemmParams(const GemmParams<AccumScalar, DstScalar, quantization_flavor> &params)
526 {
527   // Guard consistency of the quantized multiplier fields.
528   if (quantization_flavor == QuantizationFlavor::kFloatingPoint)
529   {
530     assert(!params.multiplier_fixedpoint);
531     assert(!params.multiplier_exponent);
532     assert(!params.multiplier_fixedpoint_perchannel);
533     assert(!params.multiplier_exponent_perchannel);
534   }
535   else if (quantization_flavor == QuantizationFlavor::kIntegerWithUniformMultiplier &&
536            !std::is_same<DstScalar, int32_t>::value)
537   {
538     assert(params.multiplier_fixedpoint);
539     // Nothing to check about multiplier_exponent
540     assert(!params.multiplier_fixedpoint_perchannel);
541     assert(!params.multiplier_exponent_perchannel);
542   }
543   else if (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier &&
544            !std::is_same<DstScalar, int32_t>::value)
545   {
546     assert(!params.multiplier_fixedpoint);
547     assert(!params.multiplier_exponent);
548     assert(params.multiplier_fixedpoint_perchannel);
549     assert(params.multiplier_exponent_perchannel);
550   }
551   else
552   {
553     // For the get raw accumulator case, we should make sure none of the
554     // quantization params are set.
555     assert(!params.multiplier_fixedpoint);
556     assert(!params.multiplier_exponent);
557     assert(!params.multiplier_fixedpoint_perchannel);
558     assert(!params.multiplier_exponent_perchannel);
559   }
560   UNUSED_RELEASE(params);
561 }
562
563 } // namespace cker
564 } // namespace nnfw
565
566 #endif // __NNFW_CKER_TYPES_H__