2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 #ifndef __NNFW_CKER_TYPES_H__
19 #define __NNFW_CKER_TYPES_H__
22 #include <type_traits>
31 enum class FusedActivationFunctionType
38 enum class PaddingType
45 enum class BinaryArithmeticOpType
54 enum class ComparisonOpType
70 enum class BroadcastableOpCategory : uint8_t
73 kNonBroadcast, // Matching input shapes.
74 kFirstInputBroadcastsFast, // Fivefold nested loops.
75 kSecondInputBroadcastsFast, // Fivefold nested loops.
76 kGenericBroadcast, // Fall-back.
81 FusedActivationFunctionType activation;
82 PaddingType padding_type;
83 PaddingValues padding_values;
88 // uint8, etc, activation params.
89 int32_t quantized_activation_min;
90 int32_t quantized_activation_max;
91 // float activation params.
92 float float_activation_min;
93 float float_activation_max;
98 // beta is not really used (not a Tensorflow parameter) and not implemented
102 // uint8 inference params. Used even when beta defaults to 1.0.
103 int32_t input_multiplier;
104 int32_t input_left_shift;
105 // Reverse scaling is only used by LogSoftmax.
106 int32_t reverse_scaling_divisor;
107 int32_t reverse_scaling_right_shift;
117 // zeropoint and scale were only used to implement PackWithScaling in the legacy code of
119 // const int32_t* input_zeropoint;
120 // const float* input_scale;
121 uint16_t inputs_count;
122 // int32_t output_zeropoint;
123 // float output_scale;
134 PaddingType padding_type;
135 PaddingValues padding_values;
136 // TODO(starka): This was just "stride", so check that width+height is OK.
137 int16_t stride_width;
138 int16_t stride_height;
139 int16_t dilation_width_factor;
140 int16_t dilation_height_factor;
141 // uint8_t inference params.
142 // TODO(b/65838351): Use smaller types if appropriate.
143 int32_t input_offset;
144 int32_t weights_offset;
145 int32_t output_offset;
146 int32_t output_multiplier;
148 // uint8_t, etc, activation params.
149 int32_t quantized_activation_min;
150 int32_t quantized_activation_max;
151 // float activation params.
152 float float_activation_min;
153 float float_activation_max;
154 bool is_replaced_weights{false};
157 struct ComparisonParams
159 ComparisonOpType type;
163 int32_t input1_offset;
164 int32_t input1_multiplier;
165 int32_t input2_offset;
166 int32_t input2_multiplier;
170 struct BinaryArithmeticOpParam
172 // Shape dependent / common to data / op types.
173 BroadcastableOpCategory broadcast_category;
174 // uint8 inference params.
175 int32_t input1_offset;
176 int32_t input2_offset;
177 int32_t output_offset;
178 int32_t output_multiplier;
179 int32_t output_shift;
180 // Add / Sub, not Mul, uint8 inference params.
182 int32_t input1_multiplier;
183 int32_t input1_shift;
184 int32_t input2_multiplier;
185 int32_t input2_shift;
186 // uint8, etc, activation params.
187 int32_t quantized_activation_min;
188 int32_t quantized_activation_max;
189 // float activation params.
190 float float_activation_min;
191 float float_activation_max;
193 // Processed output dimensions.
194 // Let input "a" be the one that broadcasts in the faster-changing dimension.
195 // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
196 // {b0, b1, b2, b3, b4},
197 // broadcast_shape[4] = b0 = a0.
198 // broadcast_shape[3] = b1; a1 = 1.
199 // broadcast_shape[2] = b2 = a2.
200 // broadcast_shape[1] = a3; b3 = 1.
201 // broadcast_shape[0] = b4 = a4.
202 int broadcast_shape[5] = {};
205 struct TransposeParams
211 struct ConcatenationParams
214 const int32_t *input_zeropoint;
215 const float *input_scale;
216 uint16_t inputs_count;
217 int32_t output_zeropoint;
221 struct DepthwiseConvParams
223 PaddingType padding_type;
224 PaddingValues padding_values;
225 int16_t stride_width;
226 int16_t stride_height;
227 int16_t dilation_width_factor;
228 int16_t dilation_height_factor;
229 int16_t depth_multiplier;
230 // uint8 inference params.
231 // TODO(b/65838351): Use smaller types if appropriate.
232 int32_t input_offset;
233 int32_t weights_offset;
234 int32_t output_offset;
235 int32_t output_multiplier;
237 // uint8, etc, activation params.
238 int32_t quantized_activation_min;
239 int32_t quantized_activation_max;
240 // float activation params.
241 float float_activation_min;
242 float float_activation_max;
245 struct FullyConnectedParams
247 FusedActivationFunctionType activation{FusedActivationFunctionType::kNone};
248 // uint8 inference params.
249 // TODO(b/65838351): Use smaller types if appropriate.
250 int32_t input_offset;
251 int32_t weights_offset;
253 int32_t output_offset;
254 int32_t output_multiplier;
256 // uint8, etc, activation params.
257 int32_t quantized_activation_min;
258 int32_t quantized_activation_max;
259 // float activation params.
260 float float_activation_min;
261 float float_activation_max;
262 // FullyConnectedWeightsFormat weights_format;
267 // uint8 inference params.
268 int32_t input_zero_point;
276 struct InstanceNormParams
279 float float_activation_min;
280 float float_activation_max;
283 struct ResizeBilinearParams
285 int32_t output_height;
286 int32_t output_width;
288 bool half_pixel_centers;
291 struct TransposeConvParams
293 PaddingType padding_type;
294 PaddingValues padding_values;
295 // TODO(starka): This was just "stride", so check that width+height is OK.
296 int16_t stride_width;
297 int16_t stride_height;
298 int16_t dilation_width_factor;
299 int16_t dilation_height_factor;
300 // uint8_t inference params.
301 // TODO(b/65838351): Use smaller types if appropriate.
302 int32_t input_offset;
303 int32_t weights_offset;
304 int32_t output_offset;
305 int32_t output_multiplier;
307 // uint8_t, etc, activation params.
308 int32_t quantized_activation_min;
309 int32_t quantized_activation_max;
310 // float activation params.
311 float float_activation_min;
312 float float_activation_max;
323 struct StridedSliceParams
325 int8_t start_indices_count;
326 int16_t start_indices[4];
327 int8_t stop_indices_count;
328 int16_t stop_indices[4];
329 int8_t strides_count;
333 int16_t ellipsis_mask;
335 int16_t new_axis_mask;
336 int16_t shrink_axis_mask;
351 struct FusedBatchNormParams
354 std::string data_format; // UNKNOWN(0), NHWC(1), NCHW(2)
358 struct SpaceToBatchParams
360 // "Zero" padding for uint8 means padding with the output offset.
361 int32_t output_offset;
364 struct SpaceToDepthParams
375 // MatrixParams encapsulates the parameters that Gemm needs about each
376 // matrix, besides the buffer data pointer.
377 // Compare to ruy::Matrix, which also encapsulates the data pointer.
378 // Rationale for leaving the data pointer out of here: doing so
379 // requires complicated const-correctness mechanics. See
380 // ruy::ConstCheckingPtr.
381 template <typename Scalar> struct MatrixParams
383 // Storage layout order. For now we only do plain linear non-strided
384 // layout. It would be easy to support a stride if needed.
385 Order order = Order::kColMajor;
386 // Number of rows of the matrix.
388 // Number of columns of the matrix.
390 // The zero_point, i.e. which Scalar value is to be interpreted as zero.
391 // When Scalar is floating-point, this must be 0.
392 Scalar zero_point = 0;
393 // Indicate whether the underlying data will remain unchanged for
394 // some period of time. Defaults to false, but should be set to true
395 // for unchanging data (e.g. weights buffers in many cases)
396 bool cacheable = false;
399 // Enumeration of broad categories of Gemm.
401 // The primary reason for this to exist is to allow Gemm to compile
402 // only uniform-quantized or only per-channel-quantized code paths.
403 // This is unneeded with ruy as the back-end, as this is only a runtime
404 // difference in ruy, but with gemmlowp these really are separate code
405 // paths and templatizing in a QuantizationFlavor is necessary to avoid
406 // compiling unused gemmlowp code. Indeed, TFLite currently uses
407 // uint8 with uniform quantization and int8 with per-channel quantization,
408 // and does not use uint8 with per-channel. We want to avoid compiling
409 // the gemmlowp uint8 per-channel path when gemmlowp is the back-end.
411 // It's possible to drop this in the future if gemmlowp goes away and no
412 // other then-relevant backend library handles quantized paths in a way that
413 // requires knowing this at compile-time.
414 enum class QuantizationFlavor
416 // Floating-point Gemm: the accumulators are not multiplied by any
419 // Quantized Gemm using a single multiplier for all accumulators.
420 kIntegerWithUniformMultiplier,
421 // Quantized Gemm using a separate multipliers for accumulators of each
422 // row of the destination matrix. This is what is called 'per-channel'
423 // in GemmParams. Here we use the more specific 'per-row' terminology
424 // to allow for the possibility of 'per-column' in the future, and to
425 // allow for that to be a separate code path in some back-end such as
427 kIntegerWithPerRowMultiplier
430 // Additional parameters that Gemm needs, beyond what falls into
431 // the MatrixParams that it takes. Compare to ruy::Spec.
433 // Decoupling AccumScalar from DstScalar (rather than deducing it from that)
434 // is useful future-proofing. Think of a float16 path using float32 accum.
436 // QuantizationFlavor is passed here even though it's technically not used
437 // in this class. This is so that we retain the ability in the future to
438 // specialize this class for quantization flavor, and this allows for
439 // Gemm to be templatized in quantization_flavor via the GemmParams that it
440 // takes, allowing for automatic template parameter deduction to take place,
441 // so that most call sites don't need to specify a QuantizationFlavor
442 // (only those that need perchannel quantization do).
443 template <typename AccumScalar, typename DstScalar,
444 QuantizationFlavor quantization_flavor =
445 std::is_floating_point<AccumScalar>::value
446 ? QuantizationFlavor::kFloatingPoint
447 : QuantizationFlavor::kIntegerWithUniformMultiplier>
450 // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
451 // of the multiplier by which accumulators are multiplied before being casted
452 // to the destination type.
453 AccumScalar multiplier_fixedpoint = 0;
454 // Only for non-floating-point cases. The exponent part of the aforementioned
456 int multiplier_exponent = 0;
457 // Per-channel variant of multiplier_fixedpoint. If not nullptr, this must
458 // point to a buffer of as many values as there are rows in the destination
459 // matrix. Each row of the destination matrix will use the corresponding
460 // buffer element instead of multiplier_fixedpoint.
461 const AccumScalar *multiplier_fixedpoint_perchannel = nullptr;
462 // Per-channel variant of multiplier_exponent. If not nullptr, this must
463 // point to a buffer of as many values as there are rows in the destination
464 // matrix. Each row of the destination matrix will use the corresponding
465 // buffer element instead of multiplier_exponent.
467 // Either none or both of multiplier_exponent_perchannel and
468 // multiplier_fixedpoint_perchannel must be nullptr.
469 const int *multiplier_exponent_perchannel = nullptr;
470 // The bias vector data, if not null.
471 const AccumScalar *bias = nullptr;
472 // min clamp bound of destination values.
473 DstScalar clamp_min = std::is_floating_point<DstScalar>::value
474 ? -std::numeric_limits<DstScalar>::infinity()
475 : std::numeric_limits<DstScalar>::lowest();
476 // max clamp bound of destination values.
477 DstScalar clamp_max = std::is_floating_point<DstScalar>::value
478 ? std::numeric_limits<DstScalar>::infinity()
479 : std::numeric_limits<DstScalar>::max();
482 // Validates self-consistency of GemmParams.
483 template <typename AccumScalar, typename DstScalar, QuantizationFlavor quantization_flavor>
484 void ValidateGemmParams(const GemmParams<AccumScalar, DstScalar, quantization_flavor> ¶ms)
486 // Guard consistency of the quantized multiplier fields.
487 if (quantization_flavor == QuantizationFlavor::kFloatingPoint)
489 assert(!params.multiplier_fixedpoint);
490 assert(!params.multiplier_exponent);
491 assert(!params.multiplier_fixedpoint_perchannel);
492 assert(!params.multiplier_exponent_perchannel);
494 else if (quantization_flavor == QuantizationFlavor::kIntegerWithUniformMultiplier &&
495 !std::is_same<DstScalar, int32_t>::value)
497 assert(params.multiplier_fixedpoint);
498 // Nothing to check about multiplier_exponent
499 assert(!params.multiplier_fixedpoint_perchannel);
500 assert(!params.multiplier_exponent_perchannel);
502 else if (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier &&
503 !std::is_same<DstScalar, int32_t>::value)
505 assert(!params.multiplier_fixedpoint);
506 assert(!params.multiplier_exponent);
507 assert(params.multiplier_fixedpoint_perchannel);
508 assert(params.multiplier_exponent_perchannel);
512 // For the get raw accumulator case, we should make sure none of the
513 // quantization params are set.
514 assert(!params.multiplier_fixedpoint);
515 assert(!params.multiplier_exponent);
516 assert(!params.multiplier_fixedpoint_perchannel);
517 assert(!params.multiplier_exponent_perchannel);
519 UNUSED_RELEASE(params);
525 #endif // __NNFW_CKER_TYPES_H__