Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compute / cker / include / cker / Types.h
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef __NNFW_CKER_TYPES_H__
19 #define __NNFW_CKER_TYPES_H__
20
21 #include <cstdint>
22 #include <type_traits>
23 #include <limits>
24 #include <string>
25
26 namespace nnfw
27 {
28 namespace cker
29 {
30
31 enum class FusedActivationFunctionType
32 {
33   kNone = 0,
34   kRelu6 = 1,
35   kRelu1 = 2,
36   kRelu = 3,
37 };
38 enum class PaddingType
39 {
40   kNone = 0,
41   kSame = 1,
42   kValid = 2,
43 };
44
45 enum class BinaryArithmeticOpType
46 {
47   ADD = 0,
48   SUB = 1,
49   MUL = 2,
50   DIV = 3,
51   POW = 4,
52 };
53
54 enum class ComparisonOpType
55 {
56   Equal,
57   NotEqual,
58   Greater,
59   GreaterEqual,
60   Less,
61   LessEqual
62 };
63
64 struct PaddingValues
65 {
66   int16_t width;
67   int16_t height;
68 };
69
70 enum class BroadcastableOpCategory : uint8_t
71 {
72   kNone,
73   kNonBroadcast,              // Matching input shapes.
74   kFirstInputBroadcastsFast,  // Fivefold nested loops.
75   kSecondInputBroadcastsFast, // Fivefold nested loops.
76   kGenericBroadcast,          // Fall-back.
77 };
78
79 struct PoolParams
80 {
81   FusedActivationFunctionType activation;
82   PaddingType padding_type;
83   PaddingValues padding_values;
84   int stride_height;
85   int stride_width;
86   int filter_height;
87   int filter_width;
88   // uint8, etc, activation params.
89   int32_t quantized_activation_min;
90   int32_t quantized_activation_max;
91   // float activation params.
92   float float_activation_min;
93   float float_activation_max;
94 };
95
96 struct SoftmaxParams
97 {
98   // beta is not really used (not a Tensorflow parameter) and not implemented
99   // for LogSoftmax.
100   double beta;
101   int axis;
102   // uint8 inference params.  Used even when beta defaults to 1.0.
103   int32_t input_multiplier;
104   int32_t input_left_shift;
105   // Reverse scaling is only used by LogSoftmax.
106   int32_t reverse_scaling_divisor;
107   int32_t reverse_scaling_right_shift;
108   int diff_min;
109   int32_t zero_point;
110   float scale;
111   float *table;
112 };
113
114 struct PackParams
115 {
116   int8_t axis;
117   // zeropoint and scale were only used to implement PackWithScaling in the legacy code of
118   // tensorflow
119   // const int32_t* input_zeropoint;
120   // const float* input_scale;
121   uint16_t inputs_count;
122   // int32_t output_zeropoint;
123   // float output_scale;
124 };
125
126 struct UnpackParams
127 {
128   uint16_t num_split;
129   int16_t axis;
130 };
131
132 struct ConvParams
133 {
134   PaddingType padding_type;
135   PaddingValues padding_values;
136   // TODO(starka): This was just "stride", so check that width+height is OK.
137   int16_t stride_width;
138   int16_t stride_height;
139   int16_t dilation_width_factor;
140   int16_t dilation_height_factor;
141   // uint8_t inference params.
142   // TODO(b/65838351): Use smaller types if appropriate.
143   int32_t input_offset;
144   int32_t weights_offset;
145   int32_t output_offset;
146   int32_t output_multiplier;
147   int output_shift;
148   // uint8_t, etc, activation params.
149   int32_t quantized_activation_min;
150   int32_t quantized_activation_max;
151   // float activation params.
152   float float_activation_min;
153   float float_activation_max;
154   bool is_replaced_weights{false};
155 };
156
157 struct ComparisonParams
158 {
159   ComparisonOpType type;
160   int left_shift;
161   int input1_shift;
162   int input2_shift;
163   int32_t input1_offset;
164   int32_t input1_multiplier;
165   int32_t input2_offset;
166   int32_t input2_multiplier;
167   bool is_broadcast;
168 };
169
170 struct BinaryArithmeticOpParam
171 {
172   // Shape dependent / common to data / op types.
173   BroadcastableOpCategory broadcast_category;
174   // uint8 inference params.
175   int32_t input1_offset;
176   int32_t input2_offset;
177   int32_t output_offset;
178   int32_t output_multiplier;
179   int32_t output_shift;
180   // Add / Sub, not Mul, uint8 inference params.
181   int32_t left_shift;
182   int32_t input1_multiplier;
183   int32_t input1_shift;
184   int32_t input2_multiplier;
185   int32_t input2_shift;
186   // uint8, etc, activation params.
187   int32_t quantized_activation_min;
188   int32_t quantized_activation_max;
189   // float activation params.
190   float float_activation_min;
191   float float_activation_max;
192
193   // Processed output dimensions.
194   // Let input "a" be the one that broadcasts in the faster-changing dimension.
195   // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
196   // {b0, b1, b2, b3, b4},
197   // broadcast_shape[4] = b0 = a0.
198   // broadcast_shape[3] = b1; a1 = 1.
199   // broadcast_shape[2] = b2 = a2.
200   // broadcast_shape[1] = a3; b3 = 1.
201   // broadcast_shape[0] = b4 = a4.
202   int broadcast_shape[5] = {};
203 };
204
205 struct TransposeParams
206 {
207   int8_t perm_count;
208   int32_t perm[4];
209 };
210
211 struct ConcatenationParams
212 {
213   int8_t axis;
214   const int32_t *input_zeropoint;
215   const float *input_scale;
216   uint16_t inputs_count;
217   int32_t output_zeropoint;
218   float output_scale;
219 };
220
221 struct DepthwiseConvParams
222 {
223   PaddingType padding_type;
224   PaddingValues padding_values;
225   int16_t stride_width;
226   int16_t stride_height;
227   int16_t dilation_width_factor;
228   int16_t dilation_height_factor;
229   int16_t depth_multiplier;
230   // uint8 inference params.
231   // TODO(b/65838351): Use smaller types if appropriate.
232   int32_t input_offset;
233   int32_t weights_offset;
234   int32_t output_offset;
235   int32_t output_multiplier;
236   int output_shift;
237   // uint8, etc, activation params.
238   int32_t quantized_activation_min;
239   int32_t quantized_activation_max;
240   // float activation params.
241   float float_activation_min;
242   float float_activation_max;
243 };
244
245 struct FullyConnectedParams
246 {
247   FusedActivationFunctionType activation{FusedActivationFunctionType::kNone};
248   // uint8 inference params.
249   // TODO(b/65838351): Use smaller types if appropriate.
250   int32_t input_offset;
251   int32_t weights_offset;
252   float weights_scale;
253   int32_t output_offset;
254   int32_t output_multiplier;
255   int output_shift;
256   // uint8, etc, activation params.
257   int32_t quantized_activation_min;
258   int32_t quantized_activation_max;
259   // float activation params.
260   float float_activation_min;
261   float float_activation_max;
262   // FullyConnectedWeightsFormat weights_format;
263 };
264
265 struct L2NormParams
266 {
267   // uint8 inference params.
268   int32_t input_zero_point;
269 };
270
271 struct GatherParams
272 {
273   int32_t axis;
274 };
275
276 struct InstanceNormParams
277 {
278   float epsilon;
279   float float_activation_min;
280   float float_activation_max;
281 };
282
283 struct ResizeBilinearParams
284 {
285   int32_t output_height;
286   int32_t output_width;
287   bool align_corners;
288   bool half_pixel_centers;
289 };
290
291 struct TransposeConvParams
292 {
293   PaddingType padding_type;
294   PaddingValues padding_values;
295   // TODO(starka): This was just "stride", so check that width+height is OK.
296   int16_t stride_width;
297   int16_t stride_height;
298   int16_t dilation_width_factor;
299   int16_t dilation_height_factor;
300   // uint8_t inference params.
301   // TODO(b/65838351): Use smaller types if appropriate.
302   int32_t input_offset;
303   int32_t weights_offset;
304   int32_t output_offset;
305   int32_t output_multiplier;
306   int output_shift;
307   // uint8_t, etc, activation params.
308   int32_t quantized_activation_min;
309   int32_t quantized_activation_max;
310   // float activation params.
311   float float_activation_min;
312   float float_activation_max;
313 };
314
315 struct SliceParams
316 {
317   int8_t begin_count;
318   int32_t begin[4];
319   int8_t size_count;
320   int32_t size[4];
321 };
322
323 struct StridedSliceParams
324 {
325   int8_t start_indices_count;
326   int16_t start_indices[4];
327   int8_t stop_indices_count;
328   int16_t stop_indices[4];
329   int8_t strides_count;
330   int16_t strides[4];
331
332   int16_t begin_mask;
333   int16_t ellipsis_mask;
334   int16_t end_mask;
335   int16_t new_axis_mask;
336   int16_t shrink_axis_mask;
337 };
338
339 struct SplitParams
340 {
341   uint16_t num_split;
342   int16_t axis;
343 };
344
345 struct SplitVParams
346 {
347   uint16_t num_split;
348   int16_t axis;
349 };
350
351 struct FusedBatchNormParams
352 {
353   bool is_training;
354   std::string data_format; // UNKNOWN(0), NHWC(1), NCHW(2)
355   float epsilon;
356 };
357
358 struct SpaceToBatchParams
359 {
360   // "Zero" padding for uint8 means padding with the output offset.
361   int32_t output_offset;
362 };
363
364 struct SpaceToDepthParams
365 {
366   int32_t block_size;
367 };
368
369 enum class Order
370 {
371   kColMajor,
372   kRowMajor
373 };
374
375 // MatrixParams encapsulates the parameters that Gemm needs about each
376 // matrix, besides the buffer data pointer.
377 // Compare to ruy::Matrix, which also encapsulates the data pointer.
378 // Rationale for leaving the data pointer out of here: doing so
379 // requires complicated const-correctness mechanics. See
380 // ruy::ConstCheckingPtr.
381 template <typename Scalar> struct MatrixParams
382 {
383   // Storage layout order. For now we only do plain linear non-strided
384   // layout. It would be easy to support a stride if needed.
385   Order order = Order::kColMajor;
386   // Number of rows of the matrix.
387   int rows = 0;
388   // Number of columns of the matrix.
389   int cols = 0;
390   // The zero_point, i.e. which Scalar value is to be interpreted as zero.
391   // When Scalar is floating-point, this must be 0.
392   Scalar zero_point = 0;
393   // Indicate whether the underlying data will remain unchanged for
394   // some period of time. Defaults to false, but should be set to true
395   // for unchanging data (e.g. weights buffers in many cases)
396   bool cacheable = false;
397 };
398
399 // Enumeration of broad categories of Gemm.
400 //
401 // The primary reason for this to exist is to allow Gemm to compile
402 // only uniform-quantized or only per-channel-quantized code paths.
403 // This is unneeded with ruy as the back-end, as this is only a runtime
404 // difference in ruy, but with gemmlowp these really are separate code
405 // paths and templatizing in a QuantizationFlavor is necessary to avoid
406 // compiling unused gemmlowp code. Indeed, TFLite currently uses
407 // uint8 with uniform quantization and int8 with per-channel quantization,
408 // and does not use uint8 with per-channel. We want to avoid compiling
409 // the gemmlowp uint8 per-channel path when gemmlowp is the back-end.
410 //
411 // It's possible to drop this in the future if gemmlowp goes away and no
412 // other then-relevant backend library handles quantized paths in a way that
413 // requires knowing this at compile-time.
414 enum class QuantizationFlavor
415 {
416   // Floating-point Gemm: the accumulators are not multiplied by any
417   // 'multiplier'.
418   kFloatingPoint,
419   // Quantized Gemm using a single multiplier for all accumulators.
420   kIntegerWithUniformMultiplier,
421   // Quantized Gemm using a separate multipliers for accumulators of each
422   // row of the destination matrix. This is what is called 'per-channel'
423   // in GemmParams. Here we use the more specific 'per-row' terminology
424   // to allow for the possibility of 'per-column' in the future, and to
425   // allow for that to be a separate code path in some back-end such as
426   // gemmlowp.
427   kIntegerWithPerRowMultiplier
428 };
429
430 // Additional parameters that Gemm needs, beyond what falls into
431 // the MatrixParams that it takes. Compare to ruy::Spec.
432 //
433 // Decoupling AccumScalar from DstScalar (rather than deducing it from that)
434 // is useful future-proofing. Think of a float16 path using float32 accum.
435 //
436 // QuantizationFlavor is passed here even though it's technically not used
437 // in this class. This is so that we retain the ability in the future to
438 // specialize this class for quantization flavor, and this allows for
439 // Gemm to be templatized in quantization_flavor via the GemmParams that it
440 // takes, allowing for automatic template parameter deduction to take place,
441 // so that most call sites don't need to specify a QuantizationFlavor
442 // (only those that need perchannel quantization do).
443 template <typename AccumScalar, typename DstScalar,
444           QuantizationFlavor quantization_flavor =
445               std::is_floating_point<AccumScalar>::value
446                   ? QuantizationFlavor::kFloatingPoint
447                   : QuantizationFlavor::kIntegerWithUniformMultiplier>
448 struct GemmParams
449 {
450   // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
451   // of the multiplier by which accumulators are multiplied before being casted
452   // to the destination type.
453   AccumScalar multiplier_fixedpoint = 0;
454   // Only for non-floating-point cases. The exponent part of the aforementioned
455   // multiplier.
456   int multiplier_exponent = 0;
457   // Per-channel variant of multiplier_fixedpoint. If not nullptr, this must
458   // point to a buffer of as many values as there are rows in the destination
459   // matrix. Each row of the destination matrix will use the corresponding
460   // buffer element instead of multiplier_fixedpoint.
461   const AccumScalar *multiplier_fixedpoint_perchannel = nullptr;
462   // Per-channel variant of multiplier_exponent. If not nullptr, this must
463   // point to a buffer of as many values as there are rows in the destination
464   // matrix. Each row of the destination matrix will use the corresponding
465   // buffer element instead of multiplier_exponent.
466   //
467   // Either none or both of multiplier_exponent_perchannel and
468   // multiplier_fixedpoint_perchannel must be nullptr.
469   const int *multiplier_exponent_perchannel = nullptr;
470   // The bias vector data, if not null.
471   const AccumScalar *bias = nullptr;
472   // min clamp bound of destination values.
473   DstScalar clamp_min = std::is_floating_point<DstScalar>::value
474                             ? -std::numeric_limits<DstScalar>::infinity()
475                             : std::numeric_limits<DstScalar>::lowest();
476   // max clamp bound of destination values.
477   DstScalar clamp_max = std::is_floating_point<DstScalar>::value
478                             ? std::numeric_limits<DstScalar>::infinity()
479                             : std::numeric_limits<DstScalar>::max();
480 };
481
482 // Validates self-consistency of GemmParams.
483 template <typename AccumScalar, typename DstScalar, QuantizationFlavor quantization_flavor>
484 void ValidateGemmParams(const GemmParams<AccumScalar, DstScalar, quantization_flavor> &params)
485 {
486   // Guard consistency of the quantized multiplier fields.
487   if (quantization_flavor == QuantizationFlavor::kFloatingPoint)
488   {
489     assert(!params.multiplier_fixedpoint);
490     assert(!params.multiplier_exponent);
491     assert(!params.multiplier_fixedpoint_perchannel);
492     assert(!params.multiplier_exponent_perchannel);
493   }
494   else if (quantization_flavor == QuantizationFlavor::kIntegerWithUniformMultiplier &&
495            !std::is_same<DstScalar, int32_t>::value)
496   {
497     assert(params.multiplier_fixedpoint);
498     // Nothing to check about multiplier_exponent
499     assert(!params.multiplier_fixedpoint_perchannel);
500     assert(!params.multiplier_exponent_perchannel);
501   }
502   else if (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier &&
503            !std::is_same<DstScalar, int32_t>::value)
504   {
505     assert(!params.multiplier_fixedpoint);
506     assert(!params.multiplier_exponent);
507     assert(params.multiplier_fixedpoint_perchannel);
508     assert(params.multiplier_exponent_perchannel);
509   }
510   else
511   {
512     // For the get raw accumulator case, we should make sure none of the
513     // quantization params are set.
514     assert(!params.multiplier_fixedpoint);
515     assert(!params.multiplier_exponent);
516     assert(!params.multiplier_fixedpoint_perchannel);
517     assert(!params.multiplier_exponent_perchannel);
518   }
519   UNUSED_RELEASE(params);
520 }
521
522 } // namespace cker
523 } // namespace nnfw
524
525 #endif // __NNFW_CKER_TYPES_H__