2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_TYPES_H__
19 #define __NNFW_CKER_TYPES_H__
22 #include <type_traits>
31 enum class FusedActivationFunctionType
40 enum class PaddingType
47 enum class BinaryArithmeticOpType
56 enum class ComparisonOpType
72 enum class BroadcastableOpCategory : uint8_t
75 kNonBroadcast, // Matching input shapes.
76 kFirstInputBroadcastsFast, // Fivefold nested loops.
77 kSecondInputBroadcastsFast, // Fivefold nested loops.
78 kGenericBroadcast, // Fall-back.
83 PaddingValues padding_values;
88 // uint8, etc, activation params.
89 int32_t quantized_activation_min;
90 int32_t quantized_activation_max;
91 // float activation params.
92 float float_activation_min;
93 float float_activation_max;
98 // beta is not really used (not a Tensorflow parameter) and not implemented
102 // uint8 inference params. Used even when beta defaults to 1.0.
103 int32_t input_multiplier;
104 int32_t input_left_shift;
105 // Reverse scaling is only used by LogSoftmax.
106 int32_t reverse_scaling_divisor;
107 int32_t reverse_scaling_right_shift;
112 uint8_t *uint8_table1;
113 uint8_t *uint8_table2;
119 // zeropoint and scale were only used to implement PackWithScaling in the legacy code of
121 // const int32_t* input_zeropoint;
122 // const float* input_scale;
123 uint16_t inputs_count;
124 // int32_t output_zeropoint;
125 // float output_scale;
136 PaddingType padding_type;
137 PaddingValues padding_values;
138 // TODO(starka): This was just "stride", so check that width+height is OK.
139 int16_t stride_width;
140 int16_t stride_height;
141 int16_t dilation_width_factor;
142 int16_t dilation_height_factor;
143 // uint8_t inference params.
144 // TODO(b/65838351): Use smaller types if appropriate.
145 int32_t input_offset;
146 int32_t weights_offset;
147 int32_t output_offset;
148 int32_t output_multiplier;
150 // uint8_t, etc, activation params.
151 int32_t quantized_activation_min;
152 int32_t quantized_activation_max;
153 // float activation params.
154 float float_activation_min;
155 float float_activation_max;
156 bool is_replaced_weights{false};
159 struct ComparisonParams
161 ComparisonOpType type;
165 int32_t input1_offset;
166 int32_t input1_multiplier;
167 int32_t input2_offset;
168 int32_t input2_multiplier;
172 struct BinaryArithmeticOpParam
174 // Shape dependent / common to data / op types.
175 BroadcastableOpCategory broadcast_category{BroadcastableOpCategory::kNone};
176 // uint8 inference params.
177 int32_t input1_offset = 0;
178 int32_t input2_offset = 0;
179 int32_t output_offset = 0;
180 int32_t output_multiplier = 0;
181 int32_t output_shift = 0;
182 // Add / Sub, not Mul, uint8 inference params.
183 int32_t left_shift = 0;
184 int32_t input1_multiplier = 0;
185 int32_t input1_shift = 0;
186 int32_t input2_multiplier = 0;
187 int32_t input2_shift = 0;
188 // uint8, etc, activation params.
189 int32_t quantized_activation_min = 0;
190 int32_t quantized_activation_max = 0;
191 // float activation params.
192 float float_activation_min = 0;
193 float float_activation_max = 0;
195 // Processed output dimensions.
196 // Let input "a" be the one that broadcasts in the faster-changing dimension.
197 // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
198 // {b0, b1, b2, b3, b4},
199 // broadcast_shape[4] = b0 = a0.
200 // broadcast_shape[3] = b1; a1 = 1.
201 // broadcast_shape[2] = b2 = a2.
202 // broadcast_shape[1] = a3; b3 = 1.
203 // broadcast_shape[0] = b4 = a4.
204 int broadcast_shape[5] = {};
207 struct TransposeParams
213 struct ConcatenationParams
216 const int32_t *input_zeropoint;
217 const float *input_scale;
218 uint16_t inputs_count;
219 int32_t output_zeropoint;
223 struct DepthwiseConvParams
225 PaddingType padding_type;
226 PaddingValues padding_values;
227 int16_t stride_width;
228 int16_t stride_height;
229 int16_t dilation_width_factor;
230 int16_t dilation_height_factor;
231 int16_t depth_multiplier;
232 // uint8 inference params.
233 // TODO(b/65838351): Use smaller types if appropriate.
234 int32_t input_offset;
235 int32_t weights_offset;
236 int32_t output_offset;
237 int32_t output_multiplier;
239 // uint8, etc, activation params.
240 int32_t quantized_activation_min;
241 int32_t quantized_activation_max;
242 // float activation params.
243 float float_activation_min;
244 float float_activation_max;
247 struct FullyConnectedParams
249 FusedActivationFunctionType activation{FusedActivationFunctionType::kNone};
250 // uint8 inference params.
251 // TODO(b/65838351): Use smaller types if appropriate.
252 int32_t input_offset;
253 int32_t weights_offset;
255 int32_t output_offset;
256 int32_t output_multiplier;
258 // uint8, etc, activation params.
259 int32_t quantized_activation_min;
260 int32_t quantized_activation_max;
261 // float activation params - no one use this params, but ruy might use them later.
262 // float float_activation_min;
263 // float float_activation_max;
264 // FullyConnectedWeightsFormat weights_format;
269 // uint8 inference params.
270 int32_t input_zero_point;
275 kTfLiteLSTMFullKernel = 0,
276 kTfLiteLSTMBasicKernel
281 // Parameters for LSTM version 1.
282 FusedActivationFunctionType activation{FusedActivationFunctionType::kNone};
286 // Parameters for LSTM version 2.
287 // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
288 LSTMKernelType kernel_type;
290 // Parameters for LSTM version 4.
291 bool asymmetric_quantize_inputs;
299 struct InstanceNormParams
302 float float_activation_min;
303 float float_activation_max;
306 struct ResizeBilinearParams
308 int32_t output_height;
309 int32_t output_width;
311 bool half_pixel_centers;
314 struct TransposeConvParams
316 PaddingType padding_type;
317 PaddingValues padding_values;
318 // TODO(starka): This was just "stride", so check that width+height is OK.
319 int16_t stride_width;
320 int16_t stride_height;
321 int16_t dilation_width_factor;
322 int16_t dilation_height_factor;
323 // uint8_t inference params.
324 // TODO(b/65838351): Use smaller types if appropriate.
325 int32_t input_offset;
326 int32_t weights_offset;
327 int32_t output_offset;
328 int32_t output_multiplier;
330 // uint8_t, etc, activation params.
331 int32_t quantized_activation_min;
332 int32_t quantized_activation_max;
333 // float activation params.
334 float float_activation_min;
335 float float_activation_max;
346 struct StridedSliceParams
348 int8_t start_indices_count;
349 int16_t start_indices[4];
350 int8_t stop_indices_count;
351 int16_t stop_indices[4];
352 int8_t strides_count;
356 int16_t ellipsis_mask;
358 int16_t new_axis_mask;
359 int16_t shrink_axis_mask;
374 struct FusedBatchNormParams
377 std::string data_format; // UNKNOWN(0), NHWC(1), NCHW(2)
381 struct SpaceToBatchParams
383 // "Zero" padding for uint8 means padding with the output offset.
384 int32_t output_offset;
387 struct SpaceToDepthParams
392 struct LeakyReluParams
403 enum class CachePolicy : std::uint8_t
406 kCacheIfLargeSpeedup,
410 // MatrixParams encapsulates the parameters that Gemm needs about each
411 // matrix, besides the buffer data pointer.
412 // Compare to ruy::Matrix, which also encapsulates the data pointer.
413 // Rationale for leaving the data pointer out of here: doing so
414 // requires complicated const-correctness mechanics. See
415 // ruy::ConstCheckingPtr.
416 template <typename Scalar> struct MatrixParams
418 // Storage layout order. For now we only do plain linear non-strided
419 // layout. It would be easy to support a stride if needed.
420 Order order = Order::kColMajor;
421 // Number of rows of the matrix.
423 // Number of columns of the matrix.
425 // The zero_point, i.e. which Scalar value is to be interpreted as zero.
426 // When Scalar is floating-point, this must be 0.
427 Scalar zero_point = 0;
428 // When the data pointed to by this matrix is constant data, so that it is
429 // valid to assume that equality of pointers implies equality of data,
430 // a CachePolicy may be used instead of the default kNeverCache,
431 // which will enable ruy to take advantage of this constancy of the data to
432 // cache the packing work, which can be a large speedup in matrix*vector
433 // and other narrow shapes.
434 CachePolicy cache_policy = CachePolicy::kNeverCache;
437 // Enumeration of broad categories of Gemm.
439 // The primary reason for this to exist is to allow Gemm to compile
440 // only uniform-quantized or only per-channel-quantized code paths.
441 // This is unneeded with ruy as the back-end, as this is only a runtime
442 // difference in ruy, but with gemmlowp these really are separate code
443 // paths and templatizing in a QuantizationFlavor is necessary to avoid
444 // compiling unused gemmlowp code. Indeed, TFLite currently uses
445 // uint8 with uniform quantization and int8 with per-channel quantization,
446 // and does not use uint8 with per-channel. We want to avoid compiling
447 // the gemmlowp uint8 per-channel path when gemmlowp is the back-end.
449 // It's possible to drop this in the future if gemmlowp goes away and no
450 // other then-relevant backend library handles quantized paths in a way that
451 // requires knowing this at compile-time.
452 enum class QuantizationFlavor
454 // Floating-point Gemm: the accumulators are not multiplied by any
457 // Quantized Gemm using a single multiplier for all accumulators.
458 kIntegerWithUniformMultiplier,
459 // Quantized Gemm using a separate multipliers for accumulators of each
460 // row of the destination matrix. This is what is called 'per-channel'
461 // in GemmParams. Here we use the more specific 'per-row' terminology
462 // to allow for the possibility of 'per-column' in the future, and to
463 // allow for that to be a separate code path in some back-end such as
465 kIntegerWithPerRowMultiplier
468 // Additional parameters that Gemm needs, beyond what falls into
469 // the MatrixParams that it takes. Compare to ruy::Spec.
471 // Decoupling AccumScalar from DstScalar (rather than deducing it from that)
472 // is useful future-proofing. Think of a float16 path using float32 accum.
474 // QuantizationFlavor is passed here even though it's technically not used
475 // in this class. This is so that we retain the ability in the future to
476 // specialize this class for quantization flavor, and this allows for
477 // Gemm to be templatized in quantization_flavor via the GemmParams that it
478 // takes, allowing for automatic template parameter deduction to take place,
479 // so that most call sites don't need to specify a QuantizationFlavor
480 // (only those that need perchannel quantization do).
481 template <typename AccumScalar, typename DstScalar,
482 QuantizationFlavor quantization_flavor =
483 std::is_floating_point<AccumScalar>::value
484 ? QuantizationFlavor::kFloatingPoint
485 : QuantizationFlavor::kIntegerWithUniformMultiplier>
488 // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
489 // of the multiplier by which accumulators are multiplied before being casted
490 // to the destination type.
491 AccumScalar multiplier_fixedpoint = 0;
492 // Only for non-floating-point cases. The exponent part of the aforementioned
494 int multiplier_exponent = 0;
495 // Per-channel variant of multiplier_fixedpoint. If not nullptr, this must
496 // point to a buffer of as many values as there are rows in the destination
497 // matrix. Each row of the destination matrix will use the corresponding
498 // buffer element instead of multiplier_fixedpoint.
499 const AccumScalar *multiplier_fixedpoint_perchannel = nullptr;
500 // Per-channel variant of multiplier_exponent. If not nullptr, this must
501 // point to a buffer of as many values as there are rows in the destination
502 // matrix. Each row of the destination matrix will use the corresponding
503 // buffer element instead of multiplier_exponent.
505 // Either none or both of multiplier_exponent_perchannel and
506 // multiplier_fixedpoint_perchannel must be nullptr.
507 const int *multiplier_exponent_perchannel = nullptr;
508 // The bias vector data, if not null.
509 const AccumScalar *bias = nullptr;
510 // min clamp bound of destination values.
511 DstScalar clamp_min = std::is_floating_point<DstScalar>::value
512 ? -std::numeric_limits<DstScalar>::infinity()
513 : std::numeric_limits<DstScalar>::lowest();
514 // max clamp bound of destination values.
515 DstScalar clamp_max = std::is_floating_point<DstScalar>::value
516 ? std::numeric_limits<DstScalar>::infinity()
517 : std::numeric_limits<DstScalar>::max();
520 // Validates self-consistency of GemmParams.
521 template <typename AccumScalar, typename DstScalar, QuantizationFlavor quantization_flavor>
522 void ValidateGemmParams(const GemmParams<AccumScalar, DstScalar, quantization_flavor> ¶ms)
524 // Guard consistency of the quantized multiplier fields.
525 if (quantization_flavor == QuantizationFlavor::kFloatingPoint)
527 assert(!params.multiplier_fixedpoint);
528 assert(!params.multiplier_exponent);
529 assert(!params.multiplier_fixedpoint_perchannel);
530 assert(!params.multiplier_exponent_perchannel);
532 else if (quantization_flavor == QuantizationFlavor::kIntegerWithUniformMultiplier &&
533 !std::is_same<DstScalar, int32_t>::value)
535 assert(params.multiplier_fixedpoint);
536 // Nothing to check about multiplier_exponent
537 assert(!params.multiplier_fixedpoint_perchannel);
538 assert(!params.multiplier_exponent_perchannel);
540 else if (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier &&
541 !std::is_same<DstScalar, int32_t>::value)
543 assert(!params.multiplier_fixedpoint);
544 assert(!params.multiplier_exponent);
545 assert(params.multiplier_fixedpoint_perchannel);
546 assert(params.multiplier_exponent_perchannel);
550 // For the get raw accumulator case, we should make sure none of the
551 // quantization params are set.
552 assert(!params.multiplier_fixedpoint);
553 assert(!params.multiplier_exponent);
554 assert(!params.multiplier_fixedpoint_perchannel);
555 assert(!params.multiplier_exponent_perchannel);
557 UNUSED_RELEASE(params);
563 #endif // __NNFW_CKER_TYPES_H__