2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "OpPrinter.h"
22 #include <flatbuffers/flexbuffers.h>
24 using std::make_unique;
29 // TODO move to some header
30 std::ostream &operator<<(std::ostream &os, const std::vector<int32_t> &vect);
32 // TODO Re-arrange in alphabetical order
34 class AddPrinter : public OpPrinter
37 void options(const circle::Operator *op, std::ostream &os) const override
39 if (auto *params = op->builtin_options_as_AddOptions())
42 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
49 class ArgMaxPrinter : public OpPrinter
52 void options(const circle::Operator *op, std::ostream &os) const override
54 if (auto *params = op->builtin_options_as_ArgMaxOptions())
57 os << "OutputType(" << EnumNameTensorType(params->output_type()) << ") ";
63 class ArgMinPrinter : public OpPrinter
66 void options(const circle::Operator *op, std::ostream &os) const override
68 if (auto *params = op->builtin_options_as_ArgMinOptions())
71 os << "OutputType(" << EnumNameTensorType(params->output_type()) << ") ";
77 class BatchMatMulPrinter : public OpPrinter
80 void options(const circle::Operator *op, std::ostream &os) const override
82 if (auto *params = op->builtin_options_as_BatchMatMulOptions())
86 os << "adjoint_lhs(" << params->adjoint_lhs() << ") ";
87 os << "adjoint_rhs(" << params->adjoint_rhs() << ") ";
93 class CastPrinter : public OpPrinter
96 void options(const circle::Operator *op, std::ostream &os) const override
98 if (auto cast_params = op->builtin_options_as_CastOptions())
101 os << "in_data_type(" << circle::EnumNameTensorType(cast_params->in_data_type()) << ") ";
102 os << "out_data_type(" << circle::EnumNameTensorType(cast_params->out_data_type()) << ") ";
108 class Conv2DPrinter : public OpPrinter
111 void options(const circle::Operator *op, std::ostream &os) const override
113 if (auto conv_params = op->builtin_options_as_Conv2DOptions())
116 os << "Padding(" << conv_params->padding() << ") ";
117 os << "Stride.W(" << conv_params->stride_w() << ") ";
118 os << "Stride.H(" << conv_params->stride_h() << ") ";
119 os << "Dilation.W(" << conv_params->dilation_w_factor() << ") ";
120 os << "Dilation.H(" << conv_params->dilation_h_factor() << ") ";
122 << EnumNameActivationFunctionType(conv_params->fused_activation_function()) << ")";
128 class DepthToSpacePrinter : public OpPrinter
131 void options(const circle::Operator *op, std::ostream &os) const override
133 if (auto *std_params = op->builtin_options_as_DepthToSpaceOptions())
136 os << "BlockSize(" << std_params->block_size() << ")";
142 class DivPrinter : public OpPrinter
145 void options(const circle::Operator *op, std::ostream &os) const override
147 if (auto *params = op->builtin_options_as_DivOptions())
150 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
157 class Pool2DPrinter : public OpPrinter
160 void options(const circle::Operator *op, std::ostream &os) const override
162 if (auto pool_params = op->builtin_options_as_Pool2DOptions())
165 os << "Padding(" << pool_params->padding() << ") ";
166 os << "Stride.W(" << pool_params->stride_w() << ") ";
167 os << "Stride.H(" << pool_params->stride_h() << ") ";
168 os << "Filter.W(" << pool_params->filter_width() << ") ";
169 os << "Filter.H(" << pool_params->filter_height() << ") ";
171 << EnumNameActivationFunctionType(pool_params->fused_activation_function()) << ")";
177 class ConcatenationPrinter : public OpPrinter
180 void options(const circle::Operator *op, std::ostream &os) const override
182 if (auto *concatenation_params = op->builtin_options_as_ConcatenationOptions())
186 << EnumNameActivationFunctionType(concatenation_params->fused_activation_function())
188 os << "Axis(" << concatenation_params->axis() << ")";
194 class ReducerPrinter : public OpPrinter
197 void options(const circle::Operator *op, std::ostream &os) const override
199 if (auto reducer_params = op->builtin_options_as_ReducerOptions())
202 os << "keep_dims(" << reducer_params->keep_dims() << ") ";
208 class ReshapePrinter : public OpPrinter
211 void options(const circle::Operator *op, std::ostream &os) const override
213 if (auto *reshape_params = op->builtin_options_as_ReshapeOptions())
215 auto new_shape = circleread::as_index_vector(reshape_params->new_shape());
217 os << "NewShape(" << new_shape << ")";
223 class ResizeBilinearPrinter : public OpPrinter
226 void options(const circle::Operator *op, std::ostream &os) const override
228 if (auto *resize_params = op->builtin_options_as_ResizeBilinearOptions())
231 os << std::boolalpha;
232 os << "align_corners(" << resize_params->align_corners() << ")";
233 os << "half_pixel_centers(" << resize_params->half_pixel_centers() << ")";
239 class ResizeNearestNeighborPrinter : public OpPrinter
242 void options(const circle::Operator *op, std::ostream &os) const override
244 if (auto *resize_params = op->builtin_options_as_ResizeNearestNeighborOptions())
247 os << std::boolalpha;
248 os << "align_corners(" << resize_params->align_corners() << ")";
254 class ReverseSequencePrinter : public OpPrinter
257 void options(const circle::Operator *op, std::ostream &os) const override
259 if (auto *params = op->builtin_options_as_ReverseSequenceOptions())
262 os << "seq_dim(" << params->seq_dim() << ") ";
263 os << "batch_dim(" << params->batch_dim() << ") ";
269 class DepthwiseConv2DPrinter : public OpPrinter
272 void options(const circle::Operator *op, std::ostream &os) const override
274 if (auto conv_params = op->builtin_options_as_DepthwiseConv2DOptions())
277 os << "Padding(" << conv_params->padding() << ") ";
278 os << "Stride.W(" << conv_params->stride_w() << ") ";
279 os << "Stride.H(" << conv_params->stride_h() << ") ";
280 os << "DepthMultiplier(" << conv_params->depth_multiplier() << ") ";
281 os << "Dilation.W(" << conv_params->dilation_w_factor() << ") ";
282 os << "Dilation.H(" << conv_params->dilation_h_factor() << ")";
284 << EnumNameActivationFunctionType(conv_params->fused_activation_function()) << ") ";
290 class FullyConnectedPrinter : public OpPrinter
293 void options(const circle::Operator *op, std::ostream &os) const override
295 if (auto *params = op->builtin_options_as_FullyConnectedOptions())
298 os << "WeightFormat(" << EnumNameFullyConnectedOptionsWeightsFormat(params->weights_format())
300 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
308 class GatherPrinter : public OpPrinter
311 void options(const circle::Operator *op, std::ostream &os) const override
313 if (auto *params = op->builtin_options_as_GatherOptions())
316 os << "Axis(" << params->axis() << ") ";
323 class IfPrinter : public OpPrinter
326 void options(const circle::Operator *op, std::ostream &os) const override
328 if (auto *params = op->builtin_options_as_IfOptions())
331 os << "then_subgraph_index(" << params->then_subgraph_index() << ") ";
332 os << "else_subgraph_index(" << params->else_subgraph_index() << ") ";
338 class L2NormPrinter : public OpPrinter
341 void options(const circle::Operator *op, std::ostream &os) const override
343 if (auto *params = op->builtin_options_as_L2NormOptions())
346 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
353 class LeakyReluPrinter : public OpPrinter
356 void options(const circle::Operator *op, std::ostream &os) const override
358 if (auto *params = op->builtin_options_as_LeakyReluOptions())
361 os << "alpha(" << params->alpha() << ") ";
366 class LocalResponseNormalizationPrinter : public OpPrinter
369 void options(const circle::Operator *op, std::ostream &os) const override
371 if (auto *params = op->builtin_options_as_LocalResponseNormalizationOptions())
374 os << "radius(" << params->radius() << ") ";
375 os << "bias(" << params->bias() << ") ";
376 os << "alpha(" << params->alpha() << ") ";
377 os << "beta(" << params->beta() << ") ";
383 class MirrorPadPrinter : public OpPrinter
386 void options(const circle::Operator *op, std::ostream &os) const override
388 if (auto *params = op->builtin_options_as_MirrorPadOptions())
391 os << "mode(" << EnumNameMirrorPadMode(params->mode()) << ") ";
397 class MulPrinter : public OpPrinter
400 void options(const circle::Operator *op, std::ostream &os) const override
402 if (auto *params = op->builtin_options_as_MulOptions())
405 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
412 class OneHotPrinter : public OpPrinter
415 void options(const circle::Operator *op, std::ostream &os) const override
417 if (auto *params = op->builtin_options_as_OneHotOptions())
420 os << "Axis(" << params->axis() << ") ";
427 class PackPrinter : public OpPrinter
430 void options(const circle::Operator *op, std::ostream &os) const override
432 if (auto *params = op->builtin_options_as_PackOptions())
435 os << "ValuesCount(" << params->values_count() << ") ";
436 os << "Axis(" << params->axis() << ") ";
442 class ShapePrinter : public OpPrinter
445 void options(const circle::Operator *op, std::ostream &os) const override
447 if (auto *params = op->builtin_options_as_ShapeOptions())
450 os << "out_type(" << EnumNameTensorType(params->out_type()) << ") ";
456 class SoftmaxPrinter : public OpPrinter
459 void options(const circle::Operator *op, std::ostream &os) const override
461 if (auto *softmax_params = op->builtin_options_as_SoftmaxOptions())
464 os << "Beta(" << softmax_params->beta() << ")";
470 class SpaceToDepthPrinter : public OpPrinter
473 void options(const circle::Operator *op, std::ostream &os) const override
475 if (auto *std_params = op->builtin_options_as_SpaceToDepthOptions())
478 os << "BlockSize(" << std_params->block_size() << ")";
484 class SparseToDensePrinter : public OpPrinter
487 void options(const circle::Operator *op, std::ostream &os) const override
489 if (auto *std_params = op->builtin_options_as_SparseToDenseOptions())
492 os << "ValidateIndices(" << std_params->validate_indices() << ")";
498 class SplitPrinter : public OpPrinter
501 void options(const circle::Operator *op, std::ostream &os) const override
503 if (auto *params = op->builtin_options_as_SplitOptions())
506 os << "num_splits(" << params->num_splits() << ") ";
512 class SplitVPrinter : public OpPrinter
515 void options(const circle::Operator *op, std::ostream &os) const override
517 if (auto *params = op->builtin_options_as_SplitVOptions())
520 os << "num_splits(" << params->num_splits() << ") ";
526 class SqueezePrinter : public OpPrinter
529 void options(const circle::Operator *op, std::ostream &os) const override
531 if (auto *params = op->builtin_options_as_SqueezeOptions())
534 os << "SqueezeDims(";
535 for (int i = 0; i < params->squeeze_dims()->size(); ++i)
539 os << params->squeeze_dims()->Get(i);
547 class StridedSlicePrinter : public OpPrinter
550 void options(const circle::Operator *op, std::ostream &os) const override
552 if (auto *strided_slice_params = op->builtin_options_as_StridedSliceOptions())
555 os << "begin_mask(" << strided_slice_params->begin_mask() << ") ";
556 os << "end_mask(" << strided_slice_params->end_mask() << ") ";
557 os << "ellipsis_mask(" << strided_slice_params->ellipsis_mask() << ") ";
558 os << "new_axis_mask(" << strided_slice_params->new_axis_mask() << ") ";
559 os << "shrink_axis_mask(" << strided_slice_params->shrink_axis_mask() << ") ";
565 class SubPrinter : public OpPrinter
568 void options(const circle::Operator *op, std::ostream &os) const override
570 if (auto *params = op->builtin_options_as_SubOptions())
573 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
580 class TransposeConvPrinter : public OpPrinter
583 void options(const circle::Operator *op, std::ostream &os) const override
585 if (auto conv_params = op->builtin_options_as_TransposeConvOptions())
588 os << "Padding(" << conv_params->padding() << ") ";
589 os << "Stride.W(" << conv_params->stride_w() << ") ";
590 os << "Stride.H(" << conv_params->stride_h() << ") ";
596 class UniquePrinter : public OpPrinter
599 void options(const circle::Operator *op, std::ostream &os) const override
601 if (auto *params = op->builtin_options_as_UniqueOptions())
604 os << "idx_out_type(" << EnumNameTensorType(params->idx_out_type()) << ") ";
610 class WhilePrinter : public OpPrinter
613 void options(const circle::Operator *op, std::ostream &os) const override
615 if (auto *params = op->builtin_options_as_WhileOptions())
618 os << "cond_subgraph_index(" << params->cond_subgraph_index() << ") ";
619 os << "body_subgraph_index(" << params->body_subgraph_index() << ") ";
625 class CustomOpPrinter : public OpPrinter
628 void options(const circle::Operator *op, std::ostream &os) const override
630 if (op->custom_options_format() != circle::CustomOptionsFormat::CustomOptionsFormat_FLEXBUFFERS)
633 os << "Unknown custom option format";
637 const flatbuffers::Vector<uint8_t> *option_buf = op->custom_options();
639 if (option_buf == nullptr || option_buf->size() == 0)
641 os << "No attrs found." << std::endl;
646 // attrs of custom ops are encoded in flexbuffer format
647 auto attr_map = flexbuffers::GetRoot(option_buf->data(), option_buf->size()).AsMap();
650 auto keys = attr_map.Keys();
651 for (int i = 0; i < keys.size(); i++)
653 auto key = keys[i].ToString();
654 os << key << "(" << attr_map[key].ToString() << ") ";
657 // Note: attr in "Shape" type does not seem to be converted by circle_convert.
658 // When the converted circle file (with custom op) is opened with hexa editory,
659 // attrs names can be found but attr name in "Shape" type is not found.
665 class BCQFullyConnectedPrinter : public OpPrinter
668 void options(const circle::Operator *op, std::ostream &os) const override
670 if (auto *params = op->builtin_options_as_BCQFullyConnectedOptions())
673 os << "Activation(" << EnumNameActivationFunctionType(params->fused_activation_function())
675 os << "weights_hidden_size(" << params->weights_hidden_size() << ") ";
681 class BCQGatherPrinter : public OpPrinter
684 void options(const circle::Operator *op, std::ostream &os) const override
686 if (auto *params = op->builtin_options_as_BCQGatherOptions())
689 os << "axis(" << params->axis() << ") ";
690 os << "weights_hidden_size(" << params->input_hidden_size() << ") ";
696 OpPrinterRegistry::OpPrinterRegistry()
698 _op_map[circle::BuiltinOperator_ADD] = make_unique<AddPrinter>();
699 // There is no Option for ADD_N
700 _op_map[circle::BuiltinOperator_ARG_MAX] = make_unique<ArgMaxPrinter>();
701 _op_map[circle::BuiltinOperator_ARG_MIN] = make_unique<ArgMinPrinter>();
702 _op_map[circle::BuiltinOperator_AVERAGE_POOL_2D] = make_unique<Pool2DPrinter>();
703 _op_map[circle::BuiltinOperator_BATCH_MATMUL] = make_unique<BatchMatMulPrinter>();
704 _op_map[circle::BuiltinOperator_CAST] = make_unique<CastPrinter>();
705 // There is no Option for CEIL
706 _op_map[circle::BuiltinOperator_CONCATENATION] = make_unique<ConcatenationPrinter>();
707 _op_map[circle::BuiltinOperator_CONV_2D] = make_unique<Conv2DPrinter>();
708 _op_map[circle::BuiltinOperator_DEPTH_TO_SPACE] = make_unique<DepthToSpacePrinter>();
709 _op_map[circle::BuiltinOperator_DEPTHWISE_CONV_2D] = make_unique<DepthwiseConv2DPrinter>();
710 _op_map[circle::BuiltinOperator_DIV] = make_unique<DivPrinter>();
711 // There is no Option for FLOOR
712 // There is no Option for FLOOR_MOD
713 _op_map[circle::BuiltinOperator_FULLY_CONNECTED] = make_unique<FullyConnectedPrinter>();
714 _op_map[circle::BuiltinOperator_GATHER] = make_unique<GatherPrinter>();
715 _op_map[circle::BuiltinOperator_IF] = make_unique<IfPrinter>();
716 _op_map[circle::BuiltinOperator_L2_NORMALIZATION] = make_unique<L2NormPrinter>();
717 _op_map[circle::BuiltinOperator_L2_POOL_2D] = make_unique<Pool2DPrinter>();
718 _op_map[circle::BuiltinOperator_LEAKY_RELU] = make_unique<LeakyReluPrinter>();
719 _op_map[circle::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION] =
720 make_unique<LocalResponseNormalizationPrinter>();
721 // There is no Option for LOG
722 // There is no Option for LOGISTIC
723 // There is no Option for LOG_SOFTMAX
724 _op_map[circle::BuiltinOperator_MAX_POOL_2D] = make_unique<Pool2DPrinter>();
725 _op_map[circle::BuiltinOperator_MIRROR_PAD] = make_unique<MirrorPadPrinter>();
726 _op_map[circle::BuiltinOperator_MUL] = make_unique<MulPrinter>();
727 // There is no Option for NON_MAX_SUPPRESSION_V4
728 // There is no Option for NON_MAX_SUPPRESSION_V5
729 _op_map[circle::BuiltinOperator_ONE_HOT] = make_unique<OneHotPrinter>();
730 _op_map[circle::BuiltinOperator_PACK] = make_unique<PackPrinter>();
731 // There is no Option for PAD
732 // There is no Option for PADV2
733 // There is no Option for PRELU
734 // There is no Option for RELU
735 // There is no Option for RELU6
736 // There is no Option for RELU_N1_TO_1
737 _op_map[circle::BuiltinOperator_REDUCE_ANY] = make_unique<ReducerPrinter>();
738 _op_map[circle::BuiltinOperator_REDUCE_MAX] = make_unique<ReducerPrinter>();
739 _op_map[circle::BuiltinOperator_REDUCE_MIN] = make_unique<ReducerPrinter>();
740 _op_map[circle::BuiltinOperator_REDUCE_PROD] = make_unique<ReducerPrinter>();
741 _op_map[circle::BuiltinOperator_RESHAPE] = make_unique<ReshapePrinter>();
742 _op_map[circle::BuiltinOperator_RESIZE_BILINEAR] = make_unique<ResizeBilinearPrinter>();
743 _op_map[circle::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR] =
744 make_unique<ResizeNearestNeighborPrinter>();
745 _op_map[circle::BuiltinOperator_REVERSE_SEQUENCE] = make_unique<ReverseSequencePrinter>();
746 // There is no Option for ROUND
747 // There is no Option for SELECT
748 // There is no Option for SELECT_V2
749 _op_map[circle::BuiltinOperator_SHAPE] = make_unique<ShapePrinter>();
750 // There is no Option for SIN
751 // There is no Option for SLICE
752 _op_map[circle::BuiltinOperator_SOFTMAX] = make_unique<SoftmaxPrinter>();
753 _op_map[circle::BuiltinOperator_SPACE_TO_DEPTH] = make_unique<SpaceToDepthPrinter>();
754 // There is no Option for SPACE_TO_BATCH_ND
755 _op_map[circle::BuiltinOperator_SPARSE_TO_DENSE] = make_unique<SparseToDensePrinter>();
756 _op_map[circle::BuiltinOperator_SPLIT] = make_unique<SplitPrinter>();
757 _op_map[circle::BuiltinOperator_SPLIT_V] = make_unique<SplitVPrinter>();
758 _op_map[circle::BuiltinOperator_SQUEEZE] = make_unique<SqueezePrinter>();
759 _op_map[circle::BuiltinOperator_STRIDED_SLICE] = make_unique<StridedSlicePrinter>();
760 _op_map[circle::BuiltinOperator_SUB] = make_unique<SubPrinter>();
761 _op_map[circle::BuiltinOperator_SUM] = make_unique<ReducerPrinter>();
762 _op_map[circle::BuiltinOperator_TRANSPOSE_CONV] = make_unique<TransposeConvPrinter>();
763 // There is no Option for TOPK_V2
764 _op_map[circle::BuiltinOperator_UNIQUE] = make_unique<UniquePrinter>();
765 _op_map[circle::BuiltinOperator_WHILE] = make_unique<WhilePrinter>();
766 _op_map[circle::BuiltinOperator_CUSTOM] = make_unique<CustomOpPrinter>();
769 _op_map[circle::BuiltinOperator_BCQ_FULLY_CONNECTED] = make_unique<BCQFullyConnectedPrinter>();
770 _op_map[circle::BuiltinOperator_BCQ_GATHER] = make_unique<BCQGatherPrinter>();
773 } // namespace circledump