2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "OpPrinter.h"
22 #include <flatbuffers/flexbuffers.h>
24 using std::make_unique;
29 // TODO move to some header
30 std::ostream &operator<<(std::ostream &os, const std::vector<int32_t> &vect);
32 // TODO Re-arrange in alphabetical order
34 class AddPrinter : public OpPrinter
37 void options(const tflite::Operator *op, std::ostream &os) const override
39 if (auto *params = op->builtin_options_as_AddOptions())
42 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
49 class ArgMaxPrinter : public OpPrinter
52 void options(const tflite::Operator *op, std::ostream &os) const override
54 if (auto *params = op->builtin_options_as_ArgMaxOptions())
57 os << "OutputType(" << EnumNameTensorType(params->output_type()) << ") ";
63 class ArgMinPrinter : public OpPrinter
66 void options(const tflite::Operator *op, std::ostream &os) const override
68 if (auto *params = op->builtin_options_as_ArgMinOptions())
71 os << "OutputType(" << EnumNameTensorType(params->output_type()) << ") ";
77 class BidirectionalSequenceLSTMPrinter : public OpPrinter
80 void options(const tflite::Operator *op, std::ostream &os) const override
82 if (auto *params = op->builtin_options_as_BidirectionalSequenceLSTMOptions())
85 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
87 os << "cell_clip(" << params->cell_clip() << ") ";
88 os << "proj_clip(" << params->proj_clip() << ") ";
89 os << "time_major(" << params->time_major() << ") ";
90 os << "asymmetric_quantize_inputs(" << params->asymmetric_quantize_inputs() << ") ";
91 os << "merge_outputs(" << params->merge_outputs() << ") ";
97 class CastPrinter : public OpPrinter
100 void options(const tflite::Operator *op, std::ostream &os) const override
102 if (auto cast_params = op->builtin_options_as_CastOptions())
105 os << "in_data_type(" << tflite::EnumNameTensorType(cast_params->in_data_type()) << ") ";
106 os << "out_data_type(" << tflite::EnumNameTensorType(cast_params->out_data_type()) << ") ";
112 class Conv2DPrinter : public OpPrinter
115 void options(const tflite::Operator *op, std::ostream &os) const override
117 if (auto conv_params = op->builtin_options_as_Conv2DOptions())
120 os << "Padding(" << conv_params->padding() << ") ";
121 os << "Stride.W(" << conv_params->stride_w() << ") ";
122 os << "Stride.H(" << conv_params->stride_h() << ") ";
123 os << "Dilation.W(" << conv_params->dilation_w_factor() << ") ";
124 os << "Dilation.H(" << conv_params->dilation_h_factor() << ") ";
126 << EnumNameActivationFunctionType(conv_params->fused_activation_function()) << ")";
132 class DivPrinter : public OpPrinter
135 void options(const tflite::Operator *op, std::ostream &os) const override
137 if (auto *params = op->builtin_options_as_DivOptions())
140 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
147 class Pool2DPrinter : public OpPrinter
150 void options(const tflite::Operator *op, std::ostream &os) const override
152 if (auto pool_params = op->builtin_options_as_Pool2DOptions())
155 os << "Padding(" << pool_params->padding() << ") ";
156 os << "Stride.W(" << pool_params->stride_w() << ") ";
157 os << "Stride.H(" << pool_params->stride_h() << ") ";
158 os << "Filter.W(" << pool_params->filter_width() << ") ";
159 os << "Filter.H(" << pool_params->filter_height() << ") ";
161 << EnumNameActivationFunctionType(pool_params->fused_activation_function()) << ")";
167 class ConcatenationPrinter : public OpPrinter
170 void options(const tflite::Operator *op, std::ostream &os) const override
172 if (auto *concatenation_params = op->builtin_options_as_ConcatenationOptions())
176 << EnumNameActivationFunctionType(concatenation_params->fused_activation_function())
178 os << "Axis(" << concatenation_params->axis() << ")";
184 class ReducerPrinter : public OpPrinter
187 void options(const tflite::Operator *op, std::ostream &os) const override
189 if (auto reducer_params = op->builtin_options_as_ReducerOptions())
192 os << "keep_dims(" << reducer_params->keep_dims() << ") ";
198 class ReshapePrinter : public OpPrinter
201 void options(const tflite::Operator *op, std::ostream &os) const override
203 if (auto *reshape_params = op->builtin_options_as_ReshapeOptions())
205 auto new_shape = tflread::as_index_vector(reshape_params->new_shape());
207 os << "NewShape(" << new_shape << ")";
213 class ResizeBilinearPrinter : public OpPrinter
216 void options(const tflite::Operator *op, std::ostream &os) const override
218 if (auto *resize_params = op->builtin_options_as_ResizeBilinearOptions())
221 os << std::boolalpha;
222 os << "align_corners(" << resize_params->align_corners() << ")";
223 os << "half_pixel_centers(" << resize_params->half_pixel_centers() << ")";
224 os << std::noboolalpha;
230 class ResizeNearestNeighborPrinter : public OpPrinter
233 void options(const tflite::Operator *op, std::ostream &os) const override
235 if (auto *resize_params = op->builtin_options_as_ResizeNearestNeighborOptions())
238 os << std::boolalpha;
239 os << "align_corners(" << resize_params->align_corners() << ")";
240 os << std::noboolalpha;
246 class ReverseSequencePrinter : public OpPrinter
249 void options(const tflite::Operator *op, std::ostream &os) const override
251 if (auto *std_params = op->builtin_options_as_ReverseSequenceOptions())
254 os << "seq_dim(" << std_params->seq_dim() << ") ";
255 os << "batch_dim(" << std_params->batch_dim() << ") ";
261 class DepthToSpacePrinter : public OpPrinter
264 void options(const tflite::Operator *op, std::ostream &os) const override
266 if (auto *std_params = op->builtin_options_as_DepthToSpaceOptions())
269 os << "BlockSize(" << std_params->block_size() << ")";
275 class SparseToDensePrinter : public OpPrinter
278 void options(const tflite::Operator *op, std::ostream &os) const override
280 if (auto *std_params = op->builtin_options_as_SparseToDenseOptions())
283 os << "ValidateIndices(" << std_params->validate_indices() << ")";
289 class DepthwiseConv2DPrinter : public OpPrinter
292 void options(const tflite::Operator *op, std::ostream &os) const override
294 if (auto conv_params = op->builtin_options_as_DepthwiseConv2DOptions())
297 os << "Padding(" << conv_params->padding() << ") ";
298 os << "Stride.W(" << conv_params->stride_w() << ") ";
299 os << "Stride.H(" << conv_params->stride_h() << ") ";
300 os << "DepthMultiplier(" << conv_params->depth_multiplier() << ") ";
301 os << "Dilation.W(" << conv_params->dilation_w_factor() << ") ";
302 os << "Dilation.H(" << conv_params->dilation_h_factor() << ") ";
304 << EnumNameActivationFunctionType(conv_params->fused_activation_function()) << ") ";
310 class FakeQuantPrinter : public OpPrinter
313 void options(const tflite::Operator *op, std::ostream &os) const override
315 if (auto *params = op->builtin_options_as_FakeQuantOptions())
318 os << "Min(" << params->min() << ") ";
319 os << "Max(" << params->max() << ") ";
320 os << "NumBits(" << params->num_bits() << ") ";
321 os << std::boolalpha;
322 os << "NarrowRange(" << params->narrow_range() << ") ";
323 os << std::noboolalpha;
329 class FullyConnectedPrinter : public OpPrinter
332 void options(const tflite::Operator *op, std::ostream &os) const override
334 if (auto *params = op->builtin_options_as_FullyConnectedOptions())
337 os << "WeightFormat(" << EnumNameFullyConnectedOptionsWeightsFormat(params->weights_format())
339 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
347 class GatherPrinter : public OpPrinter
350 void options(const tflite::Operator *op, std::ostream &os) const override
352 if (auto *params = op->builtin_options_as_GatherOptions())
355 os << "Axis(" << params->axis() << ") ";
362 class IfPrinter : public OpPrinter
365 void options(const tflite::Operator *op, std::ostream &os) const override
367 if (auto *params = op->builtin_options_as_IfOptions())
370 os << "then_subgraph_index(" << params->then_subgraph_index() << ") ";
371 os << "else_subgraph_index(" << params->else_subgraph_index() << ") ";
377 class L2NormPrinter : public OpPrinter
380 void options(const tflite::Operator *op, std::ostream &os) const override
382 if (auto *params = op->builtin_options_as_L2NormOptions())
385 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
392 class LeakyReluPrinter : public OpPrinter
395 void options(const tflite::Operator *op, std::ostream &os) const override
397 if (auto *params = op->builtin_options_as_LeakyReluOptions())
400 os << "alpha(" << params->alpha() << ") ";
405 class LocalResponseNormalizationPrinter : public OpPrinter
408 void options(const tflite::Operator *op, std::ostream &os) const override
410 if (auto *params = op->builtin_options_as_LocalResponseNormalizationOptions())
413 os << "radius(" << params->radius() << ") ";
414 os << "bias(" << params->bias() << ") ";
415 os << "alpha(" << params->alpha() << ") ";
416 os << "beta(" << params->beta() << ") ";
422 class MirrorPadPrinter : public OpPrinter
425 void options(const tflite::Operator *op, std::ostream &os) const override
427 if (auto *params = op->builtin_options_as_MirrorPadOptions())
430 os << "mode(" << EnumNameMirrorPadMode(params->mode()) << ") ";
436 class MulPrinter : public OpPrinter
439 void options(const tflite::Operator *op, std::ostream &os) const override
441 if (auto *params = op->builtin_options_as_MulOptions())
444 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
451 class PackPrinter : public OpPrinter
454 void options(const tflite::Operator *op, std::ostream &os) const override
456 if (auto *params = op->builtin_options_as_PackOptions())
459 os << "ValuesCount(" << params->values_count() << ") ";
460 os << "Axis(" << params->axis() << ") ";
466 class OneHotPrinter : public OpPrinter
469 void options(const tflite::Operator *op, std::ostream &os) const override
471 if (auto *params = op->builtin_options_as_OneHotOptions())
474 os << "Axis(" << params->axis() << ") ";
481 class ShapePrinter : public OpPrinter
484 void options(const tflite::Operator *op, std::ostream &os) const override
486 if (auto *params = op->builtin_options_as_ShapeOptions())
489 os << "out_type(" << EnumNameTensorType(params->out_type()) << ") ";
495 class SoftmaxPrinter : public OpPrinter
498 void options(const tflite::Operator *op, std::ostream &os) const override
500 if (auto *softmax_params = op->builtin_options_as_SoftmaxOptions())
503 os << "Beta(" << softmax_params->beta() << ")";
509 class SpaceToDepthPrinter : public OpPrinter
512 void options(const tflite::Operator *op, std::ostream &os) const override
514 if (auto *std_params = op->builtin_options_as_SpaceToDepthOptions())
517 os << "BlockSize(" << std_params->block_size() << ")";
523 class SqueezePrinter : public OpPrinter
526 void options(const tflite::Operator *op, std::ostream &os) const override
528 if (auto *params = op->builtin_options_as_SqueezeOptions())
531 os << "SqueezeDims(";
532 for (int i = 0; i < params->squeeze_dims()->size(); ++i)
536 os << params->squeeze_dims()->Get(i);
544 class StridedSlicePrinter : public OpPrinter
547 void options(const tflite::Operator *op, std::ostream &os) const override
549 if (auto *strided_slice_params = op->builtin_options_as_StridedSliceOptions())
552 os << "begin_mask(" << strided_slice_params->begin_mask() << ") ";
553 os << "end_mask(" << strided_slice_params->end_mask() << ") ";
554 os << "ellipsis_mask(" << strided_slice_params->ellipsis_mask() << ") ";
555 os << "new_axis_mask(" << strided_slice_params->new_axis_mask() << ") ";
556 os << "shrink_axis_mask(" << strided_slice_params->shrink_axis_mask() << ") ";
562 class SplitPrinter : public OpPrinter
565 void options(const tflite::Operator *op, std::ostream &os) const override
567 if (auto *params = op->builtin_options_as_SplitOptions())
570 os << "num_splits(" << params->num_splits() << ") ";
576 class SplitVPrinter : public OpPrinter
579 void options(const tflite::Operator *op, std::ostream &os) const override
581 if (auto *params = op->builtin_options_as_SplitVOptions())
584 os << "num_splits(" << params->num_splits() << ") ";
590 class SubPrinter : public OpPrinter
593 void options(const tflite::Operator *op, std::ostream &os) const override
595 if (auto *params = op->builtin_options_as_SubOptions())
598 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
605 class SVDFPrinter : public OpPrinter
608 void options(const tflite::Operator *op, std::ostream &os) const override
610 if (auto *params = op->builtin_options_as_SVDFOptions())
613 os << "rank(" << params->rank() << ") ";
614 os << "activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
616 os << "asymmetric_quantize_inputs(" << params->asymmetric_quantize_inputs() << ") ";
622 class TransposeConvPrinter : public OpPrinter
625 void options(const tflite::Operator *op, std::ostream &os) const override
627 if (auto *params = op->builtin_options_as_TransposeConvOptions())
630 os << "Padding(" << params->padding() << ") ";
631 os << "Stride.W(" << params->stride_w() << ") ";
632 os << "Stride.H(" << params->stride_h() << ") ";
638 class WhilePrinter : public OpPrinter
641 void options(const tflite::Operator *op, std::ostream &os) const override
643 if (auto *params = op->builtin_options_as_WhileOptions())
646 os << "cond_subgraph_index(" << params->cond_subgraph_index() << ") ";
647 os << "body_subgraph_index(" << params->body_subgraph_index() << ") ";
653 class UnidirectionalSequenceLSTMPrinter : public OpPrinter
656 void options(const tflite::Operator *op, std::ostream &os) const override
658 if (auto *params = op->builtin_options_as_UnidirectionalSequenceLSTMOptions())
661 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
663 os << "cell_clip(" << params->cell_clip() << ") ";
664 os << "proj_clip(" << params->proj_clip() << ") ";
665 os << "time_major(" << params->time_major() << ") ";
666 os << "asymmetric_quantize_inputs(" << params->asymmetric_quantize_inputs() << ") ";
672 class UniquePrinter : public OpPrinter
675 void options(const tflite::Operator *op, std::ostream &os) const override
677 if (auto *params = op->builtin_options_as_UniqueOptions())
680 os << "idx_out_type(" << EnumNameTensorType(params->idx_out_type()) << ") ";
686 class CustomOpPrinter : public OpPrinter
689 void options(const tflite::Operator *op, std::ostream &os) const override
691 if (op->custom_options_format() != tflite::CustomOptionsFormat::CustomOptionsFormat_FLEXBUFFERS)
694 os << "Unknown custom option format";
698 const flatbuffers::Vector<uint8_t> *option_buf = op->custom_options();
700 if (option_buf == nullptr || option_buf->size() == 0)
702 os << "No attrs found." << std::endl;
707 // attrs of custom ops are encoded in flexbuffer format
708 auto attr_map = flexbuffers::GetRoot(option_buf->data(), option_buf->size()).AsMap();
711 auto keys = attr_map.Keys();
712 for (int i = 0; i < keys.size(); i++)
714 auto key = keys[i].ToString();
715 os << key << "(" << attr_map[key].ToString() << ") ";
718 // Note: attr in "Shape" type does not seem to be converted by tflite_convert.
719 // When the converted tflite file (with custom op) is opened with hexa editory,
720 // attrs names can be found but attr name in "Shape" type is not found.
726 OpPrinterRegistry::OpPrinterRegistry()
728 _op_map[tflite::BuiltinOperator_ADD] = make_unique<AddPrinter>();
729 // There is no Option for ADD_N
730 _op_map[tflite::BuiltinOperator_ARG_MAX] = make_unique<ArgMaxPrinter>();
731 _op_map[tflite::BuiltinOperator_ARG_MIN] = make_unique<ArgMinPrinter>();
732 _op_map[tflite::BuiltinOperator_AVERAGE_POOL_2D] = make_unique<Pool2DPrinter>();
733 _op_map[tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM] =
734 make_unique<BidirectionalSequenceLSTMPrinter>();
735 _op_map[tflite::BuiltinOperator_CAST] = make_unique<CastPrinter>();
736 // There is no Option for CEIL
737 _op_map[tflite::BuiltinOperator_CONCATENATION] = make_unique<ConcatenationPrinter>();
738 _op_map[tflite::BuiltinOperator_CONV_2D] = make_unique<Conv2DPrinter>();
739 // There is no Option for DENSIFY
740 _op_map[tflite::BuiltinOperator_DEPTH_TO_SPACE] = make_unique<DepthToSpacePrinter>();
741 _op_map[tflite::BuiltinOperator_DEPTHWISE_CONV_2D] = make_unique<DepthwiseConv2DPrinter>();
742 // There is no Option for DEQUANTIZE
743 _op_map[tflite::BuiltinOperator_DIV] = make_unique<DivPrinter>();
744 _op_map[tflite::BuiltinOperator_FAKE_QUANT] = make_unique<FakeQuantPrinter>();
745 // There is no Option for FLOOR
746 // There is no Option for FLOOR_MOD
747 _op_map[tflite::BuiltinOperator_FULLY_CONNECTED] = make_unique<FullyConnectedPrinter>();
748 _op_map[tflite::BuiltinOperator_GATHER] = make_unique<GatherPrinter>();
749 _op_map[tflite::BuiltinOperator_IF] = make_unique<IfPrinter>();
750 _op_map[tflite::BuiltinOperator_L2_POOL_2D] = make_unique<Pool2DPrinter>();
751 _op_map[tflite::BuiltinOperator_L2_NORMALIZATION] = make_unique<L2NormPrinter>();
752 _op_map[tflite::BuiltinOperator_LEAKY_RELU] = make_unique<LeakyReluPrinter>();
753 _op_map[tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION] =
754 make_unique<LocalResponseNormalizationPrinter>();
755 // There is no Option for LOG
756 // There is no Option for LOGISTIC
757 // There is no Option for LOG_SOFTMAX
758 _op_map[tflite::BuiltinOperator_MAX_POOL_2D] = make_unique<Pool2DPrinter>();
759 _op_map[tflite::BuiltinOperator_MEAN] = make_unique<ReducerPrinter>();
760 _op_map[tflite::BuiltinOperator_MIRROR_PAD] = make_unique<MirrorPadPrinter>();
761 _op_map[tflite::BuiltinOperator_MUL] = make_unique<MulPrinter>();
762 // There is no Option for NON_MAX_SUPPRESSION_V4
763 // There is no Option for NON_MAX_SUPPRESSION_V5
764 _op_map[tflite::BuiltinOperator_ONE_HOT] = make_unique<OneHotPrinter>();
765 _op_map[tflite::BuiltinOperator_PACK] = make_unique<PackPrinter>();
766 // There is no Option for PAD
767 // There is no Option for PADV2
768 // There is no Option for PRELU
769 // There is no Option for RELU
770 // There is no Option for RELU6
771 // There is no Option for RELU_N1_TO_1
772 _op_map[tflite::BuiltinOperator_REDUCE_ANY] = make_unique<ReducerPrinter>();
773 _op_map[tflite::BuiltinOperator_REDUCE_MAX] = make_unique<ReducerPrinter>();
774 _op_map[tflite::BuiltinOperator_REDUCE_MIN] = make_unique<ReducerPrinter>();
775 _op_map[tflite::BuiltinOperator_REDUCE_PROD] = make_unique<ReducerPrinter>();
776 _op_map[tflite::BuiltinOperator_RESHAPE] = make_unique<ReshapePrinter>();
777 _op_map[tflite::BuiltinOperator_RESIZE_BILINEAR] = make_unique<ResizeBilinearPrinter>();
778 _op_map[tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR] =
779 make_unique<ResizeNearestNeighborPrinter>();
780 _op_map[tflite::BuiltinOperator_REVERSE_SEQUENCE] = make_unique<ReverseSequencePrinter>();
781 // There is no Option for ROUND
782 // There is no Option for SELECT
783 // There is no Option for SELECT_V2
784 _op_map[tflite::BuiltinOperator_SHAPE] = make_unique<ShapePrinter>();
785 // There is no Option for SIN
786 // There is no Option for SLICE
787 _op_map[tflite::BuiltinOperator_SOFTMAX] = make_unique<SoftmaxPrinter>();
788 _op_map[tflite::BuiltinOperator_SPACE_TO_DEPTH] = make_unique<SpaceToDepthPrinter>();
789 // There is no Option for SPACE_TO_BATCH_ND
790 _op_map[tflite::BuiltinOperator_SPARSE_TO_DENSE] = make_unique<SparseToDensePrinter>();
791 _op_map[tflite::BuiltinOperator_SPLIT] = make_unique<SplitPrinter>();
792 _op_map[tflite::BuiltinOperator_SPLIT_V] = make_unique<SplitVPrinter>();
793 _op_map[tflite::BuiltinOperator_SQUEEZE] = make_unique<SqueezePrinter>();
794 _op_map[tflite::BuiltinOperator_STRIDED_SLICE] = make_unique<StridedSlicePrinter>();
795 _op_map[tflite::BuiltinOperator_SUB] = make_unique<SubPrinter>();
796 _op_map[tflite::BuiltinOperator_SUM] = make_unique<ReducerPrinter>();
797 _op_map[tflite::BuiltinOperator_SVDF] = make_unique<SVDFPrinter>();
798 _op_map[tflite::BuiltinOperator_TRANSPOSE_CONV] = make_unique<TransposeConvPrinter>();
799 // There is no Option for TOPK_V2
800 _op_map[tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM] =
801 make_unique<UnidirectionalSequenceLSTMPrinter>();
802 _op_map[tflite::BuiltinOperator_UNIQUE] = make_unique<UniquePrinter>();
803 _op_map[tflite::BuiltinOperator_WHILE] = make_unique<WhilePrinter>();
804 _op_map[tflite::BuiltinOperator_CUSTOM] = make_unique<CustomOpPrinter>();
807 } // namespace tfldump